OpenAI Gym|cart-pole-v1任务的环境源码

 本文代码来源于Gym官方文档

https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.pyhttps://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py

cart-pole-v1任务的实现见pytorch实现CartPole-v1任务的DQN代码_bujbujbiu的博客-CSDN博客

描述

一根杆子由一个非驱动的接头连接到一辆小车上,小车沿着无摩擦的轨道移动。杆子被垂直放置在手推车上,目标是通过在手推车上施加左右方向的力来平衡杆子。

Action Space

动作是shape为(1, )的ndarray数组,可以取值{0,1},表示小车被施加力的方向

 施加力所减少或增加的速度不是固定的,它取决于杆子指向的角度。杆子重心改变了移动下面的手推车所需的能量

Observation Space

状态是shape为(4, )的ndarray数组,包括小车位置,小车速度,杆子角度,杆子角速度

 上述定义的范围只是状态空间中各要素的可能取值,但是不是episode运行允许的范围,终止条件如下:

(1)小车x轴的位置(index 0)可以取值(-4.8,4.8),但是如果小车离开(-2.4,2.4)的范围,episode终止

(2)杆子角度可以在(-0.418, 0.418) radians (or **±24°**)间,但是如果杆子超过(-0.2095, 0.2095) (or **±12°**)范围,episode终止

Rewards

训楼目标是尽可能久的保持杆子不倒,因此每步都能获得+1的奖励,包括终止步,奖励阈值475

初始状态

所有观察值都被赋于(-0.05,0.05)中的一个均匀随机值

Episode终止

有下列情形之一的,episode终止:

(1)杆子角度大于±12°

(2)小车位置大于±2.4(小车中心到达显示屏边缘)

(3)episode长度大于500 (v0为200)

参数

gym.make('CartPole-v1')

完整代码

import math
from typing import Optional, Union

import numpy as np
import pygame
from pygame import gfxdraw

import gym
from gym import spaces, logger
from gym.utils import seeding

class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):

    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 50}

    def __init__(self):
        # 以下参数用于执行动作函数中,计算施加一定力对小车和杆子的影响
        self.gravity = 9.8
        self.masscart = 1.0
        self.masspole = 0.1
        self.total_mass = self.masspole + self.masscart
        self.length = 0.5  # 杆子长度一半
        self.polemass_length = self.masspole * self.length
        self.force_mag = 10.0
        self.tau = 0.02  # 状态更新时间间隔(秒)
        self.kinematics_integrator = "euler"

        # 杆子角度阈值=12度,小车位置阈值=2.4
        self.theta_threshold_radians = 12 * 2 * math.pi / 360
        self.x_threshold = 2.4

        # 定义Observation Space的四个要素
        high = np.array(
            [
                self.x_threshold * 2, # 小车位置4.8
                np.finfo(np.float32).max, # 返回float32类型数据最大值
                self.theta_threshold_radians * 2, # 杆子角度24度
                np.finfo(np.float32).max,
            ],
            dtype=np.float32,
        )
        self.observation_space = spaces.Box(-high, high, dtype=np.float32)
        # 定义action space,Discrete(2)={0,1}
        self.action_space = spaces.Discrete(2)

        self.screen = None
        self.clock = None
        self.isopen = True
        self.state = None

        self.steps_beyond_done = None

    def step(self, action):
        # assert相当于if else语句,满足前面条件则正常运行,否则报错或中断
        err_msg = f"{action!r} ({type(action)}) invalid"
        assert self.action_space.contains(action), err_msg
        assert self.state is not None, "Call reset before using step method."
        x, x_dot, theta, theta_dot = self.state
        # 力向右为正,像左为负
        force = self.force_mag if action == 1 else -self.force_mag
        costheta = math.cos(theta)
        sintheta = math.sin(theta)

        # 施加力对杆子和小车影响的数学公式https://coneural.org/florian/papers/05_cart_pole.pdf
        temp = (
            force + self.polemass_length * theta_dot ** 2 * sintheta
        ) / self.total_mass
        thetaacc = (self.gravity * sintheta - costheta * temp) / (
            self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass)
        )
        xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass

        # 更新状态值
        if self.kinematics_integrator == "euler":
            x = x + self.tau * x_dot
            x_dot = x_dot + self.tau * xacc
            theta = theta + self.tau * theta_dot
            theta_dot = theta_dot + self.tau * thetaacc
        else:
            x_dot = x_dot + self.tau * xacc
            x = x + self.tau * x_dot
            theta_dot = theta_dot + self.tau * thetaacc
            theta = theta + self.tau * theta_dot

        self.state = (x, x_dot, theta, theta_dot)

        # 判断是否出现终止条件
        done = bool(
            x < -self.x_threshold
            or x > self.x_threshold
            or theta < -self.theta_threshold_radians
            or theta > self.theta_threshold_radians
        )

        # 根据执行动作后的状态计算奖励函数
        if not done:
            reward = 1.0
        elif self.steps_beyond_done is None:
            # Pole just fell!
            self.steps_beyond_done = 0
            reward = 1.0
        else:
            if self.steps_beyond_done == 0:
                logger.warn(
                    "You are calling 'step()' even though this "
                    "environment has already returned done = True. You "
                    "should always call 'reset()' once you receive 'done = "
                    "True' -- any further steps are undefined behavior."
                )
            self.steps_beyond_done += 1
            reward = 0.0
        # 返回执行一个动作后的新状态,奖励,是否终止
        return np.array(self.state, dtype=np.float32), reward, done, {}
    # 重置环境
    def reset(
        self,
        *,
        seed: Optional[int] = None,
        return_info: bool = False,
        options: Optional[dict] = None,
    ):
        super().reset(seed=seed)
        self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
        self.steps_beyond_done = None
        if not return_info:
            return np.array(self.state, dtype=np.float32)
        else:
            return np.array(self.state, dtype=np.float32), {}
    # 图像引擎,用于展示训练过程中物体的变化(可不要)
    def render(self, mode="human"):
        screen_width = 600
        screen_height = 400

        world_width = self.x_threshold * 2
        scale = screen_width / world_width
        polewidth = 10.0
        polelen = scale * (2 * self.length)
        cartwidth = 50.0
        cartheight = 30.0

        if self.state is None:
            return None

        x = self.state

        if self.screen is None:
            pygame.init()
            pygame.display.init()
            self.screen = pygame.display.set_mode((screen_width, screen_height))
        if self.clock is None:
            self.clock = pygame.time.Clock()

        self.surf = pygame.Surface((screen_width, screen_height))
        self.surf.fill((255, 255, 255))

        l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
        axleoffset = cartheight / 4.0
        cartx = x[0] * scale + screen_width / 2.0  # MIDDLE OF CART
        carty = 100  # TOP OF CART
        cart_coords = [(l, b), (l, t), (r, t), (r, b)]
        cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords]
        gfxdraw.aapolygon(self.surf, cart_coords, (0, 0, 0))
        gfxdraw.filled_polygon(self.surf, cart_coords, (0, 0, 0))

        l, r, t, b = (
            -polewidth / 2,
            polewidth / 2,
            polelen - polewidth / 2,
            -polewidth / 2,
        )

        pole_coords = []
        for coord in [(l, b), (l, t), (r, t), (r, b)]:
            coord = pygame.math.Vector2(coord).rotate_rad(-x[2])
            coord = (coord[0] + cartx, coord[1] + carty + axleoffset)
            pole_coords.append(coord)
        gfxdraw.aapolygon(self.surf, pole_coords, (202, 152, 101))
        gfxdraw.filled_polygon(self.surf, pole_coords, (202, 152, 101))

        gfxdraw.aacircle(
            self.surf,
            int(cartx),
            int(carty + axleoffset),
            int(polewidth / 2),
            (129, 132, 203),
        )
        gfxdraw.filled_circle(
            self.surf,
            int(cartx),
            int(carty + axleoffset),
            int(polewidth / 2),
            (129, 132, 203),
        )

        gfxdraw.hline(self.surf, 0, screen_width, carty, (0, 0, 0))

        self.surf = pygame.transform.flip(self.surf, False, True)
        self.screen.blit(self.surf, (0, 0))
        if mode == "human":
            pygame.event.pump()
            self.clock.tick(self.metadata["render_fps"])
            pygame.display.flip()

        if mode == "rgb_array":
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
            )
        else:
            return self.isopen

    def close(self):
        if self.screen is not None:
            pygame.display.quit()
            pygame.quit()
            self.isopen = False

猜你喜欢

转载自blog.csdn.net/weixin_45526117/article/details/123775401