Use Q-learning and SARSA in a grid environment
case study
The grid world environment has the following configurations and rules:
1. The grid world is 5 x 5 and is bounded by boundaries. There are four possible actions (North = 1, South = 2, East = 3, West = 4).
2. The agent starts from cell [2,1] (second row, first column).
3. If the agent reaches the final state (blue) of cell [5,5], the agent will get a +10 reward.
4. The environment contains a special jump from cell [2,4] to cell [4,4], and the reward is +5.
5. The agent is blocked by an obstacle (black grid).
6. All other actions will result in a -1 reward.
Create a grid world environment
Create a basic grid world environment
env = rlPredefinedEnv("BasicGridWorld");
To specify that the initial state of the agent is always [2,1], create a reset function to return the state number of the initial state of the agent. This function will be called at the beginning of each training and simulation. Numbering starts from position [1,1]. When you move down the first column and then in each subsequent column, the status number will increase. Therefore, create an anonymous function handle that sets the initial state to 2.
env.ResetFcn = @() 2;
Fix random generator seed to improve repeatability.
rng(0)
Create Q learning agent
To create a Q-learning agent, first use the observation and operation specifications in the grid world environment to create a Q table. Set the indicated learning rate to 1.
qTable = rlTable(getObservationInfo(env),getActionInfo(env));
qRepresentation = rlQValueRepresentation(qTable,getObservationInfo(env),getActionInfo(env));
qRepresentation.Options.LearnRate = 1;
Next, use this table to represent the creation of a Q learning agent and configure epsilon-greedy exploration.
agentOpts = rlQAgentOptions;
agentOpts.EpsilonGreedyExploration.Epsilon = .04;
qAgent = rlQAgent(qRepresentation,agentOpts);
Train Q learning agent
To train the agent, first specify the training options. For this example, use the following options:
1. The training is up to 200 episodes. Specify that each episode lasts up to 50 time steps.
2. When the average cumulative reward obtained by the agent in 30 consecutive episodes is greater than 10, please stop training.
trainOpts = rlTrainingOptions;
trainOpts.MaxStepsPerEpisode = 50;
trainOpts.MaxEpisodes= 200;
trainOpts.StopTrainingCriteria = "AverageReward";
trainOpts.StopTrainingValue = 11;
trainOpts.ScoreAveragingWindowLength = 30;
Use this train function to train Q learning agent. The training may take a few minutes to complete. In order to save the time of running this example, please load the pre-trained agent false by setting doTraining to. To train the agent yourself, please set doTraining to true.
doTraining = false;
if doTraining
% Train the agent.
trainingStats = train(qAgent,env,trainOpts);
else
% Load the pretrained agent for the example.
load('basicGWQAgent.mat','qAgent')
end
The "Scenario Manager" window will open and display the training progress.
Verify Q learning results
To verify the training results, simulate the agent in the training environment.
Before running the simulation, visualize the environment and configure the visualization to keep track of the agent's state.
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
Use the sim function to simulate the agent in the environment.
sim(qAgent,env)
Agent tracking shows that the agent has successfully found the jump from unit [2,4] to unit [4,4].
Create and train SARSA agents
To create a SARSA agent, use the same Q table representation and epsilon-greedy configuration as the Q learning agent.
agentOpts = rlSARSAAgentOptions;
agentOpts.EpsilonGreedyExploration.Epsilon = 0.04;
sarsaAgent = rlSARSAAgent(qRepresentation,agentOpts);
Use this train function to train SARSA agents. The training may take a few minutes to complete. In order to save the time of running this example, please load the pre-trained agent false by setting doTraining to. To train the agent yourself, please set doTraining to true.
doTraining = false;
if doTraining
% Train the agent.
trainingStats = train(sarsaAgent,env,trainOpts);
else
% Load the pretrained agent for the example.
load('basicGWSarsaAgent.mat','sarsaAgent')
end
Validate SARSA training
To verify the training results, simulate the agent in the training environment.
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
Simulate the agent in the environment.
sim(sarsaAgent,env)
The SARSA agent finds the same grid world solution as the Q learning agent.