This article mainly introduces how to use pytorch to obtain the intermediate feature layer of the trained network and convert it into a simple method of heat map
renderings
1. Modify the original test code
import matplotlib.pyplot as plt
2. Write a hook function casually (for details, you can search for "what does the hook (Hook) in pytorch do?")
# 用于保存信息
output_list = []
input_list = []
# 定义hook方法(类似一个插件函数)
def forward_hook(module, data_input, data_output):
# 这里简单进行保存相关的特征层
# 也可以对特征层进行操作
input_list.append(data_input)
output_list.append(data_output)
3. Then register the hook function (register in the convolutional layer you need to save)
register_forward_hook works for forward pass
How to look up the name of the relevant convolutional layer, use model.named_parameters() for traversal lookup (this article will not explain in detail)
# model.结构.某个卷积层.register_forward_hook(forward_hook)
model.det_head.conv2.register_forward_hook(forward_hook)
4. Visualize
# 特征输出可视化
for i in range(6): # 可视化卷积相应的通道数量
# 以下绘制了一个宽度为6,高度为1的展示区域
plt.subplot(6, 1, i + 1)
plt.axis('off')
# 制定使用jet热力图展示,还有其他的展示形式
plt.imshow(output_list[0].data.cpu().numpy()[0, i, :, :], cmap='jet')
# 保存起来,无白边保存
plt.savefig(file_path, bbox_inches='tight', pad_inches=0) ## 保存图片
# 在批量操作时,每次都会弹出来
# plt.show()
5. Code display (incomplete, just show how to add relevant code in test.py)
# 1、导入包
import matplotlib.pyplot as plt
# 2、定义保存信息的数组
output_list = []
input_list = []
# 3、定义hook方法
def forward_hook(module, data_input, data_output):
input_list.append(data_input)
output_list.append(data_output)
# 主要代码如下!!!!
def test(test_loader, model, cfg):
model.eval()
# 4、进行注册hook方法
model.det_head.conv2.register_forward_hook(forward_hook)
for idx, data in enumerate(test_loader):
# 5、遍历操作、每次清空一下
output_list.clear()
input_list.clear()
print('Testing %d/%d\r' % (idx, len(test_loader)), flush=True, end='')
data.update(dict(cfg=cfg))
# 6、forward,hook函数会生效
with torch.no_grad():
outputs = model(**data)
# save result
image_name, _ = osp.splitext(
osp.basename(test_loader.dataset.img_paths[idx]))
rf.write_result(image_name, outputs)
# 7、生成热力图保存路径(自己按照自己的保存路径即可)
tmp_folder = cfg.test_cfg.result_path.replace('.zip', '_visualization')
file_name = '%s.jpg' % image_name
file_path = osp.join(tmp_folder, file_name)
# 8、特征输出可视化
for i in range(6): # 可视化了32通道
plt.subplot(1, 6, i + 1)
plt.axis('off')
plt.imshow(output_list[0].data.cpu().numpy()[0, i, :, :], cmap='jet')
# 9、保存中间特征层的热力图
plt.savefig(file_path, bbox_inches='tight', pad_inches=0) ## 保存图片
# plt.show() # 展示热力图,由于现在是批量操作,故注释
# 正常的模型加载操作,可忽略!!!按照自己的模型加载方法即可!!!
def main(args):
# 读取配置文件
cfg = Config.fromfile(args.config)
# data loader数据加载
data_loader = build_data_loader(cfg.data.test)
test_loader = torch.utils.data.DataLoader(
data_loader,
batch_size=1,
shuffle=False,
num_workers=0,
)
# 模型加载
model = build_model(cfg.model)
model = model.cuda()
# 加载预训练权重
checkpoint = torch.load(args.checkpoint, map_location='cpu')
d = dict()
for key, value in checkpoint['state_dict'].items():
tmp = key[7:]
d[tmp] = value
model.load_state_dict(d)
# test
test(test_loader, model, cfg)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('config', help='config file path')
parser.add_argument('checkpoint', nargs='?', type=str, default=None)
parser.add_argument('--report_speed', action='store_true')
args = parser.parse_args()
main(args)