DLA ニューラル ネットワークの極端なトレーニング方法: 勾配チェックポインティング

勾配チェックポイント設定

        一般に、トレーニング プロセスでは (GPU か CPU かに関係なく) 中間結果を保存する必要があります。順方向伝播は入力 (bottom_data) に基づいて出力 (top_data) を計算し、逆方向伝播は (変数がトレーニング用の勾配を開く場合) top_diff からbottom_diff を計算します。top和bottom是包含数据和梯度的两个结构体,整个网络的每层top和bottom在训练的过程中都会保存,这消耗了大量的内存。

        これらの変数が保存されていない場合、再配布と計算によりメモリ使用量が大幅に削減されますが、ネットワークのトレーニング時間が無限に長くなります。これら 2 つの矛盾のバランスを取るために、論文「サブリニア メモリ コストによるディープ ネットのトレーニング」では、サブリニア メモリ コストを使用してディープ ネットワークをトレーニングしています。ディープ ニューラル ネットワークのトレーニングのメモリ消費を削減する体系的なアプローチを提案しています。具体的には、n 層ネットワークをトレーニングするために O(sqrt(n)) メモリを消費し、ミニバッチあたりの追加のフォワード パスの計算コストのみがかかるアルゴリズムを設計します。チェックポイントが設定された特徴マップは、sqrt(n) ごとに保持されます。

コード

  • https://pytorch.org/docs/stable/checkpoint.html
// https://discuss.pytorch.org/t/trying-to-understand-torch-utils-checkpoint/95224
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm.notebook import tqdm

from torch import optim
import torchvision.models as models
from torch import nn

CHECKPOINT = True
BATCH_SIZE = 32
dev = "cuda:0"

class ImageDataset(Dataset):
    def __init__(self,length = 100000,size = 244):
        self.length = length
        self.size = 244
    def __len__(self):
        return self.length
    def __getitem__(self,idx,display = False):
        return torch.from_numpy(np.random.randn(2,3,self.size,self.size))
train = ImageDataset()
trainloader = DataLoader(
    train,
    batch_size = BATCH_SIZE,
    num_workers = 24,
    pin_memory = True
)

resnet = models.resnet50(pretrained = False)

class MODEL(nn.Module):
    def __init__(self,model):
        super(MODEL,self).__init__()
        self.model = model
        self.LR = nn.Linear(1000,1000)
    def forward(self,x):
        if CHECKPOINT == False:
            o1 = self.model(x[:,0])
            o2 = self.model(x[:,1])
        else:
            o1 = torch.utils.checkpoint.checkpoint(self.model,x[:,0])
            o2 = torch.utils.checkpoint.checkpoint(self.model,x[:,1])
        
        return torch.mean((self.LR(o1)-o2)**2)
    
resnet = MODEL(resnet).to(dev)

optimizer = optim.SGD(resnet.parameters(),lr = .001)

for T in tqdm(trainloader):
    out = torch.mean(resnet(T.float().to(dev)))
    optimizer.zero_grad()
    out.backward()
    optimizer.step()

CG

ここに画像の説明を挿入

  • https://github.com/merrymercy/dtr-prototype

ゼロオフロード

  • https://arxiv.org/pdf/2101.06840.pdf 大規模なモデルのトレーニングは、複雑なモデルの再構築と高価な GPU クラスターへのアクセスを必要とする少数派の分野でした。ZeRO-Offload は、大規模なモデルのトレーニングをほぼすべての人が利用できるようにすることで機能します。単一の GPU で 13 億を超えるパラメーターを使用してモデルをトレーニングできます。これは、GPU と比較してサイズが 10 倍増加します。PyTorch などの一般的なフレームワークは、モデルを必要とせずにこれを実行します。データサイエンティストの計算効率を変更または犠牲にします。ZeRO-Offload は、データと計算をオフロードすることにより、大規模モデルの CPU トレーニングを可能にします。計算効率を維持するために、GPU 内外のデータ移動を最小限に抑え、CPU の計算時間を短縮すると同時に、GPU コストのメモリを最大限に節約するように設計されています。したがって、ZeRO-Offload はシングルで 40 TFlops/GPU を達成できます。10B パラメータ モデルの NVIDIA V100 GPU と比較して、1.4B パラメータ モデルの PyTorch のみの 30TF は、使い果たすことなくトレーニングできるパラメータ モデルの最大メモリです。ZeRO-Offload は、利用可能な場合には複数の GPU にまたがって拡張できるように設計されており、最大 128 GPU でほぼ線形の高速化を実現します。さらに、モデル並列処理と連携して、単一の DGX-2 ボックス上で 70 億を超えるパラメーターを含むモデルをトレーニングできます。これは、モデル並列処理のみを使用した場合と比較して、モデル サイズが 4.5 倍増加します。ZeRO-Offload は、計算効率とメモリ効率と使いやすさを組み合わせることで、大規模モデルのトレーニングを民主化し、データ サイエンティストでも GPU にアクセスするだけでトレーニングにアクセスできるようにします。

勾配累積

        一般に、トレーニング中のバッチが大きいほど、トレーニング結果がより安定します。勾配累積トレーニング法は、ディープ ニューラル ネットワークのトレーニングに使用される手法であり、ビデオ メモリ要件を削減し、トレーニング結果を向上させることを目的としています。従来のトレーニング方法では、モデルのパラメーターは、単一のデータ バッチから計算された平均勾配によって更新されます。ただし、勾配累積トレーニングでは、モデルのパラメーターの更新は、複数のバッチの勾配累積を通じて取得されます。

以下は、勾配累積トレーニングの基本的な手順です。

  1. 累積するグラジエントのバッチ数を決定するグラジエント累積ステップ数 (累積ステップ) を設定します。

  2. モデルのパラメータを初期化します。

  3. 各トレーニング バッチについて:

    • データの現在のバッチを使用して順伝播を実行し、損失を計算します。
    • 損失を逆伝播することで勾配が計算されます。
    • 現在のバッチの勾配を前の勾配値まで累積します。
  4. 累積が設定されたステップ数に達すると、累積された勾配がモデル パラメーターの更新に適用されます。

    • パラメータの更新された値は、累積された勾配を平均することによって取得されます。
    • 更新された値を使用してモデルのパラメーターを更新します。
  5. すべてのトレーニング バッチが完了するまで、手順 3 と 4 を繰り返します。

勾配累積トレーニングの主な利点は、バッチごとに必要なビデオ メモリの量を削減できることであり、ビデオ メモリが限られているハードウェア上で大規模なモデルをトレーニングできるようになります。さらに、勾配の累積によりモデルの収束性が向上し、モデルのパフォーマンスと汎化能力が向上します。

おすすめ

転載: blog.csdn.net/ResumeProject/article/details/132123257