Pytorch Lightning使用:【LightningModule、LightningDataModule、Trainer、ModelCheckpoint】

pytorch lightning official manual  

pytorch lightning 官方手册  Welcome to ⚡ PyTorch Lightning — PyTorch Lightning 2.1.0dev documentationhttps://lightning.ai/docs/pytorch/latest/

Introduction to Pytorch Lightning

PyTorch Lightning is a deep learning framework for professional AI researchers and machine learning engineers who need maximum flexibility without sacrificing large-scale performance. lightning brings your idea to paper and product at the same speed.

LightningModule is a lightweight structure of the original PyTorch, allowing maximum flexibility and minimum library files. It acts as a model "recipe" that specifies all the training details. 

Write 80% less code. Lightning removes about 80% of the duplicated code (boilerplate) to minimize the surface area for bugs so you can focus on delivering value instead of engineering.

Maintaining maximum flexibility, the complete PyTorch training code can be defined in training_step.

Handle data sets of any size, without special requirements, directly use PyTorch dataloader to process massive data sets

 Install Lightning

pip install lightning

 or

conda install lightning -c conda-forge

 Import related packages after installation

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import TestTubeLogger

 Define LightningModule

LightningModule organizes your PyTorch code into 6 parts:

Initialization (__init__ and setup()).
Training (training_step())
Validation (validation_step())
Testing (test_step())
Prediction (predict_step())
Optimizer and LR scheduler (configure_optimizers())

When you use Lightning, the code isn't abstracted -- it's just organized. All other code not in LightningModule has been automatically executed for you by Trainer.

net = MyLightningModuleNet()
trainer = Trainer()
trainer.fit(net)

 No .cuda() or .to(device) calls are required. Lightning already does this for you. as follows:

# don't do in Lightning
x = torch.Tensor(2, 3)
x = x.cuda()
x = x.to(device)

# do this instead
x = x  # leave it alone!

# or to init a new tensor
new_x = torch.Tensor(2, 3)
new_x = new_x.to(x)

When running under the distributed strategy, Lightning handles distributed samplers for you by default. 

# Don't do in Lightning...
data = MNIST(...)
sampler = DistributedSampler(data)
DataLoader(data, sampler=sampler)

# do this instead
data = MNIST(...)
DataLoader(data)

 LightningModule is actually a torch.nn.Module, but with some added features:

net = Net.load_from_checkpoint(PATH)
net.freeze()
out = net(x)

Example: Building a Network with LightningTraining a Network

1. Build the model

import lightning.pytorch as pl
import torch.nn as nn
import torch.nn.functional as F


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

2 Train the network

train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
trainer = pl.Trainer(max_epochs=1)
model = LitModel()

trainer.fit(model, train_dataloaders=train_loader)

3 other LightningModule:

Name

Description

__init__ and setup()

initialization

forward()

Run data through model only (separate from training_step)

training_step()

Complete training steps

validation_step()

Complete Verification Steps

test_step()

complete test procedure

predict_step()

Complete Prediction Steps

configure_optimizers()

Define optimizer and LR scheduler

3.1 Lightning dataset loading

There are two implementation methods for datasets:

  • Directly call third-party public datasets (such as: MNIST and other datasets)
  • Custom dataset (inherit torch.utils.data.dataset.Dataset, custom class)

3.1.1 Using public datasets

from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl

class MyExampleModel(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
        train_dataset, val_dataset, test_dataset = random_split(dataset, [50000, 5000, 5000])

        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        ...

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=1, shuffle=True)

3.1.2 Custom dataset

(1) Complete the compilation of dataset by yourself

# -*- coding: utf-8 -*-
'''
@Description: Define the format of data used in the model.
'''

import sys
import pathlib
import torch
from torch.utils.data import Dataset
from utils import sort_batch_by_len, source2ids

abs_path = pathlib.Path(__file__).parent.absolute()
sys.path.append(sys.path.append(abs_path))


class SampleDataset(Dataset):
    """
    The class represents a sample set for training.
    """

    def __init__(self, data_pairs, vocab):
        self.src_texts = [data_pair[0] for data_pair in data_pairs]
        self.tgt_texts = [data_pair[1] for data_pair in data_pairs]
        self.vocab = vocab
        self._len = len(data_pairs)  # Keep track of how many data points.

    def __len__(self):
        return self._len
        
    def __getitem__(self, index):
        # print("\nself.src_texts[{0}] = {1}".format(index, self.src_texts[index]))
        src_ids, oovs = source2ids(self.src_texts[index], self.vocab)  # 将当前文本self.src_texts[index]转为ids,oovs为超出词典范围的词汇文本
        item = {
            'x': [self.vocab.SOS] + src_ids + [self.vocab.EOS],
            'y': [self.vocab.SOS] + [self.vocab[i] for i in self.tgt_texts[index]] + [self.vocab.EOS],
            'x_len': len(self.src_texts[index]),
            'y_len': len(self.tgt_texts[index]),
            'oovs': oovs,
            'len_oovs': len(oovs)
        }

        return item

(2) Customize the DataModule class (inherited from LightningDataModule) to call DataLoader

from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl


class MyDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()

    def prepare_data(self):
        # 在该函数里一般实现数据集的下载等,只有cuda:0 会执行该函数
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
        pass
    def forward()

    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
        # 实现数据集的定义,每张GPU都会执行该函数, stage 用于标记是用于什么阶段
        if stage == 'fit' or stage is None:
            self.train_dataset = MyDataset(self.train_file_path, self.train_file_num, transform=None)
            self.val_dataset = MyDataset(self.val_file_path, self.val_file_num, transform=None)
        if stage == 'test' or stage is None:
            self.test_dataset = MyDataset(self.test_file_path, self.test_file_num, transform=None)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=1, shuffle=True)

3.2Training

3.2.1Training Loop:

To activate the training loop, override training_step().

class LitClassifier(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        return loss  #一定要返回loss,其中batch 即为从 train_dataloader 采样的一个batch的数据,batch_idx即为目前batch的索引

3.2.2 Train Epoch-level Metrics:

If you want to compute time-level metrics and log them, use log().

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)

    # logs metrics for each training_step,
    # and the average across the epoch, to the progress bar and logger
    self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    return loss

3.2.3Train Epoch-level Operations

Override the on_train_epoch_end() method if you need to use all outputs from each training_step().

def __init__(self):
    super().__init__()
    self.training_step_outputs = []


def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    preds = ...
    self.training_step_outputs.append(preds)
    return loss


def on_train_epoch_end(self):
    all_preds = torch.stack(self.training_step_outputs)
    # do something with all preds
    ...
    self.training_step_outputs.clear()  # free memory

3.3 Validation

3.3.1 Validation Loop

To activate the validation loop during training, override the validation_step() function.

class LitModel(pl.LightningModule):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss)

It is also possible to run only the validation loop on the validation data loader by overriding validation_step() and calling validate().

model = Model()
trainer = Trainer()
trainer.validate(model)

 Validation on individual devices is recommended to ensure each sample/sample is evaluated exactly once. This helps ensure that research papers are being benchmarked in the correct manner. Otherwise, in a multi-device setup, samples may be duplicated when using DistributedSampler, eg strategy="ddp" . It replicates some samples across some devices to ensure that all devices have the same batch size in case of uneven input.

3.3.2 Validation Epoch-level Metrics

Override the on_validation_epoch_end() function if you need to use all the outputs of each validation_step(). Note that this method is called before on_train_epoch_end().

def __init__(self):
    super().__init__()
    self.validation_step_outputs = []


def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    pred = ...
    self.validation_step_outputs.append(pred)
    return pred


def on_validation_epoch_end(self):
    all_preds = torch.stack(self.validation_step_outputs)
    # do something with all preds
    ...
    self.validation_step_outputs.clear()  # free memory

3.4 Testing

3.4.1Test Loop

The procedure for enabling a test loop is the same as for enabling a verification loop. See the above section for details. To do this, rewrite the test_step() function.

model = Model()
trainer = Trainer()
trainer.fit(model)

# automatically loads the best weights for you
trainer.test(model)

有两种方式来调用test():

# call after training
trainer = Trainer()
trainer.fit(model)

# automatically auto-loads the best weights from the previous run
trainer.test(dataloaders=test_dataloader)

# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
trainer.test(model, dataloaders=test_dataloader)

As above, validation on a single device is recommended to ensure each sample is evaluated exactly once. This helps ensure that research papers are being benchmarked in the correct manner. Otherwise, in a multi-device setup, samples may be duplicated when using DistributedSampler, eg. policy="ddp". It replicates some samples across some devices to ensure that all devices have the same batch size in case of uneven input.

3.5 Inference

3.5.1Prediction Loop

By default, the predict_step() method runs the forward() method. To customize this behavior, just override the predict_step() method. As follows, rewrite predict_step() and try Monte Carlo Dropout:

class LitMCdropoutModel(pl.LightningModule):
    def __init__(self, model, mc_iteration):
        super().__init__()
        self.model = model
        self.dropout = nn.Dropout()
        self.mc_iteration = mc_iteration

    def predict_step(self, batch, batch_idx):
        # enable Monte Carlo Dropout
        self.dropout.train()

        # take average of `self.mc_iteration` iterations
        pred = torch.vstack([self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]).mean(dim=0)
        return pred

Called in two ways  predict()

# call after training
trainer = Trainer()
trainer.fit(model)

# automatically auto-loads the best weights from the previous run
predictions = trainer.predict(dataloaders=predict_dataloader)

# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
predictions = trainer.predict(model, dataloaders=test_dataloader)

NOTE:

The training_step is followed by its corresponding training_step_end(self, batch_parts) and training_epoch_end(self, training_step_outputs) functions;


Validation_step is followed by its corresponding validation_step_end(self, batch_parts) and validation_epoch_end(self, training_step_outputs) functions;


test_step is followed by its corresponding test_step_end(self, batch_parts) and test_epoch_end(self, training_step_outputs) functions
 

 3.6 Save the model with Trainer

Set the default_root_dir parameter in Trainer, Lightning will automatically save the most recently trained epoch model to the current workspace (or.getcwd()), or specify it when defining Trainer:

trainer = Trainer(default_root_dir='/your/path/to/save/checkpoints')

Autosave of models can also be turned off:

trainer = Trainer(checkpoint_callback=False)

3.7 Load the pre-trained model, the complete process

def main(hparams):
    system = NeRFSystem(hparams)
    checkpoint_callback = \
        ModelCheckpoint(filepath=os.path.join(f'ckpts/{hparams.exp_name}',
                                               '{epoch:d}'),
                        monitor='val/psnr',
                        mode='max',
                        save_top_k=-1)

    logger = TestTubeLogger(save_dir="logs",
                            name=hparams.exp_name,
                            debug=False,
                            create_git_tag=False,
                            log_graph=False)

    trainer = Trainer(max_epochs=hparams.num_epochs,
                      checkpoint_callback=checkpoint_callback,
                      resume_from_checkpoint=hparams.ckpt_path,
                      logger=logger,
                      weights_summary=None,
                      progress_bar_refresh_rate=hparams.refresh_every,
                      gpus=hparams.num_gpus,
                      accelerator='ddp' if hparams.num_gpus>1 else None,
                      num_sanity_val_steps=1,
                      benchmark=True,
                      profiler="simple" if hparams.num_gpus==1 else None)

    trainer.fit(system)


if __name__ == '__main__':
    hparams = get_opts()
    main(hparams)

4 The complete example is as follows, NeRFW:

import os
from opt import get_opts
import torch
from collections import defaultdict

from torch.utils.data import DataLoader
from datasets import dataset_dict

# models
from models.nerf import *
from models.rendering import *

# optimizer, scheduler, visualization
from utils import *

# losses
from losses import loss_dict

# metrics
from metrics import *

# pytorch-lightning
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import TestTubeLogger


class NeRFSystem(LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        # self.hparams.update(hparams)
        self.loss = loss_dict['nerfw'](coef=1)

        self.models_to_train = []
        self.embedding_xyz = PosEmbedding(hparams.N_emb_xyz-1, hparams.N_emb_xyz)
        self.embedding_dir = PosEmbedding(hparams.N_emb_dir-1, hparams.N_emb_dir)
        self.embeddings = {'xyz': self.embedding_xyz,
                           'dir': self.embedding_dir}

        if hparams.encode_a:
            self.embedding_a = torch.nn.Embedding(hparams.N_vocab, hparams.N_a)
            self.embeddings['a'] = self.embedding_a
            self.models_to_train += [self.embedding_a]
        if hparams.encode_t:
            self.embedding_t = torch.nn.Embedding(hparams.N_vocab, hparams.N_tau)
            self.embeddings['t'] = self.embedding_t
            self.models_to_train += [self.embedding_t]

        self.nerf_coarse = NeRF('coarse',
                                in_channels_xyz=6*hparams.N_emb_xyz+3,
                                in_channels_dir=6*hparams.N_emb_dir+3)
        self.models = {'coarse': self.nerf_coarse}
        if hparams.N_importance > 0:
            self.nerf_fine = NeRF('fine',
                                  in_channels_xyz=6*hparams.N_emb_xyz+3,
                                  in_channels_dir=6*hparams.N_emb_dir+3,
                                  encode_appearance=hparams.encode_a,
                                  in_channels_a=hparams.N_a,
                                  encode_transient=hparams.encode_t,
                                  in_channels_t=hparams.N_tau,
                                  beta_min=hparams.beta_min)
            self.models['fine'] = self.nerf_fine
        self.models_to_train += [self.models]

    def get_progress_bar_dict(self):
        items = super().get_progress_bar_dict()
        items.pop("v_num", None)
        return items

    def forward(self, rays, ts):
        """Do batched inference on rays using chunk."""
        B = rays.shape[0]
        results = defaultdict(list)
        for i in range(0, B, self.hparams.chunk):
            rendered_ray_chunks = \
                render_rays(self.models,
                            self.embeddings,
                            rays[i:i+self.hparams.chunk],
                            ts[i:i+self.hparams.chunk],
                            self.hparams.N_samples,
                            self.hparams.use_disp,
                            self.hparams.perturb,
                            self.hparams.noise_std,
                            self.hparams.N_importance,
                            self.hparams.chunk, # chunk size is effective in val mode
                            self.train_dataset.white_back)

            for k, v in rendered_ray_chunks.items():
                results[k] += [v]

        for k, v in results.items():
            results[k] = torch.cat(v, 0)
        return results

    def setup(self, stage):
        dataset = dataset_dict[self.hparams.dataset_name]
        kwargs = {'root_dir': self.hparams.root_dir}
        if self.hparams.dataset_name == 'phototourism':
            kwargs['img_downscale'] = self.hparams.img_downscale
            kwargs['val_num'] = self.hparams.num_gpus
            kwargs['use_cache'] = self.hparams.use_cache
        elif self.hparams.dataset_name == 'blender':
            kwargs['img_wh'] = tuple(self.hparams.img_wh)
            kwargs['perturbation'] = self.hparams.data_perturb
        self.train_dataset = dataset(split='train', **kwargs)
        self.val_dataset = dataset(split='val', **kwargs)

    def configure_optimizers(self):
        self.optimizer = get_optimizer(self.hparams, self.models_to_train)
        scheduler = get_scheduler(self.hparams, self.optimizer)
        return [self.optimizer], [scheduler]

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          shuffle=True,
                          num_workers=4,
                          batch_size=self.hparams.batch_size,
                          pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          shuffle=False,
                          num_workers=4,
                          batch_size=1, # validate one image (H*W rays) at a time
                          pin_memory=True)
    
    def training_step(self, batch, batch_nb):
        rays, rgbs, ts = batch['rays'], batch['rgbs'], batch['ts']
        results = self(rays, ts)
        loss_d = self.loss(results, rgbs)
        loss = sum(l for l in loss_d.values())

        with torch.no_grad():
            typ = 'fine' if 'rgb_fine' in results else 'coarse'
            psnr_ = psnr(results[f'rgb_{typ}'], rgbs)

        self.log('lr', get_learning_rate(self.optimizer))
        self.log('train/loss', loss)
        for k, v in loss_d.items():
            self.log(f'train/{k}', v, prog_bar=True)
        self.log('train/psnr', psnr_, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_nb):
        rays, rgbs, ts = batch['rays'], batch['rgbs'], batch['ts']
        rays = rays.squeeze() # (H*W, 3)
        rgbs = rgbs.squeeze() # (H*W, 3)
        ts = ts.squeeze() # (H*W)
        results = self(rays, ts)
        loss_d = self.loss(results, rgbs)
        loss = sum(l for l in loss_d.values())
        log = {'val_loss': loss}
        typ = 'fine' if 'rgb_fine' in results else 'coarse'
    
        if batch_nb == 0:
            if self.hparams.dataset_name == 'phototourism':
                WH = batch['img_wh']
                W, H = WH[0, 0].item(), WH[0, 1].item()
            else:
                W, H = self.hparams.img_wh
            img = results[f'rgb_{typ}'].view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W)
            img_gt = rgbs.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W)
            depth = visualize_depth(results[f'depth_{typ}'].view(H, W)) # (3, H, W)
            stack = torch.stack([img_gt, img, depth]) # (3, 3, H, W)
            self.logger.experiment.add_images('val/GT_pred_depth',
                                               stack, self.global_step)

        psnr_ = psnr(results[f'rgb_{typ}'], rgbs)
        log['val_psnr'] = psnr_

        return log

    def validation_epoch_end(self, outputs):
        mean_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        mean_psnr = torch.stack([x['val_psnr'] for x in outputs]).mean()

        self.log('val/loss', mean_loss)
        self.log('val/psnr', mean_psnr, prog_bar=True)


def main(hparams):
    system = NeRFSystem(hparams)
    checkpoint_callback = \
        ModelCheckpoint(filepath=os.path.join(f'ckpts/{hparams.exp_name}',
                                               '{epoch:d}'),
                        monitor='val/psnr',
                        mode='max',
                        save_top_k=-1)

    logger = TestTubeLogger(save_dir="logs",
                            name=hparams.exp_name,
                            debug=False,
                            create_git_tag=False,
                            log_graph=False)

    trainer = Trainer(max_epochs=hparams.num_epochs,
                      checkpoint_callback=checkpoint_callback,
                      resume_from_checkpoint=hparams.ckpt_path,
                      logger=logger,
                      weights_summary=None,
                      progress_bar_refresh_rate=hparams.refresh_every,
                      gpus=hparams.num_gpus,
                      accelerator='ddp' if hparams.num_gpus>1 else None,
                      num_sanity_val_steps=1,
                      benchmark=True,
                      profiler="simple" if hparams.num_gpus==1 else None)

    trainer.fit(system)


if __name__ == '__main__':
    hparams = get_opts()
    main(hparams)

Guess you like

Origin blog.csdn.net/qq_35831906/article/details/131386514