强化学习 策略梯度 小例子

       本节内容例子仍为上一节的例子,见:

强化学习 Q-learning 小例子

       代码原型可参见:

https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/tree/master/contents/7_Policy_gradient_softmax

本文仅仅是按问题修改了一下代码。

策略梯度算法:(其一个区别于Q-Learning的要点包括:

1、action空间可以是连续的,即可连续抽样得到样本;

2、思想直接继承于基于梯度的优化,这种方式使得使用神经网络进行特征提取变为可能;

3、可以方便地与分类或回归问题直接挂钩建模;

4、与Q-learning每一步都进行Q表更新不同,其积累若干”batch”后进行统一更新。)


例子:

上一节的找宝藏问题。

代码: 

PolicyGradient类,对原始的该类进行简单修改使得,通过调用valid_distribution函数可以看到类似于上一节的Q表分布,来看算法收敛效果。

_build_net层重新定义网络结构,分单层特征提取及多层特征提取,这里的特征输入仅为位置,是一维特征。

import tensorflow as tf
import numpy as np

np.random.seed(1)
tf.set_random_seed(1)
EPSILON = 0.9

class PolicyGradient(object):
    def __init__(self, n_actions, n_features,
                 learning_rate = 0.01, reward_decay = 0.95,
                 output_graph = False, single_layer = False):
        self.n_actions = n_actions
        self.n_features = n_features
        self.lr = learning_rate
        self.gamma = reward_decay
        self.single_layer = single_layer

        self.ep_obs, self.ep_as, self.ep_rs = [], [], []
        self._build_net()
        self.sess = tf.Session()
        if output_graph:
            tf.summary.FileWriter("logs/", self.sess.graph)

        self.sess.run(tf.global_variables_initializer())

    def _build_net(self):
        with tf.name_scope("inputs"):
            self.tf_obs = tf.placeholder(tf.float32, [None, self.n_features], name="observations")
            self.tf_acts = tf.placeholder(tf.int32, [None], name="actions_num")
            self.tf_vt = tf.placeholder(tf.int32, [None], name="action_value")

        if self.single_layer:
            all_act = tf.layers.dense(
                inputs=self.tf_obs,
                units=self.n_actions,
                activation=None,
                kernel_initializer=tf.random_uniform_initializer(minval=0.48, maxval=0.52),
                bias_initializer=tf.constant_initializer(0.1),
                name = "fc1"
            )
        else:
            layer1 = tf.layers.dense(
                inputs=self.tf_obs,
                units=5,
                activation=tf.nn.tanh,
                kernel_initializer=tf.random_normal_initializer(mean = 0, stddev=0.3),
                bias_initializer=tf.constant_initializer(0.1),
                name = "layer1"
            )

            layer2 = tf.layers.dense(
                inputs=layer1,
                units=5,
                activation=tf.nn.tanh,
                kernel_initializer=tf.random_normal_initializer(mean = 0, stddev=0.3),
                bias_initializer=tf.constant_initializer(0.1),
                name = "layer2"
            )

            layer3 = tf.layers.dense(
                inputs=layer2,
                units=5,
                activation=tf.nn.tanh,
                kernel_initializer=tf.random_normal_initializer(mean = 0, stddev=0.3),
                bias_initializer=tf.constant_initializer(0.1),
                name = "layer3"
            )

            all_act = tf.layers.dense(
                inputs=layer3,
                units=self.n_actions,
                activation=None,
                kernel_initializer=tf.random_normal_initializer(mean = 0, stddev=0.3),
                bias_initializer=tf.constant_initializer(0.1),
                name = "fc2"
            )


        self.all_act_prob = tf.nn.softmax(all_act, name = "act_prob")

        with tf.name_scope("loss"):
            neg_log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=all_act, labels=self.tf_acts)
            loss = tf.reduce_mean(neg_log_prob * tf.cast(self.tf_vt, tf.float32))

        with tf.name_scope("train"):
            self.train_op = tf.train.AdamOptimizer(self.lr).minimize(loss)

    def valid_distribution(self):
        total_conclusion = self.sess.run(self.all_act_prob, feed_dict={self.tf_obs: np.arange(5).reshape([5, 1])})
        print("total_conclusion :")
        print(total_conclusion)

    def choose_action(self, observation):
        prob_weights = self.sess.run(self.all_act_prob, feed_dict={self.tf_obs: observation[np.newaxis, :]})
        action = np.random.choice(range(prob_weights.shape[1]), p = prob_weights.ravel())
        return action

    def store_transition(self, s, a, r):
        self.ep_obs.append(s)
        self.ep_as.append(a)
        self.ep_rs.append(r)

    def _discount_and_norm_rewards(self):
        discounted_ep_rs = np.zeros_like(self.ep_rs)
        running_add = 0
        for t in reversed(range(0, len(self.ep_rs))):
            running_add = running_add * self.gamma + self.ep_rs[t]
            discounted_ep_rs[t] = running_add

        if np.sum(np.abs(discounted_ep_rs)) == 0:
            return discounted_ep_rs

        discounted_ep_rs -= np.mean(discounted_ep_rs)
        discounted_ep_rs /= np.std(discounted_ep_rs)
        return discounted_ep_rs

    def learn(self):
        discounted_ep_rs_norm = self._discount_and_norm_rewards()

        obs, acts, vt =  np.vstack(self.ep_obs), np.array(self.ep_as), discounted_ep_rs_norm
        self.sess.run(self.train_op, feed_dict={
            self.tf_obs: obs,
            self.tf_acts: acts,
            self.tf_vt: vt,
        })

        self.ep_obs, self.ep_as, self.ep_rs = [], [], []
        return vt

效果运行类:

这里导入了上一节定义的若干函数,存成start.py文件,作为判定下一步状态使用。

from RL_brain import PolicyGradient
from start import *
from copy import deepcopy

ACTIONS = ['n', 'e', 's', 'w']
n_actions = len(ACTIONS)
# set features as state position only.
n_features = 1

RL = PolicyGradient(
    n_actions=n_actions,
    n_features=n_features,
    learning_rate=0.01,
    reward_decay=0.99
)

debug = False
for i_episode in range(int(1e10)):
    reset = True
    observation, step_counter = [None] * 2
    while True:
        if reset:
            observation = np.random.choice([0, 1, 2, 3, 4])
            observation = np.array([observation])
            step_counter = 0
            reset = False
        step_counter += 1

        action_index = RL.choose_action(observation)
        action = ACTIONS[action_index]
        observation = observation.reshape([])
        ori_observation = deepcopy(observation)

        observation_ ,reward = get_env_feedback(observation, action)

        observation = ori_observation
        reward = float(reward)
        RL.store_transition(observation, action_index, reward)

        if type(observation_) == type(""):
            if debug:
                update_env(observation_, i_episode, step_counter)
            reset = True
            ep_rs_sum = sum(RL.ep_rs)
            if len(RL.ep_rs) > 1000:
                vt = RL.learn()
            break

        observation = observation_
        observation = np.array([observation])

    if i_episode % 10000 == 0:
        print("epoch {} end".format(i_episode))
        RL.valid_distribution()

if __name__ == "__main__":
    pass

结果说明,在进行网络构建过程中,有两个要点:

1、 由于是学习每一个位置下所选择的位移方向分布,要在网络初始化过程中使得初始概率分布尽量接近离散均匀分布。

要想获得这样的分布一般有两种选择,单层,小范围均匀初始化权值,或多层初始化。

2、 浅层网络,进行reward折旧的权值,不要离0太近,由于整个体系的可选特征空间比较小,使得,当折旧率太小时,模型的“没有远见”表现得太明显,如对于total_conclusion 当某一行

的某一列概率值特别大时,该列的其他行也会表现的很大,从而得到较差的结果,见下面输出:(reward_decay=0.5,single_layer=True
epoch 0 end
total_conclusion :
[[0.25       0.25       0.25       0.25      ]
 [0.24992794 0.25418103 0.24820147 0.24768955]
 [0.24982986 0.25840503 0.2463902  0.24537486]
 [0.24970558 0.26267165 0.24456646 0.2430563 ]
 [0.24955492 0.2669804  0.24273053 0.24073412]]
epoch 10000 end
total_conclusion :
[[0.21704012 0.49736094 0.16374025 0.12185866]
 [0.18894796 0.25690013 0.364241   0.18991089]
 [0.11720868 0.09455233 0.57734776 0.21089128]
 [0.05784947 0.02768868 0.72812897 0.18633284]
 [0.0255025  0.00724229 0.82020557 0.14704964]]
epoch 20000 end
total_conclusion :
[[1.1902366e-01 6.9792479e-01 1.2305441e-01 5.9997093e-02]
 [8.7425530e-02 2.9234013e-01 5.0801122e-01 1.1222315e-01]
 [2.5749996e-02 4.9102332e-02 8.4097552e-01 8.4172145e-02]
 [5.1554046e-03 5.6061218e-03 9.4632423e-01 4.2914219e-02]
 [9.4831176e-04 5.8806542e-04 9.7836179e-01 2.0101879e-02]]
epoch 30000 end
total_conclusion :
[[6.2642455e-02 8.3603925e-01 7.0625499e-02 3.0692697e-02]
 [5.2187711e-02 3.8024789e-01 4.6989158e-01 9.7672820e-02]
 [1.1900097e-02 4.7335804e-02 8.5569036e-01 8.5073665e-02]
 [1.6536257e-03 3.5910148e-03 9.4959885e-01 4.5156576e-02]
 [2.1310350e-04 2.5264540e-04 9.7730553e-01 2.2228684e-02]]
epoch 40000 end
total_conclusion :
[[3.7692133e-02 8.9608622e-01 4.9539912e-02 1.6681695e-02]
 [3.5235319e-02 4.1345775e-01 4.7005072e-01 8.1256196e-02]
 [6.4846254e-03 3.7557058e-02 8.7803781e-01 7.7920519e-02]
 [6.9405942e-04 1.9840726e-03 9.5386559e-01 4.3456275e-02]
 [7.0038055e-05 9.8820849e-05 9.7698158e-01 2.2849590e-02]]
epoch 50000 end
total_conclusion :
[[2.4752773e-02 9.2907399e-01 3.6719639e-02 9.4535714e-03]
 [2.6470292e-02 4.4526717e-01 4.5918724e-01 6.9075264e-02]
 [4.3625305e-03 3.2887880e-02 8.8496488e-01 7.7784650e-02]
 [4.0026163e-04 1.3523112e-03 9.4948435e-01 4.8763059e-02]
 [3.4996196e-05 5.2989431e-05 9.7078079e-01 2.9131269e-02]]
epoch 60000 end
total_conclusion :
[[1.7649898e-02 9.4712859e-01 2.9592250e-02 5.6293383e-03]
 [2.1068770e-02 4.6587506e-01 4.5520204e-01 5.7854064e-02]
 [3.2033925e-03 2.9187966e-02 8.9187586e-01 7.5732864e-02]
 [2.6343120e-04 9.8906469e-04 9.4512820e-01 5.3619299e-02]
 [2.0838543e-05 3.2239463e-05 9.6342945e-01 3.6517467e-02]]

reward_decay=0.9, single_layer=True

epoch 0 end
total_conclusion :
[[0.25       0.25       0.25       0.25      ]
 [0.24992794 0.25418103 0.24820147 0.24768955]
 [0.24982986 0.25840503 0.2463902  0.24537486]
 [0.24970558 0.26267165 0.24456646 0.2430563 ]
 [0.24955492 0.2669804  0.24273053 0.24073412]]
epoch 10000 end
total_conclusion :
[[0.2541749  0.4339638  0.18223304 0.12962824]
 [0.31923154 0.29700777 0.1915905  0.19217019]
 [0.36765605 0.18639956 0.1847071  0.2612373 ]
 [0.3943957  0.10896233 0.1658623  0.3307797 ]
 [0.40119484 0.06040053 0.14123574 0.39716887]]
epoch 20000 end
total_conclusion :
[[0.17127652 0.5577266  0.19322462 0.07777215]
 [0.21130787 0.3176208  0.3229146  0.14815669]
 [0.20633315 0.14316341 0.42711854 0.22338496]
 [0.17253096 0.05525859 0.48378655 0.28842393]
 [0.13284592 0.01964042 0.5045944  0.34291923]]
epoch 30000 end
total_conclusion :
[[0.13867429 0.63206965 0.17291538 0.05634059]
 [0.18240736 0.3259061  0.35483548 0.13685109]
 [0.16338208 0.11442888 0.49583372 0.22635537]
 [0.11672037 0.03204491 0.5526184  0.29861635]
 [0.07565267 0.00814175 0.5587916  0.35741407]]
epoch 40000 end
total_conclusion :
[[0.10081001 0.71501565 0.14359678 0.04057756]
 [0.13471448 0.39188603 0.34479922 0.12860028]
 [0.11042301 0.13174638 0.50783485 0.24999571]
 [0.06612735 0.03235891 0.5464557  0.35505807]
 [0.03474245 0.00697279 0.51587576 0.44240898]]
epoch 50000 end
total_conclusion :
[[0.07768991 0.7635396  0.12839992 0.03037059]
 [0.10534117 0.41834876 0.355451   0.12085913]
 [0.07775372 0.12477711 0.53565377 0.26181543]
 [0.03906842 0.02533454 0.5495034  0.38609374]
 [0.01695424 0.00444262 0.4868603  0.4917428 ]]
epoch 60000 end
total_conclusion :
[[0.05408907 0.8204231  0.10363895 0.02184892]
 [0.07231493 0.50978357 0.31268135 0.10522011]
 [0.05188116 0.16997969 0.50622576 0.27191344]
 [0.02303072 0.03506908 0.5071108  0.43478936]
 [0.00837533 0.00592719 0.41615787 0.5695396 ]]
epoch 70000 end
total_conclusion :
[[0.05169893 0.82091075 0.10882547 0.0185649 ]
 [0.07353021 0.44676957 0.38063854 0.09906169]
 [0.04737122 0.11013763 0.60305846 0.23943269]
 [0.01917202 0.0170566  0.6002202  0.36355114]
 [0.00669014 0.00227753 0.51508164 0.47595078]]
epoch 80000 end
total_conclusion :
[[0.04197544 0.84683263 0.09752993 0.01366197]
 [0.06214712 0.49029282 0.3653582  0.08220191]
 [0.04109263 0.12677413 0.6112474  0.22088583]
 [0.01621071 0.01955695 0.6101133  0.35411903]
 [0.00539158 0.00254359 0.5134281  0.47863674]]
epoch 90000 end
total_conclusion :
[[0.0396813  0.864847   0.08391145 0.01156029]
 [0.06822652 0.5031161  0.34562966 0.08302777]
 [0.04827514 0.12044813 0.58587325 0.24540342]
 [0.01917449 0.01618683 0.55747616 0.40716252]
 [0.00626419 0.00178922 0.43630424 0.55564237]]
epoch 100000 end
total_conclusion :
[[0.03582421 0.8740298  0.08067635 0.00946964]
 [0.06337273 0.50902355 0.35384497 0.07375869]
 [0.04422295 0.11694155 0.6122082  0.22662738]
 [0.01701888 0.01481624 0.58414865 0.3840162 ]
 [0.00538392 0.00154309 0.45817512 0.5348978 ]]
epoch 110000 end
total_conclusion :
[[3.7808131e-02 8.7277049e-01 8.0167063e-02 9.2543894e-03]
 [7.3226146e-02 4.8032302e-01 3.6108947e-01 8.5361369e-02]
 [5.0292756e-02 9.3740024e-02 5.7675588e-01 2.7921137e-01]
 [1.8301729e-02 9.6931336e-03 4.8810893e-01 4.8389620e-01]
 [5.2883681e-03 7.9587736e-04 3.2800779e-01 6.6590798e-01]]
epoch 120000 end
total_conclusion :
[[3.45975310e-02 8.83603871e-01 7.44067430e-02 7.39190215e-03]
 [7.09570944e-02 5.02655447e-01 3.52849692e-01 7.35377371e-02]
 [5.13085201e-02 1.00815214e-01 5.89943051e-01 2.57933199e-01]
 [1.90419760e-02 1.03779277e-02 5.06243348e-01 4.64336753e-01]
 [5.52772358e-03 8.35616316e-04 3.39797616e-01 6.53839052e-01]]
epoch 130000 end
total_conclusion :
[[3.4428395e-02 8.8025826e-01 7.8599989e-02 6.7133084e-03]
 [7.1423702e-02 4.6414855e-01 3.9321566e-01 7.1212135e-02]
 [4.7560442e-02 7.8556448e-02 6.3141793e-01 2.4246517e-01]
 [1.6806141e-02 7.0554558e-03 5.3804874e-01 4.3808973e-01]
 [4.7259689e-03 5.0427718e-04 3.6486074e-01 6.2990904e-01]]
epoch 140000 end
total_conclusion :
[[2.9190067e-02 8.9623755e-01 6.9009550e-02 5.5628545e-03]
 [6.3767701e-02 5.0694317e-01 3.6289334e-01 6.6395715e-02]
 [4.4551507e-02 9.1704644e-02 6.1030209e-01 2.5344178e-01]
 [1.5246476e-02 8.1258537e-03 5.0275469e-01 4.7387296e-01]
 [3.9947815e-03 5.5126823e-04 3.1709099e-01 6.7836297e-01]]
epoch 150000 end
total_conclusion :
[[2.61038337e-02 9.02455091e-01 6.66292384e-02 4.81185177e-03]
 [5.75109012e-02 5.16855896e-01 3.64215791e-01 6.14175014e-02]
 [3.96257900e-02 9.25753191e-02 6.22636378e-01 2.45162502e-01]
 [1.30827725e-02 7.94538576e-03 5.10039926e-01 4.68931943e-01]
 [3.27287335e-03 5.16704924e-04 3.16578746e-01 6.79631650e-01]]
epoch 160000 end
total_conclusion :
[[2.3320269e-02 9.0283412e-01 6.9880642e-02 3.9650649e-03]
 [4.9500983e-02 4.9232301e-01 4.0799561e-01 5.0180390e-02]
 [3.0989032e-02 7.9178333e-02 7.0253527e-01 1.8729736e-01]
 [9.9952416e-03 6.5607498e-03 6.2326348e-01 3.6018053e-01]
 [2.5804520e-03 4.3512895e-04 4.4258085e-01 5.5440360e-01]]
epoch 170000 end
total_conclusion :
[[2.0959368e-02 9.1501921e-01 6.0480446e-02 3.5409157e-03]
 [4.7873363e-02 5.4098064e-01 3.5951683e-01 5.1629137e-02]
 [3.2945268e-02 9.6364401e-02 6.4388281e-01 2.2680752e-01]
 [1.0355502e-02 7.8402609e-03 5.2671248e-01 4.5509183e-01]
 [2.4148514e-03 4.7324455e-04 3.1965497e-01 6.7745697e-01]]
epoch 180000 end
total_conclusion :
[[2.0037860e-02 9.1648364e-01 6.0523044e-02 2.9554151e-03]
 [4.7534682e-02 5.0857818e-01 4.0178594e-01 4.2101104e-02]
 [3.0792855e-02 7.7067420e-02 7.2836429e-01 1.6377541e-01]
 [1.0028366e-02 5.8711735e-03 6.6380942e-01 3.2029104e-01]
 [2.6443410e-03 3.6214772e-04 4.8983011e-01 5.0716347e-01]]
epoch 190000 end
total_conclusion :
[[2.0515777e-02 9.1274059e-01 6.3958667e-02 2.7848831e-03]
 [4.9688257e-02 4.6594495e-01 4.4379443e-01 4.0572379e-02]
 [2.9871482e-02 5.9041731e-02 7.6436615e-01 1.4672065e-01]
 [9.5903184e-03 3.9953678e-03 7.0306307e-01 2.8335130e-01]
 [2.5717402e-03 2.2582512e-04 5.4013824e-01 4.5706421e-01]]
epoch 200000 end
total_conclusion :
[[1.9618945e-02 9.1322052e-01 6.4400353e-02 2.7601693e-03]
 [4.8081201e-02 4.6147650e-01 4.4375980e-01 4.6682552e-02]
 [2.8066944e-02 5.5544838e-02 7.2832942e-01 1.8805878e-01]
 [8.2912296e-03 3.3833098e-03 6.0493928e-01 3.8338622e-01]
 [1.9035551e-03 1.6016315e-04 3.9049777e-01 6.0743850e-01]]
epoch 210000 end
total_conclusion :
[[1.6461434e-02 9.1983438e-01 6.1292339e-02 2.4118708e-03]
 [3.8809031e-02 4.7240821e-01 4.4620728e-01 4.2575449e-02]
 [2.1110723e-02 5.5979695e-02 7.4950135e-01 1.7340830e-01]
 [5.7899337e-03 3.3445982e-03 6.3475835e-01 3.5610709e-01]
 [1.2497237e-03 1.5726314e-04 4.2307195e-01 5.7552105e-01]]
epoch 220000 end
total_conclusion :
[[1.5052675e-02 9.2754936e-01 5.5169675e-02 2.2282486e-03]
 [3.8371310e-02 4.9669358e-01 4.1841790e-01 4.6517245e-02]
 [2.1696571e-02 5.8997225e-02 7.0390111e-01 2.1540506e-01]
 [5.5740857e-03 3.1839970e-03 5.3803563e-01 4.5320624e-01]
 [1.0480476e-03 1.2575880e-04 3.0097839e-01 6.9784772e-01]]
epoch 230000 end
total_conclusion :
[[1.2719852e-02 9.3808472e-01 4.7526378e-02 1.6690409e-03]
 [3.3230912e-02 5.5534708e-01 3.7758598e-01 3.3835981e-02]
 [2.1167737e-02 8.0160193e-02 7.3142356e-01 1.6724853e-01]
 [5.9436020e-03 5.1003071e-03 6.2454700e-01 3.6440909e-01]
 [1.2554835e-03 2.4412917e-04 4.0118751e-01 5.9731287e-01]]
epoch 240000 end
total_conclusion :
[[1.19138835e-02 9.40711200e-01 4.58427444e-02 1.53206289e-03]
 [3.24871019e-02 5.54068565e-01 3.79663557e-01 3.37807797e-02]
 [2.05819905e-02 7.58209676e-02 7.30543077e-01 1.73053950e-01]
 [5.63108036e-03 4.48067114e-03 6.07044756e-01 3.82843435e-01]
 [1.13851484e-03 1.95676825e-04 3.72767657e-01 6.25898123e-01]]
epoch 250000 end
total_conclusion :
[[1.0488704e-02 9.4734508e-01 4.0677190e-02 1.4889601e-03]
 [2.9026536e-02 5.9247094e-01 3.3981869e-01 3.8683772e-02]
 [1.8703880e-02 8.6275846e-02 6.6100842e-01 2.3401189e-01]
 [4.4211950e-03 4.6087466e-03 4.7166994e-01 5.1930004e-01]
 [7.0127757e-04 1.6520331e-04 2.2584547e-01 7.7328807e-01]]
epoch 260000 end
total_conclusion :
[[1.1348815e-02 9.4073093e-01 4.6427861e-02 1.4923185e-03]
 [3.2144837e-02 5.1513839e-01 4.1301119e-01 3.9705586e-02]
 [1.7839961e-02 5.5271905e-02 7.1989143e-01 2.0699683e-01]
 [4.2135976e-03 2.5238392e-03 5.3400928e-01 4.5925325e-01]
 [7.0275116e-04 8.1378137e-05 2.7971739e-01 7.1949846e-01]]
epoch 270000 end
total_conclusion :
[[1.0846548e-02 9.4296771e-01 4.4784609e-02 1.4010831e-03]
 [3.1814110e-02 5.2445072e-01 4.0349430e-01 4.0240850e-02]
 [1.8027861e-02 5.6351926e-02 7.0233160e-01 2.2328855e-01]
 [4.1229795e-03 2.4437420e-03 4.9338910e-01 5.0004423e-01]
 [6.4254826e-04 7.2215320e-05 2.3619159e-01 7.6309371e-01]]
epoch 280000 end
total_conclusion :
[[9.1504268e-03 9.4728011e-01 4.2469598e-02 1.0999093e-03]
 [2.6146499e-02 5.3540409e-01 4.0784016e-01 3.0609239e-02]
 [1.4519225e-02 5.8808822e-02 7.6113093e-01 1.6554105e-01]
 [3.4599402e-03 2.7720346e-03 6.0957092e-01 3.8419718e-01]
 [5.9711642e-04 9.4628042e-05 3.5355282e-01 6.4575541e-01]]
epoch 290000 end
total_conclusion :
[[7.7553848e-03 9.5317161e-01 3.8087185e-02 9.8578667e-04]
 [2.2035997e-02 5.7590652e-01 3.7154385e-01 3.0513545e-02]
 [1.2574041e-02 6.9878809e-02 7.2786993e-01 1.8967718e-01]
 [2.7378439e-03 3.2354214e-03 5.4411316e-01 4.4991362e-01]
 [4.0424295e-04 1.0158180e-04 2.7581939e-01 7.2367477e-01]]
epoch 300000 end
total_conclusion :
[[7.4513480e-03 9.5213538e-01 3.9498676e-02 9.1447186e-04]
 [2.0736352e-02 5.5535841e-01 3.9457828e-01 2.9326942e-02]
 [1.0962928e-02 6.1538193e-02 7.4882549e-01 1.7867348e-01]
 [2.2978734e-03 2.7034695e-03 5.6342137e-01 4.3157732e-01]
 [3.2832366e-04 8.0960788e-05 2.8897664e-01 7.1061409e-01]]

reward_decay=0.9, single_layer=False

epoch 0 end
total_conclusion :
[[0.24272624 0.2664516  0.2431751  0.24764712]
 [0.25305024 0.25117907 0.23963916 0.25613153]
 [0.25990137 0.24152242 0.23768328 0.26089293]
 [0.26320875 0.2369366  0.23683953 0.26301512]
 [0.2643011  0.23534094 0.23657104 0.263787  ]]
epoch 10000 end
total_conclusion :
[[0.01996518 0.96438617 0.00703121 0.00861744]
 [0.06497823 0.8767139  0.02951219 0.02879566]
 [0.20337433 0.01869064 0.53549784 0.24243721]
 [0.17275733 0.00595712 0.58486766 0.23641789]
 [0.17022285 0.00502265 0.5905856  0.23416895]]
epoch 20000 end
total_conclusion :
[[0.00977914 0.9852465  0.00305074 0.00192364]
 [0.04920592 0.91269904 0.02295803 0.01513704]
 [0.1082299  0.00256753 0.5223677  0.36683488]
 [0.10114782 0.00176557 0.51610726 0.38097933]
 [0.10053636 0.00169165 0.51489586 0.38287607]]
epoch 30000 end
total_conclusion :
[[0.01595607 0.9805597  0.00194518 0.00153896]
 [0.02344857 0.97012335 0.00357197 0.00285605]
 [0.05884041 0.00577983 0.67154795 0.26383168]
 [0.0345589  0.00144537 0.6769783  0.28701743]
 [0.03300219 0.00128465 0.67455024 0.291163  ]]
epoch 40000 end
total_conclusion :
[[1.4459175e-02 9.8325908e-01 1.3620944e-03 9.1964524e-04]
 [2.0464014e-02 9.7535014e-01 2.4464938e-03 1.7393864e-03]
 [3.4550514e-02 2.8977739e-03 7.1116060e-01 2.5139111e-01]
 [2.1084290e-02 9.5899979e-04 6.8814623e-01 2.8981048e-01]
 [2.0119997e-02 8.7089848e-04 6.8306744e-01 2.9594159e-01]]
epoch 50000 end
total_conclusion :
[[9.5187183e-03 9.8891360e-01 1.0428770e-03 5.2477716e-04]
 [1.3581087e-02 9.8361140e-01 1.8001809e-03 1.0072116e-03]
 [3.2373980e-02 1.6586123e-03 7.0518827e-01 2.6077914e-01]
 [2.0910097e-02 6.2694255e-04 6.5031850e-01 3.2814443e-01]
 [2.0023042e-02 5.7456945e-04 6.4101899e-01 3.3838332e-01]]
epoch 60000 end
total_conclusion :
[[9.6859140e-03 9.8901761e-01 9.3515235e-04 3.6124868e-04]
 [1.2463641e-02 9.8555148e-01 1.3848786e-03 5.9998740e-04]
 [2.5720207e-02 1.6818406e-03 7.3695904e-01 2.3563893e-01]
 [1.4572582e-02 5.1689544e-04 6.3634038e-01 3.4857017e-01]
 [1.3729123e-02 4.6342265e-04 6.1854339e-01 3.6726403e-01]]
epoch 70000 end
total_conclusion :
[[8.9608859e-03 9.8972487e-01 1.0660614e-03 2.4815302e-04]
 [2.4359480e-02 9.6856683e-01 5.6313323e-03 1.4423265e-03]
 [1.7355645e-02 7.6086522e-04 7.0904922e-01 2.7283433e-01]
 [1.3102032e-02 4.5217722e-04 5.7744300e-01 4.0900281e-01]
 [1.2477216e-02 4.1652124e-04 5.5157876e-01 4.3552756e-01]]
epoch 80000 end
total_conclusion :
[[1.21220425e-02 9.86533999e-01 1.09698437e-03 2.46952492e-04]
 [2.35213954e-02 9.72076714e-01 3.56799131e-03 8.33907863e-04]
 [1.36275385e-02 9.08883812e-04 8.48975301e-01 1.36488259e-01]
 [9.87306144e-03 4.85680881e-04 6.95991516e-01 2.93649793e-01]
 [9.20754019e-03 4.30718996e-04 6.44228756e-01 3.46133053e-01]]
epoch 90000 end
total_conclusion :
[[1.1520125e-02 9.8696262e-01 1.3372508e-03 1.7998939e-04]
 [1.3917091e-02 9.8403889e-01 1.7830743e-03 2.6083150e-04]
 [2.3827225e-02 2.6155615e-03 8.6892623e-01 1.0463096e-01]
 [1.1177201e-02 5.5414089e-04 5.4745471e-01 4.4081402e-01]
 [9.1576474e-03 4.1186428e-04 4.4888943e-01 5.4154110e-01]]
epoch 100000 end
total_conclusion :
[[1.1328434e-02 9.8669565e-01 1.8585753e-03 1.1740912e-04]
 [1.3247624e-02 9.8427898e-01 2.3172286e-03 1.5621875e-04]
 [2.6637968e-02 3.3733707e-03 8.9636964e-01 7.3619053e-02]
 [9.7785229e-03 4.8096036e-04 3.5935339e-01 6.3038719e-01]
 [7.1577146e-03 3.1912065e-04 2.5883538e-01 7.3368782e-01]]
epoch 110000 end
total_conclusion :
[[1.0035030e-02 9.8829287e-01 1.5927270e-03 7.9357873e-05]
 [1.1977794e-02 9.8583788e-01 2.0754482e-03 1.0889151e-04]
 [1.4448395e-02 1.5947536e-03 9.2113924e-01 6.2817626e-02]
 [6.0290424e-03 2.9274501e-04 3.2462633e-01 6.6905183e-01]
 [4.6168631e-03 2.0721863e-04 2.3333853e-01 7.6183736e-01]]
epoch 120000 end
total_conclusion :
[[9.5429290e-03 9.8864740e-01 1.7501512e-03 5.9503229e-05]
 [1.0475075e-02 9.8746568e-01 1.9891693e-03 7.0009039e-05]
 [2.1429989e-02 4.0749395e-03 9.4946623e-01 2.5028812e-02]
 [6.0653742e-03 2.9634649e-04 2.7282944e-01 7.2080880e-01]
 [3.8274133e-03 1.6645205e-04 1.6236641e-01 8.3363974e-01]]
epoch 130000 end
total_conclusion :
[[7.3853144e-03 9.9109393e-01 1.4747812e-03 4.6016314e-05]
 [7.8640431e-03 9.9048471e-01 1.6001627e-03 5.1146781e-05]
 [3.3291042e-02 1.1841053e-02 9.3784636e-01 1.7021498e-02]
 [6.3327486e-03 2.6507085e-04 2.6024795e-01 7.3315424e-01]
 [3.6769619e-03 1.3145835e-04 1.4505264e-01 8.5113889e-01]]
epoch 140000 end
total_conclusion :
[[6.7819259e-03 9.9175429e-01 1.4199471e-03 4.3870627e-05]
 [7.4995244e-03 9.9082559e-01 1.6232599e-03 5.1661998e-05]
 [1.9932389e-02 2.3596454e-03 9.4132286e-01 3.6385126e-02]
 [4.6057180e-03 1.5399103e-04 1.8117251e-01 8.1406778e-01]
 [3.6281743e-03 1.1328981e-04 1.4112492e-01 8.5513365e-01]]
epoch 150000 end
total_conclusion :
[[7.8055235e-03 9.9048704e-01 1.6636901e-03 4.3849177e-05]
 [8.3282758e-03 9.8981315e-01 1.8098198e-03 4.8767695e-05]
 [2.6316063e-02 6.7005209e-03 9.5382625e-01 1.3157116e-02]
 [6.0292105e-03 2.4533601e-04 2.4495459e-01 7.4877083e-01]
 [3.3822318e-03 1.2004423e-04 1.2837136e-01 8.6812639e-01]]
epoch 160000 end
total_conclusion :
[[8.25789664e-03 9.89703715e-01 2.00118846e-03 3.71836068e-05]
 [9.23451316e-03 9.88391459e-01 2.32919818e-03 4.47779967e-05]
 [1.51501503e-02 1.16333005e-03 8.83565664e-01 1.00120865e-01]
 [2.56293686e-03 9.47155131e-05 8.89582932e-02 9.08384085e-01]
 [2.28234520e-03 8.22104776e-05 7.81738386e-02 9.19461608e-01]]
epoch 170000 end
total_conclusion :
[[8.3111525e-03 9.8970544e-01 1.9454649e-03 3.8032471e-05]
 [8.7420335e-03 9.8913711e-01 2.0793625e-03 4.1358129e-05]
 [2.0256471e-02 4.4284128e-03 9.6530449e-01 1.0010734e-02]
 [4.4577671e-03 1.7727882e-04 2.0603830e-01 7.8932661e-01]
 [2.4488093e-03 8.8382440e-05 9.6857719e-02 9.0060514e-01]]
epoch 180000 end
total_conclusion :
[[6.74042059e-03 9.91177499e-01 2.05093599e-03 3.11263211e-05]
 [7.52371922e-03 9.90062356e-01 2.37671356e-03 3.72426912e-05]
 [1.34547865e-02 8.92202544e-04 8.75965297e-01 1.09687738e-01]
 [2.35858560e-03 7.37307128e-05 7.19583929e-02 9.25609291e-01]
 [2.14538001e-03 6.55836702e-05 6.43573254e-02 9.33431685e-01]]
epoch 190000 end
total_conclusion :
[[4.7103683e-03 9.9364752e-01 1.6120900e-03 2.9957560e-05]
 [4.9255760e-03 9.9333936e-01 1.7029172e-03 3.2125514e-05]
 [2.5764411e-02 9.7385934e-03 9.5587242e-01 8.6245174e-03]
 [7.0717232e-03 1.9899267e-04 2.5494295e-01 7.3778635e-01]
 [3.3879858e-03 7.7527067e-05 9.4206490e-02 9.0232801e-01]]
epoch 200000 end
total_conclusion :
[[5.5062859e-03 9.9220121e-01 2.2599392e-03 3.2476153e-05]
 [6.4469520e-03 9.9072838e-01 2.7833492e-03 4.1395240e-05]
 [1.2959373e-02 8.8687212e-04 9.4337398e-01 4.2779762e-02]
 [3.4274671e-03 8.8625340e-05 8.0753848e-02 9.1573012e-01]
 [2.9254081e-03 7.2692987e-05 6.6195659e-02 9.3080622e-01]]
epoch 210000 end
total_conclusion :
[[6.7050313e-03 9.9026489e-01 2.9965141e-03 3.3478875e-05]
 [7.4811736e-03 9.8902625e-01 3.4527977e-03 3.9731662e-05]
 [1.1403289e-02 1.2215463e-03 9.7320312e-01 1.4172052e-02]
 [3.3910186e-03 1.0862275e-04 7.8053258e-02 9.1844708e-01]
 [2.4969408e-03 7.4855518e-05 5.2599501e-02 9.4482875e-01]]
epoch 220000 end
total_conclusion :
[[8.1366943e-03 9.8807412e-01 3.7588137e-03 3.0369394e-05]
 [8.8895541e-03 9.8685741e-01 4.2179059e-03 3.5078654e-05]
 [1.0125207e-02 1.4564092e-03 9.7747517e-01 1.0943188e-02]
 [2.3958066e-03 9.7383920e-05 5.6047663e-02 9.4145912e-01]
 [1.6963906e-03 6.4517873e-05 3.5769559e-02 9.6246952e-01]]
epoch 230000 end
total_conclusion :
[[1.0783126e-02 9.8507023e-01 4.1133207e-03 3.3298995e-05]
 [1.1605123e-02 9.8381996e-01 4.5371768e-03 3.7690890e-05]
 [1.0146837e-02 1.9443299e-03 9.7945261e-01 8.4562581e-03]
 [2.1244360e-03 1.1519293e-04 6.1244793e-02 9.3651557e-01]
 [1.3985953e-03 7.0800306e-05 3.6053002e-02 9.6247756e-01]]
epoch 240000 end
total_conclusion :
[[1.3546581e-02 9.8178297e-01 4.6384856e-03 3.1966913e-05]
 [1.4745870e-02 9.8001206e-01 5.2051023e-03 3.6996076e-05]
 [1.1386400e-02 1.4909450e-03 9.6796215e-01 1.9160511e-02]
 [1.2758711e-03 7.5146745e-05 3.2815725e-02 9.6583331e-01]
 [1.0705406e-03 6.1382270e-05 2.6746141e-02 9.7212189e-01]]
epoch 250000 end
total_conclusion :
[[1.8375706e-02 9.7645175e-01 5.1351385e-03 3.7367889e-05]
 [1.9848045e-02 9.7437811e-01 5.7308432e-03 4.2954627e-05]
 [1.0006342e-02 1.5601777e-03 9.7213757e-01 1.6295882e-02]
 [1.1278926e-03 8.7725217e-05 3.5601538e-02 9.6318287e-01]
 [9.2888821e-04 7.0503593e-05 2.8413255e-02 9.7058737e-01]]

最后一种是我们要的收敛结果。(这里的特征维度还仅仅是位置,就不能用浅层网路了)

这种优化方法也是非常简单的一种,较粗糙的,通过对比迭代步数会发现,

    我们这里要迭代200000 epoch 才可能有较满意的结果,而每一次进行learn又需要积累至少1000个观察样本,这在性能及效果上不一定能够比每一步进行迭代的Q-Learning好,对于有限制的小的可数action空间,使用Q-learning是更好的选择。








猜你喜欢

转载自blog.csdn.net/sinat_30665603/article/details/80549870
今日推荐