得到中间层的输出

from torchvision import transforms
import argparse

import torch
import torch.backends.cudnn as cudnn
import PIL.Image as pil_image
from models import cnn
from utils import  preprocess
import matplotlib.pyplot as plt

def shou_tensor_img(tensor_img):
    to_pil = transforms.ToPILImage()
    img = tensor_img.cpu().clone()
    img = to_pil(img)
    plt.imshow(img)
    # plt.show()
def vis_feature(feature):
    col=8  ##col表示一共有几列
    feature=feature.squeeze(0)
    width = int(feature.shape[0] / col)
    for i in range(col):
        for j in range(width):
            plt.subplot(col, width, i * width + 1 + j)
            shou_tensor_img(feature[i * width + j].unsqueeze(0))
    plt.show()
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights-file', type=str, required=True)
    parser.add_argument('--image-file', type=str, requ

猜你喜欢

转载自blog.csdn.net/qq_40107571/article/details/127114252