机器学习之Grid World的SARSA算法解析

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/tomatomas/article/details/77278184

SARSA

SARSA(State-Action-Reward-State-Action)是一个学习马尔可夫决策过程策略的算法,通常使用在机器学习领域的增强学习上。一篇技术文章介绍了这个算法并且在注脚处提到了SARSA这个别名。
State-Action-Reward-State-Action这个名称清楚地反应了其学习更新函数依赖的5个值,分别是当前状态S1,当前状态选中的动作A1,获得的奖励Reward,S1状态下执行A1后取得的状态S2及S2状态下将会执行的动作A2。我们取这5个值的首字母串起来可以得出一个词SARSA。

以下是维基百科的原文,翻译得不好请轻拍,对的,那谁,请把板砖放下:

State-Action-Reward-State-Action (SARSA) is an algorithm for learning a Markov decision process policy, used in the reinforcement learning area of machine learning. It was introduced in a technical note[1] where the alternative name SARSA was only mentioned as a footnote.

This name simply reflects the fact that the main function for updating the Q-value depends on the current state of the agent “S1”, the action the agent chooses “A1”, the reward “R” the agent gets for choosing this action, the state “S2” that the agent will now be in after taking that action, and finally the next action “A2” the agent will choose in its new state. Taking every letter in the quintuple (st, at, rt, st+1, at+1) yields the word SARSA.[2]

代码实现

前面说了那么多还是不知道在说啥怎么办!很简单,Talk is cheap,just show me the code!
跟前面几篇文章一样来自Github上同一个开源项目,我们一起来看下他的SARSA算法实现。

if __name__ == "__main__":
    env = Env()
    agent = SARSAgent(actions=list(range(env.n_actions)))

    for episode in range(1000):
        # reset environment and initialize state

        state = env.reset()
        # get action of state from agent
        action = agent.get_action(str(state))

        while True:
            env.render()

            # take action and proceed one step in the environment
            next_state, reward, done = env.step(action)
            next_action = agent.get_action(str(next_state))

            # with sample <s,a,r,s',a'>, agent learns new q function
            agent.learn(str(state), action, reward, str(next_state), next_action)

            state = next_state
            action = next_action

            # print q function of all states at screen
            env.print_value_all(agent.q_table)

            # if episode ends, then break
            if done:
                break

入口代码跟Monte Carlo算法的入口代码差不多,我就不再多说了。现在直接看其get_action函数:

    def get_action(self, state):
        if np.random.rand() < self.epsilon:
            # take random action
            action = np.random.choice(self.actions)
        else:
            # take action according to the q function table
            state_action = self.q_table[state]
            action = self.arg_max(state_action)
        return action

同样是一定的几率下取返回一个随机action,而在其他情况下就和Monte Carlo算法有点不一样。SARSA算法使用了一个q_table来获取当前状态可能的action,然后使用arg_max算法随机取一个可能获取最大收益的action。arg_max函数与Monte Carlo算法中的实现一样。
接着agent和环境做了一次交互并获得3个值:next_state, reward, done。然后对next_state再调用一次get_action获取下一个action。到这里,大家是不是注意到了,SARSA名称中的5个值我们都获取到了,接下来就是这个算法学习的过程了,代码中调用了learn函数进行学习:

    # with sample <s, a, r, s', a'>, learns new q function
    def learn(self, state, action, reward, next_state, next_action):
        current_q = self.q_table[state][action]
        next_state_q = self.q_table[next_state][next_action]
        new_q = (current_q + self.learning_rate *
                (reward + self.discount_factor * next_state_q - current_q))
        self.q_table[state][action] = new_q

其学习过程的计算还是比较简单明了的,我就不再解析了。可以看得出其最核心的一个东西就是q_table,q_table的值直接影响了其动作。

与Monte Carlo算法的异同

现在,我们将SARSA算法和Monte Carlo算法比较一下,大家是否能感受到他们的不同之处了。他们最大的不同应该就是Monte Carlo算法是执行完一整个过程,即达到蓝色圆形之后才会去更新其value_table,而SARSA算法则是每走一步就会学习更新q_table,这样做的优劣体现在:

1.每一步都可以更新,这是显然,也就是online learning,学习快
2.可以面对没有结果的场景,应用范围广

但是这样做也有不足的地方,就是因为TD target是估计值,估计是有误差的,这就会导致更新得到value是有偏差的。

引用:http://blog.csdn.net/songrotek/article/details/51382759

猜你喜欢

转载自blog.csdn.net/tomatomas/article/details/77278184