PyTorch Deep Learning Practice (14) - Class Activation Diagram
0. Preface
We have been able to build neural network models with excellent performance , but the decision-making process of convolutional neural networks is difficult for us to explain and understand. Class activation plots ( Class Activation Map
, CAM
) are a visualization technique used to explain the decision-making process of deep learning models in image classification tasks. It can show the image areas that have the most significant impact on the classification results, thereby providing interpretability of model decisions. By observing the class activation map, we can understand the areas and features that the model focuses on in classification decisions, which helps us analyze and explain the decision-making basis of the model, and verify whether the model focuses on the correct features. In this section, we will introduce the basic concept of class activation map and use the trained model to generate the class activation map of the image.
1. Class activation diagram
1.1 Basic concepts
Class activation maps ( Class Activation Map
, ) are a technique CAM
for visualizing the local importance of each class in a Convolutional Neural Network ( Convolutional Neural Networks
, ), using which can help us understand the decision process, and which features are most important for the classification of a certain class . An example is shown below, where the input image is on the left and the pixels used for class prediction are highlighted on the right:CNN
CAM
CNN
CAM
According to the above activation map, it can be seen that the high activation areas are concentrated in the parts that are most helpful for the model to make category predictions. Next, we will continue to introduce how to generate after the model is trained CAM
.
1.2 Class activation map generation
The feature map is the intermediate activation after the convolution operation. Usually the shape of the feature map is batch size x height x width
, where batch size
represents the batch size, height
represents the feature map height, and width
represents the feature map width. If we average these activations, they show hotspots for all classes in the image. However, if we are only interested in the locations that are important for a specific category (such as cats), then we need to find the n
feature maps in the channels that are only responsible for that category. For the convolutional layers that generate these feature maps, we can compute their gradients with respect to the cat class. Note that only the channel responsible for predicting cats will have higher gradients. This means we can use gradient information to n
weight each of the channels and get an activation map specific to cats.
Specifically, we can use the following process to generate CAM
, the process refers to Grad-CAM: Gradient-weighted Class Activation Mapping:
- Decide which class to compute for
CAM
, and which convolutional layer in the neural network to computeCAM
- Compute the activations produced by the convolutional layer - assuming the feature shape of the convolutional layer is
512 x 7 x 7
- Get the gradient value generated from this layer with respect to the class of interest, and the output gradient shape is
256 x 512 x 3 x 3
(the shape of the convolution tensor - that is, ,batch size x height x width x kernel size
wherekernel size
represents the kernel size) - Computes the average of the gradients within each output channel, with an output shape of
512
- Compute the weighted activation map -
512
gradient mean multiplied by512
activation channels, the output shape is512 x 7 x 7
- Computes the average of the weighted activation maps (across
512
channels), obtaining7 x 7
an output of shape - Resize (enlarge) the weighted activation map output to get an image of the same size as the input, with the goal of getting an activation map of the same size as the original image
- Overlay a weighted activation map onto the input image
The key to the whole process is 步骤 5
to consider the following two aspects:
- If a certain pixel is important, then the convolutional neural network (
Convolutional Neural Networks
,CNN
) will get larger activations at these pixels - If a certain convolutional channel is important for the category of interest, the gradient of this channel will be very large
After multiplying the two, the importance map of all pixels will be obtained. The importance map ( map of importance
) refers to expressing the importance of a certain pixel or feature in the neural network as a heat map or probability distribution map, which is usually used for visualization. Which parts of the neural network output are most important for identifying a specific class.
2 Data set analysis
Malaria Cell Images Dataset
is a commonly used dataset for training and evaluating the performance of computer vision models on malaria cell image classification tasks. This dataset contains images of infected and uninfected red blood cells and is commonly used in research and development of algorithms to automatically detect and classify malaria cells. By using this data set, a deep learning model can be built to automatically identify malaria-infected red blood cells, thereby helping doctors make accurate diagnoses. DownloadKaggle
at the official website . Malaria Cell Images Dataset
3 Generate CAM using PyTorch
Next, we PyTorch
implement CAM
the generative strategy using , to understand CNN
why the model is able to predict possible malaria events in images.
(1) Download the data set and import the relevant libraries:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from glob import glob
from random import randint
import cv2
from pathlib import Path
import torch.nn as nn
from torch import optim
from matplotlib import pyplot as plt
device = 'cuda' if torch.cuda.is_available() else 'cpu'
(2) Specify the index corresponding to the output category:
id2int = {'Parasitized': 0, 'Uninfected': 1}
(3) Perform image conversion operations:
from torchvision import transforms as T
trn_tfms = T.Compose([
T.ToPILImage(),
T.Resize(128),
T.CenterCrop(128),
T.ColorJitter(brightness=(0.95,1.05),
contrast=(0.95,1.05),
saturation=(0.95,1.05),
hue=0.05),
T.RandomAffine(5, translate=(0.01,0.1)),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]),
])
In the above code, a series of transformations are performed on the input image - first the image is resized to 128
(minimum side 128
) and then cropped from the center of the image. In addition, we performed random color dithering and affine transformation, .ToTensor()
scaled the image using the method (so that the pixel values lie between 0 and 1), and finally normalized the image.
Perform transformation on validation set images:
val_tfms = T.Compose([
T.ToPILImage(),
T.Resize(128),
T.CenterCrop(128),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]),
])
(4) Define the data set class MalariaImages
:
class MalariaImages(Dataset):
def __init__(self, files, transform=None):
self.files = files
self.transform = transform
def __len__(self):
return len(self.files)
def __getitem__(self, ix):
fpath = self.files[ix]
clss = os.path.basename(Path(fpath).parent)
img = cv2.imread(fpath)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img, clss
def choose(self):
return self[randint(len(self))]
def collate_fn(self, batch):
_imgs, classes = list(zip(*batch))
if self.transform:
imgs = [self.transform(img)[None] for img in _imgs]
classes = [torch.tensor([id2int[clss]]) for clss in classes]
imgs, classes = [torch.cat(i).to(device) for i in [imgs, classes]]
return imgs, classes, _imgs
(5) Obtain training and validation data sets and data loaders:
all_files = glob('cell_images/*/*.png')
np.random.shuffle(all_files)
from sklearn.model_selection import train_test_split
trn_files, val_files = train_test_split(all_files, random_state=1)
trn_ds = MalariaImages(trn_files, transform=trn_tfms)
val_ds = MalariaImages(val_files, transform=val_tfms)
trn_dl = DataLoader(trn_ds, 32, shuffle=True, collate_fn=trn_ds.collate_fn)
val_dl = DataLoader(val_ds, 32, shuffle=False, collate_fn=val_ds.collate_fn)
(6) Define the model MalariaClassifier
:
def convBlock(ni, no):
return nn.Sequential(
nn.Dropout(0.2),
nn.Conv2d(ni, no, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(no),
nn.MaxPool2d(2),
)
class MalariaClassifier(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
convBlock(3, 64),
convBlock(64, 64),
convBlock(64, 128),
convBlock(128, 256),
convBlock(256, 512),
convBlock(512, 64),
nn.Flatten(),
nn.Linear(256, 256),
nn.Dropout(0.2),
nn.ReLU(inplace=True),
nn.Linear(256, len(id2int))
)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
def compute_metrics(self, preds, targets):
loss = self.loss_fn(preds, targets)
acc = (torch.max(preds, 1)[1] == targets).float().mean()
return loss, acc
(7) Define functions for training and validating models using batch data:
def train_batch(model, data, optimizer, criterion):
model.train()
ims, labels, _ = data
_preds = model(ims)
optimizer.zero_grad()
loss, acc = criterion(_preds, labels)
loss.backward()
optimizer.step()
return loss.item(), acc.item()
@torch.no_grad()
def validate_batch(model, data, criterion):
model.eval()
ims, labels, _ = data
_preds = model(ims)
loss, acc = criterion(_preds, labels)
return loss.item(), acc.item()
(8) Training model:
model = MalariaClassifier().to(device)
criterion = model.compute_metrics
optimizer = optim.Adam(model.parameters(), lr=1e-3)
n_epochs = 5
for ex in range(n_epochs):
train_loss = []
train_acc = []
val_loss = []
val_acc = []
N = len(trn_dl)
for bx, data in enumerate(trn_dl):
loss, acc = train_batch(model, data, optimizer, criterion)
train_loss.append(loss)
train_acc.append(acc)
N = len(val_dl)
for bx, data in enumerate(val_dl):
loss, acc = validate_batch(model, data, criterion)
val_loss.append(loss)
val_acc.append(acc)
avg_train_loss = np.average(train_loss)
avg_train_acc = np.average(train_acc)
avg_val_loss = np.average(val_loss)
avg_val_acc = np.average(val_acc)
print(f"EPOCH: {
ex} trn_loss: {
avg_train_loss} trn_acc: {
avg_train_acc} val_loss: {
avg_val_loss} val_acc: {
avg_val_acc}")
(9)convBlock
Obtain the convolutional layer in the fifth in the model :
im2fmap = nn.Sequential(*(list(model.model[:5].children()) + list(model.model[5][:2].children())))
In the above code, the fourth layer of the model and convBlock
the first two layers in (both Conv2D
layers) are obtained.
(10) Define im2gradCAM
the function, which accepts the input image and obtains the heat map corresponding to the image activation:
def im2gradCAM(x):
model.eval()
logits = model(x)
heatmaps = []
activations = im2fmap(x)
print(activations.shape)
pred = logits.max(-1)[-1]
# 获取模型预测
model.zero_grad()
# 计算相对于模型置信度最高的 logits 的梯度
logits[0,pred].backward(retain_graph=True)
# 获取所需特征图位置的梯度,并对每个特征图取平均梯度
pooled_grads = model.model[-6][1].weight.grad.data.mean((1,2,3))
# 将每个激活图与对应的梯度平均值相乘
for i in range(activations.shape[1]):
activations[:,i,:,:] *= pooled_grads[i]
# 计算所有加权激活图的平均值
heatmap = torch.mean(activations, dim=1)[0].cpu().detach()
return heatmap, 'Uninfected' if pred.item() else 'Parasitized'
(11) Define upsampleHeatmap
the function to upsample the heatmap to a shape corresponding to the shape of the image:
SZ = 120
def upsampleHeatmap(map, img):
m,M = map.min(), map.max()
map = 255 * ((map-m) / (M-m))
map = np.uint8(map)
map = cv2.resize(map, (SZ,SZ))
map = cv2.applyColorMap(255-map, cv2.COLORMAP_JET)
map = np.uint8(map)
map = np.uint8(map*0.7 + img*0.3)
return map
In the previous lines of code, we denormalized the image and overlaid the heatmap on top of the image.
(12) Use a set of test images to call the above function:
N = 20
_val_dl = DataLoader(val_ds, batch_size=N, shuffle=True, collate_fn=val_ds.collate_fn)
x,y,z = next(iter(_val_dl))
for i in range(N):
image = cv2.resize(z[i], (SZ, SZ))
heatmap, pred = im2gradCAM(x[i:i+1])
if(pred=='Uninfected'):
continue
heatmap = upsampleHeatmap(heatmap, image)
plt.figure(figsize=(5,3))
plt.subplot(121)
plt.imshow(image)
plt.subplot(122)
plt.imshow(heatmap)
plt.suptitle(pred)
plt.show()
It can be seen from the figure that the prediction result is determined by the content of the red highlighted area (this part of the area has the highest CAM
value). Now that we have learned how to use a trained model to generate class activation heatmaps of images, we can explain what causes the model to produce a certain classification result.
summary
The generation method of class activation map is mainly based on the last convolution layer and global average pooling layer of the convolutional neural network model. First, the image is input into the convolutional neural network through forward propagation, and then the feature map of the last convolutional layer is obtained. Then, global average pooling is used to calculate the weight of each feature channel to obtain the importance of each channel to the classification result. Afterwards, the original image is multiplied with the weights of the feature maps and superimposed to finally obtain a class activation map. Areas with higher pixel values in the class activation map represent image areas that contribute more to the classification result, while areas with lower pixel values represent areas that contribute less or are irrelevant.
Series link
PyTorch Deep Learning Combat (1) - Neural Network and Model Training Process Detailed
PyTorch Deep Learning Combat (2) - PyTorch Basics
PyTorch Deep Learning Combat (3) - Using PyTorch to Build a Neural Network
PyTorch Deep Learning Combat (4) - Commonly used activation functions and loss functions in detail
PyTorch deep learning practice (5) - computer vision basics
PyTorch deep learning practice (6) - neural network performance optimization technology
PyTorch deep learning practice (7) - the impact of batch size on neural network training
PyTorch deep learning combat (8) - batch normalization
PyTorch deep learning combat (9) - learning rate optimization
PyTorch deep learning combat (10) - overfitting and its solution
PyTorch deep learning combat (11) - Convolutional Neural Network
PyTorch Deep Learning Practice (12) - Data Enhancement
PyTorch Deep Learning Practice (13) - Visualizing the Output of the Middle Layer of the Neural Network