PyTorch Deep Learning Practice (14) - Class Activation Diagram

0. Preface

We have been able to build neural network models with excellent performance , but the decision-making process of convolutional neural networks is difficult for us to explain and understand. Class activation plots ( Class Activation Map, CAM) are a visualization technique used to explain the decision-making process of deep learning models in image classification tasks. It can show the image areas that have the most significant impact on the classification results, thereby providing interpretability of model decisions. By observing the class activation map, we can understand the areas and features that the model focuses on in classification decisions, which helps us analyze and explain the decision-making basis of the model, and verify whether the model focuses on the correct features. In this section, we will introduce the basic concept of class activation map and use the trained model to generate the class activation map of the image.

1. Class activation diagram

1.1 Basic concepts

Class activation maps ( Class Activation Map, ) are a technique CAMfor visualizing the local importance of each class in a Convolutional Neural Network ( Convolutional Neural Networks, ), using which can help us understand the decision process, and which features are most important for the classification of a certain class . An example is shown below, where the input image is on the left and the pixels used for class prediction are highlighted on the right:CNNCAMCNNCAM

class activation diagram
According to the above activation map, it can be seen that the high activation areas are concentrated in the parts that are most helpful for the model to make category predictions. Next, we will continue to introduce how to generate after the model is trained CAM.

1.2 Class activation map generation

The feature map is the intermediate activation after the convolution operation. Usually the shape of the feature map is batch size x height x width, where batch sizerepresents the batch size, heightrepresents the feature map height, and widthrepresents the feature map width. If we average these activations, they show hotspots for all classes in the image. However, if we are only interested in the locations that are important for a specific category (such as cats), then we need to find the nfeature maps in the channels that are only responsible for that category. For the convolutional layers that generate these feature maps, we can compute their gradients with respect to the cat class. Note that only the channel responsible for predicting cats will have higher gradients. This means we can use gradient information to nweight each of the channels and get an activation map specific to cats.
Specifically, we can use the following process to generate CAM, the process refers to Grad-CAM: Gradient-weighted Class Activation Mapping

  1. Decide which class to compute for CAM, and which convolutional layer in the neural network to computeCAM
  2. Compute the activations produced by the convolutional layer - assuming the feature shape of the convolutional layer is512 x 7 x 7
  3. Get the gradient value generated from this layer with respect to the class of interest, and the output gradient shape is 256 x 512 x 3 x 3(the shape of the convolution tensor - that is, , batch size x height x width x kernel sizewhere kernel sizerepresents the kernel size)
  4. Computes the average of the gradients within each output channel, with an output shape of512
  5. Compute the weighted activation map - 512gradient mean multiplied by 512activation channels, the output shape is512 x 7 x 7
  6. Computes the average of the weighted activation maps (across 512channels), obtaining 7 x 7an output of shape
  7. Resize (enlarge) the weighted activation map output to get an image of the same size as the input, with the goal of getting an activation map of the same size as the original image
  8. Overlay a weighted activation map onto the input image

The key to the whole process is 步骤 5to consider the following two aspects:

  • If a certain pixel is important, then the convolutional neural network ( Convolutional Neural Networks, CNN) will get larger activations at these pixels
  • If a certain convolutional channel is important for the category of interest, the gradient of this channel will be very large

After multiplying the two, the importance map of all pixels will be obtained. The importance map ( map of importance) refers to expressing the importance of a certain pixel or feature in the neural network as a heat map or probability distribution map, which is usually used for visualization. Which parts of the neural network output are most important for identifying a specific class.

2 Data set analysis

Malaria Cell Images Datasetis a commonly used dataset for training and evaluating the performance of computer vision models on malaria cell image classification tasks. This dataset contains images of infected and uninfected red blood cells and is commonly used in research and development of algorithms to automatically detect and classify malaria cells. By using this data set, a deep learning model can be built to automatically identify malaria-infected red blood cells, thereby helping doctors make accurate diagnoses. DownloadKaggle at the official website . Malaria Cell Images Dataset

3 Generate CAM using PyTorch

Next, we PyTorchimplement CAMthe generative strategy using , to understand CNNwhy the model is able to predict possible malaria events in images.

(1) Download the data set and import the relevant libraries:

import os
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from glob import glob
from random import randint
import cv2
from pathlib import Path
import torch.nn as nn
from torch import optim
from matplotlib import pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'

(2) Specify the index corresponding to the output category:

id2int = {'Parasitized': 0, 'Uninfected': 1}

(3) Perform image conversion operations:

from torchvision import transforms as T

trn_tfms = T.Compose([
    T.ToPILImage(),
    T.Resize(128),
    T.CenterCrop(128),
    T.ColorJitter(brightness=(0.95,1.05), 
                  contrast=(0.95,1.05), 
                  saturation=(0.95,1.05), 
                  hue=0.05),
    T.RandomAffine(5, translate=(0.01,0.1)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], 
                std=[0.5, 0.5, 0.5]),
])

In the above code, a series of transformations are performed on the input image - first the image is resized to 128(minimum side 128) and then cropped from the center of the image. In addition, we performed random color dithering and affine transformation, .ToTensor()scaled the image using the method (so that the pixel values ​​lie between 0 and 1), and finally normalized the image.

Perform transformation on validation set images:

val_tfms = T.Compose([
    T.ToPILImage(),
    T.Resize(128),
    T.CenterCrop(128),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], 
                std=[0.5, 0.5, 0.5]),
])

(4) Define the data set class MalariaImages:

class MalariaImages(Dataset):
    def __init__(self, files, transform=None):
        self.files = files
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, ix):
        fpath = self.files[ix]
        clss = os.path.basename(Path(fpath).parent)
        img = cv2.imread(fpath)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img, clss

    def choose(self):
        return self[randint(len(self))]

    def collate_fn(self, batch):
        _imgs, classes = list(zip(*batch))
        if self.transform:
            imgs = [self.transform(img)[None] for img in _imgs]
        classes = [torch.tensor([id2int[clss]]) for clss in classes]
        imgs, classes = [torch.cat(i).to(device) for i in [imgs, classes]]
        return imgs, classes, _imgs

(5) Obtain training and validation data sets and data loaders:

all_files = glob('cell_images/*/*.png')
np.random.shuffle(all_files)

from sklearn.model_selection import train_test_split
trn_files, val_files = train_test_split(all_files, random_state=1)

trn_ds = MalariaImages(trn_files, transform=trn_tfms)
val_ds = MalariaImages(val_files, transform=val_tfms)
trn_dl = DataLoader(trn_ds, 32, shuffle=True, collate_fn=trn_ds.collate_fn)
val_dl = DataLoader(val_ds, 32, shuffle=False, collate_fn=val_ds.collate_fn)

(6) Define the model MalariaClassifier:

def convBlock(ni, no):
    return nn.Sequential(
        nn.Dropout(0.2),
        nn.Conv2d(ni, no, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(no),
        nn.MaxPool2d(2),
    )
    
class MalariaClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            convBlock(3, 64),
            convBlock(64, 64),
            convBlock(64, 128),
            convBlock(128, 256),
            convBlock(256, 512),
            convBlock(512, 64),
            nn.Flatten(),
            nn.Linear(256, 256),
            nn.Dropout(0.2),
            nn.ReLU(inplace=True),
            nn.Linear(256, len(id2int))
        )
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def compute_metrics(self, preds, targets):
        loss = self.loss_fn(preds, targets)
        acc = (torch.max(preds, 1)[1] == targets).float().mean()
        return loss, acc

(7) Define functions for training and validating models using batch data:

def train_batch(model, data, optimizer, criterion):
    model.train()
    ims, labels, _ = data
    _preds = model(ims)
    optimizer.zero_grad()
    loss, acc = criterion(_preds, labels)
    loss.backward()
    optimizer.step()
    return loss.item(), acc.item()

@torch.no_grad()
def validate_batch(model, data, criterion):
    model.eval()
    ims, labels, _ = data
    _preds = model(ims)
    loss, acc = criterion(_preds, labels)
    return loss.item(), acc.item()

(8) Training model:

model = MalariaClassifier().to(device)
criterion = model.compute_metrics
optimizer = optim.Adam(model.parameters(), lr=1e-3)
n_epochs = 5

for ex in range(n_epochs):
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    N = len(trn_dl)
    for bx, data in enumerate(trn_dl):
        loss, acc = train_batch(model, data, optimizer, criterion)
        train_loss.append(loss)
        train_acc.append(acc)
    N = len(val_dl)
    for bx, data in enumerate(val_dl):
        loss, acc = validate_batch(model, data, criterion)
        val_loss.append(loss)
        val_acc.append(acc)
    avg_train_loss = np.average(train_loss)
    avg_train_acc = np.average(train_acc)
    avg_val_loss = np.average(val_loss)
    avg_val_acc = np.average(val_acc)
    print(f"EPOCH: {
      
      ex}	trn_loss: {
      
      avg_train_loss}	trn_acc: {
      
      avg_train_acc}	val_loss: {
      
      avg_val_loss}	val_acc: {
      
      avg_val_acc}")

(9)convBlock Obtain the convolutional layer in the fifth in the model :

im2fmap = nn.Sequential(*(list(model.model[:5].children()) + list(model.model[5][:2].children())))

In the above code, the fourth layer of the model and convBlockthe first two layers in (both Conv2Dlayers) are obtained.

(10) Define im2gradCAMthe function, which accepts the input image and obtains the heat map corresponding to the image activation:

def im2gradCAM(x):
    model.eval()
    logits = model(x)
    heatmaps = []
    activations = im2fmap(x)
    print(activations.shape)
    pred = logits.max(-1)[-1]
    # 获取模型预测
    model.zero_grad()
    # 计算相对于模型置信度最高的 logits 的梯度
    logits[0,pred].backward(retain_graph=True)
    # 获取所需特征图位置的梯度,并对每个特征图取平均梯度
    pooled_grads = model.model[-6][1].weight.grad.data.mean((1,2,3))
    # 将每个激活图与对应的梯度平均值相乘
    for i in range(activations.shape[1]):
        activations[:,i,:,:] *= pooled_grads[i]
    # 计算所有加权激活图的平均值
    heatmap = torch.mean(activations, dim=1)[0].cpu().detach()
    return heatmap, 'Uninfected' if pred.item() else 'Parasitized'

(11) Define upsampleHeatmapthe function to upsample the heatmap to a shape corresponding to the shape of the image:

SZ = 120
def upsampleHeatmap(map, img):
    m,M = map.min(), map.max()
    map = 255 * ((map-m) / (M-m))
    map = np.uint8(map)
    map = cv2.resize(map, (SZ,SZ))
    map = cv2.applyColorMap(255-map, cv2.COLORMAP_JET)
    map = np.uint8(map)
    map = np.uint8(map*0.7 + img*0.3)
    return map

In the previous lines of code, we denormalized the image and overlaid the heatmap on top of the image.

(12) Use a set of test images to call the above function:

N = 20
_val_dl = DataLoader(val_ds, batch_size=N, shuffle=True, collate_fn=val_ds.collate_fn)
x,y,z = next(iter(_val_dl))

for i in range(N):
    image = cv2.resize(z[i], (SZ, SZ))
    heatmap, pred = im2gradCAM(x[i:i+1])
    if(pred=='Uninfected'):
        continue
    heatmap = upsampleHeatmap(heatmap, image)
    plt.figure(figsize=(5,3))
    plt.subplot(121)
    plt.imshow(image)
    plt.subplot(122)
    plt.imshow(heatmap)
    plt.suptitle(pred)
    plt.show()

Test image heat map
It can be seen from the figure that the prediction result is determined by the content of the red highlighted area (this part of the area has the highest CAMvalue). Now that we have learned how to use a trained model to generate class activation heatmaps of images, we can explain what causes the model to produce a certain classification result.

summary

The generation method of class activation map is mainly based on the last convolution layer and global average pooling layer of the convolutional neural network model. First, the image is input into the convolutional neural network through forward propagation, and then the feature map of the last convolutional layer is obtained. Then, global average pooling is used to calculate the weight of each feature channel to obtain the importance of each channel to the classification result. Afterwards, the original image is multiplied with the weights of the feature maps and superimposed to finally obtain a class activation map. Areas with higher pixel values ​​in the class activation map represent image areas that contribute more to the classification result, while areas with lower pixel values ​​represent areas that contribute less or are irrelevant.

Series link

PyTorch Deep Learning Combat (1) - Neural Network and Model Training Process Detailed
PyTorch Deep Learning Combat (2) - PyTorch Basics
PyTorch Deep Learning Combat (3) - Using PyTorch to Build a Neural Network
PyTorch Deep Learning Combat (4) - Commonly used activation functions and loss functions in detail
PyTorch deep learning practice (5) - computer vision basics
PyTorch deep learning practice (6) - neural network performance optimization technology
PyTorch deep learning practice (7) - the impact of batch size on neural network training
PyTorch deep learning combat (8) - batch normalization
PyTorch deep learning combat (9) - learning rate optimization
PyTorch deep learning combat (10) - overfitting and its solution
PyTorch deep learning combat (11) - Convolutional Neural Network
PyTorch Deep Learning Practice (12) - Data Enhancement
PyTorch Deep Learning Practice (13) - Visualizing the Output of the Middle Layer of the Neural Network

Guess you like

Origin blog.csdn.net/LOVEmy134611/article/details/131928491