Pytorch-lightning

Pytorch-lightning

Introduction

At present, it seems that most AI training and learning frameworks use pytorch-lightning, so let’s learn about it today, and use it proficiently in the future. The official definition is: build and train Pytorch models, and use Lightning Apps templates to connect them to ML life cycle without having to deal with DIY infrastructure, cost management, scaling and other headaches.

How to Use

  1. Install
pip install pytorch-lightning
  1. Add the imports

    import os
    import torch
    from torch import nn
    import torch.nn.functional as F
    from torchvision.datasets import MNIST
    from torch.utils.data import DataLoader,random_split
    from torchvision import transforms
    import pytorch_lightning as pl
    
  2. Define a LightningModule (nn.Module)

    class LitAutoEncoder(pl.LightningModuel):
    	def __init__(self)super().__init__()
    		self.encoder=nn.Sequential(nn.Linear(28*28,128),nn.ReLU(),nn.Linear(128,3))
    		self.decoder=nn.Sequential(nn.Linear(3,128),nn.ReLU(),nn.Linear(128,28*28))
    	def forward(self,x):
    		embedding=self.encoder(x)
    		return embedding
    	
    	def training_step(self,batch,batch_idx):
    		x,y=batch
    		x=x.view(x.size(0),-1)
    		z=self.encoder(x)
    		x_hat=self.decoder(z)
    		loss=F.mse_loss(x_hat,x)
    		self.log('train_loss',loss)
    		return loss
    		
    	def configure_optimizers(self):
    		optimizer=torch.optim.Adam(self.parameters(),lr=1e-3)
    		return optimizer
    
  3. Train

    dataset=MNIST(os.getcwd(),download=True,transform=transforms.ToTensor())
    train,val=random_split(dataset,[55000,5000])
    
    autoencoder=LitAutoEncoder()
    trainer=pl.Trainer()
    trainer.fit(autoencoder,DataLoader(train),DataLoader(val))
    

Advanced feature

  • Multi-GPU

    trainer=Trainer(max_epochs=1,accelerator='gpu',device=8)
    
  • TPU

  • 16-bit precision

  • experimental record

  • early_stopping

    es=EarlyStopping(monitor='val_loss')
    trainer=Trainer(callbacks=[checkpointing])
    
  • model checkpoint

    checkpointing=ModelCheckpoint(monitor='val_loss')
    trainer=Trainer(callbacks=[checkpointing])
    
  • torchscript

    # torchscript
    autoencoder = LitAutoEncoder()
    torch.jit.save(autoencoder.to_torchscript(), "model.pt")
    
  • ONNX

    # onnx
    with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile:
        autoencoder = LitAutoEncoder()
        input_sample = torch.randn((1, 64))
        autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True)
        os.path.isfile(tmpfile.name)
    
  • training tricks
    40+ training tricks for us to choose

Advantages

  • Model is hardware independent
  • code simplification
  • has been refactored
  • make fewer mistakes
  • Preserves flexibility, but removes a lot of samples
  • Has integrations with popular machine learning tools
  • Different Python, Pytorch version, operating system, GPT support
  • run faster

Manual control of the training process

class LitAntoEncoder(pl.LightningModule):
	def __init__(self):
		super().__init__()
		self.automatic_optimization=False
		
	def training_step(self,batch,batch_idx):
		# access your optimizers with use_pl_optimizer=False. Default is True
        opt_a, opt_b = self.optimizers(use_pl_optimizer=True)

        loss_a = ...
        self.manual_backward(loss_a, opt_a)
        opt_a.step()
        opt_a.zero_grad()

        loss_b = ...
        self.manual_backward(loss_b, opt_b, retain_graph=True)
        self.manual_backward(loss_b, opt_b)
        opt_b.step()
        opt_b.zero_grad()

Example

Hello world
  • MNIST
Contrastive Learning
  • BYOL
  • CPC v2
  • Moco v2
  • SIMCLR
NLP
  • GPT-2
  • BERT

Reinforcement Learning

  • DQN
  • Dueling-DQN
  • Reinforce
Vision
  • HOWEVER
Classic ML
  • Logistic Regression
  • Linear Regression

Official API Tutorial

Welcome to ⚡ PyTorch Lightning — PyTorch Lightning 1.8.0dev documentation (pytorch-lightning.readthedocs.io)

Summarize

Pytorch-lightning must be very useful as a 2w star github project. At present, I have only tried some examples. I need to fully grasp the simple syntax in pytorch-ligthning, and then it can really help us reduce repetitive AI code writing.

Guess you like

Origin blog.csdn.net/be_humble/article/details/126638270