Flappy Bird QDN PyTorch Blog - Code Interpretation

introduce

In this blog, we will introduce how to use the QDN (Quantile Dueling Network) algorithm to train the Flappy Bird game under the PyTorch platform. The QDN algorithm is a reinforcement learning algorithm that is particularly suitable for dealing with environments with uncertainty, such as games.
Insert image description here

Environment configuration

Before you begin, make sure you have configured the following environment:

(rl) PS C:\Users\dd> conda list
# packages in environment at D:\Software\Miniconda3\envs\rl:
#
# Name                    Version                   Build  Channel
numpy                     1.22.3           py38h7a0a035_0    defaults
numpy-base                1.22.3           py38hca35cd5_0    defaults
opencv-python             4.6.0.66                 pypi_0    pypi
pillow                    6.2.1                    pypi_0    pypi
pygame                    2.1.2                    pypi_0    pypi
pygments                  2.11.2             pyhd3eb1b0_0    defaults
python                    3.8.13               h6244533_0    defaults
python-dateutil           2.8.2              pyhd3eb1b0_0    defaults
python_abi                3.8                      2_cp38    conda-forge
pytorch                   1.8.2           py3.8_cuda11.1_cudnn8_0    pytorch-lts

Please make sure your environment includes the dependencies listed above, specifically PyTorch version 1.8.2.

Project directory structure

Here, we will briefly introduce the directory structure of the project so that you can better understand the organization and file layout of the entire project.

项目根目录
|-- qdn_train.py          # QDN算法训练脚本
|-- flappy_bird.py        # Flappy Bird游戏实现
|-- model.py              # QDN模型定义
|-- replay_buffer.py      # 经验回放缓存实现
|-- utils.py              # 辅助工具函数
|-- ...

QDN algorithm

The QDN (Quantile Dueling Network) algorithm is a reinforcement learning algorithm used to train agents to make decisions in the Flappy Bird game. Here are the key takeaways from the algorithm:

  1. Replay Memory: At each time step, the agent interacts with the environment and stores its experience in the memory. These experiences include the current state, the action selected, the reward obtained, the next state, and whether the game is terminated.

  2. Neural network architecture: A neural network is implemented using PyTorch, which includes convolutional layers and fully connected layers. The output of the neural network is the Q-value for each possible action.

  3. Training process: At each time step, the agent selects an action based on the current state. By interacting with the environment, the next state, reward, and termination signal are obtained. This information is used to update the weights of the neural network to maximize the expected cumulative reward.

  4. Epsilon-Greedy Exploration: In the early stages of training, the agent relies more on exploration to discover more possible strategies by randomly selecting actions. As training proceeds, the exploration rate gradually decreases.

  5. Target Network: In order to stabilize training, a target network is introduced to regularly copy parameters from the main network. This helps reduce variability in training.

Interpretation of important functions

preprocess(observation)

Process a frame of color image into a black and white binary image. Use OpenCV to resize the image to 80x80, convert to grayscale, and binarize it.

DeepNetWork(nn.Module)

The structure of the neural network is defined, including convolutional layers and fully connected layers. Used to approximate the Q-value function.

BirdDQN class

The main reinforcement learning agent class includes the following main functions:

  • save(): Save the trained model parameters.
  • load(): Load saved model parameters.
  • train(): Use small batches of memory data for neural network training.
  • setPerception(): Update the memory bank, determine whether to perform training, and output the current status information.
  • getAction(): Based on the current state, select an action through the epsilon-greedy strategy.
  • setInitState(): Initialization state, copy one frame of image four times as initial input.

Main program part

Created a BirdDQN agent instance, interacted with the Flappy Bird game environment, and continuously performed actions, observed state changes, and updated neural network parameters.

The above is an interpretation of the main algorithms and functions of the code. This project combines deep learning and reinforcement learning to train an agent to play the Flappy Bird game, demonstrating the implementation process under the PyTorch platform. If readers have any questions or need further explanation, please ask in the comments. I wish you success in your practice!

Guess you like

Origin blog.csdn.net/qq_36315683/article/details/135397934