Create a custom MATLAB environment from a template
You can define a customized reinforcement learning environment by creating and modifying template environment classes. You can use a custom template environment to
- Realize more complex environmental dynamics.
- Add custom visualizations to the environment.
- Create interfaces to third-party libraries defined in languages such as C++, Java, or Python.
Create template class
To define your custom environment, first create a template class file and specify the name of the class. For this example, name the class MyEnvironment.
rlCreateEnvTemplate("MyEnvironment")
The software creates and opens template files. The template class is a subclass of rl.env. The MATLABEnvironment abstract class, as shown in the class definition at the beginning of the template file. This abstract class is the same as the abstract class used by other MATLAB reinforcement learning environment objects.
classdef MyEnvironment < rl.env.MATLABEnvironment
By default, the template class implements a simple column balance model, similar to the column predefined environment described in Load predefined control system environment.
To define the environment to dynamically modify the template class, specify the following:
- Environmental attributes
- The method that requires the environment
- Alternative environmental methods
Environmental attributes
In the properties section of the template, specify any parameters required to create and simulate the environment. These parameters can include:
- Physical constants-the sample environment defines the acceleration due to gravity (gravity).
- Environmental Geometry-The sample environment defines the mass of carts and poles (CartMass and PoleMass) and the half length of the pole (HalfPoleLength).
- Environmental constraints-The example environment defines polar angle and cart distance thresholds (AngleThreshold and DisplacementThreshold). The environment uses these values to detect when the training episode ends.
- Environment variables needed for evaluation-define the state vector (status) of the sample environment and a flag to indicate when an episode is over (IsDone).
- The constant sample environment that defines the action or observation space defines the maximum force in the action space (MaxForce).
- Constants used to calculate reward signals-The example environment defines the RewardForNotFalling and PenaltyForFalling constants.
properties
% Specify and initialize the necessary properties of the environment
% Acceleration due to gravity in m/s^2
Gravity = 9.8
% Mass of the cart
CartMass = 1.0
% Mass of the pole
PoleMass = 0.1
% Half the length of the pole
HalfPoleLength = 0.5
% Max force the input can apply
MaxForce = 10
% Sample time
Ts = 0.02
% Angle at which to fail the episode (radians)
AngleThreshold = 12 * pi/180
% Distance at which to fail the episode
DisplacementThreshold = 2.4
% Reward each time step the cart-pole is balanced
RewardForNotFalling = 1
% Penalty when the cart-pole fails to balance
PenaltyForFalling = -10
end
properties
% Initialize system state [x,dx,theta,dtheta]'
State = zeros(4,1)
end
properties(Access = protected)
% Initialize internal flag to indicate episode termination
IsDone = false
end
Required function
The reinforcement learning environment needs to define the following functions. The functions getObservationInfo, getActionInfo, sim and validateEnvironment have been defined in the base abstract class. To create an environment, you must define constructor, reset, and step functions.
Sample constructor
The sample cart-pole constructor creates the environment in the following ways:
-
Defined actions and observation indicators. Details on creating these specifications.
-
Call the constructor of the basic abstract class.
function this = MyEnvironment()
% Initialize observation settings
ObservationInfo = rlNumericSpec([4 1]);
ObservationInfo.Name = 'CartPole States';
ObservationInfo.Description = 'x, dx, theta, dtheta';
% Initialize action settings
ActionInfo = rlFiniteSetSpec([-1 1]);
ActionInfo.Name = 'CartPole Action';
% The following line implements built-in functions of the RL environment
this = this@rl.env.MATLABEnvironment(ObservationInfo,ActionInfo);
% Initialize property values and precompute necessary values
updateActionInfo(this);
end
This example constructor does not contain any input parameters. However, you can add input parameters to the custom constructor.
Sampling reset function
The sample reset function sets the initial conditions of the model and returns the initial values of the observations. It also generates a notification that the environment has been updated by calling the envUpdatedCallback function, which is useful for updating environment visualization.
% Reset environment to initial state and return initial observation
function InitialObservation = reset(this)
% Theta (+- .05 rad)
T0 = 2 * 0.05 * rand - 0.05;
% Thetadot
Td0 = 0;
% X
X0 = 0;
% Xdot
Xd0 = 0;
InitialObservation = [X0;Xd0;T0;Td0];
this.State = InitialObservation;
% (Optional) Use notifyEnvUpdated to signal that the
% environment is updated (for example, to update the visualization)
notifyEnvUpdated(this);
end
Sampling step function
Sampling car swing function:
- Process input actions.
- Evaluate the environmental dynamics equation once.
- Calculate and return the updated observations.
- Calculate and return the reward signal.
- Check if the episode is complete, and IsDone returns the signal appropriately.
- Generate a notification that the environment has been updated.
function [Observation,Reward,IsDone,LoggedSignals] = step(this,Action)
LoggedSignals = [];
% Get action
Force = getForce(this,Action);
% Unpack state vector
XDot = this.State(2);
Theta = this.State(3);
ThetaDot = this.State(4);
% Cache to avoid recomputation
CosTheta = cos(Theta);
SinTheta = sin(Theta);
SystemMass = this.CartMass + this.PoleMass;
temp = (Force + this.PoleMass*this.HalfPoleLength*ThetaDot^2*SinTheta)...
/SystemMass;
% Apply motion equations
ThetaDotDot = (this.Gravity*SinTheta - CosTheta*temp)...
/ (this.HalfPoleLength*(4.0/3.0 - this.PoleMass*CosTheta*CosTheta/SystemMass));
XDotDot = temp - this.PoleMass*this.HalfPoleLength*ThetaDotDot*CosTheta/SystemMass;
% Euler integration
Observation = this.State + this.Ts.*[XDot;XDotDot;ThetaDot;ThetaDotDot];
% Update system states
this.State = Observation;
% Check terminal condition
X = Observation(1);
Theta = Observation(3);
IsDone = abs(X) > this.DisplacementThreshold || abs(Theta) > this.AngleThreshold;
this.IsDone = IsDone;
% Get reward
Reward = getReward(this);
% (Optional) Use notifyEnvUpdated to signal that the
% environment has been updated (for example, to update the visualization)
notifyEnvUpdated(this);
end
Optional function
You can define any other functions in the template class as needed. For example, you can create a helper function called by step or reset. The lever template model implements a getReward function to calculate the reward for each time step.
function Reward = getReward(this)
if ~this.IsDone
Reward = this.RewardForNotFalling;
else
Reward = this.PenaltyForFalling;
end
end
Environment visualization
You can add visualization effects to a custom environment by implementing the plot function. In the plot function:
- Create an instance of your own implemented graphics or visualization tool class. For this example, you will create a graph and store the handle of the graph in the environment object.
- Call the envUpdatedCallback function.
function plot(this)
% Initiate the visualization
this.Figure = figure('Visible','on','HandleVisibility','off');
ha = gca(this.Figure);
ha.XLimMode = 'manual';
ha.YLimMode = 'manual';
ha.XLim = [-3 3];
ha.YLim = [-1 2];
hold(ha,'on');
% Update the visualization
envUpdatedCallback(this)
end
For this example, store the graphics handle as a protected property of the environment object.
function envUpdatedCallback(this)
if ~isempty(this.Figure) && isvalid(this.Figure)
% Set visualization figure as the current figure
ha = gca(this.Figure);
% Extract the cart position and pole angle
x = this.State(1);
theta = this.State(3);
cartplot = findobj(ha,'Tag','cartplot');
poleplot = findobj(ha,'Tag','poleplot');
if isempty(cartplot) || ~isvalid(cartplot) ...
|| isempty(poleplot) || ~isvalid(poleplot)
% Initialize the cart plot
cartpoly = polyshape([-0.25 -0.25 0.25 0.25],[-0.125 0.125 0.125 -0.125]);
cartpoly = translate(cartpoly,[x 0]);
cartplot = plot(ha,cartpoly,'FaceColor',[0.8500 0.3250 0.0980]);
cartplot.Tag = 'cartplot';
% Initialize the pole plot
L = this.HalfPoleLength*2;
polepoly = polyshape([-0.1 -0.1 0.1 0.1],[0 L L 0]);
polepoly = translate(polepoly,[x,0]);
polepoly = rotate(polepoly,rad2deg(theta),[x,0]);
poleplot = plot(ha,polepoly,'FaceColor',[0 0.4470 0.7410]);
poleplot.Tag = 'poleplot';
else
cartpoly = cartplot.Shape;
polepoly = poleplot.Shape;
end
% Compute the new cart and pole position
[cartposx,~] = centroid(cartpoly);
[poleposx,poleposy] = centroid(polepoly);
dx = x - cartposx;
dtheta = theta - atan2(cartposx-poleposx,poleposy-0.25/2);
cartpoly = translate(cartpoly,[dx,0]);
polepoly = translate(polepoly,[dx,0]);
polepoly = rotate(polepoly,rad2deg(dtheta),[x,0.25/2]);
% Update the cart and pole positions on the plot
cartplot.Shape = cartpoly;
poleplot.Shape = polepoly;
% Refresh rendering in the figure window
drawnow();
end
end
The environment calls the envUpdatedCallback function, and therefore updates the visualization when the environment is updated.
Create a custom environment
After defining a custom environment class, create an instance of it in the MATLAB workspace. In the command line, type the following.
env = MyEnvironment;
If your constructor has input parameters, specify them after the class name. For example, MyEnvironment(arg1,arg2).
After the environment is created, the best practice is to verify the dynamics of the environment. To do this, please use the validateEnvironment function. If there is any problem with your environment implementation, this function will display an error in the command window.
validateEnvironment(env)
After verifying the environment object, you can use it to train reinforcement learning agents.