Get heatmap from CNN (Convolutional Neural Network), aka CAM

1. Description

         Convolutional Neural Networks (CNN) are incredible. If you want to know how it sees the world (image), one way is to visualize it.
        The idea is that we get the weights from the last dense layer and then multiply them by the final CNN layer. This requires Global Average Pooling (GAP) to work.

2. Select model

        In this tutorial, we use Keras with Tensorflow and ResNet50.

        Because ResNet50 has a global average pooling (GAP) layer (explained later), it is suitable for our demonstration. This is perfect.

test image

3. How heat maps work

        Heatmap from CNN, aka Class Activation Map ( CAM ). The idea is that we collect each output of the convolutional layer (as an image) and combine it in a single shot. (We will show the code step by step later)

Convolutional layer output

        So, here's how Global Average Pooling (GAP) or Global Max Pooling works (depending on which one you use, but they're the same idea).

        In some post-feature extraction models, we use flattened layers (fully connected) with neural networks to predict results. But this step is like discarding image dimensions and some information.

        In contrast, using Global Average Pooling (GAP) or Global Max Pooling (GMP) works here. It preserves image dimension information and enables the neural network to decide which CNN channel (feature image) is more critical for the prediction result.

4. Examples and Code

Let's start with ResNet50 in Keras.

from tensorflow.keras.applications import ResNet50
res_model = ResNet50()
res_model.summary() 
ResNet-50 Summary

        As you can see (above):

  • Red: We will use this layer as a "Transfer Tilt".
  • Green: Global Average Pooling (GAP). This work is critical.

        and import libraries and images for later use.

import cv2
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
img = cv2.imread('./test_cat.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
X = np.expand_dims(img, axis=0).astype(np.float32)
X = preprocess_input(X)

        We use "from scipy.ndimage import  zoom ". To resize the heatmap due to CNN, the shape of the feature extracted image is smaller than the original image.

4.1 Transfer learning

        Now extract the layer we will use.
        PS: You can train your model from scratch, but it will take a long time and feature extraction may also require a lot of tuning.

from tensorflow.keras.models import Model
conv_output = res_model.get_layer("conv5_block3_out").output
pred_ouptut = res_model.get_layer("predictions").output
model = Model(res_model.input, outputs=[conv_ouptut, pred_layer])

        Here we have two outputs (as mentioned, red in the diagram).

  • The first is the convolutional network output
  • The second is the prediction result

        and make predictions

conv, pred = model.predict(X)
decode_predictions(pred)

The results are shown below. not bad

[[('n02123159', 'tiger_cat', 0.7185241),
  ('n02123045', 'tabby', 0.1784818),
  ('n02124075', 'Egyptian_cat', 0.034279127),
  ('n03958227', 'plastic_bag', 0.006443105),
  ('n03793489', 'mouse', 0.004671723)]]

4.2 Output

        Now, let's look at some CNN output.

scale = 224 / 7
plt.figure(figsize=(16, 16))
for i in range(36):
    plt.subplot(6, 6, i + 1)
    plt.imshow(img)
    plt.imshow(zoom(conv[0, :,:,i], zoom=(scale, scale)), cmap='jet', alpha=0.3)

CNN output

We show the ground image first (  plt.imshow(img)  ) so we can compare it to the ground image.
(If you don't do this, you will get the result like this)

Figure without background image

4.3 One-time combination of outputs

        This is critical. We use the predicted outcome index (target) to obtain the weights. and multiply each feature map with weight (dot product)

target = np.argmax(pred, axis=1).squeeze()
w, b = model.get_layer("predictions").weights
weights = w[:, target].numpy()
heatmap = conv.squeeze() @ weights

The heat map with the ground image is then displayed.

scale = 224 / 7
plt.figure(figsize=(12, 12))
plt.imshow(img)
plt.imshow(zoom(heatmap, zoom=(scale, scale)), cmap='jet', alpha=0.5)
CNN heat map

        This is the result we want.

5. Reference resources

Chaos of the sea

Guess you like

Origin blog.csdn.net/gongdiwudu/article/details/132899697