tensorflow 2.1 breakpoint continuous training model access

1. Introduction

        Model training often takes a long time, and we often want to find the best model parameters, so we need to save the optimal model parameters for continued training, which is convenient for testing and easy to find the optimal solution.

        Breakpoint resume training means that the training model can be started from the saved model.

        At the same time, you can also directly load the trained model to do predict prediction.

2. How to access

        Post the code first:

import tensorflow as tf
import os

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()
# 读取模型
# 定义路径:checkpoint_save_path = "*********.ckpt" ,该.ckpt文件就是保存的模型的文件
# 由于如若已有保存的模型就一定会有index目录文件,于是用以下的代码检测是否已有存在的模型
# if os.path.exists(checkpoint_save_path + '.index'):
#     print('-------------load the model-----------------')
# 若存在模型,则直接用函数 .load_weights()去加载模型
#     model.load_weights(checkpoint_save_path)

# 保存模型
# 这里的.callbacks.ModelCheckpoint是keras的callback的一种功能,将会在另一篇博客介绍callbacks的详细用法
# callback =tf.keras.callbacks.ModelCheckpoint(
#   filepath=文件路径,
#   save_weights_only=是否仅保存参数,
#   save_best_only=是否仅保存最优模型)
# 同时需要在fit中加入回调选项callback,并返回给history,(这里的callback和上面定义的callback是同一个),即: history = model.fit(...,callbacks=[callback])

 At the end of each training/epoch/batch, if we want to perform certain tasks, such as model caching, output log, calculate current acurracy, etc., the callback in Keras comes in handy.

# callbacks可以做到以下功能:
# ModelCheckpoint模型断点续训:保存当前模型的所有权重
# EarlyStopping提早结束:当模型的损失不再下降的时候就终止训练,当然,会保存最优的模型。
# LearningRateSchedule动态调整训练时的参数,比如优化的学习率

 Code execution result:

The model parameters have been saved here:

We can easily read parameters from the model using

3. View the saved parameters

        Sometimes you may want to view the trainable parameters intuitively, the following will introduce how to view the parameters and save them in text

print(model.trainable_variables) 

This statement can print out the value of the parameter, but there are the following problems:

 Most of the printed parameters are replaced by ellipses. What if we want to see all the parameters?

Just set the maximum display number: add the following code at the beginning

import numpy as np
np.set_printoptions(threshold=np.inf)

 

 Now it is completely ok, you can see all the parameters

Then save the parameters to a txt file:

file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

 View txt in the current directory:

Saved successfully

On the complete code:

import tensorflow as tf
import os
import numpy as np
np.set_printoptions(threshold=np.inf)
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()
print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

Guess you like

Origin blog.csdn.net/qq_46006468/article/details/119645336