Resumen de introducción a Unet

Unet ya es un modelo de segmentación muy antiguo, es el modelo propuesto en "U-Net: Convolutional Networks for Biomedical Image Segmentation" en 2015.

Enlace en papel: https://arxiv.org/abs/1505.04597

Antes de Unet, era la red FCN más antigua. FCN es la abreviatura de Fully Convolutional Netowkrs, que establece el marco básico para dividir la red. Sin embargo, la precisión de la red FCN es baja y no es tan fácil de usar como Unet. .

La red Unet es muy simple, la primera mitad es extracción de características y la segunda mitad es sobremuestreo. En algunas literaturas, esta estructura se denomina estructura de codificador-decodificador Dado que la estructura general de la red es una letra U inglesa más grande, se denomina U-net.

La estructura de la red es la siguiente:

imagen-20220324143317132

  • Codificador: la mitad izquierda consta de dos capas convolucionales de 3x3 (RELU) más una capa de agrupación máxima de 2x2 para formar un módulo de reducción de resolución (como se puede ver en el código más adelante);
  • Decodificador: Hay medias partes, que se componen de una capa convolucional de muestreo ascendente (capa deconvolucional) + concat de costura de características + dos capas convolucionales de 3x3 (ReLU) repetidamente (como se puede ver en el código); este tipo de paso a través del Número de canales Empalme, puede obtener más funciones, pero también consume más memoria.

La estructura de UNet está diseñada para que pueda combinar la información de características de bajo nivel y características de alto nivel.

Funciones de bajo nivel (profundas) : información de baja resolución después de múltiples submuestreos. Puede proporcionar información semántica contextual del objetivo de segmentación en la imagen completa, que puede entenderse como una característica que refleja la relación entre el objetivo y su entorno. Esta característica es útil para el juicio de categorías de objetos (por lo que los problemas de clasificación generalmente solo requieren información profunda/de baja resolución y no involucran la fusión de múltiples escalas)

Funciones de alto nivel (superficial) : información de alta resolución que pasa directamente del codificador al decodificador de la misma altura a través de la operación de concatenación. Puede proporcionar características más refinadas para la segmentación, como gradientes, etc.

Sobre las razones de la incompatibilidad del tamaño:

​ Se puede ver en la imagen que las dimensiones a la izquierda y a la derecha no son correctas, por lo que si desea hacer coincidir, debe realizar el recorte. La explicación de todas las flechas grises es copiar y recortar, pero ninguna de las reproducidas. modelos están disponibles De esta manera, el tamaño de los lados izquierdo y derecho se establece para que sea el mismo, y se agrega relleno a cada convolución, de modo que el tamaño no cambie después de la convolución.

Desventajas:

  1. La red opera muy lentamente. La red se ejecuta una vez para cada vecindario y repite la operación para los vecindarios superpuestos .

  2. La red necesita hacer un compromiso entre la localización precisa y la obtención de información contextual . Los parches más grandes requieren más capas de agrupación máxima , lo que reduce la precisión de la localización, mientras que los vecindarios pequeños permiten que la red adquiera menos información contextual.

Código UNet (pytorch)

import torch.nn as nn
import torch
from torch import autograd
 
#把常用的2个卷积操作简单封装下
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch), #添加了BN层
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
 
    def forward(self, input):
        return self.conv(input)
 
class Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet, self).__init__()
        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        # 逆卷积,也可以使用上采样(保证k=stride,stride即上采样倍数)
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.conv10 = nn.Conv2d(64, out_ch, 1)
 
    def forward(self, x):
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)
        c5 = self.conv5(p4)
        up_6 = self.up6(c5)
        merge6 = torch.cat([up_6, c4], dim=1)
        c6 = self.conv6(merge6)
        up_7 = self.up7(c6)
        merge7 = torch.cat([up_7, c3], dim=1)
        c7 = self.conv7(merge7)
        up_8 = self.up8(c7)
        merge8 = torch.cat([up_8, c2], dim=1)
        c8 = self.conv8(merge8)
        up_9 = self.up9(c8)
        merge9 = torch.cat([up_9, c1], dim=1)
        c9 = self.conv9(merge9)
        c10 = self.conv10(c9)
        out = nn.Sigmoid()(c10)
        return out

Código Unet (Keras)

def unet(pretrained_weights=None, input_size=(256, 256, 3)):
    inputs = Input(input_size)

    conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
    conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
    conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
    conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
    conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
    conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
    drop5 = Dropout(0.5)(conv5)

    up6 = Conv2D(512, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
        UpSampling2D(size=(2, 2))(drop5))
    merge6 = concatenate([drop4, up6], axis=3)
    conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
    conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)

    up7 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
        UpSampling2D(size=(2, 2))(conv6))
    merge7 = concatenate([conv3, up7], axis=3)
    conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
    conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)

    up8 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
        UpSampling2D(size=(2, 2))(conv7))
    merge8 = concatenate([conv2, up8], axis=3)
    conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
    conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)

    up9 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
        UpSampling2D(size=(2, 2))(conv8))
    merge9 = concatenate([conv1, up9], axis=3)
    conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
    conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
    conv9 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
    conv10 = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=conv10)
    model.summary()

    if (pretrained_weights):
        model.load_weights(pretrained_weights)

    return model

Supongo que te gusta

Origin blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/123714994
Recomendado
Clasificación