Q-Learning demo

Q-Learning

Learn not to bother python

Fake code

1. 随机初始化Q(s, a)#可以为全零阵
2. repeat(for each episode):
3.     初始化s(随机)
4.     repeat(for each step of episode):
5.         选择一个a(有策略,比如epsilon greedy策略)
6.         根据s, a,得到下一个s'
7.         Q(s, a) = (1 - alpha) * Q(s, a) + alpha * {R(s, a) + gamma * max[Q(s', a')]
8.         s = s'
9.    until s 到达目标状态
10.可以选择一个条件,当Q收敛或者多少个episode的时候结束训练

When the agent is in a certain state, it needs to choose a behavior to reach the next state, that is, choose an a. How to choose:
If each choice is based on experience, will it fall into the local optimal solution? If randomly selected, will it cause the convergence speed to be too slow.
The epsilon-greedy strategy is generally used:
that is, there is a probability of 1 - epsilonepsilon to randomize, and the probability is empirical.

Code:

# coding:utf-8
from __future__ import print_function  # 必须放到开始
import numpy as np
import pandas as pd
import time

# 预设值
N_STATES = 6  # 状态数量
ACTIONS = ['left', 'right']  # 行为,两种。
EPSILON = 0.9  # epsilon greedy,贪婪度
ALPHA = 0.1  # 学习率
GAMMA = 0.9  # 奖励递减值,也就是不重视奖励的程度
MAX_EPISODES = 20  # 最大回合数
FRESH_TIME = 0.1  # 每一帧画面停留时间


# 初始化q_table
def build_q_table():
    q_table = pd.DataFrame(np.zeros((N_STATES, len(ACTIONS))), columns=ACTIONS)
    return q_table


# 选择动作,action
# 利用epsilon greedy选择a
def choose_action(state, q_table):
    state_actions = q_table.iloc[state, :]
    if np.random.uniform() > EPSILON or is_all_zero(state_actions):
        action_name = np.random.choice(ACTIONS)
    else:
        action_name = state_actions.idxmax()
    return action_name


# 判断全零行
def is_all_zero(series):
    for s in series:
        if s != 0:
            return False
    return True


# 环境反馈
# 执行走步操作,从(S, A)到S_状态,并获得奖励R
# 得到S_、R
def get_env_feedback(state, action):
    if action == 'right':
        if state == N_STATES - 2:
            next_state = 'terminal'
            R = 1
        else:
            next_state = state + 1
            R = 0
    else:
        R = 0
        if state == 0:  # todo why?将if state 改成 if state == 0 就没有bug了?而之前next_state会变成-1
            next_state = state
        else:
            next_state = state - 1
    # print('\nS: {}-{}-S\': {}'.format(state, action, next_state))
    return next_state, R


# 环境更新
# 每走完一步,要更新一帧图
def update_env(state, episode, step_counter):
    env_list = ['-'] * (N_STATES - 1) + ['T']
    if state == 'terminal':
        interaction = 'episode: %s; total_steps = %s' % (episode + 1, step_counter)  # fixme +1???
        print('\r{}'.format(interaction), end='')
        time.sleep(2)
        print('\r                           ', end='')  # 清屏
    else:
        env_list[state] = 'o'
        interaction = ''.join(env_list)
        print('\r{}'.format(interaction),
              end='')  # end=''是Python3的内容,必须在文件导入的部分第一句位置写from __future__ import print_function
        # \r是回车,回到一行的开始
        time.sleep(FRESH_TIME)


# q_learning
def q_learing():
    q_table = build_q_table()
    for episode in range(MAX_EPISODES):  # TODO 如果是判断Q阵收敛,怎么判断?
        step_counter = 0  # 走的步数
        state = 0
        is_terminated = False  # 一局游戏结束的标志
        update_env(state, episode, step_counter)
        while not is_terminated:
            action = choose_action(state, q_table)
            # print('\nS值: {}\n'.format(state))
            # print(action)
            next_state, R = get_env_feedback(state, action)
            q_predict = q_table.loc[state, action]
            if next_state != 'terminal':
                q_target = R + GAMMA * q_table.iloc[next_state, :].max()
            else:
                q_target = R  # 到终点直接拿糖
                is_terminated = True
            q_table.loc[state, action] = q_predict + ALPHA * (q_target - q_predict)
            # 如果不区分两种状态,是否可以如下写:
            # q_table.loc[state, action] = (1 - ALPHA) * q_table.loc[state, action] + ALPHA * (
            #         R + GAMMA * q_table.loc[next_state, :].max())
            # print('\nS\'值: {}\n'.format(state))
            state = next_state
            step_counter += 1
            update_env(state, episode, step_counter)
    return q_table


if __name__ == '__main__':
    q_table = q_learing()
    print('\nQ_Table:\n{}'.format(q_table))

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325521933&siteId=291194637