GPU显存不够用时,如何用PyTorch训练大模型(torch.utils.checkpoint的使用)

当只有一颗GPU时,常规方法通常难以训练大batchsize下的大模型。torch.utils.checkpoint 梯度检查点技术可以用于解决该问题,可阅读文章:

Chen T, Xu B, Zhang C, et al. Training deep nets with sublinear memory cost[J]. arXiv preprint arXiv:1604.06174, 2016.

基本原理:
网络训练高效内存管理——torch.utils.checkpoint的使用_风筝大晒的博客-CSDN博客_torch.utils.checkpoint

使用示例:
Training larger-than-memory PyTorch models using gradient checkpointing
译文:使用梯度检查点(gradient checkpointing)训练比内存还大的pytorch模型 – POLARAI.CN

下面给出我写的简单示例,注意:第一层建议不要使用checkpoint,dropout和batch normalization层不能用checkpoint(二者起冲突)。

# --------------
# A simple example for pytorch checkpoint to pretrained GPU memory
# Author: qingwen guo
# Time: 2022-03-10 15:07
# --------------
import torch
import torch.nn as nn
import numpy as np
import torch.utils.checkpoint as cp

x = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
x = torch.Tensor(x).float()
y = np.array([1, 0, 0, 1])
y = torch.Tensor(y).long()


class MyNet(nn.Module):
    def __init__(self, save_memory=False):
        super(MyNet, self).__init__()

        self.linear1 = nn.Linear(2, 50)
        self.linear2 = nn.Linear(50, 30)
        self.linear3 = nn.Linear(30, 2)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(p=0.5)

        self.save_memory = save_memory

    def forward(self, x):
        if self.save_memory:
            x = self.linear1(x)
            x = self.relu(x)
            x = cp.checkpoint(self.linear2, x)
            x = self.dropout(x)
            x = cp.checkpoint(self.linear3, x)
        else:
            x = self.linear1(x)
            x = self.relu(x)
            x = self.linear2(x)
            x = self.dropout(x)
            x = self.linear3(x)

        return x


net = MyNet(save_memory=True)
# train() enables some modules like dropout, and eval() does the opposit
net.train()

# set the optimizer where lr is the learning-rate
optimizer = torch.optim.SGD(net.parameters(), lr=0.05)
loss_func = nn.CrossEntropyLoss()

for epoch in range(50000):
    if epoch % 5000 == 0:
        # call eval() and evaluate the model on the validation set
        # when calculate the loss value or evaluate the model on the validation set,
        # it's suggested to use "with torch.no_grad()" to pretrained the memory. Here I didn't use it.
        net.eval()
        out = net(x)
        loss = loss_func(out, y)
        print(loss.detach().numpy())
        # call train() and train the model on the training set
        net.train()

    out = net(x)
    loss = loss_func(out, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 5000 == 0:
        net.eval()
        out = net(x)
        loss = loss_func(out, y)
        print(loss.detach().numpy())
        print('----')
        net.train()

    if epoch % 1000 == 0:
        # adjust the learning-rate
        # weight decay every 1000 epochs
        lr = optimizer.param_groups[0]['lr']
        lr *= 0.9
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

net.eval()
print(net(x).data)

猜你喜欢

转载自blog.csdn.net/u014134327/article/details/123419367
今日推荐