Deep reinforcement learning-DDPG code reading-ddpg.py(1)

Table of contents

1. Write ddpg.py

1.1 Import required packages and other python files

1.2 Define the training function train()

1.2.1 Code Summary

1.2.2 Code decomposition

1.3 Define the test function test()

1.3.1 Code summary

1.3.2 Code decomposition

1.4 Define the main function

1.4.1 Code Summary

1.4.2 Code decomposition

1.5 Call the training function or test function as needed

1.6 Questions


1. Write ddpg.py

First write the ddpg.py file, which contains the following steps:

1.1 Import required packages and other python files

import tensorflow as tf
import numpy as np
import gym
from gym import wrappers
import argparse
import pprint as pp
import sys

from replay_buffer import ReplayBuffer
from AandC import *
from TrainOrTest import *

This code imports some commonly used Python modules and custom modules, including tensorflow, numpy, gym, argparse, pprint, sys, and custom replay_bufferand modules.AandCTrainOrTest

The specific explanation is as follows:

  • tensorflow: Used to define and train deep learning models.
  • numpy: Used for numerical calculations and array operations.
  • gym: OpenAI Gym's Python environment for building and training reinforcement learning agents.
  • argparse: Used to parse command line arguments.
  • pprint: Used to pretty print Python objects.
  • sys: Used to access system-related functions, such as command-line parameters.
  • replay_buffer: A custom playback buffer class for storing and sampling empirical data.
  • AandC: A custom actor (Actor) and critic (Critic) class, used to build the actor and critic network in the DDPG algorithm.
  • TrainOrTest: A custom training and testing class, used to execute the training and testing process of the DDPG algorithm.
  • argparseis a module in the Python standard library for parsing command-line arguments . argparseYou can conveniently specify parameters on the command line through , and you can perform type checks on parameters, limit the range of values, and generate help information, etc. It is a very common tool when writing command-line tools and scripts. It can help developers process command-line input to achieve flexible control over program behavior.

  • pprintis a module in the Python standard library for pretty-printing Python objects. pprintThe function is provided pprint()to format the output of complex Python objects (such as dictionaries, lists, nested objects, etc.) to make them easier to read. This is very useful when debugging and outputting complex data structures, and can help developers better understand the structure and content of the data.

  • sysis a module in the Python standard library that provides system-related functionality, such as accessing command-line arguments, interacting with the system, and more. sys.argvVariables can get a list of command-line arguments, sys.exit()functions can be used to exit a program and return an exit code, sys.stdoutand sys.stderrcan be used for standard output and standard error output, etc. sysModules are very useful in scenarios such as handling command line parameters, exception handling, and system interaction, and are an important tool in Python.

1.2 Define the training function train()

1.2.1 Code Summary

def train(args):
    with tf.Session() as sess:
        env = gym.make(args['env'])  # 创建 Gym 环境
        np.random.seed(int(args['random_seed']))  # 设置随机种子
        tf.set_random_seed(int(args['random_seed']))  # 设置 TensorFlow 随机种子
        env.seed(int(args['random_seed']))  # 设置环境的随机种子
        env._max_episode_steps = int(args['max_episode_len'])  # 设置最大回合步数

        state_dim = env.observation_space.shape[0]   # 状态维度
        action_dim = env.action_space.shape[0]  # 动作维度
        action_bound = env.action_space.high  # 动作的上界

        actor = ActorNetwork(sess, state_dim, action_dim, action_bound,
                             float(args['actor_lr']), float(args['tau']), int(args['minibatch_size']))  # 创建 Actor 神经网络

        critic = CriticNetwork(sess, state_dim, action_dim,
                               float(args['critic_lr']), float(args['tau']), float(args['gamma']),
                               actor.get_num_trainable_vars())  # 创建 Critic 神经网络

        trainDDPG(sess, env, args, actor, critic)  # 使用 trainDDPG 函数进行 DDPG 算法的训练

        saver = tf.train.Saver()  # 创建保存器
        saver.save(sess, "ckpt/model")  # 保存训练得到的模型到 "ckpt/model" 目录下
        print("saved model")  # 打印提示信息

This is a train()function for training the Actor and Critic neural networks of the DDPG (Deep Deterministic Policy Gradient) algorithm. The function accepts a parameterargs , which contains various parameters required to train the DDPG algorithm , such as learning rate, discount factor, target network update parameters, buffer size, etc.

Inside the function, a TensorFlow session is first created ( tf.Session()) and a Gym environment ( gym.make(args['env'])) is created based on the parameters passed in. Then set the random seed to ensure the reproducibility of the results , initialize the Actor and Critic neural network, and call trainDDPG()the function to train the DDPG algorithm.

After the training is complete, use to tf.train.Saver()create a saver , and save the trained model to the "ckpt/model" directory, and print the "saved model" prompt information.

Note: Before using this function, you need to ensure that the relevant dependent libraries have been imported, and the , ActorNetworkand CriticNetworkfunctions have been defined trainDDPG. The specific implementation of these functions may be defined elsewhere.

1.2.2 Code decomposition

(1)np.random.seed(int(args['random_seed']))

        env = gym.make(args['env'])
        np.random.seed(int(args['random_seed']))
        tf.set_random_seed(int(args['random_seed']))
        env.seed(int(args['random_seed']))
        env._max_episode_steps = int(args['max_episode_len'])

This code uses the gym, , numpyand tensorflowlibraries to create an OpenAI Gym environment , and set the random seed and maximum round steps.

Specific steps are as follows:

① Use the function gymof make()to create an OpenAI Gym environment, where the name of the environment is args['env']specified by .

② Use the function numpyof random.seed()to set the random seed to args['random_seed']the integer value of to control the randomness.

③ Use the function tensorflowof set_random_seed()to set the random seed of TensorFlow to args['random_seed']the integer value of to control the randomness of TensorFlow .

④ Use the method envof the object seed()to set the random seed of the environment to args['random_seed']the integer value of to control the randomness of the environment.

⑤Use args['max_episode_len']the integer value of to set the maximum number of round steps of the environment to be env._max_episode_steps, which is used to limit the maximum number of steps for each round.

Note: Before running this code, you need to make sure that argsthe dictionary contains the correct parameter values, such as parameters such as 'env', 'random_seed'and 'max_episode_len', otherwise errors may occur.

This code is mainly used to initialize the environment and set the random seed and the maximum number of rounds to ensure the repeatability and control of the experiment.

(2)state_dim = env.observation_space.shape[0]

        state_dim = env.observation_space.shape[0]   # 状态维度
        action_dim = env.action_space.shape[0]  # 动作维度
        action_bound = env.action_space.high  # 动作的上界

The above code is used to obtain the dimension information of the environment state and action , and store the upper bound of the action in a variable.

env.observation_space.shapeReturns the shape of the environment state space, that is, the dimension information of the state. env.observation_space.shape[0]Gets the first element of the state dimension, the state dimension.

env.action_space.shapeReturns the shape of the environment's action space, that is, the dimension information of the action. env.action_space.shape[0]Get the first element of the action dimension, the action dimension.

env.action_space.highReturns the upper bound of the action space, i.e. the maximum value of the action. This value is used as an upper bound on the output actions in the Actor Network to ensure that the generated actions do not exceed the limits of the environment's action space . This can help train a reasonable action strategy.

(3)actor = ActorNetwork(sess, state_dim, action_dim, action_bound,
                             float(args['actor_lr']), float(args['tau']), 
                             int(args['minibatch_size']))  

        actor = ActorNetwork(sess, state_dim, action_dim, action_bound,
                             float(args['actor_lr']), float(args['tau']), 
                             int(args['minibatch_size']))  

        critic = CriticNetwork(sess, state_dim, action_dim,
                             float(args['critic_lr']), float(args['tau']), 
                             float(args['gamma']),
                             actor.get_num_trainable_vars())

The above code creates the Actor network and the Critic network, and passes in the corresponding parameters for initialization.

ActorNetworkThe class is the implementation of the Actor network, receiving sess(TensorFlow session object), state_dim(state dimension), action_dim(action dimension), action_bound(action upper bound ), actor_lr(Actor network learning rate), tau(target network update parameters), minibatch_size(small batch size), etc. Parameters are initialized.

CriticNetworkThe class is the implementation of the Critic network, receiving sess(TensorFlow session object), state_dim(state dimension), action_dim(action dimension), ( critic_lrCritic network learning rate), tau(target network update parameters), gamma(discount factor), num_actor_vars(Actor network trainable parameter number ) and other parameters to initialize.

Note that the method is also called here actor.get_num_trainable_vars()to obtain the number of trainable parameters of the Actor network, and pass it as a parameter to the Critic network for parameter update when the target network is updated. This ensures that the critic network does not update the parameters of the actor network when updating the target network.

(4)saver.save(sess, "ckpt/model")

        trainDDPG(sess, env, args, actor, critic)  # 使用 trainDDPG 函数进行 DDPG 算法的训练

        saver = tf.train.Saver()  # 创建保存器
        saver.save(sess, "ckpt/model")  # 保存训练得到的模型到 "ckpt/model" 目录下
        print("saved model") 

trainDDPGThe function is used to train the DDPG (Deep Deterministic Policy Gradient) algorithm, receiving parameters such as sess(TensorFlow session object), env(OpenAI Gym environment object), args(command line parameter dictionary), actor(Actor network object) and critic(Critic network object) as input.

trainDDPGThe function will use the Actor network and Critic network for training in a given environment according to the training process of the DDPG algorithm, update the network parameters, and optimize the strategy and value function. For the specific training process and algorithm implementation, please refer to trainDDPGthe implementation of the function.

After the training is complete, the code tf.train.Saver()creates a saver saverto save the trained model . Call saver.save(sess, "ckpt/model")to save the model to the "ckpt/model" directory, where sessis the current TensorFlow session object, and the saved model file name is "model". Finally, print("saved model")output the prompt message of successful saving through .

1.3 Define the test function test()

1.3.1 Code summary

def test(args):
    with tf.Session() as sess:
        env = gym.make(args['env'])  # 创建测试环境
        np.random.seed(int(args['random_seed']))  # 设置随机种子
        tf.set_random_seed(int(args['random_seed']))  # 设置 TensorFlow 随机种子
        env.seed(int(args['random_seed']))  # 设置测试环境的随机种子
        env._max_episode_steps = int(args['max_episode_len'])  # 设置最大的回合步数

        state_dim = env.observation_space.shape[0]  # 获取状态空间的维度
        action_dim = env.action_space.shape[0]  # 获取动作空间的维度
        action_bound = env.action_space.high  # 获取动作的上界

        actor = ActorNetwork(sess, state_dim, action_dim, action_bound,
                             float(args['actor_lr']), float(args['tau']), int(args['minibatch_size']))  # 创建 Actor 网络对象

        critic = CriticNetwork(sess, state_dim, action_dim,
                               float(args['critic_lr']), float(args['tau']), float(args['gamma']),
                               actor.get_num_trainable_vars())  # 创建 Critic 网络对象

        saver = tf.train.Saver()  # 创建 Saver 对象,用于加载之前训练时保存的模型参数
        saver.restore(sess, "ckpt/model")  # 加载之前训练时保存的模型参数

        testDDPG(sess, env, args, actor, critic)  # 使用加载的模型在测试环境中运行,并输出测试结果

This is a function for testing test(args). Its function is to load the previously trained DDPG model and run the model in the test environment for testing.

The specific steps of the function are as follows:

  1. Create a TensorFlow session sessfor performing TensorFlow operations.
  2. Use gym.make()the function to create a test environment env, and set the random seed and the maximum number of round steps.
  3. Get the dimensionality of the state space, the dimensionality of the action space, and the upper bound of the action.
  4. Create an Actor network object actorand initialize the network with the parameters saved from previous training.
  5. Create a critic network object criticand initialize the network with the parameters saved from the previous training.
  6. Create a Saver object saverfor loading model parameters saved during previous training.
  7. Use saver.restore()the function to load the model parameters saved during previous training.
  8. Call testDDPG()the function, use the loaded model to run in the test environment, and output the test results.

1.3.2 Code decomposition

(1)env = gym.make(args['env'])

env = gym.make(args['env'])

gym.make(args['env'])is the operation of creating a Gym environment object using the Gym library.

GymIt is an open source toolkit for developing and comparing reinforcement learning algorithms, providing a series of standardized environments, including various types of games, control problems, etc. By calling gym.make()the function and passing an environment name as a parameter, a specific Gym environment object can be created for training or testing reinforcement learning algorithms.

In this code, argsit is a dictionary containing the values ​​of some parameters, including 'env'the parameter, which represents the name of the Gym environment to be created . args['env']Get 'env'the value of the parameter by using and pass it to gym.make()the function to create a Gym environment object with the corresponding name and assign it to a variable envso that the environment object can be used later in the training or testing process.

(2)np.random.seed(int(args['random_seed']))

np.random.seed(int(args['random_seed']))

np.random.seed(int(args['random_seed']))It is to set the seed of the random number generator in the NumPy library , so that the random number generation process is reproducible . In machine learning and deep learning, the seed of the random number generator is often set to a fixed value to ensure that the sequence of random numbers generated each time the code is run is consistent , which facilitates comparison and reproduction of results.

In this code, argsit is a dictionary containing the values ​​of some parameters, including 'random_seed'the parameter, which is a random seed value. Use to int(args['random_seed'])convert the random seed value of string type to integer type and pass it to np.random.seed()the function to set the seed value of the random number generator in the NumPy library to ensure the consistency of random number generation . This way, using the same random seed value will generate the same sequence of random numbers each time the code is run, ensuring reproducibility of the test.

(3)tf.set_random_seed(int(args['random_seed']))

tf.set_random_seed(int(args['random_seed'])) 

tf.set_random_seed(int(args['random_seed']))is the random seed operation that sets TensorFlow .

In machine learning, randomness is often used in the training process, such as using random numbers in the process of initializing model parameters, sampling data, and so on. In order to ensure the repeatability of the experiment, it is usually necessary to set the random seed so that the sequence of random numbers obtained each time the code is run is the same.

In this code, argsit is a dictionary that contains the values ​​of some parameters, including 'random_seed'the parameter, which represents the value of the random seed. By using args['random_seed']Get 'random_seed'the value of the parameter, convert it into an integer value and tf.set_random_seed()pass it to the function, so as to set the random seed of TensorFlow, which ensures that the sequence of random numbers obtained during training or testing is the same, so as to realize the reliability of the experiment. repeatability.

(4)env.seed(int(args['random_seed']))

env.seed(int(args['random_seed']))

env.seed(int(args['random_seed']))is the random seed operation that sets the Gym environment.

In OpenAI Gym, the randomness of the environment usually includes the random setting of the initial state, the selection of random actions, etc. In order to ensure that the same random environment is obtained at different runtimes, the random seed of the environment can be set.

In this code, argsit is a dictionary that contains the values ​​of some parameters, including 'random_seed'the parameter, which represents the value of the random seed. By using args['random_seed']Get 'random_seed'the value of the parameter, convert it to an integer value and pass it to env.seed()the function, so as to set the random seed of the Gym environment, and ensure that the random environment obtained during training or testing is the same, so as to realize the reliability of the experiment repeatability.

(5)saver.restore(sess, "ckpt/model")

saver.restore(sess, "ckpt/model")

saver.restore(sess, "ckpt/model")is to load previously saved model parameters through the TensorFlow Saverobject .

In this code, a Saver object saveris tf.train.Saver()created by , which is used to save and restore the parameters of the TensorFlow model. sessIt is the session object of TensorFlow, which is used to perform operations in the calculation graph. "ckpt/model"is the save path of the model parameters, where "ckpt" is the directory where the model parameters are saved, and "model" is the file name where the model parameters are saved.

saver.restore(sess, "ckpt/model")It is used to load the previously saved model parameters from the specified path, and restore the parameter values ​​to the current TensorFlow session, so that the loaded model parameters can be used in subsequent codes for prediction, testing or other operations. After the model parameters are loaded, the parameter values ​​can be obtained and used in subsequent runs of the model by accessing the corresponding TensorFlow variables.

1.4 Define the main function

1.4.1 Code Summary

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='provide arguments for DDPG agent')

    # agent parameters
    parser.add_argument('--actor-lr', help='actor network learning rate', default=0.0001)
    parser.add_argument('--critic-lr', help='critic network learning rate', default=0.001)
    parser.add_argument('--gamma', help='discount factor for Bellman updates', default=0.99)
    parser.add_argument('--tau', help='target update parameter', default=0.001)
    parser.add_argument('--buffer-size', help='max size of the replay buffer', default=1000000)
    parser.add_argument('--minibatch-size', help='size of minibatch', default=64)

    # run parameters
    parser.add_argument('--env', help='gym env', default='Pendulum-v0')
    parser.add_argument('--random-seed', help='random seed', default=258)
    parser.add_argument('--max-episodes', help='max num of episodes', default=250)
    parser.add_argument('--max-episode-len', help='max length of each episode', default=1000)
    parser.add_argument('--render-env', help='render gym env', action='store_true')
    parser.add_argument('--mode', help='train/test', default='train')

    args = vars(parser.parse_args())

    pp.pprint(args)

This code is the entry point of a Python script that uses the argparse module to parse command-line arguments and pass those arguments to the training or testing process of a DDPG (Deep Deterministic Policy Gradient) agent.

The specific parameters are explained as follows:

  • --actor-lr: Specifies the learning rate of the Actor network, the default value is 0.0001.
  • --critic-lr: Specify the learning rate of the critic network, the default value is 0.001.
  • --gamma: Specifies the discount factor in Bellman update, the default value is 0.99.
  • --tau: Specifies the parameters of the target network update, the default value is 0.001.
  • --buffer-size: Specifies the maximum size of the playback cache, the default value is 1000000.
  • --minibatch-size: Specifies the size of each mini-batch, the default is 64.
  • --env: Specify the name of the gym environment, default is 'Pendulum-v0'.
  • --random-seed: Specifies the random seed, the default value is 258.
  • --max-episodes: Specifies the maximum number of rounds for training or testing, the default value is 250.
  • --max-episode-len: Specifies the maximum number of steps per round, the default value is 1000.
  • --render-env: Whether to render the Gym environment, the default is False.
  • --mode: Specify the running mode, it can be 'train' or 'test', the default is 'train'.

These parameters are stored in a dictionary after parsing the command line parameters by using the argparse module args, and pp.pprint(args)printed out through to view the value of the parameter. In the subsequent code, you can use argsthe key-value pairs in the dictionary to obtain the corresponding parameter values, and pass them to the training or testing functions of the DDPG agent.

1.4.2 Code decomposition

(1)if __name__ == '__main__':

if __name__ == '__main__':

This code is a common way of writing in Python, which is used to determine whether the current module is directly run as the main program, rather than being imported and called by other programs as a module.

In Python, each module has an __name__attribute that represents the name of the module. When a module is run directly, its __name__attribute value is the name of the module '__main__'when a module is imported .__name__

Therefore, using if __name__ == '__main__':such a conditional statement can determine whether the current module is directly run as the main program. If the condition is true, it means that the current module is running as the main program and can execute the corresponding main program logic; if the condition is not true, it means that the current module is imported and can be called by other programs as a module. In this way, the main program logic and module import logic can be distinguished in the code, making the code more flexible and reusable.

(2)parser = argparse.ArgumentParser(description='provide arguments for DDPG agent')

parser = argparse.ArgumentParser(description='provide arguments for DDPG agent')

This line of code creates an ArgumentParser object for parsing command-line arguments . The descriptionparameter is the description of the ArgumentParser object, which is usually used to explain the purpose of command line parameters . Here, the description is "provide arguments for DDPG agent".

(3)parser.add_argument('--actor-lr', help='actor network learning rate', default=0.0001)

 parser.add_argument('--actor-lr', help='actor network learning rate', default=0.0001)

parser.add_argument() This is a method of the ArgumentParser object for adding command line arguments. By calling this method, you can define the parameters that need to be accepted in the command line, and set their properties, such as parameter names, default values, data types, help information, etc.

This line of code defines a --actor-lrcommand-line parameter named , which is used to set the learning rate of the actor network. helpThe parameter sets the help information for this parameter to "actor network learning rate", and defaultthe parameter sets the default value to 0.0001. The other parameters are defined in a similar manner.

(4)parser.add_argument('--render-env', help='render gym env', action='store_true')

parser.add_argument('--render-env', help='render gym env', action='store_true')

This code argparseadds an optional argument to the command line arguments via the module --render-env. --render-envThe role of the parameter is to control whether to render images of the Gym environment during training or testing.

The specific explanation is as follows:

  • --render-env: Specifies the name of the parameter render-env.
  • help='render gym env': The help information of the parameter is 'render gym env', this text will be displayed when the user passes -hor view the help information.--help
  • action='store_true': When the user specifies --render-envthe parameter, set its value to True. If the user does not specify this parameter, its value defaults to False.

In this way, when the parameter is passed in the command line --render-env, args['render_env']the value of will be set to True, so that the code can decide whether to render the image of the Gym environment based on this value. If --render-envthe parameter is not passed in, args['render_env']the value of is False.

(5)args = vars(parser.parse_args())

args = vars(parser.parse_args())

This line of code stores ArgumentParserthe parsed command-line arguments in a dictionary objectargs .

vars()The function is used to convert argparse.Namespacethe object ( parser.parse_args()returned object) into a dictionary , where the keys of the dictionary are the parameter names and the values ​​of the dictionary are the parameter values ​​passed on the command line.

For example, if this parameter is passed on the command line --actor-lr 0.0002, argsthe dictionary will contain a key-value pair 'actor_lr': 0.0002where 'actor_lr'is the parameter name and 0.0002is the parameter value. args['actor_lr']The value of this parameter can be accessed via . The values ​​of other command-line parameters can also be accessed in a similar manner.

parser.parse_args()

This line of code is using ArgumentParserthe method of parsing command line arguments. It parses the parameters passed in the command line, ArgumentParservalidates and parses them according to the parameter specification defined in the object.

In code, parseris an ArgumentParserobject that parse_args()parses command-line arguments by calling methods. The parsing process will parse the command-line parameters according to the parameter specification added parserby using the method in the object add_argument(), and store the parsed parameter values ​​in an argparse.Namespaceobject.

--actor-lr 0.0002For example, if this parameter is passed on the command line , parse_args()the method will 'actor_lr'parse the value of the parameter into and return an object 0.0002containing the value . argparse.NamespaceThe parsed parameter value can be accessed through the properties of the object, for example, args.actor_lrto access 'actor_lr'the value of the parameter.

1.5 Call the training function or test function as needed

if (args['mode'] == 'train'):
    train(args)
elif (args['mode'] == 'test'):
    test(args)

This code decides whether to perform training or testing operations based on the value of the field argsin the parameter mode.

  • If args['mode']the value of is equal to 'train', then call trainthe function and pass in the parameters args.
  • If args['mode']the value of is equal to 'test', then call testthe function and pass in the parameters args.

modeIn this way , the training or testing operation can be dynamically selected according to the value of the parameter passed in the command line , so as to achieve more flexible program behavior. It should be noted that trainthe and testfunctions should be defined and implemented elsewhere in the code, and this code only args['mode']calls the corresponding function according to the value of .

1.6 Questions

(1) Question: Why does env.observation_space.shape[0] use [0] as the state dimension?

In OpenAI Gym, observation_spaceit is an object describing the state space of the environment , usually an object of type Boxor . For state spaces of type , returns a tuple of shape , where represents the number of dimensions of the state. Thus, denotes the size of the first dimension of the state space, the dimension of the state.DiscreteBoxobservation_space.shape(n,)nobservation_space.shape[0]

The number of dimensions of the state space, usually used for the number of nodes in the input layerenv.observation_space.shape[0] of a neural network model, can be obtained by , so that the size of the input layer can be set correctly. For example, if the number of dimensions of the state space is 3, then the value of will be 3, indicating that the dimension of the state is 3.env.observation_space.shape[0]

(2) Question: Why does env.action_space.shape[0] use [0] as the action dimension?

In OpenAI Gym, action_space it is an object describing the action space of the environment , usually an object of type Box, DiscreteorMultiDiscrete . For Boxan action space of type , action_space.shapereturns a (n,)tuple of shape , where nrepresents the number of dimensions of the action. Thus, action_space.shape[0]represents the size of the first dimension of the action space, the dimension of the action.

The number of dimensions of the action spaceenv.action_space.shape[0] can be obtained by , which is usually used for the number of nodes in the output layer of the neural network model , so that the size of the output layer can be set correctly. For example, if the number of dimensions of the action space is 2, then the value of will be 2, indicating that the dimension of the action is 2.env.action_space.shape[0]

(3) Question: Why do the generated model files have different suffixes?

In TensorFlow, model parameters can be saved with different file extensions, such as ".ckpt", ".meta", ".index", ".data", etc. These file extensions are used to store different information.

  • The ".ckpt" file is the checkpoint file of TensorFlow , which is used to store the parameter values ​​of the model. It contains parameter information such as weights and biases of the model, and is saved in binary format. Equivalent to a file named directly after checkpoint.
  • The ".meta" file is TensorFlow's meta-graph file, which is used to store the calculation graph structure of the model. It contains the network structure of the model, operation nodes and other information.
  • The ".index" file is the index file of TensorFlow, which is used to store the index information of each parameter in the ".ckpt" file .
  • The ".data" file is the data file of TensorFlow, which is used to store the actual value of each parameter in the ".ckpt" file .

The combination of these file suffixes forms a complete model file, which can be save()saved through the method of the Saver object and restore()loaded through the method. When using the Saver object to save and load the model, you can specify the saved and loaded file path and file suffix. For example, in the above code, "ckpt/model" is used as the file path for saving and loading model parameters, where "ckpt" is the directory for saving model parameters, and "model" is the file name for saving model parameters. The specific file The suffix is ​​automatically added and recognized by the Saver object.

(4) Question: What is the relationship between checkpoint, model.data-00000-of-00001, model.index, and model.meta?

In TensorFlow, these files are usually part of the checkpoint file, which is used to save the parameter information and calculation graph structure of the model.

  • "checkpoint" file : used to save the meta information of the current training checkpoint, including the latest checkpoint file name, a list of all checkpoint files, the latest training steps and other information.

  • "model.data-00000-of-00001" file: used to save the parameter data of the model, including binary data of parameters such as weight and bias.

  • "model.index" file: used to save the index information of the model parameters, including the name, shape, type and other information of the parameters, which is used to quickly find the model parameters.

  • "model.meta" file: used to save the metadata of the TensorFlow calculation graph, including the structure of the graph, the name and type of variables, and other information.

These files usually form a complete checkpoint file, which is used to save the parameter information and calculation graph structure of the TensorFlow model. When using restore()the method of the Saver object to load model parameters, usually you only need to specify the path of the ".ckpt" file instead of loading "model.data-00000-of-00001", "model.index", "model.meta " These files, as they are generated when saving the checkpoint file, are automatically recognized and loaded.

Note: The format of the generated file name depends on the version of TensorFlow used , but it does not affect the loading and restoring process of the model, because TensorFlow will automatically identify and load the corresponding file according to the suffix and content of the file

(6) Question: args?

argsis a Python dictionary containing arguments parsed from the command line and their corresponding values. argparseIt is generated after parsing command-line arguments using the module.

In the code, it is returned argsby the method. The command line parameters will be parsed, and the parsing results will be stored in the object, where the parameter name (eg , , etc.) is used as the key, and the corresponding parameter value is used as the value.parser.parse_args()parser.parse_args()args--actor-lr--env

For example, if a command like the following is executed on the command line:

python my_script.py --actor-lr 0.0001 --env 'Pendulum-v0' --render-env

Then argsthe object might contain the following:

args = {
    'actor_lr': 0.0001,
    'env': 'Pendulum-v0',
    'render_env': True
}

These parameter values ​​can argsbe accessed and used through the object, and processed as needed in the code. For example, args['actor_lr']will return 0.0001, args['env']will return 'Pendulum-v0', args['render_env']will return True.

Guess you like

Origin blog.csdn.net/aaaccc444/article/details/130208853