Google dopamine 搜索框架算法 -py 语言-立哥开发

# Copy Right 2020 Jacky Zong. All rights reserved.
#coding=utf-8

"""Tests for dopamine.agents.rainbow.rainbow_agent.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from dopamine.agents.dqn import dqn_agent
from dopamine.agents.rainbow import rainbow_agent
from dopamine.discrete_domains import atari_lib
from dopamine.utils import test_utils
import numpy as np
import tensorflow as tf


class ProjectDistributionTest(tf.test.TestCase):

  def testInconsistentSupportsAndWeightsParameters(self):
    supports = tf.constant([[0, 2, 4, 6, 8], [3, 4, 5, 6, 7]], dtype=tf.float32)
    weights = tf.constant(
        [[0.1, 0.2, 0.3, 0.2], [0.1, 0.2, 0.3, 0.2]], dtype=tf.float32)
    target_support = tf.constant([4, 5, 6, 7, 8], dtype=tf.float32)
    with self.assertRaisesRegexp(ValueError, 'are incompatible'):
      rainbow_agent.project_distribution(supports, weights, target_support)

  def testInconsistentSupportsAndWeightsWithPlaceholders(self):
    supports = [[0, 2, 4, 6, 8], [3, 4, 5, 6, 7]]
    supports_ph = tf.compat.v1.placeholder(tf.float32, None)
    weights = [[0.1, 0.2, 0.3, 0.2], [0.1, 0.2, 0.3, 0.2]]
    weights_ph = tf.compat.v1.placeholder(tf.float32, None)
    target_support = [4, 5, 6, 7, 8]
    target_support_ph = tf.compat.v1.placeholder(tf.float32, None)
    projection = rainbow_agent.project_distribution(
        supports_ph, weights_ph, target_support_ph, validate_args=True)
    with self.test_session() as sess:
      tf.compat.v1.global_variables_initializer().run()
      with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                   'assertion failed'):
        sess.run(
            projection,
            feed_dict={
                supports_ph: supports,
                weights_ph: weights,
                target_support_ph: target_support
            })

  def testInconsistentSupportsAndTargetSupportParameters(self):
    supports = tf.constant([[0, 2, 4, 6, 8], [3, 4, 5, 6, 7]], dtype=tf.float32)
    weights = tf.constant(
        [[0.1, 0.2, 0.3, 0.2, 0.2], [0.1, 0.2, 0.3, 0.2, 0.2]],
        dtype=tf.float32)
    target_support = tf.constant([4, 5, 6], dtype=tf.float32)
    with self.assertRaisesRegexp(ValueError, 'are incompatible'):
      rainbow_agent.project_distribution(supports, weights, target_support)

  def testInconsistentSupportsAndTargetSupportWithPlaceholders(self):
    supports = [[0, 2, 4, 6, 8], [3, 4, 5, 6, 7]]
    supports_ph = tf.compat.v1.placeholder(tf.float32, None)
    weights = [[0.1, 0.2, 0.3, 0.2, 0.2], [0.1, 0.2, 0.3, 0.2, 0.2]]
    weights_ph = tf.compat.v1.placeholder(tf.float32, None)
    target_support = [4, 5, 6]
    target_support_ph = tf.compat.v1.placeholder(tf.float32, None)
    projection = rainbow_agent.project_distribution(
        supports_ph, weights_ph, target_support_ph, validate_args=True)
    with self.test_session() as sess:
      tf.compat.v1.global_variables_initializer().run()
      with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                   'assertion failed'):
        sess.run(
            projection,
            feed_dict={
                supports_ph: supports,
                weights_ph: weights,
                target_support_ph: target_support
            })
 
  def testZeroDimensionalTargetSupport(self):
    supports = tf.constant([[0, 2, 4, 6, 8], [3, 4, 5, 6, 7]], dtype=tf.float32)
    weights = tf.constant(
        [[0.1, 0.2, 0.3, 0.2, 0.2], [0.1, 0.2, 0.3, 0.2, 0.2]],
        dtype=tf.float32)
    target_support = tf.constant(3, dtype=tf.float32)
    with self.assertRaisesRegexp(ValueError, 'Index out of range'):
      rainbow_agent.project_distribution(supports, weights, target_support)

  def testZeroDimensionalTargetSupportWithPlaceholders(self):
    supports = [[0, 2, 4, 6, 8], [3, 4, 5, 6, 7]]
    supports_ph = tf.compat.v1.placeholder(tf.float32, None)
    weights = [[0.1, 0.2, 0.3, 0.2, 0.2], [0.1, 0.2, 0.3, 0.2, 0.2]]
    weights_ph = tf.compat.v1.placeholder(tf.float32, None)
    target_support = 3
    target_support_ph = tf.compat.v1.placeholder(tf.float32, None)
    projection = rainbow_agent.project_distribution(
        supports_ph, weights_ph, target_support_ph, validate_args=True)
    with self.test_session() as sess:
      tf.compat.v1.global_variables_initializer().run()
      with self.assertRaises(tf.errors.InvalidArgumentError):
        sess.run(
            projection,
            feed_dict={
                supports_ph: supports,
                weights_ph: weights,
                target_support_ph: target_support
            })

  def testMultiDimensionalTargetSupport(self):
    supports = tf.constant([[0, 2, 4, 6, 8], [3, 4, 5, 6, 7]], dtype=tf.float32)
    weights = tf.constant(
        [[0.1, 0.2, 0.3, 0.2, 0.2], [0.1, 0.2, 0.3, 0.2, 0.2]],
        dtype=tf.float32)
    target_support = tf.constant([[3]], dtype=tf.float32)
    with self.assertRaisesRegexp(ValueError, 'out of bounds'):
      rainbow_agent.project_distribution(supports, weights, target_support)

  def testMultiDimensionalTargetSupportWithPlaceholders(self):
    supports = [[0, 2, 4, 6, 8], [3, 4, 5, 6, 7]]
    supports_ph = tf.compat.v1.placeholder(tf.float32, None)
    weights = [[0.1, 0.2, 0.3, 0.2, 0.2], [0.1, 0.2, 0.3, 0.2, 0.2]]
    weights_ph = tf.compat.v1.placeholder(tf.float32, None)
    target_support = [[3]]
    target_support_ph = tf.compat.v1.placeholder(tf.float32, None)
    projection = rainbow_agent.project_distribution(
        supports_ph, weights_ph, target_support_ph, validate_args=True)
    with self.test_session() as sess:
      tf.compat.v1.global_variables_initializer().run()
      with self.assertRaises(tf.errors.InvalidArgumentError):
        sess.run(
            projection,
            feed_dict={
                supports_ph: supports,
                weights_ph: weights,
                target_support_ph: target_support
            })

猜你喜欢

转载自blog.csdn.net/weixin_45806384/article/details/108694063