Lightning in 15 minutes
Lightning in 15 minutes — PyTorch Lightning 2.0.4 documentation
Install PyTorch Lightning
pip install lightning
conda install lightning -c conda-forge
Define a LightningModule
LightningModule
It is possible to pytorch
integrate nn.Module
some training processes (and also validation and testing).
The following is an example of an autoencoder for handwritten digit recognition:
import os
import torch
from torch import optim, nn, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning.pytorch as pl
'''
定义两个模型,编码器和解码器,这个是pytorch的模型对象
'''
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
# 定义LightningModule
class LitAutoEncoder(pl.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def training_step(self, batch, batch_idx):
# 训练步骤
# 这个跟 forward 不相关
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
# 存储日志(需要安装Tensorboard)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
# 优化器
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# 初始化自动编码器
autoencoder = LitAutoEncoder(encoder, decoder)
define dataset
Lightning
All iterable dataset forms ( DataLoader
, numpy
, and others) are supported.
# setup
datadataset=MNIST(os.getcwd(),download=True,transform=ToTensor())
train_loader=utils.data.DataLoader(dataset)
training model
Lightning
The Trainer
objects can be integrated LightningModule
with different data sets, and some methods are extended for engineering.
# 训练模型
trainer = pl.Trainer(limit_train_batches=100, max_epochs=10)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
Trainer
The object also implements many commonly used procedures:
Epoch
andbatch
iterate.optimizer.step()
,loss.backward()
,optimizer.zero_grad()
- ** during verification
model.eval()
. ** - Model storage and loading
- Tensorboard
- Multi-GPU
- TPU
- half-precision blending
[Note]: Under jupyter, multi-card training may report an error, you can try to use python
the code directly.
use model
After training the model, you can export to onnx, torchscript and put it into production, or just load the weights and run predictions.
# 载入模型
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)
# 选择训练好的编码器
encoder = autoencoder.encoder
encoder.eval()
# 编码图片
fake_image_batch = torch.randn(8, 28 * 28).to(next(encoder.parameters()).device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)
training visualization
If installed Tensorboard
, you can use it to observe the experiment process.
tensorboard --logdir .
Additional training settings
# 4gpu训练
trainer = Trainer(
devices=4,
accelerator="gpu",
)
# train 1TB+ parameter models with Deepspeed/fsdp
# 使用 Deepspeed 训练大模型
trainer = Trainer(
devices=4,
accelerator="gpu",
strategy="deepspeed_stage_2",
precision=16
)
# 20+ helpful flags for rapid idea iteration
# 有助于快速迭代的一些设置
trainer = Trainer(
max_epochs=10,
min_epochs=5,
overfit_batches=1
)
# access the latest state of the art techniques
# 获取最新的技术
trainer = Trainer(callbacks=[StochasticWeightAveraging(...)])
some flexible settings
Custom Training Loops
LightningModule
More than 20 kinds of breakpoints ( HOOK ) are set in , which can be used to customize the training process:
class LitAutoEncoder(pl.LightningModule):
def backward(self, loss):
loss.backward()
Extend Trainer
In the above code, some settings are made for the storage of the model. These settings can pl.Callback
be implemented in objects and imported into Trainer
objects.
It can be imported as follows Trainer
:
trainer = Trainer(callbacks=[AWSCheckpoints()])