kaggle可插拔tta应用记录

链接:https://www.kaggle.com/leighplt/pytorch-tta-flip-left-right
tta 见过不少了,今天发现一个python的代码技巧记录一下

import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data
import torchvision
from torchvision import models

import cv2
from pathlib import Path
import glob

#============= tta ===================
#这里的tta是可插拔的,用在训练和预测上都行,下面有使用的方法,一看就很明了,其中这里面的staticmethod返回函数的静态方法,
#该方法不强制要求传递参数,并且无需实例化就可以调用,也可以实例化调用,很灵活。
#实例化调用方法就是 形如:C = TTAFunction()  然后调用时C.tta()这样,不实例化的话可以直接TTAFunction.tta()
class TTAFunction:
    """
    Simple TTA function
    """
    @staticmethod
    def hflip(x):
        return x.flip(3)
    
    @staticmethod
    def vflip(x):
        return x.flip(2)
    
    def tta(self, x):
        self.eval()
        with torch.no_grad():
            result = self.forward(x)
            x = self.hflip(x)
            result += self.hflip(self.forward(x))
        return 0.5*result
#============= model ===================
def conv3x3(in_, out):
    return nn.Conv2d(in_, out, 3, padding=1)

class ConvRelu(nn.Module):
    def __init__(self, in_, out):
        super().__init__()
        self.conv = conv3x3(in_, out)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.activation(x)
        return x

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()

        self.block = nn.Sequential(
            ConvRelu(in_channels, middle_channels),
            nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)

class UNet11(TTAFunction, nn.Module): # use our class with TTA function
    def __init__(self, num_filters=32):
        """
        :param num_classes:
        :param num_filters:
        """
        super().__init__()
        self.pool = nn.MaxPool2d(2, 2)

        # Convolutions are from VGG11
        self.encoder = models.vgg11().features
        
        # "relu" layer is taken from VGG probably for generality, but it's not clear 
        self.relu = self.encoder[1]
        
        self.conv1 = self.encoder[0]
        self.conv2 = self.encoder[3]
        self.conv3s = self.encoder[6]
        self.conv3 = self.encoder[8]
        self.conv4s = self.encoder[11]
        self.conv4 = self.encoder[13]
        self.conv5s = self.encoder[16]
        self.conv5 = self.encoder[18]

        self.center = DecoderBlock(num_filters * 8 * 2, num_filters * 8 * 2, num_filters * 8)
        self.dec5 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 8)
        self.dec4 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 4)
        self.dec3 = DecoderBlock(num_filters * (8 + 4), num_filters * 4 * 2, num_filters * 2)
        self.dec2 = DecoderBlock(num_filters * (4 + 2), num_filters * 2 * 2, num_filters)
        self.dec1 = ConvRelu(num_filters * (2 + 1), num_filters)
        
        self.final = nn.Conv2d(num_filters, 1, kernel_size=1, )

    def forward(self, x):
        conv1 = self.relu(self.conv1(x))
        conv2 = self.relu(self.conv2(self.pool(conv1)))
        conv3s = self.relu(self.conv3s(self.pool(conv2)))
        conv3 = self.relu(self.conv3(conv3s))
        conv4s = self.relu(self.conv4s(self.pool(conv3)))
        conv4 = self.relu(self.conv4(conv4s))
        conv5s = self.relu(self.conv5s(self.pool(conv4)))
        conv5 = self.relu(self.conv5(conv5s))

        center = self.center(self.pool(conv5))

        # Deconvolutions with copies of VGG11 layers of corresponding size 
        dec5 = self.dec5(torch.cat([center, conv5], 1))
        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(torch.cat([dec2, conv1], 1))
        return torch.sigmoid(self.final(dec1))

def unet11(**kwargs):
    model = UNet11(**kwargs)
    return model

def get_model():
    np.random.seed(717)
    torch.cuda.manual_seed(717);
    torch.manual_seed(717);
    model = unet11()
    model.train()
    return model.to(device)
#============= use tta for predict===================
model = get_model()
model.load_state_dict(torch.load(model_pth)['state_dict'])

test_dataset = TGSSaltDataset(test_path, test_file_list, is_test = True)  #这个函数原来链接里有

all_predictions = []
for image in data.DataLoader(test_dataset, batch_size = 30):
    image = image[0].type(torch.FloatTensor).to(device)
    y_pred = model.tta(image).cpu().data.numpy() # use tta_flip
    all_predictions.append(y_pred)
all_predictions_stacked = np.vstack(all_predictions)[:, 0, :, :]

猜你喜欢

转载自blog.csdn.net/qq_20373723/article/details/110500670