Medical Image Segmentation 2 TransUnet: Transformers Make Strong Encoders for Medical Image Segmentation

TransUnet:Transformers Make Strong Encoders for Medical Image Segmentation

这篇文章中你可以找到一下内容:
- Attention是怎么样在CNN中火起来的?-Non Local
- Transformer结构带来了什么?-Multi Head Self Attention
- Transformer结构为何在CV中如此流行?-Vision Transformer和SETR
- TransUnet又是如何魔改Unet和Transformer?-ResNet50+VIT作为backbone\Encoder
- TransUnet的pytorch代码实现
- 作者吐槽以及偷懒的痕迹

citation

In the field of medical image segmentation, U-shaped networks, especially Unet, have achieved excellent results. However, the CNN structure is not good at establishing long-distance information connections, that is, the CNN structure has a limited receptive field. Although it is possible to increase the receptive field by stacking CNN structures, using hole convolution, etc., it will also introduce some strange problems (including but not limited to convolution kernel degradation, rasterization caused by hole convolution), resulting in limited final effects .

The Transformer structure based on the self-attention mechanism has made important achievements in NLP tasks. Vision Transformer introduced the Transformer structure into the CV field and achieved excellent results that year. Transformer is therefore popular in CV.

Having said that, why can the Transformer structure achieve good results in the CV field?

Attention is all you need?

Before introducing Transformer, let's take a look at what's interesting in the CNN structure.
First review the Non Local structure

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k} } )V Attention(Q,K,V)=softmax(dk QKT)V

Starting from Non Local, the attention (Attention) mechanism will kill the Quartet in the major conferences in 2017 and 18, including NonLocal Net, DANet, PSANet, ISANet, CCNet and other networks. There is only one core idea here, which is the Attention mechanism, which can establish remote connections with unlimited distances , breaking through the problem of insufficient receptive fields of the CNN model. Of course, one drawback of this Attention calculation method is that it requires a lot of calculation . Therefore, in this direction, CCNet, ISANet and other networks are also optimized for the defect of large amount of calculation, and some top conference papers have been published.

Of course, why did you think of proposing Non Local to calculate Attention? Because the author of Non Local got inspiration from Transformer. So, go back to the classic paper "Attention is all you need" that proposed Transformer.

This paper is mainly two works, one is to propose Transformer , and the other is Multi-head Attention , which is to replace attention with multi-head attention mechanism.

The structure of Transformer is very simple, mainly including Multi-Head Attention, FFN, and Norm modules. One thing to pay attention to is Multi-Head Attention.

Multi-Head Attention is actually not difficult to understand, Multi-Head Attention is just one of the Attention mechanisms. As the name suggests, Multi-Head Attention means that there are multiple Heads, and each Head calculates a set of attention, that is, the process of Scaled Dot-Product Attention is done h times, and then the outputs are combined. In this way, the same position has h representations, and the output content is richer than Scaled Dot-Product Attention.

M u l t i − H e a d A t t e n t i o n ( Q , K , V ) = C o n c a t ( h e a d 1 , . . . , h e a d h ) W O \small Multi-Head Attention(Q, K, V) = Concat(head_1, ..., head_h)W^O MultiHeadAttention(Q,K,V)=Concat(head1,...,headh)WO
h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) \small head_i = Attention(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)

Vision Transformer - the pioneer from CNN to Transformer

Vision Transformer can be said to be the pioneer of CV and also the savior of CVer. Before Vit, CVer does not know how long it will struggle in Non Local. (Of course, Transformer is almost struggling now).
Vit's paper "AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE" Google's name is very interesting.

The implementation principle is also very simple. Transformer processes all sequence data, and image data cannot be directly input into Transformer. Therefore, Vit thought of a way to divide the image into 9 pieces, that is, 9 patches (of course, it can be divided into 16 pieces, 25 pieces, etc., depending on the size of one of your patches). In this way, the patches are stitched together in order to form a sequence. After adding a positional encoding to this sequence, it can be input into Transformer for processing. The role of positional encoding here is to let the model know the order of image patches, which is helpful for model learning.

The success of Vit on ImageNet gave hope to the CV class. Segmentation is a major task of CV. Since Vit can classify, he can act as a Backbone for segmentation tasks like ResNet.

SERT Vit can also be used for semantic segmentation!

Then, in another CVPR top conference paper, "Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers" SERT was the first to use Vit as BackBone to implement semantic segmentation tasks.

The implementation of the SERT model is also very simple. Using the classic encoder-decoder network, Vit is used as the BackBone, and three different decoder structures are designed, and semantic segmentation experiments are carried out to prove that Vit is feasible in semantic segmentation. It’s a very simple idea, if you realize it first, you can eat the meat first (thanks to Vit for giving me a top meeting for free).

text

There is a lot of nonsense in the front, all about CNN, Attention, Non Local, Transformer, let's go back to the TransUnet model. A large part of the CV paper is a patchwork cut (although TransUnet looks like a patchwork cut too). However, there is an art to patchwork tailoring. As shown in the figure below, TransUnet structure.

It is still a classic Unet-shaped network, but unlike CNN-base Unet, the first three layers here are CNN-based, but the last layer is Transformer-based. That is, the last layer of Unet's encoder is replaced with a Transformer model.

Why is there only one layer of Transformer

TransUnet has its own considerations for only replacing part of it with Transformer. Although Transformer can obtain the global receptive field, it has defects in the processing of detailed features.
SegFormer: The paper "Segmenter: Transformer for Semantic Segmentation" discusses the influence of patch size on the prediction results of the model. It is found that although the calculation speed of large patch size is faster, the segmentation effect of the edge is obviously poor, while the edge of small patch size is relatively Be more precise.

Many facts have proved that Transformer is defective for local detail segmentation. On the contrary, CNN benefits from its local receptive field and can recover detailed features more accurately. Therefore, the TransUnet model only replaces the last layer, and this layer pays more attention to global information, which is what Transformer is good at. As for the shallow detail recognition task, it is done by CNN.

TransUnet Specific Details

  • The decoder structure is very simple, and it is still a typical combination of skip-connection and upsample.
  • For the encoder part:
    • The author chose the first three layers of ResNet50 as the CNN structure, which is easy to understand, ResNet is awesome.
    • The last layer is the Vit structure, which is the 12-layer Transformer Layer
    • The author calls the encoder R50-ViT.

For some introductions to Vit, you can read another article: VIT+SETR , this article is lazy and omitted.

However, it should be noted that if the size of the input Vit is (b, c, W, H), when patch size=P, the output of Vit is (b, c, W/P, H/P), which is H /PH/PH/P , W / P W/P W / P needs to be upsampled to (W, H) size.

TransUnet model implementation

Encoder part

The Encoder part is mainly composed of ResNet50 and Vit. In the ResNet50 part, the 4 times downsampling in the stem_block structure is canceled, and the model structure of the first three layers is retained. These three layers all choose twice downsampling, and the output of the last layer is used as the Vit. Input, which ensures that the feature size, channel number and the original image correspond.

import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
    expansion: int = 4
    def __init__(self, inplanes, planes, stride = 1, downsample = None, groups = 1,
        base_width = 64, dilation = 1, norm_layer = None):
        
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Conv2d(inplanes, planes ,kernel_size=3, stride=stride, 
                               padding=dilation,groups=groups, bias=False,dilation=dilation)
        
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes ,kernel_size=3, stride=stride, 
                               padding=dilation,groups=groups, bias=False,dilation=dilation)
        
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample= None,
        groups = 1, base_width = 64, dilation = 1, norm_layer = None,):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, stride=1, bias=False)
        self.bn1 = norm_layer(width)
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, bias=False, padding=dilation, dilation=dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = nn.Conv2d(width, planes * self.expansion, kernel_size=1, stride=1, bias=False)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(
        self,block, layers,num_classes = 1000, zero_init_residual = False, groups = 1,
        width_per_group = 64, replace_stride_with_dilation = None, norm_layer = None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.inplanes = 64
        self.dilation = 2
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
            
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                f"or a 3-element tuple, got {
      
      replace_stride_with_dilation}"
            )
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 64//4, layers[0], stride=2)
        self.layer2 = self._make_layer(block, 128//4, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256//4, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512//4, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]

    def _make_layer(
        self,
        block,
        planes,
        blocks,
        stride = 1,
        dilate = False,
    ):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = stride
            
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes,  planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                norm_layer(planes * block.expansion))

        layers = []
        layers.append(
            block(
                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )
        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        out = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        out.append(x)
        x = self.layer2(x)
        out.append(x)
        x = self.layer3(x)
        out.append(x)
        # 最后一层不输出
        # x = self.layer4(x)
        # out.append(x)
        return out

    def forward(self, x) :
        return self._forward_impl(x)

    def _resnet(block, layers, pretrained_path = None, **kwargs,):
        model = ResNet(block, layers, **kwargs)
        if pretrained_path is not None:
            model.load_state_dict(torch.load(pretrained_path),  strict=False)
        return model
    
    def resnet50(pretrained_path=None, **kwargs):
        return ResNet._resnet(Bottleneck, [3, 4, 6, 3], pretrained_path,**kwargs)
    
    def resnet101(pretrained_path=None, **kwargs):
        return ResNet._resnet(Bottleneck, [3, 4, 23, 3], pretrained_path,**kwargs)

if __name__ == "__main__":
    v = ResNet.resnet50().cuda()
    img = torch.randn(1, 3, 512, 512).cuda()
    preds = v(img)
    # torch.Size([1, 64, 256, 256])
    print(preds[0].shape)
    # torch.Size([1, 128, 128, 128])
    print(preds[1].shape)
    # torch.Size([1, 256, 64, 64])
    print(preds[2].shape)

Then comes the Vit part, Vit accepts the third output of ResNet50.

import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange


def pair(t):
    return t if isinstance(t, tuple) else (t, t)

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, dim, depth, heads, mlp_dim, channels = 512, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.out = Rearrange("b (h w) c->b c h w", h=image_height//patch_height, w=image_width//patch_width)
        
		# 这里上采样倍数为8倍。为了保持和图中的feature size一样
        self.upsample = nn.UpsamplingBilinear2d(scale_factor = patch_size//2)
        self.conv = nn.Sequential(
            nn.Conv2d(dim, dim, 3, padding=1),
            nn.BatchNorm2d(dim),
            nn.ReLU())

    def forward(self, img):
    	# 这里对应了图中的Linear Projection,主要是将图片分块嵌入,成为一个序列
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape
        # 为图像切片序列加上索引
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        # 输入到Transformer中处理
        x = self.transformer(x)

        # delete cls_tokens, 输出前需要删除掉索引
        output = x[:,1:,:]
        output = self.out(output)

        # Transformer输出后,上采样到原始尺寸
        output = self.upsample(output)
        output = self.conv(output)

        return output


import torch
if __name__ == "__main__":
    v = ViT(image_size = (64, 64), patch_size = 16, channels = 256, dim = 512, depth = 12, heads = 16, mlp_dim = 1024, dropout = 0.1, emb_dropout = 0.1).cpu()
    # 假设ResNet50第三层输出大小是 1, 256, 64, 64 也就是b, c, W/8, H/8
    img = torch.randn(1, 256, 64, 64).cpu()
    preds = v(img)
    # 输出是 b, c, W/16, H/16
    # preds:  torch.Size([1, 512, 32, 32])
    print("preds: ",preds.size())

Merge the two parts again and package them into the TransUnetEncoder class.

class TransUnetEncoder(nn.Module):
    def __init__(self, **kwargs):
        super(TransUnetEncoder, self).__init__()
        self.R50 = ResNet.resnet50()
        self.Vit = ViT(image_size = (64, 64), patch_size = 16, channels = 256, dim = 512, depth = 12, heads = 16, mlp_dim = 1024, dropout = 0.1, emb_dropout = 0.1)

    def forward(self, x):
        x1, x2, x3 = self.R50(x)
        x4 = self.Vit(x3)
        return [x1, x2, x3, x4]

if __name__ == "__main__":
    x = torch.randn(1, 3, 512, 512).cuda()
    net = TransUnetEncoder().cuda()
    out = net(x)
    # torch.Size([1, 64, 256, 256])
    print(out[0].shape)
    # torch.Size([1, 128, 128, 128])
    print(out[1].shape)
    # torch.Size([1, 256, 64, 64])
    print(out[2].shape)
    # torch.Size([1, 512, 32, 32])
    print(out[3].shape)

Decoder part

The Decoder part is the classic Unet decoder module, which accepts skip connection, then convolution, upsampling, and convolution. It is also packaged into the TransUnetDecoder class.

class TransUnetDecoder(nn.Module):
    def __init__(self, out_channels=64, **kwargs):
        super(TransUnetDecoder, self).__init__()
        self.decoder1 = nn.Sequential(
            nn.Conv2d(out_channels//4, out_channels//4, 3, padding=1), 
            nn.BatchNorm2d(out_channels//4),
            nn.ReLU()            
        )
        self.upsample1 = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(out_channels, out_channels//4, 3, padding=1),
            nn.BatchNorm2d(out_channels//4),
            nn.ReLU()     
        )

        self.decoder2 = nn.Sequential(
            nn.Conv2d(out_channels*2, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()            
        )
        self.upsample2 = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(out_channels*2, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()     
        )

        self.decoder3 = nn.Sequential(
            nn.Conv2d(out_channels*4, out_channels*2, 3, padding=1),
            nn.BatchNorm2d(out_channels*2),
            nn.ReLU()            
        )        
        self.upsample3 = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(out_channels*4, out_channels*2, 3, padding=1),
            nn.BatchNorm2d(out_channels*2),
            nn.ReLU()     
        )

        self.decoder4 = nn.Sequential(
            nn.Conv2d(out_channels*8, out_channels*4, 3, padding=1),
            nn.BatchNorm2d(out_channels*4),
            nn.ReLU()                           
        )
        self.upsample4 = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(out_channels*8, out_channels*4, 3, padding=1),
            nn.BatchNorm2d(out_channels*4),
            nn.ReLU()     
        )

    def forward(self, inputs):
        x1, x2, x3, x4 = inputs
        # b 512 H/8 W/8
        
        x4 = self.upsample4(x4)
        x = self.decoder4(torch.cat([x4, x3], dim=1))        
        
        x = self.upsample3(x)
        x = self.decoder3(torch.cat([x, x2], dim=1))

        x = self.upsample2(x)
        x = self.decoder2(torch.cat([x, x1], dim=1))

        x = self.upsample1(x)
        x = self.decoder1(x)

        return x

if __name__ == "__main__":
    x1 = torch.randn([1, 64, 256, 256]).cuda()
    x2 = torch.randn([1, 128, 128, 128]).cuda()
    x3 = torch.randn([1, 256, 64, 64]).cuda()
    x4 = torch.randn([1, 512, 32, 32]).cuda()
    net = TransUnetDecoder().cuda()
    out = net([x1,x2,x3,x4])
    # out: torch.Size([1, 16, 512, 512])
    print(out.shape)

TransUnet class

Finally, the Encoder and Decoder are packaged into TransUnet.

class TransUnet(nn.Module):
	# 主要是修改num_classes 
    def __init__(self, num_classes=4, **kwargs):
        super(TransUnet, self).__init__()
        self.TransUnetEncoder = TransUnetEncoder()
        self.TransUnetDecoder = TransUnetDecoder()
        self.cls_head = nn.Conv2d(16, num_classes, 1)
    def forward(self, x):
        x = self.TransUnetEncoder(x)
        x = self.TransUnetDecoder(x)
        x = self.cls_head(x)
        return x
    
if __name__ == "__main__":
	# 输入的图像尺寸 [1, 3, 512, 512]
    x1 = torch.randn([1, 3, 512, 512]).cuda()
    net = TransUnet().cuda()
    out = net(x1)
    # 输出的结果[batch, num_classes, 512, 512]
    print(out.shape)

Test it on the Camvid test set

Because there is no suitable image in the medical field at hand, just find a data set to test the segmentation effect.
Camvid is a segmentation data set in the field of autonomous driving. It has fewer eight or nine hundred images and runs faster on my computer.
Some parameters are set as follows

# 导入库
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
 
torch.manual_seed(17)
# 自定义数据集CamVidDataset
class CamVidDataset(torch.utils.data.Dataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    """
    
    def __init__(self, images_dir, masks_dir):
        self.transform = A.Compose([
            A.Resize(512, 512),
            A.HorizontalFlip(),
            A.VerticalFlip(),
            A.Normalize(),
            ToTensorV2(),
        ]) 
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
 
    
    def __getitem__(self, i):
        # read data
        image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
        mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
        image = self.transform(image=image,mask=mask)
        
        return image['image'], image['mask'][:,:,0]
        
    def __len__(self):
        return len(self.ids)
    
    
# 设置数据集路径
DATA_DIR = r'../blork_file/dataset//camvid/' # 根据自己的路径来设置
x_train_dir = os.path.join(DATA_DIR, 'train_images')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')
x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
    
train_dataset = CamVidDataset(
    x_train_dir, 
    y_train_dir, 
)
val_dataset = CamVidDataset(
    x_valid_dir, 
    y_valid_dir, 
)
 
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=True, drop_last=True)

Some model and training process settings

from d2l import torch as d2l
from tqdm import tqdm
import pandas as pd
import monai
# model
model = TransUnet(num_classes=33).cuda()
# training loop 100 epochs
epochs_num = 100
# 选用SGD优化器来训练
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,80], gamma=0.5)

# 损失函数选用多分类交叉熵损失函数
lossf = nn.CrossEntropyLoss(ignore_index=255)

def evaluate_accuracy_gpu(net, data_iter, device=None):
    if isinstance(net, nn.Module):
        net.eval()  # Set the model to evaluation mode
        if not device:
            device = next(iter(net.parameters())).device
    # No. of correct predictions, no. of predictions
    metric = d2l.Accumulator(2)

    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # Required for BERT Fine-tuning (to be covered later)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            output = net(X)
            metric.add(d2l.accuracy(output, y), d2l.size(y))
    return metric[0] / metric[1]


# 训练函数
def train_ch13(net, train_iter, test_iter, loss, optimizer, num_epochs, schedule, devices=d2l.try_all_gpus()):
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1], legend=['train loss', 'train acc', 'test acc'])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    # 用来保存一些训练参数

    loss_list = []
    train_acc_list = []
    test_acc_list = []
    epochs_list = []
    time_list = []
    lr_list = []

    for epoch in range(num_epochs):
        # Sum of training loss, sum of training accuracy, no. of examples,
        # no. of predictions
        metric = d2l.Accumulator(4)
        for i, (X, labels) in enumerate(train_iter):
            timer.start()

            if isinstance(X, list):
                X = [x.to(devices[0]) for x in X]
            else:
                X = X.to(devices[0])
            gt = labels.long().to(devices[0])

            net.train()
            optimizer.zero_grad()
            result = net(X)
            loss_sum = loss(result, gt)
            loss_sum.sum().backward()
            optimizer.step()

            acc = d2l.accuracy(result, gt)
            metric.add(loss_sum, acc, labels.shape[0], labels.numel())

            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[2], metric[1] / metric[3], None))
                
        schedule.step()

        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
        print(f"epoch {
      
      epoch+1}/{
      
      epochs_num} --- loss {
      
      metric[0] / metric[2]:.3f} --- train acc {
      
      metric[1] / metric[3]:.3f} --- test acc {
      
      test_acc:.3f} --- lr {
      
      optimizer.state_dict()['param_groups'][0]['lr']} --- cost time {
      
      timer.sum()}")
        
        #---------保存训练数据---------------
        df = pd.DataFrame()
        loss_list.append(metric[0] / metric[2])
        train_acc_list.append(metric[1] / metric[3])
        test_acc_list.append(test_acc)
        epochs_list.append(epoch+1)
        time_list.append(timer.sum())
        lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
        
        df['epoch'] = epochs_list
        df['loss'] = loss_list
        df['train_acc'] = train_acc_list
        df['test_acc'] = test_acc_list
        df["lr"] = lr_list
        df['time'] = time_list
        
        df.to_excel("../blork_file/savefile/TransUnet_camvid.xlsx")
        #----------------保存模型------------------- 
        if np.mod(epoch+1, 5) == 0:
            torch.save(net.state_dict(), f'../blork_file/checkpoints/TransUnet_{
      
      epoch+1}.pth')

    # 保存下最后的model
    torch.save(net.state_dict(), f'../blork_file/checkpoints/TransUnet_last.pth')
    
# 开始训练
train_ch13(model, train_loader, val_loader, lossf, optimizer, epochs_num, schedule)

Training result:
insert image description here

say at the end

Although the code of the article is relatively rough, it roughly corresponds to the original image of TransUnet. If you want to get models of different scales, all you need to change is the number of channels in each layer. You need to modify and confirm in ResNet50, Vit, and Decoder. If you want to use TransUnet in different datasets, you only need to modify the value of num_classes when creating the model.

Author's note :

  • The composition of num_classes is mainly: background+category 1+category 2+category n.
  • The author is lazy and still criticizing himself. If the author is not lazy, he can connect the relationship between the number of channels, so that the scale of the model can be modified only by changing one place, unlike the current situation where several places need to be changed and verification is required.
  • However, the verification process is also a learning process, so it is of great benefit to Xiaobai to take a look at the code and change it.
  • Therefore, the author found a good excuse for being lazy here.
  • This article finished TransUnet, and at the request of a reader, the next article will be written about SwinUnet.
  • Personally, I think that the effect of Transformer may not be very good. At least the author tested the situation on his own cell data set, the results of Swin Transformer are not as good as the traditional CNN model. The defects of Transformer are obvious, and GPU resources are consumed a lot. But the segmentation effect on large objects will be very good, which is also the power of the attention mechanism. However, its handling of small objects and boundaries is obviously not so good. In this case, using the multi-scale Deformable Attention mentioned in deformable-DETR may achieve a good effect, after all, you can pay more attention to local information. However, the major conferences in 2022 have already started the transformation of Transformer, integrating CNN into Transformer, so as to achieve the effect of both local and global efforts, such as MixFormer, MaxVit and so on.
  • In short, I personally think that CV is approaching the bottleneck period, and I look forward to the birth of the next dark horse, which will overturn Transformer and CNN.

Guess you like

Origin blog.csdn.net/yumaomi/article/details/127217206