Due to various problems we encountered while using pytorch, the framework Pytorch-Lighting emerged.
A major feature of Pytorch-Lighting is to look at the model and the system separately. The system defines how a set of models interact with each other, such as GAN (generator network and discriminator network), Seq2Seq (encoder and decoder network) and Bert. Sometimes the problem involves only one model, such as UNet, ResNet, etc., then this system can be a general system to describe how the model is used and can be reused in many other projects.
Under the Pytorch-Lighting framework, each network includes how to train, how to test, optimizer definition, etc.
Face Mask Detector in the Lightning community: https://towardsdatascience.com/how-i-built-a-face-mask-detector-for-covid-19-using-pytorch-lightning-67eb3752fd61
PyTorch Lightning tutorial with handwriting recognition as a routine: https://colab.research.google.com/drive/1Mowb4NzWlRCxzAFjOIJqUmmk_wAT-XP3
1. Install
pip install pytorch-lightning
2. Network design
import torch
from torch import nn
import pytorch_lightning as pl
class LightningMNISTClassifier(pl.LightningModule):
def __init__(self):
super(LightningMNISTClassifier, self).__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 256)
self.layer_3 = torch.nn.Linear(256, 10)
def forward(self, x):
batch_size, channels, width, height = x.siz()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
# layer 1
x = self.layer_1(x)
x = torch.relu(x)
# layer 2
x = self.layer_2(x)
x = torch.relu(x)
# layer 3
x = self.layer_3(x)
# probability distribution over labels
x = torch.log_softmax(x, dim=1)
return x
It can be found that PyTorch and LP are almost the same.
# restore with PyTorch
pytorch_model = MNISTClassifier()
pytorch_model.load_state_dict(torch.load(PATH))
model.eval()
lightning_model = LightningMNISTClassifier.load_from_checkpoint(PATH)
lightning_model.eval()
3. Data
Let's generate three datasets of the MNIST dataset - train, validation and test.
In pytorch, datasets are added to Dataloader, which handles loading, shuffling, and batching of datasets.
- Image conversion.
- Generate training, validation, and test datasets.
- Load each dataset into DataLoader.
as follows
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torchvision import datasets, transforms
# ----------------
# TRANSFORMS
# ----------------
# prepare transforms standard to MNIST
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
# ----------------
# TRAINING, VAL DATA
# ----------------
mnist_train = MNIST(os.getcwd(), train=True, download=True)
# train (55,000 images), val split (5,000 images)
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
# ----------------
# TEST DATA
# ----------------
mnist_test = MNIST(os.getcwd(), train=False, download=True)
# ----------------
# DATALOADERS
# ----------------
# The dataloaders handle shuffling, batching, etc...
mnist_train = DataLoader(mnist_train, batch_size=64)
mnist_val = DataLoader(mnist_val, batch_size=64)
mnist_test = DataLoader(mnist_test, batch_size=64)
In PyTorch, this data loading can be done anywhere in the training program, while in PyTorch Lightning, the dataloader can be used directly, or a combination of the three methods can be used under the LightningDataModule
train_dataloader()
val_dataloader()
test_dataloader()
There is a fourth method for data preparation/downloading.
prepare_data()
Lightning takes this approach so that every model implemented with Lightning follows the same structure. This makes the code extremely readable and organized.
That is, when you encounter a project that uses Lightning, it is clear from the code where the data processing/downloading takes place.
The LPs are as follows:
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torchvision import datasets, transforms
class MNISTDataModule(pl.LightningDataModule):
def setup(self, stage):
# transforms for images
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
# prepare transforms standard to MNIST
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=64)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=64)
def test_dataloader(self):
return DataLoader(self,mnist_test, batch_size=64)
Optimizer Optimizer
We choose the Adam optimizer, and the Pytorch code is as follows:
pytorch_model = MNISTClassifier()
optimizer = torch.optim.Adam(pytorch_model.parameters(), lr=1e-3)
In Lightning, pass in self.parameters()
instead of a model, because LightningModule is a model.
class LightningMNISTClassifier(pl.LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
Loss
For n-way classification, the logarithm is taken and the cross-entropy loss is calculated. The cross entropy is the same as will be used NegativeLogLikelihood(log_softmax)
.
from torch.nn import functional as F
def cross_entropy_loss(logits, labels):
return F.nll_loss(logits, labels)
In PyTorch Lightning, we use the exact same code to compute the loss. But we can put it anywhere in the file.
from torch.nn import functional as F
class LightningMNISTClassifier(pl.LightningModule):
def cross_entropy_loss(self, logits, labels):
return F.nll_loss(logits, labels)
Training training loop
We have now defined all the key components of a neural network:
- A model (3-layer network)
- Dataset (MNIST)
- an optimizer (Adam)
- A loss function (cross-entropy loss)
Now, we have implemented a complete training program as follows:
- 迭代 Iterates for many epochs。
D = { ( x 1 , y 1 ) , . . . , ( x n , y n ) } D=\{(x_1,y_1),...,(x_n,y_n)\} D={(x1,y1),...,(xn,yn)} - In each epoch, we iterate over the dataset in batches.
b ∈ D b ∈ Db∈D - Make a forward pass.
y ^ = f ( x ) \hat y=f(x)y^=f(x) - Define
L = − ∑ i = 1 C ( y ^ log ( y ^ i ) ) \begin{aligned} L&=-\sum_{i=1}^C(\hat y \mathrm log(\hat y_i) )\\ \end{aligned}L=−i=1∑C(y^log(y^i)) - Perform a backward pass, computing gradients for all weights.
∇ ω i = ∂ L ∂ ω i ∀ ω i ∇\omega_i = \frac{\partial L}{\partial\omega_i} \qquad ∀\omega_i∇ωi=∂ωi∂L∀ωi - These gradients are fed back to each weight.
ω i = ω i + α ∇ ω i \omega_i=\omega_i+α∇\omega_iohi=ohi+α∇ωi
num_epochs = 100
for epoch in range(num_epochs): # (1)
for batch in dataloader: # (2)
x, y = batch
logits = model(x) # (3)
loss = cross_entropy_loss(logits, y) # (4)
loss.backward() # (5)
optimizer.step() # (6)
PyTorch Training loop
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
import os
# -----------------
# MODEL
# -----------------
class LightningMNISTClassifier(pl.LightningModule):
def __init__(self):
super(LightningMNISTClassifier, self).__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 256)
self.layer_3 = torch.nn.Linear(256, 10)
def forward(self, x):
batch_size, channels, width, height = x.sizes()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
# layer 1
x = self.layer_1(x)
x = torch.relu(x)
# layer 2
x = self.layer_2(x)
x = torch.relu(x)
# layer 3
x = self.layer_3(x)
# probability distribution over labels
x = torch.log_softmax(x, dim=1)
return x
# ----------------
# DATA
# ----------------
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
# train (55,000 images), val split (5,000 images)
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
mnist_test = MNIST(os.getcwd(), train=False, download=True)
# The dataloaders handle shuffling, batching, etc...
mnist_train = DataLoader(mnist_train, batch_size=64)
mnist_val = DataLoader(mnist_val, batch_size=64)
mnist_test = DataLoader(mnist_test, batch_size=64)
# ----------------
# OPTIMIZER
# ----------------
pytorch_model = MNISTClassifier()
optimizer = torch.optim.Adam(pytorch_model.parameters(), lr=1e-3)
# ----------------
# LOSS
# ----------------
def cross_entropy_loss(logits, labels):
return F.nll_loss(logits, labels)
# ----------------
# TRAINING LOOP
# ----------------
num_epochs = 1
for epoch in range(num_epochs):
# TRAINING LOOP
for train_batch in mnist_train:
x, y = train_batch
logits = pytorch_model(x)
loss = cross_entropy_loss(logits, y)
print('train loss: ', loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
# VALIDATION LOOP
with torch.no_grad():
val_loss = []
for val_batch in mnist_val:
x, y = val_batch
logits = pytorch_model(x)
val_loss.append(cross_entropy_loss(logits, y).item())
val_loss = torch.mean(torch.tensor(val_loss))
print('val_loss: ', val_loss.item())
PyTorch Lightning Training Loop
To do this in Lightning, we extract the main parts of the training and validation loops into three functions:
- training_step
- validation_step
- validation_end
for epoch in range(num_epochs):
# TRAINING LOOP
for train_batch in mnist_train:
x, y = train_batch # training_step
logits = pytorch_model(x) # training_step
loss = cross_entropy_loss(logits, y) # training_step
print('train loss: ', loss.item()) # training_step
loss.backward()
optimizer.step()
optimizer.zero_grad()
# VALIDATION LOOP
with torch.no_grad():
val_loss = []
for val_batch in mnist_val:
x, y = val_batch # validation_step
logits = pytorch_model(x) # validation_step
val_loss.append(cross_entropy_loss(logits, y).item())
val_loss = torch.mean(torch.tensor(val_loss)) # validation_epoch_end
print('val_loss: ', val_loss.item())
training_step
The process of the training loop is shown.
class LightningMNISTClassifier(pl.LightningModule):
def training_step(self, batch, batch_idx):
x, y = train_batch
logits = self.forward(x) # we already defined forward and loss in the lightning module. We'll show the full code next
loss = self.cross_entropy_loss(logits, y)
self.log('train_loss', loss)
return loss
Trainer(precision=16)
validation_step
What is executed is the process of verifying the loop. But we compute the average loss over all batches in the validation loop. For this we use validation_end
, which receives an output list(output)
that includes each batch .validation_step
outputs = []
for batch in validation_dataloader:
loss = some_loss(batch) # validation_step
outputs.append(loss # validation_step
outputs = outputs.mean() # validation_epoch_end
class LightningMNISTClassifier(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = self.cross_entropy_loss(logits, y)
self.log('val_loss', loss)
And finally the complete LightningModule.
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
import os
class LightningMNISTClassifier(pl.LightningModule):
def __init__(self):
super(LightningMNISTClassifier, self).__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 256)
self.layer_3 = torch.nn.Linear(256, 10)
def forward(self, x):
batch_size, channels, width, height = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
# layer 1 (b, 1*28*28) -> (b, 128)
x = self.layer_1(x)
x = torch.relu(x)
# layer 2 (b, 128) -> (b, 256)
x = self.layer_2(x)
x = torch.relu(x)
# layer 3 (b, 256) -> (b, 10)
x = self.layer_3(x)
# probability distribution over labels
x = torch.log_softmax(x, dim=1)
return x
def cross_entropy_loss(self, logits, labels):
return F.nll_loss(logits, labels)
def training_step(self, batch, batch_idx):
x, y = train_batch
logits = self.forward(x) # we already defined forward and loss in the lightning module. We'll show the full code next
loss = self.cross_entropy_loss(logits, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = self.cross_entropy_loss(logits, y)
self.log('val_loss', loss)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# train
model = LightningMNISTClassifier()
trainer = pl.Trainer()
trainer.fit(model)
It can be found:
- This structure is very standardized.
- Here's the same PyTorch code, only it's organized.
- The inner loop of the PyTorch training code becomes
training_step
. But we don't need to do anything about gradients, because Lightning will do it automatically.
for train_batch in mnist_train:
x, y = train_batch # training_step
logits = pytorch_model(x) # training_step
loss = cross_entropy_loss(logits, y) # training_step
print('train loss: ', loss.item()) # training_step
loss.backward()
optimizer.step()
optimizer.zero_grad()
- The inner loop of the PyTorch verification code becomes
validation_step
.
outputs = []
for batch in validation_dataloader:
loss = some_loss(batch) # validation_step
outputs.append(loss # validation_step
- Whereas
validation_end
steps allow us to compute metrics for the entire validation set. Again, we don't need to turn on gradients or freeze the model or loop any structures. Lightning will do it automatically for us.
outputs = outputs.mean() # validation end
validation_loss = outputs
For greater portability and normalization, we can actually pull out the model definition, which will let us pass in arbitrary classifiers.
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
import os
class Backbone(nn.Module):
def __init__(self):
super().__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 256)
self.layer_3 = torch.nn.Linear(256, 10)
def forward(self, x):
batch_size, channels, width, height = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
# layer 1 (b, 1*28*28) -> (b, 128)
x = self.layer_1(x)
x = torch.relu(x)
# layer 2 (b, 128) -> (b, 256)
x = self.layer_2(x)
x = torch.relu(x)
# layer 3 (b, 256) -> (b, 10)
x = self.layer_3(x)
return x
class LightningClassifier(pl.LightningModule):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
def cross_entropy_loss(self, logits, labels):
return F.nll_loss(logits, labels)
def training_step(self, batch, batch_idx):
x, y = train_batch
logits = self.backbone(x)
# probability distribution over labels
x = torch.log_softmax(x, dim=1)
loss = self.cross_entropy_loss(logits, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = train_batch
# probability distribution over labels
logits = self.backbone(x)
x = torch.log_softmax(x, dim=1)
loss = self.cross_entropy_loss(logits, y)
self.log('val_loss', loss)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# train
model = Backbone()
classifier = LightningClassifier(model)
trainer = pl.Trainer()
trainer.fit(classifier)