asdTools-ReID热力图可视化

文章首发见博客:https://mwhls.top/4869.html
无图/格式错误/后续更新请见首发页。
更多更新请到mwhls.top查看
欢迎留言提问或批评建议,私信不回。

Github - 开源代码及Readme
Blog - 工具介绍

摘要:基于TorchCam实现ReID的热力图可视化的工具介绍。

场景
  • 想为科研添砖加瓦,因此需要一个好看的可视化。
  • 或者只是想了解一下模型的效果。
使用示例
    from torchvision.transforms.functional import to_pil_image
    from asdTools.Classes.Image.ImageBase import ImageBase
    from torchcam.utils import overlay_mask
    from torchcam.methods import GradCAM
    
    class VisualizeHeatmapOfReID(ImageBase):
        """ Sample: Sample/VisualizeHeatmapOfReID
        使用torch-cam可视化热力图,仅测试于ReID模型。
        Visualize heatmap by torch-cam, test only on ReID.
    
        torch-cam: `pip install torchcam` or `conda install -c frgfm torchcam`, GitHub: https://github.com/frgfm/torch-cam
        """
        def __init__(self, **kwargs) -> None:
            super().__init__(multipleFiles=True, **kwargs)
    
        def __call__(self, 
                model, 
                imgs_dir:str, 
                transform,
                img_ext:list=["png", "jpg", "jpeg"],
                torchCamMethod=GradCAM,
                device="cuda") -> str:
            self.run(model, imgs_dir, transform, img_ext, torchCamMethod, device)
    
        def run(self, model, imgs_dir:str, transform, img_ext:list, torchCamMethod, device) -> str:
            self.begining()
            # get paths of imgs from img_dir
            imgs_path = self.get_paths_from_dir(imgs_dir)
            self.log(f"{
      
      len(imgs_path)} files found in {
      
      imgs_dir}")
            imgs_path = self.filter_ext(imgs_path, img_ext)
            self.log(f"{
      
      len(imgs_path)} images found after filter extension by {
      
      img_ext}")
            # init torchCam
            model.to(device).eval()
            cam_extractor = torchCamMethod(model)
    
            for i, img_path in enumerate(imgs_path):
                # model(x)
                img = self.read_img(img_path)
                x = transform(img).unsqueeze(0).to(device)
                out = model(x)
                # visualize feature
                activation_map = cam_extractor(class_idx=0, scores=out.unsqueeze(0))[0]
                result = overlay_mask(img, to_pil_image(activation_map, mode='F'), alpha=0.5)
                # save img
                save_path = self.remove_root_of_path(path=img_path, root=imgs_dir)
                save_middle_dir = self.get_dir_of_file(save_path)
                save_name = self.get_name_of_file(save_path, True)
                save_path = self.save_image(result, output_middle_dir=save_middle_dir, output_file=save_name)
                self.log(f"{
      
      i+1}/{
      
      len(imgs_path)}: the heatmap of {
      
      img_path} has been saved to {
      
      save_path}.")
            self.done()
    
    if __name__ == "__main__":
        """ 
        ---------- usage example ---------- 
        -- you can simply insert VisualizeHeatmapOfReID() in the train.py or main.py
        -- and then set value of imgs_dir
    
        from asdTools.Tools.Image.VisualizeHeatmapOfReID import VisualizeHeatmapOfReID
        from models.img_resnet import ResNet50
        import data.img_transforms as T
    
        imgs_dir = "../../../../Datasets/PRCC"
    
        weight = "logs/0/baseline.pth.tar"
        checkpoint = torch.load(weight)
        model = ResNet50(config)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.cuda().eval()
    
        transform_test = T.Compose([
            T.Resize((256, 128)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    
        color_heatmap = VisualizeHeatmapOfReID()
        color_heatmap(model, imgs_dir, transform_test)
        """
    

猜你喜欢

转载自blog.csdn.net/asd123pwj/article/details/133678468