【分系列发送 缓慢更新】CREStereo源码阅读2——特定模块阅读

背景

CREStereo先前记录了大概框架,现在看看有些存在感很强的模块具体的逻辑。

参考

论文阅读
源码
论文
上一篇博客

特定模块

经过框架大概阅读,列举了这些需要深入分析的模块:
LocalFeatureTransformer, AGCL, BasicUpdateBlock, convex_upsample。注意力应该会和AGCL在一块儿看。

LocalFeatureTransformer

新增参考

LoFTR阅读
LoFTR参考源码

具体内容

还是先从代码角度看,不是很理解的看别人的论文阅读吧。直接从forward开始。可以看到这个层单纯的是自注意力和交叉注意力的任意组合,具体组合方式看self.layer是怎么定义的。

for layer, name in zip(self.layers, self.layer_names):
    if name == "self":
        feat0 = layer(feat0, feat0, mask0, mask0)
        feat1 = layer(feat1, feat1, mask1, mask1)
    elif name == "cross":
        feat0 = layer(feat0, feat1, mask0, mask1)
        feat1 = layer(feat1, feat0, mask1, mask0)
    else:
        raise KeyError

于是看向初始化部分。可以看到套多少个子模块连接是由layer_names决定。这个子模块是由 LoFTREncoderLayer(d_model, nhead, attention)定义,于是接着把代码往上翻,看这个模块具体内容。

def __init__(self, d_model, nhead, layer_names, attention):
    super(LocalFeatureTransformer, self).__init__()

    self.d_model = d_model
    self.nhead = nhead
    self.layer_names = layer_names
    encoder_layer = LoFTREncoderLayer(d_model, nhead, attention)
    self.layers = [
        copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))
    ]
    self._reset_parameters()

还是从forward开始看,这边就有点难对应了,结合论文来看。本来想顺着说一遍,读着读着发现基本一致,就看这边论文的描述吧。

def forward(self, x, source, x_mask=None, source_mask=None):
    bs = x.shape[0]
    query, key, value = x, source, source

    # multi-head attention
    query = F.reshape(
        self.q_proj(query), (bs, -1, self.nhead, self.dim)
    )  # [N, L, (H, D)] (H=8, D=256//8)
    key = F.reshape(
        self.k_proj(key), (bs, -1, self.nhead, self.dim)
    )  # [N, S, (H, D)]
    value = F.reshape(self.v_proj(value), (bs, -1, self.nhead, self.dim))
    message = self.attention(
        query, key, value, q_mask=x_mask, kv_mask=source_mask
    )  # [N, L, (H, D)]
    message = self.merge(
        F.reshape(message, (bs, -1, self.nhead * self.dim))
    )  # [N, L, C]
    message = self.norm1(message)

    # feed-forward network
    message = self.mlp(F.concat([x, message], axis=2))
    message = self.norm2(message)

    return x + message

PS:在LinearAttention里面有个小小的可以优化的小地方,基本对速度没啥帮助哈哈。就是megengine好像没有原生的elu算子,他是这么表示的。我左瞅右瞅,感觉这边加减1可以消掉,elu_feature_map注释掉的部分是我的写法。

PPS:其实这么看下来,感觉transformer架构很特别,但也只是很特别,感觉reshape之类的冗余操作太多有点影响计算速度。是不是可以从一些角度对此进行一些修改?顺带代码中用的是Linear Transformer,我不是很了解一些变体,看到这么一篇文章,是不是可以试试这里面提到的Performer?

def elu(x, alpha=1.0):
    return F.maximum(0, x) + F.minimum(0, alpha * (F.exp(x) - 1))

def elu_feature_map(x):
    return elu(x) + 1
    # return F.relu(x) + F.minimum(1, F.exp(x))

AGCL

新增参考
具体内容

好家伙,这玩意儿好长,不想看。
先从__call__开始阅读。简单来看就是对应上这两种模块。然后对应以下具体内容。

def __call__(self, flow, extra_offset, small_patch=False, iter_mode=False):
    if iter_mode:
        corr = self.corr_iter(self.fmap1, self.fmap2, flow, small_patch)
    else:
        corr = self.corr_att_offset(
            self.fmap1, self.fmap2, flow, extra_offset, small_patch
        )
    return corr

首先是corr_iter。self.coords是模块初始化时生成的相当于是点坐标集,加上flow意思就是新的坐标集,然后从右图按坐标集采点。从这边的small_patch就可以看出他的搜索方式是可选择的1D或2D搜索,但是保证是9个点,从而保证计算的一致性。然后将左右特征按照特征通道数4分割,self.get_correlation就是在psize_list[i]范围内,执行9次相关性计算然后拼接出一个相关性结果,最后四个部分拼接成最后输出的结果。
这边为什么要先分出4部分再融合呢?不知道是处于精简计算的关系还是和输入内容有关,之后要更细致的分析

def corr_iter(self, left_feature, right_feature, flow, small_patch):

    coords = self.coords + flow
    coords = F.transpose(coords, (0, 2, 3, 1))
    right_feature = bilinear_sampler(right_feature, coords)

    if small_patch:
        psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
        dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
    else:
        psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
        dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]

    N, C, H, W = left_feature.shape
    lefts = F.split(left_feature, 4, axis=1)
    rights = F.split(right_feature, 4, axis=1)

    corrs = []
    for i in range(len(psize_list)):
        corr = self.get_correlation(
            lefts[i], rights[i], psize_list[i], dilate_list[i]
        )
        corrs.append(corr)

    final_corr = F.concat(corrs, axis=1)

    return final_corr

To be continued…

猜你喜欢

转载自blog.csdn.net/weixin_42492254/article/details/125081621