查看卷积网络每一层的feature map的代码

import os
import sys
import pdb
import logging
import time
import torch
import argparse
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict

import options.options as option
import utils.util as util
from data.util import bgr2ycbcr
from data import create_dataset, create_dataloader
from models import create_model
from models.modules import block as B

import matplotlib.pyplot as plt

# options
opt = 'options/test/test_sr.json'
opt = option.parse(opt, is_train=False)
util.mkdirs((path for key, path in opt['path'].items() if not key == 'pretrain_model_G'))
opt = option.dict_to_nonedict(opt)


# Create test dataset and dataloader
test_loaders = []
for phase, dataset_opt in sorted(opt['datasets'].items()):
    test_set = create_dataset(dataset_opt)
    test_loader = create_dataloader(test_set, dataset_opt)
    print('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
    test_loaders.append(test_loader)

# Create model
model = create_model(opt)

# Register hook for featrue map.
def save_feature(name):
    def hook(module, input, output):
        featuremap[name] = output
    return hook

# Set register hook
conv_idx = 0
featuremap = OrderedDict()

for m in model.netG.module.model.modules():
    if m._get_name() == '********':
        conv_idx += 1
        m.register_forward_hook(save_feature('conv_' + str(conv_idx)))#可在module前向传播或反向传播时注册钩子
# print(conv_idx)
# exit()

for test_loader in test_loaders:
    test_set_name = test_loader.dataset.opt['name']
    test_start_time = time.time()
    dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
    util.mkdir(dataset_dir)

    test_results = OrderedDict()
    test_results['psnr'] = []
    test_results['ssim'] = []
    test_results['psnr_y'] = []
    test_results['ssim_y'] = []

    for data in test_loader:
        need_HR = False if test_loader.dataset.opt['dataroot_HR'] is None else True

        model.feed_data(data, need_HR=need_HR)
        img_path = data['LR_path'][0]
        img_name = os.path.splitext(os.path.basename(img_path))[0]

        model.test()  # test
        visuals = model.get_current_visuals(need_HR=need_HR)

        sr_img = util.tensor2img(visuals['SR'])  # uint8

gwp=0

from mpl_toolkits.axes_grid1 import AxesGrid

for k,v in featuremap.items():
    # print(v[1].shape)
    # exit()
    vals = v[0].squeeze().float().cpu().numpy()
    print(vals.shape)
    # exit()
    fig = plt.figure(figsize=(15,5))
    grid = AxesGrid(fig, 111,
                    nrows_ncols=(2, 16),
                    axes_pad=0.05,
                    share_all=True,
                    label_mode="L",
                    cbar_location="right",
                    cbar_mode="single",
                    )
    for val, ax in zip(vals,grid):
        im = ax.imshow(val)

    grid.cbar_axes[0].colorbar(im)
    for cax in grid.cbar_axes:
        cax.toggle_label(True)
    gwp=gwp+1
    # plt.show()
    fig.savefig(os.path.join(dataset_dir, str(gwp) + '.png'), dpi=400, bbox_inches='tight', transparent=True)



# from mpl_toolkits.axes_grid1 import AxesGrid

# for k,v in featuremap.items():
#     vals = []
#     if isinstance(v, tuple):
#         for i in range(v[0].shape[1]):
#             vals.append(v[0].squeeze().float().cpu().numpy()[i])
            
#         hr = F.upsample(v[1], scale_factor=2, mode='nearest')
#         for i in range(hr.shape[1]):
#             vals.append(hr.squeeze().float().cpu().numpy()[i])
#     else:
#         for i in range(v.shape[1]):
#             vals = v.squeeze().float().cpu().numpy()
    
#     fig = plt.figure(figsize=(15,5))
#     grid = AxesGrid(fig, 111,
#                     nrows_ncols=(2, 8),
#                     axes_pad=0.05,
#                     share_all=True,
#                     label_mode="L",
#                     cbar_location="right",
#                     cbar_mode="single",
#                     )
#     for val, ax in zip(vals,grid):
#         im = ax.imshow(val, vmin=0, vmax=2)

#     grid.cbar_axes[0].colorbar(im)
#     for cax in grid.cbar_axes:
#         cax.toggle_label(True)

#     plt.show()
#     # fig.savefig(os.path.join(dataset_dir, k.split('/')[1] + '.png'), dpi=400, bbox_inches='tight', transparent=True)

运行代码后,就可以看到卷积网络每一层layer输出的feature map的形式,进而可以进一步的分析网络

发布了208 篇原创文章 · 获赞 198 · 访问量 23万+

猜你喜欢

转载自blog.csdn.net/gwplovekimi/article/details/101443866