强化学习经典算法笔记——策略迭代算法
上一篇讲了价值迭代算法,这一篇介绍另一个动态规划算法——策略迭代算法(Policy Iteration)。
简单介绍
Value Iteration的思路是:先迭代找出一个最优的Value Function,然后再根据Value Function迭代出一个最优策略。
Policy Iteration的思路是反着的,首先给定一个初始化的策略函数,一般是随机策略。给予这个策略,可以得到每个状态下采取的动作,进而得到reward和下一状态,利用更新法则,更新value function。
然后,根据更新后的value function更新之前的随机策略。如此,就完成了value-policy的交替更新,直至收敛。
编程实现
还是以FrozenLake游戏为例,实现Policy Iteration算法。
import gym
import numpy as np
env = gym.make('FrozenLake-v0')
def compute_value_function(policy, gamma=1.0, threshold = 1e-20):
'''
计算value function
'''
# len=16
value_table = np.zeros(env.nS)
while True:
updated_value_table = np.copy(value_table)
for state in range(env.nS):
# 根据策略函数,给定一个状态,输出该状态下的应该采取的动作
action = policy[state]
# 采取动作后,遍历所有可能的转移状态,计算状态价值
# 和value iteration的区别是:
# value iteration的策略是从value function中根据greedy原则选出来的
# policy iteration的策略是事先给定的,value是根据policy得出的,这个action的价值代表了当前状态的价值?
value_table[state] = sum([ trans_prob * (reward_prob + gamma * updated_value_table[next_state])
for trans_prob, next_state, reward_prob,_ in env.P[state][action] ])
if (np.sum((np.fabs(updated_value_table - value_table))) <= threshold):
break
return value_table
def extract_policy(value_table, gamma = 1.0):
'''
'''
policy = np.zeros(env.observation_space.n)
for state in range(env.observation_space.n):
Q_table = np.zeros(env.action_space.n)
for action in range(env.action_space.n):
for next_sr in env.P[state][action]:
trans_prob, next_state, reward_prob,_= next_sr
Q_table[action] += (trans_prob * (reward_prob + gamma * value_table[next_state]))
policy[state] = np.argmax(Q_table)
return policy
def policy_iteration(env,gamma = 1.0, no_of_iterations = 200000):
'''
状态值估计和策略函数的优化是交替进行的,从随机策略出发,估计状态价值
再从收敛的状态值函数出发,优化之前的随机策略。由此往复,直至收敛
'''
gamma = 1.0
# 随机策略
random_policy = np.zeros(env.observation_space.n)
for i in range(no_of_iterations):
new_value_function = compute_value_function(random_policy, gamma)
new_policy = extract_policy(new_value_function, gamma)
if (np.all(random_policy == new_policy)):
print('Policy-Iteration converged at step %d.'%(i+1))
break
random_policy = new_policy
return new_policy
print (policy_iteration(env))
最后结果
Policy-Iteration converged at step 7.
[0. 3. 3. 3. 0. 0. 0. 0. 3. 1. 0. 0. 0. 2. 1. 0.]