Pytorch's visualization scheme for intermediate feature layers

This article mainly introduces how to use pytorch to obtain the intermediate feature layer of the trained network and convert it into a simple method of heat map

renderings

1. Modify the original test code

import matplotlib.pyplot as plt

2. Write a hook function casually (for details, you can search for "what does the hook (Hook) in pytorch do?")

# 用于保存信息
output_list = []
input_list = []


# 定义hook方法(类似一个插件函数)
def forward_hook(module, data_input, data_output):
    # 这里简单进行保存相关的特征层
    # 也可以对特征层进行操作
    input_list.append(data_input)
    output_list.append(data_output)

3. Then register the hook function (register in the convolutional layer you need to save)

        register_forward_hook works for forward pass

        How to look up the name of the relevant convolutional layer, use model.named_parameters() for traversal lookup (this article will not explain in detail)

# model.结构.某个卷积层.register_forward_hook(forward_hook)
model.det_head.conv2.register_forward_hook(forward_hook)

4. Visualize

# 特征输出可视化
for i in range(6):  # 可视化卷积相应的通道数量
    # 以下绘制了一个宽度为6,高度为1的展示区域
    plt.subplot(6, 1, i + 1)
    plt.axis('off')
    # 制定使用jet热力图展示,还有其他的展示形式
    plt.imshow(output_list[0].data.cpu().numpy()[0, i, :, :], cmap='jet')
# 保存起来,无白边保存
plt.savefig(file_path, bbox_inches='tight', pad_inches=0)  ## 保存图片
# 在批量操作时,每次都会弹出来
# plt.show()  

5. Code display (incomplete, just show how to add relevant code in test.py)

# 1、导入包
import matplotlib.pyplot as plt


# 2、定义保存信息的数组
output_list = []
input_list = []


# 3、定义hook方法
def forward_hook(module, data_input, data_output):
    input_list.append(data_input)
    output_list.append(data_output)


# 主要代码如下!!!!
def test(test_loader, model, cfg):
    model.eval()
    
    # 4、进行注册hook方法
    model.det_head.conv2.register_forward_hook(forward_hook)

    for idx, data in enumerate(test_loader):
        # 5、遍历操作、每次清空一下
        output_list.clear()
        input_list.clear()
        print('Testing %d/%d\r' % (idx, len(test_loader)), flush=True, end='')
        data.update(dict(cfg=cfg))


        # 6、forward,hook函数会生效
        with torch.no_grad():
            outputs = model(**data)

        # save result
        image_name, _ = osp.splitext(
            osp.basename(test_loader.dataset.img_paths[idx]))
        rf.write_result(image_name, outputs)

        # 7、生成热力图保存路径(自己按照自己的保存路径即可)
        tmp_folder = cfg.test_cfg.result_path.replace('.zip', '_visualization')
        file_name = '%s.jpg' % image_name
        file_path = osp.join(tmp_folder, file_name)
        # 8、特征输出可视化
        for i in range(6):  # 可视化了32通道
            plt.subplot(1, 6, i + 1)
            plt.axis('off')
            plt.imshow(output_list[0].data.cpu().numpy()[0, i, :, :], cmap='jet')
        # 9、保存中间特征层的热力图
        plt.savefig(file_path, bbox_inches='tight', pad_inches=0)  ## 保存图片
        # plt.show()  # 展示热力图,由于现在是批量操作,故注释


# 正常的模型加载操作,可忽略!!!按照自己的模型加载方法即可!!!
def main(args):
    # 读取配置文件
    cfg = Config.fromfile(args.config)

    # data loader数据加载
    data_loader = build_data_loader(cfg.data.test)
    test_loader = torch.utils.data.DataLoader(
        data_loader,
        batch_size=1,
        shuffle=False,
        num_workers=0,
    )
    
    # 模型加载
    model = build_model(cfg.model)
    model = model.cuda()

    # 加载预训练权重
    checkpoint = torch.load(args.checkpoint, map_location='cpu')
    d = dict()
    for key, value in checkpoint['state_dict'].items():
        tmp = key[7:]
        d[tmp] = value
    model.load_state_dict(d)

    # test
    test(test_loader, model, cfg)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Hyperparams')
    parser.add_argument('config', help='config file path')
    parser.add_argument('checkpoint', nargs='?', type=str, default=None)
    parser.add_argument('--report_speed', action='store_true')
    args = parser.parse_args()

    main(args)

Guess you like

Origin blog.csdn.net/kb16045125/article/details/127495243