特征图可视化(可以直接运行)

import torch
from torchvision import transforms
import matplotlib.pyplot as plt
a=torch.rand(1,32,32,32)


def shou_tensor_img(tensor_img):
    to_pil = transforms.ToPILImage()
    img = tensor_img.cpu().clone()
    img = to_pil(img)
    plt.imshow(img)
    # plt.show()
def vis_feature(feature):
    col=8  ##col表示一共有几列
    feature=feature.squeeze(0)
    width = int(feature.shape[0] / col)
    for i in range(col):
        for j in range(width):
            plt.subplot(col, width, i * width + 1 + j)
            shou_tensor_img(feature[i * width + j].unsqueeze(0))
    plt.show()

vis_feature(a)

猜你喜欢

转载自blog.csdn.net/qq_40107571/article/details/127112472
今日推荐