MATLAB Reinforcement Learning Toolbox (6) Create a custom MATLAB environment from a template


You can define a customized reinforcement learning environment by creating and modifying template environment classes. You can use a custom template environment to

  1. Realize more complex environmental dynamics.
  2. Add custom visualizations to the environment.
  3. Create interfaces to third-party libraries defined in languages ​​such as C++, Java, or Python.

Create template class

To define your custom environment, first create a template class file and specify the name of the class. For this example, name the class MyEnvironment.

rlCreateEnvTemplate("MyEnvironment")

The software creates and opens template files. The template class is a subclass of rl.env. The MATLABEnvironment abstract class, as shown in the class definition at the beginning of the template file. This abstract class is the same as the abstract class used by other MATLAB reinforcement learning environment objects.

classdef MyEnvironment < rl.env.MATLABEnvironment

By default, the template class implements a simple column balance model, similar to the column predefined environment described in Load predefined control system environment.

To define the environment to dynamically modify the template class, specify the following:

  1. Environmental attributes
  2. The method that requires the environment
  3. Alternative environmental methods

Environmental attributes

In the properties section of the template, specify any parameters required to create and simulate the environment. These parameters can include:

  1. Physical constants-the sample environment defines the acceleration due to gravity (gravity).
  2. Environmental Geometry-The sample environment defines the mass of carts and poles (CartMass and PoleMass) and the half length of the pole (HalfPoleLength).
  3. Environmental constraints-The example environment defines polar angle and cart distance thresholds (AngleThreshold and DisplacementThreshold). The environment uses these values ​​to detect when the training episode ends.
  4. Environment variables needed for evaluation-define the state vector (status) of the sample environment and a flag to indicate when an episode is over (IsDone).
  5. The constant sample environment that defines the action or observation space defines the maximum force in the action space (MaxForce).
  6. Constants used to calculate reward signals-The example environment defines the RewardForNotFalling and PenaltyForFalling constants.
properties
    % Specify and initialize the necessary properties of the environment  
    % Acceleration due to gravity in m/s^2
    Gravity = 9.8
    
    % Mass of the cart
    CartMass = 1.0
    
    % Mass of the pole
    PoleMass = 0.1
    
    % Half the length of the pole
    HalfPoleLength = 0.5
    
    % Max force the input can apply
    MaxForce = 10
           
    % Sample time
    Ts = 0.02
    
    % Angle at which to fail the episode (radians)
    AngleThreshold = 12 * pi/180
        
    % Distance at which to fail the episode
    DisplacementThreshold = 2.4
        
    % Reward each time step the cart-pole is balanced
    RewardForNotFalling = 1
    
    % Penalty when the cart-pole fails to balance
    PenaltyForFalling = -10 
end
    
properties
    % Initialize system state [x,dx,theta,dtheta]'
    State = zeros(4,1)
end

properties(Access = protected)
    % Initialize internal flag to indicate episode termination
    IsDone = false        
end

Required function

The reinforcement learning environment needs to define the following functions. The functions getObservationInfo, getActionInfo, sim and validateEnvironment have been defined in the base abstract class. To create an environment, you must define constructor, reset, and step functions.
Insert picture description here

Sample constructor

The sample cart-pole constructor creates the environment in the following ways:

  1. Defined actions and observation indicators. Details on creating these specifications.

  2. Call the constructor of the basic abstract class.

function this = MyEnvironment()
    % Initialize observation settings
    ObservationInfo = rlNumericSpec([4 1]);
    ObservationInfo.Name = 'CartPole States';
    ObservationInfo.Description = 'x, dx, theta, dtheta';

    % Initialize action settings   
    ActionInfo = rlFiniteSetSpec([-1 1]);
    ActionInfo.Name = 'CartPole Action';

    % The following line implements built-in functions of the RL environment
    this = this@rl.env.MATLABEnvironment(ObservationInfo,ActionInfo);

    % Initialize property values and precompute necessary values
    updateActionInfo(this);
end

This example constructor does not contain any input parameters. However, you can add input parameters to the custom constructor.

Sampling reset function

The sample reset function sets the initial conditions of the model and returns the initial values ​​of the observations. It also generates a notification that the environment has been updated by calling the envUpdatedCallback function, which is useful for updating environment visualization.

% Reset environment to initial state and return initial observation
function InitialObservation = reset(this)
    % Theta (+- .05 rad)
    T0 = 2 * 0.05 * rand - 0.05;  
    % Thetadot
    Td0 = 0;
    % X 
    X0 = 0;
    % Xdot
    Xd0 = 0;

    InitialObservation = [X0;Xd0;T0;Td0];
    this.State = InitialObservation;

    % (Optional) Use notifyEnvUpdated to signal that the 
    % environment is updated (for example, to update the visualization)
    notifyEnvUpdated(this);
end

Sampling step function

Sampling car swing function:

  1. Process input actions.
  2. Evaluate the environmental dynamics equation once.
  3. Calculate and return the updated observations.
  4. Calculate and return the reward signal.
  5. Check if the episode is complete, and IsDone returns the signal appropriately.
  6. Generate a notification that the environment has been updated.
function [Observation,Reward,IsDone,LoggedSignals] = step(this,Action)
    LoggedSignals = [];

    % Get action
    Force = getForce(this,Action);            

    % Unpack state vector
    XDot = this.State(2);
    Theta = this.State(3);
    ThetaDot = this.State(4);

    % Cache to avoid recomputation
    CosTheta = cos(Theta);
    SinTheta = sin(Theta);            
    SystemMass = this.CartMass + this.PoleMass;
    temp = (Force + this.PoleMass*this.HalfPoleLength*ThetaDot^2*SinTheta)...
        /SystemMass;

    % Apply motion equations            
    ThetaDotDot = (this.Gravity*SinTheta - CosTheta*temp)...
        / (this.HalfPoleLength*(4.0/3.0 - this.PoleMass*CosTheta*CosTheta/SystemMass));
    XDotDot  = temp - this.PoleMass*this.HalfPoleLength*ThetaDotDot*CosTheta/SystemMass;

    % Euler integration
    Observation = this.State + this.Ts.*[XDot;XDotDot;ThetaDot;ThetaDotDot];

    % Update system states
    this.State = Observation;

    % Check terminal condition
    X = Observation(1);
    Theta = Observation(3);
    IsDone = abs(X) > this.DisplacementThreshold || abs(Theta) > this.AngleThreshold;
    this.IsDone = IsDone;

    % Get reward
    Reward = getReward(this);

    % (Optional) Use notifyEnvUpdated to signal that the 
    % environment has been updated (for example, to update the visualization)
    notifyEnvUpdated(this);
end

Optional function

You can define any other functions in the template class as needed. For example, you can create a helper function called by step or reset. The lever template model implements a getReward function to calculate the reward for each time step.

function Reward = getReward(this)
    if ~this.IsDone
        Reward = this.RewardForNotFalling;
    else
        Reward = this.PenaltyForFalling;
    end          
end

Environment visualization

You can add visualization effects to a custom environment by implementing the plot function. In the plot function:

  1. Create an instance of your own implemented graphics or visualization tool class. For this example, you will create a graph and store the handle of the graph in the environment object.
  2. Call the envUpdatedCallback function.
function plot(this)
    % Initiate the visualization
    this.Figure = figure('Visible','on','HandleVisibility','off');
    ha = gca(this.Figure);
    ha.XLimMode = 'manual';
    ha.YLimMode = 'manual';
    ha.XLim = [-3 3];
    ha.YLim = [-1 2];
    hold(ha,'on');
    % Update the visualization
    envUpdatedCallback(this)
end

For this example, store the graphics handle as a protected property of the environment object.

function envUpdatedCallback(this)
    if ~isempty(this.Figure) && isvalid(this.Figure)
        % Set visualization figure as the current figure
        ha = gca(this.Figure);

        % Extract the cart position and pole angle
        x = this.State(1);
        theta = this.State(3);

        cartplot = findobj(ha,'Tag','cartplot');
        poleplot = findobj(ha,'Tag','poleplot');
        if isempty(cartplot) || ~isvalid(cartplot) ...
                || isempty(poleplot) || ~isvalid(poleplot)
            % Initialize the cart plot
            cartpoly = polyshape([-0.25 -0.25 0.25 0.25],[-0.125 0.125 0.125 -0.125]);
            cartpoly = translate(cartpoly,[x 0]);
            cartplot = plot(ha,cartpoly,'FaceColor',[0.8500 0.3250 0.0980]);
            cartplot.Tag = 'cartplot';

            % Initialize the pole plot
            L = this.HalfPoleLength*2;
            polepoly = polyshape([-0.1 -0.1 0.1 0.1],[0 L L 0]);
            polepoly = translate(polepoly,[x,0]);
            polepoly = rotate(polepoly,rad2deg(theta),[x,0]);
            poleplot = plot(ha,polepoly,'FaceColor',[0 0.4470 0.7410]);
            poleplot.Tag = 'poleplot';
        else
            cartpoly = cartplot.Shape;
            polepoly = poleplot.Shape;
        end

        % Compute the new cart and pole position
        [cartposx,~] = centroid(cartpoly);
        [poleposx,poleposy] = centroid(polepoly);
        dx = x - cartposx;
        dtheta = theta - atan2(cartposx-poleposx,poleposy-0.25/2);
        cartpoly = translate(cartpoly,[dx,0]);
        polepoly = translate(polepoly,[dx,0]);
        polepoly = rotate(polepoly,rad2deg(dtheta),[x,0.25/2]);

        % Update the cart and pole positions on the plot
        cartplot.Shape = cartpoly;
        poleplot.Shape = polepoly;

        % Refresh rendering in the figure window
        drawnow();
    end
end     

The environment calls the envUpdatedCallback function, and therefore updates the visualization when the environment is updated.

Create a custom environment

After defining a custom environment class, create an instance of it in the MATLAB workspace. In the command line, type the following.

env = MyEnvironment;

If your constructor has input parameters, specify them after the class name. For example, MyEnvironment(arg1,arg2).

After the environment is created, the best practice is to verify the dynamics of the environment. To do this, please use the validateEnvironment function. If there is any problem with your environment implementation, this function will display an error in the command window.

validateEnvironment(env)

After verifying the environment object, you can use it to train reinforcement learning agents.

Guess you like

Origin blog.csdn.net/wangyifan123456zz/article/details/109477303