强化学习之Sarsa

在强化学习中,Sarsa和Q-Learning很类似,本次内容将会基于之前所讲的Q-Learning的内容。

目录

  • 算法简介
  • 更新准则
  • 探险者上天堂实战

算法简介


Sarsa决策部分和Q-Learning一抹一样,都是采用Q表的方式进行决策,所以我们会在Q表中挑选values比较大的动作实施在环境中来换取奖赏。但是Sarsa的更新是不一样的

更新准则


和上次一样用小学生写作业为例子,我们会经历写作业的状态s1,然后再挑选一个带来最大潜在奖励的动作a2,这样我们就到达了继续写作业的状态s2,而在这一步没如果你用的是Q-Learning,你会观察一下在s2上选取哪一个动作会带来最大的奖赏reward来更新,但是在真正要做决定的时候却不一定会选取到那个带来最大reward的动作,Q-Learning这一步只是估计了接下来的value。而Sarsa在s2这一步估计的动作就是他接下来要做的动作。所以Q(s1,a2)现实的计算值我们也会改动,去掉了maxQ,取而代之的是在S2上我们实实在在选取的a2的Q值。最后像Q-Learning一样,求出现实和估计的差距并更新Q表里的Q(s1,a2)

上图就是Sarsa更新的公式。我们可以看到和Q-Learning的不同之处:

  • 他在当前的state中已经想好了state对应的action,而且想好了下一个state_和下一个action_(Q-learning还没有想好下一个action_
  • 更新Q(s,a)的时候基于的是下一个Q(s_,a_)(Q-learning基于的是maxQ(s_)

这种不同之处使得Sarsa相对于Q-learning显得比较的”胆小“。原因在于

  • Q-learning在更新的时候始终都是选择maxQ最大化,因为这个maxQ变得贪婪,不考虑其他非maxQ的结果。我们可以理解成Q-learning是一种贪婪,大胆,勇敢的算法,对于错误,死亡并不在乎。而Sarsa是一种保守的算法,他在乎每一步的决策,对于错误和死亡比较敏感,这可以在可视化部分看出他们的不同。两种算法都有他们的好处,比如在实际中,如果你比较在乎机器的损害,那么用一种保守的算法,在训练中就可以有效地减少损坏的次数。
  • 从另一个角度想,Q-learning更新使用maxQ,而Sarsa却要看a_的值,而a_的值需要看greedy的脸色,如果greedy=1那么a_就是maxQ,与Q—Learning在greedy=1无差别。greedy值越小,Sarsa越不坚决(选择Q表中大的那个),而是会根据np.random.choice随机选择一个方向,同时也正是因Sarsa多了一项探索的概率,所以才是的Sarsa容易偏离终点,从视觉上看Sarsa有时显得很纠结。正因如此,Sarsa其实在某些程度上显得他很勇敢,因为Sarsa比Q-Learning更有探索精神,也正是这份精神使得Sarsa对终点的渴望不那么果决,饥渴成都要看greedy的脸色,更具多面性。

探险者上天堂实战

背景

黄色是天堂(reward=1),黑色是地狱(reward=-1)。我们的目标就是让探险者经过自己的多次入“地狱”,最终学会入“天堂”

主模块

首先我们先import两个模块,maze_env是我们游戏虚拟环境模块,是用python自带的GUI模块tkinter来编写,具体细节不多赘述,完整代码会放在最后。RL_brain这个模块是RL的大脑部分,稍后会提及。

from maze_env import Maze
from RL_brain import SarsaTable

下面就是我们的更新部分代码

def update():
    for episode in range(100):
        # 初始化环境
        observation = env.reset()

        # Sarsa根据state观测选择行为
        action = RL.choose_action(str(observation))

        while True:
            # 刷新环境
            env.render()

            # 在环境中采取行为,获得下一个state_(observation_),reward,和终止信号
            observation_, reward, done = env.step(action)

            # 根据下一个state(observation_)选取下一个action_
            action_ = RL.choose_action(str(observation_))

            #从(s, a, r, s, a)中学习,更新Q_table的参数
            RL.learn(str(observation), action, reward, str(observation_), action_)

            # 将下一个的observation_和action_当成对应下一步的参数
            observation = observation_
            action = action_

            if done:
                break

    # end of game
    print('game over')
    env.destroy()

if __name__ == "__main__":
    #定义环境enc和RL方式
    env = Maze()
    RL = SarsaTable(actions=list(range(env.n_actions)))
    env.after(100, update)
    env.mainloop()

RL_brain模块

我们定义一个父类classRL,然后SarsaTable作为父类的衍生。

import numpy as np
import pandas as pd


class RL:
    #初始化参数
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        self.actions = actions  # 行为列表
        self.lr = learning_rate #学习率
        self.gamma = reward_decay  #奖励衰减度
        self.epsilon = e_greedy #贪婪度
        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64) #初始化q_table

    #选择行为
    def choose_action(self, observation):
        self.check_state_exist(observation) #检验state是否在q_table中出现
        # 贪婪模式
        if np.random.uniform() < self.epsilon:
            state_action = self.q_table.loc[observation, :]
            # 同一个state,可能会有多个相同的Q action value,所以我们乱序一下
            action = np.random.choice(state_action[state_action == np.max(state_action)].index)
        else:
            # 非贪婪模式随机选择action
            action = np.random.choice(self.actions)
        return action

    #学习更新参数
    def learn(self, s, a, r, s_):
        self.check_state_exist(s_)#同样先检验一下q_table中是否存在S_
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            #下个状态不是终止
            q_target = r + self.gamma * self.q_table.loc[s_, :].max()
        else:
            q_target = r
        #更新参数
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)

    #检验state是否存在
    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # 如果不存在就插入一组全0数据,当做state的所有action的初始values
            self.q_table = self.q_table.append(
                pd.Series(
                    [0]*len(self.actions),
                    index=self.q_table.columns,
                    name=state,
                )
            )

然后我们编写SarsaTablelearn也就是更新功能就完成了。

class SarsaTable(RL):

    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, s, a, r, s_, a_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # q_target基于选好的a_而不是Q(s_)的最大值
        else:
            q_target = r
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # 更新q_table

最后探险者就可以很轻松的上天堂了!

参考:
https://github.com/MorvanZhou

猜你喜欢

转载自blog.csdn.net/cristiano20/article/details/96454568