【分系列发送】CREStereo源码阅读1——框架阅读

背景

CREStereo试了预训练的模型,效果真的非常好,就是还没办法实时。打算自己参照着找找有没有可以加速的地方

参考

论文阅读
源码
论文

过程

整体框架

打开nets文件夹里面的crestereo.py,先看看大体的inference流程,具体太长了,总的就不贴了,一点一点贴。

  1. 特征提取
    上来先归一到-1~+1就不说了,通过self.fnet把左右图特征提取出来,然后cast改改类型,根据注释应该是原图像的1/4大小
fmap1, fmap2 = self.fnet([image1, image2])
  1. 最初的降采样+卷积+上下文语义提取
    先是将刚刚提取到的特征进行平均池化,大小从1/4变到1/8,然后经过一次卷积之后通过sigmoid函数。这边sigmoid的写法让我有点懵,实际上和我注释掉的写法是等价的,是不是这么写计算的时候会快?得到offset_dw8
# 1/4 -> 1/8
# feature
fmap1_dw8 = F.avg_pool2d(fmap1, 2, stride=2)
fmap2_dw8 = F.avg_pool2d(fmap2, 2, stride=2)

# offset
offset_dw8 = self.conv_offset_8(fmap1_dw8)
offset_dw8 = self.range_8 * (F.sigmoid(offset_dw8) - 0.5) * 2.0
# offset_dw8 = self.range_8 * F.tanh(offset_dw8*0.5)

随后将还是之前的特征图按特征通道进行分割,一半用tanh激活函数,一半用relu,池化后得到net_dw8和inp_dw8

# context
net, inp = F.split(fmap1, [hdim], axis=1)
net = F.tanh(net)
inp = F.relu(inp)
net_dw8 = F.avg_pool2d(net, 2, stride=2)
inp_dw8 = F.avg_pool2d(inp, 2, stride=2)

随后有一段类似的降采样到1/16的,不再叙述。
然后就是位置编码,初始化一个1/16分辨率的位置编码,对左右的dw16特征进行编码整合

# positional encoding and self-attention
pos_encoding_fn_small = PositionEncodingSine(
    d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
)
# 'n c h w -> n (h w) c'
x_tmp = pos_encoding_fn_small(fmap1_dw16)
fmap1_dw16 = F.reshape(
    F.transpose(x_tmp, (0, 2, 3, 1)),
    (x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1]),
)
# 'n c h w -> n (h w) c'
x_tmp = pos_encoding_fn_small(fmap2_dw16)
fmap2_dw16 = F.reshape(
    F.transpose(x_tmp, (0, 2, 3, 1)),
    (x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1]),
)

随后经过自注意力部分后再变回原来的shape

fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
fmap1_dw16, fmap2_dw16 = [
    F.transpose(
        F.reshape(x, (x.shape[0], image1.shape[2] // 16, -1, x.shape[2])),
        (0, 3, 1, 2),
    )
    for x in [fmap1_dw16, fmap2_dw16]
]
  1. flow_init初始化的计算
    先对1/4、1/8、1/16分辨率的特征分别进行AGCL自适应组相关层的初始化,具体内容之后展开
corr_fn = AGCL(fmap1, fmap2)
corr_fn_dw8 = AGCL(fmap1_dw8, fmap2_dw8)
corr_fn_att_dw16 = AGCL(fmap1_dw16, fmap2_dw16, att=self.cross_att_fn)

对于已经有输入flow_init的情况,识别flow_init和自己需要的大小的比例,然后进行插值,注意有个负号。得到一个1/4分辨率的flow初值。

predictions = []
flow = None
flow_up = None
if flow_init is not None:
    scale = fmap1.shape[2] / flow_init.shape[2]
    flow = -scale * F.nn.interpolate(
        flow_init,
        size=(fmap1.shape[2], fmap1.shape[3]),
        mode="bilinear",
        align_corners=True,
    )

对于没有flow_init输入的情况,先初始化一个零值的flow,然后使用刚刚初始化过的AGCL模块,先计算出一个粗略的相关性结果。注意这边的small_patch代表着这个AGCL模块会根据itr的值交替进行1维和2维的搜索。这边1D和2D的交替搜索是这个模型的一个亮点

# zero initialization
flow_dw16 = self.zero_init(fmap1_dw16)

# Recurrent Update Module
# RUM: 1/16
for itr in range(iters // 2):
    if itr % 2 == 0:
        small_patch = False
    else:
        small_patch = True

    flow_dw16 = flow_dw16.detach()
    out_corrs = corr_fn_att_dw16(
        flow_dw16, offset_dw16, small_patch=small_patch
    )

随后通过之前提取的1/16的语义和粗略相关性结果更新新的语义和mask,以及flow的变化,更新到flow上,进行4倍的上采样和4倍的插值,直接拉回原图大小(是不是?我不确定),然后把这个得到的flow_up放到predictions里面,这个predictions除了带有最后结果,应该还有每个小环节得到的flow,用来计算loss用的。

    with amp.autocast(enabled=self.mixed_precision):
        net_dw16, up_mask, delta_flow = self.update_block(
            net_dw16, inp_dw16, out_corrs, flow_dw16
        )

    flow_dw16 = flow_dw16 + delta_flow
    flow = self.convex_upsample(flow_dw16, up_mask, rate=4)
    flow_up = -4 * F.nn.interpolate(
        flow,
        size=(4 * flow.shape[2], 4 * flow.shape[3]),
        mode="bilinear",
        align_corners=True,
    )
    predictions.append(flow_up)

随后将刚刚得到的1/16分辨率上采样的1/4分辨率的flow进行插值(下采样)得到一个对应dw8分辨率的flow初值

scale = fmap1_dw8.shape[2] / flow.shape[2]
flow_dw8 = -scale * F.nn.interpolate(
    flow,
    size=(fmap1_dw8.shape[2], fmap1_dw8.shape[3]),
    mode="bilinear",
    align_corners=True,
)

后面一段和上面1/16的类似,不断更新得到一个认为可靠的flow,并更新到predictions里面,就不贴代码了。最后得到一个1/4分辨率的flow初值。

  1. 最后得到结果
    其实和上面1/16分辨率的没多少区别,只是最后得到的结果不需要再往1/4分辨率上变了。最后4倍上采样到原图大小输出。
 # RUM: 1/4
for itr in range(iters):
    if itr % 2 == 0:
        small_patch = False
    else:
        small_patch = True

    flow = flow.detach()
    out_corrs = corr_fn(flow, None, small_patch=small_patch, iter_mode=True)

    with amp.autocast(enabled=self.mixed_precision):
        net, up_mask, delta_flow = self.update_block(net, inp, out_corrs, flow)

    flow = flow + delta_flow
    flow_up = -self.convex_upsample(flow, up_mask, rate=4)
    predictions.append(flow_up)
部分特定模块展开

经过以上的大概叙述,列举以下一些需要深入分析的模块:
LocalFeatureTransformer, AGCL, BasicUpdateBlock, convex_upsample
接下来的内容就放第二篇来写了。

猜你喜欢

转载自blog.csdn.net/weixin_42492254/article/details/125069478