Understanding CNN’s with a CAM - A Class Activation Map

In this article we will look at how Class Acivation Maps (CAM’s) can be used to understand and interpret the decisions that Convolutional Neural Networks (CNN’s) make.
deep-learning
Author

Pranath Fernando

Published

June 19, 2021

1 Introduction

In this article we will look at how Class Acivation Maps (CAM’s) can be used to understand and interpret the decisions that Convolutional Neural Networks (CNN’s) make.

2 CAM and Pytorch hooks

A Class Activation Map (CAM) and help us understand why Convolutional Neural Networks (CNN’s) make the descisions they do. CAM’s do this by looking at the outputs of the last convolutional layer just before the average pooling layer - combined with the predictions, to give a heatmap visualisation of why the model made that descision.

At each point in our final convolutional layer, we have as many channels as in the last linear layer. We can compute a dot product of those activations with the final weights to get for each location in our feature map, the score of the feature that was used to make that decision. In other words, we can identify the relationships between the parts of the network that are most active in generating the correct choice.

We can access activations inside the network using Pytorch hooks. Wheras fastai callbacks allow you to inject code into the training loop, Pytorch hooks allow you to inject code into the forward and backward calculations themselves..

Lets see an example looking at a dataset of cats and dogs.

path = untar_data(URLs.PETS)/'images'
def is_cat(x): return x[0].isupper()
dls = ImageDataLoaders.from_name_func(
    path, get_image_files(path), valid_pct=0.2, seed=21,
    label_func=is_cat, item_tfms=Resize(224))
learn = cnn_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(1)
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
epoch train_loss valid_loss error_rate time
0 0.138940 0.025390 0.008796 00:48
epoch train_loss valid_loss error_rate time
0 0.047596 0.024207 0.007442 00:52

We can get a cat image. For CAM we want to store the activations of the last convolutional layer, lets create a hook function in a class with a state.

img = PILImage.create(image_cat())
x, = first(dls.test_dl([img]))

class Hook():
    def hook_func(self, m, i, o): self.stored = o.detach().clone()

We can then instantiate a hook and attach it to any layer, in this case the last layer of the CNN body.

hook_output = Hook()
hook = learn.model[0].register_forward_hook(hook_output.hook_func)

Then we can grab a batch of images and feed it through our model.

with torch.no_grad(): output = learn.model.eval()(x)

Then we can extract our stored activations

act = hook_output.stored[0]

And check our predictions.

F.softmax(output, dim=-1)
tensor([[1.1078e-08, 1.0000e+00]], device='cuda:0')

So 0 means dog, but just to check.

dls.vocab
[False, True]

So the model seems quite confident the image is a cat.

To perform our dot product of the weight matrix with the activations we can use einsum.

x.shape
torch.Size([1, 3, 224, 224])
cam_map = torch.einsum('ck,kij->cij', learn.model[1][-1].weight, act)
cam_map.shape
torch.Size([2, 7, 7])

So for each image in the batch, we get a 7x7 channel map that tells us which activations were higher or lower, which will allow us to see what parts of the image most influenced the models choice.

x_dec = TensorImage(dls.train.decode((x,))[0][0])
_,ax = plt.subplots()
x_dec.show(ctx=ax)
ax.imshow(cam_map[1].detach().cpu(), alpha=0.6, extent=(0,224,224,0),
              interpolation='bilinear', cmap='magma');

The parts in bright yellow correspond to higher activations and purple lower activations. So we can see the paws are the main area that made the model decide it was a cat. Its good to remove a hook once used as it can leak memory.

hook.remove()

We can manage hooks better by using a class, to handle all these things automatically.

class Hook():
    def __init__(self, m):
        self.hook = m.register_forward_hook(self.hook_func)   
    def hook_func(self, m, i, o): self.stored = o.detach().clone()
    def __enter__(self, *args): return self
    def __exit__(self, *args): self.hook.remove()

with Hook(learn.model[0]) as hook:
    with torch.no_grad(): output = learn.model.eval()(x.cuda())
    act = hook.stored

This Hook class is provided by fastai. This approach only works for the last layer.

3 Gradient CAM

The previous approach only works for the last layer, but what if we want to look at activations for earlier layers? Gradient CAM lets us do this. Normally the gradients for weights are not stored after the backward pass, but we can store them, and then pick them up with a hook.

class HookBwd():
    def __init__(self, m):
        self.hook = m.register_backward_hook(self.hook_func)   
    def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()
    def __enter__(self, *args): return self
    def __exit__(self, *args): self.hook.remove()

Let’s try this approach on the last layer, as we did before. However we can use this approach to calculate the gradients for any layer, with respect to the output.

cls = 1
with HookBwd(learn.model[0]) as hookg:
    with Hook(learn.model[0]) as hook:
        output = learn.model.eval()(x.cuda())
        act = hook.stored
    output[0,cls].backward()
    grad = hookg.stored
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py:974: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
  warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "

The weights for the Grad-CAM approach are given by the average of our gradients accross the feature/channel map.

w = grad[0].mean(dim=[1,2], keepdim=True)
cam_map = (w * act[0]).sum(0)

_,ax = plt.subplots()
x_dec.show(ctx=ax)
ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),
              interpolation='bilinear', cmap='magma');

Let’s now try this on a different layer, the second to last ResNet group layer.

with HookBwd(learn.model[0][-2]) as hookg:
    with Hook(learn.model[0][-2]) as hook:
        output = learn.model.eval()(x.cuda())
        act = hook.stored
    output[0,cls].backward()
    grad = hookg.stored

w = grad[0].mean(dim=[1,2], keepdim=True)
cam_map = (w * act[0]).sum(0)

_,ax = plt.subplots()
x_dec.show(ctx=ax)
ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),
              interpolation='bilinear', cmap='magma');
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py:974: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
  warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "

4 Conclusion

In this article we saw how we can use Class Activation Map’s to understand and interpret the choices a CNN makes.

Subscribe