Explanation of PyTorch-based model saving and reloading modules, and visualization modules (with source code)

When training the model, there will be several better models among the many trained models. We hope to store the parameter values ​​corresponding to these models, so as to avoid the difficulty of training better results in the future, and it is also convenient for us to reproduce these models. in later research. PyTorch provides model saving and reloading modules, including torch.save() and torch.load(), and EarlyStopping in pytorchtools. This module is used to solve the above-mentioned model saving and reloading problems.

1. Saving and reloading modules

If you want to save/load the parameters of the model model without saving/loading the structure of the model, you can use the following code

Among them, state_dict is a dictionary object in torch, and establishes a mapping relationship between each layer and the corresponding parameter tensor of the layer

If you want to save/load the parameters of the model model and the model structure at the same time, without saving/loading the structure of the model, you can use the following code

In order to obtain a neural network with good performance, many settings for each part of the model, that is, hyperparameter adjustments, are required in the process of training the network. One of the hyperparameters is the training period (epoch). If the value of the training period is too small, it may lead to underfitting, and if the value is too large, it may lead to overfitting. In order to avoid improper training cycle settings affecting the model effect, EarlyStopping came into being. EarlyStopping solves the problem that epoch needs to be set manually, and can also be considered as a regularization method to avoid network overfitting 

The principle of EarlyStopping can be roughly divided into three parts:

Divide the original data into training set and validation set;

Only train on the training set, and calculate the error of the model on the verification set every other cycle. If the test error on the verification set increases as the cycle increases, stop the training;

Use the weights after stopping as the final parameters of the network

Initialize the early_stopping object:

The initialization of the EarlyStopping object includes three parameters, whose meanings are as follows:

patience(int) : how many epochs to wait after the last validation loss improvement, default: 7.

verbose(bool): If the value is True, print a message for each validation set loss value; if False, do not print, default value: False.

delta(float): The minimum change in the improvement of the loss function value. When the improvement of the loss function value is greater than this value, the model will be saved. Default value: 0, that is, the model will be saved as long as the loss function improves 

 Define a function to represent the training function. It is hoped that when the loss value on the test set decreases through EarlyStopping, the information at this time will be printed out and the parameters will be saved. First create the variables that will be used and initialize the earlystopping object

Then train the model and save the loss value, calculate the mean value of the loss value of each iteration on the training set and test set, and save

 

Call the _call_() module in EarlyStopping to judge whether the loss value has decreased, save it if it decreases, and print the information

Finally, call torch.load() to load the last saved point, which is the optimal model, and return the model, as well as the mean value of the loss value of each iteration on the training set and test set

 

2. Visualization module

In the process of model training, sometimes it is necessary not only to maintain and load the trained model, but also to convert the training set loss function, verification set loss function, model calculation graph (ie, model frame diagram, model data flow diagram) during the training process, etc. Keep it for subsequent analysis and drawing

For example, through the change of the loss function, you can observe whether the model is converging, and through the model calculation graph, you can observe the data flow, etc.

Tensorboard can visualize data, model calculation graphs, etc., automatically obtain the latest data information, store it in the log file, and update the information in the log file to run the latest status of the data or model. Commonly used modules in Tensorboard include the following seven categories

add_graph(): Add a network structure graph to visualize the calculation graph.

add_image()/add_images(): Add single image data/add image data in batches.

add_figure(): Add matplotlib pictures.

add_scalar()/add_scalars(): Add a scalar/batch add scalars, which can be used to draw loss functions in machine learning.

add_histogram(): Add a statistical distribution histogram.

add_pr_curve(): Add PR (precision rate-recall rate) curve.  

add_txt(): add text

The overall usage of Tensorboard, see the figure below 

 

 In TensorBoard, you can use the add_graph() function to save the model calculation graph. This function is used to create Graphs that store the network structure in tensorboard. The function and its parameters are as follows:

model (torch.nn.Module) represents the network model that needs to be visualized;

input_to_model (torch.Tensor or list of torch.Tensor) indicates the input variables of the model. If the model input is multiple variables, then use list or tuple to pass in multiple variables in order;

verbose (bool) is a switch statement, controlling whether to print out the graphical structure of the network in the console 

For example, if there is a variable model whose data type is torch.nn.Module, the input tensors are input1 and input2, and you expect to return the model calculation graph, you can enter the following code to save the data flow graph in the log folder of SummaryWriter

 The output folder of SummaryWriter in PyTorch is generally a runs file. The saved log file cannot be opened directly by double-clicking. You need to navigate the directory to the upper-level directory of the runs folder in the cmd command window and enter tensorboard –logdir runs to open it. Log file, open it and copy the link to the browser to open the saved model calculation graph or data variable, etc. 

TensorBoard can use the add_scalar()/add_scalars() function to save one or multiple constants in one graph, such as training loss function values, test loss function values, or training loss function values ​​and test loss function values ​​in one graph.

The add_scalar() function and parameters are as follows:

  

tag (string) is the data identifier;

scalar_value (float or string) is a scalar value, that is, the value you want to save;

global_step (int) is the global step value, which can be understood as x-axis coordinates 

 The add_scalars() function and parameters are as follows:

main_tag (string) is the main identifier, which is the parent name of the tag;

tag_scalar_dict (dict) is a dictionary type data that saves the tag and the value corresponding to the tag;

global_step (int) is the global step value, which can be understood as the x-axis coordinate. 

add_scalars() can add scalars in batches, for example, to draw images of y=xsinx, y=xcosx, y=tanx, you can enter the following code, and the way to open the saved log file is the same as described above

 

 

It's not easy to create and find it helpful, please like, follow and collect~~~ 

Guess you like

Origin blog.csdn.net/jiebaoshayebuhui/article/details/130441654