MMSeg は、モデルの指定されたレイヤーのヒートマップ ヒート マップを描画します

最初の記事とその後の更新: https://mwhls.top/4475.html、画像なし/ディレクトリなし/形式エラー/詳細については、最初のページに移動して表示してください。
新しい更新については、mwhls.topを確認してください。
質問や批判は大歓迎です、どうもありがとうございました!

概要: モデルの指定したレイヤーのヒート マップを描画する

視覚環境のインストール

  • 利用可能な環境バージョン:
    • mmseg 1.0.0rc5
    • mmdet 3.0.0rc6
    • mmcv 2.0.0rc4
    • mmengine 0.6.0
    • Note: Don't overwrite it with files running in other versions. 怠けて自分のモデルを直接コピーしたかったので、最初は成功しませんでしたが、モデルは元のバージョンには存在するが新しいバージョンには存在しないメソッドを呼び出しました。バージョン、エラーになります。
  • 上記の環境をインストールし、発行コードを参照して正常に推論します。コードは次のとおりです。
    • 他にも featmap に言及している問題があります。mmseg GitHub で cam キーワードを検索するか、ここをクリックしてください。
import torch
import cv2
import numpy as np

from mmseg.visualization import SegLocalVisualizer
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
from mmengine.model import revert_sync_batchnorm

config_path = '../mmsegv2/configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_path = '../mmsegv2/checkpoints/pspnet_r50-d8_512x1024_80k_cityscapes_20200606_112131-2376f12b.pth'
img_path = '../mmsegv2/demo/demo.png'

register_all_modules()

model = init_model(config_path, checkpoint_path, device='cpu')
model = revert_sync_batchnorm(model)
vis = SegLocalVisualizer()


ori_img = cv2.imread(img_path)
img = torch.from_numpy(ori_img.astype(np.single)).permute(2, 0, 1).unsqueeze(0)

logits = model(img)
out = vis.draw_featmap(logits[0], ori_img)

cv2.imshow('cam', out)
cv2.waitKey(0)

指定場所の可視化

  • 変更された可視化コード Startup.py
# Thank xiexinch: https://github.com/open-mmlab/mmsegmentation/issues/2434#issuecomment-1441392574
import torch
import cv2
import numpy as np
from mmseg.visualization import SegLocalVisualizer
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
from mmengine.model import revert_sync_batchnorm


# prefix = "mmsegmentation-1.0.0rc5/"
prefix = ""
config = prefix + r"log\7_ttpla_p2t_t_20k\ttpla_p2t_t_20k.py"
checkpoint = prefix + r"log\7_ttpla_p2t_t_20k\iter_8000.pth"

config = prefix + r"log\9_ttpla_r50_20k\ttpla_r50_20k.py"
checkpoint = prefix + r"log\9_ttpla_r50_20k\iter_8000.pth"

img_path = prefix + r"img.png"

def draw_heatmap(featmap):
    vis = SegLocalVisualizer()
    ori_img = cv2.imread(img_path)
    out = vis.draw_featmap(featmap, ori_img)
    cv2.imshow('cam', out)
    cv2.waitKey(0)

def generate_featmap(config, checkpoint, img_path):
    register_all_modules()

    model = init_model(config, checkpoint, device='cpu')
    model = revert_sync_batchnorm(model)
    vis = SegLocalVisualizer()

    ori_img = cv2.imread(img_path)
    img = torch.from_numpy(ori_img.astype(np.single)).permute(2, 0, 1).unsqueeze(0)

    logits = model(img)
    out = vis.draw_featmap(logits[0], ori_img)

    cv2.imshow('cam', out)
    cv2.waitKey(0)

if __name__ == "__main__":
    generate_featmap(config, checkpoint, img_path)
  • 次のようにモデルで呼び出されますdraw_heatmap()
from Startup import draw_heatmap
draw_heatmap(x[0])
def forward(self, x):
    """Forward function."""
    from Startup import draw_heatmap
    draw_heatmap(x[0])
    if self.deep_stem:
        x = self.stem(x)
    else:
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu(x)
    x = self.maxpool(x)
    outs = []
    for i, layer_name in enumerate(self.res_layers):
        res_layer = getattr(self, layer_name)
        x = res_layer(x)
        if i in self.out_indices:
            outs.append(x)
        from Startup import draw_heatmap
        draw_heatmap(x[0])

    return tuple(outs)

結果を示す

Heatmap1.png Heatmap2.png Heatmap3.png Heatmap4.png Heatmap5.png Heatmap6.png

おすすめ

転載: blog.csdn.net/asd123pwj/article/details/129346884