一、环境构建
搭建一个简单的迷宫环境,红色位置出发,黑色位置代表失败,黄色位置代表成功,让红色块慢慢通过不断探索学习的方式走到黄色的位置
#初始化迷宫
def _build_maze(self):
h = self.MAZE_H*self.UNIT
w = self.MAZE_W*self.UNIT
#初始化画布
self.canvas = tk.Canvas(self, bg='white', height=h, width=w)
#画线
for c in range(0, w, self.UNIT):
self.canvas.create_line(c, 0, c, h)
for r in range(0, h, self.UNIT):
self.canvas.create_line(0, r, w, r)
# 陷阱
self.hells = [self._draw_rect(3, 2, 'black'),
self._draw_rect(3, 3, 'black'),
self._draw_rect(3, 4, 'black'),
self._draw_rect(3, 5, 'black'),
self._draw_rect(4, 1, 'black'),
self._draw_rect(4, 5, 'black'),
self._draw_rect(1, 0, 'black'),
self._draw_rect(1, 1, 'black'),
self._draw_rect(1, 2, 'black'),
self._draw_rect(1, 3, 'black'),
self._draw_rect(1, 4, 'black')]
self.hell_coords = []
for hell in self.hells:
self.hell_coords.append(self.canvas.coords(hell))
# 奖励
self.oval = self._draw_rect(4, 5, 'yellow')
# 玩家对象
self.rect = self._draw_rect(0, 0, 'red')
self.canvas.pack()
然后就是实现走迷宫的动作了,“上下左右”走到对应的位置得到不同的结果,如果走到了黑块就得到-1的惩罚并结束回合,走到黄块得到1的奖励并结束回合,当然需要返回当前的行走策略得到的奖励(或惩罚)
def step(self, action):
s = self.canvas.coords(self.rect)
base_action = np.array([0, 0])
if action == 0: # up
if s[1] > self.UNIT:
base_action[1] -= self.UNIT
elif action == 1: # down
if s[1] < (self.MAZE_H - 1) * self.UNIT:
base_action[1] += self.UNIT
elif action == 2: # right
if s[0] < (self.MAZE_W - 1) * self.UNIT:
base_action[0] += self.UNIT
elif action == 3: # left
if s[0] > self.UNIT:
base_action[0] -= self.UNIT
#根据策略移动红块
self.canvas.move(self.rect, base_action[0], base_action[1])
s_ = self.canvas.coords(self.rect)
#判断是否得到奖励或惩罚
done = False
if s_ == self.canvas.coords(self.oval):
reward = 1
done = True
elif s_ in self.hell_coords:
reward = -1
done = True
#elif base_action.sum() == 0:
# reward = -1
else:
reward = 0
self.old_s = s
return s_, reward, done
二、实现Q Learning
Q-Learning的原理很简单,就是用一张Q表来记录每个状态下取不同的策略(action)的权值,而权值是根据历史经验(得到的奖励、惩罚)来不断更新得到的
这是根据Q表来得到价值最高的步骤,当然为了有探索性所以给了一定权重进行完全随机
#选择动作
def choose_action(self, s):
self.check_state_exist(s)
if np.random.uniform() < self.e_greedy:
state_action = self.q_table.ix[s, :]
state_action = state_action.reindex(
np.random.permutation(state_action.index)) #防止相同列值时取第一个列,所以打乱列的顺序
action = state_action.argmax()
else:
action = np.random.choice(self.actions)
return action
另外就是记录当前的状态,下一步的动作,这个动作得到的奖励或惩罚根据这个核心算法更新到Q表中
#更新q表
def rl(self, s, a, r, s_):
self.check_state_exist(s_)
q_predict = self.q_table.ix[s, a] #q估计
if s_ != 'terminal':
q_target = r + self.reward_decay * self.q_table.ix[s_, :].max() #q现实
else:
q_target = r
self.q_table.ix[s, a] += self.learning_rate * (q_target - q_predict)
三、训练实验
训练的步骤是
1、根据当前的状态得到下一个步骤
2、执行这个步骤,得到执行后的状态
3、记录算法计算出的权值
def update():
for episode in range(100):
s = env.reset()
while True:
env.render()
#选择一个动作
action = RL.choose_action(str(s))
#执行这个动作得到反馈(下一个状态s 奖励r 是否结束done)
s_, r, done = env.step(action)
#更新状态表
RL.rl(str(s), action, r, str(s_))
s = s_
if done:
print(episode)
break
我在第25轮的时候得到第一次奖励,等到了第50轮基本就是走最短路径了
四、Q表的解读
这里我们的Q表的数据结构是以动作为列,每一行是对应不同的状态。初始化的时候是这样的:
Empty DataFrame
Columns: [0, 1, 2, 3]
Index: []
0 1 2 3
[5.0, 5.0, 35.0, 35.0] 0.0 0.0 0.0 0.0
这里的[5.0, 5.0, 35.0, 35.0]代表当前我们的在迷宫中的状态(位置)。
接下来我们随机走一格,根据走一格后的结果(奖励)进行更新到当前这个状态中,由于我们的宝箱很远,而陷阱很多,所以如果走到陷阱的时候会得到一个负奖励,如:
0 1 2 3
[5.0, 5.0, 35.0, 35.0] 0.0 0.0 -0.01 0.0
[5.0, 45.0, 35.0, 75.0] 0.0 0.0 -0.01 0.0
[45.0, 5.0, 75.0, 35.0] 0.0 0.0 0.00 0.0
[5.0, 85.0, 35.0, 115.0] 0.0 0.0 -0.01 0.0
[45.0, 85.0, 75.0, 115.0] 0.0 0.0 0.00 0.0
[45.0, 45.0, 75.0, 75.0] 0.0 0.0 0.00 0.0
[5.0, 125.0, 35.0, 155.0] 0.0 0.0 -0.01 0.0
[45.0, 125.0, 75.0, 155.0] 0.0 0.0 0.00 0.0
[5.0, 165.0, 35.0, 195.0] 0.0 0.0 -0.01 0.0
这些为负数的动作代表了这个状态下这么走会遇到陷阱,随着不断的试错,这张表不断的完善,终于拿到第一个奖励(估计在第30轮)
('----------------', '[205.0, 205.0, 235.0, 235.0]', '------------------', 3, '---------------', 1, '----------------', '[165.0, 205.0, 195.0, 235.0]')
0 1 2 3
[5.0, 5.0, 35.0, 35.0] 0.00 0.00 -0.039404 0.000000
[5.0, 45.0, 35.0, 75.0] 0.00 0.00 -0.019900 0.000000
[45.0, 45.0, 75.0, 75.0] 0.00 0.00 0.000000 0.000000
[5.0, 85.0, 35.0, 115.0] 0.00 0.00 -0.019900 0.000000
[45.0, 85.0, 75.0, 115.0] 0.00 0.00 0.000000 0.000000
[45.0, 5.0, 75.0, 35.0] 0.00 0.00 0.000000 0.000000
[5.0, 125.0, 35.0, 155.0] 0.00 0.00 -0.010000 0.000000
[45.0, 125.0, 75.0, 155.0] 0.00 0.00 0.000000 0.000000
[5.0, 165.0, 35.0, 195.0] 0.00 0.00 -0.010000 0.000000
[5.0, 205.0, 35.0, 235.0] 0.00 0.00 0.000000 0.000000
[45.0, 205.0, 75.0, 235.0] -0.01 0.00 0.000000 0.000000
[45.0, 165.0, 75.0, 195.0] 0.00 0.00 0.000000 0.000000
[85.0, 205.0, 115.0, 235.0] 0.00 0.00 -0.010000 0.000000
[125.0, 205.0, 155.0, 235.0] 0.00 0.00 0.000000 0.000000
[85.0, 165.0, 115.0, 195.0] 0.00 0.00 -0.010000 -0.010000
[85.0, 125.0, 115.0, 155.0] 0.00 0.00 -0.010000 -0.010000
[125.0, 125.0, 155.0, 155.0] 0.00 0.00 0.000000 0.000000
[85.0, 85.0, 115.0, 115.0] 0.00 0.00 -0.010000 -0.010000
[125.0, 165.0, 155.0, 195.0] 0.00 0.00 0.000000 0.000000
[85.0, 45.0, 115.0, 75.0] 0.00 0.00 0.000000 -0.029701
[85.0, 5.0, 115.0, 35.0] 0.00 0.00 0.000000 -0.010000
[125.0, 85.0, 155.0, 115.0] 0.00 0.00 0.000000 0.000000
[125.0, 45.0, 155.0, 75.0] 0.00 -0.01 -0.010000 0.000000
[125.0, 5.0, 155.0, 35.0] 0.00 0.00 0.000000 0.000000
[165.0, 5.0, 195.0, 35.0] 0.00 -0.01 0.000000 0.000000
[205.0, 5.0, 235.0, 35.0] 0.00 0.00 0.000000 0.000000
[205.0, 45.0, 235.0, 75.0] 0.00 0.00 0.000000 -0.010000
[165.0, 45.0, 195.0, 75.0] 0.00 0.00 0.000000 0.000000
[205.0, 85.0, 235.0, 115.0] 0.00 0.00 0.000000 0.000000
[205.0, 125.0, 235.0, 155.0] 0.00 0.00 0.000000 0.000000
[165.0, 85.0, 195.0, 115.0] -0.01 0.00 0.000000 -0.010000
[165.0, 125.0, 195.0, 155.0] 0.00 0.00 0.000000 -0.010000
[205.0, 165.0, 235.0, 195.0] 0.00 0.00 0.000000 0.000000
[165.0, 165.0, 195.0, 195.0] 0.00 0.00 0.000000 -0.010000
[205.0, 205.0, 235.0, 235.0] 0.00 0.00 0.000000 0.010000
[165.0, 205.0, 195.0, 235.0] 0.00 0.00 0.000000 0.000000
以这次为分水岭,再继续走会不断根据得到的那个奖励逐步把可以得到奖励的值更新到最优路径上,后期的表会是这样,基本可以一次就走到迷宫上去了,这是我迭代99步的Q表:
0 1 2 \
[5.0, 5.0, 35.0, 35.0] 0.000000e+00 7.215152e-30 -5.851985e-02
[45.0, 5.0, 75.0, 35.0] 0.000000e+00 0.000000e+00 0.000000e+00
[5.0, 45.0, 35.0, 75.0] 0.000000e+00 4.987220e-28 -3.940399e-02
[45.0, 45.0, 75.0, 75.0] 0.000000e+00 0.000000e+00 0.000000e+00
[5.0, 85.0, 35.0, 115.0] 0.000000e+00 3.331252e-26 -1.000000e-02
[45.0, 85.0, 75.0, 115.0] 0.000000e+00 0.000000e+00 0.000000e+00
[5.0, 125.0, 35.0, 155.0] 0.000000e+00 2.139487e-24 -4.900995e-02
[5.0, 165.0, 35.0, 195.0] 0.000000e+00 1.288811e-22 -4.900995e-02
[45.0, 125.0, 75.0, 155.0] 0.000000e+00 0.000000e+00 0.000000e+00
[45.0, 165.0, 75.0, 195.0] 0.000000e+00 0.000000e+00 0.000000e+00
[5.0, 205.0, 35.0, 235.0] 0.000000e+00 0.000000e+00 7.144125e-21
[45.0, 205.0, 75.0, 235.0] -2.970100e-02 0.000000e+00 3.617735e-19
[85.0, 205.0, 115.0, 235.0] 1.668205e-17 0.000000e+00 -3.940399e-02
[85.0, 165.0, 115.0, 195.0] 6.976738e-16 9.453308e-30 -1.990000e-02
[125.0, 205.0, 155.0, 235.0] 0.000000e+00 0.000000e+00 0.000000e+00
[85.0, 125.0, 115.0, 155.0] 2.633329e-14 0.000000e+00 -2.970100e-02
[85.0, 85.0, 115.0, 115.0] 8.916751e-13 0.000000e+00 -1.000000e-02
[85.0, 45.0, 115.0, 75.0] 0.000000e+00 0.000000e+00 2.689670e-11
[125.0, 165.0, 155.0, 195.0] 0.000000e+00 0.000000e+00 0.000000e+00
[125.0, 125.0, 155.0, 155.0] 0.000000e+00 0.000000e+00 0.000000e+00
[85.0, 5.0, 115.0, 35.0] 0.000000e+00 0.000000e+00 0.000000e+00
[125.0, 85.0, 155.0, 115.0] 0.000000e+00 0.000000e+00 0.000000e+00
[125.0, 45.0, 155.0, 75.0] 7.171044e-10 -1.000000e-02 -1.000000e-02
[125.0, 5.0, 155.0, 35.0] 0.000000e+00 0.000000e+00 1.676346e-08
[165.0, 5.0, 195.0, 35.0] 1.396498e-13 -1.000000e-02 3.409809e-07
[165.0, 45.0, 195.0, 75.0] 0.000000e+00 0.000000e+00 0.000000e+00
[205.0, 5.0, 235.0, 35.0] 0.000000e+00 5.989715e-06 3.319001e-13
[205.0, 45.0, 235.0, 75.0] 0.000000e+00 8.988374e-05 2.131374e-08
[205.0, 85.0, 235.0, 115.0] 0.000000e+00 1.126386e-03 0.000000e+00
[165.0, 85.0, 195.0, 115.0] -1.000000e-02 2.585824e-04 0.000000e+00
[205.0, 125.0, 235.0, 155.0] 0.000000e+00 0.000000e+00 0.000000e+00
[205.0, 165.0, 235.0, 195.0] 1.354722e-07 0.000000e+00 0.000000e+00
[205.0, 205.0, 235.0, 235.0] 0.000000e+00 0.000000e+00 0.000000e+00
[165.0, 165.0, 195.0, 195.0] 4.077995e-04 3.949939e-01 0.000000e+00
[165.0, 125.0, 195.0, 155.0] 0.000000e+00 8.241680e-02 1.012101e-04
[165.0, 205.0, 195.0, 235.0] 0.000000e+00 0.000000e+00 0.000000e+00
3
[5.0, 5.0, 35.0, 35.0] 0.000000e+00
[45.0, 5.0, 75.0, 35.0] 0.000000e+00
[5.0, 45.0, 35.0, 75.0] 0.000000e+00
[45.0, 45.0, 75.0, 75.0] 0.000000e+00
[5.0, 85.0, 35.0, 115.0] 0.000000e+00
[45.0, 85.0, 75.0, 115.0] 0.000000e+00
[5.0, 125.0, 35.0, 155.0] 0.000000e+00
[5.0, 165.0, 35.0, 195.0] 0.000000e+00
[45.0, 125.0, 75.0, 155.0] 0.000000e+00
[45.0, 165.0, 75.0, 195.0] 0.000000e+00
[5.0, 205.0, 35.0, 235.0] 0.000000e+00
[45.0, 205.0, 75.0, 235.0] 0.000000e+00
[85.0, 205.0, 115.0, 235.0] 0.000000e+00
[85.0, 165.0, 115.0, 195.0] -1.990000e-02
[125.0, 205.0, 155.0, 235.0] 0.000000e+00
[85.0, 125.0, 115.0, 155.0] -1.000000e-02
[85.0, 85.0, 115.0, 115.0] -1.000000e-02
[85.0, 45.0, 115.0, 75.0] -1.990000e-02
[125.0, 165.0, 155.0, 195.0] 0.000000e+00
[125.0, 125.0, 155.0, 155.0] 0.000000e+00
[85.0, 5.0, 115.0, 35.0] -1.990000e-02
[125.0, 85.0, 155.0, 115.0] 0.000000e+00
[125.0, 45.0, 155.0, 75.0] 0.000000e+00
[125.0, 5.0, 155.0, 35.0] 0.000000e+00
[165.0, 5.0, 195.0, 35.0] 0.000000e+00
[165.0, 45.0, 195.0, 75.0] 0.000000e+00
[205.0, 5.0, 235.0, 35.0] 0.000000e+00
[205.0, 45.0, 235.0, 75.0] -1.000000e-02
[205.0, 85.0, 235.0, 115.0] 7.290000e-09
[165.0, 85.0, 195.0, 115.0] -1.000000e-02
[205.0, 125.0, 235.0, 155.0] 1.185054e-02
[205.0, 165.0, 235.0, 195.0] 0.000000e+00
[205.0, 205.0, 235.0, 235.0] 0.000000e+00
[165.0, 165.0, 195.0, 195.0] -1.000000e-02
[165.0, 125.0, 195.0, 155.0] -1.000000e-02
[165.0, 205.0, 195.0, 235.0] 0.000000e+00
五、完整的代码
1、环境
# coding: utf-8
import sys
import time
import numpy as np
if sys.version_info.major == 2:
import Tkinter as tk
else:
import tkinter as tk
class Maze(tk.Tk, object):
UNIT = 40 # pixels
MAZE_H = 6 # grid height
MAZE_W = 6 # grid width
def __init__(self):
super(Maze, self).__init__()
self.action_space = ['U', 'D', 'L', 'R']
self.n_actions = len(self.action_space)
self.title('迷宫')
self.geometry('{0}x{1}'.format(self.MAZE_H * self.UNIT, self.MAZE_W * self.UNIT)) #窗口大小
self._build_maze()
#画矩形
#x y 格坐标
#color 颜色
def _draw_rect(self, x, y, color):
center = self.UNIT / 2
w = center - 5
x_ = self.UNIT * x + center
y_ = self.UNIT * y + center
return self.canvas.create_rectangle(x_-w, y_-w, x_+w, y_+w, fill = color)
#初始化迷宫
def _build_maze(self):
h = self.MAZE_H*self.UNIT
w = self.MAZE_W*self.UNIT
#初始化画布
self.canvas = tk.Canvas(self, bg='white', height=h, width=w)
#画线
for c in range(0, w, self.UNIT):
self.canvas.create_line(c, 0, c, h)
for r in range(0, h, self.UNIT):
self.canvas.create_line(0, r, w, r)
# 陷阱
self.hells = [self._draw_rect(3, 2, 'black'),
self._draw_rect(3, 3, 'black'),
self._draw_rect(3, 4, 'black'),
self._draw_rect(3, 5, 'black'),
self._draw_rect(4, 1, 'black'),
self._draw_rect(4, 5, 'black'),
self._draw_rect(1, 0, 'black'),
self._draw_rect(1, 1, 'black'),
self._draw_rect(1, 2, 'black'),
self._draw_rect(1, 3, 'black'),
self._draw_rect(1, 4, 'black')]
self.hell_coords = []
for hell in self.hells:
self.hell_coords.append(self.canvas.coords(hell))
# 奖励
self.oval = self._draw_rect(4, 5, 'yellow')
# 玩家对象
self.rect = self._draw_rect(0, 0, 'red')
self.canvas.pack() #执行画
#重新初始化
def reset(self):
self.update()
time.sleep(0.5)
self.canvas.delete(self.rect)
self.rect = self._draw_rect(0, 0, 'red')
self.old_s = None
#返回 玩家矩形的坐标 [5.0, 5.0, 35.0, 35.0]
return self.canvas.coords(self.rect)
#走下一步
def step(self, action):
s = self.canvas.coords(self.rect)
base_action = np.array([0, 0])
if action == 0: # up
if s[1] > self.UNIT:
base_action[1] -= self.UNIT
elif action == 1: # down
if s[1] < (self.MAZE_H - 1) * self.UNIT:
base_action[1] += self.UNIT
elif action == 2: # right
if s[0] < (self.MAZE_W - 1) * self.UNIT:
base_action[0] += self.UNIT
elif action == 3: # left
if s[0] > self.UNIT:
base_action[0] -= self.UNIT
#根据策略移动红块
self.canvas.move(self.rect, base_action[0], base_action[1])
s_ = self.canvas.coords(self.rect)
#判断是否得到奖励或惩罚
done = False
if s_ == self.canvas.coords(self.oval):
reward = 1
done = True
elif s_ in self.hell_coords:
reward = -1
done = True
#elif base_action.sum() == 0:
# reward = -1
else:
reward = 0
self.old_s = s
return s_, reward, done
def render(self):
time.sleep(0.01)
self.update()
2、Q-Learning
# coding: utf-8
import pandas as pd
import numpy as np
class q_learning_model_maze:
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.99):
self.actions = actions
self.learning_rate = learning_rate
self.reward_decay = reward_decay
self.e_greedy = e_greedy
self.q_table = pd.DataFrame(columns=actions,dtype=np.float32)
#检查状态是否存在
def check_state_exist(self, state):
if state not in self.q_table.index:
self.q_table = self.q_table.append(
pd.Series(
[0] * len(self.actions),
index=self.q_table.columns,
name=state,
)
)
#选择动作
def choose_action(self, s):
self.check_state_exist(s)
if np.random.uniform() < self.e_greedy:
state_action = self.q_table.ix[s, :]
state_action = state_action.reindex(
np.random.permutation(state_action.index)) #防止相同列值时取第一个列,所以打乱列的顺序
action = state_action.argmax()
else:
action = np.random.choice(self.actions)
return action
#更新q表
def rl(self, s, a, r, s_):
self.check_state_exist(s_)
q_predict = self.q_table.ix[s, a] #q估计
if s_ != 'terminal':
q_target = r + self.reward_decay * self.q_table.ix[s_, :].max() #q现实
else:
q_target = r
self.q_table.ix[s, a] += self.learning_rate * (q_target - q_predict)
3、训练实验
# coding: utf-8
from maze_env_1 import Maze
from q_learning_model_maze import q_learning_model_maze
def update():
for episode in range(100):
s = env.reset()
while True:
env.render()
#选择一个动作
action = RL.choose_action(str(s))
#执行这个动作得到反馈(下一个状态s 奖励r 是否结束done)
s_, r, done = env.step(action)
#更新状态表
RL.rl(str(s), action, r, str(s_))
s = s_
if done:
print(episode)
break
if __name__ == "__main__":
env = Maze()
RL = q_learning_model_maze(actions=list(range(env.n_actions)))
env.after(10, update) #延迟10毫秒执行update
env.mainloop()