Código Dirección
Primero defina un entrenador.
trainer = Pix2PixTrainer(opt)
Dentro de Pix2PixTrainer, primero defina el modelo Pix2PixModel.
self.pix2pix_model = Pix2PixModel(opt)
Defina el generador y el discriminador dentro de Pix2PixModel.
self.netG, self.netD, self.netE = self.initialize_networks(opt)
Defina funciones dentro de initialize_networks.
netG = networks.define_G(opt)
netD = networks.define_D(opt) if opt.isTrain else None
netE = networks.define_E(opt) if opt.use_vae else None
Primero mira el generador:
def define_G(opt):
netG_cls = find_network_using_name(opt.netG, 'generator')#netG=spade
return create_network(netG_cls, opt)
El parámetro de entrada es opt.netG, que corresponde a la opción spade in. En find_network_using_name:
def find_network_using_name(target_network_name, filename):#spade,generator
target_class_name = target_network_name + filename#spadegenerator
module_name = 'models.networks.' + filename#models.networks.generator
network = util.find_class_in_module(target_class_name, module_name)#<class 'models.networks.generator.SPADEGenerator'>
assert issubclass(network, BaseNetwork), \
"Class %s should be a subclass of BaseNetwork" % network
return network
Entrada a find_class_in_module según target_network_name y el nombre de archivo correspondiente:
def find_class_in_module(target_cls_name, module):
target_cls_name = target_cls_name.replace('_', '').lower()#spadegenerator
clslib = importlib.import_module(module)#import_module()返回指定的包或模块
cls = None
for name, clsobj in clslib.__dict__.items():
if name.lower() == target_cls_name:
cls = clsobj
if cls is None:
print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name))
exit(0)
return cls
Cargamos el módulo a través de la función import_module, y el módulo corresponde a models.networks.generator. Es decir, clslib es la clase en el archivo generador. Repetimos el diccionario clslib, y si nombre es igual a spadegenerator, sea cls = clsobj.
Es decir, la red es igual a cls.
network = util.find_class_in_module(target_class_name, module_name)
Aquí hay dos problemas gramaticales:
①: importar importlib, llamar al método import_module (), puede obtener el módulo clslib de acuerdo con la cadena de entrada, y clslib puede llamar a todos los atributos y métodos en el archivo models.networks.generator.
Dentro del generador está:
puede crear una instancia de SPADEGenerator a través de clslib.SPADEGenerator y luego llamar al método dentro de SPADEGenerator.
Por ejemplo: cree tres archivos nuevos.
tren:
la prueba no se utiliza y las clases del tren se importan al archivo tt.
Debido a que es el directorio del mismo nivel, puede importar directamente el tren de cadenas. Si no está en el directorio del mismo nivel, debe importar el directorio anterior.
Entonces a se convertirá en un módulo, es decir, en tren. Luego cree una instancia de las clases en la carpeta del tren. Finalmente, se llaman los métodos kill y qqq de la clase s.
Salida:
②: dict , este atributo puede ser llamado por el nombre de clase o el objeto de instancia de la clase, si llama a dict directamente por **nombre de clase , generará el diccionario compuesto por todos los atributos de clase en la clase; ** y usar el objeto de instancia de la clase. Llamar a dict generará un diccionario que consta de todos los atributos de instancia de la clase.
Referencia
aquí SPADEGenerator hereda BaseNetwork. Para la clase principal y la subclase con relación de herencia, la clase principal tiene su propio dict y la subclase también tiene su propio dict , que no contendrá la clase principal.dict .
Ejemplo: según el ejemplo anterior, a es un módulo, verifique el __dict__ de a:
salida:
volvamos al código: la red que generamos es la clase <clase 'models.networks.generator.SPADEGenerator'>.
A continuación creamos la red:
cls corresponde a la red SPADEGenerator.
En ESPADA:
"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
from models.networks.normalization import get_nonspade_norm_layer
from models.networks.architecture import ResnetBlock as ResnetBlock
from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock
from models.networks.architecture import Zencoder
class SPADEGenerator(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.set_defaults(norm_G='spectralspadesyncbatch3x3')
parser.add_argument('--num_upsampling_layers',
choices=('normal', 'more', 'most'), default='normal',
help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator")
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
nf = opt.ngf
self.sw, self.sh = self.compute_latent_vector_size(opt)
self.Zencoder = Zencoder(3, 512)
self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='head_0')
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='G_middle_0')
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='G_middle_1')
self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt, Block_Name='up_0')
self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt, Block_Name='up_1')
self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt, Block_Name='up_2')
self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt, Block_Name='up_3', use_rgb=False)
final_nc = nf
if opt.num_upsampling_layers == 'most':
self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt, Block_Name='up_4')
final_nc = nf // 2
self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
self.up = nn.Upsample(scale_factor=2)
#self.up = nn.Upsample(scale_factor=2, mode='bilinear')
def compute_latent_vector_size(self, opt):
if opt.num_upsampling_layers == 'normal':#默认
num_up_layers = 5
elif opt.num_upsampling_layers == 'more':
num_up_layers = 6
elif opt.num_upsampling_layers == 'most':
num_up_layers = 7
else:
raise ValueError('opt.num_upsampling_layers [%s] not recognized' %
opt.num_upsampling_layers)
sw = opt.crop_size // (2**num_up_layers)#256//32=16
sh = round(sw / opt.aspect_ratio)#8
return sw, sh
def forward(self, input, rgb_img, obj_dic=None):
seg = input
x = F.interpolate(seg, size=(self.sh, self.sw))#(16,16)
x = self.fc(x)#(b,1024,16,16)
style_codes = self.Zencoder(input=rgb_img, segmap=seg)
x = self.head_0(x, seg, style_codes, obj_dic=obj_dic)
x = self.up(x)
x = self.G_middle_0(x, seg, style_codes, obj_dic=obj_dic)
if self.opt.num_upsampling_layers == 'more' or \
self.opt.num_upsampling_layers == 'most':
x = self.up(x)
x = self.G_middle_1(x, seg, style_codes, obj_dic=obj_dic)
x = self.up(x)
x = self.up_0(x, seg, style_codes, obj_dic=obj_dic)
x = self.up(x)
x = self.up_1(x, seg, style_codes, obj_dic=obj_dic)
x = self.up(x)
x = self.up_2(x, seg, style_codes, obj_dic=obj_dic)
x = self.up(x)
x = self.up_3(x, seg, style_codes, obj_dic=obj_dic)
# if self.opt.num_upsampling_layers == 'most':
# x = self.up(x)
# x= self.up_4(x, seg, style_codes, obj_dic=obj_dic)
x = self.conv_img(F.leaky_relu(x, 2e-1))
x = F.tanh(x)
return x
Primero calcule el tamaño del vector espacial latente:
luego calcule la matriz de estiloST. Correspondiente al artículo:
En el código: por convolución, reducción de resolución, reducción de resolución, aumento de resolución, convolución. Genera un vector con 512 canales.
A esto le sigue una secuencia de cuatro bloques de muestreo ascendente:
Correspondiente a:
Dentro del bloque SPADEResnet: El bloque SEAN se define utilizando la clase ACE.
Los parámetros normalizados y el ruido, etc. se definen dentro de ACE.
Diseñemos una expresión regular de Python. Si no la has aprendido, sigue adelante y haz las paces. Sólo puedes obtener el resultado con la depuración primero.
Aquí se utiliza SynchronizedBatchNorm2d para la normalización:
γ y β se obtienen mediante convolución:
después de realizar los cuatro bloques SEAN muestreados, se realiza una convolución final para generar la imagen compuesta. Este es el proceso de toda la red.
Parámetros de impresión del generador:
seguido del discriminador: target_class_name = multiscalediscriminator , module_name = models.networks.discriminator, siguiendo
la lógica del generador . Luego importamos el módulo discriminador. Dentro del discriminador de múltiples escalas: cree dos discriminadores únicos. Defina parámetros dentro de un único discriminador: defina la entrada del discriminador: entrada después de empalmar el canal de etiqueta y la imagen RGB.
Luego pase por una convolución 4x4 con un tamaño de paso de 2, luego pase por dos convoluciones con un tamaño de paso de 2 y finalmente pase por una convolución con un canal de salida de 1 y un tamaño de paso de 1. Registre cada convolución con el modelo.
Es decir, el discriminador consta de cinco convoluciones.
Registre un único discriminador en el discriminador. Regístrese dos veces, de modo que el disco esté compuesto por 10 convoluciones y todas tengan los nombres correspondientes.
MultiscaleDiscriminator(
(discriminator_0): NLayerDiscriminator(
(model0): Sequential(
(0): Conv2d(16, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
(1): LeakyReLU(negative_slope=0.2)
)
(model1): Sequential(
(0): Sequential(
(0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)
(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
)
(1): LeakyReLU(negative_slope=0.2)
)
(model2): Sequential(
(0): Sequential(
(0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)
(1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
)
(1): LeakyReLU(negative_slope=0.2)
)
(model3): Sequential(
(0): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
)
)
(discriminator_1): NLayerDiscriminator(
(model0): Sequential(
(0): Conv2d(16, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
(1): LeakyReLU(negative_slope=0.2)
)
(model1): Sequential(
(0): Sequential(
(0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)
(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
)
(1): LeakyReLU(negative_slope=0.2)
)
(model2): Sequential(
(0): Sequential(
(0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)
(1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
)
(1): LeakyReLU(negative_slope=0.2)
)
(model3): Sequential(
(0): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
)
)
)
De esta manera, se construyen el generador y el discriminador y netE está vacío.