[pytorch]可视化feature map

在计算机视觉的项目中,尤其是物体分类,关键点检测等的实验里,我们常常需要可视化中间的feature map来帮助判断我们的模型是否可以很好地提取到我们想要的特征,进而帮助我们调整模型或者参数。

可视化代码:

from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

def visualize_feature(x, model, layers=[0,1]):
    net = nn.Sequential(*list(model.children())[:layers[0]])
    img = net(x)
    transform1 = transforms.ToPILImage(mode='L')
    #img = torch.cpu().clone()
    for i in range(img.size(0)):
        image = img[i]
        #print(image.size())
        image = transform1(np.uint8(image.numpy().transpose(1,2,0)))
        image.show()
    

transform函数:

将Numpy的ndarray或者Tensor转化成PILImage类型【在数据类型上,两者都有明确的要求】

  1. ndarray的数据类型要求dtype=uint8, range[0, 255] and shape H x W x C
  2. Tensor 的shape为 C x H x W 要求是FloadTensor的,不允许DoubleTensor或者其他类型

numpy转为PIL:

#初始化随机数种子
np.random.seed(0)
 
data = np.random.randint(0, 255, 300)
print(data.dtype)
n_out = data.reshape(10,10,3)
 
#强制类型转换
n_out = n_out.astype(np.uint8)
print(n_out.dtype)
 
img2 = transforms.ToPILImage()(n_out)
img2.show()

tensor转为PIL:

t_out = torch.randn(3,10,10)
img1 = transforms.ToPILImage()(t_out)
img1.show()

训练过程中调用可视化函数

def train(epoch):
    cnn.train()
    for data in tqdm(train_loader, desc='Train: epoch {}'.format(epoch), leave=False, total=len(train_loader)):  # 对于训练集的每一个batch
        img, label = data
        if cuda_available:
            img = img.cuda()
            label = label.cuda()
        #visualize_feature(img, cnn)
 
    
        out = cnn( img )  # 送进网络进行输出
        #out = torch.nn.functional.softmax(out, dim=1)
        #print(out.size())
        #print(label.size())
        loss = loss_function( out, label )  # 获得损失
 
        optimizer.zero_grad()  # 梯度归零
        loss.backward()  # 反向传播获得梯度,但是参数还没有更新
        optimizer.step()  # 更新梯度

直接load预训练好的model并输出feature map

model = Residual_Model()
model.load_state_dict(torch.load('./model.pkl'))

output = get_features(model,x)## model是训练好的model,前面已经import 进来了Residual model
print('output.shape:',output.shape)
发布了10 篇原创文章 · 获赞 0 · 访问量 214

猜你喜欢

转载自blog.csdn.net/weixin_43844219/article/details/104482843