Medical Transformer: Gated Axial-Attention for Medical Image Segmentation
paper interpretation:
https://zhuanlan.zhihu.com/p/408662947
Experimental results:
0 Preface
0.1 The original attention mechanism
0.2 Axial attention mechanism + relative position encoding
0.3 Based on the axial attention mechanism +gated gate control unit
The gated axial attention mechanism introduces four gates to form a gating mechanism to control the amount of information provided by relative positional encoding to key, query, and value. The influence of relative position encoding on non-local context encoding is controlled.
Depending on whether the information obtained by relative position encoding is useful, the gate parameter either converges to 0 or to some higher value. If a relative positional code is learned accurately, the gating mechanism gives it higher weight compared to those codes that are not learned accurately.
1. axialAttentionUNet
1.1 The original axialAttentionUNet
model = ResAxialAttentionUNet(AxialBlock, [1, 2, 4, 1], s= 0.125, **kwargs)
- Original axis attention + unet composed of residual network
ResAxialAttentionUNet(
(conv1): Conv2d(3, 8, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(conv2): Conv2d(8, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv3): Conv2d(128, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(layer1): Sequential(
(0): AxialBlock(
(conv_down): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(layer2): Sequential(
(0): AxialBlock(
(conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): AxialBlock(
(conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer3): Sequential(
(0): AxialBlock(
(conv_down): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): AxialBlock(
(conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): AxialBlock(
(conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): AxialBlock(
(conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer4): Sequential(
(0): AxialBlock(
(conv_down): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(conv_up): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(decoder1): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(decoder2): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoder3): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoder4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoder5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(adjust): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))
(soft): Softmax(dim=1)
)
1.2 Axis attention network with gating unit added
model = ResAxialAttentionUNet(AxialBlock_dynamic, [1, 2, 4, 1], s= 0.125, **kwargs)
In the gated axis attention network,
1. gated axial attention network
Replace all the axial attention layers with the gated axis attention layer.
ResAxialAttentionUNet(
(conv1): Conv2d(3, 8, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(conv2): Conv2d(8, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv3): Conv2d(128, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(layer1): Sequential(
(0): AxialBlock_dynamic(
(conv_down): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(layer2): Sequential(
(0): AxialBlock_dynamic(
(conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): AxialBlock_dynamic(
(conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer3): Sequential(
(0): AxialBlock_dynamic(
(conv_down): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): AxialBlock_dynamic(
(conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): AxialBlock_dynamic(
(conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): AxialBlock_dynamic(
(conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer4): Sequential(
(0): AxialBlock_dynamic(
(conv_down): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(conv_up): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(decoder1): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(decoder2): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoder3): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoder4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoder5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(adjust): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))
(soft): Softmax(dim=1)
)
2. Medical Transformer
训练过程中,需要注意 前10个 epoch 并没有激活gated 门控单元,在10个epoch 之后才会开启。
2.1 local _ global
model = medt_net(AxialBlock,AxialBlock, [1, 2, 4, 1], s= 0.125, **kwargs)
LoGo network:
In a local + global network:
The way to use is:
- The unet constructed using the original axis attention does not use the gated axis attention unit proposed in this paper.
- The training strategy of local+ global training is used.
medt_net(
(conv1): Conv2d(3, 8, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(conv2): Conv2d(8, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv3): Conv2d(128, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(layer1): Sequential(
(0): AxialBlock(
(conv_down): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(layer2): Sequential(
(0): AxialBlock(
(conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): AxialBlock(
(conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(decoder4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoder5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(adjust): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))
(soft): Softmax(dim=1)
(conv1_p): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(conv2_p): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv3_p): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1_p): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn2_p): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn3_p): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu_p): ReLU(inplace=True)
(layer1_p): Sequential(
(0): AxialBlock(
(conv_down): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(layer2_p): Sequential(
(0): AxialBlock(
(conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): AxialBlock(
(conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer3_p): Sequential(
(0): AxialBlock(
(conv_down): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): AxialBlock(
(conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): AxialBlock(
(conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): AxialBlock(
(conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer4_p): Sequential(
(0): AxialBlock(
(conv_down): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention(
(qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention(
(qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(conv_up): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(decoder1_p): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(decoder2_p): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoder3_p): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoder4_p): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoder5_p): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoderf): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(adjust_p): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))
(soft_p): Softmax(dim=1)
)
2.2 With transformer
model = medt_net(AxialBlock_dynamic,AxialBlock_wopos, [1, 2, 4, 1], s= 0.125, **kwargs)
The way to use is:
-
In the global branch, the proposed gated axis attention unit is used. While in the local branch, the original axis attention is used and there is no position encoding.
-
The training strategy of local+ global training is used.
medt_net(
(conv1): Conv2d(3, 8, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(conv2): Conv2d(8, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv3): Conv2d(128, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(layer1): Sequential(
(0): AxialBlock_dynamic(
(conv_down): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(layer2): Sequential(
(0): AxialBlock_dynamic(
(conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): AxialBlock_dynamic(
(conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_dynamic(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(decoder4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoder5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(adjust): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))
(soft): Softmax(dim=1)
(conv1_p): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(conv2_p): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv3_p): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1_p): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn2_p): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn3_p): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu_p): ReLU(inplace=True)
(layer1_p): Sequential(
(0): AxialBlock_wopos(
(conv_down): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
(bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_wopos(
(qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_wopos(
(qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(layer2_p): Sequential(
(0): AxialBlock_wopos(
(conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_wopos(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_wopos(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): AxialBlock_wopos(
(conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_wopos(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_wopos(
(qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer3_p): Sequential(
(0): AxialBlock_wopos(
(conv_down): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_wopos(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_wopos(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): AxialBlock_wopos(
(conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_wopos(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_wopos(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): AxialBlock_wopos(
(conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_wopos(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_wopos(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): AxialBlock_wopos(
(conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_wopos(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_wopos(
(qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer4_p): Sequential(
(0): AxialBlock_wopos(
(conv_down): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(hight_block): AxialAttention_wopos(
(qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(width_block): AxialAttention_wopos(
(qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)
(bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(conv_up): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(decoder1_p): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(decoder2_p): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoder3_p): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoder4_p): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoder5_p): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(decoderf): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(adjust_p): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))
(soft_p): Softmax(dim=1)
)
3. reference:
3.1 Cross Attention
https://github.com/yearing1017/CCNet_PyTorch/tree/master/CCNet
https://github.com/speedinghzl/CCNet
3.2 Axis Attention Mechanism
https://github.com/lucidrains/axial-attention
Axial Attention in Multidimensional Transformers
3.3 Application of Axis Attention Mechanism
MetNet: A Neural Weather Model for Precipitation Forecasting
Medical Transformer:
Axis attention network:
https://blog.csdn.net/hxxjxw/article/details/121445561;
https://blog.csdn.net/weixin_43718675/article/details/106760382
https://zhuanlan.zhihu.com/p/408662947;
Recommended reading
https://blog.csdn.net/weixin_43718675/article/details/106760382#t4