Latent Constraints: Conditional Generation from Unconditional Generative Models
Jesse Engel, Matthew Hoffman, Adam Roberts arXiv link
Abstract:
深度生成神经网络在复杂数据分布的条件和无条件建模方面都是有效的。
条件生成实现了交互式控制,但创建新控件通常需要昂贵的再训练。
在本文中,我们开发了一种条件生成方法,无需重新训练模型。
通过事后学习____latent constraints___,识别潜在空间中产生具有所需属性的输出的区域的值函数,我们可以使用基于梯度的优化或摊销的演员函数从这些区域中有条件地进行采样。
将属性约束与通用的“现实主义”约束相结合,强制实现与数据分布的相似性,我们从无条件变量自动编码器生成逼真的条件图像。
此外,使用基于梯度的优化,我们演示了保持身份的转换
潜在空间中的最小调整以修改图像的属性。
最后,利用离散的音符序列,我们展示了零射击条件生成,在没有标记数据或可微分奖励函数的情况下学习潜在约束。
此笔记本包含用于运行与纸张相关的实验的代码。 首先,我们加载预先训练的检查点:
- 在CelebS上训练的VAE模型具有像素明智的高斯数据可能性 and .
- 我们还提供了VAE模型的训练和评估集的嵌入。
- 来自条件GAN的生成器( )和鉴别器( ),经过训练,可以从潜在空间中的新点移动样本,满足现实约束( )和属性约束( ).
- 我们的训练版本没有距离惩罚,并且罚款为1e-1。
- z空间(
)和像素空间(
)中经过单独训练的属性分类器。
然后我们继续:
*证明VAE重建在 降低时会降低,但代价是样本质量,这可以通过潜在约束来补偿。
*使用CGAN( , )绘制条件生成,无论是否有距离惩罚。
*使用 在z空间中执行身份保留转换
*评估使用条件属性生成图像的准确性。
在笔记本电脑末端还提供了训练循环以用于演示目的。
这个colab笔记本是独立的,应该在谷歌云上本地运行。 代码和检查点可以单独下载并在本地运行,如果您想训练自己的模型,建议您使用。 可以在download.magenta.tensorflow.org/models/latent_constraints/latent_constraints.tar 上找到预训练模型检查点。
# This notebook requires DeepMind's sonnet library, which itself
# requires the nightly build of TensorFlow. The command below
# installs both.
!pip install -q -U dm-sonnet tf-nightly tfp-nightly
import os
import PIL
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import sklearn.metrics
import sonnet as snt
import tensorflow as tf
import tensorflow_probability as tfp
ds = tfp.distributions
%matplotlib inline
# Copy checkpoints from google cloud
# Copying 3GB, takes a minute
!gsutil -q -m cp -R gs://download.magenta.tensorflow.org/models/latent_constraints /content/
Load the Data
basepath = '/content/latent_constraints/'
# Load CelebA embeddings
# VAE with x_sigma = 0.1
train_mu = np.load(basepath + 'train_mu.npy')
train_sigma = np.load(basepath + 'train_sigma.npy')
eval_mu = np.load(basepath + 'eval_mu.npy')
eval_sigma = np.load(basepath + 'eval_sigma.npy')
# VAE with x_sigma = 1.0
eval_mu_xsigma1 = np.load(basepath + 'eval_mu_xsigma1.npy')
eval_sigma_xsigma1 = np.load(basepath + 'eval_sigma_xsigma1.npy')
np.random.seed(10003)
n_train = train_mu.shape[0]
n_eval = eval_mu.shape[0]
# Load Attributes
# Only use 10 salient attributes
attr_train = np.load(basepath + 'attr_train.npy')
attr_eval = np.load(basepath + 'attr_eval.npy')
attr_test = np.load(basepath + 'attr_test.npy')
attr_mask = [4, 8, 9, 11, 15, 20, 24, 31, 35, 39]
attribute_names = [
'Bald',
'Black_Hair',
'Blond_Hair',
'Brown_Hair',
'Eyeglasses',
'Male',
'No_Beard',
'Smiling',
'Wearing_Hat',
'Young',
]
attr_train = attr_train[:, attr_mask]
attr_eval = attr_eval[:, attr_mask]
attr_test = attr_test[:, attr_mask]
Define the Graph
所有带变量的函数都包含在sonnet
模块中。
就像将它们联系在一起的“模型”一样。
张量(端点)可作为“模型”的属性访问。
class Encoder(snt.AbstractModule):
'''VAE Convolutional Encoder.'''
def __init__(self,
n_latent,
layers=((256, 5, 2),
(512, 5, 2),
(1024, 3, 2),
(2048, 3, 2)),
name='encoder'):
super(Encoder, self).__init__(name=name)
self.n_latent = n_latent
self.layers = layers
def _build(self, x):
h = x
for unused_i, l in enumerate(self.layers):
h = tf.nn.relu(snt.Conv2D(l[0], l[1], l[2])(h))
h_shape = h.get_shape().as_list()
h = tf.reshape(h, [-1, h_shape[1] * h_shape[2] * h_shape[3]])
pre_z = snt.Linear(2 * self.n_latent)(h)
mu = pre_z[:, :self.n_latent]
sigma = tf.nn.softplus(pre_z[:, self.n_latent:])
return mu, sigma
class Decoder(snt.AbstractModule):
'''VAE Convolutional Decoder.'''
def __init__(self,
layers=((2048, 4, 4),
(1024, 3, 2),
(512, 3, 2),
(256, 5, 2),
(3, 5, 2)),
name='decoder'):
super(Decoder, self).__init__(name=name)
self.layers = layers
def _build(self, x):
for i, l in enumerate(self.layers):
if i == 0:
h = snt.Linear(l[1] * l[2] * l[0])(x)
h = tf.reshape(h, [-1, l[1], l[2], l[0]])
elif i == len(self.layers) - 1:
h = snt.Conv2DTranspose(l[0], None, l[1], l[2])(h)
else:
h = tf.nn.relu(snt.Conv2DTranspose(l[0], None, l[1], l[2])(h))
logits = h
return logits
class G(snt.AbstractModule):
'''CGAN Generator. Maps from z-space to z-space.'''
def __init__(self,
n_latent,
layers=(2048,)*4,
name='generator'):
super(G, self).__init__(name=name)
self.layers = layers
self.n_latent = n_latent
def _build(self, z_and_labels):
z, labels = z_and_labels
labels = tf.cast(labels, tf.float32)
size = self.layers[0]
x = tf.concat([z, snt.Linear(size)(labels)], axis=-1)
for l in self.layers:
x = tf.nn.relu(snt.Linear(l)(x))
x = snt.Linear(2 * self.n_latent)(x)
dz = x[:, :self.n_latent]
gates = tf.nn.sigmoid(x[:, self.n_latent:])
z_prime = (1-gates) * z + gates * dz
return z_prime
class D(snt.AbstractModule):
'''CGAN Discriminator.'''
def __init__(self,
output_size=1,
layers=(2048,)*4,
name='D'):
super(D, self).__init__(name=name)
self.layers = layers
self.output_size = output_size
def _build(self, z_and_labels):
z, labels = z_and_labels
labels = tf.cast(labels, tf.float32)
size = self.layers[0]
x = tf.concat([z, snt.Linear(size)(labels)], axis=-1)
for l in self.layers:
x = tf.nn.relu(snt.Linear(l)(x))
logits = snt.Linear(self.output_size)(x)
return logits
class DAttr(snt.AbstractModule):
'''Attribute Classifier from z-space.'''
def __init__(self,
output_size=1,
layers=(2048,)*4,
name='D'):
super(DAttr, self).__init__(name=name)
self.layers = layers
self.output_size = output_size
def _build(self, x):
for l in self.layers:
x = tf.nn.relu(snt.Linear(l)(x))
logits = snt.Linear(self.output_size)(x)
return logits
class Classifier(snt.AbstractModule):
'''Convolutional Attribute Classifier from Pixels.'''
def __init__(self,
output_size,
layers=((256, 5, 2),
(256, 3, 1),
(512, 5, 2),
(512, 3, 1),
(1024, 3, 2),
(2048, 3, 2)),
name='encoder'):
super(Classifier, self).__init__(name=name)
self.output_size = output_size
self.layers = layers
def _build(self, x):
h = x
for unused_i, l in enumerate(self.layers):
h = tf.nn.relu(snt.Conv2D(l[0], l[1], l[2])(h))
h_shape = h.get_shape().as_list()
h = tf.reshape(h, [-1, h_shape[1] * h_shape[2] * h_shape[3]])
logits = snt.Linear(self.output_size)(h)
return logits
class Model(snt.AbstractModule):
'''All the components glued together.'''
def __init__(self, config, name=''):
super(Model, self).__init__(name=name)
self.config = config
def _build(self, unused_input=None):
config = self.config
# Constants
batch_size = config['batch_size']
n_latent = config['n_latent']
img_width = config['img_width']
half_batch = int(batch_size / 2)
n_labels = 10
#---------------------------------------------------------------------------
### Placeholders
#---------------------------------------------------------------------------
x = tf.placeholder(tf.float32,
shape=(None, img_width, img_width, 3), name='x')
# Attributes
labels = tf.placeholder(tf.int32, shape=(None, n_labels), name='labels')
# Real / fake label reward
r = tf.placeholder(tf.float32, shape=(None, 1), name='D_label')
# Transform through optimization
z0 = tf.placeholder(tf.float32, shape=(None, n_latent), name='z0')
z_prime = tf.get_variable('z_prime',
shape=(half_batch, n_latent), dtype=tf.float32)
#---------------------------------------------------------------------------
### Modules with parameters
#---------------------------------------------------------------------------
encoder = Encoder(n_latent=n_latent, name='encoder')
decoder = Decoder(name='decoder')
g = G(n_latent=n_latent, name='generator')
d = D(output_size=1, name='d_z')
d_attr = DAttr(output_size=n_labels, name='d_attr')
classifier = Classifier(output_size=n_labels, name='classifier')
#---------------------------------------------------------------------------
### VAE
#---------------------------------------------------------------------------
# Encode
mu, sigma = encoder(x)
q_z = ds.Normal(loc=mu, scale=sigma)
# Optimize / Amortize or feedthrough
q_z_sample = q_z.sample()
transform = tf.constant(False)
z = tf.cond(transform, lambda: z_prime, lambda: q_z_sample)
amortize = tf.constant(False)
z = tf.cond(amortize, lambda: g((z, labels)), lambda: z)
# Decode
logits = decoder(z)
x_sigma = tf.constant(config['x_sigma'])
p_x = ds.Normal(loc=tf.nn.sigmoid(logits), scale=x_sigma)
x_mean = p_x.mean()
# Reconstruction Loss
recons = tf.reduce_sum(p_x.log_prob(x), axis=[1, 2, 3])
mean_recons = tf.reduce_mean(recons)
# Prior
p_z = ds.Normal(loc=0., scale=1.)
prior_sample = p_z.sample(sample_shape=[batch_size, n_latent])
# KL Loss
KL_qp = ds.kl_divergence(q_z, p_z)
KL = tf.reduce_sum(KL_qp, axis=-1)
mean_KL = tf.reduce_mean(KL)
beta = tf.constant(config['beta'])
# VAE Loss
vae_loss = -mean_recons + mean_KL * beta
#---------------------------------------------------------------------------
### Discriminator Constraint in Img and Z space and Digit space (implicit)
#---------------------------------------------------------------------------
d_logits = d([z, labels])
r_pred = tf.nn.sigmoid(d_logits) # r = [0 prior, 1 data]
d_loss = tf.losses.sigmoid_cross_entropy(r, d_logits)
# Mean over examples
d_loss = tf.reduce_mean(d_loss)
# Gradient Penalty
real_data = z[:half_batch]
fake_data = z[half_batch:batch_size]
alpha = tf.random_uniform(shape=[half_batch, n_latent], minval=0., maxval=1.)
differences = fake_data - real_data
interpolates = real_data + (alpha * differences)
interp_pred = d([interpolates, labels[:half_batch]])
gradients = tf.gradients(interp_pred, [interpolates])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]) + 1e-10)
gradient_penalty = tf.reduce_mean((slopes - 1.)**2)
# Add penalty
lambda_weight = tf.constant(config['lambda_weight'])
d_loss_training = d_loss + lambda_weight * gradient_penalty
#---------------------------------------------------------------------------
### Discriminator Attribute classification (implicit constraint)
#---------------------------------------------------------------------------
# Z-Space
attr_weights = tf.constant(np.ones([1, n_labels]).astype(np.float32))
logits_attr = d_attr(z)
pred_attr = tf.nn.sigmoid(logits_attr)
d_loss_attr = tf.losses.sigmoid_cross_entropy(labels,
logits=logits_attr,
weights=attr_weights)
#---------------------------------------------------------------------------
### OPTIMIZTION TRANSFORMATION (SGD)
#---------------------------------------------------------------------------
# Realism Constraint
transform_r_weight = tf.constant(1.0)
loss_transform = transform_r_weight * tf.reduce_mean(d_loss)
# Attribute Constraint
transform_attr_weight = tf.constant(0.0)
loss_transform += transform_attr_weight * d_loss_attr
# Distance Penalty
transform_penalty_weight = tf.constant(0.0)
z_sigma_mean = tf.constant(np.ones([1, n_latent]).astype(np.float32))
transform_penalty = tf.log(1 + (z_prime - z0)**2)
transform_penalty = transform_penalty * z_sigma_mean**-2
loss_transform += tf.reduce_mean(transform_penalty_weight * transform_penalty)
#---------------------------------------------------------------------------
### AMORTIZED TRANSFORMATION (Generator)
#---------------------------------------------------------------------------
# Realism and Attribute Constraint
g_loss = -tf.log(tf.clip_by_value(r_pred, 1e-15, 1 - 1e-15))
g_loss = tf.reduce_mean(g_loss)
# Distance Penalty
g_penalty_weight = tf.constant(0.0)
g_penalty = tf.log(1 + (z - q_z_sample)**2)
g_penalty = g_penalty * z_sigma_mean**-2
g_penalty = tf.reduce_mean(g_penalty)
g_loss += g_penalty_weight * g_penalty
#---------------------------------------------------------------------------
### Classify Attributes from pixels
#---------------------------------------------------------------------------
logits_classifier = classifier(x)
pred_classifier = tf.nn.sigmoid(logits_classifier)
classifier_loss = tf.losses.sigmoid_cross_entropy(labels,
logits=logits_classifier)
#---------------------------------------------------------------------------
### Training
#---------------------------------------------------------------------------
# Learning rates
d_lr = tf.constant(3e-4)
d_attr_lr = tf.constant(3e-4)
vae_lr = tf.constant(3e-4)
g_lr = tf.constant(3e-4)
classifier_lr = tf.constant(3e-4)
transform_lr = tf.constant(3e-4)
# Training Ops
vae_vars = list(encoder.get_variables())
vae_vars.extend(decoder.get_variables())
train_vae = tf.train.AdamOptimizer(vae_lr).minimize(vae_loss, var_list=vae_vars)
d_vars = d.get_variables()
train_d = tf.train.AdamOptimizer(d_lr, beta1=0, beta2=0.9).minimize(
d_loss_training, var_list=d_vars)
classifier_vars = classifier.get_variables()
train_classifier = tf.train.AdamOptimizer(classifier_lr).minimize(
classifier_loss, var_list=classifier_vars)
g_vars = g.get_variables()
train_g = tf.train.AdamOptimizer(g_lr, beta1=0, beta2=0.9).minimize(
g_loss, var_list=g_vars)
d_attr_vars = d_attr.get_variables()
train_d_attr = tf.train.AdamOptimizer(d_attr_lr).minimize(
d_loss_attr, var_list=d_attr_vars)
train_transform = tf.train.AdamOptimizer(transform_lr).minimize(
loss_transform, var_list=[z_prime])
# Savers
vae_saver = tf.train.Saver(vae_vars, max_to_keep=100)
g_saver = tf.train.Saver(g_vars, max_to_keep=1000)
d_saver = tf.train.Saver(d_vars, max_to_keep=1000)
d_attr_saver = tf.train.Saver(d_attr_vars, max_to_keep=1000)
classifier_saver = tf.train.Saver(classifier_vars, max_to_keep=1000)
# Add all endpoints as object attributes
for k, v in locals().iteritems():
self.__dict__[k] = v
Load all models
config = {
'n_latent': 1024,
'img_width': 64,
'crop_width': 64,
# Optimization parameters
'batch_size': 128,
'beta': 1.0,
'x_sigma': 0.1,
'lambda_weight': 10.0,
'penalty_weight': 0.0,
}
tf.reset_default_graph()
sess = tf.Session()
# Declare
m = Model(config)
# Build
_ = m()
# Initialize
sess.run(tf.global_variables_initializer())
# Load VAE
ckpt = os.path.join(basepath, 'vae_best_celeba_0_crop128_beta1.ckpt')
m.vae_saver.restore(sess, ckpt)
# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_penalty1e-1_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_penalty1e-1_399000.ckpt')
m.g_saver.restore(sess, ckpt)
# Load D_attr
ckpt = os.path.join(basepath, 'D_attr_best_d_attr_0.ckpt')
m.d_attr_saver.restore(sess, ckpt)
# Load Classifier
ckpt = os.path.join(basepath, 'classifier_best_classifier_0.ckpt')
m.classifier_saver.restore(sess, ckpt)
GENERATE PLOTS
def im(x):
plt.imshow(np.maximum(0, np.minimum(1, x)), interpolation='none')
plt.xticks([])
plt.yticks([])
def batch_image(b, max_images=64, rows=None, cols=None):
"""Turn a batch of images into a single image mosaic."""
mb = min(b.shape[0], max_images)
if rows is None:
rows = int(np.ceil(np.sqrt(mb)))
cols = rows
diff = rows * cols - mb
b = np.vstack([b[:mb], np.zeros([diff, b.shape[1], b.shape[2], b.shape[3]])])
tmp = b.reshape(-1, cols * b.shape[1], b.shape[2], b.shape[3])
img = np.hstack(tmp[i] for i in range(rows))
return img
# A list of attributes from which to condition generation
# Each list element corresponds to a different fully-speciffied condition
cond_attr_list = [
[
(0, 'Bald'),
(0, 'Black_Hair'),
(1, 'Blond_Hair'),
(0, 'Brown_Hair'),
(0, 'Eyeglasses'),
(0, 'Male'),
(1, 'No_Beard'),
(1, 'Smiling'),
(0, 'Wearing_Hat'),
(1, 'Young'),
],
[
(0, 'Bald'),
(0, 'Black_Hair'),
(0, 'Blond_Hair'),
(1, 'Brown_Hair'),
(0, 'Eyeglasses'),
(0, 'Male'),
(1, 'No_Beard'),
(1, 'Smiling'),
(0, 'Wearing_Hat'),
(1, 'Young'),
],
[
(0, 'Bald'),
(1, 'Black_Hair'),
(0, 'Blond_Hair'),
(0, 'Brown_Hair'),
(0, 'Eyeglasses'),
(0, 'Male'),
(1, 'No_Beard'),
(1, 'Smiling'),
(0, 'Wearing_Hat'),
(1, 'Young'),
],
[
(0, 'Bald'),
(1, 'Black_Hair'),
(0, 'Blond_Hair'),
(0, 'Brown_Hair'),
(0, 'Eyeglasses'),
(1, 'Male'),
(1, 'No_Beard'),
(1, 'Smiling'),
(0, 'Wearing_Hat'),
(1, 'Young'),
],
[
(0, 'Bald'),
(1, 'Black_Hair'),
(0, 'Blond_Hair'),
(0, 'Brown_Hair'),
(0, 'Eyeglasses'),
(1, 'Male'),
(0, 'No_Beard'),
(1, 'Smiling'),
(0, 'Wearing_Hat'),
(1, 'Young'),
],
[
(0, 'Bald'),
(1, 'Black_Hair'),
(0, 'Blond_Hair'),
(0, 'Brown_Hair'),
(1, 'Eyeglasses'),
(1, 'Male'),
(0, 'No_Beard'),
(1, 'Smiling'),
(0, 'Wearing_Hat'),
(1, 'Young'),
],
[
(1, 'Bald'),
(0, 'Black_Hair'),
(0, 'Blond_Hair'),
(0, 'Brown_Hair'),
(0, 'Eyeglasses'),
(1, 'Male'),
(0, 'No_Beard'),
(1, 'Smiling'),
(0, 'Wearing_Hat'),
(1, 'Young'),
],
[
(1, 'Bald'),
(0, 'Black_Hair'),
(0, 'Blond_Hair'),
(0, 'Brown_Hair'),
(0, 'Eyeglasses'),
(1, 'Male'),
(0, 'No_Beard'),
(0, 'Smiling'),
(0, 'Wearing_Hat'),
(0, 'Young'),
],
]
cond_attrs = []
for attrs in cond_attr_list:
cond_attrs.append( np.repeat(np.array([[a[0] for a in attrs]]).astype(np.int32), m.batch_size, axis=0) )
VAE Reconstructions
# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_399000.ckpt')
m.g_saver.restore(sess, ckpt)
z_prior = np.random.randn(9, m.n_latent)
labels = attr_eval[:9]
ckpt = os.path.join(basepath, 'vae_best_celeba_0_crop128_beta1_xsigma1.ckpt')
m.vae_saver.restore(sess, ckpt)
z_eval_blurry = (eval_mu_xsigma1[:9] +
eval_sigma_xsigma1[:9] * np.random.randn(9, m.n_latent))
b_eval_recon_blurry = sess.run(m.x_mean, {m.z:z_eval_blurry, m.labels:np.zeros([m.batch_size, m.n_labels])})
b_samples_blurry = sess.run(m.x_mean, {m.z:z_prior})
ckpt = os.path.join(basepath, 'vae_best_celeba_0_crop128_beta1.ckpt')
m.vae_saver.restore(sess, ckpt)
z_eval_sharp = (eval_mu[:9] +
eval_sigma[:9] * np.random.randn(9, m.n_latent))
b_eval_recon_sharp = sess.run(m.x_mean, {m.z:z_eval_sharp, m.labels:np.zeros([m.batch_size, m.n_labels])})
b_samples_sharp = sess.run(m.x_mean, {m.z:z_prior})
b_samples_refined = sess.run(m.x_mean, {m.q_z_sample:z_prior,
m.amortize:True,
m.labels:labels})
# Visualize Reconstructions
tot = 6
row = 5
plt.figure(figsize=[10, 10])
for i in range(row):
plt.subplot(tot, row, 1 + i+5*1)
im(b_eval_recon_blurry[i])
plt.title('Recon')
plt.subplot(tot, row, 1 + i+5*2)
im(b_samples_blurry[i])
plt.title('Sample')
plt.subplot(tot, row, 1 + i+5*3)
im(b_eval_recon_sharp[i])
plt.title('Recon')
plt.subplot(tot, row, 1 + i+5*4)
im(b_samples_sharp[i])
plt.title('Sample')
plt.subplot(tot, row, 1 + i+5*5)
im(b_samples_refined[i])
plt.title('Refinement')
Conditional Generation
# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_399000.ckpt')
m.g_saver.restore(sess, ckpt)
# Compute the Conditional Samples
z_original = sess.run(m.prior_sample)
z_new = [z_original]
b_new = [sess.run(m.x_mean, {m.z:z_original})]
for cond_attr in cond_attrs:
z_new.append(sess.run(m.z, {m.q_z_sample: z_original, m.amortize:True, m.labels:cond_attr}))
b_new.append(sess.run(m.x_mean, {m.z:z_new[-1]}))
# Plot them
idxs = range(10)
plt.figure(figsize=(12, 14))
n_b = len(b_new)
tot = 6
barr = np.array(b_new)
barr = np.swapaxes(barr, 0, 1)
barr = barr[idxs, :, :, :, :]
plt.figure(figsize=(18, 12))
n_b = len(b_new)
tot = 16
for i, b in enumerate(b_new):
plt.subplot(n_b, 1, i + 1)
im(batch_image(b, max_images=tot, rows=tot, cols=1))
Z-Penalty = 0.1
# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_penalty1e-1_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_penalty1e-1_399000.ckpt')
m.g_saver.restore(sess, ckpt)
# Compute the Conditional Samples
z_original = sess.run(m.prior_sample)
z_new = [z_original]
b_new = [sess.run(m.x_mean, {m.z:z_original})]
for cond_attr in cond_attrs:
z_new.append(sess.run(m.z, {m.q_z_sample: z_original, m.amortize:True, m.labels:cond_attr}))
b_new.append(sess.run(m.x_mean, {m.z:z_new[-1]}))
# Plot them
idxs = range(10)
plt.figure(figsize=(12, 14))
n_b = len(b_new)
tot = 6
barr = np.array(b_new)
barr = np.swapaxes(barr, 0, 1)
barr = barr[idxs, :, :, :, :]
plt.figure(figsize=(18, 12))
n_b = len(b_new)
tot = 16
for i, b in enumerate(b_new):
plt.subplot(n_b, 1, i + 1)
im(batch_image(b, max_images=tot, rows=tot, cols=1))
Identity Preserving Transformations
# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_penalty1e-1_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_penalty1e-1_399000.ckpt')
m.g_saver.restore(sess, ckpt)
def transform(z_original,
labels,
z0=None,
lr=1e-1,
n_opt=100,
penalty_weight=0.0,
r_weight=1.0,
attr_weight=1.0,
attr_weights=np.ones([1, m.n_labels]),
r_threshold = 0.9,
attr_threshold = 0.9,
adaptive=False,
):
if z0 is None:
z0 = z_original
_ = sess.run(tf.assign(m.z_prime, z_original))
z_new = np.zeros([m.half_batch, m.n_latent])
i_threshold = np.zeros([m.half_batch])
z_trace = []
for i in range(n_opt):
res = sess.run([m.train_transform,
m.loss_transform,
m.transform_penalty,
m.z_prime,
m.r_pred,
m.pred_attr],
{m.z0: z0,
m.r: np.ones([m.half_batch, 1]),
m.labels: labels,
m.transform: True,
m.transform_lr: lr,
m.transform_penalty_weight: penalty_weight,
m.transform_r_weight: r_weight,
m.transform_attr_weight: attr_weight,
m.attr_weights: attr_weights,
m.x: np.zeros([m.half_batch, m.img_width, m.img_width, 3])
})
z_prime = res[3]
z_trace.append(z_prime)
pred_r = res[-2]
pred_attr = res[-1]
attr_acc = 1.0 - np.mean(attr_weights * np.abs(labels - pred_attr), axis=1)
transform_penalty = np.mean(res[2])
check_idx = np.where(i_threshold == 0)[0]
if len(check_idx) == 0:
break
for idx in check_idx:
if pred_r[idx] > r_threshold and attr_acc[idx] > attr_threshold:
z_new[idx] = z_prime[idx]
i_threshold[idx] = i
if adaptive:
r_weight = 1 - np.mean(pred_r)
attr_weight = 1 - np.mean(attr_acc)
if i % 100 == 1:
print( 'Step %d, NotConverged: %d, Loss: %0.3e, Penalty: %0.3f, '
'r: %0.3f, r_min: %0.3f, '
'attr: %0.3f, attr_min:%0.3f, ' % (
i,
len(check_idx),
res[1],
transform_penalty,
np.mean(pred_r),
np.min(pred_r),
np.mean(attr_acc),
np.min(attr_acc),
))
check_idx = np.where(i_threshold == 0)[0]
print('%d did not converge' % len(check_idx))
for idx in check_idx:
z_new[idx] = z_prime[idx]
return z_new, i_threshold, z_trace
# BETA1 TRANSFORMATION
z_original = (eval_mu[:m.half_batch] +
eval_sigma[:m.half_batch] * np.random.randn(m.half_batch, m.n_latent))
z = z_original
label_list = (
(1, 1, 1e-2, 1e-3, 0.05),
(2, 1, 1e-2, 1e-3, 0.05),
(3, 1, 1e-2, 1e-3, 0.05),
(4, 1, 1e-2, 1e-5, 0.05),
(5, 0, 1e-2, 1e-3, 0.05),
(5, 1, 1e-2, 1e-4, 0.05),
(6, 0, 5e-3, 1e-4, 0.005),
(7, 0, 1e-2, 1e-3, 0.05),
(7, 1, 1e-2, 1e-3, 0.05),
(9, 0, 3e-2, 1e-4, 0.01),
(9, 1, 3e-2, 1e-4, 0.01)
)
z_list = [z_original]
b_list = [sess.run(m.x_mean, {m.z: z_original})]
for attr, value, lr, r_thresh, r_weight in label_list:
print('Label: %d, Value: %d' % (attr, value))
attr_weights=np.zeros([1, m.n_labels])
labels = attr_eval[:m.half_batch].copy()
labels[:, attr] = value
attr_weights[:, attr] = 1
z, i_threshold, z_trace = transform(
z_original,
labels,
lr=lr,
n_opt=300,
r_weight=r_weight,
attr_weight=1.0,
attr_weights=attr_weights,
r_threshold=0.3,
attr_threshold=(1 - r_thresh),
)
z_list.append(z)
b_list.append(sess.run(m.x_mean, {m.z: z}))
labels = attr_eval[:m.half_batch].copy()
idxs = (0, 2, 5, 10, 7, 13, 3,)
labels = labels[idxs, :]
plt.figure(figsize=(24, 18))
n_b = len(b_list)
tot = 6
barr = np.array(b_list)
barr = np.swapaxes(barr, 0, 1)
barr = barr[idxs, :, :, :, :]
for i, b in enumerate(barr):
for j, (k, v, _, _, _) in enumerate(label_list):
if labels[i, k] == v:
barr[i, j+1] = 0.
for i, b in enumerate(barr):
plt.subplot(12, 1, i + 1)
im(batch_image(b, max_images=n_b, rows=n_b, cols=1))
Attribute Classification Accuracy
# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_penalty1e-1_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_penalty1e-1_399000.ckpt')
m.g_saver.restore(sess, ckpt)
Original Images
tf.train.AdamOptimizer()
batch_size = 256
# Prediction Accuracy
train_pred = []
eval_pred = []
test_pred = []
for i in range(n_train / 10 /batch_size):
start = (i * batch_size)
end = start + batch_size
batch_images = train_data[start:end]
start = i * batch_size
end = start + batch_size
res = sess.run([m.pred_classifier], {m.x: batch_images})
train_pred.append(res[0])
train_pred = np.vstack(train_pred)
for i in range(n_eval/batch_size):
start = (i * batch_size)
end = start + batch_size
batch_images = eval_data[start:end]
res = sess.run([m.pred_classifier], {m.x: batch_images})
eval_pred.append(res[0])
eval_pred = np.vstack(eval_pred)
train_acc = (train_pred > 0.5) == attr_train[:train_pred.shape[0]]
eval_acc = (eval_pred > 0.5) == attr_eval[:eval_pred.shape[0]]
print "Train Accuracy: %.4f" % (np.mean(train_acc) * 100)
print "Eval Accuracy: %.4f" % (np.mean(eval_acc) * 100)
y_true = attr_eval[:eval_pred.shape[0]]
y_pred = eval_pred >= 0.5
prec, recall, f1, support = sklearn.metrics.classification.precision_recall_fscore_support(y_true, y_pred, )
report = sklearn.metrics.classification_report(y_true, y_pred, digits=3, target_names=attribute_names)
print(report)
Conditional Generation
# Prediction Accuracy
train_pred = []
eval_pred = []
test_pred = []
for i in range(n_eval/batch_size):
start = (i * batch_size)
end = start + batch_size
labels = attr_eval[start:end]
batch_z = np.random.randn(labels.shape[0], config["n_latent"])
xtmp = np.zeros([labels.shape[0], img_width*img_width])
xsamp, z_prime = sess.run([m.x_mean, m.z],
{m.q_z_sample: batch_z,
m.amortize:True,
m.labels:labels})
res = sess.run([m.pred_classifier], {m.x: xsamp})
eval_pred.append(res[0])
eval_pred = np.vstack(eval_pred)
eval_acc = (eval_pred > 0.5) == attr_eval[:eval_pred.shape[0]]
eval_softmax = eval_pred.argmax(axis=-1) == attr_eval[:eval_pred.shape[0]].argmax(axis=-1)
print "Eval Accuracy: %.4f" % (np.mean(eval_acc) * 100)
y_true = attr_eval[:eval_pred.shape[0]]
y_pred = eval_pred >= 0.5
prec, recall, f1, support = sklearn.metrics.classification.precision_recall_fscore_support(y_true, y_pred, )
report = sklearn.metrics.classification_report(y_true, y_pred, digits=3, target_names=attribute_names)
print(report)
TRAINING MODELS
此代码用于演示目的。
在TitanX GPU上训练VAE和D和G可能需要一天左右的时间。
要从头开始训练模型,您需要将整个CelebA数据集下载到基本目录(假设为〜/ Desktop /CelebA/
)。 下面提供了预处理数据的步骤。
(img_align_celeba /,list_attr_celeba.txt,list_eval_partition.txt)
# Running Average
running_N = 100
running_N_eval = 10
rmean = lambda data: np.mean(data[-running_N:])
rmeane = lambda data: np.mean(data[-running_N_eval:])
batch_size = 256
Prepare the Data (Crop and Pad)
basepath = os.path.expanduser('~/Desktop/CelebA/')
save_path = basepath
partition = np.loadtxt(basepath + 'list_eval_partition.txt', usecols=(1,))
train_mask = (partition == 0)
eval_mask = (partition == 1)
test_mask = (partition == 2)
print "Train: %d, Validation: %d, Test: %d, Total: %d" % (train_mask.sum(), eval_mask.sum(), test_mask.sum(), partition.shape[0])
attributes = pd.read_table(basepath + 'list_attr_celeba.txt', skiprows=1, delim_whitespace=True, usecols=range(1, 41))
attribute_names = attributes.columns.values
attribute_values = attributes.values
attr_train = attribute_values[train_mask]
attr_eval = attribute_values[eval_mask]
attr_test = attribute_values[test_mask]
attr_train[attr_train == -1] = 0
attr_eval[attr_eval == -1] = 0
attr_test[attr_test == -1] = 0
np.save(basepath + 'attr_train.npy', attr_train)
np.save(basepath + 'attr_eval.npy', attr_eval)
np.save(basepath + 'attr_test.npy', attr_test)
def pil_crop_downsample(x, width, out_width):
half_shape = tuple((i - width) / 2 for i in x.size)
x = x.crop([half_shape[0], half_shape[1], half_shape[0] + width, half_shape[1] + width])
return x.resize([out_width, out_width], resample=PIL.Image.ANTIALIAS)
def load_and_adjust_file(filename, width, outwidth):
img = PIL.Image.open(filename)
img = pil_crop_downsample(img, width, outwidth)
img = np.array(img, np.float32) / 255.
return img
# CELEBA images are (218 x 178) originally
filenames = np.sort(glob(basepath + 'img_align_celeba/*.jpg'))
crop_width = 128
img_width = 64
postfix = '_crop_%d_res_%d.npy' % (crop_width, img_width)
n_files = len(filenames)
all_data = np.zeros([n_files, img_width, img_width, 3], np.float32)
for i, fname in enumerate(filenames):
all_data[i, :, :] = load_and_adjust_file(fname, crop_width, img_width)
if i % 10000 == 0:
print('%.2f percent done' % (float(i)/n_files * 100.0))
train_data = all_data[train_mask]
eval_data = all_data[eval_mask]
test_data = all_data[test_mask]
np.save(basepath + 'train' + postfix, train_data)
np.save(basepath + 'eval' + postfix, eval_data)
np.save(basepath + 'test' + postfix, test_data)
Train the VAE
sess.run(tf.variables_initializer(var_list=m.vae_vars))
# Train the VAE
results = []
results_eval = []
traces = {'i': [],
'i_eval': [],
'loss': [],
'loss_eval': [],
'recons': [],
'recons_eval': [],
'kl': [],
'kl_eval': []}
n_iters = 200000
vae_lr_ = np.logspace(np.log10(3e-4), np.log10(1e-6), n_iters)
for i in range(n_iters):
start = (i * batch_size) % n_train
end = start + batch_size
batch = train_data[start:end]
res = sess.run([m.train_vae,
m.vae_loss,
m.mean_recons,
m.mean_KL],
{m.x: batch,
m.vae_lr: vae_lr_[i],
m.amortize: False,
m.labels: attr_train[start:end]})
traces['loss'].append(res[1])
traces['recons'].append(res[2])
traces['kl'].append(res[3])
traces['i'].append(i)
if i % 10 == 0:
start = (i * batch_size) % n_eval
end = start + batch_size
batch = eval_data[start:end]
res_eval = sess.run([m.vae_loss, m.mean_recons, m.mean_KL],
{m.x: batch, m.labels: attr_eval[start:end]})
traces['loss_eval'].append(res_eval[0])
traces['recons_eval'].append(res_eval[1])
traces['kl_eval'].append(res_eval[2])
traces['i_eval'].append(i)
print('Step %5d \t TRAIN \t Loss: %0.3f, Recon: %0.3f, KL: %0.3f '
'\t EVAL \t Loss: %0.3f, Recon: %0.3f, KL: %0.3f' % (i,
rmean(traces['loss']),
rmean(traces['recons']),
rmean(traces['kl']),
rmeane(traces['loss_eval']),
rmeane(traces['recons_eval']),
rmeane(traces['kl_eval']) ))
plt.figure(figsize=(18,6))
plt.subplot(131)
plt.plot(traces['i'], traces['loss'])
plt.plot(traces['i_eval'], traces['loss_eval'])
plt.title('Loss')
# plt.ylim(30, 100)
plt.subplot(132)
plt.plot(traces['i'], traces['recons'])
plt.plot(traces['i_eval'], traces['recons_eval'])
plt.title('Recons')
# plt.ylim(-100, -30)
plt.subplot(133)
plt.plot(traces['i'], traces['kl'])
plt.plot(traces['i_eval'], traces['kl_eval'])
plt.title('KL')
# plt.ylim(10, 100)
Train D and G jointly
# Precompute means and vars
train_mu = []
train_sigma = []
n_batches = int(np.ceil(float(n_train) / batch_size))
for i in range(n_batches):
if i % 1000 == 0:
print '%.1f Done' % (float(i) / n_train * batch_size * 100)
start = i * batch_size
end = start + batch_size
res = sess.run([m.mu, m.sigma], {m.x: train_data[start:end]})
train_mu.append(res[0])
train_sigma.append(res[1])
train_mu = np.vstack(train_mu)
train_sigma = np.vstack(train_sigma)
sigma_mean = train_sigma.mean(0, keepdims=True)
print train_mu.shape, train_sigma.shape, train_data.shape
# Precompute means and vars
eval_mu = []
eval_sigma = []
n_batches = int(np.ceil(float(n_eval) / batch_size))
for i in range(n_batches):
if i % 1000 == 0:
print '%.1f Done' % (float(i) / n_eval * batch_size * 100)
start = i * batch_size
end = start + batch_size
res = sess.run([m.mu, m.sigma], {m.x: eval_data[start:end]})
eval_mu.append(res[0])
eval_sigma.append(res[1])
eval_mu = np.vstack(eval_mu)
eval_sigma = np.vstack(eval_sigma)
sigma_mean_eval = eval_sigma.mean(0, keepdims=True)
print eval_mu.shape, eval_sigma.shape, eval_data.shape
plt.plot(np.sort(sigma_mean.flatten()))
plt.plot(np.sort(sigma_mean_eval.flatten()))
# With eval loss, looking for overfitting
sess.run(tf.variables_initializer(var_list=m.d_vars))
sess.run(tf.variables_initializer(var_list=m.g_vars))
# Declare hyperparameters
results_D = []
results_G = []
traces = {'i': [],
'i_pred': [],
'D': [],
'G': [],
'G_real': [],
'g_penalty': [],
'pred_train': [],
'pred_eval': [],
'pred_prior': [],
'pred_gen': [],
'z_dist_eval': [],
'attr_loss': [],
'attr_acc': [],
}
n_iters = 200000
d_lr_ = np.logspace(-4, -4, n_iters)
g_lr_ = np.logspace(-4, -4, n_iters)
g_penalty_weight_ = 0.1
lambda_weight_ = 10
z_sigma_mean_ = sigma_mean
percentage_prior_fake = 0.1
N_between_update_G = 10
N_between_eval = 100
n_train = train_mu.shape[0]
n_eval = eval_mu.shape[0]
# Training Loop
for i in range(n_iters):
start = (i * batch_size/2) % n_train
end = start + batch_size/2
fake_start = np.random.choice(np.arange(n_train - batch_size/2))
fake_end = fake_start + batch_size/2
real_img = train_data[start:end]
n_batch = real_img.shape[0]
if n_batch == batch_size/2 and start != fake_start:
# Compare real vs. fake
fake_z_prior = np.random.randn(batch_size/2, n_latent)
real_attr = attr_train[start:end].astype(np.int32)
fake_attr = attr_train[fake_start:fake_end].astype(np.int32)
real_z = train_mu[start:end] + train_sigma[start:end] * np.random.randn(batch_size/2, n_latent)
if np.random.rand(1) < percentage_prior_fake:
# Use Prior for fake_samples
all_z = np.vstack([real_z, fake_z_prior, real_z])
else:
# Use Generator to make fake_samples
fake_z_gen = sess.run(m.z, {m.q_z_sample: fake_z_prior,
m.amortize:True,
m.labels: real_attr,})
all_z = np.vstack([real_z, fake_z_gen, real_z])
all_attr = np.vstack([real_attr, real_attr, fake_attr])
# Train Discriminator
real_r = np.ones([batch_size/2, 1])
fake_r = np.zeros([batch_size/2, 1])
all_r = np.concatenate([real_r, fake_r, fake_r])
res_d = sess.run([m.train_d, m.d_loss], {m.z: all_z,
m.r: all_r,
m.d_lr: d_lr_[i],
m.lambda_weight: lambda_weight_,
m.labels: all_attr,})
# Train Generator
if i % N_between_update_G == 0:
if g_penalty_weight_ > 0:
# Train on real data
res_g_real = sess.run([m.train_g, m.g_loss, m.g_penalty],
{m.q_z_sample: real_z,
m.amortize: True,
m.g_penalty_weight: g_penalty_weight_,
m.z_sigma_mean: z_sigma_mean_,
m.g_lr: g_lr_[i],
m.labels: real_attr,})
traces['G_real'].append(res_g_real[1])
# Train on generated data
res_g = sess.run([m.train_g, m.g_loss, m.g_penalty],
{m.q_z_sample: fake_z_prior,
m.amortize: True,
m.g_penalty_weight: g_penalty_weight_,
m.z_sigma_mean: z_sigma_mean_,
m.g_lr: g_lr_[i],
m.labels: real_attr,})
traces['i'].append(i)
traces['D'].append(res_d[1])
traces['G'].append(res_g[1])
traces['g_penalty'].append(res_g[2])
if i % N_between_eval == 0:
eval_start = np.random.choice(np.arange(n_eval - batch_size/2))
eval_end = eval_start + batch_size/2
real_attr_eval = attr_eval[eval_start:eval_end].astype(np.int32)
real_z_eval = eval_mu[eval_start:eval_end] + eval_sigma[eval_start:eval_end] * np.random.randn(batch_size/2, n_latent)
z_eval_gen = sess.run(m.z, {m.q_z_sample: real_z_eval,
m.amortize:True,
m.labels: real_attr,})
fake_z_gen = sess.run(m.z, {m.q_z_sample: fake_z_prior,
m.amortize:True,
m.labels: real_attr,})
pred_train_ = np.mean(sess.run([m.r_pred], {m.z: real_z, m.labels: real_attr}))
pred_eval_ = np.mean(sess.run([m.r_pred], {m.z: real_z_eval, m.labels: real_attr_eval}))
pred_prior_ = np.mean(sess.run([m.r_pred], {m.z: fake_z_prior, m.labels: real_attr}))
pred_gen_ = np.mean(sess.run([m.r_pred], {m.z: fake_z_gen, m.labels: real_attr}))
traces['i_pred'].append(i)
traces['pred_train'].append(pred_train_)
traces['pred_eval'].append(pred_eval_)
traces['pred_prior'].append(pred_prior_)
traces['pred_gen'].append(pred_gen_)
traces['z_dist_eval'].append(np.mean(((z_eval_gen - real_z_eval)/z_sigma_mean_)**2))
print 'PRED Step %d, \t TRAIN: %.2e \t EVAL: %.2e \t PRIOR: %.2e \t GEN: %.2e ' % (i,
rmeanp(traces['pred_train']),
rmeanp(traces['pred_eval']),
rmeanp(traces['pred_prior']),
rmeanp(traces['pred_gen']))
plt.figure(figsize=(18,12))
plt.subplot(4, 1, 1)
plt.plot(traces['i_pred'], traces['pred_train'], label='train')
plt.plot(traces['i_pred'], traces['pred_eval'], label='eval')
plt.plot(traces['i_pred'], traces['pred_prior'], label='prior')
plt.plot(traces['i_pred'], traces['pred_gen'], label='gen')
plt.ylabel('Predictions')
plt.legend(loc='upper right')
plt.subplot(4, 1, 2)
plt.plot(traces['i'], traces['G'])
plt.ylabel('G Loss')
plt.subplot(4, 1, 3)
# plt.semilogy(traces['i'], traces['D'])
plt.plot(traces['i'], traces['D'])
plt.ylabel('D Loss')
plt.subplot(4, 1, 4)
plt.semilogy(traces['i_pred'], traces['z_dist_eval'])
plt.ylabel('Weighted Z Distance Eval')
Train Classifier
sess.run(tf.variables_initializer(var_list=m.classifier_vars))
# Train the Classifier
results = []
results_eval = []
running_N = 100
running_loss = 1
running_loss_eval = 1
classifier_lr_ = 3e-4
# Train
for i in range(40000):
start = (i * batch_size) % n_train
end = start + batch_size
batch_images = train_data[start:end]
batch_labels = attr_train[start:end]
res = sess.run([m.train_classifier,
m.classifier_loss],
{m.x: batch_images,
m.labels: batch_labels.astype(np.int32),
m.classifier_lr: classifier_lr_})
running_loss += (res[1] - running_loss) / running_N
if i % 10 == 1:
start = (i * batch_size) % n_eval
end = start + batch_size
eval_images = eval_data[start:end]
eval_labels = attr_eval[start:end]
res_eval = sess.run([m.classifier_loss],
{m.x: eval_images,
m.labels: eval_labels.astype(np.int32)})
running_loss_eval += (res_eval[0] - running_loss_eval) / (running_N / 10)
results.append([i] + res[1:])
results_eval.append([i] + res_eval[0:])
if i % 10 == 1:
print('Step %d, \t TRAIN \t Loss: %0.3f \t EVAL \t Loss: %0.3f' % (i, running_loss, running_loss_eval))
plot_train = np.array(results).T
plot_eval = np.array(results_eval).T
plt.figure(figsize=(18,6))
plt.plot(plot_train[0],plot_train[1])
plt.plot(plot_eval[0],plot_eval[1])
plt.ylim(1e-1, 1)
plt.title('Loss')
# Prediction Accuracy
train_pred = []
eval_pred = []
test_pred = []
for i in range(n_train / 10 /batch_size):
start = (i * batch_size)
end = start + batch_size
batch_images = train_data[start:end]
start = i * batch_size
end = start + batch_size
res = sess.run([m.pred_classifier], {m.x: batch_images})
train_pred.append(res[0])
train_pred = np.vstack(train_pred)
for i in range(n_eval/batch_size):
start = (i * batch_size)
end = start + batch_size
batch_images = eval_data[start:end]
res = sess.run([m.pred_classifier], {m.x: batch_images})
eval_pred.append(res[0])
eval_pred = np.vstack(eval_pred)
train_acc = (train_pred > 0.5) == attr_train[:train_pred.shape[0]]
eval_acc = (eval_pred > 0.5) == attr_eval[:eval_pred.shape[0]]
print "Train Accuracy: %.4f" % (np.mean(train_acc) * 100)
print "Eval Accuracy: %.4f" % (np.mean(eval_acc) * 100)
y_true = attr_eval[:eval_pred.shape[0]]
y_pred = eval_pred >= 0.5
prec, recall, f1, support = sklearn.metrics.classification.precision_recall_fscore_support(y_true, y_pred, )
report = sklearn.metrics.classification_report(y_true, y_pred, digits=2, target_names=attribute_names)
print report
Train Attribute Classifier (D_attr) in z space
# Train the Discriminator
sess.run(tf.variables_initializer(var_list=m.d_attr_vars))
traces = {"i": [],
"i_eval": [],
"D_loss": [],
"D_loss_eval": [],
"accuracy": [],
}
running_N = 800
running_N_eval = 80
n_iters = 10000
d_attr_lr_ = np.logspace(-4, -4, n_iters)
for i in range(n_iters):
start = (i * batch_size/2) % n_train
end = min(start + batch_size/2, n_train)
batch = train_mu[start:end] + train_sigma[start:end] * np.random.randn(end-start, n_latent)
if batch.shape[0] == batch_size/2:
labels = attr_train[start:end]
# Train
res = sess.run([m.train_d_attr, m.d_loss_attr],
{m.z: batch,
m.labels: labels,
m.d_attr_lr: d_attr_lr_[i],})
traces['i'].append(i)
traces['D_loss'].append(res[1])
if i % 10 == 1:
start = (i * batch_size/2) % n_eval
end = min(start + batch_size/2, n_eval)
batch = eval_mu[start:end] + eval_sigma[start:end] * np.random.randn(end-start, n_latent)
if batch.shape[0] == batch_size/2:
labels = attr_eval[start:end]
res_eval = sess.run([m.d_loss_attr, m.pred_attr],
{m.z: batch,
m.labels: labels,})
y_true = labels
y_pred = (res_eval[1] > 0.5)
accuracy = np.mean(y_true == y_pred)
traces['i_eval'].append(i)
traces['D_loss_eval'].append(res_eval[0])
traces['accuracy'].append(accuracy)
if i % 100 == 0:
print 'Step %d, \t TRAIN \t Loss: %0.3f \t EVAL \t Loss: %0.3f \t Accuracy: %0.3f' % (i,
np.mean(traces['D_loss'][-running_N_eval:]),
np.mean(traces['D_loss_eval'][-running_N_eval:]),
np.mean(traces['accuracy'][-running_N_eval:]),
)
plt.figure(figsize=(18,18))
plt.subplot(3, 1, 1)
plt.semilogy(traces['i'], traces['D_loss'], label='train')
plt.semilogy(traces['i_eval'], traces['D_loss_eval'], label='eval')
plt.ylabel('Loss')
plt.legend(loc='upper right')
plt.subplot(3, 1, 2)
plt.plot(traces['i_eval'], traces['accuracy'], label="eval")
plt.ylabel('Prediction Accuracy')
plt.legend(loc='upper right')
# Prediction Accuracy
train_pred = []
eval_pred = []
test_pred = []
for i in range(n_train / 10 /batch_size):
start = (i * batch_size)
end = start + batch_size
batch_images = train_data[start:end]
start = i * batch_size
end = start + batch_size
res = sess.run([m.pred_attr], {m.x: batch_images,
m.labels: attr_train[start:end]})
train_pred.append(res[0])
train_pred = np.vstack(train_pred)
for i in range(n_eval/batch_size):
start = (i * batch_size)
end = start + batch_size
batch_images = eval_data[start:end]
res = sess.run([m.pred_attr], {m.x: batch_images,
m.labels: attr_eval[start:end]})
eval_pred.append(res[0])
eval_pred = np.vstack(eval_pred)
train_acc = (train_pred > 0.5) == attr_train[:train_pred.shape[0]]
eval_acc = (eval_pred > 0.5) == attr_eval[:eval_pred.shape[0]]
print "Train Accuracy: %.4f" % (np.mean(train_acc) * 100)
print "Eval Accuracy: %.4f" % (np.mean(eval_acc) * 100)
y_true = attr_eval[:eval_pred.shape[0]]
y_pred = eval_pred >= 0.5
prec, recall, f1, support = sklearn.metrics.classification.precision_recall_fscore_support(y_true, y_pred, )
report = sklearn.metrics.classification_report(y_true, y_pred, digits=2, target_names=attribute_names)
print(report)