tfgan折腾笔记(三):核心函数详述——gan_loss族

gan_loss族的函数有:

1.gan_loss:

函数原型:

def gan_loss(
    # GANModel.
    model,
    # Loss functions.
    generator_loss_fn=tuple_losses.wasserstein_generator_loss,
    discriminator_loss_fn=tuple_losses.wasserstein_discriminator_loss,
    # Auxiliary losses.
    gradient_penalty_weight=None,
    gradient_penalty_epsilon=1e-10,
    gradient_penalty_target=1.0,
    gradient_penalty_one_sided=False,
    mutual_information_penalty_weight=None,
    aux_cond_generator_weight=None,
    aux_cond_discriminator_weight=None,
    tensor_pool_fn=None,
    # Options.
    reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
    add_summaries=True)

参数:

model:gan_model族函数的返回值

generator_loss_fn:生成器使用的损失函数,可用函数见其他说明。

discriminator_loss_fn:判别器使用的损失函数,可用函数见其他说明。

gradient_penalty_weight:如果不是None,则必须提供一个非负数或Tensor,意义为梯度惩罚的权值。

gradient_penalty_epsilon:如果提供了上一个参数,那么这个参数应该提供一个用于在梯度罚函数中维持数值稳定性的较小的正值。 请注意,某些应用程序需要增加此值以避免NaN。

gradient_penalty_target:如果上上个参数不是None,那么这个参数就指明了梯度规范的目标值。应该是一个数值类型或Tensor。

gradient_penalty_one_sided:(暂不明白什么意思)。

mutual_information_penalty_weight:交叉信息惩罚权值。如果不是None,必须提供一个非负数或Tensor。

aux_cond_generator_weight:如果不是None,则添加生成器分类损失。

aux_cond_discriminator_weight:如果不是None,则添加判别器分类损失。

tensor_pool_fn:tensor pool函数。此函数传入tuple类型:(generated_data, generator_inputs),函数将它们放在内部pool中,并且返回上一个pool中的值。如,可以传入tfgan.features.tensor_pool。

reduction:传入tf.losses.Reduction类的函数。

add_summaries:是否添加总结到Tensorboard日志。

返回值:

 返回“GANLoss 命名元组”。

函数内部实现:

# Create standard losses with optional kwargs, if the loss functions accept
  # them.
  def _optional_kwargs(fn, possible_kwargs):
    """Returns a kwargs dictionary of valid kwargs for a given function."""
    if inspect.getargspec(fn).keywords is not None:
      return possible_kwargs
    actual_args = inspect.getargspec(fn).args
    actual_kwargs = {}
    for k, v in possible_kwargs.items():
      if k in actual_args:
        actual_kwargs[k] = v
    return actual_kwargs
  possible_kwargs = {'reduction': reduction, 'add_summaries': add_summaries}
  gen_loss = generator_loss_fn(
      model, **_optional_kwargs(generator_loss_fn, possible_kwargs))
  dis_loss = discriminator_loss_fn(
      pooled_model, **_optional_kwargs(discriminator_loss_fn, possible_kwargs))

其他说明:

  •  tfgan内置损失函数:
__all__ = [
    'acgan_discriminator_loss',
    'acgan_generator_loss',
    'least_squares_discriminator_loss',
    'least_squares_generator_loss',
    'modified_discriminator_loss',
    'modified_generator_loss',
    'minimax_discriminator_loss',
    'minimax_generator_loss',
    'wasserstein_discriminator_loss',
    'wasserstein_hinge_discriminator_loss',
    'wasserstein_hinge_generator_loss',
    'wasserstein_generator_loss',
    'wasserstein_gradient_penalty',
    'mutual_information_penalty',
    'combine_adversarial_loss',
    'cycle_consistency_loss',
    'stargan_generator_loss_wrapper',
    'stargan_discriminator_loss_wrapper',
    'stargan_gradient_penalty_wrapper'
]

2.cyclegan_loss:

函数原型:

def cyclegan_loss(
    model,
    # Loss functions.
    generator_loss_fn=tuple_losses.least_squares_generator_loss,
    discriminator_loss_fn=tuple_losses.least_squares_discriminator_loss,
    # Auxiliary losses.
    cycle_consistency_loss_fn=tuple_losses.cycle_consistency_loss,
    cycle_consistency_loss_weight=10.0,
    # Options
    **kwargs)

参数:

model:gan_model族函数的返回值

generator_loss_fn:生成器使用的损失函数。

discriminator_loss_fn:判别器使用的损失函数。

cycle_consistency_loss_fn:循环一致性损失函数。

cycle_consistency_loss_weight:循环一致性损失的权值。

**kwargs:这里的参数将直接传递给cyclegan_loss函数内部调用的gan_loss函数。

返回值:

返回“CycleGANLoss 命名元组”。

函数内部实现:

循环一致性损失函数与权值的定义:

  # Defines cycle consistency loss.
  cycle_consistency_loss = cycle_consistency_loss_fn(
      model, add_summaries=kwargs.get('add_summaries', True))
  cycle_consistency_loss_weight = _validate_aux_loss_weight(
      cycle_consistency_loss_weight, 'cycle_consistency_loss_weight')
  aux_loss = cycle_consistency_loss_weight * cycle_consistency_loss

**kwargs的实现:

  # Defines losses for each partial model.
  def _partial_loss(partial_model):
    partial_loss = gan_loss(
        partial_model,
        generator_loss_fn=generator_loss_fn,
        discriminator_loss_fn=discriminator_loss_fn,
        **kwargs)
    return partial_loss._replace(generator_loss=partial_loss.generator_loss +
                                 aux_loss)

  with tf.compat.v1.name_scope('cyclegan_loss_x2y'):
    loss_x2y = _partial_loss(model.model_x2y)
  with tf.compat.v1.name_scope('cyclegan_loss_y2x'):
    loss_y2x = _partial_loss(model.model_y2x)

其他说明:

  • cycle-gan实际上是由两个普通gan组合而成的,其loss是普通gan的loss加上循环一致性损失。
  • 循环一致性损失权值越大,则X->Y->X循环的相似性方面学习的越快。

3.stargan_loss:

函数原型:

def stargan_loss(
    model,
    generator_loss_fn=tuple_losses.stargan_generator_loss_wrapper(
        losses_wargs.wasserstein_generator_loss),
    discriminator_loss_fn=tuple_losses.stargan_discriminator_loss_wrapper(
        losses_wargs.wasserstein_discriminator_loss),
    gradient_penalty_weight=10.0,
    gradient_penalty_epsilon=1e-10,
    gradient_penalty_target=1.0,
    gradient_penalty_one_sided=False,
    reconstruction_loss_fn=tf.compat.v1.losses.absolute_difference,
    reconstruction_loss_weight=10.0,
    classification_loss_fn=tf.compat.v1.losses.softmax_cross_entropy,
    classification_loss_weight=1.0,
    classification_one_hot=True,
    add_summaries=True)

参数:

model:gan_model族函数的返回值

generator_loss_fn:生成器使用的损失函数。

discriminator_loss_fn:判别器使用的损失函数。

gradient_penalty_weight:如果不是None,则必须提供一个非负数或Tensor,意义为梯度惩罚的权值。

gradient_penalty_epsilon:如果提供了上一个参数,那么这个参数应该提供一个用于在梯度罚函数中维持数值稳定性的较小的正值。 请注意,某些应用程序需要增加此值以避免NaN。

gradient_penalty_target:如果上上个参数不是None,那么这个参数就指明了梯度规范的目标值。应该是一个数值类型或Tensor。

gradient_penalty_one_sided:(暂不明白什么意思)。

reconstruction_loss_fn:重建损失函数。

reconstruction_loss_weight:重建损失的权重。

classification_loss_fn:分类损失函数。

classification_loss_weight:分类损失的权重。

classification_one_hot:分类的one_hot_label。

add_summaries:是否向tensorboard添加总结。

返回值:

 返回“StarGANLoss 命名元组”。

函数内部实现:

 梯度惩罚函数与权值的定义:

  # Gradient Penalty.
  if _use_aux_loss(gradient_penalty_weight):
    gradient_penalty_fn = tuple_losses.stargan_gradient_penalty_wrapper(
        losses_wargs.wasserstein_gradient_penalty)
    discriminator_loss += gradient_penalty_fn(
        model,
        epsilon=gradient_penalty_epsilon,
        target=gradient_penalty_target,
        one_sided=gradient_penalty_one_sided,
        add_summaries=add_summaries) * gradient_penalty_weight

重建损失函数与权值的定义:

  # Reconstruction Loss.
  reconstruction_loss = reconstruction_loss_fn(model.input_data,
                                               model.reconstructed_data)
  generator_loss += reconstruction_loss * reconstruction_loss_weight
  if add_summaries:
    tf.compat.v1.summary.scalar('reconstruction_loss', reconstruction_loss)

分类损失函数与权值定义:

  # Classification Loss.
  generator_loss += _classification_loss_helper(
      true_labels=model.generated_data_domain_target,
      predict_logits=model.discriminator_generated_data_domain_predication,
      scope_name='generator_classification_loss') * classification_loss_weight
  discriminator_loss += _classification_loss_helper(
      true_labels=model.input_data_domain_label,
      predict_logits=model.discriminator_input_data_domain_predication,
      scope_name='discriminator_classification_loss'
  ) * classification_loss_weight

其他说明:

 无

猜你喜欢

转载自www.cnblogs.com/WongWai95/p/TFGAN-ZHE-TENG-BI-JI-3.html