from torchvision import transforms
import argparse
import torch
import torch.backends.cudnn as cudnn
import PIL.Image as pil_image
from models import cnn
from utils import preprocess
import matplotlib.pyplot as plt
def shou_tensor_img(tensor_img):
to_pil = transforms.ToPILImage()
img = tensor_img.cpu().clone()
img = to_pil(img)
plt.imshow(img)
# plt.show()
def vis_feature(feature):
col=8 ##col表示一共有几列
feature=feature.squeeze(0)
width = int(feature.shape[0] / col)
for i in range(col):
for j in range(width):
plt.subplot(col, width, i * width + 1 + j)
shou_tensor_img(feature[i * width + j].unsqueeze(0))
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights-file', type=str, required=True)
parser.add_argument('--image-file', type=str, requ
得到中间层的输出
猜你喜欢
转载自blog.csdn.net/qq_40107571/article/details/127114252
今日推荐
周排行