[Send slow updates by series] CREStereo source code reading 2——Specific module reading

background

CREStereo previously recorded the general framework, now let’s look at the specific logic of some modules with a strong sense of presence.

reference

Paper Reading
Source Code
Paper
Previous Blog

specific module

After a general reading of the framework, these modules that need in-depth analysis are listed:
LocalFeatureTransformer, AGCL, BasicUpdateBlock, convex_upsample. Attention should be seen together with AGCL.

LocalFeatureTransformer

add reference

LoFTR read
LoFTR reference source code

Specific content

Let's look at it from the code point of view first, and read other people's papers if you don't understand it very well. Start directly from forward. It can be seen that this layer is simply any combination of self-attention and cross-attention. The specific combination depends on how self.layer is defined.

for layer, name in zip(self.layers, self.layer_names):
    if name == "self":
        feat0 = layer(feat0, feat0, mask0, mask0)
        feat1 = layer(feat1, feat1, mask1, mask1)
    elif name == "cross":
        feat0 = layer(feat0, feat1, mask0, mask1)
        feat1 = layer(feat1, feat0, mask1, mask0)
    else:
        raise KeyError

So look at the initialization part. You can see how many submodule connections are determined by layer_names. This submodule is LoFTREncoderLayer(d_model, nhead, attention)defined by , so then scroll up the code to see the specific content of this module.

def __init__(self, d_model, nhead, layer_names, attention):
    super(LocalFeatureTransformer, self).__init__()

    self.d_model = d_model
    self.nhead = nhead
    self.layer_names = layer_names
    encoder_layer = LoFTREncoderLayer(d_model, nhead, attention)
    self.layers = [
        copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))
    ]
    self._reset_parameters()

Still starting from the forward, it is a bit difficult to correspond here, combined with the paper . I wanted to say it along the way, but after reading it, I found that it was basically the same, so just read the description of the paper here.

def forward(self, x, source, x_mask=None, source_mask=None):
    bs = x.shape[0]
    query, key, value = x, source, source

    # multi-head attention
    query = F.reshape(
        self.q_proj(query), (bs, -1, self.nhead, self.dim)
    )  # [N, L, (H, D)] (H=8, D=256//8)
    key = F.reshape(
        self.k_proj(key), (bs, -1, self.nhead, self.dim)
    )  # [N, S, (H, D)]
    value = F.reshape(self.v_proj(value), (bs, -1, self.nhead, self.dim))
    message = self.attention(
        query, key, value, q_mask=x_mask, kv_mask=source_mask
    )  # [N, L, (H, D)]
    message = self.merge(
        F.reshape(message, (bs, -1, self.nhead * self.dim))
    )  # [N, L, C]
    message = self.norm1(message)

    # feed-forward network
    message = self.mlp(F.concat([x, message], axis=2))
    message = self.norm2(message)

    return x + message

PS: There is a small place in LinearAttention that can be optimized, which basically doesn't help the speed haha. That is, megengine does not seem to have a native elu operator, which is what he said. I looked left and right, and felt that the addition and subtraction of 1 here can be eliminated, and the commented out part of elu_feature_map is my way of writing.

PPS: In fact, looking at it this way, I feel that the transformer architecture is very special, but it is only very special. I feel that too many redundant operations such as reshape affect the calculation speed. Isn't it possible to modify this a bit from some angles? By the way, the Linear Transformer is used in the code. I don’t know some variants very well. After seeing such an article , can I try the Performer mentioned here?

def elu(x, alpha=1.0):
    return F.maximum(0, x) + F.minimum(0, alpha * (F.exp(x) - 1))

def elu_feature_map(x):
    return elu(x) + 1
    # return F.relu(x) + F.minimum(1, F.exp(x))

AGCL

add reference
Specific content

Good guy, this thing is so long, I don't want to read it.
Start reading with __call__ first. Simply put, it corresponds to these two modules. Then correspond to the following specific content.

def __call__(self, flow, extra_offset, small_patch=False, iter_mode=False):
    if iter_mode:
        corr = self.corr_iter(self.fmap1, self.fmap2, flow, small_patch)
    else:
        corr = self.corr_att_offset(
            self.fmap1, self.fmap2, flow, extra_offset, small_patch
        )
    return corr

The first is corr_iter. self.coords is equivalent to a point coordinate set generated when the module is initialized, adding flow means a new coordinate set, and then collect points from the coordinate set on the right. From the small_patch here, it can be seen that his search method is an optional 1D or 2D search, but it is guaranteed to be 9 points, so as to ensure the consistency of calculation. Then divide the left and right features according to the number of feature channels 4, self.get_correlation is within the range of psize_list[i], perform 9 correlation calculations and splicing a correlation result, and splicing the last four parts into the final output result.
Why do we need to separate the 4 parts first and then integrate them? I don't know whether it is related to the streamlined calculation or the input content, and a more detailed analysis will be required later .

def corr_iter(self, left_feature, right_feature, flow, small_patch):

    coords = self.coords + flow
    coords = F.transpose(coords, (0, 2, 3, 1))
    right_feature = bilinear_sampler(right_feature, coords)

    if small_patch:
        psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
        dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
    else:
        psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
        dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]

    N, C, H, W = left_feature.shape
    lefts = F.split(left_feature, 4, axis=1)
    rights = F.split(right_feature, 4, axis=1)

    corrs = []
    for i in range(len(psize_list)):
        corr = self.get_correlation(
            lefts[i], rights[i], psize_list[i], dilate_list[i]
        )
        corrs.append(corr)

    final_corr = F.concat(corrs, axis=1)

    return final_corr

To be continued…

Guess you like

Origin blog.csdn.net/weixin_42492254/article/details/125081621