今天在阅读《Changer: Feature Interaction is What You Need for Change Detection》这篇论文中,其中的通道交换问题不是很清楚,和大家一起学习一下。
这是网络整体和Channel_exchange所处的位置。具体的channel_exchange细节如下图
它的公式如下:它的涵义是如果exchange_mask是1则交换x0与x1
然后是伪代码展示如下
import torch
# 根据通道c和间距p定义exchange_mask
# exchange_mask:tensor([ True, False, True])
exchange_mask = torch.arange(3)%2==0
# 根据批量个数扩张exchange_mask
# exchange_mask
# tensor([[ True, False, True],
# [ True, False, True]])
exchange_mask = exchange_mask.unsqueeze(0).expand((2,-1))
# x1: tensor([[[[ 0.1008, -0.1781],
# [-0.3207, -0.3508]],
# [[-1.0290, -1.2271],
# [ 0.6968, -1.4028]],
# [[ 0.6179, -0.3015],
# [-0.8388, -0.0772]]],
# [[[ 0.7332, 1.2673],
[-1.2648, 1.0845]],
# [[-0.8240, -0.8279],
# [-1.2621, 0.2231]],
# [[ 0.5337, 1.0726],
# [ 1.1852, 0.3237]]]])
# x1: tensor([[[[ 1.4240, -1.3492],
# [ 0.7807, -2.0613]],
# [[-0.3764, 0.2079],
# [-1.0346, -1.2292]],
# [[ 2.0369, 0.8063],
# [-0.1176, -0.8294]]],
# [[[-0.4517, -0.3017],
# [-1.2654, -0.4562]],
# [[ 0.6034, -0.4299],
# [ 0.0746, -0.3248]],
# [[-0.8387, -0.4018],
# [-1.1967, 1.1520]]]])
x1 = torch.randn(2,3,2,2)
x2 = torch.randn(2,3,2,2)
out_x1, out_x2 = torch.zeros_like(x1),torch.zeros_like(x2)
# 拷贝不交换的通道
out_x1[~exchange_mask,] = x1[~exchange_mask,]
# >>> out_x1
# tensor([[[[ 0.0000, 0.0000],
# [ 0.0000, 0.0000]],
# [[-1.0290, -1.2271],
# [ 0.6968, -1.4028]],
# [[ 0.0000, 0.0000],
# [ 0.0000, 0.0000]]],
# [[[ 0.0000, 0.0000],
# [ 0.0000, 0.0000]],
# [[-0.8240, -0.8279],
# [-1.2621, 0.2231]],
# [[ 0.0000, 0.0000],
# [ 0.0000, 0.0000]]]])
out_x2[~exchange_mask,] = x2[~exchange_mask,]
# >>> out_x2
# tensor([[[[ 0.0000, 0.0000],
# [ 0.0000, 0.0000]],
# [[-0.3764, 0.2079],
# [-1.0346, -1.2292]],
# [[ 0.0000, 0.0000],
# [ 0.0000, 0.0000]]],
# [[[ 0.0000, 0.0000],
# [ 0.0000, 0.0000]],
# [[ 0.6034, -0.4299],
# [ 0.0746, -0.3248]],
# [[ 0.0000, 0.0000],
# [ 0.0000, 0.0000]]]])
# 交换通道
out_x1[exchange_mask,] = x2[exchange_mask,]
out_x2[exchange_mask,] = x1[exchange_mask,]
# >>> out_x1
# tensor([[[[ 1.4240, -1.3492],
# [ 0.7807, -2.0613]],
# [[-1.0290, -1.2271],
# [ 0.6968, -1.4028]],
# [[ 2.0369, 0.8063],
# [-0.1176, -0.8294]]],
# [[[-0.4517, -0.3017],
# [-1.2654, -0.4562]],
# [[-0.8240, -0.8279],
# [-1.2621, 0.2231]],
# [[-0.8387, -0.4018],
# [-1.1967, 1.1520]]]])
# >>> out_x2
# tensor([[[[ 0.1008, -0.1781],
# [-0.3207, -0.3508]],
# [[-0.3764, 0.2079],
# [-1.0346, -1.2292]],
# [[ 0.6179, -0.3015],
# [-0.8388, -0.0772]]],
# [[[ 0.7332, 1.2673],
# [-1.2648, 1.0845]],
# [[ 0.6034, -0.4299],
# [ 0.0746, -0.3248]],
# [[ 0.5337, 1.0726],
# [ 1.1852, 0.3237]]]])
此次学习后,知道了亲自动手的好处。