Getting Started with PyTorch Lightning Basics

Lightning in 15 minutes

Lightning in 15 minutes — PyTorch Lightning 2.0.4 documentation

Install PyTorch Lightning

pip install lightning

conda install lightning -c conda-forge

Define a LightningModule

LightningModuleIt is possible to pytorchintegrate nn.Modulesome training processes (and also validation and testing).

The following is an example of an autoencoder for handwritten digit recognition:

import os
import torch
from torch import optim, nn, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning.pytorch as pl

'''
定义两个模型,编码器和解码器,这个是pytorch的模型对象
'''
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

# 定义LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # 训练步骤
        # 这个跟 forward 不相关
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # 存储日志(需要安装Tensorboard)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
				# 优化器
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# 初始化自动编码器
autoencoder = LitAutoEncoder(encoder, decoder)

define dataset

LightningAll iterable dataset forms ( DataLoader, numpy, and others) are supported.

# setup 
datadataset=MNIST(os.getcwd(),download=True,transform=ToTensor())
train_loader=utils.data.DataLoader(dataset)

training model

LightningThe Trainerobjects can be integrated LightningModulewith different data sets, and some methods are extended for engineering.

# 训练模型
trainer = pl.Trainer(limit_train_batches=100, max_epochs=10)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

TrainerThe object also implements many commonly used procedures:

  1. Epochand batchiterate.
  2. optimizer.step()loss.backward()optimizer.zero_grad()
  3. ** during verification model.eval(). **
  4. Model storage and loading
  5. Tensorboard
  6. Multi-GPU
  7. TPU
  8. half-precision blending

[Note]: Under jupyter, multi-card training may report an error, you can try to use pythonthe code directly.

use model

After training the model, you can export to onnx, torchscript and put it into production, or just load the weights and run predictions.

# 载入模型
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# 选择训练好的编码器
encoder = autoencoder.encoder
encoder.eval()

# 编码图片
fake_image_batch = torch.randn(8, 28 * 28).to(next(encoder.parameters()).device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

training visualization

If installed Tensorboard, you can use it to observe the experiment process.

tensorboard --logdir .

Additional training settings

# 4gpu训练
trainer = Trainer(
    devices=4,
    accelerator="gpu",
 )

# train 1TB+ parameter models with Deepspeed/fsdp
# 使用 Deepspeed 训练大模型
trainer = Trainer(
    devices=4,
    accelerator="gpu",
    strategy="deepspeed_stage_2",
    precision=16
 )

# 20+ helpful flags for rapid idea iteration
# 有助于快速迭代的一些设置
trainer = Trainer(
    max_epochs=10,
    min_epochs=5,
    overfit_batches=1
 )

# access the latest state of the art techniques
# 获取最新的技术
trainer = Trainer(callbacks=[StochasticWeightAveraging(...)])

some flexible settings

Custom Training Loops

https://img1.imgtp.com/2023/06/28/GOfzfhZ0.png

LightningModuleMore than 20 kinds of breakpoints ( HOOK ) are set in , which can be used to customize the training process:

class LitAutoEncoder(pl.LightningModule):
    def backward(self, loss):
        loss.backward()

Extend Trainer

https://img1.imgtp.com/2023/06/28/PhzDKj0U.png

In the above code, some settings are made for the storage of the model. These settings can pl.Callbackbe implemented in objects and imported into Trainerobjects.

https://img1.imgtp.com/2023/06/28/4lU9PREN.png

It can be imported as follows Trainer:

trainer = Trainer(callbacks=[AWSCheckpoints()])

Guess you like

Origin blog.csdn.net/qq_42464569/article/details/131443795