import torch
from torchvision import transforms
import matplotlib.pyplot as plt
a=torch.rand(1,32,32,32)
def shou_tensor_img(tensor_img):
to_pil = transforms.ToPILImage()
img = tensor_img.cpu().clone()
img = to_pil(img)
plt.imshow(img)
# plt.show()
def vis_feature(feature):
col=8 ##col表示一共有几列
feature=feature.squeeze(0)
width = int(feature.shape[0] / col)
for i in range(col):
for j in range(width):
plt.subplot(col, width, i * width + 1 + j)
shou_tensor_img(feature[i * width + j].unsqueeze(0))
plt.show()
vis_feature(a)
特征图可视化(可以直接运行)
猜你喜欢
转载自blog.csdn.net/qq_40107571/article/details/127112472
今日推荐
周排行