PyTorch Swin-Transformer 各层特征可视化

PyTorch相关开源库
https://gitee.com/hejuncheng1/pytorch-grad-cam

安装命令

pip install grad-cam

具体使用参考
Swin Transformer各层特征可视化_不高兴与没头脑Fire的博客-CSDN博客

提供示例

# dataloader.py
from torchvision import datasets, transforms
import os
import torch

input_size = 224

data_transforms = {
    
    
    'train': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.RandomResizedCrop(size=input_size, scale=(0.7, 1)),
        transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
    'val': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
    'test': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
}


def update(new_input_size):
    global input_size
    global data_transforms

    input_size = new_input_size

    data_transforms = {
    
    
        'train': transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.RandomResizedCrop(size=input_size, scale=(0.7, 1)),
            transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]),
        'val': transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]),
        'test': transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    }


def dataloader(data_dir, batch_size, set_name, shuffle):
    image_datasets = {
    
    x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in [set_name]}
    num_workers = 1 if torch.cuda.is_available() else 0
    dataset_loaders = {
    
    
        x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=shuffle,
                                       num_workers=num_workers)
        for x in [set_name]}
    dataset_sizes = len(image_datasets[set_name])
    return dataset_loaders, dataset_sizes


if __name__ == '__main__':
    data_dir = ''
    dset_loaders, dset_sizes = dataloader(data_dir=data_dir, batch_size=16, set_name='train', shuffle=True)
    print(dset_loaders, dset_sizes)
# main.py
import cv2
import numpy as np
import torch
import torch.nn as nn
import os
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image

import dataloader


def reshape_transform(tensor, height=12, width=12):
    result = tensor.reshape(tensor.size(0),
                            height, width, tensor.size(2))
    result = result.transpose(2, 3).transpose(1, 2)
    return result


if __name__ == '__main__':
    net_name = 'swin_base_patch4_window12_384_22k'
    categories_size = 2
    model_ft = None

    if net_name == 'swin_base_patch4_window12_384_22k':
        from models import swintf

        model_ft = swintf.build_model('config/swin_base_patch4_window12_384_22k.yaml', use_checkpoint=True)
        model_ft.head = nn.Linear(1024, categories_size)
        dataloader.update(384)

    use_gpu = True if torch.cuda.is_available() else False
    if use_gpu:
        model_ft = model_ft.cuda()

    load_path = os.path.join('./save', net_name + '.pth')
    if os.path.exists(load_path):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        msg = model_ft.load_state_dict(torch.load(load_path, map_location=device))
    print('msg:', msg)

    model_ft.eval()

    target_layer = [model_ft.norm]

    target_category = 0
    image_path = ''
    image = Image.open(image_path)
    transformer = dataloader.data_transforms['test']
    image_ = transformer(image)
    inputs = image_.unsqueeze(0)

    cam = GradCAM(model=model_ft, target_layers=target_layer, use_cuda=False, reshape_transform=reshape_transform)
    cam.batch_size = 1
    grayscale_cam = cam(input_tensor=inputs, target_category=target_category, eigen_smooth=True,
                        aug_smooth=True)
    grayscale_cam = grayscale_cam[0, :]
    image = np.array(image.resize((384, 384))) / 255.0
    cam_image = show_cam_on_image(image, grayscale_cam)
    cv2.imwrite('cam.jpg', cam_image)
    print('OK')

猜你喜欢

转载自blog.csdn.net/u014134327/article/details/124043802