Latent Constraints: Conditional Generation from Unconditional Generative Models

Latent Constraints: Conditional Generation from Unconditional Generative Models

Jesse Engel, Matthew Hoffman, Adam Roberts arXiv link


Abstract:
深度生成神经网络在复杂数据分布的条件和无条件建模方面都是有效的。
条件生成实现了交互式控制,但创建新控件通常需要昂贵的再训练。
在本文中,我们开发了一种条件生成方法,无需重新训练模型。
通过事后学习____latent constraints___,识别潜在空间中产生具有所需属性的输出的区域的值函数,我们可以使用基于梯度的优化或摊销的演员函数从这些区域中有条件地进行采样。
将属性约束与通用的“现实主义”约束相结合,强制实现与数据分布的相似性,我们从无条件变量自动编码器生成逼真的条件图像。
此外,使用基于梯度的优化,我们演示了保持身份的转换
潜在空间中的最小调整以修改图像的属性。
最后,利用离散的音符序列,我们展示了零射击条件生成,在没有标记数据或可微分奖励函数的情况下学习潜在约束。


在这里插入图片描述


此笔记本包含用于运行与纸张相关的实验的代码。 首先,我们加载预先训练的检查点:

  • 在CelebS上训练的VAE模型具有像素明智的高斯数据可能性 N ( μ ( z ) , σ x = 0.1 ) \mathcal{N}(\mu(z), \sigma_x=0.1) and N ( μ ( z ) , σ x = 1 ) \mathcal{N}(\mu(z), \sigma_x=1) .
  • 我们还提供了VAE模型的训练和评估集的嵌入。
  • 来自条件GAN的生成器( G G )和鉴别器( D D ),经过训练,可以从潜在空间中的新点移动样本,满足现实约束( r r )和属性约束( r a t t r r_{attr} ).
  • 我们的训练版本没有距离惩罚,并且罚款为1e-1。
  • z空间( D a t t r D_ {attr} )和像素空间( C l a s s i f i e r Classifier )中经过单独训练的属性分类器。
    然后我们继续:
    *证明VAE重建在 σ x \sigma_x 降低时会降低,但代价是样本质量,这可以通过潜在约束来补偿。
    *使用CGAN( D D G G )绘制条件生成,无论是否有距离惩罚。
    *使用 D a t t r D_{attr} 在z空间中执行身份保留转换
    *评估使用条件属性生成图像的准确性。

在笔记本电脑末端还提供了训练循环以用于演示目的。


这个colab笔记本是独立的,应该在谷歌云上本地运行。 代码和检查点可以单独下载并在本地运行,如果您想训练自己的模型,建议您使用。 可以在download.magenta.tensorflow.org/models/latent_constraints/latent_constraints.tar 上找到预训练模型检查点。

# This notebook requires DeepMind's sonnet library, which itself
# requires the nightly build of TensorFlow. The command below 
# installs both.
!pip install -q -U dm-sonnet tf-nightly tfp-nightly

import os
import PIL

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import sklearn.metrics
import sonnet as snt
import tensorflow as tf
import tensorflow_probability as tfp

ds = tfp.distributions

%matplotlib inline
# Copy checkpoints from google cloud
# Copying 3GB, takes a minute
!gsutil -q -m cp -R gs://download.magenta.tensorflow.org/models/latent_constraints /content/

Load the Data

basepath = '/content/latent_constraints/'

# Load CelebA embeddings
# VAE with x_sigma = 0.1
train_mu = np.load(basepath + 'train_mu.npy')
train_sigma = np.load(basepath + 'train_sigma.npy')
eval_mu = np.load(basepath + 'eval_mu.npy')
eval_sigma = np.load(basepath + 'eval_sigma.npy')

# VAE with x_sigma = 1.0
eval_mu_xsigma1 = np.load(basepath + 'eval_mu_xsigma1.npy')
eval_sigma_xsigma1 = np.load(basepath + 'eval_sigma_xsigma1.npy')

np.random.seed(10003)
n_train = train_mu.shape[0]
n_eval = eval_mu.shape[0]

# Load Attributes
# Only use 10 salient attributes
attr_train = np.load(basepath + 'attr_train.npy')
attr_eval = np.load(basepath + 'attr_eval.npy')
attr_test = np.load(basepath + 'attr_test.npy')

attr_mask = [4, 8, 9, 11, 15, 20, 24, 31, 35, 39]
attribute_names = [
    'Bald',
    'Black_Hair',
    'Blond_Hair',
    'Brown_Hair',
    'Eyeglasses',
    'Male',
    'No_Beard',
    'Smiling',
    'Wearing_Hat',
    'Young',
]

attr_train = attr_train[:, attr_mask]
attr_eval = attr_eval[:, attr_mask]
attr_test = attr_test[:, attr_mask]

Define the Graph

所有带变量的函数都包含在sonnet模块中。

就像将它们联系在一起的“模型”一样。

张量(端点)可作为“模型”的属性访问。

class Encoder(snt.AbstractModule):
  '''VAE Convolutional Encoder.'''
  def __init__(self,
               n_latent,
               layers=((256, 5, 2),
                       (512, 5, 2),
                       (1024, 3, 2),
                       (2048, 3, 2)),
               name='encoder'):
    super(Encoder, self).__init__(name=name)
    self.n_latent = n_latent
    self.layers = layers

  def _build(self, x):
    h = x
    for unused_i, l in enumerate(self.layers):
      h = tf.nn.relu(snt.Conv2D(l[0], l[1], l[2])(h))

    h_shape = h.get_shape().as_list()
    h = tf.reshape(h, [-1, h_shape[1] * h_shape[2] * h_shape[3]])
    pre_z = snt.Linear(2 * self.n_latent)(h)
    mu = pre_z[:, :self.n_latent]
    sigma = tf.nn.softplus(pre_z[:, self.n_latent:])
    return mu, sigma


class Decoder(snt.AbstractModule):
  '''VAE Convolutional Decoder.'''
  def __init__(self,
               layers=((2048, 4, 4),
                       (1024, 3, 2),
                       (512, 3, 2),
                       (256, 5, 2),
                       (3, 5, 2)),
               name='decoder'):
    super(Decoder, self).__init__(name=name)
    self.layers = layers

  def _build(self, x):
    for i, l in enumerate(self.layers):
      if i == 0:
        h = snt.Linear(l[1] * l[2] * l[0])(x)
        h = tf.reshape(h, [-1, l[1], l[2], l[0]])
      elif i == len(self.layers) - 1:
        h = snt.Conv2DTranspose(l[0], None, l[1], l[2])(h)
      else:
        h = tf.nn.relu(snt.Conv2DTranspose(l[0], None, l[1], l[2])(h))
    logits = h
    return logits


class G(snt.AbstractModule):
  '''CGAN Generator. Maps from z-space to z-space.'''
  def __init__(self,
               n_latent,
               layers=(2048,)*4,
               name='generator'):
    super(G, self).__init__(name=name)
    self.layers = layers
    self.n_latent = n_latent

  def _build(self, z_and_labels):
    z, labels = z_and_labels
    labels = tf.cast(labels, tf.float32)
    size = self.layers[0]
    x = tf.concat([z, snt.Linear(size)(labels)], axis=-1)
    for l in self.layers:
      x = tf.nn.relu(snt.Linear(l)(x))
    x = snt.Linear(2 * self.n_latent)(x)
    dz = x[:, :self.n_latent]
    gates = tf.nn.sigmoid(x[:, self.n_latent:])
    z_prime = (1-gates) * z + gates * dz
    return z_prime


class D(snt.AbstractModule):
  '''CGAN Discriminator.'''
  def __init__(self,
               output_size=1,
               layers=(2048,)*4,
               name='D'):
    super(D, self).__init__(name=name)
    self.layers = layers
    self.output_size = output_size

  def _build(self, z_and_labels):
    z, labels = z_and_labels
    labels = tf.cast(labels, tf.float32)
    size = self.layers[0]
    x = tf.concat([z, snt.Linear(size)(labels)], axis=-1)
    for l in self.layers:
      x = tf.nn.relu(snt.Linear(l)(x))
    logits = snt.Linear(self.output_size)(x)
    return logits


class DAttr(snt.AbstractModule):
  '''Attribute Classifier from z-space.'''
  def __init__(self,
               output_size=1,
               layers=(2048,)*4,
               name='D'):
    super(DAttr, self).__init__(name=name)
    self.layers = layers
    self.output_size = output_size

  def _build(self, x):
    for l in self.layers:
      x = tf.nn.relu(snt.Linear(l)(x))
    logits = snt.Linear(self.output_size)(x)
    return logits
  
  
class Classifier(snt.AbstractModule):
  '''Convolutional Attribute Classifier from Pixels.'''
  def __init__(self,
               output_size,
               layers=((256, 5, 2),
                       (256, 3, 1),
                       (512, 5, 2),
                       (512, 3, 1),
                       (1024, 3, 2),
                       (2048, 3, 2)),
               name='encoder'):
    super(Classifier, self).__init__(name=name)
    self.output_size = output_size
    self.layers = layers

  def _build(self, x):
    h = x
    for unused_i, l in enumerate(self.layers):
      h = tf.nn.relu(snt.Conv2D(l[0], l[1], l[2])(h))

    h_shape = h.get_shape().as_list()
    h = tf.reshape(h, [-1, h_shape[1] * h_shape[2] * h_shape[3]])
    logits = snt.Linear(self.output_size)(h)
    return logits
class Model(snt.AbstractModule):
  '''All the components glued together.'''
  def __init__(self, config, name=''):
    super(Model, self).__init__(name=name)
    self.config = config

  def _build(self, unused_input=None):
    config = self.config

    # Constants
    batch_size = config['batch_size']
    n_latent = config['n_latent']
    img_width = config['img_width']
    half_batch = int(batch_size / 2)
    n_labels = 10

    #---------------------------------------------------------------------------
    ### Placeholders
    #---------------------------------------------------------------------------
    x = tf.placeholder(tf.float32, 
                       shape=(None, img_width, img_width, 3), name='x')
    # Attributes
    labels = tf.placeholder(tf.int32, shape=(None, n_labels), name='labels')
    # Real / fake label reward
    r = tf.placeholder(tf.float32, shape=(None, 1), name='D_label')
    # Transform through optimization
    z0 = tf.placeholder(tf.float32, shape=(None, n_latent), name='z0')
    z_prime = tf.get_variable('z_prime', 
                              shape=(half_batch, n_latent), dtype=tf.float32)

    #---------------------------------------------------------------------------
    ### Modules with parameters
    #---------------------------------------------------------------------------
    encoder = Encoder(n_latent=n_latent, name='encoder')
    decoder = Decoder(name='decoder')
    g = G(n_latent=n_latent, name='generator')
    d = D(output_size=1, name='d_z')
    d_attr = DAttr(output_size=n_labels, name='d_attr')
    classifier = Classifier(output_size=n_labels, name='classifier')


    #---------------------------------------------------------------------------
    ### VAE
    #---------------------------------------------------------------------------
    # Encode
    mu, sigma = encoder(x)
    q_z = ds.Normal(loc=mu, scale=sigma)

    # Optimize / Amortize or feedthrough
    q_z_sample = q_z.sample()

    transform = tf.constant(False)
    z = tf.cond(transform, lambda: z_prime, lambda: q_z_sample)

    amortize = tf.constant(False)
    z = tf.cond(amortize, lambda: g((z, labels)), lambda: z)

    # Decode
    logits = decoder(z)
    x_sigma = tf.constant(config['x_sigma'])
    p_x = ds.Normal(loc=tf.nn.sigmoid(logits), scale=x_sigma)
    x_mean = p_x.mean()

    # Reconstruction Loss
    recons = tf.reduce_sum(p_x.log_prob(x), axis=[1, 2, 3])

    mean_recons = tf.reduce_mean(recons)

    # Prior
    p_z = ds.Normal(loc=0., scale=1.)
    prior_sample = p_z.sample(sample_shape=[batch_size, n_latent])

    # KL Loss
    KL_qp = ds.kl_divergence(q_z, p_z)
    KL = tf.reduce_sum(KL_qp, axis=-1)
    mean_KL = tf.reduce_mean(KL)

    beta = tf.constant(config['beta'])

    # VAE Loss
    vae_loss = -mean_recons + mean_KL * beta

    #---------------------------------------------------------------------------
    ### Discriminator Constraint in Img and Z space and Digit space (implicit)
    #---------------------------------------------------------------------------
    d_logits = d([z, labels])

    r_pred = tf.nn.sigmoid(d_logits)  # r = [0 prior, 1 data]
    d_loss = tf.losses.sigmoid_cross_entropy(r, d_logits)

    # Mean over examples
    d_loss = tf.reduce_mean(d_loss)

    # Gradient Penalty
    real_data = z[:half_batch]
    fake_data = z[half_batch:batch_size]
    alpha = tf.random_uniform(shape=[half_batch, n_latent], minval=0., maxval=1.)
    differences = fake_data - real_data
    interpolates = real_data + (alpha * differences)
    interp_pred = d([interpolates, labels[:half_batch]])
    gradients = tf.gradients(interp_pred, [interpolates])[0]
    slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]) + 1e-10)
    gradient_penalty = tf.reduce_mean((slopes - 1.)**2)

    # Add penalty
    lambda_weight = tf.constant(config['lambda_weight'])
    d_loss_training = d_loss + lambda_weight * gradient_penalty

    
    #---------------------------------------------------------------------------
    ### Discriminator Attribute classification (implicit constraint)
    #---------------------------------------------------------------------------
    # Z-Space
    attr_weights = tf.constant(np.ones([1, n_labels]).astype(np.float32))
    logits_attr = d_attr(z)
    pred_attr = tf.nn.sigmoid(logits_attr)
    d_loss_attr =  tf.losses.sigmoid_cross_entropy(labels, 
                                                   logits=logits_attr, 
                                                   weights=attr_weights)

    
    #---------------------------------------------------------------------------
    ### OPTIMIZTION TRANSFORMATION (SGD)
    #---------------------------------------------------------------------------
    # Realism Constraint
    transform_r_weight = tf.constant(1.0)
    loss_transform = transform_r_weight * tf.reduce_mean(d_loss)

    # Attribute Constraint
    transform_attr_weight = tf.constant(0.0)
    loss_transform += transform_attr_weight * d_loss_attr
    
    # Distance Penalty
    transform_penalty_weight = tf.constant(0.0)
    z_sigma_mean = tf.constant(np.ones([1, n_latent]).astype(np.float32))
    transform_penalty = tf.log(1 + (z_prime - z0)**2)
    transform_penalty = transform_penalty * z_sigma_mean**-2
    loss_transform += tf.reduce_mean(transform_penalty_weight * transform_penalty)


    #---------------------------------------------------------------------------
    ### AMORTIZED TRANSFORMATION (Generator)
    #---------------------------------------------------------------------------
    # Realism and Attribute Constraint
    g_loss = -tf.log(tf.clip_by_value(r_pred, 1e-15, 1 - 1e-15))
    g_loss = tf.reduce_mean(g_loss)

    # Distance Penalty
    g_penalty_weight = tf.constant(0.0)
    g_penalty = tf.log(1 + (z - q_z_sample)**2)
    g_penalty = g_penalty * z_sigma_mean**-2
    g_penalty = tf.reduce_mean(g_penalty) 
    g_loss += g_penalty_weight * g_penalty

    #---------------------------------------------------------------------------
    ### Classify Attributes from pixels
    #---------------------------------------------------------------------------
    logits_classifier = classifier(x)
    pred_classifier = tf.nn.sigmoid(logits_classifier)
    classifier_loss =  tf.losses.sigmoid_cross_entropy(labels, 
                                                       logits=logits_classifier)

    
    #---------------------------------------------------------------------------
    ### Training
    #---------------------------------------------------------------------------
    # Learning rates
    d_lr = tf.constant(3e-4)
    d_attr_lr = tf.constant(3e-4)
    vae_lr = tf.constant(3e-4)
    g_lr = tf.constant(3e-4)
    classifier_lr = tf.constant(3e-4)
    transform_lr = tf.constant(3e-4)

    # Training Ops
    vae_vars = list(encoder.get_variables())
    vae_vars.extend(decoder.get_variables())
    train_vae = tf.train.AdamOptimizer(vae_lr).minimize(vae_loss, var_list=vae_vars)

    d_vars = d.get_variables()
    train_d = tf.train.AdamOptimizer(d_lr, beta1=0, beta2=0.9).minimize(
        d_loss_training, var_list=d_vars)

    classifier_vars = classifier.get_variables()
    train_classifier = tf.train.AdamOptimizer(classifier_lr).minimize(
        classifier_loss, var_list=classifier_vars)

    g_vars = g.get_variables()
    train_g = tf.train.AdamOptimizer(g_lr, beta1=0, beta2=0.9).minimize(
        g_loss, var_list=g_vars)

    d_attr_vars = d_attr.get_variables()
    train_d_attr = tf.train.AdamOptimizer(d_attr_lr).minimize(
        d_loss_attr, var_list=d_attr_vars)

    train_transform = tf.train.AdamOptimizer(transform_lr).minimize(
            loss_transform, var_list=[z_prime])
    
    # Savers
    vae_saver = tf.train.Saver(vae_vars, max_to_keep=100)
    g_saver = tf.train.Saver(g_vars, max_to_keep=1000)
    d_saver = tf.train.Saver(d_vars, max_to_keep=1000)
    d_attr_saver = tf.train.Saver(d_attr_vars, max_to_keep=1000)
    classifier_saver = tf.train.Saver(classifier_vars, max_to_keep=1000)

    # Add all endpoints as object attributes
    for k, v in locals().iteritems():
      self.__dict__[k] = v

Load all models

config = {
    'n_latent': 1024,
    'img_width': 64,
    'crop_width': 64,
    # Optimization parameters
    'batch_size': 128,
    'beta': 1.0,
    'x_sigma': 0.1,
    'lambda_weight': 10.0,
    'penalty_weight': 0.0,
}
tf.reset_default_graph()
sess = tf.Session()

# Declare
m = Model(config)
# Build
_ = m()
# Initialize
sess.run(tf.global_variables_initializer())
# Load VAE
ckpt = os.path.join(basepath, 'vae_best_celeba_0_crop128_beta1.ckpt')
m.vae_saver.restore(sess, ckpt)
# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_penalty1e-1_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_penalty1e-1_399000.ckpt')
m.g_saver.restore(sess, ckpt)
# Load D_attr
ckpt = os.path.join(basepath, 'D_attr_best_d_attr_0.ckpt')
m.d_attr_saver.restore(sess, ckpt)
# Load Classifier
ckpt = os.path.join(basepath, 'classifier_best_classifier_0.ckpt')
m.classifier_saver.restore(sess, ckpt)

GENERATE PLOTS

def im(x):
  plt.imshow(np.maximum(0, np.minimum(1, x)), interpolation='none')
  plt.xticks([])
  plt.yticks([])    
  
def batch_image(b, max_images=64, rows=None, cols=None):
  """Turn a batch of images into a single image mosaic."""
  mb = min(b.shape[0], max_images)
  if rows is None:
    rows = int(np.ceil(np.sqrt(mb)))
    cols = rows
  diff = rows * cols - mb
  b = np.vstack([b[:mb], np.zeros([diff, b.shape[1], b.shape[2], b.shape[3]])])
  tmp = b.reshape(-1, cols * b.shape[1], b.shape[2], b.shape[3])
  img = np.hstack(tmp[i] for i in range(rows))
  return img
# A list of attributes from which to condition generation
# Each list element corresponds to a different fully-speciffied condition

cond_attr_list = [
    [
        (0, 'Bald'),
        (0, 'Black_Hair'),
        (1, 'Blond_Hair'),
        (0, 'Brown_Hair'),
        (0, 'Eyeglasses'),
        (0, 'Male'),
        (1, 'No_Beard'),
        (1, 'Smiling'),
        (0, 'Wearing_Hat'),
        (1, 'Young'),
    ],
    [
        (0, 'Bald'),
        (0, 'Black_Hair'),
        (0, 'Blond_Hair'),
        (1, 'Brown_Hair'),
        (0, 'Eyeglasses'),
        (0, 'Male'),
        (1, 'No_Beard'),
        (1, 'Smiling'),
        (0, 'Wearing_Hat'),
        (1, 'Young'),
    ],
    [
        (0, 'Bald'),
        (1, 'Black_Hair'),
        (0, 'Blond_Hair'),
        (0, 'Brown_Hair'),
        (0, 'Eyeglasses'),
        (0, 'Male'),
        (1, 'No_Beard'),
        (1, 'Smiling'),
        (0, 'Wearing_Hat'),
        (1, 'Young'),
    ],
    [
        (0, 'Bald'),
        (1, 'Black_Hair'),
        (0, 'Blond_Hair'),
        (0, 'Brown_Hair'),
        (0, 'Eyeglasses'),
        (1, 'Male'),
        (1, 'No_Beard'),
        (1, 'Smiling'),
        (0, 'Wearing_Hat'),
        (1, 'Young'),
    ],
    [
        (0, 'Bald'),
        (1, 'Black_Hair'),
        (0, 'Blond_Hair'),
        (0, 'Brown_Hair'),
        (0, 'Eyeglasses'),
        (1, 'Male'),
        (0, 'No_Beard'),
        (1, 'Smiling'),
        (0, 'Wearing_Hat'),
        (1, 'Young'),
    ],
    [
        (0, 'Bald'),
        (1, 'Black_Hair'),
        (0, 'Blond_Hair'),
        (0, 'Brown_Hair'),
        (1, 'Eyeglasses'),
        (1, 'Male'),
        (0, 'No_Beard'),
        (1, 'Smiling'),
        (0, 'Wearing_Hat'),
        (1, 'Young'),
    ],
    [
        (1, 'Bald'),
        (0, 'Black_Hair'),
        (0, 'Blond_Hair'),
        (0, 'Brown_Hair'),
        (0, 'Eyeglasses'),
        (1, 'Male'),
        (0, 'No_Beard'),
        (1, 'Smiling'),
        (0, 'Wearing_Hat'),
        (1, 'Young'),
    ],
    [
        (1, 'Bald'),
        (0, 'Black_Hair'),
        (0, 'Blond_Hair'),
        (0, 'Brown_Hair'),
        (0, 'Eyeglasses'),
        (1, 'Male'),
        (0, 'No_Beard'),
        (0, 'Smiling'),
        (0, 'Wearing_Hat'),
        (0, 'Young'),
    ],
]
    
  
cond_attrs = []
for attrs in cond_attr_list:
  cond_attrs.append( np.repeat(np.array([[a[0] for a in attrs]]).astype(np.int32), m.batch_size, axis=0) )

VAE Reconstructions

# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_399000.ckpt')
m.g_saver.restore(sess, ckpt)
z_prior = np.random.randn(9, m.n_latent)
labels = attr_eval[:9]

ckpt = os.path.join(basepath, 'vae_best_celeba_0_crop128_beta1_xsigma1.ckpt')
m.vae_saver.restore(sess, ckpt)

z_eval_blurry = (eval_mu_xsigma1[:9] + 
                 eval_sigma_xsigma1[:9] * np.random.randn(9, m.n_latent))
b_eval_recon_blurry = sess.run(m.x_mean, {m.z:z_eval_blurry, m.labels:np.zeros([m.batch_size, m.n_labels])})
b_samples_blurry = sess.run(m.x_mean, {m.z:z_prior})


ckpt = os.path.join(basepath, 'vae_best_celeba_0_crop128_beta1.ckpt')
m.vae_saver.restore(sess, ckpt)

z_eval_sharp = (eval_mu[:9] + 
                eval_sigma[:9] * np.random.randn(9, m.n_latent))
b_eval_recon_sharp = sess.run(m.x_mean, {m.z:z_eval_sharp, m.labels:np.zeros([m.batch_size, m.n_labels])})
b_samples_sharp = sess.run(m.x_mean, {m.z:z_prior})
b_samples_refined = sess.run(m.x_mean, {m.q_z_sample:z_prior, 
                                        m.amortize:True, 
                                        m.labels:labels})
# Visualize Reconstructions
tot = 6
row = 5
plt.figure(figsize=[10, 10])
for i in range(row):
  plt.subplot(tot, row, 1 +  i+5*1)
  im(b_eval_recon_blurry[i])
  plt.title('Recon')
  plt.subplot(tot, row, 1 +  i+5*2)
  im(b_samples_blurry[i])
  plt.title('Sample')
  plt.subplot(tot, row, 1 +  i+5*3)
  im(b_eval_recon_sharp[i])
  plt.title('Recon')
  plt.subplot(tot, row, 1 + i+5*4)
  im(b_samples_sharp[i])
  plt.title('Sample')
  plt.subplot(tot, row, 1 + i+5*5)
  im(b_samples_refined[i])
  plt.title('Refinement')

在这里插入图片描述

Conditional Generation

# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_399000.ckpt')
m.g_saver.restore(sess, ckpt)
# Compute the Conditional Samples
z_original = sess.run(m.prior_sample)
z_new = [z_original]
b_new = [sess.run(m.x_mean, {m.z:z_original})]

for cond_attr in cond_attrs:
  z_new.append(sess.run(m.z, {m.q_z_sample: z_original, m.amortize:True, m.labels:cond_attr}))
  b_new.append(sess.run(m.x_mean, {m.z:z_new[-1]}))
# Plot them
idxs = range(10)
plt.figure(figsize=(12, 14))
n_b = len(b_new)
tot = 6
barr = np.array(b_new)
barr = np.swapaxes(barr, 0, 1)
barr = barr[idxs, :, :, :, :]

plt.figure(figsize=(18, 12))
n_b = len(b_new)
tot = 16
for i, b in enumerate(b_new):
  plt.subplot(n_b, 1, i + 1)
  im(batch_image(b, max_images=tot, rows=tot, cols=1))

在这里插入图片描述

Z-Penalty = 0.1

# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_penalty1e-1_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_penalty1e-1_399000.ckpt')
m.g_saver.restore(sess, ckpt)
# Compute the Conditional Samples
z_original = sess.run(m.prior_sample)
z_new = [z_original]
b_new = [sess.run(m.x_mean, {m.z:z_original})]

for cond_attr in cond_attrs:
  z_new.append(sess.run(m.z, {m.q_z_sample: z_original, m.amortize:True, m.labels:cond_attr}))
  b_new.append(sess.run(m.x_mean, {m.z:z_new[-1]}))
# Plot them
idxs = range(10)
plt.figure(figsize=(12, 14))
n_b = len(b_new)
tot = 6
barr = np.array(b_new)
barr = np.swapaxes(barr, 0, 1)
barr = barr[idxs, :, :, :, :]

plt.figure(figsize=(18, 12))
n_b = len(b_new)
tot = 16
for i, b in enumerate(b_new):
  plt.subplot(n_b, 1, i + 1)
  im(batch_image(b, max_images=tot, rows=tot, cols=1))

在这里插入图片描述

Identity Preserving Transformations

# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_penalty1e-1_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_penalty1e-1_399000.ckpt')
m.g_saver.restore(sess, ckpt)
def transform(z_original, 
              labels,
              z0=None,
              lr=1e-1, 
              n_opt=100, 
              penalty_weight=0.0,
              r_weight=1.0, 
              attr_weight=1.0, 
              attr_weights=np.ones([1, m.n_labels]),
              r_threshold = 0.9,
              attr_threshold = 0.9,
              adaptive=False,
             ):

  if z0 is None:
    z0 = z_original
  _ = sess.run(tf.assign(m.z_prime, z_original))
  z_new = np.zeros([m.half_batch, m.n_latent])
  i_threshold = np.zeros([m.half_batch])
  z_trace = []
  for i in range(n_opt):
    res = sess.run([m.train_transform, 
                    m.loss_transform, 
                    m.transform_penalty,
                    m.z_prime,
                    m.r_pred,
                    m.pred_attr], 
                   {m.z0: z0,
                    m.r: np.ones([m.half_batch, 1]),
                    m.labels: labels,
                    m.transform: True, 
                    m.transform_lr: lr,
                    m.transform_penalty_weight: penalty_weight,
                    m.transform_r_weight: r_weight,
                    m.transform_attr_weight: attr_weight,
                    m.attr_weights: attr_weights,
                    m.x: np.zeros([m.half_batch, m.img_width, m.img_width, 3])
                   })

    z_prime = res[3]
    z_trace.append(z_prime)
    pred_r = res[-2]
    pred_attr = res[-1]
    attr_acc = 1.0 - np.mean(attr_weights * np.abs(labels - pred_attr), axis=1)
    transform_penalty = np.mean(res[2])
    check_idx = np.where(i_threshold == 0)[0]
    if len(check_idx) == 0:
      break
    for idx in check_idx:
      if pred_r[idx] > r_threshold and attr_acc[idx] > attr_threshold:
        z_new[idx] = z_prime[idx]
        i_threshold[idx] = i
    if adaptive:
      r_weight = 1 - np.mean(pred_r)
      attr_weight = 1 - np.mean(attr_acc)

    if i % 100 == 1:
      print( 'Step %d, NotConverged: %d, Loss: %0.3e, Penalty: %0.3f, '
            'r: %0.3f, r_min: %0.3f, '
            'attr: %0.3f, attr_min:%0.3f, ' % (
                i,
                len(check_idx), 
                res[1], 
                transform_penalty, 
                np.mean(pred_r), 
                np.min(pred_r),
                np.mean(attr_acc), 
                np.min(attr_acc),
           ))
      

  check_idx = np.where(i_threshold == 0)[0]
  print('%d did not converge' % len(check_idx))
  for idx in check_idx:
    z_new[idx] = z_prime[idx]
  return z_new, i_threshold, z_trace
  
# BETA1 TRANSFORMATION
z_original = (eval_mu[:m.half_batch] + 
              eval_sigma[:m.half_batch] * np.random.randn(m.half_batch, m.n_latent))
z = z_original

label_list = (
    (1, 1, 1e-2, 1e-3, 0.05), 
    (2, 1, 1e-2, 1e-3, 0.05), 
    (3, 1, 1e-2, 1e-3, 0.05), 
    (4, 1, 1e-2, 1e-5, 0.05), 
    (5, 0, 1e-2, 1e-3, 0.05), 
    (5, 1, 1e-2, 1e-4, 0.05), 
    (6, 0, 5e-3, 1e-4, 0.005), 
    (7, 0, 1e-2, 1e-3, 0.05), 
    (7, 1, 1e-2, 1e-3, 0.05), 
    (9, 0, 3e-2, 1e-4, 0.01), 
    (9, 1, 3e-2, 1e-4, 0.01)
)
z_list = [z_original]
b_list = [sess.run(m.x_mean, {m.z: z_original})]


for attr, value, lr, r_thresh, r_weight in label_list:
  print('Label: %d, Value: %d' % (attr, value))
  attr_weights=np.zeros([1, m.n_labels])
  labels = attr_eval[:m.half_batch].copy()
  labels[:, attr] = value
  attr_weights[:, attr] = 1

  z, i_threshold, z_trace = transform(
      z_original, 
      labels, 
      lr=lr,
      n_opt=300,
      r_weight=r_weight,
      attr_weight=1.0,
      attr_weights=attr_weights,
      r_threshold=0.3,
      attr_threshold=(1 - r_thresh),
  )
  z_list.append(z)
  b_list.append(sess.run(m.x_mean, {m.z: z}))
labels = attr_eval[:m.half_batch].copy()
idxs = (0, 2, 5, 10, 7, 13,  3,)
labels = labels[idxs, :]

plt.figure(figsize=(24, 18))
n_b = len(b_list)
tot = 6
barr = np.array(b_list)
barr = np.swapaxes(barr, 0, 1)
barr = barr[idxs, :, :, :, :]
for i, b in enumerate(barr):
  for j, (k, v, _, _, _) in enumerate(label_list):
    if labels[i, k] == v:
      barr[i, j+1] = 0.

for i, b in enumerate(barr):
  plt.subplot(12, 1, i + 1)
  im(batch_image(b, max_images=n_b, rows=n_b, cols=1))

在这里插入图片描述

Attribute Classification Accuracy

# Load D
ckpt = os.path.join(basepath, 'D_d_2_conditional_penalty1e-1_399000.ckpt')
m.d_saver.restore(sess, ckpt)
# Load G
ckpt = os.path.join(basepath, 'G_d_2_conditional_penalty1e-1_399000.ckpt')
m.g_saver.restore(sess, ckpt)

Original Images

tf.train.AdamOptimizer()
batch_size = 256
# Prediction Accuracy
train_pred = []
eval_pred = []
test_pred = []

for i in range(n_train / 10 /batch_size):
  start = (i * batch_size)
  end = start + batch_size
  batch_images = train_data[start:end]

  start = i * batch_size
  end = start + batch_size
  res = sess.run([m.pred_classifier], {m.x: batch_images})
  train_pred.append(res[0])
train_pred = np.vstack(train_pred)

for i in range(n_eval/batch_size):
  start = (i * batch_size)
  end = start + batch_size
  batch_images = eval_data[start:end]

  res = sess.run([m.pred_classifier], {m.x: batch_images})
  eval_pred.append(res[0])
eval_pred = np.vstack(eval_pred)

train_acc = (train_pred > 0.5) == attr_train[:train_pred.shape[0]]
eval_acc = (eval_pred > 0.5) == attr_eval[:eval_pred.shape[0]]

print "Train Accuracy: %.4f" % (np.mean(train_acc) * 100)
print "Eval Accuracy: %.4f" % (np.mean(eval_acc) * 100)
y_true = attr_eval[:eval_pred.shape[0]]
y_pred = eval_pred >= 0.5

prec, recall, f1, support = sklearn.metrics.classification.precision_recall_fscore_support(y_true, y_pred, )
report = sklearn.metrics.classification_report(y_true, y_pred, digits=3, target_names=attribute_names)
print(report)

在这里插入图片描述

Conditional Generation

# Prediction Accuracy
train_pred = []
eval_pred = []
test_pred = []

for i in range(n_eval/batch_size):
  start = (i * batch_size)
  end = start + batch_size
  labels = attr_eval[start:end]
  batch_z = np.random.randn(labels.shape[0], config["n_latent"])
  xtmp = np.zeros([labels.shape[0], img_width*img_width])
  xsamp, z_prime = sess.run([m.x_mean, m.z],
                  {m.q_z_sample: batch_z, 
                   m.amortize:True, 
                   m.labels:labels})
  res = sess.run([m.pred_classifier], {m.x: xsamp})
  eval_pred.append(res[0])
eval_pred = np.vstack(eval_pred)

eval_acc = (eval_pred > 0.5) == attr_eval[:eval_pred.shape[0]]

eval_softmax = eval_pred.argmax(axis=-1) == attr_eval[:eval_pred.shape[0]].argmax(axis=-1)

print "Eval Accuracy: %.4f" % (np.mean(eval_acc) * 100)
y_true = attr_eval[:eval_pred.shape[0]]
y_pred = eval_pred >= 0.5

prec, recall, f1, support = sklearn.metrics.classification.precision_recall_fscore_support(y_true, y_pred, )
report = sklearn.metrics.classification_report(y_true, y_pred, digits=3, target_names=attribute_names)
print(report)

在这里插入图片描述

TRAINING MODELS

此代码用于演示目的。

在TitanX GPU上训练VAE和D和G可能需要一天左右的时间。

要从头开始训练模型,您需要将整个CelebA数据集下载到基本目录(假设为〜/ Desktop /CelebA/)。 下面提供了预处理数据的步骤。

(img_align_celeba /,list_attr_celeba.txt,list_eval_partition.txt)

# Running Average
running_N = 100
running_N_eval = 10
rmean = lambda data: np.mean(data[-running_N:])
rmeane = lambda data: np.mean(data[-running_N_eval:])
batch_size = 256

Prepare the Data (Crop and Pad)

basepath = os.path.expanduser('~/Desktop/CelebA/')
save_path = basepath
partition = np.loadtxt(basepath + 'list_eval_partition.txt', usecols=(1,))
train_mask = (partition == 0)
eval_mask = (partition == 1)
test_mask = (partition == 2)

print "Train: %d, Validation: %d, Test: %d, Total: %d" % (train_mask.sum(), eval_mask.sum(), test_mask.sum(), partition.shape[0])
attributes = pd.read_table(basepath + 'list_attr_celeba.txt', skiprows=1, delim_whitespace=True, usecols=range(1, 41))
attribute_names = attributes.columns.values
attribute_values = attributes.values
attr_train = attribute_values[train_mask]
attr_eval = attribute_values[eval_mask]
attr_test = attribute_values[test_mask]

attr_train[attr_train == -1] = 0
attr_eval[attr_eval == -1] = 0
attr_test[attr_test == -1] = 0

np.save(basepath + 'attr_train.npy', attr_train)
np.save(basepath + 'attr_eval.npy', attr_eval)
np.save(basepath + 'attr_test.npy', attr_test)
def pil_crop_downsample(x, width, out_width):
  half_shape = tuple((i - width) / 2 for i in x.size)
  x = x.crop([half_shape[0], half_shape[1], half_shape[0] + width, half_shape[1] + width])
  return x.resize([out_width, out_width], resample=PIL.Image.ANTIALIAS)

def load_and_adjust_file(filename, width, outwidth):
  img = PIL.Image.open(filename)
  img = pil_crop_downsample(img, width, outwidth)
  img = np.array(img, np.float32) / 255.
  return img
# CELEBA images are (218 x 178) originally
filenames = np.sort(glob(basepath + 'img_align_celeba/*.jpg'))

crop_width = 128
img_width = 64
postfix = '_crop_%d_res_%d.npy' % (crop_width, img_width)

n_files = len(filenames)
all_data = np.zeros([n_files, img_width, img_width, 3], np.float32)
for i, fname in enumerate(filenames):
  all_data[i, :, :] = load_and_adjust_file(fname, crop_width, img_width)
  if i % 10000 == 0:
    print('%.2f percent done' % (float(i)/n_files * 100.0))
train_data = all_data[train_mask]
eval_data = all_data[eval_mask]
test_data = all_data[test_mask]
np.save(basepath + 'train' + postfix, train_data)
np.save(basepath + 'eval' + postfix, eval_data)
np.save(basepath + 'test' + postfix, test_data)

Train the VAE

sess.run(tf.variables_initializer(var_list=m.vae_vars))

# Train the VAE
results = []
results_eval = []

traces = {'i': [],
          'i_eval': [],
          'loss': [],
          'loss_eval': [],
          'recons': [],
          'recons_eval': [],
          'kl': [],
          'kl_eval': []}

n_iters = 200000
vae_lr_ = np.logspace(np.log10(3e-4), np.log10(1e-6), n_iters)

for i in range(n_iters):
  start = (i * batch_size) % n_train
  end = start + batch_size
  batch = train_data[start:end]

  res = sess.run([m.train_vae, 
                  m.vae_loss, 
                  m.mean_recons, 
                  m.mean_KL], 
                 {m.x: batch,
                  m.vae_lr: vae_lr_[i],
                  m.amortize: False,
                  m.labels: attr_train[start:end]})
  
  traces['loss'].append(res[1])
  traces['recons'].append(res[2])
  traces['kl'].append(res[3])
  traces['i'].append(i)

  if i % 10 == 0:
    start = (i * batch_size) % n_eval
    end = start + batch_size
    batch = eval_data[start:end]
    res_eval = sess.run([m.vae_loss, m.mean_recons, m.mean_KL], 
                        {m.x: batch, m.labels: attr_eval[start:end]})
    traces['loss_eval'].append(res_eval[0])
    traces['recons_eval'].append(res_eval[1])
    traces['kl_eval'].append(res_eval[2])
    traces['i_eval'].append(i)

    print('Step %5d \t TRAIN \t Loss: %0.3f, Recon: %0.3f, KL: %0.3f '
          '\t EVAL \t  Loss: %0.3f, Recon: %0.3f, KL: %0.3f' % (i, 
                                                                rmean(traces['loss']), 
                                                                rmean(traces['recons']), 
                                                                rmean(traces['kl']), 
                                                                rmeane(traces['loss_eval']), 
                                                                rmeane(traces['recons_eval']), 
                                                                rmeane(traces['kl_eval']) ))
plt.figure(figsize=(18,6))

plt.subplot(131)
plt.plot(traces['i'], traces['loss'])
plt.plot(traces['i_eval'], traces['loss_eval'])
plt.title('Loss')
# plt.ylim(30, 100)

plt.subplot(132)
plt.plot(traces['i'], traces['recons'])
plt.plot(traces['i_eval'], traces['recons_eval'])
plt.title('Recons')
# plt.ylim(-100, -30)

plt.subplot(133)
plt.plot(traces['i'], traces['kl'])
plt.plot(traces['i_eval'], traces['kl_eval'])
plt.title('KL')
# plt.ylim(10, 100)

Train D and G jointly

# Precompute means and vars
train_mu = []
train_sigma = []
n_batches = int(np.ceil(float(n_train) / batch_size))
for i in range(n_batches):
  if i % 1000 == 0:
    print '%.1f Done' % (float(i) / n_train * batch_size * 100)
  start = i * batch_size
  end = start + batch_size
  res = sess.run([m.mu, m.sigma], {m.x: train_data[start:end]})
  train_mu.append(res[0])
  train_sigma.append(res[1])
train_mu = np.vstack(train_mu)
train_sigma = np.vstack(train_sigma)
sigma_mean = train_sigma.mean(0, keepdims=True)
print train_mu.shape, train_sigma.shape, train_data.shape
# Precompute means and vars
eval_mu = []
eval_sigma = []
n_batches = int(np.ceil(float(n_eval) / batch_size))
for i in range(n_batches):
  if i % 1000 == 0:
    print '%.1f Done' % (float(i) / n_eval * batch_size * 100)
  start = i * batch_size
  end = start + batch_size
  res = sess.run([m.mu, m.sigma], {m.x: eval_data[start:end]})
  eval_mu.append(res[0])
  eval_sigma.append(res[1])
eval_mu = np.vstack(eval_mu)
eval_sigma = np.vstack(eval_sigma)
sigma_mean_eval = eval_sigma.mean(0, keepdims=True)
print eval_mu.shape, eval_sigma.shape, eval_data.shape
plt.plot(np.sort(sigma_mean.flatten()))
plt.plot(np.sort(sigma_mean_eval.flatten()))

在这里插入图片描述

# With eval loss, looking for overfitting
sess.run(tf.variables_initializer(var_list=m.d_vars))
sess.run(tf.variables_initializer(var_list=m.g_vars))

# Declare hyperparameters
results_D = []
results_G = []
traces = {'i': [], 
          'i_pred': [], 
          'D': [], 
          'G': [], 
          'G_real': [],
          'g_penalty': [],
          'pred_train': [], 
          'pred_eval': [], 
          'pred_prior': [], 
          'pred_gen': [],
          'z_dist_eval': [],
          'attr_loss': [],
          'attr_acc': [],
         }


n_iters = 200000
d_lr_ = np.logspace(-4, -4, n_iters)
g_lr_ = np.logspace(-4, -4, n_iters)

g_penalty_weight_ = 0.1
lambda_weight_ = 10
z_sigma_mean_ = sigma_mean

percentage_prior_fake = 0.1
N_between_update_G = 10
N_between_eval = 100

n_train = train_mu.shape[0]
n_eval = eval_mu.shape[0]

# Training Loop
for i in range(n_iters):
  start = (i * batch_size/2) % n_train
  end = start + batch_size/2
  fake_start = np.random.choice(np.arange(n_train - batch_size/2))
  fake_end = fake_start + batch_size/2
  
  real_img = train_data[start:end]
  n_batch = real_img.shape[0]
  if n_batch == batch_size/2 and start != fake_start:
    # Compare real vs. fake
    fake_z_prior = np.random.randn(batch_size/2, n_latent)
    real_attr = attr_train[start:end].astype(np.int32)
    fake_attr = attr_train[fake_start:fake_end].astype(np.int32)            
    real_z = train_mu[start:end] + train_sigma[start:end] * np.random.randn(batch_size/2, n_latent)


    
    if np.random.rand(1) < percentage_prior_fake:
      # Use Prior for fake_samples      
      all_z = np.vstack([real_z, fake_z_prior, real_z])
    else:
      # Use Generator to make fake_samples 
      fake_z_gen = sess.run(m.z, {m.q_z_sample: fake_z_prior, 
                                  m.amortize:True, 
                                  m.labels: real_attr,})      
      all_z = np.vstack([real_z, fake_z_gen, real_z])
    all_attr = np.vstack([real_attr, real_attr, fake_attr]) 

      
    # Train Discriminator
    real_r = np.ones([batch_size/2, 1])
    fake_r = np.zeros([batch_size/2, 1])
    all_r = np.concatenate([real_r, fake_r, fake_r])
    res_d = sess.run([m.train_d, m.d_loss], {m.z: all_z, 
                                         m.r: all_r,
                                         m.d_lr: d_lr_[i],
                                         m.lambda_weight: lambda_weight_,
                                         m.labels: all_attr,})
    
    # Train Generator
    if i % N_between_update_G == 0:
      if g_penalty_weight_ > 0:
        # Train on real data
        res_g_real = sess.run([m.train_g, m.g_loss, m.g_penalty], 
                              {m.q_z_sample: real_z,
                               m.amortize: True,
                               m.g_penalty_weight: g_penalty_weight_,
                               m.z_sigma_mean: z_sigma_mean_,
                               m.g_lr: g_lr_[i],
                               m.labels: real_attr,})
        traces['G_real'].append(res_g_real[1])

      # Train on generated data
      res_g = sess.run([m.train_g, m.g_loss, m.g_penalty], 
                       {m.q_z_sample: fake_z_prior,
                        m.amortize: True,
                        m.g_penalty_weight: g_penalty_weight_,
                        m.z_sigma_mean: z_sigma_mean_,
                        m.g_lr: g_lr_[i],
                        m.labels: real_attr,})



    traces['i'].append(i)
    traces['D'].append(res_d[1])
    traces['G'].append(res_g[1])
    traces['g_penalty'].append(res_g[2])

    if i % N_between_eval == 0:
      eval_start = np.random.choice(np.arange(n_eval - batch_size/2))
      eval_end = eval_start + batch_size/2
      real_attr_eval = attr_eval[eval_start:eval_end].astype(np.int32)
      real_z_eval = eval_mu[eval_start:eval_end] + eval_sigma[eval_start:eval_end] * np.random.randn(batch_size/2, n_latent)
      z_eval_gen = sess.run(m.z, {m.q_z_sample: real_z_eval, 
                                  m.amortize:True, 
                                  m.labels: real_attr,})      
      fake_z_gen = sess.run(m.z, {m.q_z_sample: fake_z_prior, 
                                  m.amortize:True, 
                                  m.labels: real_attr,})      
      
      pred_train_ = np.mean(sess.run([m.r_pred], {m.z: real_z, m.labels: real_attr}))
      pred_eval_ = np.mean(sess.run([m.r_pred], {m.z: real_z_eval, m.labels: real_attr_eval}))
      pred_prior_ = np.mean(sess.run([m.r_pred], {m.z: fake_z_prior, m.labels: real_attr}))
      pred_gen_ = np.mean(sess.run([m.r_pred], {m.z: fake_z_gen, m.labels: real_attr}))

      traces['i_pred'].append(i)
      traces['pred_train'].append(pred_train_)
      traces['pred_eval'].append(pred_eval_)
      traces['pred_prior'].append(pred_prior_)
      traces['pred_gen'].append(pred_gen_)
      traces['z_dist_eval'].append(np.mean(((z_eval_gen - real_z_eval)/z_sigma_mean_)**2))
      print 'PRED Step %d, \t TRAIN: %.2e \t EVAL: %.2e \t PRIOR: %.2e \t GEN: %.2e ' % (i, 
                                                                                         rmeanp(traces['pred_train']),  
                                                                                         rmeanp(traces['pred_eval']),  
                                                                                         rmeanp(traces['pred_prior']),  
                                                                                         rmeanp(traces['pred_gen']))
    
plt.figure(figsize=(18,12))

plt.subplot(4, 1, 1)
plt.plot(traces['i_pred'], traces['pred_train'], label='train')
plt.plot(traces['i_pred'], traces['pred_eval'], label='eval')
plt.plot(traces['i_pred'], traces['pred_prior'], label='prior')
plt.plot(traces['i_pred'], traces['pred_gen'], label='gen')
plt.ylabel('Predictions')
plt.legend(loc='upper right')

plt.subplot(4, 1, 2)
plt.plot(traces['i'], traces['G'])
plt.ylabel('G Loss')

plt.subplot(4, 1, 3)
# plt.semilogy(traces['i'], traces['D'])
plt.plot(traces['i'], traces['D'])
plt.ylabel('D Loss')

plt.subplot(4, 1, 4)
plt.semilogy(traces['i_pred'], traces['z_dist_eval'])
plt.ylabel('Weighted Z Distance Eval')

Train Classifier

sess.run(tf.variables_initializer(var_list=m.classifier_vars))
# Train the Classifier
results = []
results_eval = []

running_N = 100
running_loss = 1
running_loss_eval = 1

classifier_lr_ = 3e-4

# Train
for i in range(40000):
  start = (i * batch_size) % n_train
  end = start + batch_size
  batch_images = train_data[start:end]
  batch_labels = attr_train[start:end]

  res = sess.run([m.train_classifier, 
                  m.classifier_loss], 
                 {m.x: batch_images, 
                  m.labels: batch_labels.astype(np.int32),
                  m.classifier_lr: classifier_lr_})
  running_loss += (res[1] - running_loss) / running_N 
  if i % 10 == 1:
    start = (i * batch_size) % n_eval
    end = start + batch_size
    eval_images = eval_data[start:end]
    eval_labels = attr_eval[start:end]
    res_eval = sess.run([m.classifier_loss], 
                        {m.x: eval_images, 
                         m.labels: eval_labels.astype(np.int32)})
    running_loss_eval += (res_eval[0] - running_loss_eval) / (running_N / 10)
      
    results.append([i] + res[1:])
    results_eval.append([i] + res_eval[0:])

  if i % 10 == 1:
    print('Step %d, \t TRAIN \t Loss: %0.3f \t EVAL \t Loss: %0.3f' % (i, running_loss, running_loss_eval))
plot_train = np.array(results).T
plot_eval = np.array(results_eval).T
plt.figure(figsize=(18,6))
plt.plot(plot_train[0],plot_train[1])
plt.plot(plot_eval[0],plot_eval[1])
plt.ylim(1e-1, 1)
plt.title('Loss')
# Prediction Accuracy
train_pred = []
eval_pred = []
test_pred = []

for i in range(n_train / 10 /batch_size):
  start = (i * batch_size)
  end = start + batch_size
  batch_images = train_data[start:end]

  start = i * batch_size
  end = start + batch_size
  res = sess.run([m.pred_classifier], {m.x: batch_images})
  train_pred.append(res[0])
train_pred = np.vstack(train_pred)

for i in range(n_eval/batch_size):
  start = (i * batch_size)
  end = start + batch_size
  batch_images = eval_data[start:end]

  res = sess.run([m.pred_classifier], {m.x: batch_images})
  eval_pred.append(res[0])
eval_pred = np.vstack(eval_pred)

train_acc = (train_pred > 0.5) == attr_train[:train_pred.shape[0]]
eval_acc = (eval_pred > 0.5) == attr_eval[:eval_pred.shape[0]]

print "Train Accuracy: %.4f" % (np.mean(train_acc) * 100)
print "Eval Accuracy: %.4f" % (np.mean(eval_acc) * 100)
y_true = attr_eval[:eval_pred.shape[0]]
y_pred = eval_pred >= 0.5

prec, recall, f1, support = sklearn.metrics.classification.precision_recall_fscore_support(y_true, y_pred, )
report = sklearn.metrics.classification_report(y_true, y_pred, digits=2, target_names=attribute_names)
print report

Train Attribute Classifier (D_attr) in z space

# Train the Discriminator
sess.run(tf.variables_initializer(var_list=m.d_attr_vars))
traces = {"i": [],
          "i_eval": [],
          "D_loss": [],
          "D_loss_eval": [],
          "accuracy": [],
         }
running_N = 800
running_N_eval = 80

n_iters = 10000
d_attr_lr_ = np.logspace(-4, -4, n_iters)

for i in range(n_iters):
  start = (i * batch_size/2) % n_train
  end = min(start + batch_size/2, n_train)
  batch = train_mu[start:end] + train_sigma[start:end] * np.random.randn(end-start, n_latent)
  if batch.shape[0] == batch_size/2:
    labels = attr_train[start:end]

    # Train
    res = sess.run([m.train_d_attr, m.d_loss_attr], 
                   {m.z: batch, 
                    m.labels: labels, 
                    m.d_attr_lr: d_attr_lr_[i],})

    traces['i'].append(i)
    traces['D_loss'].append(res[1])

  if i % 10 == 1:
    start = (i * batch_size/2) % n_eval
    end = min(start + batch_size/2, n_eval)
    batch = eval_mu[start:end] + eval_sigma[start:end] * np.random.randn(end-start, n_latent)
    if batch.shape[0] == batch_size/2:
      labels = attr_eval[start:end]

      res_eval = sess.run([m.d_loss_attr, m.pred_attr], 
                          {m.z: batch,
                           m.labels: labels,})
      
      y_true = labels
      y_pred = (res_eval[1] > 0.5)
      accuracy = np.mean(y_true == y_pred)
      
      traces['i_eval'].append(i)
      traces['D_loss_eval'].append(res_eval[0])
      traces['accuracy'].append(accuracy)
      

  if i % 100 == 0:
    print 'Step %d, \t TRAIN \t Loss: %0.3f \t EVAL \t Loss: %0.3f \t Accuracy: %0.3f' % (i, 
                                                                       np.mean(traces['D_loss'][-running_N_eval:]), 
                                                                       np.mean(traces['D_loss_eval'][-running_N_eval:]),
                                                                       np.mean(traces['accuracy'][-running_N_eval:]), 
                                                                       )
plt.figure(figsize=(18,18))
plt.subplot(3, 1, 1)
plt.semilogy(traces['i'], traces['D_loss'], label='train')
plt.semilogy(traces['i_eval'], traces['D_loss_eval'], label='eval')
plt.ylabel('Loss')
plt.legend(loc='upper right')

plt.subplot(3, 1, 2)
plt.plot(traces['i_eval'], traces['accuracy'], label="eval")
plt.ylabel('Prediction Accuracy')
plt.legend(loc='upper right')
# Prediction Accuracy
train_pred = []
eval_pred = []
test_pred = []

for i in range(n_train / 10 /batch_size):
  start = (i * batch_size)
  end = start + batch_size
  batch_images = train_data[start:end]

  start = i * batch_size
  end = start + batch_size
  res = sess.run([m.pred_attr], {m.x: batch_images, 
                                 m.labels: attr_train[start:end]})
  train_pred.append(res[0])
train_pred = np.vstack(train_pred)

for i in range(n_eval/batch_size):
  start = (i * batch_size)
  end = start + batch_size
  batch_images = eval_data[start:end]

  res = sess.run([m.pred_attr], {m.x: batch_images,
                                 m.labels: attr_eval[start:end]})
  eval_pred.append(res[0])
eval_pred = np.vstack(eval_pred)

train_acc = (train_pred > 0.5) == attr_train[:train_pred.shape[0]]
eval_acc = (eval_pred > 0.5) == attr_eval[:eval_pred.shape[0]]

print "Train Accuracy: %.4f" % (np.mean(train_acc) * 100)
print "Eval Accuracy: %.4f" % (np.mean(eval_acc) * 100)

y_true = attr_eval[:eval_pred.shape[0]]
y_pred = eval_pred >= 0.5

prec, recall, f1, support = sklearn.metrics.classification.precision_recall_fscore_support(y_true, y_pred, )
report = sklearn.metrics.classification_report(y_true, y_pred, digits=2, target_names=attribute_names)
print(report)

猜你喜欢

转载自blog.csdn.net/weixin_41697507/article/details/89300900