最初の記事とその後の更新: https://mwhls.top/4475.html、画像なし/ディレクトリなし/形式エラー/詳細については、最初のページに移動して表示してください。
新しい更新については、mwhls.topを確認してください。
質問や批判は大歓迎です、どうもありがとうございました!
概要: モデルの指定したレイヤーのヒート マップを描画する
視覚環境のインストール
- 利用可能な環境バージョン:
- mmseg 1.0.0rc5
- mmdet 3.0.0rc6
- mmcv 2.0.0rc4
- mmengine 0.6.0
- Note: Don't overwrite it with files running in other versions. 怠けて自分のモデルを直接コピーしたかったので、最初は成功しませんでしたが、モデルは元のバージョンには存在するが新しいバージョンには存在しないメソッドを呼び出しました。バージョン、エラーになります。
- 上記の環境をインストールし、発行コードを参照して正常に推論します。コードは次のとおりです。
- 他にも featmap に言及している問題があります。mmseg GitHub で cam キーワードを検索するか、ここをクリックしてください。
import torch
import cv2
import numpy as np
from mmseg.visualization import SegLocalVisualizer
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
from mmengine.model import revert_sync_batchnorm
config_path = '../mmsegv2/configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_path = '../mmsegv2/checkpoints/pspnet_r50-d8_512x1024_80k_cityscapes_20200606_112131-2376f12b.pth'
img_path = '../mmsegv2/demo/demo.png'
register_all_modules()
model = init_model(config_path, checkpoint_path, device='cpu')
model = revert_sync_batchnorm(model)
vis = SegLocalVisualizer()
ori_img = cv2.imread(img_path)
img = torch.from_numpy(ori_img.astype(np.single)).permute(2, 0, 1).unsqueeze(0)
logits = model(img)
out = vis.draw_featmap(logits[0], ori_img)
cv2.imshow('cam', out)
cv2.waitKey(0)
指定場所の可視化
- 変更された可視化コード Startup.py
# Thank xiexinch: https://github.com/open-mmlab/mmsegmentation/issues/2434#issuecomment-1441392574
import torch
import cv2
import numpy as np
from mmseg.visualization import SegLocalVisualizer
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
from mmengine.model import revert_sync_batchnorm
# prefix = "mmsegmentation-1.0.0rc5/"
prefix = ""
config = prefix + r"log\7_ttpla_p2t_t_20k\ttpla_p2t_t_20k.py"
checkpoint = prefix + r"log\7_ttpla_p2t_t_20k\iter_8000.pth"
config = prefix + r"log\9_ttpla_r50_20k\ttpla_r50_20k.py"
checkpoint = prefix + r"log\9_ttpla_r50_20k\iter_8000.pth"
img_path = prefix + r"img.png"
def draw_heatmap(featmap):
vis = SegLocalVisualizer()
ori_img = cv2.imread(img_path)
out = vis.draw_featmap(featmap, ori_img)
cv2.imshow('cam', out)
cv2.waitKey(0)
def generate_featmap(config, checkpoint, img_path):
register_all_modules()
model = init_model(config, checkpoint, device='cpu')
model = revert_sync_batchnorm(model)
vis = SegLocalVisualizer()
ori_img = cv2.imread(img_path)
img = torch.from_numpy(ori_img.astype(np.single)).permute(2, 0, 1).unsqueeze(0)
logits = model(img)
out = vis.draw_featmap(logits[0], ori_img)
cv2.imshow('cam', out)
cv2.waitKey(0)
if __name__ == "__main__":
generate_featmap(config, checkpoint, img_path)
- 次のようにモデルで呼び出されます
draw_heatmap()
from Startup import draw_heatmap
draw_heatmap(x[0])
def forward(self, x):
"""Forward function."""
from Startup import draw_heatmap
draw_heatmap(x[0])
if self.deep_stem:
x = self.stem(x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
from Startup import draw_heatmap
draw_heatmap(x[0])
return tuple(outs)