本节内容例子仍为上一节的例子,见:
强化学习 Q-learning 小例子
代码原型可参见:
本文仅仅是按问题修改了一下代码。
策略梯度算法:(其一个区别于Q-Learning的要点包括:
1、action空间可以是连续的,即可连续抽样得到样本;
2、思想直接继承于基于梯度的优化,这种方式使得使用神经网络进行特征提取变为可能;
3、可以方便地与分类或回归问题直接挂钩建模;
4、与Q-learning每一步都进行Q表更新不同,其积累若干”batch”后进行统一更新。)
例子:
上一节的找宝藏问题。
代码:
PolicyGradient类,对原始的该类进行简单修改使得,通过调用valid_distribution函数可以看到类似于上一节的Q表分布,来看算法收敛效果。
_build_net层重新定义网络结构,分单层特征提取及多层特征提取,这里的特征输入仅为位置,是一维特征。
import tensorflow as tf import numpy as np np.random.seed(1) tf.set_random_seed(1) EPSILON = 0.9 class PolicyGradient(object): def __init__(self, n_actions, n_features, learning_rate = 0.01, reward_decay = 0.95, output_graph = False, single_layer = False): self.n_actions = n_actions self.n_features = n_features self.lr = learning_rate self.gamma = reward_decay self.single_layer = single_layer self.ep_obs, self.ep_as, self.ep_rs = [], [], [] self._build_net() self.sess = tf.Session() if output_graph: tf.summary.FileWriter("logs/", self.sess.graph) self.sess.run(tf.global_variables_initializer()) def _build_net(self): with tf.name_scope("inputs"): self.tf_obs = tf.placeholder(tf.float32, [None, self.n_features], name="observations") self.tf_acts = tf.placeholder(tf.int32, [None], name="actions_num") self.tf_vt = tf.placeholder(tf.int32, [None], name="action_value") if self.single_layer: all_act = tf.layers.dense( inputs=self.tf_obs, units=self.n_actions, activation=None, kernel_initializer=tf.random_uniform_initializer(minval=0.48, maxval=0.52), bias_initializer=tf.constant_initializer(0.1), name = "fc1" ) else: layer1 = tf.layers.dense( inputs=self.tf_obs, units=5, activation=tf.nn.tanh, kernel_initializer=tf.random_normal_initializer(mean = 0, stddev=0.3), bias_initializer=tf.constant_initializer(0.1), name = "layer1" ) layer2 = tf.layers.dense( inputs=layer1, units=5, activation=tf.nn.tanh, kernel_initializer=tf.random_normal_initializer(mean = 0, stddev=0.3), bias_initializer=tf.constant_initializer(0.1), name = "layer2" ) layer3 = tf.layers.dense( inputs=layer2, units=5, activation=tf.nn.tanh, kernel_initializer=tf.random_normal_initializer(mean = 0, stddev=0.3), bias_initializer=tf.constant_initializer(0.1), name = "layer3" ) all_act = tf.layers.dense( inputs=layer3, units=self.n_actions, activation=None, kernel_initializer=tf.random_normal_initializer(mean = 0, stddev=0.3), bias_initializer=tf.constant_initializer(0.1), name = "fc2" ) self.all_act_prob = tf.nn.softmax(all_act, name = "act_prob") with tf.name_scope("loss"): neg_log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=all_act, labels=self.tf_acts) loss = tf.reduce_mean(neg_log_prob * tf.cast(self.tf_vt, tf.float32)) with tf.name_scope("train"): self.train_op = tf.train.AdamOptimizer(self.lr).minimize(loss) def valid_distribution(self): total_conclusion = self.sess.run(self.all_act_prob, feed_dict={self.tf_obs: np.arange(5).reshape([5, 1])}) print("total_conclusion :") print(total_conclusion) def choose_action(self, observation): prob_weights = self.sess.run(self.all_act_prob, feed_dict={self.tf_obs: observation[np.newaxis, :]}) action = np.random.choice(range(prob_weights.shape[1]), p = prob_weights.ravel()) return action def store_transition(self, s, a, r): self.ep_obs.append(s) self.ep_as.append(a) self.ep_rs.append(r) def _discount_and_norm_rewards(self): discounted_ep_rs = np.zeros_like(self.ep_rs) running_add = 0 for t in reversed(range(0, len(self.ep_rs))): running_add = running_add * self.gamma + self.ep_rs[t] discounted_ep_rs[t] = running_add if np.sum(np.abs(discounted_ep_rs)) == 0: return discounted_ep_rs discounted_ep_rs -= np.mean(discounted_ep_rs) discounted_ep_rs /= np.std(discounted_ep_rs) return discounted_ep_rs def learn(self): discounted_ep_rs_norm = self._discount_and_norm_rewards() obs, acts, vt = np.vstack(self.ep_obs), np.array(self.ep_as), discounted_ep_rs_norm self.sess.run(self.train_op, feed_dict={ self.tf_obs: obs, self.tf_acts: acts, self.tf_vt: vt, }) self.ep_obs, self.ep_as, self.ep_rs = [], [], [] return vt
效果运行类:
这里导入了上一节定义的若干函数,存成start.py文件,作为判定下一步状态使用。
from RL_brain import PolicyGradient from start import * from copy import deepcopy ACTIONS = ['n', 'e', 's', 'w'] n_actions = len(ACTIONS) # set features as state position only. n_features = 1 RL = PolicyGradient( n_actions=n_actions, n_features=n_features, learning_rate=0.01, reward_decay=0.99 ) debug = False for i_episode in range(int(1e10)): reset = True observation, step_counter = [None] * 2 while True: if reset: observation = np.random.choice([0, 1, 2, 3, 4]) observation = np.array([observation]) step_counter = 0 reset = False step_counter += 1 action_index = RL.choose_action(observation) action = ACTIONS[action_index] observation = observation.reshape([]) ori_observation = deepcopy(observation) observation_ ,reward = get_env_feedback(observation, action) observation = ori_observation reward = float(reward) RL.store_transition(observation, action_index, reward) if type(observation_) == type(""): if debug: update_env(observation_, i_episode, step_counter) reset = True ep_rs_sum = sum(RL.ep_rs) if len(RL.ep_rs) > 1000: vt = RL.learn() break observation = observation_ observation = np.array([observation]) if i_episode % 10000 == 0: print("epoch {} end".format(i_episode)) RL.valid_distribution() if __name__ == "__main__": pass
结果说明,在进行网络构建过程中,有两个要点:
1、 由于是学习每一个位置下所选择的位移方向分布,要在网络初始化过程中使得初始概率分布尽量接近离散均匀分布。
要想获得这样的分布一般有两种选择,单层,小范围均匀初始化权值,或多层初始化。
2、 浅层网络,进行reward折旧的权值,不要离0太近,由于整个体系的可选特征空间比较小,使得,当折旧率太小时,模型的“没有远见”表现得太明显,如对于total_conclusion 当某一行
的某一列概率值特别大时,该列的其他行也会表现的很大,从而得到较差的结果,见下面输出:(reward_decay=0.5,single_layer=True)
epoch 0 end total_conclusion : [[0.25 0.25 0.25 0.25 ] [0.24992794 0.25418103 0.24820147 0.24768955] [0.24982986 0.25840503 0.2463902 0.24537486] [0.24970558 0.26267165 0.24456646 0.2430563 ] [0.24955492 0.2669804 0.24273053 0.24073412]] epoch 10000 end total_conclusion : [[0.21704012 0.49736094 0.16374025 0.12185866] [0.18894796 0.25690013 0.364241 0.18991089] [0.11720868 0.09455233 0.57734776 0.21089128] [0.05784947 0.02768868 0.72812897 0.18633284] [0.0255025 0.00724229 0.82020557 0.14704964]] epoch 20000 end total_conclusion : [[1.1902366e-01 6.9792479e-01 1.2305441e-01 5.9997093e-02] [8.7425530e-02 2.9234013e-01 5.0801122e-01 1.1222315e-01] [2.5749996e-02 4.9102332e-02 8.4097552e-01 8.4172145e-02] [5.1554046e-03 5.6061218e-03 9.4632423e-01 4.2914219e-02] [9.4831176e-04 5.8806542e-04 9.7836179e-01 2.0101879e-02]] epoch 30000 end total_conclusion : [[6.2642455e-02 8.3603925e-01 7.0625499e-02 3.0692697e-02] [5.2187711e-02 3.8024789e-01 4.6989158e-01 9.7672820e-02] [1.1900097e-02 4.7335804e-02 8.5569036e-01 8.5073665e-02] [1.6536257e-03 3.5910148e-03 9.4959885e-01 4.5156576e-02] [2.1310350e-04 2.5264540e-04 9.7730553e-01 2.2228684e-02]] epoch 40000 end total_conclusion : [[3.7692133e-02 8.9608622e-01 4.9539912e-02 1.6681695e-02] [3.5235319e-02 4.1345775e-01 4.7005072e-01 8.1256196e-02] [6.4846254e-03 3.7557058e-02 8.7803781e-01 7.7920519e-02] [6.9405942e-04 1.9840726e-03 9.5386559e-01 4.3456275e-02] [7.0038055e-05 9.8820849e-05 9.7698158e-01 2.2849590e-02]] epoch 50000 end total_conclusion : [[2.4752773e-02 9.2907399e-01 3.6719639e-02 9.4535714e-03] [2.6470292e-02 4.4526717e-01 4.5918724e-01 6.9075264e-02] [4.3625305e-03 3.2887880e-02 8.8496488e-01 7.7784650e-02] [4.0026163e-04 1.3523112e-03 9.4948435e-01 4.8763059e-02] [3.4996196e-05 5.2989431e-05 9.7078079e-01 2.9131269e-02]] epoch 60000 end total_conclusion : [[1.7649898e-02 9.4712859e-01 2.9592250e-02 5.6293383e-03] [2.1068770e-02 4.6587506e-01 4.5520204e-01 5.7854064e-02] [3.2033925e-03 2.9187966e-02 8.9187586e-01 7.5732864e-02] [2.6343120e-04 9.8906469e-04 9.4512820e-01 5.3619299e-02] [2.0838543e-05 3.2239463e-05 9.6342945e-01 3.6517467e-02]]
reward_decay=0.9, single_layer=True 时
epoch 0 end total_conclusion : [[0.25 0.25 0.25 0.25 ] [0.24992794 0.25418103 0.24820147 0.24768955] [0.24982986 0.25840503 0.2463902 0.24537486] [0.24970558 0.26267165 0.24456646 0.2430563 ] [0.24955492 0.2669804 0.24273053 0.24073412]] epoch 10000 end total_conclusion : [[0.2541749 0.4339638 0.18223304 0.12962824] [0.31923154 0.29700777 0.1915905 0.19217019] [0.36765605 0.18639956 0.1847071 0.2612373 ] [0.3943957 0.10896233 0.1658623 0.3307797 ] [0.40119484 0.06040053 0.14123574 0.39716887]] epoch 20000 end total_conclusion : [[0.17127652 0.5577266 0.19322462 0.07777215] [0.21130787 0.3176208 0.3229146 0.14815669] [0.20633315 0.14316341 0.42711854 0.22338496] [0.17253096 0.05525859 0.48378655 0.28842393] [0.13284592 0.01964042 0.5045944 0.34291923]] epoch 30000 end total_conclusion : [[0.13867429 0.63206965 0.17291538 0.05634059] [0.18240736 0.3259061 0.35483548 0.13685109] [0.16338208 0.11442888 0.49583372 0.22635537] [0.11672037 0.03204491 0.5526184 0.29861635] [0.07565267 0.00814175 0.5587916 0.35741407]] epoch 40000 end total_conclusion : [[0.10081001 0.71501565 0.14359678 0.04057756] [0.13471448 0.39188603 0.34479922 0.12860028] [0.11042301 0.13174638 0.50783485 0.24999571] [0.06612735 0.03235891 0.5464557 0.35505807] [0.03474245 0.00697279 0.51587576 0.44240898]] epoch 50000 end total_conclusion : [[0.07768991 0.7635396 0.12839992 0.03037059] [0.10534117 0.41834876 0.355451 0.12085913] [0.07775372 0.12477711 0.53565377 0.26181543] [0.03906842 0.02533454 0.5495034 0.38609374] [0.01695424 0.00444262 0.4868603 0.4917428 ]] epoch 60000 end total_conclusion : [[0.05408907 0.8204231 0.10363895 0.02184892] [0.07231493 0.50978357 0.31268135 0.10522011] [0.05188116 0.16997969 0.50622576 0.27191344] [0.02303072 0.03506908 0.5071108 0.43478936] [0.00837533 0.00592719 0.41615787 0.5695396 ]] epoch 70000 end total_conclusion : [[0.05169893 0.82091075 0.10882547 0.0185649 ] [0.07353021 0.44676957 0.38063854 0.09906169] [0.04737122 0.11013763 0.60305846 0.23943269] [0.01917202 0.0170566 0.6002202 0.36355114] [0.00669014 0.00227753 0.51508164 0.47595078]] epoch 80000 end total_conclusion : [[0.04197544 0.84683263 0.09752993 0.01366197] [0.06214712 0.49029282 0.3653582 0.08220191] [0.04109263 0.12677413 0.6112474 0.22088583] [0.01621071 0.01955695 0.6101133 0.35411903] [0.00539158 0.00254359 0.5134281 0.47863674]] epoch 90000 end total_conclusion : [[0.0396813 0.864847 0.08391145 0.01156029] [0.06822652 0.5031161 0.34562966 0.08302777] [0.04827514 0.12044813 0.58587325 0.24540342] [0.01917449 0.01618683 0.55747616 0.40716252] [0.00626419 0.00178922 0.43630424 0.55564237]] epoch 100000 end total_conclusion : [[0.03582421 0.8740298 0.08067635 0.00946964] [0.06337273 0.50902355 0.35384497 0.07375869] [0.04422295 0.11694155 0.6122082 0.22662738] [0.01701888 0.01481624 0.58414865 0.3840162 ] [0.00538392 0.00154309 0.45817512 0.5348978 ]] epoch 110000 end total_conclusion : [[3.7808131e-02 8.7277049e-01 8.0167063e-02 9.2543894e-03] [7.3226146e-02 4.8032302e-01 3.6108947e-01 8.5361369e-02] [5.0292756e-02 9.3740024e-02 5.7675588e-01 2.7921137e-01] [1.8301729e-02 9.6931336e-03 4.8810893e-01 4.8389620e-01] [5.2883681e-03 7.9587736e-04 3.2800779e-01 6.6590798e-01]] epoch 120000 end total_conclusion : [[3.45975310e-02 8.83603871e-01 7.44067430e-02 7.39190215e-03] [7.09570944e-02 5.02655447e-01 3.52849692e-01 7.35377371e-02] [5.13085201e-02 1.00815214e-01 5.89943051e-01 2.57933199e-01] [1.90419760e-02 1.03779277e-02 5.06243348e-01 4.64336753e-01] [5.52772358e-03 8.35616316e-04 3.39797616e-01 6.53839052e-01]] epoch 130000 end total_conclusion : [[3.4428395e-02 8.8025826e-01 7.8599989e-02 6.7133084e-03] [7.1423702e-02 4.6414855e-01 3.9321566e-01 7.1212135e-02] [4.7560442e-02 7.8556448e-02 6.3141793e-01 2.4246517e-01] [1.6806141e-02 7.0554558e-03 5.3804874e-01 4.3808973e-01] [4.7259689e-03 5.0427718e-04 3.6486074e-01 6.2990904e-01]] epoch 140000 end total_conclusion : [[2.9190067e-02 8.9623755e-01 6.9009550e-02 5.5628545e-03] [6.3767701e-02 5.0694317e-01 3.6289334e-01 6.6395715e-02] [4.4551507e-02 9.1704644e-02 6.1030209e-01 2.5344178e-01] [1.5246476e-02 8.1258537e-03 5.0275469e-01 4.7387296e-01] [3.9947815e-03 5.5126823e-04 3.1709099e-01 6.7836297e-01]] epoch 150000 end total_conclusion : [[2.61038337e-02 9.02455091e-01 6.66292384e-02 4.81185177e-03] [5.75109012e-02 5.16855896e-01 3.64215791e-01 6.14175014e-02] [3.96257900e-02 9.25753191e-02 6.22636378e-01 2.45162502e-01] [1.30827725e-02 7.94538576e-03 5.10039926e-01 4.68931943e-01] [3.27287335e-03 5.16704924e-04 3.16578746e-01 6.79631650e-01]] epoch 160000 end total_conclusion : [[2.3320269e-02 9.0283412e-01 6.9880642e-02 3.9650649e-03] [4.9500983e-02 4.9232301e-01 4.0799561e-01 5.0180390e-02] [3.0989032e-02 7.9178333e-02 7.0253527e-01 1.8729736e-01] [9.9952416e-03 6.5607498e-03 6.2326348e-01 3.6018053e-01] [2.5804520e-03 4.3512895e-04 4.4258085e-01 5.5440360e-01]] epoch 170000 end total_conclusion : [[2.0959368e-02 9.1501921e-01 6.0480446e-02 3.5409157e-03] [4.7873363e-02 5.4098064e-01 3.5951683e-01 5.1629137e-02] [3.2945268e-02 9.6364401e-02 6.4388281e-01 2.2680752e-01] [1.0355502e-02 7.8402609e-03 5.2671248e-01 4.5509183e-01] [2.4148514e-03 4.7324455e-04 3.1965497e-01 6.7745697e-01]] epoch 180000 end total_conclusion : [[2.0037860e-02 9.1648364e-01 6.0523044e-02 2.9554151e-03] [4.7534682e-02 5.0857818e-01 4.0178594e-01 4.2101104e-02] [3.0792855e-02 7.7067420e-02 7.2836429e-01 1.6377541e-01] [1.0028366e-02 5.8711735e-03 6.6380942e-01 3.2029104e-01] [2.6443410e-03 3.6214772e-04 4.8983011e-01 5.0716347e-01]] epoch 190000 end total_conclusion : [[2.0515777e-02 9.1274059e-01 6.3958667e-02 2.7848831e-03] [4.9688257e-02 4.6594495e-01 4.4379443e-01 4.0572379e-02] [2.9871482e-02 5.9041731e-02 7.6436615e-01 1.4672065e-01] [9.5903184e-03 3.9953678e-03 7.0306307e-01 2.8335130e-01] [2.5717402e-03 2.2582512e-04 5.4013824e-01 4.5706421e-01]] epoch 200000 end total_conclusion : [[1.9618945e-02 9.1322052e-01 6.4400353e-02 2.7601693e-03] [4.8081201e-02 4.6147650e-01 4.4375980e-01 4.6682552e-02] [2.8066944e-02 5.5544838e-02 7.2832942e-01 1.8805878e-01] [8.2912296e-03 3.3833098e-03 6.0493928e-01 3.8338622e-01] [1.9035551e-03 1.6016315e-04 3.9049777e-01 6.0743850e-01]] epoch 210000 end total_conclusion : [[1.6461434e-02 9.1983438e-01 6.1292339e-02 2.4118708e-03] [3.8809031e-02 4.7240821e-01 4.4620728e-01 4.2575449e-02] [2.1110723e-02 5.5979695e-02 7.4950135e-01 1.7340830e-01] [5.7899337e-03 3.3445982e-03 6.3475835e-01 3.5610709e-01] [1.2497237e-03 1.5726314e-04 4.2307195e-01 5.7552105e-01]] epoch 220000 end total_conclusion : [[1.5052675e-02 9.2754936e-01 5.5169675e-02 2.2282486e-03] [3.8371310e-02 4.9669358e-01 4.1841790e-01 4.6517245e-02] [2.1696571e-02 5.8997225e-02 7.0390111e-01 2.1540506e-01] [5.5740857e-03 3.1839970e-03 5.3803563e-01 4.5320624e-01] [1.0480476e-03 1.2575880e-04 3.0097839e-01 6.9784772e-01]] epoch 230000 end total_conclusion : [[1.2719852e-02 9.3808472e-01 4.7526378e-02 1.6690409e-03] [3.3230912e-02 5.5534708e-01 3.7758598e-01 3.3835981e-02] [2.1167737e-02 8.0160193e-02 7.3142356e-01 1.6724853e-01] [5.9436020e-03 5.1003071e-03 6.2454700e-01 3.6440909e-01] [1.2554835e-03 2.4412917e-04 4.0118751e-01 5.9731287e-01]] epoch 240000 end total_conclusion : [[1.19138835e-02 9.40711200e-01 4.58427444e-02 1.53206289e-03] [3.24871019e-02 5.54068565e-01 3.79663557e-01 3.37807797e-02] [2.05819905e-02 7.58209676e-02 7.30543077e-01 1.73053950e-01] [5.63108036e-03 4.48067114e-03 6.07044756e-01 3.82843435e-01] [1.13851484e-03 1.95676825e-04 3.72767657e-01 6.25898123e-01]] epoch 250000 end total_conclusion : [[1.0488704e-02 9.4734508e-01 4.0677190e-02 1.4889601e-03] [2.9026536e-02 5.9247094e-01 3.3981869e-01 3.8683772e-02] [1.8703880e-02 8.6275846e-02 6.6100842e-01 2.3401189e-01] [4.4211950e-03 4.6087466e-03 4.7166994e-01 5.1930004e-01] [7.0127757e-04 1.6520331e-04 2.2584547e-01 7.7328807e-01]] epoch 260000 end total_conclusion : [[1.1348815e-02 9.4073093e-01 4.6427861e-02 1.4923185e-03] [3.2144837e-02 5.1513839e-01 4.1301119e-01 3.9705586e-02] [1.7839961e-02 5.5271905e-02 7.1989143e-01 2.0699683e-01] [4.2135976e-03 2.5238392e-03 5.3400928e-01 4.5925325e-01] [7.0275116e-04 8.1378137e-05 2.7971739e-01 7.1949846e-01]] epoch 270000 end total_conclusion : [[1.0846548e-02 9.4296771e-01 4.4784609e-02 1.4010831e-03] [3.1814110e-02 5.2445072e-01 4.0349430e-01 4.0240850e-02] [1.8027861e-02 5.6351926e-02 7.0233160e-01 2.2328855e-01] [4.1229795e-03 2.4437420e-03 4.9338910e-01 5.0004423e-01] [6.4254826e-04 7.2215320e-05 2.3619159e-01 7.6309371e-01]] epoch 280000 end total_conclusion : [[9.1504268e-03 9.4728011e-01 4.2469598e-02 1.0999093e-03] [2.6146499e-02 5.3540409e-01 4.0784016e-01 3.0609239e-02] [1.4519225e-02 5.8808822e-02 7.6113093e-01 1.6554105e-01] [3.4599402e-03 2.7720346e-03 6.0957092e-01 3.8419718e-01] [5.9711642e-04 9.4628042e-05 3.5355282e-01 6.4575541e-01]] epoch 290000 end total_conclusion : [[7.7553848e-03 9.5317161e-01 3.8087185e-02 9.8578667e-04] [2.2035997e-02 5.7590652e-01 3.7154385e-01 3.0513545e-02] [1.2574041e-02 6.9878809e-02 7.2786993e-01 1.8967718e-01] [2.7378439e-03 3.2354214e-03 5.4411316e-01 4.4991362e-01] [4.0424295e-04 1.0158180e-04 2.7581939e-01 7.2367477e-01]] epoch 300000 end total_conclusion : [[7.4513480e-03 9.5213538e-01 3.9498676e-02 9.1447186e-04] [2.0736352e-02 5.5535841e-01 3.9457828e-01 2.9326942e-02] [1.0962928e-02 6.1538193e-02 7.4882549e-01 1.7867348e-01] [2.2978734e-03 2.7034695e-03 5.6342137e-01 4.3157732e-01] [3.2832366e-04 8.0960788e-05 2.8897664e-01 7.1061409e-01]]
reward_decay=0.9, single_layer=False 时
epoch 0 end total_conclusion : [[0.24272624 0.2664516 0.2431751 0.24764712] [0.25305024 0.25117907 0.23963916 0.25613153] [0.25990137 0.24152242 0.23768328 0.26089293] [0.26320875 0.2369366 0.23683953 0.26301512] [0.2643011 0.23534094 0.23657104 0.263787 ]] epoch 10000 end total_conclusion : [[0.01996518 0.96438617 0.00703121 0.00861744] [0.06497823 0.8767139 0.02951219 0.02879566] [0.20337433 0.01869064 0.53549784 0.24243721] [0.17275733 0.00595712 0.58486766 0.23641789] [0.17022285 0.00502265 0.5905856 0.23416895]] epoch 20000 end total_conclusion : [[0.00977914 0.9852465 0.00305074 0.00192364] [0.04920592 0.91269904 0.02295803 0.01513704] [0.1082299 0.00256753 0.5223677 0.36683488] [0.10114782 0.00176557 0.51610726 0.38097933] [0.10053636 0.00169165 0.51489586 0.38287607]] epoch 30000 end total_conclusion : [[0.01595607 0.9805597 0.00194518 0.00153896] [0.02344857 0.97012335 0.00357197 0.00285605] [0.05884041 0.00577983 0.67154795 0.26383168] [0.0345589 0.00144537 0.6769783 0.28701743] [0.03300219 0.00128465 0.67455024 0.291163 ]] epoch 40000 end total_conclusion : [[1.4459175e-02 9.8325908e-01 1.3620944e-03 9.1964524e-04] [2.0464014e-02 9.7535014e-01 2.4464938e-03 1.7393864e-03] [3.4550514e-02 2.8977739e-03 7.1116060e-01 2.5139111e-01] [2.1084290e-02 9.5899979e-04 6.8814623e-01 2.8981048e-01] [2.0119997e-02 8.7089848e-04 6.8306744e-01 2.9594159e-01]] epoch 50000 end total_conclusion : [[9.5187183e-03 9.8891360e-01 1.0428770e-03 5.2477716e-04] [1.3581087e-02 9.8361140e-01 1.8001809e-03 1.0072116e-03] [3.2373980e-02 1.6586123e-03 7.0518827e-01 2.6077914e-01] [2.0910097e-02 6.2694255e-04 6.5031850e-01 3.2814443e-01] [2.0023042e-02 5.7456945e-04 6.4101899e-01 3.3838332e-01]] epoch 60000 end total_conclusion : [[9.6859140e-03 9.8901761e-01 9.3515235e-04 3.6124868e-04] [1.2463641e-02 9.8555148e-01 1.3848786e-03 5.9998740e-04] [2.5720207e-02 1.6818406e-03 7.3695904e-01 2.3563893e-01] [1.4572582e-02 5.1689544e-04 6.3634038e-01 3.4857017e-01] [1.3729123e-02 4.6342265e-04 6.1854339e-01 3.6726403e-01]] epoch 70000 end total_conclusion : [[8.9608859e-03 9.8972487e-01 1.0660614e-03 2.4815302e-04] [2.4359480e-02 9.6856683e-01 5.6313323e-03 1.4423265e-03] [1.7355645e-02 7.6086522e-04 7.0904922e-01 2.7283433e-01] [1.3102032e-02 4.5217722e-04 5.7744300e-01 4.0900281e-01] [1.2477216e-02 4.1652124e-04 5.5157876e-01 4.3552756e-01]] epoch 80000 end total_conclusion : [[1.21220425e-02 9.86533999e-01 1.09698437e-03 2.46952492e-04] [2.35213954e-02 9.72076714e-01 3.56799131e-03 8.33907863e-04] [1.36275385e-02 9.08883812e-04 8.48975301e-01 1.36488259e-01] [9.87306144e-03 4.85680881e-04 6.95991516e-01 2.93649793e-01] [9.20754019e-03 4.30718996e-04 6.44228756e-01 3.46133053e-01]] epoch 90000 end total_conclusion : [[1.1520125e-02 9.8696262e-01 1.3372508e-03 1.7998939e-04] [1.3917091e-02 9.8403889e-01 1.7830743e-03 2.6083150e-04] [2.3827225e-02 2.6155615e-03 8.6892623e-01 1.0463096e-01] [1.1177201e-02 5.5414089e-04 5.4745471e-01 4.4081402e-01] [9.1576474e-03 4.1186428e-04 4.4888943e-01 5.4154110e-01]] epoch 100000 end total_conclusion : [[1.1328434e-02 9.8669565e-01 1.8585753e-03 1.1740912e-04] [1.3247624e-02 9.8427898e-01 2.3172286e-03 1.5621875e-04] [2.6637968e-02 3.3733707e-03 8.9636964e-01 7.3619053e-02] [9.7785229e-03 4.8096036e-04 3.5935339e-01 6.3038719e-01] [7.1577146e-03 3.1912065e-04 2.5883538e-01 7.3368782e-01]] epoch 110000 end total_conclusion : [[1.0035030e-02 9.8829287e-01 1.5927270e-03 7.9357873e-05] [1.1977794e-02 9.8583788e-01 2.0754482e-03 1.0889151e-04] [1.4448395e-02 1.5947536e-03 9.2113924e-01 6.2817626e-02] [6.0290424e-03 2.9274501e-04 3.2462633e-01 6.6905183e-01] [4.6168631e-03 2.0721863e-04 2.3333853e-01 7.6183736e-01]] epoch 120000 end total_conclusion : [[9.5429290e-03 9.8864740e-01 1.7501512e-03 5.9503229e-05] [1.0475075e-02 9.8746568e-01 1.9891693e-03 7.0009039e-05] [2.1429989e-02 4.0749395e-03 9.4946623e-01 2.5028812e-02] [6.0653742e-03 2.9634649e-04 2.7282944e-01 7.2080880e-01] [3.8274133e-03 1.6645205e-04 1.6236641e-01 8.3363974e-01]] epoch 130000 end total_conclusion : [[7.3853144e-03 9.9109393e-01 1.4747812e-03 4.6016314e-05] [7.8640431e-03 9.9048471e-01 1.6001627e-03 5.1146781e-05] [3.3291042e-02 1.1841053e-02 9.3784636e-01 1.7021498e-02] [6.3327486e-03 2.6507085e-04 2.6024795e-01 7.3315424e-01] [3.6769619e-03 1.3145835e-04 1.4505264e-01 8.5113889e-01]] epoch 140000 end total_conclusion : [[6.7819259e-03 9.9175429e-01 1.4199471e-03 4.3870627e-05] [7.4995244e-03 9.9082559e-01 1.6232599e-03 5.1661998e-05] [1.9932389e-02 2.3596454e-03 9.4132286e-01 3.6385126e-02] [4.6057180e-03 1.5399103e-04 1.8117251e-01 8.1406778e-01] [3.6281743e-03 1.1328981e-04 1.4112492e-01 8.5513365e-01]] epoch 150000 end total_conclusion : [[7.8055235e-03 9.9048704e-01 1.6636901e-03 4.3849177e-05] [8.3282758e-03 9.8981315e-01 1.8098198e-03 4.8767695e-05] [2.6316063e-02 6.7005209e-03 9.5382625e-01 1.3157116e-02] [6.0292105e-03 2.4533601e-04 2.4495459e-01 7.4877083e-01] [3.3822318e-03 1.2004423e-04 1.2837136e-01 8.6812639e-01]] epoch 160000 end total_conclusion : [[8.25789664e-03 9.89703715e-01 2.00118846e-03 3.71836068e-05] [9.23451316e-03 9.88391459e-01 2.32919818e-03 4.47779967e-05] [1.51501503e-02 1.16333005e-03 8.83565664e-01 1.00120865e-01] [2.56293686e-03 9.47155131e-05 8.89582932e-02 9.08384085e-01] [2.28234520e-03 8.22104776e-05 7.81738386e-02 9.19461608e-01]] epoch 170000 end total_conclusion : [[8.3111525e-03 9.8970544e-01 1.9454649e-03 3.8032471e-05] [8.7420335e-03 9.8913711e-01 2.0793625e-03 4.1358129e-05] [2.0256471e-02 4.4284128e-03 9.6530449e-01 1.0010734e-02] [4.4577671e-03 1.7727882e-04 2.0603830e-01 7.8932661e-01] [2.4488093e-03 8.8382440e-05 9.6857719e-02 9.0060514e-01]] epoch 180000 end total_conclusion : [[6.74042059e-03 9.91177499e-01 2.05093599e-03 3.11263211e-05] [7.52371922e-03 9.90062356e-01 2.37671356e-03 3.72426912e-05] [1.34547865e-02 8.92202544e-04 8.75965297e-01 1.09687738e-01] [2.35858560e-03 7.37307128e-05 7.19583929e-02 9.25609291e-01] [2.14538001e-03 6.55836702e-05 6.43573254e-02 9.33431685e-01]] epoch 190000 end total_conclusion : [[4.7103683e-03 9.9364752e-01 1.6120900e-03 2.9957560e-05] [4.9255760e-03 9.9333936e-01 1.7029172e-03 3.2125514e-05] [2.5764411e-02 9.7385934e-03 9.5587242e-01 8.6245174e-03] [7.0717232e-03 1.9899267e-04 2.5494295e-01 7.3778635e-01] [3.3879858e-03 7.7527067e-05 9.4206490e-02 9.0232801e-01]] epoch 200000 end total_conclusion : [[5.5062859e-03 9.9220121e-01 2.2599392e-03 3.2476153e-05] [6.4469520e-03 9.9072838e-01 2.7833492e-03 4.1395240e-05] [1.2959373e-02 8.8687212e-04 9.4337398e-01 4.2779762e-02] [3.4274671e-03 8.8625340e-05 8.0753848e-02 9.1573012e-01] [2.9254081e-03 7.2692987e-05 6.6195659e-02 9.3080622e-01]] epoch 210000 end total_conclusion : [[6.7050313e-03 9.9026489e-01 2.9965141e-03 3.3478875e-05] [7.4811736e-03 9.8902625e-01 3.4527977e-03 3.9731662e-05] [1.1403289e-02 1.2215463e-03 9.7320312e-01 1.4172052e-02] [3.3910186e-03 1.0862275e-04 7.8053258e-02 9.1844708e-01] [2.4969408e-03 7.4855518e-05 5.2599501e-02 9.4482875e-01]] epoch 220000 end total_conclusion : [[8.1366943e-03 9.8807412e-01 3.7588137e-03 3.0369394e-05] [8.8895541e-03 9.8685741e-01 4.2179059e-03 3.5078654e-05] [1.0125207e-02 1.4564092e-03 9.7747517e-01 1.0943188e-02] [2.3958066e-03 9.7383920e-05 5.6047663e-02 9.4145912e-01] [1.6963906e-03 6.4517873e-05 3.5769559e-02 9.6246952e-01]] epoch 230000 end total_conclusion : [[1.0783126e-02 9.8507023e-01 4.1133207e-03 3.3298995e-05] [1.1605123e-02 9.8381996e-01 4.5371768e-03 3.7690890e-05] [1.0146837e-02 1.9443299e-03 9.7945261e-01 8.4562581e-03] [2.1244360e-03 1.1519293e-04 6.1244793e-02 9.3651557e-01] [1.3985953e-03 7.0800306e-05 3.6053002e-02 9.6247756e-01]] epoch 240000 end total_conclusion : [[1.3546581e-02 9.8178297e-01 4.6384856e-03 3.1966913e-05] [1.4745870e-02 9.8001206e-01 5.2051023e-03 3.6996076e-05] [1.1386400e-02 1.4909450e-03 9.6796215e-01 1.9160511e-02] [1.2758711e-03 7.5146745e-05 3.2815725e-02 9.6583331e-01] [1.0705406e-03 6.1382270e-05 2.6746141e-02 9.7212189e-01]] epoch 250000 end total_conclusion : [[1.8375706e-02 9.7645175e-01 5.1351385e-03 3.7367889e-05] [1.9848045e-02 9.7437811e-01 5.7308432e-03 4.2954627e-05] [1.0006342e-02 1.5601777e-03 9.7213757e-01 1.6295882e-02] [1.1278926e-03 8.7725217e-05 3.5601538e-02 9.6318287e-01] [9.2888821e-04 7.0503593e-05 2.8413255e-02 9.7058737e-01]]
最后一种是我们要的收敛结果。(这里的特征维度还仅仅是位置,就不能用浅层网路了)
这种优化方法也是非常简单的一种,较粗糙的,通过对比迭代步数会发现,
我们这里要迭代200000 epoch 才可能有较满意的结果,而每一次进行learn又需要积累至少1000个观察样本,这在性能及效果上不一定能够比每一步进行迭代的Q-Learning好,对于有限制的小的可数action空间,使用Q-learning是更好的选择。