Detailed explanation of VIT source code

1. Project configuration instructions

        Parameter Description:

data set:  

       --name cifar10-100_500

       --dataset cifar10

Which version of the model: 

       --model_type ViT-B_16

Pretrained weights: 

       --pretrained_dir checkpoint/ViT-B_16.npz

2.patch embeding与position_embedding

        For image coding, take VIT-B/16 as an example, first use the convolution with a convolution kernel size of 16*16 and a step size of 16 to transform the image. At this time, the image dimension becomes 16 * 768 * 14 * 14 , and then transform the dimension to [16, 196, 768], and then connect the 0patch with the dimension of 16*1*768.

        For positional encoding, construct a vector of 1*197*768

        Finally, the encoding is completed by adding the image encoding and position encoding.

code show as below:

class Embeddings(nn.Module):
    """Construct the embeddings from patch, position embeddings.
    """
    def __init__(self, config, img_size, in_channels=3):
        super(Embeddings, self).__init__()
        self.hybrid = None
        img_size = _pair(img_size)

        # patch_size 大小 与 patch数量  n_patches
        if config.patches.get("grid") is not None:
            grid_size = config.patches["grid"]
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
            n_patches = (img_size[0] // 16) * (img_size[1] // 16)
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
            self.hybrid = False

        # 使用混合模型
        if self.hybrid:
            self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers,
                                         width_factor=config.resnet.width_factor)
            in_channels = self.hybrid_model.width * 16
        # patch_embeding 16 * 768 * 14 * 14
        self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=config.hidden_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)
        # 初始化 position_embeddings: 1 * 197 * 768
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
        # 初始化第 0 个patch,表示分类特征 1*1*768
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        # dropout层
        self.dropout = Dropout(config.transformer["dropout_rate"])

    def forward(self, x):
        print(x.shape)
        B = x.shape[0]
        # 拓展cls_tokens的维度:16 *1*768
        cls_tokens = self.cls_token.expand(B, -1, -1)
        print(cls_tokens.shape)
        # 混合模型
        if self.hybrid:
            x = self.hybrid_model(x)
        # 编码:16 * 768 * 14 * 14
        x = self.patch_embeddings(x)
        print(x.shape)
        # 变换维度:16 * 768 * 14 * 14-->[16, 768, 196]
        x = x.flatten(2)
        print(x.shape)
        # [16, 768, 196] --> [16, 196, 768]
        x = x.transpose(-1, -2)
        print(x.shape)
        # 加入分类特征patch
        x = torch.cat((cls_tokens, x), dim=1)
        print(x.shape)

        # 加入位置编码
        embeddings = x + self.position_embeddings
        print(embeddings.shape)
        # dropout层
        embeddings = self.dropout(embeddings)
        print(embeddings.shape)
        return embeddings

3.ecoder 

Multi-head attention module:

        First construct three auxiliary vectors of q, k, and v, because we use a multi-head attention mechanism (12), first, we need to convert the dimensions of q, k, and v from 16, 197, 768 to 16, 12, 197, 64 , and then obtain the similarity qk of q and k, because the relationship between the two is obtained, so the dimensions are 16, 12, 197, 197, and the dimension is eliminated. After softmax, the extracted feature vector qkv is obtained, and the dimension 16, 12, 197, 64, and then restore the dimension to 16, 197, 768

class Attention(nn.Module):
    def __init__(self, config, vis):
        super(Attention, self).__init__()
        self.vis = vis
        # heads数量
        self.num_attention_heads = config.transformer["num_heads"]
        # 每个head的向量维度
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        # 总head_size
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        # query向量
        self.query = Linear(config.hidden_size, self.all_head_size)
        # key向量
        self.key = Linear(config.hidden_size, self.all_head_size)
        # value向量
        self.value = Linear(config.hidden_size, self.all_head_size)
        # 全连接层
        self.out = Linear(config.hidden_size, config.hidden_size)
        # dropout层
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):
        # 维度:16, 197, 768-->16,197,12,64
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        # print(new_x_shape)
        x = x.view(*new_x_shape)
        # print(x.shape)
        # print(x.permute(0, 2, 1, 3).shape)
        # 16,197,12,64 --> 16, 12, 197, 64
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        # print(hidden_states.shape)
        # q,k,v:16, 197, 768
        mixed_query_layer = self.query(hidden_states)
        # print(mixed_query_layer.shape)
        mixed_key_layer = self.key(hidden_states)
        # print(mixed_key_layer.shape)
        mixed_value_layer = self.value(hidden_states)
        # print(mixed_value_layer.shape)
        # q,k,v:16, 197, 768-->16, 12, 197, 64
        query_layer = self.transpose_for_scores(mixed_query_layer)
        # print(query_layer.shape)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        # print(key_layer.shape)
        value_layer = self.transpose_for_scores(mixed_value_layer)
        # print(value_layer.shape)
        # q,k的相似性:16, 12, 197, 197
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        # print(attention_scores.shape)
        # 消除量纲
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        # print(attention_scores.shape)
        attention_probs = self.softmax(attention_scores)
        # print(attention_probs.shape)
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)
        # print(attention_probs.shape)
        # print(value_layer.shape)
        # 特征向量:qkv:16, 12, 197, 64
        context_layer = torch.matmul(attention_probs, value_layer)
        # print(context_layer.shape)
        # 16, 12, 197, 64-->16, 12, 197, 64
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        # print(context_layer.shape)
        # 16, 12, 197, 64-->16, 197, 768
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        # print(context_layer.shape)
        # 全连接层:16, 197, 768
        attention_output = self.out(context_layer)
        # print(attention_output.shape)
        # dropout层
        attention_output = self.proj_dropout(attention_output)
        # print(attention_output.shape)
        return attention_output, weights

      transformer encoder 

        For the input x, firstly, after layer normalization, input the multi-head attention mechanism, perform residual connection on the result, then go through layer normalization, after two layers of full connection, after residual connection, get a module result, stack L layer, output the final result 

 

class Block(nn.Module):
    def __init__(self, config, vis):
        super(Block, self).__init__()
        # 序列的大小:768
        self.hidden_size = config.hidden_size
        # 层归一化
        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
        # MLP层
        self.ffn = Mlp(config)
        # 多头注意力机制
        self.attn = Attention(config, vis)

    def forward(self, x):
        # print(x.shape)
        # 16, 197, 768
        h = x
        # 层归一化
        x = self.attention_norm(x)
        # print(x.shape)
        # 多头注意力机制
        x, weights = self.attn(x)
        # 残差连接
        x = x + h
        # print(x.shape)

        h = x
        # 层归一化
        x = self.ffn_norm(x)
        # print(x.shape)
        # MLP层
        x = self.ffn(x)
        # print(x.shape)
        # 残差连接
        x = x + h
        # print(x.shape)
        return x, weights

Overall structure

        For the input x, after patch embedding and position embedding, the dimension is 16*197*768 at this time, input to the encoder, after the encoding module of the L layer, the encoding result of the 0th patch (representing the classification feature) is taken out, and input to the classification layer , to get the predicted result.

class VisionTransformer(nn.Module):
    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.classifier = config.classifier

        self.transformer = Transformer(config, img_size, vis)
        self.head = Linear(config.hidden_size, num_classes)

    def forward(self, x, labels=None):
        x, attn_weights = self.transformer(x)
        print(x.shape)
        # X.shape:16, 197, 768   logits.shape:16, 10
        logits = self.head(x[:, 0])
        print(logits.shape)
        # 交叉熵
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
            return loss
        else:
            return logits, attn_weights

 

 

 

         

 

 

Guess you like

Origin blog.csdn.net/qq_52053775/article/details/126261070#comments_27453747