Tensorflow2 - model saving and loading as well as training data saving and breakpoint retraining

This article has participated in the "Newcomer Creation Ceremony" event to start the road of gold creation together.

What can you get?

        By reading this blog, you can learn how to save accuracy and loss during Tensorflow training, how to save and load models in tensorflow, and how to resume training from the previous round of training.

Recently, in the process of neural network training, it is necessary to save the data in the training process, and the next training can continue with the results of the previous training. Therefore, through the Tensorflow2 official website, you can find out how to write the callback functions related to model.fit. The following is a summary. In Tensorflow, an example code will be given at the end for the content of parameter saving and continuous training at breakpoints for your reference.

1. Save the training data

How to save the data of the training process, including the number of training rounds (Epoch), training set acc, training set loss, validation set acc, and validation set loss. A callback function that saves training data can be found on the Tensorflow official website, namely tf.keras.callbacks.CSVLogger. The use method is very simple, you only need to specify the save path, and then add the method to the callbacks parameter list of model.fit. The sample code is as follows:

#参数说明
#append: 是否在指定文件基础上追加内容,
csv_logger = CSVLogger('training.log',append=False)
model.fit(X_train, Y_train, callbacks=[csv_logger])
复制代码

2. Save and load models

Saving the model can be done with the help of tf.keras.callbacks.ModelCheckpoint(). The interface description is as follows: In the process of using, we only need to use three parameters: filepath, save_best_only, save_weight_only. Among them, filepath indicates the path to save the file, save_best_only indicates whether to save only the best model, and save_weight_only indicates whether to save only the model weight.

tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, mode='auto', save_freq='epoch',
    options=None, initial_value_threshold=None, **kwargs
)
复制代码

After we save the model, we need to load the model. To load the model, we will load the model through model.load_weights(filepath).

Here is a sample code for your reference, the specific needs can be modified according to the code.


model = TestModel()
model.compile(....)

# 读取保存的模型权重
checkpoint_save_path = './checkpoint/Baseline.ckpt'
checkpoint_save_best_path = './checkpoint_best/Baseline.ckpt'
if os.path.exists(checkpoint_save_path + '.index'):
    print('------load the model------')
    model.load_weights(checkpoint_save_path)
    
# 这里只保存权重(这样在下一次训练开始的时候,就能够接着上一次的训练继续训练)
cp_callback_save = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True)
# 这里只保存最优模型(这样在训练结束后,能够保存训练过程中的最优模型)
cp_callback_save_best = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_best_path,save_weights_only=True,save_best_only=True)

model.fit(.....,callbacks=[cp_callback_save,cp_callback_save_best])

复制代码

3. Sample code

前面介绍了参数和模型的保存,回到刚刚的问题,我们需要在上一次训练后继续之前的训练过程,并且保存数据参数。首先,如果要接着上一轮继续训练,那么就需要知道上一轮训练了多少轮,我们可以通过我们的参数数据文件,很容易得出我们训练了多少论,接着我们可以借助model.fit的initial_epoch指定起始轮数,这样就可以使得训练接着上一轮继续训练。参考代码如下:

# 返回训练轮数,filename:训练数据保存文件路径
def get_init_epoch(filename):
    with open(filename) as f:
        f_csv = csv.DictReader(f)
        count = 0
        for row in f_csv:
            count = count+1
        return count

init_epoch = 0 # 起始轮次
if os.path.exists(filename):
    init_epoch = get_init_epoch(filename)
model = Test()
model.compile(...)


checkpoint_save_path = './checkpoint/Baseline.ckpt'
checkpoint_save_best_path = './checkpoint_best/Baseline.ckpt'
if os.path.exists(checkpoint_save_path + '.index'):
    print('------load the model------')
    model.load_weights(checkpoint_save_path)

cp_callback_save = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True)
cp_callback_save_best = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_best_path,save_weights_only=True,save_best_only=True)

# 通过设置append可以选择是否在原文件上添加还是重新创建增加内容
csv_logger = CSVLogger('training_log',append=True) 

model.fit(....,init_epoch=init_epoch,callbacks=[csv_logger,cp_callback_save_best,cp_callback_save])
复制代码

如果你需要自定义回调函数,可以参考下列资料

所有回调函数都将 keras.callbacks.Callback 类作为子类,并重写在训练、测试和预测的各个阶段调用的一组方法。回调函数对于在训练期间了解模型的内部状态和统计信息十分有用。

回调函数方法概述 全局方法

on_(train|test|predict)_begin(self, logs=None)
在 fit/evaluate/predict 开始时调用。

on_(train|test|predict)_end(self, logs=None)
在 fit/evaluate/predict 结束时调用。

Batch-level methods for training/testing/predicting
on_(train|test|predict)_batch_begin(self, batch, logs=None)
正好在训练/测试/预测期间处理批次之前调用。

on_(train|test|predict)_batch_end(self, batch, logs=None)
在训练/测试/预测批次结束时调用。在此方法中,logs 是包含指标结果的字典。

周期级方法(仅训练)
on_epoch_begin(self, epoch, logs=None)
在训练期间周期开始时调用。

on_epoch_end(self, epoch, logs=None)
在训练期间周期开始时调用。
复制代码

基本示例 让我们来看一个具体的例子。

class CustomCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))

    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))

    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))

    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))

    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))

    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys)) 

复制代码

Guess you like

Origin juejin.im/post/7080449901938606088