论文解读:BIT | Remote Sensing Image Change Detection with Transformers

论文解读:BIT | Remote Sensing Image Change Detection with Transformers
论文地址:https://arxiv.org/pdf/2103.00208.pdf
项目地址:https://github.com/justchenhao/BIT_CD
在这里插入图片描述

现代变化检测(CD)凭借其强大的深度卷积识别能力取得了显著的成功。然而,由于场景中物体的复杂性,高分辨率遥感CD仍然具有挑战性。在这里,我们提出了一个bitemporal image transformer(BIT)来有效地建模时空域内的上下文。.我们的直觉是,兴趣变化的高级概念可以用一些视觉单词来表示,即语义token。为了实现这一点,我们将双时态图像表示为几个token,并使用transformer编码器在紧凑的基于token的时空中建模上下文。然后将学习到的上下文丰富的tokens反馈到像素空间,通过transformer解码器细化原始特征。我们将BIT合并到一个深度的基于特征差异的CD框架中。在三个CD数据集上进行的大量实验证明了该方法的有效性和有效性。值得注意的是,我们的基于bit的模型仅使用了3倍的计算成本和模型参数,其性能显著优于纯卷积基线。基于一个没有复杂结构(如FPN,UNet)的朴素主干(ResNet18),我们的模型超过了几种最先进的CD方法,包括在效率和准确性方面优于最近的四种基于关注的方法.

基本总结

  • 1、传统方法的感受野有限,在大面积物体的变化检测中容易存在漏检(可见章节3.2中的实验效果
  • 2、BIT与同时期的方法相比,属于flop低、参数少的轻量化模型;且基于transformer后,模型的泛化能力更为强劲(可见章节3.3中的两个图表
  • 3、论文本质就是原始transformer在语义分割中的应用,写作较为丰富和扎实,并未针对变化检测的数据特性提出网络结构或module的设计(可见章节4中的消融实验)。文章在于写作方面建议阅读原文,可以学习各种对比实验的设置、网络layer的公式化等等。
  • 4、作者通过系列消融实验证明了原始transformer中Token(特征投影到图像语义空间)、position embedding(transformer编码器和解码器中对x的累加值)是必须的,并提供了相应的可视化图表。然而在segformer和changeformer中并没有position embedding和复杂的Token编码,仅使用conv实现OverlapPatchEmbed即达到了较好的性能。

1、论文总览

引入了双时间图像变换器(BIT)来有效地对双时间图像中的随机上下文进行建模。作者认为,感兴趣变化区域的高级概念可以用一些视觉单词来表示,即语义tokens。BIT没有在像素空间中建模像素之间的密集关系,而是将输入的图像表示为一些高级语义tokens,并在一个紧凑的基于tokens的时空中对上下文进行建模。此外,我们通过利用每个像素和语义tokens之间的关系来增强原始像素空间的特征表示。图1给出了一个例子来显示我们的比特对图像特征的影响。考虑到与建筑概念相关的原始图像特征(见图1 (b)),BIT学会了通过考虑时空上的全局上下文来进一步一致地突出建筑区域(见图1 ©)。请注意,我们展示了增强特征和原始特征之间的差异图像,以更好地展示所提议的BIT的作用。

我们将BIT合并到一个基于深度特征差异的CD框架中。我们的基于bit的模型的总体过程如图2所示。利用CNN主干网(ResNet)从输入的图像对中提取高级语义特征。我们利用空间注意将每个时间特征映射转换为一组紧凑的语义tokens。然后,我们使用一个transformer[15]编码器来建模在两个令牌集中的上下文。由此产生的上下文丰富的tokens被Siamese-transformer译码器重新投影到像素空间,以增强原始的像素级特征。最后,我们从两个细化的特征图中计算特征差异图像(FDI),然后将它们输入一个浅层CNN,以产生像素级的变化预测。

总体贡献为:

  • 提出了一种有效的基于transformer的遥感图像变化检测方法。我们在CD任务中引入了transformer,以更好地在双时间图像中建模上下文,这有助于识别感兴趣的变化并排除不相关的变化。
  • 我们的BIT没有建模像素空间中任何一对元素对之间的密集关系,而是将输入的图像表示为几个视觉单词,即token,并在基于紧凑tokens的时空中对上下文建模。
  • 在三个CD数据集上的大量实验验证了该方法的有效性和有效性。我们用BIT替换了ResNet18的最后一个卷积阶段,所得到的基于BIT的模型仅以降低3倍的计算成本和模型参数的性能优于纯卷积模型。在一个没有复杂结构(如FPN,UNEN)的朴素CNN主干,我们的方法在效率和准确性方面显示出比最近几种基于注意力的CD方法的更好的性能。

2、 方法

基于bit的模型的总体过程如图2所示,将BIT合并到一个正常的变化检测pipeline(最后一个特征图输出)中,因为我们想要利用卷积和transformer的优势。我们的模型从几个卷积块开始,获得每个输入图像的特征图,然后将它们输入到BIT,生成增强的双时态特征。最后,将得到的特征图输入到一个预测头,以产生像素级的预测。我们的关键见解是,BIT学习并关联高级语义概念的全局上下文,并进行反馈,以有利于原始的双时间特征。

BIT有三个主要组件:
1)Siamese Semantic tokenizer,将像素到概念生成一个紧凑的语义tokens为每个时间输入,
2)transformer encoder,建模语义上下文概念的时空,
3)Siamese transformer,投影相应的语义tokens回到像素空间获得每个时间的精炼特征映射。
算法1显示了基于变化检测的BIT模型的推理细节。

BIT对应的实现代码如下,可以看到所对x_feature进行提取的token用于transformer encoder编码,在transformer decoder中才将x_feature与token合并。最后是简单的分类操作

class BASE_Transformer(ResNet):
    """
    Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN
    """
    def __init__(self, input_nc, output_nc, with_pos, resnet_stages_num=5,
                 token_len=4, token_trans=True,
                 enc_depth=1, dec_depth=1,
                 dim_head=64, decoder_dim_head=64,
                 tokenizer=True, if_upsample_2x=True,
                 pool_mode='max', pool_size=2,
                 backbone='resnet18',
                 decoder_softmax=True, with_decoder_pos=None,
                 with_decoder=True):
        super(BASE_Transformer, self).__init__(input_nc, output_nc,backbone=backbone,
                                             resnet_stages_num=resnet_stages_num,
                                               if_upsample_2x=if_upsample_2x,
                                               )
        self.token_len = token_len
        self.conv_a = nn.Conv2d(32, self.token_len, kernel_size=1,
                                padding=0, bias=False)
        self.tokenizer = tokenizer
        if not self.tokenizer:
            #  if not use tokenzier,then downsample the feature map into a certain size
            self.pooling_size = pool_size
            self.pool_mode = pool_mode
            self.token_len = self.pooling_size * self.pooling_size

        self.token_trans = token_trans
        self.with_decoder = with_decoder
        dim = 32
        mlp_dim = 2*dim

        self.with_pos = with_pos
        if with_pos is 'learned':
            self.pos_embedding = nn.Parameter(torch.randn(1, self.token_len*2, 32))
        decoder_pos_size = 256//4
        self.with_decoder_pos = with_decoder_pos
        if self.with_decoder_pos == 'learned':
            self.pos_embedding_decoder =nn.Parameter(torch.randn(1, 32,
                                                                 decoder_pos_size,
                                                                 decoder_pos_size))
        self.enc_depth = enc_depth
        self.dec_depth = dec_depth
        self.dim_head = dim_head
        self.decoder_dim_head = decoder_dim_head
        self.transformer = Transformer(dim=dim, depth=self.enc_depth, heads=8,
                                       dim_head=self.dim_head,
                                       mlp_dim=mlp_dim, dropout=0)
        self.transformer_decoder = TransformerDecoder(dim=dim, depth=self.dec_depth,
                            heads=8, dim_head=self.decoder_dim_head, mlp_dim=mlp_dim, dropout=0,
                                                      softmax=decoder_softmax)

    def _forward_semantic_tokens(self, x):
        b, c, h, w = x.shape
        spatial_attention = self.conv_a(x)
        spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous()
        spatial_attention = torch.softmax(spatial_attention, dim=-1)
        x = x.view([b, c, -1]).contiguous()
        tokens = torch.einsum('bln,bcn->blc', spatial_attention, x)

        return tokens

    def _forward_reshape_tokens(self, x):
        # b,c,h,w = x.shape
        if self.pool_mode is 'max':
            x = F.adaptive_max_pool2d(x, [self.pooling_size, self.pooling_size])
        elif self.pool_mode is 'ave':
            x = F.adaptive_avg_pool2d(x, [self.pooling_size, self.pooling_size])
        else:
            x = x
        tokens = rearrange(x, 'b c h w -> b (h w) c')
        return tokens

    def _forward_transformer(self, x):
        if self.with_pos:
            x += self.pos_embedding
        x = self.transformer(x)
        return x

    def _forward_transformer_decoder(self, x, m):
        b, c, h, w = x.shape
        if self.with_decoder_pos == 'fix':
            x = x + self.pos_embedding_decoder
        elif self.with_decoder_pos == 'learned':
            x = x + self.pos_embedding_decoder
        x = rearrange(x, 'b c h w -> b (h w) c')
        x = self.transformer_decoder(x, m)
        x = rearrange(x, 'b (h w) c -> b c h w', h=h)
        return x

    def _forward_simple_decoder(self, x, m):
        b, c, h, w = x.shape
        b, l, c = m.shape
        m = m.expand([h,w,b,l,c])
        m = rearrange(m, 'h w b l c -> l b c h w')
        m = m.sum(0)
        x = x + m
        return x

    def forward(self, x1, x2):
        # forward backbone resnet
        x1 = self.forward_single(x1)
        x2 = self.forward_single(x2)

        #  forward tokenzier
        if self.tokenizer:
            token1 = self._forward_semantic_tokens(x1)
            token2 = self._forward_semantic_tokens(x2)
        else:
            token1 = self._forward_reshape_tokens(x1)
            token2 = self._forward_reshape_tokens(x2)
        # forward transformer encoder
        if self.token_trans:
            self.tokens_ = torch.cat([token1, token2], dim=1)
            self.tokens = self._forward_transformer(self.tokens_)
            token1, token2 = self.tokens.chunk(2, dim=1)
        # forward transformer decoder
        if self.with_decoder:
            x1 = self._forward_transformer_decoder(x1, token1)
            x2 = self._forward_transformer_decoder(x2, token2)
        else:
            x1 = self._forward_simple_decoder(x1, token1)
            x2 = self._forward_simple_decoder(x2, token2)
        # feature differencing
        x = torch.abs(x1 - x2)
        if not self.if_upsample_2x:
            x = self.upsamplex2(x)
        x = self.upsamplex4(x)
        # forward small cnn
        x = self.classifier(x)
        if self.output_sigmoid:
            x = self.sigmoid(x)
        return x

2.1 Semantic tokenizer

对输入图像的兴趣的变化可以用一些高级概念来描述,即语义tokens,并且语义概念可以被双时像数据共享。为此,我们使用Siamese tokenizer从每个时态的特征图中提取紧凑的语义tokens。类似于NLP中的tokens器,它将输入句子分割成几个元素(即单词或短语),并用一个tokens向量表示每个元素,我们的语义tokens器将整个图像分割成几个视觉单词,每个单词对应一个tokens向量。如图3所示,为了获得紧凑tokens,我们的tokens器学习一组空间注意映射,将特征映射空间集中到一组特征,即tokens集。

X 1 、 X 2 ∈ R H W × C X^1、X^2∈R^{HW×C} X1X2RHW×C为输入的双时特征图,其中H、W、C为特征图的高度、宽度和信道尺寸。设 T 1 , T 2 ∈ R L × C T^1,T^2∈R^{L×C} T1T2RL×C为两组tokens,其中L为tokens的词汇表集的大小。
对于特征映射 X i ( i = 1 , 2 ) X^i(i = 1,2) Xii=1,2上的每个像素 X p i X^i_p Xpi,我们使用点向卷积得到L个语义组,每一组表示一个语义概念。然后,我们利用一个对每个语义群的HW维数进行操作的softmax函数来计算空间注意映射。最后,我们利用注意力映射来计算像素的加权平均和。

具体实现代码如下

    #self.conv_a = nn.Conv2d(32, self.token_len, kernel_size=1,padding=0, bias=False)
    def _forward_semantic_tokens(self, x):
        b, c, h, w = x.shape
        spatial_attention = self.conv_a(x)
        spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous()
        spatial_attention = torch.softmax(spatial_attention, dim=-1)
        x = x.view([b, c, -1]).contiguous()
        tokens = torch.einsum('bln,bcn->blc', spatial_attention, x)

        return tokens

2.2 Transformer Encoder

在获得输入双时图像的两个语义tokens集 T 1 , T 2 T^1,T^2 T1,T2后,我们用Transformer Encoder[15]对这些tokens之间的上下文进行建模。我们的动机是,Transformer可以充分利用基于tokens的时空中的全局语义关系,从而为每个时态产生上下文化的tokens表示。如图4 (a)所示,我们首先将这两组令牌连接到一个令牌集 T ∈ R 2 L × C T∈R^{2L×C} TR2L×C中,并将其输入Transformer Encoder,获得一个新的令牌集 T n e w T^{new} Tnew。最后,我们将tokens分成两组 T n e w i ( i = 1 , 2 ) T^i_{new}(i = 1,2) Tnewii=1,2.

    def _forward_transformer(self, x):
        if self.with_pos:
            x += self.pos_embedding
        x = self.transformer(x)
        return x
    def forward(self, x1, x2):
        # forward backbone resnet
        # forward tokenzier
        # forward transformer encoder
        if self.token_trans:
            self.tokens_ = torch.cat([token1, token2], dim=1)
            self.tokens = self._forward_transformer(self.tokens_)
            token1, token2 = self.tokens.chunk(2, dim=1)#token_new

Transformer Encoder由多头自注意(MSA)和多层感知器(MLP)块的NE层组成(图4 (a))。与使用post-norm残差单元的原始Transformer不同,我们遵循ViT [33]采用了pre-norm残差单元(在attention操作前先进行layernorm),即在MSA/MLP之前进行层归一化。PreNorm已经被证明比post-norm[52]更稳定和更优胜。

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
            ]))
    def forward(self, x, mask = None):
        for attn, ff in self.layers:
            x = attn(x, mask = mask)
            x = ff(x)
        return x

在每一层l中,self-attention的输入是一个三元组(query Q,Key K,Value V),从输入 T l − 1 ∈ R 2 L × C T^{l−1}∈R^{2L×C} Tl1R2L×C计算出来,结果为:

其中的W为具体layer中的可学习参数,一般表现为linear或conv。具体attention的计算方式为

transformer的核心思想是mutil-head self-attention。MSA并行执行多个独立的注意头,多头输出将被连接起来,然后投影得到最终的值。MSA的优点是它可以共同关注来自不同位置的不同表示子空间的信息。其公式如下所示,r为attention头的数量。

具体实现代码如下:

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads #设置在输出上的多头
        self.heads = heads
        self.scale = dim ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)# 一次性将QKV计算出来
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )#对多头注意力进行集成得到一个头的输出

    def forward(self, x, mask = None):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)#计算出q, k, v
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)#基于lambda表达式对q, k, v进行reshape操作

        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale#实现Q*K/^d得到atten-map
        mask_value = -torch.finfo(dots.dtype).max

        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, mask_value)#在atten-map上进行mask操作
            del mask

        attn = dots.softmax(dim=-1)


        out = torch.einsum('bhij,bhjd->bhid', attn, v)#得到atten-value
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out

MLP块由两个线性transformation层组成,中间的激活函数为GELU。输入和输出的维数为C,层内的维数为2C。公式为:

其对应实现代码为:

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

注意,我们将可学习的位置嵌入(PE)添加到token序列T中,然后再将该序列输入transformer层。我们的实验表明:有必要在token上补充PE(Position embedding)。PE对有关元素在基于tokens的时空中的相对位置或绝对位置的信息进行编码。这样的位置信息可能有利于上下文建模。例如,时间位置信息可以指导transformer利用与时间相关的上下文。
这里所谓的pos_embedding,仅是一个可训练的位置参数,pos_embedding_decoder也是如此,并未进行特殊的初始化

        self.with_pos = with_pos
        if with_pos is 'learned':
            self.pos_embedding = nn.Parameter(torch.randn(1, self.token_len*2, 32))
        decoder_pos_size = 256//4
        self.with_decoder_pos = with_decoder_pos
        if self.with_decoder_pos == 'learned':
            self.pos_embedding_decoder =nn.Parameter(torch.randn(1, 32,
                                                                 decoder_pos_size,
                                                                 decoder_pos_size))

2.3 Transformer Decoder

到目前为止,我们已经为每个时间图像获得了两组上下文丰富的token T n e w i ( i = 1 , 2 ) T^i_{new}(i = 1,2) Tnewii=1,2。这些上下文token包含紧凑的高级语义信息,很好地揭示了兴趣的变化。现在,我们需要将基于概念的表示投影回像素空间,以获得像素级的特征。为了实现这一点,我们使用一个改进的Siamese Transformer Decoder[15]来细化每个时间的图像特征。如图4 (b)所示,给定一系列特征Xi,transformer解码器利用每个像素与令牌集Ti新之间的关系,获得细化的Xi新特征。我们将Xi中的像素作为查询,而将tokens作为键。我们的直觉是,每个像素都可以用紧凑的语义tokens的组合来表示。

具体实现代码如下,可以看到Decoder与encoder使用的是不同的pos_embedding

    def __init__():
        self.transformer_decoder = TransformerDecoder(dim=dim, depth=self.dec_depth,
                            heads=8, dim_head=self.decoder_dim_head, mlp_dim=mlp_dim, dropout=0,
                                                      softmax=decoder_softmax)
    def _forward_transformer_decoder(self, x, m):
        b, c, h, w = x.shape
        if self.with_decoder_pos == 'fix':
            x = x + self.pos_embedding_decoder
        elif self.with_decoder_pos == 'learned':
            x = x + self.pos_embedding_decoder
        x = rearrange(x, 'b c h w -> b (h w) c')
        x = self.transformer_decoder(x, m)
        x = rearrange(x, 'b (h w) c -> b c h w', h=h)
        return x
    def forward(self):
        # forward transformer decoder
        if self.with_decoder:
            x1 = self._forward_transformer_decoder(x1, token1)
            x2 = self._forward_transformer_decoder(x2, token2)

transformer decoder由多头部交叉注意(MA)和MLP块组成。与[15]中的原始实现不同,我们删除了MSA块,以避免 X i X^i Xi中像素之间密集关系的大量计算。我们采用PerNorm和相同的配置的MLP作为Transformer Encoder。在MSA中,Q、K、V来自相同的输入序列;而在MA中,Q来自图像特征 X i X^i Xi,K和V来自token T n e w i T^i_{new} Tnewi。在公式上,在每一层l,MA被定义为

代码具体实现如下,Cross_Attention拆分开了qkv的计算方式,因为q只针对x, kv要针对token,所有代码上使用PreNorm2对两个数据进行layernorm

class Cross_Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., softmax=True):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim ** -0.5

        self.softmax = softmax
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_k = nn.Linear(dim, inner_dim, bias=False)
        self.to_v = nn.Linear(dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, m, mask = None):

        b, n, _, h = *x.shape, self.heads
        q = self.to_q(x)
        k = self.to_k(m)
        v = self.to_v(m)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), [q,k,v])

        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        mask_value = -torch.finfo(dots.dtype).max

        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, mask_value)
            del mask

        if self.softmax:
            attn = dots.softmax(dim=-1)
        else:
            attn = dots
        # attn = dots
        # vis_tmp(dots)

        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        # vis_tmp2(out)

        return out

class PreNorm2(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, x2, **kwargs):
        return self.fn(self.norm(x), self.norm(x2), **kwargs)
class TransformerDecoder(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout, softmax=True):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual2(PreNorm2(dim, Cross_Attention(dim, heads = heads,
                                                        dim_head = dim_head, dropout = dropout,
                                                        softmax=softmax))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
            ]))
    def forward(self, x, m, mask = None):
        """target(query), memory"""
        for attn, ff in self.layers:
            x = attn(x, m, mask = mask)
            x = ff(x)
        return x

2.4 模型结构

CNN backbone

我们使用改进的ResNet18 [32]来提取双时态图像特征图。原来的ResNet18有5个stages,每个stage都有2个降采样。我们替换最后两个stage的下采样为1并添加一个点的卷积(输出通道C=32)避免ResNet减少特征维度,其次是一个双线性插值层,从而获得输出特性映射与降采样系数4来减少空间细节的损失。我们将这个主干命名为ResNet18_S5。为了验证所提出的方法的有效性,我们还使用了两个较轻的主干,即ResNet18_S4/ResNet18_S3,它只使用了ResNet18的前四/三个stage。

Bitemporal image transformer

我们设置token长度L = 4。我们将transformer编码器的层数设置为1,将transformer解码器的层数设置为8。MSA和MA中的头数h设置为8,每个头的信道尺寸d设置为8

Prediction head

得益于CNN主干和BIT提取的高级语义特征,我们采用了一个非常浅的FCN来进行变化识别。给定两个上采样特征图 X 1 ∗ , X 2 ∗ ∈ R H × W × C X^{1∗},X^{2∗}∈R^{H×W×C} X1X2RH×W×C(H,W分别为原始图像的高度、宽度),具体计算过程如下,其中g表示变化分类器(具体实现为输出通道为2conv3x3), σ表示为softmax函数。

cls head的实现代码极为简单,具体如下

    def __init__():
        self.classifier = TwoLayerConv2d(in_channels=32, out_channels=output_nc)
    def forward(self, x1, x2):
        x = torch.abs(x1 - x2)
        if not self.if_upsample_2x:
            x = self.upsamplex2(x)
        x = self.upsamplex4(x)
        # forward small cnn
        x = self.classifier(x)
        if self.output_sigmoid:
            x = self.sigmoid(x)

Loss function

在训练阶段,我们最小化交叉熵损失来优化网络参数。在形式上,损失函数的定义为:

3、实验

3.1 实验环境

数据集

LEVIR-CD 是一个公共的大型构建CD数据集。它包含637对高分辨率(0.5m)RS图像,大小为1024×1024。我们遵循它的默认数据集划分(训练/验证/测试)。由于GPU内存容量的限制,我们将图像切割成大小为256×256的小块,没有重叠。因此,我们分别获得了7120/1024/2048对patch用于训练/验证/测试;

WHU-CD 是一个公共建筑的CD数据集。它包含一对大小为32507×15354的高分辨率(分辨率0.075m)航拍图像。由于[54]中没有提供数据划分方案,我们将图像裁剪成大小为256×256的小块,并将其随机分成三个部分:分别为6096/762/762进行训练/验证/测试。

DSIFN-CD 是一个公共的二值CD数据集。它分别包括来自中国六个主要城市的6对大型高分辨率(2米)卫星图像。该数据集包含了多种土地覆盖物体的变化,如道路、建筑物、农田和水体。我们遵循作者提供的大小为512×512的默认裁剪样本。我们分别有3600/340/48个样本用于训练/验证/测试。

模型

  • base:我们的基线模型,由CNN主干(ResNet18 S5)和预测头组成。
  • BIT:BIT-base的light backbone模型(ResNet18 S4)。
    为了进一步评估该方法的效率,我们另外设置了以下模型:
  • BaseS4:一个轻的CNN主干(ResNet18 S4)+的预测头。
  • Base S3:一个很轻的CNN主干(ResNet18 S3)+的预测头。
  • BIT S3:基于BIT的模型,具有非常轻的主干(ResNet18 S3)。

实验参数

优化器:SGD ,momentum 为 0.99,weight decay为0.0005
学习率:最初被设置为0.01,然后线性衰减为0
epoch:200 (每个epoch都进行验证,最佳模型用于测试)
评价指标:F1、precision、recall、IoU、平均精度(OA)

3.2 Comparison to state-of-the-art

对比模型

  • FC-EF [22]:图像级融合方法,其中双时间图像作为一个单一的输入连接到一个完全卷积的网络。
  • FC-Siam-Di [22]:特征级融合方法,采用Siamese-FCN提取多层特征,并利用特征差异融合双时间信息。
  • FC-Siam-Conc [22]:特征级融合方法,它采用Siamese-FCN来提取multi-level特征,并使用特征连接来融合双时态信息。
  • DTCDSCN [9]:多尺度特征连接方法,为深度Siamese-FCN增加通道注意和空间注意,从而获得更多的鉴别特征。请注意,他们还在单独时像标签下训练了两个额外的语义分割解码器。为了进行公平的比较,我们省略了语义分割解码器。
  • STANet [2]:基于度量的Siamese-FCN的方法,该方法集成了时空注意机制,获得更多的区分特征。
  • IFNet [10]:多尺度特征连接方法,将信道注意和空间注意应用于解码器每层上的连接双时间特征。使用深度监督(即,在解码器的每个级别上计算监督损失),以更好地训练中间层。
  • SNUNet [14]:多尺度特征连接方法,它结合了孪生网络和NestedUNet[55],提取高分辨率的高级特征。通道注意应用于每一个级别的特征提取。并采用深度监督来提高中间特征的识别能力。

表1报告了模型LEVIR-CD、WHU-CD和DSIFN-CD测试集的总体比较结果。结果显示,基于bit的模型在这些数据集上始终显著地优于其他方法。例如,在三个数据集上,BIT的f1分数分别超过了最近的STANet2/1.6/4.7分。请注意,我们的CNN主干只是纯ResNet,我们不应用复杂的结构,如FPN或UNET,通过融合高空间精度和高级语义特征的像素级预测任务。我们可以得出结论,即使使用一个简单的主干,基于bit的模型也可以获得优越的性能。这可能归因于我们的BIT能够在全局高度抽象的空间范围内建模上下文,并利用上下文来增强像素空间中的特征表示。

这些方法在三个数据集上的可视化比较如图5所示。为了更好地查看视图,使用不同的颜色来表示TP(白色)、TN(黑色)、FP(红色)、FN(绿色)。我们可以观察到,基于比特的模型比其他模型取得了更好的结果。首先,我们的基于bit的模型可以更好地避免假阳性(例如,图5 (a),(e),(g),(i)),因为物体的外观与兴趣变化的外观相似。例如,如图5 (a)所示,大多数比较方法都错误地将游泳池区域划分为建筑变化(视图为红色),而基于全局上下文建模增强的判别特征,STANet和我们的BIT可以减少这种错误检测。在图5 ©中,传统方法误认为道路是建筑变化,因为道路具有与建筑具有相似的颜色行为,这些方法由于接收场地有限,不能排除这些伪变化。

其次,我们的BIT也可以很好地处理由土地覆盖元素的季节差异或外观变化引起的不相关的变化(例如,图5 (b)、(f)和(l))。图5 (f)中构建的非语义变化的一个例子说明了我们的BIT的有效性,它学习时空域内的有效上下文,以更好地表达真实的语义变化,排除不相关的变化。最后,我们的BIT可以生成相对完整的预测结果(例如,图5 ©,(h)和(j))。例如,在图5 (j)中,由于图像2中的大建筑面积的接收场有限,一些比较方法无法完全检测到(视为绿色),而我们基于bit的模型得到了更完整的结果。

3.3 Model efficiency and effectiveness

为了公平地比较模型的效率,我们在配备了英特尔Xeon Silver 4214 CPU和NVIDIA特斯拉V100 GPU的计算服务器上测试了所有的方法。表2报告参数的数量,每秒浮点操作(FLOPs),以及不同方法在LEVIR-CD、WHU-CD和DSIFN-CD测试集上的F1/IoU分数。

首先,我们通过比较卷积对应项来验证我们提出的BIT的效率。上图表明,在Base S3/Base S4上添加的BIT(BIT S3/BIT S4)比卷积层更大的(Base S4/Base S5)更高效。例如,BIT S4在三个测试集上比f1分数的基础S5多1.7/2.4/10.8分,而模型参数少3倍,计算成本低3倍。此外,我们可以观察到,与Base S4相比,添加更多的卷积层只会引入微小的改进(即三个测试集上的f1分数的0.16/0.75/0.18点),而BIT的改进比CNN多得多(即4∼60倍)。
其次,我们与四种基于注意力的方法(DTCDSCN、STANet、IFNet和SNUNet)进行了比较。如上图所示,我们的BIT S4在F1/IoU分数上优于四位同行,其计算复杂度和模型参数都要小得多。有趣的是,即使有更轻的主干(大约小10倍),我们的基于bit的模型(BIT S3)在大多数数据集上仍然优于四种比较方法。比较结果进一步证明了我们的基于bit的模型的有效性和有效性。

Training visualization 图6显示了每个训练时期的训练/验证集上的平均f1分数。我们可以观察到,虽然Base模型和BIT模型在训练精度方面有相似的性能,但BIT模型在稳定性和有效性方面的验证准确性方面优于Base模型。说明比特的训练更加稳定、高效,基于比特的模型具有更强的泛化能力。这可能是因为它能够学习紧凑的上下文丰富的概念,这有效地代表了兴趣的变化。

4、模型分析

4.1 Ablation studies

Context modeling. 在Transformer Encoder(TE)上进行消融实验,以验证其在上下文建模中的有效性,其中多头自注意是TE中建模上下文的核心组件。在表3中,当从BIT中删除TE时,我们可以观察到LEVIR-CD、WHUCD和DSIFN-CD数据集上的f1-分数出现了一致且显著的下降。这说明了基于tokens的时空内关系的self-attention在TE模型中的重要性。此外,我们用一个非局部的[56]self-attention代替了我们的BIT,它能够在基于像素的时空中建模关系。比较结果显示我们的BIT在三个测试集上优于Non-local测试,并具有显著的差距。这可能是因为我们的BIT在基于token的空间中学习上下文(token更紧凑,具有更高的信息密度比Non-local的关系更好,从而促进了关系的有效提取)

Ablation on tokenizer. 通过将token从BIT中移除来对其进行消融。所得到的模型可以考虑使用密集token,这是由CNN主干提取的特征序列。表3中,基于bit的模型(T x)的f1分数显著下降。这表明,tokens发生器模块在我们的基于transformer的框架中是至关重要的。我们可以看到这个模型(T x,最后一行)只比S4稍微好一点。这可能是因为密集的特征包含过多的冗余信息,这使得训练基于transformer的模型成为一项艰巨的任务。相反,我们提出的tokens器在空间上汇集了密集的特征来聚合语义信息,从而获得概念的紧凑tokens。
Ablation on transformer decoder. 为了验证我们的transformer解码器(TD)的有效性,我们将其替换为一个简单的模块来融合来自TE的新tokens T i T_i Ti和来自CNN主干的原始特征 X i X^i Xi。在简单模块中,我们将 T n e w i T^i_{new} Tnewi(包含Ltokens)中每个tokens的空间维数扩展为HxW的形状。L扩展tokens被加到 X i X^i Xi中,产生更新的特征,然后输入预测头。表3表示没有TD的BIT模型的性能持续下降。这可能是因为交叉注意(TD的核心部分)提供了一种优雅的方法,通过建模它们的关系来使用上下文丰富的tokens来增强原始特征。此外,BIT(没有TE和TD)都远远低于正常的BIT模型。
Effect of position embedding transformer的体系结构是排列不变的,而CD任务同时需要空间和时间上的位置信息。为此,我们将学习到的位置嵌入(PE)添加到输入到transformer的特征序列中。我们在TE和TD中对PE进行消融。我们设置了不包含PE的BIT模型作为基线。如表4所示,当我们的BIT模型在输入TE的tokens中时,BIT模型在三个测试集的f1分数上取得了一致的改进。这表明双时态tokens集中的位置信息对于TE中的上下文建模至关重要。与基线相比,当向输入TD的查询中添加PE时,对BIT模型的f1分数没有显著改善。位置信息对于进入TD的查询可能是不必要的,因为进入TD的键(即tokens)是高度抽象的,并且不包含空间结构。因此,我们只在TE中的TE中添加PE,而没有在BIT模型中添加TD。

4.2 Parameter analysis

Token length. 我们的tokens器在空间上将图像的密集特征集中到一个紧凑的tokens集中。我们的直觉是,双时间图像中兴趣的变化可以用一些视觉概念来描述,即语义tokens。tokens集L的长度是一个重要的超参数。我们分别对不同的L∈{2、4、8、16、32}进行了测试,以分析其对我们的模型在LEVIR-CD、WHUCD和DSIFN-CD数据集上的性能的影响。表5表示,当将tokens长度从32减少到4时,模型的f1分数有显著改善。这表明,一个紧凑的tokens集足以表示感兴趣的变化的语义概念,而冗余的tokens可能会阻碍模型的性能。我们还可以观察到,当L从4进一步下降到2时,f1-分数略有下降。这是因为当L太短时,模型可能会丢失一些与改变概念相关的有用信息。因此,我们把L设为4。

Depth of transformer transformer层数是一个重要的超参数。我们测试了在TE和TD中包含不同数量的transformer层的BIT模型的不同配置。表6显示,当增加transformer编码器的深度时,在三个数据集上的BIT的F1/IoU得分没有显著的改善。这表明,通过单层TE可以很好地学习双时tokens之间的关系。表6还表明,模型的性能与解码器的深度大致呈正相关。这可能是因为图像特征在transformer解码器的每一层之后通过考虑上下文令牌而得到细化。当解码器深度为8时,得到了最佳的结果。虽然通过进一步增加解码器深度可能会提高性能,但对于效率和精度之间的权衡,我们将编码器深度设置为1,解码器深度设置为8。

Ablation on tokenizer. 通过将tokens器从比特中移除来对其进行消融。所得到的模型可以考虑使用密集tokens,这是由CNN主干提取的特征序列。如表3所示,基于bit的模型(w.o.。tokens化器)的f1分数显著下降。这表明,tokens发生器模块在我们的基于transformer的框架中是至关重要的。我们可以看到这个模型(without tokenizer)只比S4稍微好一点。这可能是因为密集的特征包含过多的冗余信息,这使得训练基于transformer的模型成为一项艰巨的任务。相反,我们提出的tokens器在空间上汇集了密集的特征来聚合语义信息,从而获得概念的紧凑tokens。

4.3 Token visualization

我们假设我们的tokens器可以提取揭示兴趣变化的高级语义概念。为了更好地理解语义tokens,我们将tokens化器从双时间特征映射中提取的注意映射 A i l ∈ R H W A_i^l∈R^{HW} AilRHW可视化。图7显示了来自LEVIR-CD、WHU-CD和DSIFN-CD数据集的一些双时间图像的tokens的可视化结果。我们为每个输入图像显示从Ti中选择的两个tokens的注意力图。红色表示注意值较高,蓝色表示注意值较低。

从图7中我们可以看出,所提取的令牌可以注意到属于感兴趣变化的语义概念的区域。不同的tokens可能与具有不同语义意义的对象相关联。例如,由于LEVIR-CD和WHU-CD数据集只描述了建筑的变化,因此这些数据集中所学习的令牌主要关注于属于建筑的像素。而因为DSIFN-CD数据集包含各种变化,这些令牌可以突出不同的语义区域,如建筑、农田和水体。有趣的是,如图7 ©和(f)所示,我们的tokens器也可以突出显示建筑周围的像素(例如,阴影),即使在训练我们的模型时没有提供对这些区域的明确监督。这并不奇怪,因为建筑周围的环境是物体识别的关键线索。这表明,我们的模型可以隐式地学习一些额外的概念,以促进对变化的识别。

4.4 Network visualization

为了更好地理解我们的模型,我们提供了一个例子来可视化BIT模型的不同阶段的激活图。给定双时间图像(图8 (a)),一个暹罗的FCN生成高级特征图Xi(图8 (b))。然后,tokens化器使用学习到的注意力映射Ai将特征映射空间地池化到几个tokens向量中(图8 ©)。然后,由转换器编码器生成的上下文丰富的tokens通过变压器解码器投影回像素空间,从而得到改进的特征映射Xinew(图8 (d))。
我们从原始特征Xi和改进的特征Xi新中展示了四个相应的代表性特征图。从图8 (b)和图(d)中,我们可以观察到,我们的模型可以提取出与每个时间图像的兴趣变化相关的高级特征,如建筑及其边缘的概念。为了更好地说明BIT模块的效果,改进后的特征与原始特征之间的差异图像如图8 (e)所示。这表明我们的BIT可以进一步突出与变化类别相关的语义概念区域。最后,预测头计算Xi new和Xi之间的特征差分图(图8 (f)),生成变化概率图P(图8 (g))。

猜你喜欢

转载自blog.csdn.net/a486259/article/details/129730631