pytorch lightning official manual
Introduction to Pytorch Lightning
PyTorch Lightning is a deep learning framework for professional AI researchers and machine learning engineers who need maximum flexibility without sacrificing large-scale performance. lightning brings your idea to paper and product at the same speed.
LightningModule is a lightweight structure of the original PyTorch, allowing maximum flexibility and minimum library files. It acts as a model "recipe" that specifies all the training details.
Write 80% less code. Lightning removes about 80% of the duplicated code (boilerplate) to minimize the surface area for bugs so you can focus on delivering value instead of engineering.
Maintaining maximum flexibility, the complete PyTorch training code can be defined in training_step.
Handle data sets of any size, without special requirements, directly use PyTorch dataloader to process massive data sets
Install Lightning
pip install lightning
or
conda install lightning -c conda-forge
Import related packages after installation
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import TestTubeLogger
Define LightningModule
LightningModule organizes your PyTorch code into 6 parts:
Initialization (__init__ and setup()).
Training (training_step())
Validation (validation_step())
Testing (test_step())
Prediction (predict_step())
Optimizer and LR scheduler (configure_optimizers())
When you use Lightning, the code isn't abstracted -- it's just organized. All other code not in LightningModule has been automatically executed for you by Trainer.
net = MyLightningModuleNet()
trainer = Trainer()
trainer.fit(net)
No .cuda() or .to(device) calls are required. Lightning already does this for you. as follows:
# don't do in Lightning
x = torch.Tensor(2, 3)
x = x.cuda()
x = x.to(device)
# do this instead
x = x # leave it alone!
# or to init a new tensor
new_x = torch.Tensor(2, 3)
new_x = new_x.to(x)
When running under the distributed strategy, Lightning handles distributed samplers for you by default.
# Don't do in Lightning...
data = MNIST(...)
sampler = DistributedSampler(data)
DataLoader(data, sampler=sampler)
# do this instead
data = MNIST(...)
DataLoader(data)
LightningModule is actually a torch.nn.Module, but with some added features:
net = Net.load_from_checkpoint(PATH)
net.freeze()
out = net(x)
Example: Building a Network with LightningTraining a Network
1. Build the model
import lightning.pytorch as pl
import torch.nn as nn
import torch.nn.functional as F
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
2 Train the network
train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
trainer = pl.Trainer(max_epochs=1)
model = LitModel()
trainer.fit(model, train_dataloaders=train_loader)
3 other LightningModule:
Name
Description
__init__
andsetup()
initialization
forward()
Run data through model only (separate from training_step)
training_step()
Complete training steps
validation_step()
Complete Verification Steps
test_step()
complete test procedure
predict_step()
Complete Prediction Steps
configure_optimizers()
Define optimizer and LR scheduler
3.1 Lightning dataset loading
There are two implementation methods for datasets:
- Directly call third-party public datasets (such as: MNIST and other datasets)
- Custom dataset (inherit torch.utils.data.dataset.Dataset, custom class)
3.1.1 Using public datasets
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
class MyExampleModel(pl.LightningModule):
def __init__(self, args):
super().__init__()
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_dataset, val_dataset, test_dataset = random_split(dataset, [50000, 5000, 5000])
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.test_dataset = test_dataset
...
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=1, shuffle=True)
3.1.2 Custom dataset
(1) Complete the compilation of dataset by yourself
# -*- coding: utf-8 -*-
'''
@Description: Define the format of data used in the model.
'''
import sys
import pathlib
import torch
from torch.utils.data import Dataset
from utils import sort_batch_by_len, source2ids
abs_path = pathlib.Path(__file__).parent.absolute()
sys.path.append(sys.path.append(abs_path))
class SampleDataset(Dataset):
"""
The class represents a sample set for training.
"""
def __init__(self, data_pairs, vocab):
self.src_texts = [data_pair[0] for data_pair in data_pairs]
self.tgt_texts = [data_pair[1] for data_pair in data_pairs]
self.vocab = vocab
self._len = len(data_pairs) # Keep track of how many data points.
def __len__(self):
return self._len
def __getitem__(self, index):
# print("\nself.src_texts[{0}] = {1}".format(index, self.src_texts[index]))
src_ids, oovs = source2ids(self.src_texts[index], self.vocab) # 将当前文本self.src_texts[index]转为ids,oovs为超出词典范围的词汇文本
item = {
'x': [self.vocab.SOS] + src_ids + [self.vocab.EOS],
'y': [self.vocab.SOS] + [self.vocab[i] for i in self.tgt_texts[index]] + [self.vocab.EOS],
'x_len': len(self.src_texts[index]),
'y_len': len(self.tgt_texts[index]),
'oovs': oovs,
'len_oovs': len(oovs)
}
return item
(2) Customize the DataModule class (inherited from LightningDataModule) to call DataLoader
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
class MyDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
def prepare_data(self):
# 在该函数里一般实现数据集的下载等,只有cuda:0 会执行该函数
# download, split, etc...
# only called on 1 GPU/TPU in distributed
pass
def forward()
def setup(self, stage):
# make assignments here (val/train/test split)
# called on every process in DDP
# 实现数据集的定义,每张GPU都会执行该函数, stage 用于标记是用于什么阶段
if stage == 'fit' or stage is None:
self.train_dataset = MyDataset(self.train_file_path, self.train_file_num, transform=None)
self.val_dataset = MyDataset(self.val_file_path, self.val_file_num, transform=None)
if stage == 'test' or stage is None:
self.test_dataset = MyDataset(self.test_file_path, self.test_file_num, transform=None)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=1, shuffle=True)
3.2Training
3.2.1Training Loop:
To activate the training loop, override training_step().
class LitClassifier(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
return loss #一定要返回loss,其中batch 即为从 train_dataloader 采样的一个batch的数据,batch_idx即为目前batch的索引
3.2.2 Train Epoch-level Metrics:
If you want to compute time-level metrics and log them, use log().
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
# logs metrics for each training_step,
# and the average across the epoch, to the progress bar and logger
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
3.2.3Train Epoch-level Operations
Override the on_train_epoch_end() method if you need to use all outputs from each training_step().
def __init__(self):
super().__init__()
self.training_step_outputs = []
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
preds = ...
self.training_step_outputs.append(preds)
return loss
def on_train_epoch_end(self):
all_preds = torch.stack(self.training_step_outputs)
# do something with all preds
...
self.training_step_outputs.clear() # free memory
3.3 Validation
3.3.1 Validation Loop
To activate the validation loop during training, override the validation_step() function.
class LitModel(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
self.log("val_loss", loss)
It is also possible to run only the validation loop on the validation data loader by overriding validation_step() and calling validate().
model = Model()
trainer = Trainer()
trainer.validate(model)
Validation on individual devices is recommended to ensure each sample/sample is evaluated exactly once. This helps ensure that research papers are being benchmarked in the correct manner. Otherwise, in a multi-device setup, samples may be duplicated when using DistributedSampler, eg strategy="ddp" . It replicates some samples across some devices to ensure that all devices have the same batch size in case of uneven input.
3.3.2 Validation Epoch-level Metrics
Override the on_validation_epoch_end() function if you need to use all the outputs of each validation_step(). Note that this method is called before on_train_epoch_end().
def __init__(self):
super().__init__()
self.validation_step_outputs = []
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
self.validation_step_outputs.append(pred)
return pred
def on_validation_epoch_end(self):
all_preds = torch.stack(self.validation_step_outputs)
# do something with all preds
...
self.validation_step_outputs.clear() # free memory
3.4 Testing
3.4.1Test Loop
The procedure for enabling a test loop is the same as for enabling a verification loop. See the above section for details. To do this, rewrite the test_step() function.
model = Model()
trainer = Trainer()
trainer.fit(model)
# automatically loads the best weights for you
trainer.test(model)
有两种方式来调用
test()
:
# call after training
trainer = Trainer()
trainer.fit(model)
# automatically auto-loads the best weights from the previous run
trainer.test(dataloaders=test_dataloader)
# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
trainer.test(model, dataloaders=test_dataloader)
As above, validation on a single device is recommended to ensure each sample is evaluated exactly once. This helps ensure that research papers are being benchmarked in the correct manner. Otherwise, in a multi-device setup, samples may be duplicated when using DistributedSampler, eg. policy="ddp". It replicates some samples across some devices to ensure that all devices have the same batch size in case of uneven input.
3.5 Inference
3.5.1Prediction Loop
By default, the predict_step() method runs the forward() method. To customize this behavior, just override the predict_step() method. As follows, rewrite predict_step() and try Monte Carlo Dropout:
class LitMCdropoutModel(pl.LightningModule):
def __init__(self, model, mc_iteration):
super().__init__()
self.model = model
self.dropout = nn.Dropout()
self.mc_iteration = mc_iteration
def predict_step(self, batch, batch_idx):
# enable Monte Carlo Dropout
self.dropout.train()
# take average of `self.mc_iteration` iterations
pred = torch.vstack([self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]).mean(dim=0)
return pred
Called in two ways predict()
:
# call after training
trainer = Trainer()
trainer.fit(model)
# automatically auto-loads the best weights from the previous run
predictions = trainer.predict(dataloaders=predict_dataloader)
# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
predictions = trainer.predict(model, dataloaders=test_dataloader)
NOTE:
The training_step is followed by its corresponding training_step_end(self, batch_parts) and training_epoch_end(self, training_step_outputs) functions;
Validation_step is followed by its corresponding validation_step_end(self, batch_parts) and validation_epoch_end(self, training_step_outputs) functions;
test_step is followed by its corresponding test_step_end(self, batch_parts) and test_epoch_end(self, training_step_outputs) functions
3.6 Save the model with Trainer
Set the default_root_dir parameter in Trainer, Lightning will automatically save the most recently trained epoch model to the current workspace (or.getcwd()), or specify it when defining Trainer:
trainer = Trainer(default_root_dir='/your/path/to/save/checkpoints')
Autosave of models can also be turned off:
trainer = Trainer(checkpoint_callback=False)
3.7 Load the pre-trained model, the complete process
def main(hparams):
system = NeRFSystem(hparams)
checkpoint_callback = \
ModelCheckpoint(filepath=os.path.join(f'ckpts/{hparams.exp_name}',
'{epoch:d}'),
monitor='val/psnr',
mode='max',
save_top_k=-1)
logger = TestTubeLogger(save_dir="logs",
name=hparams.exp_name,
debug=False,
create_git_tag=False,
log_graph=False)
trainer = Trainer(max_epochs=hparams.num_epochs,
checkpoint_callback=checkpoint_callback,
resume_from_checkpoint=hparams.ckpt_path,
logger=logger,
weights_summary=None,
progress_bar_refresh_rate=hparams.refresh_every,
gpus=hparams.num_gpus,
accelerator='ddp' if hparams.num_gpus>1 else None,
num_sanity_val_steps=1,
benchmark=True,
profiler="simple" if hparams.num_gpus==1 else None)
trainer.fit(system)
if __name__ == '__main__':
hparams = get_opts()
main(hparams)
4 The complete example is as follows, NeRFW:
import os
from opt import get_opts
import torch
from collections import defaultdict
from torch.utils.data import DataLoader
from datasets import dataset_dict
# models
from models.nerf import *
from models.rendering import *
# optimizer, scheduler, visualization
from utils import *
# losses
from losses import loss_dict
# metrics
from metrics import *
# pytorch-lightning
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import TestTubeLogger
class NeRFSystem(LightningModule):
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
# self.hparams.update(hparams)
self.loss = loss_dict['nerfw'](coef=1)
self.models_to_train = []
self.embedding_xyz = PosEmbedding(hparams.N_emb_xyz-1, hparams.N_emb_xyz)
self.embedding_dir = PosEmbedding(hparams.N_emb_dir-1, hparams.N_emb_dir)
self.embeddings = {'xyz': self.embedding_xyz,
'dir': self.embedding_dir}
if hparams.encode_a:
self.embedding_a = torch.nn.Embedding(hparams.N_vocab, hparams.N_a)
self.embeddings['a'] = self.embedding_a
self.models_to_train += [self.embedding_a]
if hparams.encode_t:
self.embedding_t = torch.nn.Embedding(hparams.N_vocab, hparams.N_tau)
self.embeddings['t'] = self.embedding_t
self.models_to_train += [self.embedding_t]
self.nerf_coarse = NeRF('coarse',
in_channels_xyz=6*hparams.N_emb_xyz+3,
in_channels_dir=6*hparams.N_emb_dir+3)
self.models = {'coarse': self.nerf_coarse}
if hparams.N_importance > 0:
self.nerf_fine = NeRF('fine',
in_channels_xyz=6*hparams.N_emb_xyz+3,
in_channels_dir=6*hparams.N_emb_dir+3,
encode_appearance=hparams.encode_a,
in_channels_a=hparams.N_a,
encode_transient=hparams.encode_t,
in_channels_t=hparams.N_tau,
beta_min=hparams.beta_min)
self.models['fine'] = self.nerf_fine
self.models_to_train += [self.models]
def get_progress_bar_dict(self):
items = super().get_progress_bar_dict()
items.pop("v_num", None)
return items
def forward(self, rays, ts):
"""Do batched inference on rays using chunk."""
B = rays.shape[0]
results = defaultdict(list)
for i in range(0, B, self.hparams.chunk):
rendered_ray_chunks = \
render_rays(self.models,
self.embeddings,
rays[i:i+self.hparams.chunk],
ts[i:i+self.hparams.chunk],
self.hparams.N_samples,
self.hparams.use_disp,
self.hparams.perturb,
self.hparams.noise_std,
self.hparams.N_importance,
self.hparams.chunk, # chunk size is effective in val mode
self.train_dataset.white_back)
for k, v in rendered_ray_chunks.items():
results[k] += [v]
for k, v in results.items():
results[k] = torch.cat(v, 0)
return results
def setup(self, stage):
dataset = dataset_dict[self.hparams.dataset_name]
kwargs = {'root_dir': self.hparams.root_dir}
if self.hparams.dataset_name == 'phototourism':
kwargs['img_downscale'] = self.hparams.img_downscale
kwargs['val_num'] = self.hparams.num_gpus
kwargs['use_cache'] = self.hparams.use_cache
elif self.hparams.dataset_name == 'blender':
kwargs['img_wh'] = tuple(self.hparams.img_wh)
kwargs['perturbation'] = self.hparams.data_perturb
self.train_dataset = dataset(split='train', **kwargs)
self.val_dataset = dataset(split='val', **kwargs)
def configure_optimizers(self):
self.optimizer = get_optimizer(self.hparams, self.models_to_train)
scheduler = get_scheduler(self.hparams, self.optimizer)
return [self.optimizer], [scheduler]
def train_dataloader(self):
return DataLoader(self.train_dataset,
shuffle=True,
num_workers=4,
batch_size=self.hparams.batch_size,
pin_memory=True)
def val_dataloader(self):
return DataLoader(self.val_dataset,
shuffle=False,
num_workers=4,
batch_size=1, # validate one image (H*W rays) at a time
pin_memory=True)
def training_step(self, batch, batch_nb):
rays, rgbs, ts = batch['rays'], batch['rgbs'], batch['ts']
results = self(rays, ts)
loss_d = self.loss(results, rgbs)
loss = sum(l for l in loss_d.values())
with torch.no_grad():
typ = 'fine' if 'rgb_fine' in results else 'coarse'
psnr_ = psnr(results[f'rgb_{typ}'], rgbs)
self.log('lr', get_learning_rate(self.optimizer))
self.log('train/loss', loss)
for k, v in loss_d.items():
self.log(f'train/{k}', v, prog_bar=True)
self.log('train/psnr', psnr_, prog_bar=True)
return loss
def validation_step(self, batch, batch_nb):
rays, rgbs, ts = batch['rays'], batch['rgbs'], batch['ts']
rays = rays.squeeze() # (H*W, 3)
rgbs = rgbs.squeeze() # (H*W, 3)
ts = ts.squeeze() # (H*W)
results = self(rays, ts)
loss_d = self.loss(results, rgbs)
loss = sum(l for l in loss_d.values())
log = {'val_loss': loss}
typ = 'fine' if 'rgb_fine' in results else 'coarse'
if batch_nb == 0:
if self.hparams.dataset_name == 'phototourism':
WH = batch['img_wh']
W, H = WH[0, 0].item(), WH[0, 1].item()
else:
W, H = self.hparams.img_wh
img = results[f'rgb_{typ}'].view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W)
img_gt = rgbs.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W)
depth = visualize_depth(results[f'depth_{typ}'].view(H, W)) # (3, H, W)
stack = torch.stack([img_gt, img, depth]) # (3, 3, H, W)
self.logger.experiment.add_images('val/GT_pred_depth',
stack, self.global_step)
psnr_ = psnr(results[f'rgb_{typ}'], rgbs)
log['val_psnr'] = psnr_
return log
def validation_epoch_end(self, outputs):
mean_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
mean_psnr = torch.stack([x['val_psnr'] for x in outputs]).mean()
self.log('val/loss', mean_loss)
self.log('val/psnr', mean_psnr, prog_bar=True)
def main(hparams):
system = NeRFSystem(hparams)
checkpoint_callback = \
ModelCheckpoint(filepath=os.path.join(f'ckpts/{hparams.exp_name}',
'{epoch:d}'),
monitor='val/psnr',
mode='max',
save_top_k=-1)
logger = TestTubeLogger(save_dir="logs",
name=hparams.exp_name,
debug=False,
create_git_tag=False,
log_graph=False)
trainer = Trainer(max_epochs=hparams.num_epochs,
checkpoint_callback=checkpoint_callback,
resume_from_checkpoint=hparams.ckpt_path,
logger=logger,
weights_summary=None,
progress_bar_refresh_rate=hparams.refresh_every,
gpus=hparams.num_gpus,
accelerator='ddp' if hparams.num_gpus>1 else None,
num_sanity_val_steps=1,
benchmark=True,
profiler="simple" if hparams.num_gpus==1 else None)
trainer.fit(system)
if __name__ == '__main__':
hparams = get_opts()
main(hparams)