MATLAB reinforcement learning combat (12) to create an agent for custom reinforcement learning algorithms


This example shows how to create a custom agent for your own custom reinforcement learning algorithm. Doing so allows you to take advantage of the following built-in features of the Reinforcement Learning Toolbox™ software.

  1. Access all agent functions, including train and sim

  2. Use Episode Manager to visualize training progress

  3. Train the agent in the Simulink® environment

In this example, you convert a custom REINFORCE training loop into a custom agent class. For more information about REINFORCE custom training loop, please refer to T rain Reinforcement Learning Policy Using Custom Training Loop . For more information on writing custom agent classes, see Custom Agents .

Fixed reproducibility of random generator seeds.

rng(0)

Create the environment

Create and use the same training environment used in " Train Reinforcement Learning Policy Using Custom Training Loop example ". This environment is a balance bar environment with discrete action space. Use the rlPredefinedEnv function to create an environment.

env = rlPredefinedEnv('CartPole-Discrete');

Extract observations and action specifications from the environment.

obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);

Get the number of observations (numObs) and the number of actions (numAct).

numObs = obsInfo.Dimension(1);
numAct = numel(actInfo.Elements);

Define strategy

In this example, the reinforcement learning strategy is a discrete action stochastic strategy. It is represented by a deep neural network, which contains fullyConnectedLayer, reluLayer and softmaxLayer layers. Given the current observations, the network outputs the probability of each discrete action. softmaxLayer can ensure that the probability value range of the representation output is [0 1], and the sum of all probabilities is 1.

Create deep neural networks for actors.

actorNetwork = [featureInputLayer(numObs,'Normalization','none','Name','state')
                fullyConnectedLayer(24,'Name','fc1')
                reluLayer('Name','relu1')
                fullyConnectedLayer(24,'Name','fc2')
                reluLayer('Name','relu2')
                fullyConnectedLayer(2,'Name','output')
                softmaxLayer('Name','actionProb')];

Use the rlStochasticActorRepresentation object to create an actor representation.

actorOpts = rlRepresentationOptions('LearnRate',1e-3,'GradientThreshold',1);
actor = rlStochasticActorRepresentation(actorNetwork,...
    obsInfo,actInfo,'Observation','state',actorOpts);

Custom Agent Class

To define your custom agent, first create a class, which is a subclass of the rl.agent.CustomAgent class. The custom agent class for this example is defined in CustomReinforceAgent.m.

The CustomReinforceAgent class has the following class definition, which indicates the name of the agent class and the associated abstract agent.

classdef CustomReinforceAgent < rl.agent.CustomAgent

To define your agent, you must specify the following:

  1. Agent attributes

  2. Constructor

  3. Commenter’s representation for evaluating long-term reward discounts (if required for learning)

  4. Representation of actors who choose actions based on current observations (if required for learning)

  5. The required agent method

  6. Optional agent method

Agent attributes

In the properties section of the class file, specify any parameters required to create and train the agent.

The rl.Agent.CustomAgent class already contains the attributes of agent sampling time (SampleTime) and operation and observation specifications (ActionInfo and ObservationInfo, respectively).

The custom REINFORCE agent defines the following other agent attributes.

properties
    % Actor representation
    Actor
    
    % Agent options
    Options
    
    % Experience buffer
    ObservationBuffer
    ActionBuffer
    RewardBuffer
end

properties (Access = private)
    % Training utilities
    Counter
    NumObservation
    NumAction
end

Constructor

To create a custom agent, you must define a constructor. The constructor performs the following operations.

  1. Define actions and observe norms. For more information on creating these specifications, see rlNumericSpec and rlFiniteSetSpec .

  2. Set agent properties.

  3. Call the constructor of the basic abstract class.

  4. Define the sampling time (required for training in the Simulink environment).

For example, the CustomREINFORCEAgent constructor defines the action and observation space based on the input actor representation.

function obj = CustomReinforceAgent(Actor,Options)
    %CUSTOMREINFORCEAGENT Construct custom agent
    %   AGENT = CUSTOMREINFORCEAGENT(ACTOR,OPTIONS) creates custom
    %   REINFORCE AGENT from rlStochasticActorRepresentation ACTOR
    %   and structure OPTIONS. OPTIONS has fields:
    %       - DiscountFactor
    %       - MaxStepsPerEpisode
    
    % (required) Call the abstract class constructor.
    obj = obj@rl.agent.CustomAgent();
    obj.ObservationInfo = Actor.ObservationInfo;
    obj.ActionInfo = Actor.ActionInfo;
    
    % (required for Simulink environment) Register sample time. 
    % For MATLAB environment, use -1.
    obj.SampleTime = -1;
    
    % (optional) Register actor and agent options.
    Actor = setLoss(Actor,@lossFunction);
    obj.Actor = Actor;
    obj.Options = Options;
    
    % (optional) Cache the number of observations and actions.
    obj.NumObservation = prod(obj.ObservationInfo.Dimension);
    obj.NumAction = prod(obj.ActionInfo.Dimension);
    
    % (optional) Initialize buffer and counter.
    reset(obj);
end

The constructor uses the function handle to set the loss function represented by the actor to lossFunction, which is implemented as a local function in CustomREINFORCEAgent.m.

function loss = lossFunction(policy,lossData)

    % Create the action indication matrix.
    batchSize = lossData.batchSize;
    Z = repmat(lossData.actInfo.Elements',1,batchSize);
    actionIndicationMatrix = lossData.actionBatch(:,:) == Z;
    
    % Resize the discounted return to the size of policy.
    G = actionIndicationMatrix .* lossData.discountedReturn;
    G = reshape(G,size(policy));
    
    % Round any policy values less than eps to eps.
    policy(policy < eps) = eps;
    
    % Compute the loss.
    loss = -sum(G .* log(policy),'all');
    
end

related functions

To create a custom reinforcement learning agent, you must define the following implementation functions.

  1. getActionImpl —Evaluate the agent's strategy and select an agent during the simulation.

  2. getActionWithExplorationImpl —Evaluate the strategy and select exploratory actions during training.

  3. learningImpl —how the agent learns from current experience

To call these functions in your own code, use the wrapper method in the abstract base class. For example, to call getActionImpl, use getAction. The wrapper method has the same input and output parameters as the implementation method.

getActionImpl Function

The getActionImpl function is used to evaluate the strategy of the agent, and select an operation when the sim function is used to simulate the agent. This function must have the following signature, where obj is the agent object, observation is the current observation value, and Action is the selected action.

 function Action = getActionImpl(obj,Observation)

For a custom REINFORCE agent, you can select an action by calling the getAction function of the actor representation . The discrete rlStochasticActorRepresentation generates a discrete distribution based on observations, and samples an action from the distribution.

function Action = getActionImpl(obj,Observation)
    % Compute an action using the policy given the current 
    % observation.
    
    Action = getAction(obj.Actor,Observation);
end

getActionWithExplorationImpl Function

When using the training function to train the agent, the getActionWithExplorationImpl function uses the agent's exploration model to select actions. Using this function, you can implement exploration techniques such as epsilon-greedy exploration or adding Gaussian noise. This function must have the following signature, where obj is the agent object, observation is the current observation value, and Action is the selected action.

function Action = getActionWithExplorationImpl(obj,Observation)

For a custom REINFORCE agent, the getActionWithExplorationImpl function is the same as getActionImpl. By default, random actors always explore, that is, they always choose an action based on a probability distribution.

function Action = getActionWithExplorationImpl(obj,Observation)
    % Compute an action using the exploration policy given the  
    % current observation.
    
    % REINFORCE: Stochastic actors always explore by default
    % (sample from a probability distribution)
    Action = getAction(obj.Actor,Observation);
end

learnImpl Function

The learningImpl function defines how the agent learns from current experience. This function implements the agent's custom learning algorithm by updating the strategy parameters and choosing the action to explore the next state. The function must have the following signature, where obj is the agent object, Experience is the current agent experience, and Action is the selected operation.

function Action = learnImpl(obj,Experience)

The agent experience is the cell array Experience = {state, action, reward, nextstate, isdone}. Here:

  1. The state is the current observation.

  2. The action is the current action. This is different from the output parameter Action, which is the next state action.

  3. The reward is the current reward.

  4. nextState is the next observation value.

  5. isDone is a logical flag indicating that the training episode has been completed.

For the custom REINFORCE agent, please repeat steps 2 to 7 of the custom training cycle in " Strengthening training strategy using custom training cycle ". You will omit steps 1, 8, and 9, because you will use the built-in training function to train your agent.

function Action = learnImpl(obj,Experience)
    % Define how the agent learns from an Experience, which is a
    % cell array with the following format.
    %   Experience = {
    
    observation,action,reward,nextObservation,isDone}
    
    % Reset buffer at the beginning of the episode.
    if obj.Counter < 2
        resetBuffer(obj);
    end
    
    % Extract data from experience.
    Obs = Experience{
    
    1};
    Action = Experience{
    
    2};
    Reward = Experience{
    
    3};
    NextObs = Experience{
    
    4};
    IsDone = Experience{
    
    5};
    
    % Save data to buffer.
    obj.ObservationBuffer(:,:,obj.Counter) = Obs{
    
    1};
    obj.ActionBuffer(:,:,obj.Counter) = Action{
    
    1};
    obj.RewardBuffer(:,obj.Counter) = Reward;
    
    if ~IsDone
        % Choose an action for the next state.
        
        Action = getActionWithExplorationImpl(obj, NextObs);
        obj.Counter = obj.Counter + 1;
    else
        % Learn from episodic data.
        
        % Collect data from the buffer.
        BatchSize = min(obj.Counter,obj.Options.MaxStepsPerEpisode);
        ObservationBatch = obj.ObservationBuffer(:,:,1:BatchSize);
        ActionBatch = obj.ActionBuffer(:,:,1:BatchSize);
        RewardBatch = obj.RewardBuffer(:,1:BatchSize);
        
        % Compute the discounted future reward.
        DiscountedReturn = zeros(1,BatchSize);
        for t = 1:BatchSize
            G = 0;
            for k = t:BatchSize
                G = G + obj.Options.DiscountFactor ^ (k-t) * RewardBatch(k);
            end
            DiscountedReturn(t) = G;
        end
        
        % Organize data to pass to the loss function.
        LossData.batchSize = BatchSize;
        LossData.actInfo = obj.ActionInfo;
        LossData.actionBatch = ActionBatch;
        LossData.discountedReturn = DiscountedReturn;
        
        % Compute the gradient of the loss with respect to the
        % actor parameters.
        ActorGradient = gradient(obj.Actor,'loss-parameters',...
            {
    
    ObservationBatch},LossData);
        
        % Update the actor parameters using the computed gradients.
        obj.Actor = optimize(obj.Actor,ActorGradient);
        
        % Reset the counter.
        obj.Counter = 1;
    end
end

Optional function

(Optional) You can define how to reset the agent at the beginning of training by specifying the resetImpl function with the following function signature, where obj is the agent object.

function resetImpl(obj)

Using this function, you can set the agent to known or random conditions before training.

function resetImpl(obj)
    % (Optional) Define how the agent is reset before training/
    
    resetBuffer(obj);
    obj.Counter = 1;
end

In addition, you can define any other helper functions in the custom agent class as needed. For example, the custom REINFORCE agent defines the resetBuffer function to reinitialize the experience buffer at the beginning of each training episode.

function resetBuffer(obj)
    % Reinitialize all experience buffers.
    
    obj.ObservationBuffer = zeros(obj.NumObservation,1,obj.Options.MaxStepsPerEpisode);
    obj.ActionBuffer = zeros(obj.NumAction,1,obj.Options.MaxStepsPerEpisode);
    obj.RewardBuffer = zeros(1,obj.Options.MaxStepsPerEpisode);
end

Create a custom agent

After defining the custom agent class, create an instance of it in the MATLAB workspace. To create a custom REINFORCE agent, first specify agent options.

options.MaxStepsPerEpisode = 250;
options.DiscountFactor = 0.995;

Then, using the options and the previously defined actor representation, call the custom agent constructor.

agent = CustomReinforceAgent(actor,options);

Train a custom agent

Configure training to use the following options.

  1. Set the training to last up to 5000 episodes, and each episode lasts up to 250 steps.

  2. The training is terminated when the maximum number of episodes is reached or the average reward in 100 episodes reaches a value of 240.

numEpisodes = 5000;
aveWindowSize = 100;
trainingTerminationValue = 240;
trainOpts = rlTrainingOptions(...
    'MaxEpisodes',numEpisodes,...
    'MaxStepsPerEpisode',options.MaxStepsPerEpisode,...
    'ScoreAveragingWindowLength',aveWindowSize,...
    'StopTrainingValue',trainingTerminationValue);

Use the training function to train the agent. Training this agent is a computationally intensive process that takes several minutes to complete. To save time running this example, please load the pre-trained agent by setting doTraining to false. To train the agent yourself, set doTraining to true .

doTraining = false;
if doTraining
    % Train the agent.
    trainStats = train(agent,env,trainOpts);
else
    % Load pretrained agent for the example.
    load('CustomReinforce.mat','agent');
end

Insert picture description here

Custom agent simulation

Enable environment visualization, which is updated every time the environment step function is called.

plot(env)

To verify the performance of the trained agent, please simulate it in an inverted pendulum environment. For more information about agent simulation, see rlSimulationOptions and sim.

simOpts = rlSimulationOptions('MaxSteps',options.MaxStepsPerEpisode);
experience = sim(env,agent,simOpts);

Insert picture description here

Guess you like

Origin blog.csdn.net/wangyifan123456zz/article/details/109651707