Pytorch uses Grad-CAM to draw ViT heat map

VIT process:
Image division patch, plus class_token, plus position encoding, passed to transformer, classification prediction.
When using ViT to draw the Grad-CAM heat map, please note:
(1) When the code obtains the CAM map, since VIT finally obtains the gradient of the patch, it must be reshaped into a two-dimensional map.
Therefore, remove the class_token sequence, get all the tokens that make up the original image, and reshape them back to the size of the original image.
Insert image description here
The schematic diagram of the predicted output obtained after the last transformer block is as follows.
(2) Reverse gradient propagation starts from the last prediction and goes through the entire model in reverse, while Dropout and MLP are both done with a single token. The final y_c result gradient cannot be passed to all tokens, only after back-pushing back to the self-attention Only then can the gradient be returned to all tokens.
The last fully connected layer of CNN can directly pass the gradient to all points.
Insert image description here

import os
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from utils import GradCAM, show_cam_on_image, center_crop_img
from vit_model import vit_base_patch16_224


class ReshapeTransform:
    def __init__(self, model):
        input_size = model.patch_embed.img_size
        patch_size = model.patch_embed.patch_size
        self.h = input_size[0] // patch_size[0]
        self.w = input_size[1] // patch_size[1]

    def __call__(self, x):# x是个token序列
        # remove cls token and reshape
        # [batch_size, num_tokens, token_dim]
        #拿到所有组成原图的token,将它们reshape回原图的大小
        result = x[:, 1:, :].reshape(x.size(0),#从1开始,忽略掉class_token
                                     self.h,
                                     self.w,
                                     x.size(2))

        # Bring the channels to the first dimension,
        # like in CNNs.
        # [batch_size, H, W, C] -> [batch, C, H, W]
        result = result.permute(0, 3, 1, 2)
        return result


def main():
    model = vit_base_patch16_224()
    # 链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA  密码: eu9f
    weights_path = "./vit_base_patch16_224.pth"
    model.load_state_dict(torch.load(weights_path, map_location="cpu"))

    target_layers = [model.blocks[-1].norm1] #最后一个block的norm1-
    #---vit最后只对class_token做预测,只用它对结果有贡献,也就只有它有梯度,再将最后预测的结果进行反向传播,后面那几层都只是token自己的MLP,LN只有在多头注意力才将class_token与其余token关联起来
    #反向梯度传播是从最后预测开始,经过整个模型。target_layers只是表示记录这些layers的梯度信息而已
    data_transform = transforms.Compose([transforms.ToTensor(),
                                         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    # load image
    img_path = "both.png"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path).convert('RGB')
    img = np.array(img, dtype=np.uint8)
    img = center_crop_img(img, 224)
    # [C, H, W]
    img_tensor = data_transform(img)
    # expand batch dimension
    # [C, H, W] -> [N, C, H, W]
    input_tensor = torch.unsqueeze(img_tensor, dim=0)

    cam = GradCAM(model=model,
                  target_layers=target_layers,
                  use_cuda=False,
                  reshape_transform=ReshapeTransform(model))
    target_category = 281  # tabby, tabby cat
    # target_category = 254  # pug, pug-dog

    grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)

    grayscale_cam = grayscale_cam[0, :]
    visualization = show_cam_on_image(img / 255., grayscale_cam, use_rgb=True)
    plt.imshow(visualization)
    plt.show()


if __name__ == '__main__':
    main()

Guess you like

Origin blog.csdn.net/weixin_44040169/article/details/128045559