Q-learning之一维世界的简单寻宝

Q-learning的算法:

(1)先初始化一个Q table,Q table的行数是state的个数,列数是action的个数。

(2)先随机选择一个作为初始状态S1,根据一些策略选择此状态下的动作,比如贪心策略,假设选择的动作为A1。

(3)判断由A1动作之后的状态S2是不是终止状态,如果是终止状态,返回的reward,相当于找到了宝藏,游戏结束,如果不是最终状态,在S2状态时选择此时使Q值最大的action作为下一步的动作。可以得到一个实际的Q值。Q(S1,A1)=R+λ*maxQ(S2)。更新Q table中的Q(S1,A1)。Q(S1,A1)=Q(S1,A1)+α*[R+λ*maxQ(S2)-Q(S1,A1)], []里面是实际的Q值减去估计的Q值。

简单的代码如下:

 1 #coding=utf-8
 2 import numpy as np
 3 import pandas as pd
 4 import time
 5 #计算机产生一段伪随机数,每次运行的时候产生的随机数都是一样的
 6 np.random.seed(2)
 7 #创建几个全局变量
 8 N_STATES=6#状态的个数,一共有六个状态0-5状态
 9 ACTIONS=["left","right"]#action只有两个左和右
10 EPSILON=0.9#贪心策略
11 ALPHA=0.1#学习率
12 LAMBDA=0.9#discount factor
13 MAX_EPISODEs=10#一共训练10次
14 FRESH_TIME=0.1
15 #初始化一个Q-table,我觉得Q-table里面的值初始化成什么样子应该不影响最终的结果
16 def build_q_table(n_states,actions):
17     table=pd.DataFrame(
18         np.zeros((n_states,len(actions))),
19         columns=actions,
20     )
21     # print(table)
22     return(table)
23 # build_q_table(N_STATES,ACTIONS)
24 def choose_action(state,q_table):
25     state_action=q_table.iloc[state,:]
26     if (np.random.uniform()>EPSILON) or (state_action.all()==0):
27         action_name=np.random.choice(ACTIONS)
28     else:
29         action_name=state_action.idxmax()
30     return action_name
31 def get_env_feedback(s,A):
32     if A=="right":
33         if s==N_STATES-2:
34             s_="terminal"
35             R=1
36         else:
37             s_=s+1
38             R=0
39     else:
40         R=0
41         if s==0:
42             s_=s
43         else:
44             s_=s-1
45     return s_,R
46 def update_env(S,episode,step_couter):
47     env_list=["-"]*(N_STATES-1)+["T"]
48     if S=="terminal":
49         interaction="Episode %s:total_steps=%s"%(episode+1,step_couter)
50         print("\r{}".format(interaction),end='')
51         time.sleep(2)
52         print('\r                        ',end='')
53     else:
54         env_list[S]='0'
55         interaction=''.join(env_list)
56         print("\r{}".format(interaction),end='')
57         time.sleep(FRESH_TIME)
58 def rl():
59     #先初始化一个Q table
60     q_table=build_q_table(N_STATES,ACTIONS)
61     for episode in range(MAX_EPISODEs):
62         step_counter=0
63         #选择一个初始的S
64         S=0
65         is_terminal=False
66         update_env(S,episode,step_counter)
67         #如果S不是终止状态的话,选择动作,得到环境给出的一个反馈S_(新的状态)和R(奖励)
68         while not is_terminal:
69             A=choose_action(S,q_table)
70             S_,R=get_env_feedback(S,A)
71             q_predict=q_table.ix[S,A]
72             if S_!="terminal":
73                 #算出来实际的Q值
74                 q_target=R+LAMBDA*q_table.iloc[S_,:].max()
75             else:
76                 q_target=R
77                 is_terminal=True
78             q_table.ix[S,A]+=ALPHA*(q_target-q_predict)
79             S=S_
80             update_env(
81                 S,episode,step_counter+1
82             )
83             step_counter=step_counter+1
84     return q_table
85 
86 if __name__=="__main__":
87     q_table=rl()
88     print("\r\nQ-table:\n")
89     print(q_table)

猜你喜欢

转载自www.cnblogs.com/hellojiaojiao/p/11352796.html
今日推荐