当只有一颗GPU时,常规方法通常难以训练大batchsize下的大模型。torch.utils.checkpoint 梯度检查点技术可以用于解决该问题,可阅读文章:
Chen T, Xu B, Zhang C, et al. Training deep nets with sublinear memory cost[J]. arXiv preprint arXiv:1604.06174, 2016.
基本原理:
网络训练高效内存管理——torch.utils.checkpoint的使用_风筝大晒的博客-CSDN博客_torch.utils.checkpoint
使用示例:
Training larger-than-memory PyTorch models using gradient checkpointing
译文:使用梯度检查点(gradient checkpointing)训练比内存还大的pytorch模型 – POLARAI.CN
下面给出我写的简单示例,注意:第一层建议不要使用checkpoint,dropout和batch normalization层不能用checkpoint(二者起冲突)。
# --------------
# A simple example for pytorch checkpoint to pretrained GPU memory
# Author: qingwen guo
# Time: 2022-03-10 15:07
# --------------
import torch
import torch.nn as nn
import numpy as np
import torch.utils.checkpoint as cp
x = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
x = torch.Tensor(x).float()
y = np.array([1, 0, 0, 1])
y = torch.Tensor(y).long()
class MyNet(nn.Module):
def __init__(self, save_memory=False):
super(MyNet, self).__init__()
self.linear1 = nn.Linear(2, 50)
self.linear2 = nn.Linear(50, 30)
self.linear3 = nn.Linear(30, 2)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout(p=0.5)
self.save_memory = save_memory
def forward(self, x):
if self.save_memory:
x = self.linear1(x)
x = self.relu(x)
x = cp.checkpoint(self.linear2, x)
x = self.dropout(x)
x = cp.checkpoint(self.linear3, x)
else:
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
x = self.dropout(x)
x = self.linear3(x)
return x
net = MyNet(save_memory=True)
# train() enables some modules like dropout, and eval() does the opposit
net.train()
# set the optimizer where lr is the learning-rate
optimizer = torch.optim.SGD(net.parameters(), lr=0.05)
loss_func = nn.CrossEntropyLoss()
for epoch in range(50000):
if epoch % 5000 == 0:
# call eval() and evaluate the model on the validation set
# when calculate the loss value or evaluate the model on the validation set,
# it's suggested to use "with torch.no_grad()" to pretrained the memory. Here I didn't use it.
net.eval()
out = net(x)
loss = loss_func(out, y)
print(loss.detach().numpy())
# call train() and train the model on the training set
net.train()
out = net(x)
loss = loss_func(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 5000 == 0:
net.eval()
out = net(x)
loss = loss_func(out, y)
print(loss.detach().numpy())
print('----')
net.train()
if epoch % 1000 == 0:
# adjust the learning-rate
# weight decay every 1000 epochs
lr = optimizer.param_groups[0]['lr']
lr *= 0.9
for param_group in optimizer.param_groups:
param_group['lr'] = lr
net.eval()
print(net(x).data)