基于模型的动态规划方法理论——策略迭代方法(maze代码实现)

直接上代码
policy_iteration_method.py

import random
import numpy as np

class PIM:
    def __init__(self):
        self.pi = dict()
        self.v = dict()


    def create(self, mdp):
        for state in mdp.states:
            self.v[self.encode_state(state)] = 0
            # 随机策略
            action = random.choice(mdp.state_action(state))
            self.pi[self.encode_state(state)] = action
        
        

    def policy_iteration(self, mdp, interation):
        for _ in range(interation):
            self.policy_evaluate(mdp)
            self.policy_improve(mdp)

    # 策略评估
    def policy_evaluate(self,mdp):
        for _ in range(1000):
            delta = 0.0
            # 遍历 states
            for state in mdp.states:
                # 判断是否是终止条件
                if mdp.is_done(state):        
                    continue
                # 获取当前state的最佳策略
                action = self.pi[self.encode_state(state)]

                # 状态通过action转移
                _, s, r = mdp.transform(state, action)
                # v = r + gamma * v'
                new_v = r + mdp.gamma * self.v[self.encode_state(s)]
                # 计算delta += v - v'
                delta += abs(self.v[self.encode_state(state)] - new_v)
                # 更新(迭代) v
                self.v[self.encode_state(state)] = new_v
            # 当 delta 小于 1e-6 当作v = v'
            if delta < 1e-6:
                break

    # 策略改善
    def policy_improve(self,mdp):
        # 遍历 states
        for state in mdp.states:
            # 判断是否是终止条件
            if mdp.is_done(state):        
                continue
            actions = mdp.state_action(state)
            # 取第一个策略
            a1 = actions[0]
            
            # 状态通过action转移
            t, s, r = mdp.transform(state, a1)
            # v = r + gamma * v'
            v1 = r + mdp.gamma * self.v[self.encode_state(s)]

            #如果策略无效,v1 = 极小值

            # 尝试其他策略
            for action in actions:
                # 状态通过action转移                 
                t, s, r = mdp.transform(state, action)
                # 当其他策略的v更高时 更新策略
                if v1 < r + mdp.gamma * self.v[self.encode_state(s)]:
                    a1 = action
                    v1 = r+mdp.gamma * self.v[self.encode_state(s)]
            # 更新当前stage的最佳策略
            self.pi[self.encode_state(state)] = a1

    def action(self, state):
        return self.pi[self.encode_state(state)]

    # 解析动作
    def encode_state(self, state):
        return "%d_%d" % (state[0],state[1])

env_maze.py

import numpy as np
import copy

class Maze():
    def __init__(self):
        # 迷宫
        self.block = np.asarray(
        # 0 1 2 3 4 5 6
        [[1,1,1,1,1,1,1]
        ,[1,0,0,0,0,0,1]
        ,[1,0,1,0,0,0,1]
        ,[1,0,1,0,1,1,1]
        ,[1,0,1,0,1,2,1]
        ,[1,0,1,0,0,0,1]
        ,[1,1,1,1,1,1,1]]).T

        # state
        self.state = np.zeros(2, dtype = int)

        # gamma 折扣因子
        self.gamma = 0.8

        # 状态空间
        self.states = None
        self.create_states()
    
    # 构建状态空间
    def create_states(self):
        self.states = []
        for i in range(0,6):
            for j in range(0,6):
                if self.block[i][j] != 1:
                    self.states.append(np.asarray([i,j]))
        
    # 获取state的动作空间        
    def state_action(self, state):
        actions = []
        if self.block[state[0]-1][state[1]] != 1:
            actions.append("left")
        if self.block[state[0]+1][state[1]] != 1:
            actions.append("right")
        if self.block[state[0]][state[1]-1] != 1:
            actions.append("up")
        if self.block[state[0]][state[1]+1] != 1:
            actions.append("down")
        return actions
    
    # 状态转移
    def transform(self, state, action):
        new_state = state.copy()
        actions = self.state_action(state)
        if action not in actions:
            print(state)
            print(action)
            print(actions)
            return state
        if action == "left":
            new_state[0] -= 1
        elif action == "right":
            new_state[0] += 1
        elif action == "up":
            new_state[1] -= 1
        elif action == "down":
            new_state[1] += 1
        return actions, new_state, self.r(new_state)
    
    # 初始化
    def reset(self):
        # 随机1~5 to x,y
        x = np.random.randint(1, 6)
        y = np.random.randint(1, 6)
        if self.block[x][y] == 1:
            self.reset()
        self.state[0] = x
        self.state[1] = y
        return self.state

    def step(self, action):
        _, state, r = self.transform(self.state, action)
        self.state = state
        return self.state, r, self.is_done(self.state)
    
    def is_done(self, state):
        if self.block[state[0]][state[1]] == 2:
            return True
        return False
    
    def r(self, state):
        if self.block[state[0]][state[1]] == 2:
            return 1
        return 0

    def render(self):
        print(self.state)

run_PIM.py

from policy_iteration_method import PIM
from env_maze import Maze

if __name__ == "__main__":

    # 导入环境
    env = Maze()

    state = env.reset()

    worker = PIM()
    worker.create(env)

    worker.policy_iteration(env,100)
    
    print(worker.v)
    print(worker.pi)

迭代100次的结果
{‘1_1’: 0.1677721600000001, ‘1_2’: 0.13421772800000006, ‘1_3’: 0.10737418240000006, ‘1_4’: 0.08589934592000005, ‘1_5’: 0.06871947673600004, ‘2_1’: 0.2097152000000001, ‘3_1’: 0.2621440000000001, ‘3_2’: 0.32768000000000014, ‘3_3’: 0.40960000000000013, ‘3_4’: 0.5120000000000001, ‘3_5’: 0.6400000000000001, ‘4_1’: 0.2097152000000001, ‘4_2’: 0.2621440000000001, ‘4_5’: 0.8, ‘5_1’: 0.1677721600000001, ‘5_2’: 0.2097152000000001, ‘5_4’: 0, ‘5_5’: 1.0}

{‘1_1’: ‘right’, ‘1_2’: ‘up’, ‘1_3’: ‘up’, ‘1_4’: ‘up’, ‘1_5’: ‘up’, ‘2_1’: ‘right’, ‘3_1’: ‘down’, ‘3_2’: ‘down’, ‘3_3’: ‘down’, ‘3_4’: ‘down’, ‘3_5’: ‘right’, ‘4_1’: ‘left’, ‘4_2’: ‘left’, ‘4_5’: ‘right’, ‘5_1’: ‘left’, ‘5_2’: ‘left’, ‘5_4’: ‘none’, ‘5_5’: ‘up’}

确实可以找到出口,比DQN要快很多,不过缺点就是好像得有全部的状态空间。

猜你喜欢

转载自blog.csdn.net/qq_27389705/article/details/89048768
今日推荐