Vision Transformer (vit) の原理分析と機能の可視化

目次

ビタミンの紹介

Vitモデル構造図

vit入力処理

画像のタイリング

クラストークンとポジションの追加

特徴抽出

ビットコード


ビタミンの紹介

Vision Transformer (ViT) は、画像認識およびコンピューター ビジョン タスク用のTransformerアーキテクチャに基づくディープ ラーニング モデルです。従来の畳み込みニューラル ネットワーク (CNN) とは異なり、ViT は画像をシリアル化された入力として直接見なし、セルフ アテンション メカニズムを使用して画像内のピクセル関係を処理します。

ViT は、画像を一連のパッチに分割し各パッチを入力シーケンスとしてベクトル表現に変換することによって機能しますこれらのベクトルは、セルフ アテンション メカニズムとフィードフォワード ニューラル ネットワーク層を含む多層の Transformer エンコーダを通じて処理されますこれにより、画像内のさまざまな位置でのコンテキスト依存関係がキャプチャされます最後に、Transformer エンコーダ出力を分類または回帰することで、特定のビジョン タスクを実行できます。

vit コード リファレンス:ニューラル ネットワーク学習の短い記録 67 - ビジョン トランスフォーマー (VIT) モデルの Pytorch バージョンの再発の詳細な説明_vit の再発_Bubbliiing のブログ - CSDN ブログ

トランスフォーマーを画像処理に直接適用できないのはなぜですか? これは、トランスフォーマー自体はシーケンス タスク (NLP など) の処理に使用されますが、画像は 2 次元または 3 次元であり、ピクセル間には特定の構造的関係があるためです。画素と画素 ある程度の相関関係が必要となるため、計算量が非常に多くなります。そこでvitが誕生しました。


Vitモデル構造図

Vit のモデル構造を下図に示します。vit は画像ブロックをトランスフォーマーに適用することです。CNN はスライディング ウィンドウの考えに基づいており、畳み込みカーネルを使用して画像を畳み込み、特徴マップを取得します。画像に NLP の入力シーケンスを模倣させるには、まず画像をパッチに分割し、次にこれらの画像パッチを並べてネットワークに入力し (このようにして画像シーケンスになります)、その後特徴抽出を実行します。トランスフォーマー を介して、最終的に MLP を介してこれらの特徴を分類します [実際、これは、以前の CNN 分類タスクでバックボーンをトランスフォーマーに置き換えることとして理解できます]。

図 1: モデルの概要。 画像を固定サイズのパッチに分割し、それぞれを線形に埋め込み、位置埋め込みを追加して、結果のベクトルのシーケンスを標準の Transformer エンコーダに送ります。 分類を実行するには、追加の学習可能な「分類トークン」をシーケンスに追加するという標準的なアプローチを使用します。 Transformer エンコーダの図は、Vaswani らからインスピレーションを得たものです。 (2017年)。

vit入力処理

画像のタイリング

Image Blocking は上の vit 図のパッチ、position Embedding は位置埋め込み(画像ブロックの位置情報が得られる)です。では、画像をブロックするにはどうすればよいでしょうか? 最も単純なものは、畳み込みによって実現できます。畳み込みカーネル サイズとステップ サイズを設定することで、画像ブロックの解像度とそれを何ブロックに分割するかを制御できます。

コードではどのように実装されているのでしょうか? 以下のコードをご覧ください。

class PatchEmbed(nn.Module):
    def __init__(self, input_shape = [224,224], patch_size = 16, in_channels = 3, num_features = 768, norm_layer = None, flatten = True):
        super().__init__()
        self.num_patch = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)  # 14*14的patch = 196
        self.flatten = flatten

        self.proj = nn.Conv2d(in_channels, num_features, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(num_features) if norm_layer else nn.Identity()

    def forward(self, x):
        x = self.proj(x)  # 先进行卷积 [1,3,224,224] ->[1,768,14,14]
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # x.flatten.transpose(1, 2) shape:[1,768,196]
        x = self.norm(x)
        return x

上記のコードの num_patch は、分割できる画像ブロックの数です。proj は入力画像に対して畳み込みブロックを実行し、特徴マッピングを実行します。入力サイズが 1、3、224、224 であると仮定します。畳み込み演算の後、1、768、14、14 が得られます [768 個の解像度が形成されることを示します)畳み込みによる画像ブロック サイズは 14×14]。

各画像ブロックは 1 回抽出されます。画像ブロックの 1 つを視覚化できます。

入力画像
入力画像
画像ブロックの 1 つ

次に平坦化操作を実行します。これは [1, 768, 196] になり、最終的に Layernorm 層を通過して最終出力を取得します。

入力シーケンスを視覚化する

クラストークンとポジションの追加

上記の操作により、タイル化された特徴シーケンス (形状は 1,768,196)が得られます次に、クラス トークンがシーケンスに追加され、このトークンは特徴抽出のために前の特徴シーケンスとともにネットワークに送信されます。図ではクラス トークンは 0* であるため、長さ 196 の元のシーケンスは長さ 197 のシーケンスになります。

次に、すべての特徴シーケンスに位置情報を追加できるPosition embeddingが追加されます[197,768] 行列を生成し、それを元の特徴シーケンスに追加します。この時点で、ネットワーク入力の前処理パッチ+位置埋め込みが完了します。

# class token的定义
self.cls_token      = nn.Parameter(torch.zeros(1, 1, num_features))

# position embedding定义
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))

コード:

class VisionTransformer(nn.Module):
    def __init__(
            self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
            depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
            norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
        ):
        """
        输入大小为224,以16*16的卷积核分块14*14
        :param input_shape: 网络输入大小
        :param patch_size:  分块大小
        :param in_chans:  输入通道
        :param num_classes:  类别数量
        :param num_features: 特征图维度
        :param num_heads:  多头注意力机制head的数量
        :param mlp_ratio: MLP ratio
        :param qkv_bias: qkv的bias
        :param drop_rate: dropout rate
        :param norm_layer: layernorm
        :param act_layer: 激活函数
        """
        super().__init__()
        #-----------------------------------------------#
        #   224, 224, 3 -> 196, 768
        #-----------------------------------------------#
        self.patch_embed    = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_channels=in_chans, num_features=num_features)
        num_patches         = (224 // patch_size) * (224 // patch_size)
        self.num_features   = num_features
        self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
        self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]

        self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))  # shape [1,1,768]
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))  # shape [1, 197, 768]

    def forward_features(self, x):
        x = self.patch_embed(x)  # 先分块 [1,196, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # [1,1,768]
        x = torch.cat((cls_token, x), dim=1)  # [1,197,768]
        cls_token_pe = self.pos_embed[:, 0:1, :]  # 获取class token pos_embed 【类位置信息】
        img_token_pe = self.pos_embed[:, 1:, :]  # 后196维度是图像特征的位置信息

        img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)  # [1,768,14,14]
        img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
        img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)  # [1,196,768]
        pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)  # [1,197,768]

        x = self.pos_drop(x + pos_embed)

特徴抽出

CNN ネットワークと同様に、特徴抽出にもバックボーンが必要です。vit では、特徴抽出にトランスフォーマー エンコーダーが使用されます。

入力は [197,768] のシーケンスで、197 にはクラス トークン [学習可能]、画像シーケンス、pos_embed [学習可能] が含まれます。このシーケンスは、特徴抽出のためにエンコーダーに入力されます。トランスフォーマーにおける特徴抽出の重要なコンポーネントは、マルチヘッド アテンションです。

上の図では、入力画像が最初に Norm 層を通過し、次にq、k、v の 3 つの部分に分割され、同時にマルチヘッド アテンション メカニズムに入力されることがわかります。時間、それが自己注意のメカニズムです。次に入力を残差側に追加し、Norm と MLP を介して出力します。

q はクエリ シーケンスです。q と k の乗算により、q の各クエリ ベクトルと k の特徴ベクトルの間の相関関係または重要度が求められます。次に、これに元の入力ベクトル v を乗算して、各シーケンスの寄与を取得します (実際には、チャネル アテンション メカニズムに多少似ています)。

多くの自己注意を構築することで特徴を抽出します。CNN と比較すると、トランスフォーマーの基本コンポーネント単位はセルフアテンションであり、CNN の基本コンポーネント単位はコンボリューション カーネルです。

セルフアテンションメカニズムのコード:

コード内の qkv:

# 幾何学的意味: q、k、v は num_heads 個のヘッドに分散され (各ヘッドには qkv があります)、各ヘッドには 197*64 個の特徴シーケンスがあります。

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads  = num_heads
        self.scale      = (dim // num_heads) ** -0.5

        self.qkv        = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop  = nn.Dropout(attn_drop)
        self.proj       = nn.Linear(dim, dim)
        self.proj_drop  = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C     = x.shape
        qkv         = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v     = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

同様に、qkv を使用して特徴を視覚化し、q、k、v に何が含まれているかを確認できます。q、k、v の形状は同じであり、その形状は [batch_size, num_heads, 197, 768//num_heads] です。q の最初のヘッドの入力を視覚化します (この図には 12 個のヘッドがあります)。各頭部で 64 個の特徴が抽出されます)。

最初の頭部の q 特徴ベクトル

次に第一頭のkの特徴を見てみましょう。

最初の頭部の k の固有ベクトル

次に、q 行列と k 行列を乗算することで注意の重みが得られます。 

ps: コード内の q と k の結果が q @ k.transpose(-2,-1) になるのはなぜですか? なぜkTではないのでしょうか?これは、あと 2 つの行列計算を行う場合、最後の 2 次元についてのみ計算すればよく、各ヘッドの q と k の間の内積演算は最後の 2 次元で実行する必要があるためです。したがって、最後の 2 つの次元を交換するだけです。

q と k の行列乗算によって得られる注意特徴マップは次のとおりです。まだ最初の頭部のみを視覚化しています (合計 12 の頭部があり、各頭部の注意特徴マップは異なります)。

次に、sofmax を使用して、すべての頭の注意スコアを計算します。 

最初の頭の注目スコアは次のとおりです。

tensor([[9.9350e-01, 2.5650e-05, 2.6444e-05, ..., 3.7445e-05, 3.3614e-05,
         2.7365e-05],
        [3.7948e-01, 2.3743e-01, 8.7877e-02、...、2.2976e-05、1.2177e-
         04、6.6991e-04]、
        [3.7756e-01、1.2583e-01、1.4249e-01、...、1.0860e-05、 3.4743e-05、1.1384e
         -04]、
        ...、
        [4.1151e-01、3.6945e-05、9.8513e-06、...、1.5886e-01、1.1042e-
         01、4.4855e-02] 、
        [4.0967e-01、1.7754e-04、2.8480e-05、...、1.0884e-01、1.4333e-
         01、1.2111e-01]、
        [4.1888e-01、6.8779e-04、6.7465e -05, ..., 3.5659e-02, 9.4098e-02,
         2.2174e-01]], デバイス='cuda:0')

取得された注意スコアに v を乗じて、各チャネルの寄与度を取得します。 

次に MLP レイヤーを追加すると、最後に Transformer ブロックを取得できます。 

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        drop_probs = (drop, drop)

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
    def forward(self, x):
        '''
        :param x: 输入序列
        x --> layer_norm --> mulit head attention --> + --> x --> layer_norm --> mlp --> +-->x
        |____________________________________________|     |_____________________________|

        '''
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

ビットコード

GitHub - YINYIPENG-EN/vit_classification_pytorch: vit を使用して画像分類を実装します。GitHub でアカウントを作成して、YINYIPENG-EN/vit_classification_pytorch の開発に貢献してください。https: icon-default.png?t=N7T8//github.com/YINYIPENG-EN/vit_classification_pytorch.git

class VisionTransformer(nn.Module):
    def __init__(
            self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
            depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
            norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
        ):
        """
        输入大小为224,以16*16的卷积核分块14*14
        :param input_shape: 网络输入大小
        :param patch_size:  分块大小
        :param in_chans:  输入通道
        :param num_classes:  类别数量
        :param num_features: 特征图维度
        :param num_heads:  多头注意力机制head的数量
        :param mlp_ratio: MLP ratio
        :param qkv_bias: qkv的bias
        :param drop_rate: dropout rate
        :param norm_layer: layernorm
        :param act_layer: 激活函数
        """
        super().__init__()
        #-----------------------------------------------#
        #   224, 224, 3 -> 196, 768
        #-----------------------------------------------#
        self.patch_embed    = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_channels=in_chans, num_features=num_features)
        num_patches         = (224 // patch_size) * (224 // patch_size)
        self.num_features   = num_features
        self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
        self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]

        self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))  # shape [1,1,768]
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))  # shape [1, 197, 768]

        # -----------------------------------------------#
        #   197, 768 -> 197, 768  12次
        # -----------------------------------------------#
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.Sequential(
            *[
                Block(
                    dim=num_features,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                    act_layer=act_layer
                ) for i in range(depth)
            ]
        )
        self.norm = norm_layer(num_features)
        self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()


    def forward_features(self, x):
        x = self.patch_embed(x)  # 先分块 [1,196, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # [1,1,768]
        x = torch.cat((cls_token, x), dim=1)  # [1,197,768]
        cls_token_pe = self.pos_embed[:, 0:1, :]  # 获取class token pos_embed 【类位置信息】
        img_token_pe = self.pos_embed[:, 1:, :]  # 后196维度是图像特征的位置信息

        img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)  # [1,768,14,14]
        img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
        img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)  # [1,196,768]
        pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)  # [1,197,768] 获得最终的位置信息

        x = self.pos_drop(x + pos_embed)  # 将位置信息和图像序列相加

        x = self.blocks(x)  # 特征提取
        x = self.norm(x)
        return x[:, 0]

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

    def freeze_backbone(self):
        backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
        for module in backbone:
            try:
                for param in module.parameters():
                    param.requires_grad = False
            except:
                module.requires_grad = False

    def Unfreeze_backbone(self):
        backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
        for module in backbone:
            try:
                for param in module.parameters():
                    param.requires_grad = True
            except:
                module.requires_grad = True

おすすめ

転載: blog.csdn.net/z240626191s/article/details/132504292