『ignite』PyTorch好用的工具包

尽管 PyTorch 已经为我们实现神经网络提供了不少便利,但是人的惰性是无极限的,这里介绍一个进一步抽象的工具包——ignite,它将 PyTorch 训练过程更加简化了。

1. 安装

pip install pytorch-ignite

2. 基础示例

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss

model = Net()
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)
criterion = nn.NLLLoss()

trainer = create_supervised_trainer(model, optimizer, criterion)

val_metrics = {
    
    
    "accuracy": Accuracy(),
    "nll": Loss(criterion)
}
evaluator = create_supervised_evaluator(model, metrics=val_metrics)

@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(trainer):
    print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output))

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    evaluator.run(train_loader)
    metrics = evaluator.state.metrics
    print("Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))

@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
    evaluator.run(val_loader)
    metrics = evaluator.state.metrics
    print("Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))

trainer.run(train_loader, max_epochs=100)

显然,这里先创建网络模型,Dataloader,优化器以及目标函数,然后用 ignite 的方法 create_supervised_trainer 和 create_supervised_evaluator 简化以往繁琐的循环写法,另外,ignite 还提供了面向切面的处理方法,可以在epoch、iteration等开始前、结束后位置执行你希望的操作

3. Engine

这是 ignite 的核心类,它是一种抽象,它在提供的数据上循环给定的次数,执行处理函数并返回结果

while epoch < max_epochs:
    # run an epoch on data
    data_iter = iter(data)
    while True:
        try:
            batch = next(data_iter)
            output = process_function(batch)
            iter_counter += 1
        except StopIteration:
            data_iter = iter(data)

        if iter_counter == epoch_length:
            break

因此,模型训练器只是一个引擎,它在训练数据集上循环多次并更新模型参数。例如:

def train_step(trainer, batch):
    model.train()
    optimizer.zero_grad()
    x, y = prepare_batch(batch)
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(train_step)
trainer.run(data, max_epochs=100)

【例 1】创建一个基本的训练器

def update_model(engine, batch):
    inputs, targets = batch
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(update_model)

@trainer.on(Events.ITERATION_COMPLETED(every=100))
def log_training(engine):
    batch_loss = engine.state.output
    lr = optimizer.param_groups[0]['lr']
    e = engine.state.epoch
    n = engine.state.max_epochs
    i = engine.state.iteration
    print("Epoch {}/{} : {} - batch loss: {}, lr: {}".format(e, n, i, batch_loss, lr))

trainer.run(data_loader, max_epochs=5)

【例 2】创建一个基本的评估器并计算指标

from ignite.metrics import Accuracy

def predict_on_batch(engine, batch)
    model.eval()
    with torch.no_grad():
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)

    return y_pred, y

evaluator = Engine(predict_on_batch)
Accuracy().attach(evaluator, "val_acc")
evaluator.run(val_dataloader)

【例 3】在训练数据集上计算图像均值/标准差

from ignite.metrics import Average

def compute_mean_std(engine, batch):
    b, c, *_ = batch['image'].shape
    data = batch['image'].reshape(b, c, -1).to(dtype=torch.float64)
    mean = torch.mean(data, dim=-1).sum(dim=0)
    mean2 = torch.mean(data ** 2, dim=-1).sum(dim=0)
    return {
    
    "mean": mean, "mean^2": mean2}

compute_engine = Engine(compute_mean_std)
img_mean = Average(output_transform=lambda output: output['mean'])
img_mean.attach(compute_engine, 'mean')
img_mean2 = Average(output_transform=lambda output: output['mean^2'])
img_mean2.attach(compute_engine, 'mean2')
state = compute_engine.run(train_loader)
state.metrics['std'] = torch.sqrt(state.metrics['mean2'] - state.metrics['mean'] ** 2)
mean = state.metrics['mean'].tolist()
std = state.metrics['std'].tolist()

【例 4】从状态恢复引擎的运行。用户可以加载state_dict并从加载的状态开始运行引擎

# Restore from an epoch
state_dict = {
    
    "epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)}
# or an iteration
# state_dict = {"iteration": 500, "max_epochs": 100, "epoch_length": len(data_loader)}

trainer = Engine(...)
trainer.load_state_dict(state_dict)
trainer.run(data)

Engine 对象还有以下方法:

  • terminate():向引擎发送终止信号,以便它在当前迭代之后完全终止运行。

  • terminate_epoch():向引擎发送终止信号,以便它在当前迭代之后终止当前epoch。

  • ignite.engine.create_supervised_trainer

    工厂功能,用于创建受监管模型的trainer。

    def create_supervised_trainer(
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        loss_fn: Union[Callable, torch.nn.Module],
        device: Optional[Union[str, torch.device]] = None,
        non_blocking: bool = False,
        prepare_batch: Callable = _prepare_batch,
        output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
        deterministic: bool = False,
    ) -> Engine:
    

    model:要训练的模型

    optimizer:要使用的优化器

    loss_fn:要使用的损失函数

    device:设备类型规范(默认值:无)Device can be CPU, GPU or TPU

    non_blocking:如果为True且此副本位于CPU和GPU之间,则该副本可能相对于主机异步发生。在其他情况下,此参数无效。

    prepare_batch:接收(batch,device,non_blocking)并输出张量元组(batch_x,batch_y)的函数

    output_transform:接收“ x”,“ y”,“ y_pred”,“ loss”并返回要分配给引擎状态的值的函数。每次迭代后输出。默认为returning loss.item()

    deterministic:如果为True,则返回类型为确定性的引擎DeterministicEngine,否则返回 Engine (默认值:False)

  • 类似地还有ignite.engine.create_supervised_evaluator,其参数少于trainer

    def create_supervised_evaluator(
        model: torch.nn.Module,
        metrics: Optional[Dict[str, Metric]] = None,
        device: Optional[Union[str, torch.device]] = None,
        non_blocking: bool = False,
        prepare_batch: Callable = _prepare_batch,
        output_transform: Callable = lambda x, y, y_pred: (y_pred, y),
    ) -> Engine:
    

    model:训练好的模型

    metrics:指标名称到指标的映射

    device:设备类型规范(默认值:无)Device can be CPU, GPU or TPU

    output_transform:接收“ x”,“ y”,“ y_pred” 并在每次迭代后返回要分配给引擎state.output的值的函数。默认为返回值(y_pred,y,),它适合度量期望的输出。如果更改它,则应在指标中使用output_transform

【例 5】断点恢复训练

有可能从一个检查点恢复训练,并大致重现原来的运行行为。使用Ignite,这可以通过使用检查点处理程序轻松完成。引擎提供了两个方法来序列化和反序列化其内部状态state_dict()和load_state_dict()。除了序列化模型,优化器,lr调度器等用户可以存储培训器,然后恢复培训。例如

from ignite.handlers import Checkpoint, DiskSaver

trainer = ...
model = ...
optimizer = ...
lr_scheduler = ...
data_loader = ...

to_save = {
    
    'trainer': trainer, 
           'model': model, 
           'optimizer': optimizer, 
           'lr_scheduler': lr_scheduler}

handler = Checkpoint(to_save, DiskSaver('/tmp/training', create_dir=True))
trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)
trainer.run(data_loader, max_epochs=100)

然后,我们可以从最后一个检查点恢复训练。

from ignite.handlers import Checkpoint

trainer = ...
model = ...
optimizer = ...
lr_scheduler = ...
data_loader = ...

to_load = {
    
    'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
checkpoint = torch.load(checkpoint_file)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)

trainer.run(train_loader, max_epochs=100)

4. Events & Handlers

为了提高 Engine 灵活性,引入了一个事件系统,该系统促进了运行的每个步骤之间的交互:

  • engine is started/completed
  • epoch is started/completed
  • batch iteration is started/completed

详细的事件可以进ignite.engine.events

下面展示了 Enginerun() 方法执行的细节:

fire_event(Events.STARTED)
while epoch < max_epochs:
    fire_event(Events.EPOCH_STARTED)
    # run once on data
    for batch in data:
        fire_event(Events.ITERATION_STARTED)

        output = process_function(batch)

        fire_event(Events.ITERATION_COMPLETED)
    fire_event(Events.EPOCH_COMPLETED)
fire_event(Events.COMPLETED)

上述代码展示了各个事件执行的位置

使用事件的方法又2种:add_event_handler()装饰器 on

trainer = Engine(update_model)

trainer.add_event_handler(Events.STARTED, lambda _: print("Start training"))

# or
@trainer.on(Events.STARTED)
def on_training_started(engine):
    print("Another message of start training")
    
# or even simpler, use only what you need !
@trainer.on(Events.STARTED)
def on_training_started():
    print("Another message of start training")

# attach handler with args, kwargs
mydata = [1, 2, 3, 4]

def on_training_ended(data):
    print("Training is ended. mydata={}".format(data))

trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)

利用add_event_handler()方法还可以动态添加事件

model = ...
train_loader, validation_loader, test_loader = ...

trainer = create_supervised_trainer(model, optimizer, loss)
evaluator = create_supervised_evaluator(model, metrics={
    
    "acc": Accuracy()})

def log_metrics(engine, title):
    print("Epoch: {} - {} accuracy: {:.2f}"
           .format(trainer.state.epoch, title, engine.state.metrics["acc"]))

@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(trainer):
    with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "train"):
        evaluator.run(train_loader)

    with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "validation"):
        evaluator.run(validation_loader)

    with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "test"):
        evaluator.run(test_loader)

trainer.run(train_loader, max_epochs=100)

还可以将事件处理程序配置为以用户模式调用:每第n个事件一次,或使用自定义事件过滤功能:

model = ...
train_loader, validation_loader, test_loader = ...

trainer = create_supervised_trainer(model, optimizer, loss)

@trainer.on(Events.ITERATION_COMPLETED(every=50))
def log_training_loss_every_50_iterations():
    print("{} / {} : {} - loss: {:.2f}"
          .format(trainer.state.epoch, trainer.state.max_epochs, trainer.state.iteration, trainer.state.output))

@trainer.on(Events.EPOCH_STARTED(once=25))
def do_something_once_on_25_epoch():
    # do something

def custom_event_filter(engine, event):
    if event in [1, 2, 5, 10, 50, 100]:
        return True
    return False

@engine.on(Events.ITERATION_STARTED(event_filter=custom_event_filter))
def call_on_special_event(engine):
     # do something on 1, 2, 5, 10, 50, 100 iterations

trainer.run(train_loader, max_epochs=100)

也可以自定义Events:

class CustomEvents(EventEnum):
    """
    Custom events defined by user
    """
    CUSTOM_STARTED = 'custom_started'
    CUSTOM_COMPLETED = 'custom_completed'

engine.register_events(*CustomEvents)

可以同时对某个handler设置多个events:

events = Events.STARTED | Events.COMPLETED | Events.ITERATION_STARTED(every=3)
engine = ...

@engine.on(events)
def call_on_events(engine):
    # do something

这些事件可用于附加任何处理程序,并使用触发fire_event()

@engine.on(CustomEvents.CUSTOM_STARTED)
def call_on_custom_event(engine):
     # do something

@engine.on(Events.STARTED)
def fire_custom_events(engine):
     engine.fire_event(CustomEvents.CUSTOM_STARTED)

Handlers 函数的参数不一定非得是engine,不涉及可以空参,可以多个其他参数

也可以允许将事件过滤器传递给引擎:

engine = Engine()

# a) custom event filter
def custom_event_filter(engine, event):
    if event in [1, 2, 5, 10, 50, 100]:
        return True
    return False

@engine.on(Events.ITERATION_STARTED(event_filter=custom_event_filter))
def call_on_special_event(engine):
    # do something on 1, 2, 5, 10, 50, 100 iterations

# b) "every" event filter
@engine.on(Events.ITERATION_STARTED(every=10))
def call_every(engine):
    # do something every 10th iteration

# c) "once" event filter
@engine.on(Events.ITERATION_STARTED(once=50))
def call_once(engine):
    # do something on 50th iteration

5. 内置Handlers

库提供了一组内置处理程序,用于检查训练流水线,保存最佳模型,在没有改进的情况下停止训练,使用实验跟踪系统等。可以在以下两个模块中找到它们:

  • ignite.handlers
  • ignite.contrib.handlers

一些类可以简单地添加Engine为可调用函数。例如,

from ignite.handlers import TerminateOnNan

trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

另外还提供了attach()方法,咋程序执行中手动的添加handles给Engine

from ignite.contrib.handlers.tensorboard_logger import *

# Create a logger
tb_logger = TensorboardLogger(log_dir="experiments/tb_logs")

# Attach the logger to the trainer to log model's weights as a histogram after each epoch
tb_logger.attach(
    trainer,
    event_name=Events.EPOCH_COMPLETED,
    log_handler=WeightsHistHandler(model)
)

6.State

State 是用来存储 Engine 的输出结果的,每一个Engine对象都有 State 属性

  • engine.state.seed: Seed to set at each data “epoch”.
  • engine.state.epoch: Number of epochs the engine has completed. Initializated as 0 and the first epoch is 1.
  • engine.state.iteration: Number of iterations the engine has completed. Initialized as 0 and the first iteration is 1.
  • engine.state.max_epochs: Number of epochs to run for. Initializated as 1.
  • engine.state.output: The output of the process_function defined for the Engine.
  • etc

其他的可在技术文档里查找

在下面的代码中,engine.state.output 将存储批次损失。此输出用于打印每次迭代的损失。

def update(engine, batch):
    x, y = batch
    y_pred = model(inputs)
    loss = loss_fn(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

def on_iteration_completed(engine):
    iteration = engine.state.iteration
    epoch = engine.state.epoch
    loss = engine.state.output
    print("Epoch: {}, Iteration: {}, Loss: {}".format(epoch, iteration, loss))

trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed)

在下面的代码中,engine.state.output将是已处理批次的损耗列表y_pred,y。如果要连接Accuracy到引擎,则需要output_transform来从engine.state.output获取y_pred和y

def update(engine, batch):
    x, y = batch
    y_pred = model(inputs)
    loss = loss_fn(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item(), y_pred, y

trainer = Engine(update)

@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):
    epoch = engine.state.epoch
    loss = engine.state.output[0]
    print ('Epoch {epoch}: train_loss = {loss}'.format(epoch=epoch, loss=loss))

accuracy = Accuracy(output_transform=lambda x: [x[1], x[2]])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)

与上面类似,但是这次process_function的输出是处理后的批次的损耗字典y_pred,y,这是用户可以使用output_transform从engine.state.output获取y_pred和y的方式

def update(engine, batch):
    x, y = batch
    y_pred = model(inputs)
    loss = loss_fn(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return {
    
    'loss': loss.item(),
            'y_pred': y_pred,
            'y': y}

trainer = Engine(update)

@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):
    epoch = engine.state.epoch
    loss = engine.state.output['loss']
    print ('Epoch {epoch}: train_loss = {loss}'.format(epoch=epoch, loss=loss))

accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)

优良作法是State还用作存储在更新或处理程序函数中创建的用户数据。例如,我们想将new_attribute保存为state:

def user_handler_function(engine):
 engine.state.new_attribute = 12345

7. Metrics

库提供了各种机器学习任务的现成指标列表。支持两种计算指标的方式:1)在线和2)存储整个输出历史记录

指标可以附加到 Engine:

from ignite.metrics import Accuracy

accuracy = Accuracy()

accuracy.attach(evaluator, "accuracy")

state = evaluator.run(validation_data)

print("Result:", state.metrics)
# > {"accuracy": 0.12345}

猜你喜欢

转载自blog.csdn.net/dreaming_coder/article/details/108548798