Anomalib实战之二:支持自定义模型

要将新的异常检测模型集成到Anomalib中,可以按照以下步骤进行操作:

1 Create a new sub-package

在anomalib/models中创建的一个新目录,用于存储与模型相关的文件。

./anomalib/models/<new-model>
├── __init__.py
├── config.yaml
├── torch_model.py
├── lightning_model.py
├── loss.py    # OPTIONAL
├── anomaly_map.py    # OPTIONAL
└── README.md

3 Create a config.yaml file.

config.yaml文件存储了所有的配置信息,包括数据和优化选项。下面是一个示例的yaml文件:

dataset:
    name: mvtec #options: [mvtec, btech, folder]
    format: mvtec
    ...
model:
    name: patchcore
    backbone: wide_resnet50_2
    ...
metrics:
    image:
        - F1Score
    ...
visualization:
    show_images: False # show images on the screen
    ...
# PL Trainer Args. Don't add extra parameter here.
trainer:
    accelerator: auto # <"cpu", "gpu", "tpu", "ipu", "hpu", "auto">
    ...

4 Create a torch_model.py file.

torch_model.py文件包含了继承自torch.nn.Module的torch模型实现,定义了模型的架构并执行基本的前向传播。将模型存储在一个独立的torch_model.py文件中的优势是,模型与anomalib的其他实现解耦,也可以在库之外使用。基本实现如下所示:

class NewModelModel(nn.Module):
    """New Model Module."""
    def __init__(self):
        pass
    def forward(self, x):
        pass

5 Create a lightning_model.py file.

lightning_model.py模块包含了继承自AnomalModule的lightning模型实现,AnomalModule已经具有与anomalib相关的属性和方法。用户不需要担心样板代码,只需要实现算法的训练和验证逻辑即可。

class NewModel(AnomalyModule):
    """PL Lightning Module for the New Model."""
    def __init__(self):
        super().__init__()
        pass
    def training_step(self, batch):
        pass
    ...
    def validation_step(self, batch):
        pass

6 [OPTIONAL] Create a loss.py file.

如果算法需要自定义的复杂损失函数,则需要实现loss.py文件。loss.py文件包含torch.nn.Module类实现的子类。然后,lightning模块将使用这个损失函数。

class NewModelLoss(nn.Module):
    """NewModel Loss."""

    def forward(self) -> Tensor:
        """Calculate the NewModel loss."""
        pass

7 [OPTIONAL] Create an anomaly_map.py file.

如果算法支持分割,那么可以实现这个模块,anomaly_map.py模块根据算法的能力以便逐像素地预测异常的位置。

class AnomalyMapGenerator(nn.Module):
    """Generate Anomaly Heatmap."""

    def __init__(self, input_size: Union[ListConfig, Tuple]):
        pass

    def forward(self, x: Tensor) -> Tensor:
        """Generate Anomaly Heatmap."""
        ...
        return anomaly_map

8 Create a README.md file.

编写readme

# Name of the Model

## Description
Brief description of the paper.

## Architecture
A diagram showing the high-level overview.

## Usage
python tools/train.py --model <newmodel>

## Benchmark
Benchmark results on MVTec categories.

猜你喜欢

转载自blog.csdn.net/shanglianlm/article/details/132845679
今日推荐