【transformar】Algumas notas pessoais de SETF + VIT

Antes de tudo, declare: o código é baseado no seguinte autor, se houver alguma infração, será deletado imediatamente.

Acabei de adicionar minhas próprias notas na base original e mudei a estrutura

https://github.com/920232796/SETR-pytorch

https://github.com/lucidrains/vit-pytorch

A essência da transformação é, na verdade, cortar a imagem em pedaços e, em seguida, contar com o mapeamento de camada linear para encontrar o relacionamento

transform_seg

import logging
import math
import os
import numpy as np 

import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from einops import rearrange
from transformer_model import TransModel2d, TransConfig
import math 

class Encoder2D(nn.Module):
    def __init__(self, config: TransConfig, is_segmentation=True):
        super().__init__()
        self.config = config
        self.out_channels = config.out_channels
        self.bert_model = TransModel2d(config)      ## 经过transform计算了
        sample_rate = config.sample_rate
        sample_v = int(math.pow(2, sample_rate))
        assert config.patch_size[0] * config.patch_size[1] * config.hidden_size % (sample_v**2) == 0, "不能除尽"
        self.final_dense = nn.Linear(config.hidden_size, config.patch_size[0] * config.patch_size[1] * config.hidden_size // (sample_v**2))
        self.patch_size = config.patch_size
        self.hh = self.patch_size[0] // sample_v
        self.ww = self.patch_size[1] // sample_v

        self.is_segmentation = is_segmentation
    def forward(self, x):
        ## x:(b, c, w, h)
        b, c, h, w = x.shape
        assert self.config.in_channels == c, "in_channels != 输入图像channel"
        p1 = self.patch_size[0]
        p2 = self.patch_size[1]

        if h % p1 != 0:
            print("请重新输入img size 参数 必须整除")
            os._exit(0)
        if w % p2 != 0:
            print("请重新输入img size 参数 必须整除")
            os._exit(0)
        hh = h // p1        ## 分成几块
        ww = w // p2        ##

        x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p1, p2 = p2)
        
        encode_x = self.bert_model(x)[-1] # 取出来最后一层
        if not self.is_segmentation:
            return encode_x

        x = self.final_dense(encode_x)
        x = rearrange(x, "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", p1 = self.hh, p2 = self.ww, h = hh, w = ww, c = self.config.hidden_size)
        # print(self.hh)
        # print('**********************************')
        return encode_x, x 


class PreTrainModel(nn.Module):
    def __init__(self, patch_size, 
                        in_channels, 
                        out_class, 
                        hidden_size=1024, 
                        num_hidden_layers=8, 
                        num_attention_heads=16,
                        decode_features=[512, 256, 128, 64]):
        super().__init__()
        config = TransConfig(patch_size=patch_size, 
                            in_channels=in_channels, 
                            out_channels=0, 
                            hidden_size=hidden_size, 
                            num_hidden_layers=num_hidden_layers, 
                            num_attention_heads=num_attention_heads)
        self.encoder_2d = Encoder2D(config, is_segmentation=False)
        self.cls = nn.Linear(hidden_size, out_class)

    def forward(self, x):
        encode_img = self.encoder_2d(x)
        encode_pool = encode_img.mean(dim=1)
        out = self.cls(encode_pool)
        return out 

class Vit(nn.Module):
    def __init__(self, patch_size, 
                        in_channels, 
                        out_class, 
                        hidden_size=1024, 
                        num_hidden_layers=8, 
                        num_attention_heads=16,
                        sample_rate=4,
                        ):
        super().__init__()
        config = TransConfig(patch_size=patch_size, 
                            in_channels=in_channels, 
                            out_channels=0, 
                            sample_rate=sample_rate,
                            hidden_size=hidden_size, 
                            num_hidden_layers=num_hidden_layers, 
                            num_attention_heads=num_attention_heads)
        self.encoder_2d = Encoder2D(config, is_segmentation=False)
        self.cls = nn.Linear(hidden_size, out_class)

    def forward(self, x):
        encode_img = self.encoder_2d(x)
        
        encode_pool = encode_img.mean(dim=1)
        out = self.cls(encode_pool)
        return out 

class Decoder2D(nn.Module):
    def __init__(self, in_channels, out_channels, features=[512, 256, 128, 64]):
        super().__init__()
        self.decoder_1 = nn.Sequential(
                    nn.Conv2d(in_channels, features[0], 3, padding=1),
                    nn.BatchNorm2d(features[0]),
                    nn.ReLU(inplace=True),
                    nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
                )
        self.decoder_2 = nn.Sequential(
                    nn.Conv2d(features[0], features[1], 3, padding=1),
                    nn.BatchNorm2d(features[1]),
                    nn.ReLU(inplace=True),
                    nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
                )
        self.decoder_3 = nn.Sequential(
            nn.Conv2d(features[1], features[2], 3, padding=1),
            nn.BatchNorm2d(features[2]),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        )
        self.decoder_4 = nn.Sequential(
            nn.Conv2d(features[2], features[3], 3, padding=1),
            nn.BatchNorm2d(features[3]),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        )

        self.final_out = nn.Conv2d(features[-1], out_channels, 3, padding=1)

    def forward(self, x):
        x = self.decoder_1(x)
        x = self.decoder_2(x)
        x = self.decoder_3(x)
        x = self.decoder_4(x)
        x = self.final_out(x)
        return x

class SETRModel(nn.Module):
    def __init__(self, patch_size=(32, 32), 
                        in_channels=3, 
                        out_channels=1, 
                        hidden_size=1024, 
                        num_hidden_layers=8, 
                        num_attention_heads=16,
                        decode_features=[512, 256, 128, 64],
                        sample_rate=4,):
        super().__init__()
        config = TransConfig(patch_size=patch_size, 
                            in_channels=in_channels, 
                            out_channels=out_channels, 
                            sample_rate=sample_rate,
                            hidden_size=hidden_size, 
                            num_hidden_layers=num_hidden_layers, 
                            num_attention_heads=num_attention_heads)
        self.encoder_2d = Encoder2D(config)
        self.decoder_2d = Decoder2D(in_channels=config.hidden_size, out_channels=config.out_channels, features=decode_features)

    def forward(self, x):
        _, final_x = self.encoder_2d(x)
        x = self.decoder_2d(final_x)
        return x 


if __name__ == "__main__":
    net = SETRModel(patch_size=(32, 32),        ## 每多少个像素为一组
                    in_channels=3,              ## 输入通道
                    out_channels=1,             ## 输出通道
                    hidden_size=1024,           ## 中间层分布数
                    sample_rate=5,              ## 不知道。。。
                    num_hidden_layers=1,        ## 有多少个transform
                    num_attention_heads=16,     ## 多头
                    decode_features=[512, 256, 128, 64])    ## 输出通道卷积解码器的通道数
    t1 = torch.rand(2, 3, 512, 512)
    print("input: " + str(t1.shape))

    print("output: " + str(net(t1).shape))

transformador_modelo

import logging
import math
import os

import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from einops import rearrange

def swish(x):
    return x * torch.sigmoid(x)

def gelu(x):
    """ 
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

def mish(x):
    return x * torch.tanh(nn.functional.softplus(x))

ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "mish": mish}

class TransConfig(object):
    
    def __init__(
        self,
        patch_size,
        in_channels,
        out_channels,
        sample_rate=4,
        hidden_size=768,
        num_hidden_layers=8,
        num_attention_heads=6,
        intermediate_size=1024,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
    ):  
        self.sample_rate = sample_rate
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps

class TransLayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
        super(TransLayerNorm, self).__init__()

        self.gamma = nn.Parameter(torch.ones(hidden_size))
        self.beta = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps
       

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.gamma * x + self.beta
      
class TransEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """

    def __init__(self, config):
        super().__init__()
        ## nn.Embedding(词的维度,表示词的向量)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)

        self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids):
        input_shape = input_ids.size()
    
        seq_length = input_shape[1]
        device = input_ids.device
        
        position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
        position_ids = position_ids.unsqueeze(0).expand(input_shape[:2])

        position_embeddings = self.position_embeddings(position_ids)

        embeddings = input_ids + position_embeddings        ## + 位置信息
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

## 多头注意力机制
class TransSelfAttention(nn.Module):
    def __init__(self, config: TransConfig):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )
        
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    ## 切分和移位作用
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        ## 最后xshape (batch_size, num_attention_heads, seq_len, head_size)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states
    ):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        attention_scores = attention_scores

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # 注意力加权
        context_layer = torch.matmul(attention_probs, value_layer)
        # 把加权后的V reshape, 得到[batch_size, length, embedding_dimension]
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)

        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer


class TransSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

## 多头注意力 + 残差
class TransAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self = TransSelfAttention(config)
        self.output = TransSelfOutput(config)

    def forward(
        self,
        hidden_states,
    ):
        self_outputs = self.self(hidden_states)     ## 经过多头注意力
        attention_output = self.output(self_outputs, hidden_states)     ## 残差模块 + 标准化模块
        
        return attention_output


class TransIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        self.intermediate_act_fn = ACT2FN[config.hidden_act] ## relu 

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states

class TransOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

## trans模块
class TransLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = TransAttention(config)
        self.intermediate = TransIntermediate(config)
        self.output = TransOutput(config)

    def forward(
        self,
        hidden_states
    ):
        attention_output = self.attention(hidden_states)    ## 多头注意力 + 残差
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output


class TransEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer = nn.ModuleList([TransLayer(config) for _ in range(config.num_hidden_layers)])

    def forward(
        self,
        hidden_states,
        output_all_encoded_layers=True,
    ):
        all_encoder_layers = []
        
        for i, layer_module in enumerate(self.layer):
            layer_output = layer_module(hidden_states)
            hidden_states = layer_output
            if output_all_encoded_layers:
                all_encoder_layers.append(hidden_states)
        if not output_all_encoded_layers:
            all_encoder_layers.append(hidden_states)
            
        return all_encoder_layers

class InputDense2d(nn.Module):
    def __init__(self, config):
        super(InputDense2d, self).__init__()
        self.dense = nn.Linear(config.patch_size[0] * config.patch_size[1] * config.in_channels, config.hidden_size)
        self.transform_act_fn = ACT2FN[config.hidden_act]
        self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)        ## 激活函数
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states

class InputDense3d(nn.Module):
    def __init__(self, config):
        super(InputDense3d, self).__init__()
        self.dense = nn.Linear(config.patch_size[0] * config.patch_size[1] * config.patch_size[2] * config.in_channels, config.hidden_size)
        self.transform_act_fn = ACT2FN[config.hidden_act]
        self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states

class TransModel2d(nn.Module):

    def __init__(self, config):
        super(TransModel2d, self).__init__()
        self.config = config
        self.dense = InputDense2d(config)
        self.embeddings = TransEmbeddings(config)
        self.encoder = TransEncoder(config)

    def forward(
        self,
        input_ids,          ## 输入的是经过位置分割之后的数据(b,hhww,p1p2c).输出为(b,hhww,p1p2c)
        output_all_encoded_layers=True,
       
    ):
        dense_out = self.dense(input_ids)       ## 投影 + 标准化
        # print(dense_out.shape)
        embedding_output = self.embeddings(dense_out)   ## 加上位置编码
        encoder_layers = self.encoder(embedding_output,output_all_encoded_layers=output_all_encoded_layers,) ## transform模块
        sequence_output = encoder_layers[-1]

        if not output_all_encoded_layers:
            # 如果不用输出所有encoder层
            encoder_layers = encoder_layers[-1]
        return encoder_layers


class TransModel3d(nn.Module):

    def __init__(self, config):
        super(TransModel3d, self).__init__()
        self.config = config
        self.dense = InputDense3d(config)
        self.embeddings = TransEmbeddings(config)
        self.encoder = TransEncoder(config)

    def forward(
        self,
        input_ids,
        output_all_encoded_layers=True,
       
    ):  
        dense_out = self.dense(input_ids)
        embedding_output = self.embeddings(
            input_ids=dense_out
        )
        encoder_layers = self.encoder(
            embedding_output,
            output_all_encoded_layers=output_all_encoded_layers,
        )
        sequence_output = encoder_layers[-1]
        
        if not output_all_encoded_layers:
            # 如果不用输出所有encoder层
            encoder_layers = encoder_layers[-1]
        return encoder_layers

exemplo de segmentação de carros

O conjunto de dados pode ser baixado aqui:

Carvana Image Masking Challenge | Kaggle

# data_url : https://www.kaggle.com/c/carvana-image-masking-challenge/data
import torch
import numpy as np
from SETR.transformer_seg import SETRModel
from PIL import Image
import glob
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torchvision.utils import save_image
import os
import xlwt

img_url = sorted(glob.glob("./data/train/*"))
mask_url = sorted(glob.glob("./data/train_masks/*"))
# print(img_url)
train_size = int(len(img_url) * 0.8)
train_img_url = img_url[:train_size]
train_mask_url = mask_url[:train_size]
val_img_url = img_url[train_size:]
val_mask_url = mask_url[train_size:]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device is " + str(device))
epoches = 100
out_channels = 1


def build_model():
    model = SETRModel(patch_size=(16, 16),
                      in_channels=3,
                      out_channels=1,
                      hidden_size=1024,
                      num_hidden_layers=6,
                      num_attention_heads=16,
                      decode_features=[512, 256, 128, 64])
    return model


class CarDataset(Dataset):
    def __init__(self, img_url, mask_url):
        super(CarDataset, self).__init__()
        self.img_url = img_url
        self.mask_url = mask_url

    def __getitem__(self, idx):
        img = Image.open(self.img_url[idx])
        img = img.resize((256, 256))
        img_array = np.array(img, dtype=np.float32) / 255
        mask = Image.open(self.mask_url[idx])
        mask = mask.resize((256, 256))
        mask = np.array(mask, dtype=np.float32)
        img_array = img_array.transpose(2, 0, 1)

        return torch.tensor(img_array.copy()), torch.tensor(mask.copy())

    def __len__(self):
        return len(self.img_url)


def compute_dice(input, target):
    eps = 0.0001
    # input 是经过了sigmoid 之后的输出。
    input = (input > 0.5).float()
    target = (target > 0.5).float()

    # inter = torch.dot(input.view(-1), target.view(-1)) + eps
    inter = torch.sum(target.view(-1) * input.view(-1)) + eps

    # print(self.inter)
    union = torch.sum(input) + torch.sum(target) + eps

    t = (2 * inter.float()) / union.float()
    return t


def predict():
    model = build_model()
    model.load_state_dict(torch.load("./SETR_car.pth", map_location="cpu"))
    print(model)

    import matplotlib.pyplot as plt
    val_dataset = CarDataset(val_img_url, val_mask_url)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    with torch.no_grad():
        for img, mask in val_loader:
            pred = torch.sigmoid(model(img))
            pred = (pred > 0.5).int()
            plt.subplot(1, 3, 1)
            print(img.shape)
            img = img.permute(0, 2, 3, 1)
            plt.imshow(img[0])
            plt.subplot(1, 3, 2)
            plt.imshow(pred[0].squeeze(0), cmap="gray")
            plt.subplot(1, 3, 3)
            plt.imshow(mask[0], cmap="gray")
            plt.show()


def data_write(file_path, epoch, datas1, datas2):  # datas是列表
    # print(datas)
    f = xlwt.Workbook(encoding='utf-8')  # 设置一个workbook,其编码是utf-8
    sheet1 = f.add_sheet(u'阿强的表1', cell_overwrite_ok=True)  # 创建sheet
    sheet1.write(0, 0, label='epoch')  # 将‘列1’作为标题
    sheet1.write(0, 1, label='train_Loss')  # 将‘列1’作为标题
    sheet1.write(0, 2, label='Val_Loss')  # 将‘列2’作为标题
    # 将数据写入第 i 行,第 j 列
    for j in range(len(datas1)):
        sheet1.write(j + 1, 0, epoch[j])
        sheet1.write(j + 1, 1, datas1[j])
        sheet1.write(j + 1, 2, datas2[j])

    f.save(file_path)  # 保存文件


if __name__ == "__main__":

    model = build_model()
    model.to(device)

    train_dataset = CarDataset(train_img_url, train_mask_url)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

    val_dataset = CarDataset(val_img_url, val_mask_url)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    loss_func = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)

    ## 加载参数
    model.load_state_dict(torch.load('./SETR_car.pth'))

    step = 0
    report_loss = 0.0

    train_men = []
    test_men = []
    epo = []
    for epoch in range(epoches):
        print("epoch is " + str(epoch))
        epo.append(epoch)
        train_loss = 0
        test_loss = 0
        print("进行------训练------测试:")
        for img, mask in tqdm(train_loader, total=len(train_loader)):
            optimizer.zero_grad()
            step += 1
            img = img.to(device)
            mask = mask.to(device)

            pred_img = model(img)  ## pred_img (batch, len, channel, W, H)
            # print('***********************')
            # print('输出结果为', pred_img.shape)
            if out_channels == 1:
                pred_img = pred_img.squeeze(1)  # 去掉通道维度

            loss = loss_func(pred_img, mask)
            train_loss += loss.item()
            loss.backward()
            optimizer.step()

        ## 测试
        model.eval()
        with torch.no_grad():
            print("进行------验证------测试:")
            for val_img, val_mask in tqdm(val_loader, total=len(val_loader)):
                val_img = val_img.to(device)
                val_mask = val_mask.to(device)
                pred_img = torch.sigmoid(model(val_img))
                if out_channels == 1:
                    pred_img = pred_img.squeeze(1)
                cur_dice = compute_dice(pred_img, val_mask)
                test_loss += cur_dice.item()

                if (step % 50 == 0):
                    # 输入的图像,取第一张
                    # x = img[0]
                    # 标签,取第一张
                    x_ = val_mask[0]
                    # 标签的图像,取第一张
                    y = pred_img[0]
                    # 三张图,从第0轴拼接起来,再保存
                    img = torch.stack([x_, y], 0)
                    if not os.path.exists('./outputs'):
                        os.mkdir('outputs')
                    save_image(img.cpu(), f"./outputs/{step}.png")

            torch.save(model.state_dict(), "./SETR_car.pth")
            model.train()

        train_men.append(train_loss / len(train_loader))
        test_men.append(1 - (test_loss / len(val_loader)))
        data_write("Car_loss.xls", epo, train_men, test_men)

        print("train_loss is " + str(train_men[epoch]))
        print("Val_dice is " + str(test_men[epoch]))

Algumas mudanças para SETR + VIT

A estrutura original do SETR é um tanto complicada e difícil de ler, mas está escrita com muitos detalhes. No entanto, o VIT é simples e fácil de entender, mas não é adequado para dividir a rede. Então eu reúno as forças de todos e uma família.

transform_self

import logging
import math
import os
import numpy as np 

import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from einops import rearrange
from vit import ViT

class Decoder2D(nn.Module):
    def __init__(self, in_channels, out_channels, features=[512, 256, 128, 64]):
        super().__init__()
        self.decoder_1 = nn.Sequential(
                    nn.Conv2d(in_channels, features[0], 3, padding=1),
                    nn.BatchNorm2d(features[0]),
                    nn.ReLU(inplace=True),
                    nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
                )
        self.decoder_2 = nn.Sequential(
                    nn.Conv2d(features[0], features[1], 3, padding=1),
                    nn.BatchNorm2d(features[1]),
                    nn.ReLU(inplace=True),
                    nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
                )
        self.decoder_3 = nn.Sequential(
            nn.Conv2d(features[1], features[2], 3, padding=1),
            nn.BatchNorm2d(features[2]),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        )
        self.decoder_4 = nn.Sequential(
            nn.Conv2d(features[2], features[3], 3, padding=1),
            nn.BatchNorm2d(features[3]),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        )

        self.final_out = nn.Conv2d(features[-1], out_channels, 3, padding=1)

    def forward(self, x):
        x = self.decoder_1(x)
        x = self.decoder_2(x)
        x = self.decoder_3(x)
        x = self.decoder_4(x)
        x = self.final_out(x)
        return x

class SETRModel(nn.Module):
    def __init__(self, patch_size=(32, 32),
                        image_size=512,
                        in_channels=3, 
                        out_channels=1, 
                        hidden_size=1024, 
                        num_hidden_layers=8, 
                        num_attention_heads=16,
                        decode_features=[512, 256, 128, 64],
                        sample_rate=4,):
        super().__init__()
        # self.encoder_2d = Encoder2D(config)
        self.encoder_2d = ViT(
                            image_size = image_size,
                            in_channels = in_channels,
                            patch_size = patch_size,
                            hidden_size = hidden_size,             # 每个向量的维度
                            num_hidden_layers = num_hidden_layers,              # 就是上右图的L,就是用了几次这个Transformer Encoder
                            num_attention_heads = num_attention_heads,             # 多头注意力机制的 多头
                            sample_rate = sample_rate
                            )
        self.decoder_2d = Decoder2D(in_channels=hidden_size, out_channels=out_channels, features=decode_features)

    def forward(self, x):
        final_x = self.encoder_2d(x)
        x = self.decoder_2d(final_x)
        return x 


if __name__ == "__main__":
    net = SETRModel(
                    image_size = 512,
                    patch_size=(32, 32),        ## 每多少个像素为一组
                    in_channels=3,              ## 输入通道
                    out_channels=1,             ## 输出通道
                    hidden_size=1024,           ## 中间层分布数
                    sample_rate=5,              ## 不知道。。。
                    num_hidden_layers=1,        ## 有多少个transform
                    num_attention_heads=16,     ## 多头
                    decode_features=[512, 256, 128, 64])    ## 输出通道卷积解码器的通道数
    t1 = torch.rand(2, 3, 512, 512)
    print("input: " + str(t1.shape))
    sample_rate = 5
    # sample_v = int(math.pow(2, sample_rate))
    # print(sample_v)
    print("output: " + str(net(t1).shape))

VIT

import torch
import math
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

def swish(x):
    return x * torch.sigmoid(x)
def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def mish(x):
    return x * torch.tanh(nn.functional.softplus(x))
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "mish": mish}

# helpers
## 判断t是否是元组,如果是,直接返回t;如果不是,则将t复制为元组(t, t)再返回。
## 用来处理当给出的图像尺寸或块尺寸是int类型(如224)时,直接返回为同值元组(如(224, 224))
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes
##
class PreNorm(nn.Module):
    def __init__(self, dim, fn, eps=1e-12):
        super().__init__()
        self.fn = fn    ## 这个函数可能是多头注意力函数,或者是 MLP 函数

        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))
        self.variance_epsilon = eps

    def forward(self, x, **kwargs):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        x = self.gamma * x + self.beta      ## y = [(x - Ex) / (Varx - e) ] * γ + β
        return self.fn(x, **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 16, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads       ## 1024
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5       ## 4

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale        ## q乘以k的装置

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            # print(attn(x).shape)
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, hidden_size, num_hidden_layers, num_attention_heads, in_channels = 3,
                 mlp_dim =2048, act = 'gelu',  dim_head = 64, dropout = 0.1, emb_dropout = 0.1, sample_rate = 4):    ## 内部改的参数
        super().__init__()
        image_height, image_width = pair(image_size)    ## 图片大小:256, 256
        patch_height, patch_width = pair(patch_size)    ## 图块大小:32, 32

        dim = hidden_size
        depth = num_hidden_layers
        heads = num_attention_heads
        channels = in_channels

        sample_v = int(math.pow(2, sample_rate))
        assert patch_height * patch_width * num_hidden_layers % (sample_v ** 2) == 0, "不能除尽"
        self.hh = patch_size[0] // sample_v
        self.ww = patch_size[1] // sample_v
        self.h = image_height // patch_height
        self.w = image_width // patch_width
        self.hidden_size = hidden_size

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)     ## 64
        patch_dim = channels * patch_height * patch_width           ## # 图块拉成 3 * 32 * 32 变成一维的长度
        assert act in {'gelu', 'relu', 'swish', 'mish'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        self.transform_act_fn = ACT2FN[act]

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),          # 通过线性函数 把32*32*3 -> 1024
        )   ## 1,(8,8,)(3,32,32)

        ## 分成了64块图片,加入位置信息,并且,多加了一个class维度,用来做分类,
        ## 我的理解是,它可以整合我这64块图片的信息,最终判断这是个什么类
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))     ## 1,65,1024
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))                       ## 1,1,1024
        self.dropout = nn.Dropout(emb_dropout)

        ## dim=1024,depth=6, head=16, dim_head=64, mlp_dim=2048, dropout=0.1
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        ## 上面的操作都是为了让数据能够进入到transform这个结构模型中

    def forward(self, img):
        x = self.to_patch_embedding(img)    ## 将图片展平压缩投影至dim维度
        x = self.transform_act_fn(x)        ## 选择一个激活函数激活一下
        '''
        从这里开始 是按照VIT的格式来的
        '''
        x += self.pos_embedding[:, :]          ## 加上位置信息
        x = self.dropout(x)

        ## 上面的操作都是为了让数据能够进入到transform这个结构模型中
        x = self.transformer(x)     ##  1, 64, 1024

        x = rearrange(x, "b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
                      p1=self.hh, p2=self.ww, h=self.h, w=self.w, c=self.hidden_size)

        return x

Resumir:

As estruturas de VIT e SETR diferem em muitos detalhes:

a diferença VIT SETR
estrutura Codificação de posição —> Normalização —> Atenção de várias cabeças Padronização —> Codificação de posição —> Mecanismo de atenção multicabeça
QKV 1024 torna-se 3072 após uma camada linear e, em seguida, dividida em três partes, essas três partes são definidas como QKV 1024 torna-se três 1024 QKV através de três camadas lineares
Código de localização Codificação do Gerador Aleatório código de pesquisa

Entre eles, o efeito de VIT e SETR nesta amostra também é obviamente diferente, mas a diferença nos dados não é muito grande, mas a imagem do efeito é óbvia.

comparado VIT SETR
Precisão

Um pouco pior (20 épocas: 0,975)

Ligeiramente melhor (20 épocas: 0,980)

velocidade

Quase na metade

trem + Val = 5 minutos e 20 segundos

trem + Val = 7 minutos 55 segundos

renderizações

Um pouco de puxão no quadril, mas aceitável

Pacote direto Monai

from monai.networks.nets import ViT

self.vit = ViT(
            in_channels=in_channels,            ## 输入通道
            img_size=img_size,                  ## 图像大小
            patch_size=self.patch_size,         ## 采样块大小
            hidden_size=hidden_size,            ## 隐藏层线性大小
            mlp_dim=mlp_dim,                    ## MLP线性大小
            num_layers=self.num_layers,         ## 多少个VIM
            num_heads=num_heads,                ## 多头
            pos_embed=pos_embed,                ## 编码
            classification=self.classification, ## 是否分类
            dropout_rate=dropout_rate,          
        )

Acho que você gosta

Origin blog.csdn.net/qq_42792802/article/details/127604908
Recomendado
Clasificación