VisionTransformer(一)—— Embedding Patched与Word embedding及其实现

Embedding Patched与Word embedding及其实现


前言

VisionTransformer可以说火到不得不会的程度,而本人之前其实对NLP领域了解不是很多,在学习中,认为在VIT论文里比较值得学习的地方有两点,一个是对图片的预处理成image token的Embedding Patched,另一个则是Transformer模块里的多头注意力模块,这次先讲讲个人Embedding Patched的理解。

零、VIT是什么?

在了解其他东西前,先对VIT做一个,个人简单的理解与概况。 

这里简单的说VIT其实就是作者想对image采取和context一样的处理方法,将image像context一样处理成一个个token,然后送到transform中,然后再接上一个分类头,就得到了一个基于transform的分类器了。

所以说其实要搞懂VIT,我画了一个图,其实就是两部分:

  • 一部分就是如何将image处理成token的样子——Embedding Patch。
  • 另一部分就是transformer,而这里的transformer相较与NLP里的transformer是没有Decoder部分的,所以只有Encoder。而Encoder部分网络,除了多头注意力模块——Multi-Head Attention以外的其他部分实现起来是很简单的,所以另一部分就要研究的就是这个多头注意力机制。 

而这篇文章,主要讲解一下个人对Embedding Patch的理解。 

 一、Word Embedding

想要对Embedding Patch有比较好的了解,个人认为有必要简单介绍一下在NLP领域里的Word Embedding技术,对比学习,会有更深的理解。

1)为什么要有Word Embedding

Word Embedding简单的说,其实就是一种token(词)到向量的映射编码

为什么需要做这件事呢,下面举个例子来说明这件事。

有一个句子:今天天气不错,我要去看电影。假如我们想要机器就认识这个句子,那我们可以采取先对这个句子每个部分进行分词,然后对每个分词出来的单词进行编码,那么下次在遇到相应的单词时,就可以通过查这个编码,去获得这个句子的意思了。

今天天气不错,我要去看电影。可以通过分词划分为,今天/天气/不错/,//要去//电影这8个词,那么我们对这八个词进行one-hot编码,比如今天可以得到编码为[1,0,0,0,0,0,0,0],而则被编码为[0,0,0,0,1,0,0,0].

那么当下次遇到句子:今天去看电影,机器只要先对这个句子进行分词,然后在自己的码表中查找相应的编码,那么机器就可以认识这个句子了。

但我们现在考虑两个问题就是:

  1. 中文的字和词太多了,如果按照这个方法进行one-hot编码,那么这个码表,少说也是一个5000*5000的稀疏矩阵,如果要用这个矩阵去给机器进行学习,是十分耗费内存和浪费时间的。
  2. 使用one-hot编码,词和词之间丢失了关联性。比如对于中文来说,应该属于近义词,英语里catcats更应该是相近的词,但如果使用了one-hot编码,词与词之间的相似性被丢弃了,也不利于机器进行学习。

所以就引入了word embedding的做法。 

2)Word Embedding在做什么

现在,每个token不仅停留在独热编码,而是把每个token的独热编码再映射为N维(embedded_dim)空间上的点。比如今天可以再次后编码为[0.1,0.2,0.3], 则编码为[0.5,0.6,0.6].

所以说整个word embedding(广义)具体在做什么事情,我用两步来概括:

  1. 对context进行分词操作。
  2. 对分好的词进行one-hot编码,根据学习相应的权重对one-hot编码进行N(embedded_dim)维空间的映射.

这里针对第二点,以我们上面的例子来说,今天天气不错,我要去看电影这个句子,通过one-hot编码,整个句子可以表示为一个8X8的矩阵,我们通过学习一个权重矩阵,其大小为8X(embedded_dim),那么与其相乘,就相对于做了一个到(embedded_dim)维的映射,得到8X(embedded_dim)的矩阵。

而如果我们对这N维的空间上的点进行降维,会发现意思相近的词,相互靠近的情况

 例如上图,man和king的点更加相近,cat和cats也更加相近。所以word embedding即解决了one-hot编码稀疏矩阵的问题,又使编码的向量具有的语义信息。

、Embedding Patch

word embedding是针对context进行编码,便于使机器进行学习的方法,而Embedding patch则是针对image进行编码,便于机器学习的方法。而像作者说的,作者的本义其实就是在想,将image当成context一样去处理。

所以Embedding patch也其实在做两步:

  1. 将图片像分词一样划分
  2. 将分好的图片(我们这里称为Patch)进行N(embedded_dim)维空间的映射。

 1)将图片进行划分成Patch

对context进行分词实际上比较简单,比如说英语里的句子,基本就是按照空格进行划分,这个没有什么问题。但图片没有明显的分开处, 不过最直观的想法就是将二维的图片直接拉成一维的向量,如28X28的图片,拉成1X784长度的向量,将784维的向量当成context,然后去做word embedding。但这种方法问题就在与消耗太大,NLP领域处理一个14X14=196长度的句子已经算是比较费时的事情,更何况只是28*28的图片,照现在CV领域处理图片都基本在224X224往上,明显不行。

那么我们换一个思路,我们将图片先分成一个个(PXP)的小块,这么我们设P为7,那么一个28X28的图片,就可以被划分成16个7X7的图片了,我们再拉平的话,对于transformer来说就变的可以处理了。

                              

2) N(embeded_dim)维空间映射

我们现在对图片划分成Patch并且将图片拉平,相对于完成了context中的对句子分词并且one-hot编码的工作,我们以28X28的图片为例,我们现在得到的是一个16X49的矩阵。接下来我们可以像context一样,去构造一个可以学习的49(PXP) X (embedded_dim)的权重矩阵,那么16X49的矩阵与其相乘,得到一个16Xembedded_dim的向量,就相对于对其进行了embedded_dim维的映射了。

                  

官方给的图就是我所说的意思。

3)实现 Embedding Patch

 但真正实现Embedding Patch并不需要所说的这么麻烦,因为对于图片来说,我们可以通过卷积操作就可以直接完成对图片的分块以及embedded_dim维映射的关系。

 我们要做的与这个动图有所区别的地方是,我们每个Patch是不重复的(当然也有人会划分成重复的),即每个stride的长度为P。一个28X28的图片,通过一个7X7Xembedded_dim的kernel,且stride是7,再拉平,就得到了一个embedded_dimX16的向量了,经过转置后为16Xembedded_dim的矩阵,是不是就和我们上面的结果一样了。

那么我们只需要通过一个卷积操作就相对于完成了我们的Embedding Patch。

具体代码如下

class PatchEmbedding(nn.Module):
    def __init__(self,
                 patch_size,
                 in_channels,
                 embedded_dim,
                 dropout=0.):
        super().__init__()
        self.patch_embedded = nn.Conv2d(in_channels=in_channels,
                                        out_channels=embedded_dim,
                                        kernel_size=patch_size,
                                        stride=patch_size,
                                        bias=False)
        # 这里加了dropout的操作
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # X = [batchsize, 1, 28, 28]
        x = self.patch_embedded(x)
        # X = [batchsize, embedded_dim, h, w]
        x = x.flatten(2)
        # X = [batchsize, embedded_dim, h*w]
        x = x.transpose(2, 1)
        # X = [batchisize, h*w, embedded_dim]
        x = self.dropout(x)
        return x

经过 Embedding Patch之后,其实图片就相对于是token了,后面就是transformer的事了。

所以VIT论文的题目也叫做AN IMAGE IS WORTH 16X16 WORDS。

当然其实除了Embedding Patch,其实还有一个Position embedded和cls token的东西,下面简单介绍一下吧。

4)Position embedded

position embedded是考虑到如果就这么简单的划分Patch,没有考虑到图片间的位置关系问题,这个以一个context为例就很好理解了。

句子1:我来到这里坐公交车

句子2:我坐公交车到这里来 

很明显,由于坐公交车在句子中的位置,导致语义完全不同,而图片之间也有类似这样的关系,所以我们要引入图片间位置的关系。而每个分好的Patch如何确定位置了,作者简单的直接在原来Patch的基础上,接上一个与之形状相应的权重矩阵,让网络自己学就好了,得到的效果也不错。

5)Class token(代表分类的标签)

还有一个小细节是 Class token,这个东西原本是Bert里面用于区分句子的情感或者对句子内容进行分类的一个技巧。如果想要了解比较详细,可以去看看Bert的论文。

所以其实我们如果用VIT做分类任务的话,最后送入分类头的内容其实是这个Class token里面的东西,所以说class token是一个与我们分出Patch一样维度的矩阵,它是我们最后送入分类器进行学习的对象,根据反向传播,这个class token会自动强制的去向其他image token学习相应的特征。

所以说最后其实完整的Embedding Patch的图应该是下面的样子

6)完整细节代码

class PatchEmbedding(nn.Module):
    def __init__(self,
                 image_size,
                 patch_size,
                 in_channels,
                 embedded_dim,
                 dropout=0.):
        super().__init__()
        num_patches = (image_size // patch_size) * (image_size // patch_size)
        self.patch_embedding = nn.Conv2d(in_channels=in_channels,
                                        out_channels=embedded_dim,
                                        kernel_size=patch_size,
                                        stride=patch_size,
                                        bias=False)
        self.dropout = nn.Dropout(dropout)
        class_token = torch.zeros(
            size=(1, 1, embedded_dim)
        )
        self.class_token = nn.parameter.Parameter(class_token)

        position_embedding = truncate_normal(size=(1, num_patches+1, embedded_dim))

        self.position_embedding = nn.parameter.Parameter(position_embedding)

    def forward(self, x):
        class_tokens = self.class_token.expand([x.shape[0], -1, -1])
        # X = [batchsize, 1, 28, 28]
        x = self.patch_embedding(x)
        # X = [batchsize, embedded_dim, h, w]
        x = x.flatten(2)
        # X = [batchsize, embedded_dim, h*w]
        x = x.transpose(2, 1)
        # X = [batchisize, h*w, embedded_dim]
        x = torch.concat([class_tokens, x], axis=1)
        # X = [batchisize, h*w+1, embedded_dim]
        x = x + self.position_embedding
        # X = [batchisize, h*w+1, embedded_dim]
        x = self.dropout(x)
        return x

    @staticmethod
    def truncate_normal(size, std=1, mean=0):
        # 返回一个截断正态分布的tensor
        lower, upper = mean - 2 * std, mean + 2 * std  # 截断在[μ-2σ, μ+2σ]
        X = stats.truncnorm((lower - mean) / std, (upper - mean) / std, loc=mean, scale=std)
        size_ = 1
        shape = []
        for s in size:
            size_ *= s
            shape.append(s)

        X = np.array(X.rvs(size_), dtype='float32')
        X = torch.from_numpy(X)
        X = X.reshape(size)
        return X

总结

 整篇写比较口语化,且都是个人的一些学习理解,如果有出错的地方,请在评论区指出,欢迎讨论探讨。

猜你喜欢

转载自blog.csdn.net/lzzzzzzm/article/details/122902777