pytorch に基づいてモデルの枝刈りを実装する

1.剪定の分類

いわゆるモデル プルーニングは、実際には、ニューラル ネットワークから「不要な」重みやバイアス (重み/バイアス) を削除するモデル圧縮テクノロジーです。どのパラメータが「不要」とみなされるかについては、まだ研究中の領域です。

1.1、非構造化枝刈り

非構造化プルーニングとは、完全接続層の単一の重み、畳み込み層の単一の畳み込みカーネル パラメーター要素、カスタム レイヤーのスケーリング フロートなど、パラメーターの単一要素をプルーニングすることを指します。重要な点は、枝刈りの重みオブジェクトはランダムであり、特定の構造を持たないため、 非構造枝刈り と呼ばれます

1.2、構造化された枝刈り

非構造化枝刈りとは対照的に、構造化枝刈りはパラメータ構造全体を枝刈りします。たとえば、重みの行全体または列全体を破棄したり、畳み込み層のフィルター全体を破棄したりします ( Filter)。

1.3、ローカルおよびグローバルのプルーニング

プルーニングは、各レイヤー (ローカル) または複数/すべてのレイヤー (グローバル) で実行できます。

2. PyTorch のプルーニング

現在、PyTorch フレームワークでサポートされている重みプルーニング メソッドは次のとおりです。

  • Random : ランダムなパラメータをトリミングするだけです。
  • 大きさ: 最小の重みを持つパラメータ (L2 ノルムなど) をプルーニングします。

上記の 2 つの方法は実装も計算も簡単で、データがなくても適用できます。

2.1、pytorch 枝刈りの動作原理

プルーニング関数はtorch.nn.utils.pruneクラスに実装されており、コードはファイル torch/nn/utils/prune.py にあります。主なプルーニング クラスを次の図に示します。

pytorch_pruning_api_file.png

枝刈りの原理は、テンソルのマスク実装に基づいています。マスクはテンソルと同じ形状のブール テンソルです。マスクの値は True で、対応する位置の重みを保持する必要があることを示します。マスクの値は False で、対応する位置の重みを保持する必要があることを示します。位置を削除することができます。

Pytorch は<param>、元のパラメータを<param>_originalという名前のパラメータにコピーし、プルーニング マスクを格納するバッファ<param>_maskを作成します。同時に、元の重みに枝刈りマスクを適用するために、モジュール レベルの forward_pre_hook コールバック関数 (モデルが順伝播される前に呼び出されるコールバック関数) も作成します。

Pytorch のプルーニングapiとチュートリアルは非常にわかりにくいので、API とプルーニングの方法と分類をまとめたいと思い、個人的に次の表を作成しました。

pytorch_pruning_api

pytorch でのモデルの枝刈りのワークフローは次のとおりです。

  1. プルーニング メソッド (または独自のプルーニング メソッドを実装するには BasePruningMethod のサブクラス) を選択します。
  2. プルーニング モジュールとパラメーター名を指定します。
  3. 枝刈り率など枝刈り方法のパラメータを設定します。

2.2、ローカルプルーニング

Pytorch フレームワークには、非構造化プルーニングと構造化プルーニングの 2 つのタイプのローカル プルーニングがありますが、構造化プルーニングはローカルのみをサポートし、グローバルはサポートしないことに注意してください。

2.2.1、ローカルの非構造化プルーニング

1. Local Unstructed Pruning の対応する関数プロトタイプは次のとおりです

def random_unstructured(module, name, amount)

1.機能機能:

重みパラメータ テンソルの非構造化枝刈りに使用されます。このメソッドは、枝刈りのためにテンソル内のいくつかの重みまたは接続をランダムに選択し、枝刈り率はユーザーによって指定されます。

2. 関数パラメータの定義:

  • module(nn.Module): nn.Conv2d() や nn.Linear() など、プルーニングする必要があるネットワーク層/モジュール。
  • name(str): 「weight」や「bias」など、プルーニングされるパラメータの名前。
  • amount(int または float): 枝刈りする数量を指定します。0 ~ 1 の 10 進数の場合は枝刈り率を示し、証明書の場合はパラメータの絶対量を直接切り捨てます。たとえばamount=0.2、要素の 20% が枝刈りのためにランダムに選択されることを意味します。

3. 以下はrandom_unstructured関数の使用例です。

import torch
import torch.nn.utils.prune as prune
conv = torch.nn.Conv2d(1, 1, 4)
prune.random_unstructured(conv, name="weight", amount=0.5)
conv.weight
"""
tensor([[[[-0.1703,  0.0000, -0.0000,  0.0690],
          [ 0.1411,  0.0000, -0.0000, -0.1031],
          [-0.0527,  0.0000,  0.0640,  0.1666],
          [ 0.0000, -0.0000, -0.0000,  0.2281]]]], grad_fn=<MulBackward0>)
"""

conv 層出力の重み値の半分が であることがわかります0

2.2.2、ローカル構造化プルーニング

Local Structured Pruning には 2 つの関数があり、対応する関数プロトタイプは次のとおりです。

def random_structured(module, name, amount, dim)
def ln_structured(module, name, amount, n, dim, importance_scores=None)

1. 機能機能

接続の重みを削除する非構造化プルーニングとは異なり、構造化プルーニングではチャネルの重み全体が削除されます。

2. パラメータの定義

ローカルの非構造化関数とよく似ていますが、唯一の違いは、dim パラメーターを定義する必要があることです (ln_structed 関数にはより多くのnパラメーターがあります)。

nは剪定の基準を表し、dim剪定の次元を表します。

torch.nn.Linear の場合:

  • dim = 0: ニューロンを削除します。
  • dim = 1: 入力へのすべての接続を削除します。

torch.nn.Conv2d の場合:

  • dim = 0(チャンネル): チャンネルのプルーニング/フィルター フィルターのプルーニング
  • dim = 1(ニューロン): 2 次元コンボリューション カーネル プルーニング、つまり入力チャネルに接続されたカーネル

2.2.3、ローカル構造化プルーニングのサンプルコード

サンプル コードを記述する前に、まずConv2d関数パラメーター、コンボリューション カーネル形状、軸、テンソルの間の関係を理解する必要があります。

まず、Conv2d 関数のプロトタイプは次のとおりです。

class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

pytorch では、従来の畳み込みの畳み込みカーネルの重みはshape( C_out, C_in, kernel_height, kernel_width) であるため、コードでは畳み込み層の重みはshapeであり[3, 2, 3, 3]、 dim = 0 は形状 [3, 2, 3, 3] に対応します3ここでは、どの軸を設定するかをディムします。そうすると当然、ウェイト テンソルに対応する軸は枝刈り後に変更されます。

薄暗い

これまでの重要な概念を理解した後、実際に使用することができます (dim=0例を以下に示します)。

conv = torch.nn.Conv2d(2, 3, 3)
norm1 = torch.norm(conv.weight, p=1, dim=[1,2,3])
print(norm1)
"""
tensor([1.9384, 2.3780, 1.8638], grad_fn=<NormBackward1>)
"""
prune.ln_structured(conv, name="weight", amount=1, n=2, dim=0)
print(conv.weight)
"""
tensor([[[[-0.0005,  0.1039,  0.0306],
          [ 0.1233,  0.1517,  0.0628],
          [ 0.1075, -0.0606,  0.1140]],

         [[ 0.2263, -0.0199,  0.1275],
          [-0.0455, -0.0639, -0.2153],
          [ 0.1587, -0.1928,  0.1338]]],


        [[[-0.2023,  0.0012,  0.1617],
          [-0.1089,  0.2102, -0.2222],
          [ 0.0645, -0.2333, -0.1211]],

         [[ 0.2138, -0.0325,  0.0246],
          [-0.0507,  0.1812, -0.2268],
          [-0.1902,  0.0798,  0.0531]]],


        [[[ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000]]]], grad_fn=<MulBackward0>)
"""

0実行結果から、畳み込み層パラメーターの最後のチャネル パラメーター テンソルが (テンソルとして) 削除されていることが明確にわかります。その説明については、下の図を参照してください。

薄暗い理解

dim = 1場合:

conv = torch.nn.Conv2d(2, 3, 3)
norm1 = torch.norm(conv.weight, p=1, dim=[0, 2,3])
print(norm1)
"""
tensor([3.1487, 3.9088], grad_fn=<NormBackward1>)
"""
prune.ln_structured(conv, name="weight", amount=1, n=2, dim=1)
print(conv.weight)
"""
tensor([[[[ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000]],

         [[-0.2140,  0.1038,  0.1660],
          [ 0.1265, -0.1650, -0.2183],
          [-0.0680,  0.2280,  0.2128]]],


        [[[-0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000]],

         [[-0.2087,  0.1275,  0.0228],
          [-0.1888, -0.1345,  0.1826],
          [-0.2312, -0.1456, -0.1085]]],


        [[[-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]],

         [[-0.0891,  0.0946, -0.1724],
          [-0.2068,  0.0823,  0.0272],
          [-0.2256, -0.1260, -0.0323]]]], grad_fn=<MulBackward0>)
"""

明らかに、dim=1の次元では、最初のテンソルの L2 ノルムが小さいため、形状 [2, 3, 3] のテンソルでは、最初の [3, 3] テンソル パラメーターが削除されます (つまり、テンソルは 0 行列です)。 。

2.3. グローバルな非構造化プルーニング

前回のローカル プルーニングの対象は特定のネットワーク層でしたが、グローバル プルーニングではモデル全体をみなして指定された割合 (数値) のパラメータを削除し、グローバル プルーニングの結果により各層のスパース率が生じます。モデルは異なりますが、同じです。

グローバル非構造化プルーニング関数のプロトタイプは次のとおりです。

# v1.4.0 版本
def global_unstructured(parameters, pruning_method, **kwargs)
# v2.0.0-rc2版本
def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs):

1.機能機能:

どのレイヤーに属しているかに関係なく、すべてのグローバル パラメーター (重みとバイアスを含む) の一部を枝刈りのためにランダムに選択します。

2.パラメータの定義:

  • parameters(((モジュール, 名前) タプルの反復可能)): モデルのパラメーター リストをトリミングします。リスト内の要素は (モジュール, 名前) です。
  • pruning_method(関数): 現在、公式は pruning_method=prune.L1Unstuctured のみをサポートしているようですが、これに加えて、独自に実装した非構造化枝刈りメソッド関数を使用することもできます。
  • importance_scores: 各パラメータの重要度スコアを示します。なしの場合はデフォルトのスコアが使用されます。
  • **kwargs: 特定のプルーニング メソッドに渡される追加パラメーターを表します。たとえば、amount枝刈りする枝刈りの数を指定します。

3.global_unstructured関数のサンプルコードは以下のとおりです。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)
# 计算卷积层和整个模型的稀疏度
# 其实调用的是 Tensor.numel 内内函数,返回输入张量中元素的总数
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)
# 程序运行结果
"""
Sparsity in conv1.weight: 3.70%
Global sparsity: 20.00%
"""

実行結果は、モデル全体の (グローバル) スパース性が20%20% であるにもかかわらず、各ネットワーク層のスパース性が必ずしも 20% であるわけではないことを示しています。

3、まとめ

さらに、pytorch フレームワークはいくつかのヘルパー関数も提供します。

  1. torch.nn.utils.prune.is_pruned(module): モジュールがプルーニングされているかどうかを判断します。
  2. torch.nn.utils.prune.remove(module, name):指定された module 内の指定されたパラメータに対するプルーニング操作を削除し、それによってパラメータの元の形状と値を復元するために使用されます。

PyTorch は組み込みのプルーニングを提供しておりAPI、いくつかの非構造化および構造化プルーニング メソッドもサポートしていますが、APIわかりにくく、対応するドキュメントの説明が明確ではないため、nni後ほど Microsoft のオープンソース ツールを組み合わせてモデル プルーニング機能を実現します。

プルーニング方法の詳細については、githubリポジトリModel-Compressionを参照してください。

参考文献

  1. PyTorch を使用してニューラル ネットワークをプルーニングする方法
  2. 剪定チュートリアル
  3. PyTorch プルーニング

記事を読んで何かを得た場合は、まず「いいね!」をしてから保存してください。だって、誰かにバラを贈ると、手にも香りが残るんです
最後に、さらに多くのインタビューや役立つ記事をご覧になるには、私の公式アカウントである Embedded Visionをフォローしてください。

おすすめ

転載: blog.csdn.net/lovely_yoshino/article/details/132405013