强化学习——蛇棋游戏策略迭代实现
1"表格式"Agent
在之前的文章的基础之上,本文对搭建的蛇棋游戏采用策略迭代的方法实现。策略迭代时,环境的状态转移概率需要对Agent公开,这样Agent就能利用这些信息做出更好的决策。对于蛇棋来说,如果知道骰子的每一面朝上的概率是均匀的,以及棋盘上的每一个梯子都是可见的,就可以计算出状态转移概率。下面一段简单的代码,根据环境的信息生成问题中所有实体“表格式”数据结构,Agent代码的基本结构如下所示:
class TableAgent(object):
def __init__(self, env):
self.s_len = env.observation_space.n # |S|
self.a_len = env.action_space.n # |A|
self.r = [env.reward(s) for s in range(0, self.s_len)] # R
self.pi = np.array([0 for s in range(0, self.s_len)]) # π
self.p = env.p # P
self.value_pi = np.zeros((self.s_len)) # V
self.value_q = np.zeros((self.s_len, self.a_len)) # Q
self.gamma = 0.8 # γ
def play(self, state):
return self.pi[state]
2、对游戏的评估
代码中实现了三种策略:最优策略、全部执行第一个行动的策略和全部执行第二个行动的策略。为了简化代码,直接用一个数组表示Agent要执行的行动。完成了策略的构建,就可以评估策略了,这其实就是让Agent在真实的环境上进行交互,得到回报总和。代码如下:
def eval_game(env, policy):
state = env.reset()
return_val = 0
# 有两种play的方法,一种是用我们定义的智能体去玩,另一种是直接指定每个s的a。
while True:
if isinstance(policy, TableAgent) or isinstance(policy, ModelFreeAgent):
act = policy.play(state)
elif isinstance(policy, list):
act = policy[state]
else:
raise Error('Illegal policy')
state, reward, terminate, _ = env.step(act) # 不断游戏直至结束
return_val += reward
if terminate:
break
return return_val
def test_easy():
policy_opt = [1] * 97 + [0] * 3 # 最优策略
policy_0 = [0] * 100 # 全部都投掷第一个骰子(1~3)
policy_1 = [1] * 100 # 全部都投掷第二个骰子(1~6)
np.random.seed(0)
sum_opt = 0
sum_0 = 0
sum_1 = 0
env = SnakeEnv(0, [3, 6])
for i in range(10000):
sum_opt += eval_game(env, policy_opt)
sum_0 += eval_game(env, policy_0)
sum_1 += eval_game(env, policy_1)
print('opt avg={}'.format(sum_opt / 10000.0))
print('0 avg={}'.format(sum_0 / 10000.0))
print('1 avg={}'.format(sum_1 / 10000.0))
在程序中,我们使用每一种策略进行一万局游戏,并显示每一种策略的平均得分,由于棋盘上没有梯子,所以棋局的环境不用发生变化。游戏最终的平均得分如下:
3、策略迭代
3.1、策略评估
策略评估公式:
程序如下:
# 迭代计算V直至收敛
def policy_evaluation(self, agent, max_iter=-1):
iteration = 0
while True:
iteration += 1
new_value_pi = agent.value_pi.copy()
for i in range(1, agent.s_len):
value_sas = []
ac = agent.pi[i]
transition = agent.p[ac, i, :]
value_sa = np.dot(transition, agent.r + agent.gamma * agent.value_pi)
new_value_pi[i] = value_sa
diff = np.sqrt(np.sum(np.power(agent.value_pi - new_value_pi, 2)))
if diff < 1e-6:
break
else:
agent.value_pi = new_value_pi
if iteration == max_iter:
break
3.2、策略改善
完成上面的部分之后,根据之前的状态值函数计算状态行为值函数:
完成计算之后,根据同意状态下的行为值函数更新策略:
代码如下:
# 根据V更新π
def policy_improvement(self, agent):
new_policy = np.zeros_like(agent.pi)
for i in range(1, agent.s_len):
for j in range(0, agent.a_len):
agent.value_q[i, j] = np.dot(agent.p[j, i, :], agent.r + agent.gamma * agent.value_pi)
max_act = np.argmax(agent.value_q[i, :])
new_policy[i] = max_act
if np.all(np.equal(new_policy, agent.pi)):
return False
else:
agent.pi = new_policy
return True
# 大框架:进行一定次数的迭代,每次先策略评估,再策略改善
再将上面的程序联合起来,怎个算法的执行如下所示:
def policy_iteration(self, agent):
iteration = 0
while True:
iteration += 1
self.policy_evaluation(agent)
ret = self.policy_improvement(agent)
if not ret:
break
print('Iter {} rounds converge'.format(iteration))
运行程序:
env = SnakeEnv(0, [3,6])
agent = TableAgent(env)
pi_algo = PolicyIteration()
pi_algo.policy_iteration(agent)
print('return_pi={}'.format(eval_game(env, agent)))
print(agent.pi)
运行结果如下:
可以看出,每一轮迭代结束,策略都进行了一次更新,当策略没有更新时,迭代结束。