GAN变体keras程序总结

话不多说,附上

eriklindernoren/Keras-GAN

一位GitHub群众eriklindernoren就发布了17种GAN的Keras实现,得到Keras亲爸爸François Chollet在Twitter上的热情推荐。

这真的超级棒,对于我这种辣鸡,看一篇外文要好几天还看不到的人来说,有了这些代码,瞬间明白了文章的意思和做法了!

现在将我看的代码总结下关键地方。

1、acgan

核心思想在于加入了条件,也就是标签,我的理解是这在一定程度是规范了潜在空间的编码结构。

 noise = Input(shape=(self.latent_dim,))
 label = Input(shape=(1,))
 img = self.generator([noise, label])

同时在discrimination里,不仅仅要识别出图片的真假,同时也要判别图片的标签

 valid, target_label = self.discriminator(img)

 # The combined model  (stacked generator and discriminator)
 # Trains the generator to fool the discriminator
 
self.combined = Model([noise, label], [valid, target_label])

而在生成器中,标签的使用是先用embedding进行词嵌入,在将编码后的标签与随机生成的潜在输入相乘(相当于对原本输入的随机噪声进行了一定的映射),最后送到generation里面生成图片:

noise = Input(shape=(self.latent_dim,))
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))####

model_input = multiply([noise, label_embedding])
img = model(model_input)

return Model([noise, label], img)

2、bigan

bigan的关键之处在于增加了encode器,思想应该:是既然训练出的generator能够能够让Z能够很好的生成img,那么我这个img应该也能很好的反过来被编码成Z。

# Build the generator
self.generator = self.build_generator()

# Build the encoder
self.encoder = self.build_encoder()

# The part of the bigan that trains the discriminator and encoder
self.discriminator.trainable = False

# Generate image from sampled noise
z = Input(shape=(self.latent_dim, ))
img_ = self.generator(z)

# Encode image
img = Input(shape=self.img_shape)
z_ = self.encoder(img)

# Latent -> img is fake, and img -> latent is valid
fake = self.discriminator([z, img_])
valid = self.discriminator([z_, img])

# Set up and compile the combined model
# Trains generator to fool the discriminator
self.bigan_generator = Model([z, img], [fake, valid])

3.WGAN

扫描二维码关注公众号,回复: 11346269 查看本文章

核心思想就是loss使用了(EM)堆土距离:

def wasserstein_loss(self, y_true, y_pred):
    return K.mean(y_true * y_pred)

然后标签是-1与1 不是0 和1 

valid = -np.ones((batch_size, 1))
fake = np.ones((batch_size, 1))

相对于传统的GAN,WGAN只做了以下三点简单的改动

D最后一层去掉sigmoid

G和D的loss不取log(sigmoid_cross_entropy_with_logits)

每次更新D的参数之后,将其绝对值截断到不超过一个固定常数c(-0.01,0.01),即gradient clipping(前作);或使用梯度惩罚,即gradient penalty(后作)

 # Clip critic weights
for l in self.critic.layers:
     weights = l.get_weights()
     weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
     l.set_weights(weights)

猜你喜欢

转载自blog.csdn.net/qq_26593695/article/details/94594656
今日推荐