The first article and subsequent updates: https://mwhls.top/4475.html , no picture/no directory/format error/more information, please go to the first page to view. Please check mwhls.top
for new updates . Any questions and criticisms are welcome, thank you very much!
Summary: Draw a heat map of a specified layer of the model
Visual environment installation
- Available environment versions:
- mmseg 1.0.0rc5
- mmdet 3.0.0rc6
- mmcv 2.0.0rc4
- mmengine 0.6.0
- Note: Don’t overwrite it with files running in other versions. I didn’t succeed at first because I wanted to be lazy and copy my model directly, but the model called a method that existed in the original version but not in the new version, resulting in an error. .
- Install the above environment, refer to the issue code to reason normally, the code is as follows
- There are other issues that also mention featmap, you can search for cam keywords in mmseg GitHub, or click here .
import torch
import cv2
import numpy as np
from mmseg.visualization import SegLocalVisualizer
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
from mmengine.model import revert_sync_batchnorm
config_path = '../mmsegv2/configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_path = '../mmsegv2/checkpoints/pspnet_r50-d8_512x1024_80k_cityscapes_20200606_112131-2376f12b.pth'
img_path = '../mmsegv2/demo/demo.png'
register_all_modules()
model = init_model(config_path, checkpoint_path, device='cpu')
model = revert_sync_batchnorm(model)
vis = SegLocalVisualizer()
ori_img = cv2.imread(img_path)
img = torch.from_numpy(ori_img.astype(np.single)).permute(2, 0, 1).unsqueeze(0)
logits = model(img)
out = vis.draw_featmap(logits[0], ori_img)
cv2.imshow('cam', out)
cv2.waitKey(0)
Specified location visualization
- Modified visualization code Startup.py
# Thank xiexinch: https://github.com/open-mmlab/mmsegmentation/issues/2434#issuecomment-1441392574
import torch
import cv2
import numpy as np
from mmseg.visualization import SegLocalVisualizer
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
from mmengine.model import revert_sync_batchnorm
# prefix = "mmsegmentation-1.0.0rc5/"
prefix = ""
config = prefix + r"log\7_ttpla_p2t_t_20k\ttpla_p2t_t_20k.py"
checkpoint = prefix + r"log\7_ttpla_p2t_t_20k\iter_8000.pth"
config = prefix + r"log\9_ttpla_r50_20k\ttpla_r50_20k.py"
checkpoint = prefix + r"log\9_ttpla_r50_20k\iter_8000.pth"
img_path = prefix + r"img.png"
def draw_heatmap(featmap):
vis = SegLocalVisualizer()
ori_img = cv2.imread(img_path)
out = vis.draw_featmap(featmap, ori_img)
cv2.imshow('cam', out)
cv2.waitKey(0)
def generate_featmap(config, checkpoint, img_path):
register_all_modules()
model = init_model(config, checkpoint, device='cpu')
model = revert_sync_batchnorm(model)
vis = SegLocalVisualizer()
ori_img = cv2.imread(img_path)
img = torch.from_numpy(ori_img.astype(np.single)).permute(2, 0, 1).unsqueeze(0)
logits = model(img)
out = vis.draw_featmap(logits[0], ori_img)
cv2.imshow('cam', out)
cv2.waitKey(0)
if __name__ == "__main__":
generate_featmap(config, checkpoint, img_path)
- Called in the model as follows
draw_heatmap()
from Startup import draw_heatmap
draw_heatmap(x[0])
def forward(self, x):
"""Forward function."""
from Startup import draw_heatmap
draw_heatmap(x[0])
if self.deep_stem:
x = self.stem(x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
from Startup import draw_heatmap
draw_heatmap(x[0])
return tuple(outs)