『ignite』模型的训练过程

trainer的父类

from typing import Mapping, Dict, Optional

import torch
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Engine, Events
from ignite.handlers import TerminateOnNan, EarlyStopping
from ignite.metrics import Loss, RunningAverage

from src.exception import ModelNotFoundException
from src.experiment import Number


class Trainer(object):

    def __init__(self, *, model: torch.nn.Module = None, file: str = None, save: str = None, device: str = None):
        if model is not None:
            self.model = model
        elif file is not None:
            self.model = torch.load(file, map_location=device)
        else:
            raise ModelNotFoundException("模型未定义,请传入 torch.nn.Module 对象或可加载的模型的文件路径.")

        if device is not None:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        if save is not None:
            self.save: str = save
        else:
            raise ValueError("模型存储路径未定义!")

        # 评价指标,一般都至少使用 MSE
        self.metrics: Dict = {
    
    "MSE": Loss(torch.nn.MSELoss())}

        self.trainer: Optional[Engine] = None
        self.evaluator: Optional[Engine] = None

    def set_dataset(self, train_batch_size, val_batch_size=1) -> None:
        raise NotImplementedError("请重写 set_dataset.")

    def set_metrics(self, metric: Mapping) -> None:
        """
        设置自定义的评价指标,以字典形式传入
        """
        self.metrics.update(metric)

    @staticmethod
    def score_function(engine: Engine) -> Number:
        return -engine.state.metrics["MSE"]

    def early_stop(self, every: int = 1, patience: int = 10, min_delta: float = 0,
                   output_transform=lambda x: {
    
    'MSE': torch.nn.MSELoss()(*x)}) -> None:
        """
        如果模型试集的性能没有提升,则提前停止训练

        :param every:                      间隔多少个 EPOCH 验证一次测试集
        :param patience:                   多少次模型在测试集上性能没有优化就停止训练
        :param min_delta:                  分数最少提高多少才认为有改进
        :param output_transform:           对 engine 的输出进行转换的函数,转换成日志要输出的评估值
        :return:
        """
        evaluator_bar_format = "\033[0;32m 测试集验证:{percentage:3.0f}%|{bar}{postfix} 【已执行时间:{elapsed},剩余时间:{remaining}】\033[0m"
        bar = ProgressBar(persist=True, bar_format=evaluator_bar_format)
        bar.attach(self.evaluator, output_transform=output_transform)

        handler = EarlyStopping(patience=patience, score_function=self.score_function,
                                trainer=self.trainer, min_delta=min_delta)
        self.evaluator.add_event_handler(Events.COMPLETED, handler)
        # noinspection PyUnresolvedReferences
        self.trainer.add_event_handler(Events.EPOCH_COMPLETED(every=every), lambda: self.evaluator.run(self.test_set))

    def create_trainer(self) -> None:
        """
        创建 trainer engine
        """
        raise NotImplementedError("请重写 create_trainer.")

    def create_evaluator(self) -> None:
        """
        创建 evaluator engine
        """
        raise NotImplementedError("请重写 create_evaluator.")

    def set_trainer(self):
        """
        配置切面操作
        :return:
        """
        # 遇到 NaN 终止训练
        self.trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

        """
        控制台记录日志
        bar_format:输出的格式
        """
        trainer_bar_format = "\033[0;34m{desc}【{n_fmt:0>5s}/{total_fmt:0>5s}】 {percentage:3.0f}%|{bar}{postfix} 【已执行时间:{elapsed},剩余时间:{remaining}】\033[0m"

        # 第一行是指定输出的指标,第二行方法是连接Engine对象,输出的是设定的 Loss 的就计算结果
        # ProgressBar(persist=True, bar_format=trainer_bar_format).attach(self.trainer, metric_names=['MSE'])
        bar = ProgressBar(persist=True, bar_format=trainer_bar_format)
        # bar.attach(self.trainer, output_transform=lambda x: {'loss': x})

        RunningAverage(output_transform=lambda x: x, alpha=0.98).attach(self.trainer, 'loss')  # 计算指标的运行平均值
        bar.attach(self.trainer, metric_names=["loss"])
        self.trainer.add_event_handler(Events.COMPLETED, lambda: torch.save(self.model, self.save))
        self.trainer.add_event_handler(Events.COMPLETED, lambda: print("训练结束...."))
        self.trainer.add_event_handler(Events.STARTED, lambda: print("训练开始...."))

    def run(self, max_epochs, test_frequency=10) -> None:
        if not hasattr(self, "train_set") or not hasattr(self, "test_set"):
            raise FileExistsError("请先通过 set_dataset 方法设置数据集.")
        self.create_trainer()
        self.create_evaluator()
        self.set_trainer()
        self.early_stop(every=test_frequency)
        # noinspection PyUnresolvedReferences
        self.trainer.run(self.train_set, max_epochs=max_epochs)


trainer的实现类

import torch
from ignite.contrib.handlers import LRScheduler
from ignite.engine import create_supervised_trainer, Events, create_supervised_evaluator
from torch import nn, optim
from torch.optim.lr_scheduler import ExponentialLR

from src.data import get_data_loaders
from src.experiment.Trainer import Trainer
from src.model import ConvLSTM
from src.util import config
from src.util.patch import reshape_patch_back

cfg = config.load_model_parameters("ConvLSTM")


class ConvLSTMTrainer(Trainer):
    def __init__(self, *, model: torch.nn.Module = None, file: str = None, save: str = None, device: str = None):
        super().__init__(model=model, file=file, save=save, device=device)
        self.model.to(device)

    def create_trainer(self) -> None:
        """
        学习率衰减的代码可以写在这,虽然也是创建 Handler,我认为在这比较适合
        """

        criterion = nn.MSELoss()

        optimizer = optim.Adam(self.model.parameters(), lr=0.01)

        self.trainer = create_supervised_trainer(model=self.model, optimizer=optimizer, loss_fn=criterion,
                                                 device=self.device)

        # 学习率衰减
        step_scheduler = ExponentialLR(optimizer=optimizer, gamma=0.98)
        scheduler = LRScheduler(step_scheduler)
        self.trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)

    def create_evaluator(self) -> None:
        self.evaluator = create_supervised_evaluator(model=self.model, metrics=self.metrics, device=self.device,
                                                     output_transform=lambda x, y, y_pred: (
                                                         reshape_patch_back(y_pred, patch_size=4),
                                                         reshape_patch_back(y, patch_size=4)
                                                     ))

    def set_dataset(self, train_batch_size, val_batch_size=1) -> None:
        train_set, test_set = get_data_loaders("ConvLSTM", train_batch_size, val_batch_size)
        setattr(self, "train_set", train_set)
        setattr(self, "test_set", test_set)

测试代码

if __name__ == '__main__':
    net = ConvLSTM(in_channels=cfg["in_channels"] * 4 * 4, hidden_channels_list=cfg["hidden_channels_list"],
                   kernel_size_list=cfg["kernel_size_list"], forget_bias=cfg["forget_bias"])
    trainer = ConvLSTMTrainer(model=net, save="test.pth", device="cuda")
    trainer.set_dataset(train_batch_size=2, val_batch_size=1)
    trainer.run(max_epochs=3, test_frequency=1)

控制台显示

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/dreaming_coder/article/details/108834238