Channel Exchange(通道交换)

今天在阅读《Changer: Feature Interaction is What You Need for Change Detection》这篇论文中,其中的通道交换问题不是很清楚,和大家一起学习一下。在这里插入图片描述
这是网络整体和Channel_exchange所处的位置。具体的channel_exchange细节如下图
在这里插入图片描述
它的公式如下:它的涵义是如果exchange_mask是1则交换x0与x1
在这里插入图片描述
然后是伪代码展示如下
在这里插入图片描述

import torch
#  根据通道c和间距p定义exchange_mask
#  exchange_mask:tensor([ True, False,  True])
exchange_mask = torch.arange(3)%2==0
#  根据批量个数扩张exchange_mask 
#  exchange_mask
#  tensor([[ True, False,  True],
#          [ True, False,  True]])
exchange_mask = exchange_mask.unsqueeze(0).expand((2,-1))
# x1:    tensor([[[[ 0.1008, -0.1781],
#           [-0.3207, -0.3508]],

#           [[-1.0290, -1.2271],
#            [ 0.6968, -1.4028]],

#           [[ 0.6179, -0.3015],
#            [-0.8388, -0.0772]]],


#          [[[ 0.7332,  1.2673],
          [-1.2648,  1.0845]],

#           [[-0.8240, -0.8279],
#            [-1.2621,  0.2231]],

#           [[ 0.5337,  1.0726],
#            [ 1.1852,  0.3237]]]])
# x1:    tensor([[[[ 1.4240, -1.3492],
#               [ 0.7807, -2.0613]],

#              [[-0.3764,  0.2079],
#               [-1.0346, -1.2292]],

#              [[ 2.0369,  0.8063],
#              [-0.1176, -0.8294]]],


#             [[[-0.4517, -0.3017],
#               [-1.2654, -0.4562]],

#              [[ 0.6034, -0.4299],
#               [ 0.0746, -0.3248]],

#              [[-0.8387, -0.4018],
#               [-1.1967,  1.1520]]]])
x1 = torch.randn(2,3,2,2)
x2 = torch.randn(2,3,2,2)
out_x1, out_x2 = torch.zeros_like(x1),torch.zeros_like(x2)
#   拷贝不交换的通道
out_x1[~exchange_mask,] = x1[~exchange_mask,]
#  >>> out_x1
#  tensor([[[[ 0.0000,  0.0000],
#            [ 0.0000,  0.0000]],

#           [[-1.0290, -1.2271],
#            [ 0.6968, -1.4028]],

#           [[ 0.0000,  0.0000],
#            [ 0.0000,  0.0000]]],


#          [[[ 0.0000,  0.0000],
#            [ 0.0000,  0.0000]],

#           [[-0.8240, -0.8279],
#            [-1.2621,  0.2231]],

#           [[ 0.0000,  0.0000],
#            [ 0.0000,  0.0000]]]])
out_x2[~exchange_mask,] = x2[~exchange_mask,]
#   >>> out_x2
#   tensor([[[[ 0.0000,  0.0000],
#             [ 0.0000,  0.0000]],

#            [[-0.3764,  0.2079],
#             [-1.0346, -1.2292]],

#            [[ 0.0000,  0.0000],
#             [ 0.0000,  0.0000]]],


#           [[[ 0.0000,  0.0000],
#             [ 0.0000,  0.0000]],

#            [[ 0.6034, -0.4299],
#             [ 0.0746, -0.3248]],

#            [[ 0.0000,  0.0000],
#             [ 0.0000,  0.0000]]]])
#   交换通道
out_x1[exchange_mask,] = x2[exchange_mask,]
out_x2[exchange_mask,] = x1[exchange_mask,]
#   >>> out_x1
#   tensor([[[[ 1.4240, -1.3492],
#             [ 0.7807, -2.0613]],

#            [[-1.0290, -1.2271],
#             [ 0.6968, -1.4028]],

#            [[ 2.0369,  0.8063],
#             [-0.1176, -0.8294]]],


#           [[[-0.4517, -0.3017],
#             [-1.2654, -0.4562]],

#            [[-0.8240, -0.8279],
#             [-1.2621,  0.2231]],

#            [[-0.8387, -0.4018],
#             [-1.1967,  1.1520]]]])
#   >>> out_x2
#   tensor([[[[ 0.1008, -0.1781],
#             [-0.3207, -0.3508]],

#            [[-0.3764,  0.2079],
#             [-1.0346, -1.2292]],

#            [[ 0.6179, -0.3015],
#             [-0.8388, -0.0772]]],


#           [[[ 0.7332,  1.2673],
#             [-1.2648,  1.0845]],

#            [[ 0.6034, -0.4299],
#             [ 0.0746, -0.3248]],

#            [[ 0.5337,  1.0726],
#             [ 1.1852,  0.3237]]]])

此次学习后,知道了亲自动手的好处。

猜你喜欢

转载自blog.csdn.net/Aaaha_jasper/article/details/127412847