一行代码实现深度学习训练:inferno Pytorch 安装及介绍

inferno简介

Inferno是一个小库,提供了围绕PyTorch的实用程序和方便的函数/类。
其主要功能包括:

  • 一个基本的训练类,用来封装训练过程(迭代/epoch循环、验证和checkpoint创建)
  • 由networkx提供的用于构建复杂架构模型的图形API
  • 数据并行在多个gpu上更容易
  • Pytorch神经网络的子模块,模块级参数初始化
  • 数据预处理/变换的子模块
  • 支持Tensorboard
  • 一个回调API来支持与调参工程师的灵活交互
  • 未来将拥有更多功能

安装

Conda packages for python >= 3.6 for all distributions are availaible on conda-forge:

$ conda install -c pytorch -c conda-forge inferno

一个简单的示例

安装成功之后,可以直接运行下面的代码(来自官网 )。

import torch.nn as nn
from inferno.io.box.cifar import get_cifar10_loaders
from inferno.trainers.basic import Trainer
from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger
from inferno.extensions.layers.convolutional import ConvELU2D
from inferno.extensions.layers.reshape import Flatten

# Fill these in:
LOG_DIRECTORY = '...'
SAVE_DIRECTORY = '...'
DATASET_DIRECTORY = '...'
DOWNLOAD_CIFAR = True
USE_CUDA = True

# Build torch model
model = nn.Sequential(
    ConvELU2D(in_channels=3, out_channels=256, kernel_size=3),
    nn.MaxPool2d(kernel_size=2, stride=2),
    ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
    nn.MaxPool2d(kernel_size=2, stride=2),
    ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
    nn.MaxPool2d(kernel_size=2, stride=2),
    Flatten(),
    nn.Linear(in_features=(256 * 4 * 4), out_features=10),
    nn.LogSoftmax(dim=1)
)

# Load loaders
train_loader, validate_loader = get_cifar10_loaders(DATASET_DIRECTORY,
                                                    download=DOWNLOAD_CIFAR)

# Build trainer
trainer = Trainer(model) \
  .build_criterion('NLLLoss') \
  .build_metric('CategoricalError') \
  .build_optimizer('Adam') \
  .validate_every((2, 'epochs')) \
  .save_every((5, 'epochs')) \
  .save_to_directory(SAVE_DIRECTORY) \
  .set_max_num_epochs(10) \
  .build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),
                                  log_images_every='never'),
                log_directory=LOG_DIRECTORY)

# Bind loaders
trainer \
    .bind_loader('train', train_loader) \
    .bind_loader('validate', validate_loader)

if USE_CUDA:
  trainer.cuda()

# Go!
trainer.fit()

可视化方法

tensorboard --logdir="./" --port=6007

如果一切正常的话,会显示以下结果
在这里插入图片描述

以下是我目前(2020.11)总结的inferno(0.1.29)的各个模块的功能:

在这里插入图片描述
部分模块(上图画框部分)的源码介绍及使用见我的其他文章:
inferno Pytorch: inferno.io.transform 介绍及使用
inferno Pytorch: inferno.io.box.cifar下载cifar10 cifar100数据集 介绍及使用
inferno Pytorch: inferno.extensions.layers.convolutional 介绍及使用

参考链接
https://github.com/inferno-pytorch/inferno/

猜你喜欢

转载自blog.csdn.net/qq_36937684/article/details/110188199
今日推荐