【画像分類】 【ディープラーニング】 【軽量ネットワーク】 【Pytorch版】 ShuffleNet_V2モデルアルゴリズムの詳細説明

【画像分類】 【ディープラーニング】 【軽量ネットワーク】 【Pytorch版】 ShuffleNet_V2モデルアルゴリズムの詳細説明


序文

ShuffleNet_V2 は、記事「ShuffleNet V2: Efficient CNN Architecture Design [ECCV-2018]」[論文アドレス] で、Megvii Technology の Ma 氏、Ningning 氏らが提案した改良モデルです。この論文では、効率的なネットワーク アーキテクチャ設計を提案しています。原則: 第一に、間接的な指標 (FLOP など) ではなく直接的な指標 (速度など) を使用すること、第二に、4 つのクロスプラットフォーム設計ガイドラインを提案し、これらのガイドラインの指導の下で ShuffleNet_V2 を設計することです。


ShuffleNet_V2の説明

MobileNet_v1、v2、ShuffleNet_v1、Xception などのこれまでの一部のネットワーク モデルでは、浮動小数点演算 (FLOP) の量をある程度削減するために、グループ化畳み込みまたは深さ分離可能な畳み込みを使用していましたが、FLOP は完全なものではありません。モデルの速度を直接測定する指標であり、理論上の計算量を通じて間接的にのみモデルの速度を測定します。
しかし、実際のデバイスでは、さまざまな最適化計算が行われるため、メモリ アクセス コスト (MAC) やプラットフォームの特性によっても制限されるため、計算量によってモデルの速度を正確に測定することはできません。同じ FLOP に対して異なる推論速度が表示されます。次の図は、元の論文のさまざまな状況における推論速度の具体的な状況を示しています。
速度は FLOP によって完全に決定されるわけではありません。上の図に示すように、赤いボックスはさまざまなデバイスを表し、左側は結果です。右側がGPU、ARMの結果です。以下の 2 つの図からわかるように、異なるモデルの MFLOP が同じである場合、速度は異なります。

バッチ/秒: データ ウェアハウス書き込み操作モジュールが 1 秒あたりに受信したバッチの数。

では、デバイスの実行速度に影響を与える他の要因は何でしょうか? この論文では、直接指標と間接指標の間の矛盾には 2 つの理由が考えられると説明されています。

  • まず、FLOP では速度に影響を与えるいくつかの重要な要素が考慮されていません。たとえば、メモリ アクセス コストは、
    グループ畳み込みで多くの計算時間を消費し、GPU 操作の潜在的なパフォーマンスのボトルネックにもなります。また、並列処理もあります。同じ FLOP を使用すると、並列性の高いネットワークの方が高速に実行されます。
  • 第 2 に、FLOP の同じ操作が異なるプラットフォームでは異なる方法で実行されます。たとえば、初期の研究では、行列の乗算を高速化するためにテンソル分解が広く使用されていました。これにより FLOP は 75% 削減できますが、GPU での動作は遅くなります。これは、CuDnn が 3×3 畳み込み用に特別に最適化されているためです。積は 1×1 畳み込みの理論上の時間の 9 倍ではなくなり、この分解には高速化する明確な意味がありません。

したがって、この文書では、効率的なネットワーク アーキテクチャ設計のための 2 つの主要な原則を提案します。第一に、間接的な指標 (FLOP など) ではなく直接的な指標 (速度など) を使用すること、第二に、指標はターゲット プラットフォーム上で検証される必要があること、同時に 4 つのクロスプラットフォーム設計ガイドラインが提案されており、新しいネットワーク アーキテクチャ ShuffleNet V2 は、これらのガイドラインに基づいて設計されています。

4 つの実践的な指針となるアイデア

次の図は、ShuffleNet_V1 および MobileNet_V2 ネットワーク コンポーネントの実行速度に関する元の論文の統計です:

図から、GPU/ARM 上のモデルのさまざまな操作に費やされた時間がわかります。FLOP は畳み込み部分の計算量のみを表しますこれが実行時間の大半を占めますが、実際にはデータの入出力、データのスクランブル、要素レベルの処理関連演算(テンソル加算、活性化関数の処理など)など、ElemwiseやData関連の部分もかなりの時間を消費します。 、など)。
この論文では、特定のプラットフォームでの ShuffleNetv1 と MobileNetv2 の実行時間を調査し、理論と実験を組み合わせることで、4 つの実践的な指針を提案しています。

G1: チャネル幅が等しいため、ストレージ アクセス コストが削減されます。

チャネル幅が等しいため、メモリ アクセス コスト (MAC) が最小限に抑えられます。
Xception [参考資料]、MobileNet_V1 [参考資料]、MobileNet_V2 [参考資料]、ShuffleNet_V1 [参考資料] などの現在のネットワークはすべて深さ分離可能な畳み込みを使用し、1×1 ポイントの畳み込みが大部分を占めます。計算量: 入力特徴マップがh × w × c 1 {\rm{h}} \times {\rm{w}} \times { { \rm{c}}_1}であると仮定します。h×w×c1, 1×1点の畳み込みはc 1 × c 2 × 1 × 1 { {\rm{c}}_1} \times { {\rm{c}}_2} \times 1 \times 1c1×c2×1×1 の場合、出力特徴マップのサイズは変更されないため、1×1 ポイント畳み込みの FLOP は次のようになります。B = h × w × c 1 × c 2 {\rm{B = h}} \times {\rm{w} } \times { {\rm{c}}_1} \times { {\rm{c}}_2}B=h×w×c1×c2

FLOPs 計算は乗算と加算を浮動小数点演算として扱います

コンピューティング デバイスのバッファーが特徴マップ全体とすべてのパラメーターを保存するのに十分な大きさであると仮定すると、1×1 ポイント畳み込みのメモリ アクセス コスト (メモリ アクセス数) は、MAC = hwc 1 + hwc 2 + c 1 c 2となります。 = hw ( c 1 + c 2 ) + c 1 c 2 {\rm{MAC = hw}}{ { \rm{c}}_1}{\rm{ + hw}}{ {\rm{c}}_2 } + { {\ rm{c}}_1}{ {\rm{c}}_2} = {\rm{hw(}}{ { \rm{c}}_1}{\rm{ + }}{ { \rm{c} }_2}) + { {\rm{c}}_1}{ {\rm{c}}_2}マック=c_ _1+ハードウェア2+c1c2=うわー( c1+ c2+c1c2、入力特徴マップ、出力特徴マップ、重みパラメーターのコストを表します。BB
固定時B時、c 2 = B hwc 1 { {\rm{c}}_2} = \frac{B}{ {hw{ {\rm{c}}_1}}}c2=ああああ1B,均等
メモリMAC = hw ( c 1 + c 2 ) + c 1 c 2 = ( hw ) 2 ( c 1 + c 2 ) 2 + B hw ≥ ( hw ) 2 ( 4 c 1 c 2 ) + B hw に基づく≥ 2 hw B + B hw {\rm{MAC = hw(}}{ {\rm{c}}_1}{\rm{ + }}{ {\rm{c}}_2}) + { {\rm {c}}_1}{ {\rm{c}}_2} = \sqrt { { { {\rm{(hw}})}^2}{ { { \rm{(}}{ {\rm{c }}_1}{\rm{ + }}{ {\rm{c}}_2})}^2}} + \frac{B}{ {hw}} \ge \sqrt { { { {\rm{( hw}})}^2}{\rm{(4}}{ {\rm{c}}_1}{ {\rm{c}}_2})} + \frac{B}{ {hw}} \ ge 2\sqrt { {\rm{hwB}}} + \frac{B}{ {hw}}マック=うわー( c1+ c2+c1c2=( hw )2 ( c1+ c22 +はぁ、wB( hw )2 (4c)1c2 +はぁ、wB2hwB +はぁ、wB
平均不等式から、 c 1 = c 2 { {\rm{c}}_1} = { {\rm{c}}_2}の場合がわかります。 c1=c2時間取( c 1 + c 2 ) 2 { {\rm{(}}{ {\rm{c}}_1}{\rm{ + }}{ {\rm{c}}_2})^2}( c1+ c22の下限、つまり( c 1 + c 2 ) 2 = 4 c 1 c 2 { {\rm{(}}{ {\rm{c}}_1}{\rm{ + }}{ { \rm{c }}_2})^2} = {\rm{4}}{ {\rm{c}}_1}{ {\rm{c}}_2}( c1+ c22=4c _1c2マックマックMAC は最小値を取得します
指定された計算制限の下では、 MAC MACM A Cには下限があります。
この結論を検証するために、論文は実験解析を実施しました。下の表のテスト ネットワークは 10 個の繰り返しブロックでスタックされ、各ブロックには 2 つの畳み込み層が含まれ、入力チャネルはc 1 { {\rm{c}}_1 }c1、出力チャネルはc 2 { {\rm{c}}_2}ですc2

テーブル内のデータから、次のことが得られます。 when c 1 : c 2 { {\rm{c}}_1}{\rm{:}}{ { \rm{c}}_2}c1: c2値が 1:1 に近づくと、MAC 値はますます小さくなり、ネットワーク動作の評価速度はますます速くなります。

G2: 多数のグループ化された畳み込みによりストレージ アクセスが増加します

過剰なグループ コンボリューションにより MAC が増加する
グループ コンボリューションは、現在のネットワーク構造設計の中核であり、チャネル間の疎な接続、つまり同じグループ内の機能のみに接続することにより、計算の複雑さの FLOP を削減します。一方で、より多くのチャネルを使用できるようになり、ネットワーク容量が増加し、精度が向上しますが、他方では、チャネル数が増加するにつれて、より多くの MAC がもたらされます。
1×1 グループ化畳み込みの場合、グループ化畳み込み FLOP の計算式は次のとおりです。
B = h × w × 1 × 1 × c 1 g × c 2 g × g = hwc 1 c 2 g {\rm{B = h} } \times {\rm{w}} \times 1 \times 1 \times \frac{ { { { \rm{c}}_1}}}{g} \times \frac{ { { { \ rm{c} } _2}}}{g} \times g = \frac{ { {\rm{hw}}{ {\rm{c}}_1}{ {\rm{c}}_2}}}{g}B=h×w×1×1×gc1×gc2×g=gc_ _1c2
グループ化コンボリューション MAC の計算式:
MAC = hw ( c 1 + c 2 ) + c 1 c 2 g = hwc 1 + B gc 1 + B hw MAC = hw({c_1} + {c_2}) + \frac{ { {c_1}{c_2}}}{g} = hw{c_1} + \frac{ {Bg}}{ { {c_1}}} + \frac{B}{ {hw}}マック_ _=うわ( c _1+c2+gc1c2=ああああ1+c1Bg _+はぁ、wB
入力特徴マップh × w × c 1 {\rm{h}} \times {\rm{w}} \times { {\rm{c}}_1}h×w×c1BBが固定されると固定されますBの場合、 c 2 g \frac{ { {c_2}}}{g}を修正する必要があります gc2の比率なので、MAC MACM・A・Cgggは関係に正比例します。
この論文では、10 個のグループ化ポイント畳み込み層を重ね合わせて実験を設計し、異なる数のグループ化グループを使用して、同じ計算コスト FLOP を確保しながらモデルの実行時間をテストしました。結果は次の表に示されています。

総計算量を固定してグループ数を変えると、使用するグループが増えるほど実際の実行速度が遅くなることがわかります。したがって、グループ化された畳み込みのグループ数は、ハードウェア プラットフォームと対象タスクに基づいて慎重に選択することを推奨しており、メモリ アクセスの増加を無視して、精度が向上するからといって単純にグループ数を多く選択することはできません。コストMAC。

G3: ネットワークの断片化により並列処理が低下する

ネットワークの断片化により並列度が低下します。GoogLeNet
シリーズ: Inception V1 [参考]、V2 [参考]、V3 [参考] V4 [参考] などでは、ネットワーク内の各単位ブロックがマルチブランチ構造 (マルチブランチ構造) を使用しています。この構造では、大きなオペレータの代わりに小さなオペレータ (フラグメント オペレータ/フラグメント オペレータ) が主に使用され、ネットワーク構造ブロック内の各畳み込み演算またはプーリング演算はフラグメント オペレータと呼ばれます。過去の論文では、断片化された構造によりモデルの精度が向上する可能性があることが示されていますが、この GPU の構造は並列性の高いデバイスには適していないため、効率が低下します
ネットワークの断片化、またはネットワークの分岐が効率に与える影響を定量化するために、この論文ではさまざまな程度の断片化を伴う一連のネットワーク構成要素を評価しています。

具体的には、比較実験の各構成ブロックは、順次または並列構造の 1 ~ 4 つの 1x1 畳み込み層で構成されます。
ネットワーク ブランチがパフォーマンスに与える影響を検証するために、この論文では、さまざまなブランチ レベルのネットワークで比較実験を実施しました。各ブロックの積み重ねを 10 回繰り返します。以下の表の結果は、断片化により GPU の速度が低下することを示しています。

過剰なネットワーク ブランチは、GPU デバイスの実行速度を大幅に低下させますが、ARM プラットフォームでは速度の低下は比較的穏やかです。

G4: 要素レベルの操作は無視できません

要素ごとの操作は無視できません
一部の要素ごとの操作 (要素ごとの演算子) も、特に GPU デバイスで時間のかなりの部分を占めます。FLOP は比較的小さいですが、MAC 値は大きくなります。特に論文では、深い畳み込みも要素レベルの演算であり、通常はより高い MAC/FLOP 値を持つと推測しています。

この論文の要素ごとの演算子 (要素レベルの演算) には、ReLU、AddTensor、AddBias などが含まれます。

この論文では、ResNet の「ボトルネック」ユニットを使用して実験を行っています。実験では ReLU とショートカット操作が削除されています。GPU および ARM デバイスでは、実行速度が約 20% 向上しました。結果は次の表に示されています。

ShuffleNet_V2のモデル構造

ShuffleNet_V1 は、ポイントごとのグループ畳み込みとボトルネック状の構造という 2 つのテクノロジーを使用します。この記事の前章の導入から、点ごとのグループ畳み込みとボトルネック構造の両方が MAC (G2 原則と G1 原則) を増加させることがわかります。これは、特に軽量ネットワークでは無視できません。さらに、残りの接続 (G3 原則と G4 原則) であまりにも多くのグループ化や要素ごとの加算を使用することはお勧めできません。
したがって、効率的なネットワーク モデルを構築するには、密な畳み込みやグループ化をあまり行わずに、多数の同じ幅のチャネルを維持することが重要です。したがって、ShuffleNet_V1 基本ユニットに基づいて、ShuffleNet_V2 基本ユニットはチャネル分割を導入します。
次の図は、元の論文の ShuffleNet_V1 と ShuffleNet_V2 の比較図です。 stride=1 の場合、ShuffleNet_V2 の基本ユニットはチャネル セグメンテーションを通じてcc

を入力します。cチャネルの特徴マップは 2 つの分岐部分に分割されます。1 つの部分はc, { {\rm{c}}^,}cチャネルのショートカット分岐、他の部分はc − c 、{\rm{c - }}{ {\rm{c}}^,}c c各チャネルのトランク ブランチ、簡単にするために、 setc , = c / 2 { {\rm{c}}^,} = c/2c=c /2 (G3 原則を満たす)参考文献と同様に、次の基本ユニットに直接入りますバックボーン ブランチには、同じ数のチャネルを持つ 3 つの畳み込み(G1 原則を満たす)、元の通常の畳み込みになります(G2 原則を満たす)幹ブランチのチャネル シャッフルはスプライシング後に移動されます。最後に、2 つのブランチの出力が ShuffleNet_V1 で追加されるのではなく結合され、基本ユニットの入力チャネルと出力チャネルの数が一貫した状態に保たれます(G1 原則を満たす)
stride=2 では、チャネル セグメンテーションが削除され、ShuffleNet_V1 ショートカット ブランチの 3x3 平均プーリングが 3x3 深さ畳み込み + 1x1 通常畳み込みの組み合わせに置き換えられます。
要素ごとの演算演算子 ReLU は右分岐にのみ存在し、3 つの連続する要素ごとの演算演算子、スプライシング、チャネル シャッフル、およびチャネル分割を 1 つの要素ごとの演算子(G4 原則を満たす)

次の図は、元の論文で示されている ShuffleNet_V2 モデル構造の詳細な概略図です。

ShuffleNet_V2 は、画像分類において 2 つの部分に分かれています:バックボーン部分:主に ShuffleNet_V2 基本ユニット、畳み込み層とプーリング層 (集約層)、分類器で構成されます。部分: グローバル プーリング層と完全接続層で構成されます。

ShuffleNet_V2 の基本ユニット チャネルの数は、0.5 倍の比率に従ってスケーリングされ、さまざまな複雑さの ShuffleNet_V2 ネットワークが生成されます。


ShuffleNet_V2 Pytorch コード

チャンネルシャッフル:機能のインタラクティブ性と表現力が向上します。

def channel_shuffle(x, groups):
    # 获得特征图的所以维度的数据
    batch_size, num_channels, height, width = x.shape
    # 对特征通道进行分组
    channels_per_group = num_channels // groups
    # reshape新增特征图的维度
    x = x.view(batch_size, groups, channels_per_group, height, width)
    # 通道混洗(将输入张量的指定维度进行交换)
    x = torch.transpose(x, 1, 2).contiguous()
    # reshape降低特征图的维度
    x = x.view(batch_size, -1, height, width)
    return x

チャネル シャッフルのコード図を以下に示します。

ShuffleNet Uint 基本ユニット): 1×1 畳み込みおよび 3×3 深さ畳み込み + BN 層 + 活性化関数

class ShuffleUnit(nn.Module):
    def __init__(self, input_c: int, output_c: int, stride: int):
        super(ShuffleUnit, self).__init__()
        # 步长必须在1和2之间
        if stride not in [1, 2]:
            raise ValueError("illegal stride value.")
        self.stride = stride

        # 输出通道必须能二被等分
        assert output_c % 2 == 0
        branch_features = output_c // 2

        # 当stride为1时,input_channel是branch_features的两倍
        # '<<' 是位运算,可理解为计算×2的快速方法
        assert (self.stride != 1) or (input_c == branch_features << 1)

        # 捷径分支
        if self.stride == 2:
            # 进行下采样:3×3深度卷积+1×1卷积
            self.branch1 = nn.Sequential(
                self.depthwise_conv(input_c, input_c, kernel_s=3, stride=self.stride, padding=1),
                nn.BatchNorm2d(input_c),
                nn.Conv2d(input_c, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True)
            )
        else:
            # 不进行下采样:保持原状
            self.branch1 = nn.Sequential()

        # 主干分支
        self.branch2 = nn.Sequential(
            # 1×1卷积+3×3深度卷积+1×1卷积
            nn.Conv2d(input_c if self.stride > 1 else branch_features, branch_features, kernel_size=1,
                      stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
            self.depthwise_conv(branch_features, branch_features, kernel_s=3, stride=self.stride, padding=1),
            nn.BatchNorm2d(branch_features),
            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True)
        )

    # 深度卷积
    @staticmethod
    def depthwise_conv(input_c, output_c, kernel_s, stride, padding, bias= False):
        return nn.Conv2d(in_channels=input_c, out_channels=output_c, kernel_size=kernel_s,
                         stride=stride, padding=padding, bias=bias, groups=input_c)

    def forward(self, x):
        if self.stride == 1:
            # 通道切分
            x1, x2 = x.chunk(2, dim=1)
            # 主干分支和捷径分支拼接
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            # 通道切分被移除
            # 主干分支和捷径分支拼接
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
        # 通道混洗
        out = channel_shuffle(out, 2)
        return out

完全なコード

from typing import List, Callable

import torch
from torch import Tensor
import torch.nn as nn
from torchsummary import summary

def channel_shuffle(x, groups):
    # 获得特征图的所以维度的数据
    batch_size, num_channels, height, width = x.shape
    # 对特征通道进行分组
    channels_per_group = num_channels // groups
    # reshape新增特征图的维度
    x = x.view(batch_size, groups, channels_per_group, height, width)
    # 通道混洗(将输入张量的指定维度进行交换)
    x = torch.transpose(x, 1, 2).contiguous()
    # reshape降低特征图的维度
    x = x.view(batch_size, -1, height, width)
    return x

class ShuffleUnit(nn.Module):
    def __init__(self, input_c: int, output_c: int, stride: int):
        super(ShuffleUnit, self).__init__()
        # 步长必须在1和2之间
        if stride not in [1, 2]:
            raise ValueError("illegal stride value.")
        self.stride = stride

        # 输出通道必须能二被等分
        assert output_c % 2 == 0
        branch_features = output_c // 2

        # 当stride为1时,input_channel是branch_features的两倍
        # '<<' 是位运算,可理解为计算×2的快速方法
        assert (self.stride != 1) or (input_c == branch_features << 1)

        # 捷径分支
        if self.stride == 2:
            # 进行下采样:3×3深度卷积+1×1卷积
            self.branch1 = nn.Sequential(
                self.depthwise_conv(input_c, input_c, kernel_s=3, stride=self.stride, padding=1),
                nn.BatchNorm2d(input_c),
                nn.Conv2d(input_c, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True)
            )
        else:
            # 不进行下采样:保持原状
            self.branch1 = nn.Sequential()

        # 主干分支
        self.branch2 = nn.Sequential(
            # 1×1卷积+3×3深度卷积+1×1卷积
            nn.Conv2d(input_c if self.stride > 1 else branch_features, branch_features, kernel_size=1,
                      stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
            self.depthwise_conv(branch_features, branch_features, kernel_s=3, stride=self.stride, padding=1),
            nn.BatchNorm2d(branch_features),
            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True)
        )

    # 深度卷积
    @staticmethod
    def depthwise_conv(input_c, output_c, kernel_s, stride, padding, bias= False):
        return nn.Conv2d(in_channels=input_c, out_channels=output_c, kernel_size=kernel_s,
                         stride=stride, padding=padding, bias=bias, groups=input_c)

    def forward(self, x):
        if self.stride == 1:
            # 通道切分
            x1, x2 = x.chunk(2, dim=1)
            # 主干分支和捷径分支拼接
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            # 通道切分被移除
            # 主干分支和捷径分支拼接
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
        # 通道混洗
        out = channel_shuffle(out, 2)
        return out


class ShuffleNetV2(nn.Module):
    def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, ShuffleUnit=ShuffleUnit):
        super(ShuffleNetV2, self).__init__()

        if len(stages_repeats) != 3:
            raise ValueError("expected stages_repeats as list of 3 positive ints")
        if len(stages_out_channels) != 5:
            raise ValueError("expected stages_out_channels as list of 5 positive ints")
        self._stage_out_channels = stages_out_channels

        # 输入通道
        input_channels = 3
        output_channels = self._stage_out_channels[0]

        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True)
        )
        input_channels = output_channels
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # 三个基本单元组层
        self.stage2: nn.Sequential
        self.stage3: nn.Sequential
        self.stage4: nn.Sequential

        stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
        for name, repeats, output_channels in zip(stage_names, stages_repeats,
                                                  self._stage_out_channels[1:]):
            # 每个Stage的首个基础单元都需要进行下采样,其他单元不需要
            seq = [ShuffleUnit(input_channels, output_channels, 2)]
            for i in range(repeats - 1):
                seq.append(ShuffleUnit(output_channels, output_channels, 1))
            setattr(self, name, nn.Sequential(*seq))
            input_channels = output_channels
        output_channels = self._stage_out_channels[-1]
        self.conv5 = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True)
        )
        # 全局平局池化
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        # 全连接层
        self.fc = nn.Linear(output_channels, num_classes)
        # 权重初始化
        self.init_params()
    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.conv5(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

def shufflenet_v2_x0_5(num_classes=1000):
    """
    weight: https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth
    """
    model = ShuffleNetV2(stages_repeats=[4, 8, 4],
                         stages_out_channels=[24, 48, 96, 192, 1024],
                         num_classes=num_classes)
    return model
def shufflenet_v2_x1_0(num_classes=1000):
    """
    weight: https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth
    """
    model = ShuffleNetV2(stages_repeats=[4, 8, 4],
                         stages_out_channels=[24, 116, 232, 464, 1024],
                         num_classes=num_classes)
    return model

def shufflenet_v2_x1_5(num_classes=1000):
    """
    weight: https://download.pytorch.org/models/shufflenetv2_x1_5-3c479a10.pth
    """
    model = ShuffleNetV2(stages_repeats=[4, 8, 4],
                         stages_out_channels=[24, 176, 352, 704, 1024],
                         num_classes=num_classes)
    return model

def shufflenet_v2_x2_0(num_classes=1000):
    """
    weight: https://download.pytorch.org/models/shufflenetv2_x2_0-8be3c8ee.pth
    """
    model = ShuffleNetV2(stages_repeats=[4, 8, 4],
                         stages_out_channels=[24, 244, 488, 976, 2048],
                         num_classes=num_classes)
    return model

if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = shufflenet_v2_x2_0().to(device)
    summary(model, input_size=(3, 224, 224))

summary ではネットワーク構造とパラメータを出力できるため、構築されたネットワーク構造を簡単に確認できます。


要約する

4 つの実践的な指針の原則ができるだけ簡単かつ詳細に紹介され、ShuffleNet_V2 モデルの構造と pytorch コードが説明されます。

おすすめ

転載: blog.csdn.net/yangyu0515/article/details/135168267
おすすめ