Transformer——patch embedding code

Simple version of ViT (without attention part)

Mainly record how to deal with Patch Embedding and understand the simple basic framework of vit. The next section will write the complete ViT framework.


How to deal with the Transformer on the image? As shown in
the picture->block patch---->mapping (learnable)---->feature
Insert image description here
overall network structure:
Insert image description here

Insert image description here

Practical part:

Patch Embedding is used to convert the original 2-dimensional image into a series of 1-dimensional patch embeddings.
Patch Embedding part of the code:

class PatchEmbedding(nn.Module):
    def __init__(self,image_size, in_channels,patch_size, embed_dim,dropout=0.):
        super(PatchEmbedding, self).__init__()
        #patch_embed相当于做了一个卷积
        self.patch_embed=nn.Conv2d(in_channels,embed_dim,kernel_size=patch_size,stride=patch_size,bias=False)
        self.drop=nn.Dropout(dropout)

    def forward(self,x):
        # x[4, 3, 224, 224]
        x=self.patch_embed(x)
        # x [4, 16, 32, 32]
        # x:[n,embed_dim,h',w']
        x = x.flatten(2)  #将x拉直,h'和w'合并   [n,embed,h'*w']   #x [4, 16, 1024]
        x = x.permute(0,2,1)     # [n,h'*w',embed]      #x [4, 1024, 16]
        x = self.drop(x)
        print(x.shape)           #    [4, 1024, 16] 对应[batchsize,num_patch,embed_dim]
        return x

ViT part of the code:
the attention part is omitted

class Vit(nn.Module):
    def __init__(self):
        super(Vit, self).__init__()
        self.patch_embed=PatchEmbedding(224, 3, 7, 16)     #  image tokens
        layer_list = [Encoder(16) for i in range(5)]   # 假设有5层encoder,Encoder维度16
        self.encoders=nn.Sequential(*layer_list)
        self.head=nn.Linear(16,10)     #做完5层Encoder后的输出维度16,最后做分类num_classes为10
        self.avg=nn.AdaptiveAvgPool1d(1)       # 所有tensor去平均

    def forward(self,x):
        x=self.patch_embed(x)      # #x [4, 1024, 16]
        for i in self.encoders:
            x=i(x)
        # [n,h*w,c]
        x=x.permute((0,2,1))  # [4, 16, 1024]
        # [n,c,h*w]
        x=self.avg(x)  # [n,c,1]  [4, 16, 1]
        x=x.flatten(1)  # [n,c]  [4,16]
        x=self.head(x)
        return x

Complete code:

from PIL import Image
import numpy as np
import torch
import torch.nn as nn

# Identity  什么都不做
class Identity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

 #在Mlp中,其实就是两层全连接层,该mlp一般接在attention层后面。首先将16的通道膨胀4倍到64,然后再缩小4倍,最终保持通道数不变。
class Mlp(nn.Module):
    def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.):       #  mlp_ratio就是膨胀参数
        super(Mlp, self).__init__()
        self.fc1 = nn.Linear(embed_dim, int(embed_dim * mlp_ratio))       # 膨胀
        self.fc2 = nn.Linear(int(embed_dim * mlp_ratio), embed_dim)      # 尺寸变回去
        self.act = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self,x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class PatchEmbedding(nn.Module):
    def __init__(self,image_size, in_channels,patch_size, embed_dim,dropout=0.):
        super(PatchEmbedding, self).__init__()
        #patch_embed相当于做了一个卷积
        self.patch_embed=nn.Conv2d(in_channels,embed_dim,kernel_size=patch_size,stride=patch_size,bias=False)
        self.drop=nn.Dropout(dropout)

    def forward(self,x):
        # x[4, 3, 224, 224]
        x=self.patch_embed(x)
        # x [4, 16, 32, 32]
        # x:[n,embed_dim,h',w']
        x = x.flatten(2)  #将x拉直,h'和w'合并   [n,embed,h'*w']   #x [4, 16, 1024]
        x = x.permute(0,2,1)     # [n,h'*w',embed]      #x [4, 1024, 16]
        x = self.drop(x)
        print(x.shape)           #    [4, 1024, 16] 对应[batchsize,num_patch,embed_dim]
        return x

class Encoder(nn.Module):
    def __init__(self,embed_dim):
        super(Encoder, self).__init__()
        self.atten = Identity()      # self-attention部分先不去实现
        self.layer_nomer = nn.LayerNorm(embed_dim)   # LN层
        self.mlp = Mlp(embed_dim)
        self.mlp_nomer = nn.LayerNorm(embed_dim)


    def forward(self,x):
        # 参差结构
        h = x
        x = self.atten(x)  # 先做self-attention
        x = self.layer_nomer(x)  # 再做LN层
        x = h+x

        h = x
        x = self.mlp(x)  #先做FC层
        x = self.layer_nomer(x)  # 再做LN层
        x = h + x

        return x



class Vit(nn.Module):
    def __init__(self):
        super(Vit, self).__init__()
        self.patch_embed=PatchEmbedding(224, 3, 7, 16)     #  image tokens
        layer_list = [Encoder(16) for i in range(5)]   # 假设有5层encoder,Encoder维度16
        self.encoders=nn.Sequential(*layer_list)
        self.head=nn.Linear(16,10)     #做完5层Encoder后的输出维度16,最后做分类num_classes为10
        self.avg=nn.AdaptiveAvgPool1d(1)       # 所有tensor去平均

    def forward(self,x):
        x=self.patch_embed(x)      # #x [4, 1024, 16]
        for i in self.encoders:
            x=i(x)
        # [n,h*w,c]
        x=x.permute((0,2,1))  # [4, 16, 1024]
        # [n,c,h*w]
        x=self.avg(x)  # [n,c,1]  [4, 16, 1]
        x=x.flatten(1)  # [n,c]  [4,16]
        x=self.head(x)
        return x


def test():
    # 1. create a image
    img=np.array(Image.open('test.jpg'))   # 224x224
    t = torch.tensor(img, dtype=torch.float32)
    print(t.shape)                # [224, 224, 3]
    sample = t.reshape([4,3,224,224])      # 将[224, 224, 3]reshape成一行
    print(sample)
    #print(t.transpose(1,0))

    # 2. patch embedding--------Patch Embedding用于将原始的2维图像转换成一系列的1维patch embeddings
    # patch_size是切分的大小,原始224 ∗ 224 ∗ 3 的图片会首先变成32 ∗ 32 ∗ 16
    # in_channel rgb图是3
    # embed_dim是需要映射的dim

    patch_embedding = PatchEmbedding(image_size=224, patch_size=7, in_channels=3, embed_dim=1)
    # 做前向操作
    out = patch_embedding(sample)
    print(out)
    #print(out.shape)

    mlp=Mlp(embed_dim=1)
    out = mlp(out)
    print(out.shape)

def main():
    t = torch.randn([4,3,224,224])
    model=Vit()
    out=model(t)
    print(out.shape)


if __name__ == "__main__":
    main()

Finally output [4,10]
and write the complete ViT code in the next section.

Guess you like

Origin blog.csdn.net/qq_42740834/article/details/124994344
Recommended