Masking GAN

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u011961856/article/details/79057469

github代码:https://github.com/tgeorgy/mgan

文章的创新点:

1.生成网络输入x,输出包括分割模板mask,和中间图像y,根据mask将输入x与中间图像y结合,得到生成图像.这样得到的生成图像背景与输入x相同,前景为生成部分.

2.采用端到端训练,在cyclegan损失函数的基础上,添加了对输出生成图像进行约束.

模型结构如下,

diagram

生成网络首先输出为分割模板mask,以及中间图像y,将中间图像y和mask混合,得到的输出作为最后的生成生成图像.生成网络代码如下,

class Generator(nn.Module):
    def __init__(self, input_nc=3, output_nc=4, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6):
        assert(n_blocks >= 0)
        super(Generator, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, use_dropout=use_dropout)]

        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ReflectionPad2d(1),
                      nn.Conv2d(ngf * mult, int(ngf * mult / 2),
                                kernel_size=3, stride=1),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True),
                      nn.Conv2d(int(ngf * mult / 2), int(ngf * mult / 2)*4,
                                kernel_size=1, stride=1),
                      nn.PixelShuffle(2),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True),
                     ]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]

        self.model = nn.Sequential(*model)

代码中,生成网络输入通道为3,输出通道为4,第一个通道为mask,其他三个通道为中间生成图像.

def forward(self, input):
    output = self.model(input)
    mask = F.sigmoid(output[:, :1])
    oimg = output[:, 1:]
    mask = mask.repeat(1, 3, 1, 1)
    oimg = oimg*mask + input*(1-mask)

    return oimg, mask

采用cyclegan结构,也就是,包含两个生成网络,两个判别网络.

对于每个生成网络,损失函数包括三个部分,第一个为loss_P2N_cyc ,与cyclegan loss相同,即输入到生成网络g1的输出,在输入生成网络g2,得到输出与输入尽量相同.第二个loss_P2N_gan为gan损失函数,也就是判别网络判断label为真.第三个为loss_N2P_idnt,也就是生成网路g1的输出与label尽量相似,也就是文章是end to end(输入-label对应)训练,由于cyclegan不是end to end,所以没有这个损失函数,

criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()
criterion_gan = nn.MSELoss()
# Train P2N Generator
real_pos_v = Variable(real_pos)
fake_neg, mask_neg = netP2N(real_pos_v)
rec_pos, _ = netN2P(fake_neg)
fake_neg_lbl = netDN(fake_neg)

loss_P2N_cyc = criterion_cycle(rec_pos, real_pos_v)
loss_P2N_gan = criterion_gan(fake_neg_lbl, Variable(real_lbl))
loss_N2P_idnt = criterion_identity(fake_neg, real_pos_v)
# Train N2P Generator
real_neg_v = Variable(real_neg)
fake_pos, mask_pos = netN2P(real_neg_v)
rec_neg, _ = netP2N(fake_pos)
fake_pos_lbl = netDP(fake_pos)

loss_N2P_cyc = criterion_cycle(rec_neg, real_neg_v)
loss_N2P_gan = criterion_gan(fake_pos_lbl, Variable(real_lbl))
loss_P2N_idnt = criterion_identity(fake_pos, real_neg_v)

loss_G = ((loss_P2N_gan + loss_N2P_gan)*0.5 +
          (loss_P2N_cyc + loss_N2P_cyc)*lambda_cycle +
          (loss_P2N_idnt + loss_N2P_idnt)*lambda_identity)

判别网络用于判别输入的真假,

# Train Discriminators
netDN.zero_grad()
netDP.zero_grad()
fake_neg_score = netDN(fake_neg.detach())
loss_D = criterion_gan(fake_neg_score, Variable(fake_lbl))
fake_pos_score = netDP(fake_pos.detach())
loss_D += criterion_gan(fake_pos_score, Variable(fake_lbl))

real_neg_score = netDN.forward(real_neg_v)
loss_D += criterion_gan(real_neg_score, Variable(real_lbl))
real_pos_score = netDP.forward(real_pos_v)
loss_D += criterion_gan(real_pos_score, Variable(real_lbl))

猜你喜欢

转载自blog.csdn.net/u011961856/article/details/79057469
GAN