Código SEAN(1)

Código Dirección
Primero defina un entrenador.

trainer = Pix2PixTrainer(opt)

Dentro de Pix2PixTrainer, primero defina el modelo Pix2PixModel.

self.pix2pix_model = Pix2PixModel(opt)

Defina el generador y el discriminador dentro de Pix2PixModel.

self.netG, self.netD, self.netE = self.initialize_networks(opt)

Defina funciones dentro de initialize_networks.

netG = networks.define_G(opt)
netD = networks.define_D(opt) if opt.isTrain else None
netE = networks.define_E(opt) if opt.use_vae else None

Primero mira el generador:

def define_G(opt):
    netG_cls = find_network_using_name(opt.netG, 'generator')#netG=spade
    return create_network(netG_cls, opt)

El parámetro de entrada es opt.netG, que corresponde a la opción spade in. En find_network_using_name:

def find_network_using_name(target_network_name, filename):#spade,generator
    target_class_name = target_network_name + filename#spadegenerator
    module_name = 'models.networks.' + filename#models.networks.generator
    network = util.find_class_in_module(target_class_name, module_name)#<class 'models.networks.generator.SPADEGenerator'>
    assert issubclass(network, BaseNetwork), \
        "Class %s should be a subclass of BaseNetwork" % network

    return network

Entrada a find_class_in_module según target_network_name y el nombre de archivo correspondiente:

def find_class_in_module(target_cls_name, module):
    target_cls_name = target_cls_name.replace('_', '').lower()#spadegenerator
    clslib = importlib.import_module(module)#import_module()返回指定的包或模块
    cls = None
    for name, clsobj in clslib.__dict__.items():
        if name.lower() == target_cls_name:
            cls = clsobj

    if cls is None:
        print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name))
        exit(0)

    return cls

Cargamos el módulo a través de la función import_module, y el módulo corresponde a models.networks.generator. Es decir, clslib es la clase en el archivo generador. Repetimos el diccionario clslib, y si nombre es igual a spadegenerator, sea cls = clsobj.
Es decir, la red es igual a cls.

network = util.find_class_in_module(target_class_name, module_name)

Aquí hay dos problemas gramaticales:
①: importar importlib, llamar al método import_module (), puede obtener el módulo clslib de acuerdo con la cadena de entrada, y clslib puede llamar a todos los atributos y métodos en el archivo models.networks.generator.
inserte la descripción de la imagen aquí
Dentro del generador está:
inserte la descripción de la imagen aquípuede crear una instancia de SPADEGenerator a través de clslib.SPADEGenerator y luego llamar al método dentro de SPADEGenerator.
Por ejemplo: cree tres archivos nuevos.
inserte la descripción de la imagen aquí
tren:
inserte la descripción de la imagen aquí
la prueba no se utiliza y las clases del tren se importan al archivo tt.
inserte la descripción de la imagen aquí
Debido a que es el directorio del mismo nivel, puede importar directamente el tren de cadenas. Si no está en el directorio del mismo nivel, debe importar el directorio anterior.
Entonces a se convertirá en un módulo, es decir, en tren. Luego cree una instancia de las clases en la carpeta del tren. Finalmente, se llaman los métodos kill y qqq de la clase s.
Salida:
inserte la descripción de la imagen aquí
②: dict , este atributo puede ser llamado por el nombre de clase o el objeto de instancia de la clase, si llama a dict directamente por **nombre de clase , generará el diccionario compuesto por todos los atributos de clase en la clase; ** y usar el objeto de instancia de la clase. Llamar a dict generará un diccionario que consta de todos los atributos de instancia de la clase.
Referencia
aquí SPADEGenerator hereda BaseNetwork. Para la clase principal y la subclase con relación de herencia, la clase principal tiene su propio dict y la subclase también tiene su propio dict , que no contendrá la clase principal.dict .
Ejemplo: según el ejemplo anterior, a es un módulo, verifique el __dict__ de a:
inserte la descripción de la imagen aquí
salida:
inserte la descripción de la imagen aquí
volvamos al código: la red que generamos es la clase <clase 'models.networks.generator.SPADEGenerator'>.
A continuación creamos la red:inserte la descripción de la imagen aquí
inserte la descripción de la imagen aquí
cls corresponde a la red SPADEGenerator.
En ESPADA:

"""
Copyright (C) 2019 NVIDIA Corporation.  All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
from models.networks.normalization import get_nonspade_norm_layer
from models.networks.architecture import ResnetBlock as ResnetBlock
from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock
from models.networks.architecture import Zencoder

class SPADEGenerator(BaseNetwork):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.set_defaults(norm_G='spectralspadesyncbatch3x3')
        parser.add_argument('--num_upsampling_layers',
                            choices=('normal', 'more', 'most'), default='normal',
                            help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator")

        return parser

    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        self.sw, self.sh = self.compute_latent_vector_size(opt)

        self.Zencoder = Zencoder(3, 512)


        self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)

        self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='head_0')

        self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='G_middle_0')
        self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='G_middle_1')

        self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt, Block_Name='up_0')
        self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt, Block_Name='up_1')
        self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt, Block_Name='up_2')
        self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt, Block_Name='up_3', use_rgb=False)

        final_nc = nf

        if opt.num_upsampling_layers == 'most':
            self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt, Block_Name='up_4')
            final_nc = nf // 2

        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)
        #self.up = nn.Upsample(scale_factor=2, mode='bilinear')
    def compute_latent_vector_size(self, opt):
        if opt.num_upsampling_layers == 'normal':#默认
            num_up_layers = 5
        elif opt.num_upsampling_layers == 'more':
            num_up_layers = 6
        elif opt.num_upsampling_layers == 'most':
            num_up_layers = 7
        else:
            raise ValueError('opt.num_upsampling_layers [%s] not recognized' %
                             opt.num_upsampling_layers)

        sw = opt.crop_size // (2**num_up_layers)#256//32=16
        sh = round(sw / opt.aspect_ratio)#8

        return sw, sh

    def forward(self, input, rgb_img, obj_dic=None):
        seg = input
        x = F.interpolate(seg, size=(self.sh, self.sw))#(16,16)
        x = self.fc(x)#(b,1024,16,16)

        style_codes = self.Zencoder(input=rgb_img, segmap=seg)
        x = self.head_0(x, seg, style_codes, obj_dic=obj_dic)

        x = self.up(x)
        x = self.G_middle_0(x, seg, style_codes, obj_dic=obj_dic)

        if self.opt.num_upsampling_layers == 'more' or \
           self.opt.num_upsampling_layers == 'most':
            x = self.up(x)

        x = self.G_middle_1(x, seg, style_codes,  obj_dic=obj_dic)

        x = self.up(x)
        x = self.up_0(x, seg, style_codes, obj_dic=obj_dic)
        x = self.up(x)
        x = self.up_1(x, seg, style_codes, obj_dic=obj_dic)
        x = self.up(x)
        x = self.up_2(x, seg, style_codes, obj_dic=obj_dic)
        x = self.up(x)
        x = self.up_3(x, seg, style_codes,  obj_dic=obj_dic)

        # if self.opt.num_upsampling_layers == 'most':
        #     x = self.up(x)
        #     x= self.up_4(x, seg, style_codes,  obj_dic=obj_dic)

        x = self.conv_img(F.leaky_relu(x, 2e-1))
        x = F.tanh(x)
        return x

Primero calcule el tamaño del vector espacial latente:
inserte la descripción de la imagen aquí
luego calcule la matriz de estiloST. Correspondiente al artículo:
inserte la descripción de la imagen aquí
En el código: por convolución, reducción de resolución, reducción de resolución, aumento de resolución, convolución. Genera un vector con 512 canales.
inserte la descripción de la imagen aquí
A esto le sigue una secuencia de cuatro bloques de muestreo ascendente:
inserte la descripción de la imagen aquí
Correspondiente a:
inserte la descripción de la imagen aquí
Dentro del bloque SPADEResnet: El bloque SEAN se define utilizando la clase ACE.
inserte la descripción de la imagen aquí
Los parámetros normalizados y el ruido, etc. se definen dentro de ACE.
inserte la descripción de la imagen aquí
Diseñemos una expresión regular de Python. Si no la has aprendido, sigue adelante y haz las paces. Sólo puedes obtener el resultado con la depuración primero.
inserte la descripción de la imagen aquí
Aquí se utiliza SynchronizedBatchNorm2d para la normalización:
inserte la descripción de la imagen aquí
γ y β se obtienen mediante convolución:
inserte la descripción de la imagen aquí
después de realizar los cuatro bloques SEAN muestreados, se realiza una convolución final para generar la imagen compuesta. Este es el proceso de toda la red.
Parámetros de impresión del generador:
inserte la descripción de la imagen aquí
seguido del discriminador: target_class_name = multiscalediscriminator , module_name = models.networks.discriminator, siguiendo
la lógica del generador . Luego importamos el módulo discriminador. Dentro del discriminador de múltiples escalas: cree dos discriminadores únicos. Defina parámetros dentro de un único discriminador: defina la entrada del discriminador: entrada después de empalmar el canal de etiqueta y la imagen RGB.

inserte la descripción de la imagen aquí

inserte la descripción de la imagen aquí
inserte la descripción de la imagen aquí
inserte la descripción de la imagen aquí

inserte la descripción de la imagen aquí
Luego pase por una convolución 4x4 con un tamaño de paso de 2, luego pase por dos convoluciones con un tamaño de paso de 2 y finalmente pase por una convolución con un canal de salida de 1 y un tamaño de paso de 1. Registre cada convolución con el modelo.
inserte la descripción de la imagen aquí
Es decir, el discriminador consta de cinco convoluciones.
Registre un único discriminador en el discriminador. Regístrese dos veces, de modo que el disco esté compuesto por 10 convoluciones y todas tengan los nombres correspondientes.
inserte la descripción de la imagen aquí

MultiscaleDiscriminator(
  (discriminator_0): NLayerDiscriminator(
    (model0): Sequential(
      (0): Conv2d(16, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model1): Sequential(
      (0): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)
        (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model2): Sequential(
      (0): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model3): Sequential(
      (0): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
    )
  )
  (discriminator_1): NLayerDiscriminator(
    (model0): Sequential(
      (0): Conv2d(16, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model1): Sequential(
      (0): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)
        (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model2): Sequential(
      (0): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model3): Sequential(
      (0): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
    )
  )
)

De esta manera, se construyen el generador y el discriminador y netE está vacío.

Supongo que te gusta

Origin blog.csdn.net/qq_43733107/article/details/132645878
Recomendado
Clasificación