Tensorflow2——模型保存与加载以及训练数据保存和断点续训

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

你能收获什么?

        通过阅读这篇博客,你可以了解如何在Tensorflow训练过程中保存准确率和loss,以及如何在tensorflow中保存与加载模型,如何再重新接着上一轮的训练过程继续训练。

最近在神经网络训练的过程中,需要保存训练过程中的数据,并且再下次训练的时候能够接着上次训练的结果进行断点续训。所以通过Tensorflow2官网查询到对于model.fit相关的回调函数的编写方法。下面总结了一下,在Tensorflow中,对于参数保存以及断点续训的内容,在最后会给出一个示例代码,供大家参考。

1. 保存训练数据

如何保存训练过程的数据,包括训练轮数(Epoch),训练集acc,训练集loss,验证集acc,验证集loss。通过Tensorflow官网可以找到一个保存训练数据的回调函数,即tf.keras.callbacks.CSVLogger。使用方法很简单,只需要指定保存路径,然后将方法加入到model.fit的callbacks参数列表中。示例代码如下:

#参数说明
#append: 是否在指定文件基础上追加内容,
csv_logger = CSVLogger('training.log',append=False)
model.fit(X_train, Y_train, callbacks=[csv_logger])
复制代码

2. 保存与加载模型

保存模型可以借助tf.keras.callbacks.ModelCheckpoint()进行保存。接口说明如下:我们在使用过程中,大致只需要使用filepath,save_best_only,save_weight_only三个参数。其中filepath指明保存文件的路径,save_best_only说明是否只保存为佳模型,save_weight_only说明是否只保存模型权重。

tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, mode='auto', save_freq='epoch',
    options=None, initial_value_threshold=None, **kwargs
)
复制代码

我们保存了模型,就需要加载模型,加载模型就通过model.load_weights(filepath)进行模型的加载操作。

这里给出一个示例代码,供大家参考,具体需求可以根据代码修改。


model = TestModel()
model.compile(....)

# 读取保存的模型权重
checkpoint_save_path = './checkpoint/Baseline.ckpt'
checkpoint_save_best_path = './checkpoint_best/Baseline.ckpt'
if os.path.exists(checkpoint_save_path + '.index'):
    print('------load the model------')
    model.load_weights(checkpoint_save_path)
    
# 这里只保存权重(这样在下一次训练开始的时候,就能够接着上一次的训练继续训练)
cp_callback_save = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True)
# 这里只保存最优模型(这样在训练结束后,能够保存训练过程中的最优模型)
cp_callback_save_best = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_best_path,save_weights_only=True,save_best_only=True)

model.fit(.....,callbacks=[cp_callback_save,cp_callback_save_best])

复制代码

3. 示例代码

前面介绍了参数和模型的保存,回到刚刚的问题,我们需要在上一次训练后继续之前的训练过程,并且保存数据参数。首先,如果要接着上一轮继续训练,那么就需要知道上一轮训练了多少轮,我们可以通过我们的参数数据文件,很容易得出我们训练了多少论,接着我们可以借助model.fit的initial_epoch指定起始轮数,这样就可以使得训练接着上一轮继续训练。参考代码如下:

# 返回训练轮数,filename:训练数据保存文件路径
def get_init_epoch(filename):
    with open(filename) as f:
        f_csv = csv.DictReader(f)
        count = 0
        for row in f_csv:
            count = count+1
        return count

init_epoch = 0 # 起始轮次
if os.path.exists(filename):
    init_epoch = get_init_epoch(filename)
model = Test()
model.compile(...)


checkpoint_save_path = './checkpoint/Baseline.ckpt'
checkpoint_save_best_path = './checkpoint_best/Baseline.ckpt'
if os.path.exists(checkpoint_save_path + '.index'):
    print('------load the model------')
    model.load_weights(checkpoint_save_path)

cp_callback_save = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True)
cp_callback_save_best = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_best_path,save_weights_only=True,save_best_only=True)

# 通过设置append可以选择是否在原文件上添加还是重新创建增加内容
csv_logger = CSVLogger('training_log',append=True) 

model.fit(....,init_epoch=init_epoch,callbacks=[csv_logger,cp_callback_save_best,cp_callback_save])
复制代码

如果你需要自定义回调函数,可以参考下列资料

所有回调函数都将 keras.callbacks.Callback 类作为子类,并重写在训练、测试和预测的各个阶段调用的一组方法。回调函数对于在训练期间了解模型的内部状态和统计信息十分有用。

回调函数方法概述 全局方法

on_(train|test|predict)_begin(self, logs=None)
在 fit/evaluate/predict 开始时调用。

on_(train|test|predict)_end(self, logs=None)
在 fit/evaluate/predict 结束时调用。

Batch-level methods for training/testing/predicting
on_(train|test|predict)_batch_begin(self, batch, logs=None)
正好在训练/测试/预测期间处理批次之前调用。

on_(train|test|predict)_batch_end(self, batch, logs=None)
在训练/测试/预测批次结束时调用。在此方法中,logs 是包含指标结果的字典。

周期级方法(仅训练)
on_epoch_begin(self, epoch, logs=None)
在训练期间周期开始时调用。

on_epoch_end(self, epoch, logs=None)
在训练期间周期开始时调用。
复制代码

基本示例 让我们来看一个具体的例子。

class CustomCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))

    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))

    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))

    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))

    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))

    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys)) 

复制代码

猜你喜欢

转载自juejin.im/post/7080449901938606088