话不多说,附上
一位GitHub群众eriklindernoren就发布了17种GAN的Keras实现,得到Keras亲爸爸François Chollet在Twitter上的热情推荐。
这真的超级棒,对于我这种辣鸡,看一篇外文要好几天还看不到的人来说,有了这些代码,瞬间明白了文章的意思和做法了!
现在将我看的代码总结下关键地方。
1、acgan
核心思想在于加入了条件,也就是标签,我的理解是这在一定程度是规范了潜在空间的编码结构。
noise = Input(shape=(self.latent_dim,))
label = Input(shape=(1,))
img = self.generator([noise, label])
同时在discrimination里,不仅仅要识别出图片的真假,同时也要判别图片的标签。
valid, target_label = self.discriminator(img)
# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
self.combined = Model([noise, label], [valid, target_label])
而在生成器中,标签的使用是先用embedding进行词嵌入,在将编码后的标签与随机生成的潜在输入相乘(相当于对原本输入的随机噪声进行了一定的映射),最后送到generation里面生成图片:
noise = Input(shape=(self.latent_dim,))
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))####
model_input = multiply([noise, label_embedding])
img = model(model_input)
return Model([noise, label], img)
2、bigan
bigan的关键之处在于增加了encode器,思想应该:是既然训练出的generator能够能够让Z能够很好的生成img,那么我这个img应该也能很好的反过来被编码成Z。
# Build the generator
self.generator = self.build_generator()
# Build the encoder
self.encoder = self.build_encoder()
# The part of the bigan that trains the discriminator and encoder
self.discriminator.trainable = False
# Generate image from sampled noise
z = Input(shape=(self.latent_dim, ))
img_ = self.generator(z)
# Encode image
img = Input(shape=self.img_shape)
z_ = self.encoder(img)
# Latent -> img is fake, and img -> latent is valid
fake = self.discriminator([z, img_])
valid = self.discriminator([z_, img])
# Set up and compile the combined model
# Trains generator to fool the discriminator
self.bigan_generator = Model([z, img], [fake, valid])
3.WGAN
扫描二维码关注公众号,回复:
11346269 查看本文章
核心思想就是loss使用了(EM)堆土距离:
def wasserstein_loss(self, y_true, y_pred):
return K.mean(y_true * y_pred)
然后标签是-1与1 不是0 和1
valid = -np.ones((batch_size, 1))
fake = np.ones((batch_size, 1))
相对于传统的GAN,WGAN只做了以下三点简单的改动
D最后一层去掉sigmoid
G和D的loss不取log(sigmoid_cross_entropy_with_logits)
每次更新D的参数之后,将其绝对值截断到不超过一个固定常数c(-0.01,0.01),即gradient clipping(前作);或使用梯度惩罚,即gradient penalty(后作)
# Clip critic weights
for l in self.critic.layers:
weights = l.get_weights()
weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
l.set_weights(weights)