强化学习——蛇棋游戏策略迭代实现

1"表格式"Agent

  在之前的文章的基础之上,本文对搭建的蛇棋游戏采用策略迭代的方法实现。策略迭代时,环境的状态转移概率需要对Agent公开,这样Agent就能利用这些信息做出更好的决策。对于蛇棋来说,如果知道骰子的每一面朝上的概率是均匀的,以及棋盘上的每一个梯子都是可见的,就可以计算出状态转移概率。下面一段简单的代码,根据环境的信息生成问题中所有实体“表格式”数据结构,Agent代码的基本结构如下所示:

class TableAgent(object):
    def __init__(self, env):
        self.s_len = env.observation_space.n # |S|
        self.a_len = env.action_space.n # |A|
        self.r = [env.reward(s) for s in range(0, self.s_len)] # R
        self.pi = np.array([0 for s in range(0, self.s_len)]) # π
        self.p = env.p # P
        self.value_pi = np.zeros((self.s_len)) # V
        self.value_q = np.zeros((self.s_len, self.a_len)) # Q
        self.gamma = 0.8 # γ

    def play(self, state):
        return self.pi[state]

2、对游戏的评估

  代码中实现了三种策略:最优策略、全部执行第一个行动的策略和全部执行第二个行动的策略。为了简化代码,直接用一个数组表示Agent要执行的行动。完成了策略的构建,就可以评估策略了,这其实就是让Agent在真实的环境上进行交互,得到回报总和。代码如下:

def eval_game(env, policy):
    state = env.reset()
    return_val = 0
    # 有两种play的方法,一种是用我们定义的智能体去玩,另一种是直接指定每个s的a。
    while True:
        if isinstance(policy, TableAgent) or isinstance(policy, ModelFreeAgent):
            act = policy.play(state)
        elif isinstance(policy, list):
            act = policy[state]
        else:
            raise Error('Illegal policy')
        state, reward, terminate, _ = env.step(act) # 不断游戏直至结束
        return_val += reward
        if terminate:
          break
    return return_val

def test_easy():
    policy_opt = [1] * 97 + [0] * 3 # 最优策略
    policy_0 = [0] * 100 # 全部都投掷第一个骰子(1~3)
    policy_1 = [1] * 100 # 全部都投掷第二个骰子(1~6)
    np.random.seed(0)
    sum_opt = 0
    sum_0 = 0
    sum_1 = 0
    env = SnakeEnv(0, [3, 6])
    for i in range(10000):
        sum_opt += eval_game(env, policy_opt)
        sum_0 += eval_game(env, policy_0)
        sum_1 += eval_game(env, policy_1)
    print('opt avg={}'.format(sum_opt / 10000.0))
    print('0 avg={}'.format(sum_0 / 10000.0))
    print('1 avg={}'.format(sum_1 / 10000.0))

  在程序中,我们使用每一种策略进行一万局游戏,并显示每一种策略的平均得分,由于棋盘上没有梯子,所以棋局的环境不用发生变化。游戏最终的平均得分如下:
在这里插入图片描述

3、策略迭代

3.1、策略评估

策略评估公式: v T ( s t ) = a π ( a s t ) s t + 1 p ( s t + 1 s t , a ) [ r t + 1 + γ v T 1 ( s t + 1 ) ] {v^T}({s_t}) = \sum\limits_a {\pi (a|{s_t})} \sum\limits_{{s_{t + 1}}} {p({s_{t + 1}}|{s_t},a)[{r_{t + 1}} + \gamma {v^{T - 1}}({s_{t + 1}})]}
程序如下:

 # 迭代计算V直至收敛
    def policy_evaluation(self, agent, max_iter=-1):
        iteration = 0
        while True:
            iteration += 1
            new_value_pi = agent.value_pi.copy()
            for i in range(1, agent.s_len):
                value_sas = []
                ac = agent.pi[i]
                transition = agent.p[ac, i, :]
                value_sa = np.dot(transition, agent.r + agent.gamma * agent.value_pi)
                new_value_pi[i] = value_sa
            diff = np.sqrt(np.sum(np.power(agent.value_pi - new_value_pi, 2)))
            if diff < 1e-6:
                break
            else:
                agent.value_pi = new_value_pi
            if iteration == max_iter:
                break

3.2、策略改善

完成上面的部分之后,根据之前的状态值函数计算状态行为值函数: q ( s t , a t ) = s t + 1 p ( s t + 1 s t , a t ) [ r t + γ v ( s t + 1 ) ] q({s_t},{a_t}) = \sum\limits_{{s_{t + 1}}} {p({s_{t + 1}}|{s_t},{a_t})[{r_t} + \gamma v({s_{t + 1}})]} 完成计算之后,根据同意状态下的行为值函数更新策略: π ( s ) = arg max a q ( s , a ) \pi (s) = \mathop {\arg \max }\limits_a q(s,a)
代码如下:

 # 根据V更新π
    def policy_improvement(self, agent):
        new_policy = np.zeros_like(agent.pi)
        for i in range(1, agent.s_len):
            for j in range(0, agent.a_len):
                agent.value_q[i, j] = np.dot(agent.p[j, i, :], agent.r + agent.gamma * agent.value_pi)
            max_act = np.argmax(agent.value_q[i, :])
            new_policy[i] = max_act
        if np.all(np.equal(new_policy, agent.pi)):
            return False
        else:
            agent.pi = new_policy
            return True

        # 大框架:进行一定次数的迭代,每次先策略评估,再策略改善

再将上面的程序联合起来,怎个算法的执行如下所示:

 def policy_iteration(self, agent):
        iteration = 0
        while True:
            iteration += 1
            self.policy_evaluation(agent)
            ret = self.policy_improvement(agent)
            if not ret:
                break
        print('Iter {} rounds converge'.format(iteration))

运行程序:

env = SnakeEnv(0, [3,6])
agent = TableAgent(env)
pi_algo = PolicyIteration()
pi_algo.policy_iteration(agent)
print('return_pi={}'.format(eval_game(env, agent)))
print(agent.pi)

运行结果如下:
在这里插入图片描述
  可以看出,每一轮迭代结束,策略都进行了一次更新,当策略没有更新时,迭代结束。

发布了19 篇原创文章 · 获赞 25 · 访问量 2415

猜你喜欢

转载自blog.csdn.net/fly975247003/article/details/102153016
今日推荐