Tensorflow2.0 keras 回调函数

回调函数

1、Tensorboard 查看过程数据

tf.keras.callbacks.TensorBoard(日志保存路径)

# 创建保存文件夹
log_dir = os.path.join('callbacks')
if not os.path.exists(log_dir):
    os.mkdir(log_dir)

# 设置保存文件夹
callbacks = [ tf.keras.callbacks.TensorBoard(log_dir)]

详细使用过程
使用过程

2、Earlystopping 提早停止

tf.keras.callbacks.EarlyStopping(停止参数)

callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_loss',# 监控对象
                                     min_delta=1e-3,  # 每个epoch最小差值
                                     patience=5)  # 耐心等待次数 连续出现最小差值时退出
]

3、Modelcheckpoint 保存模型检查点

tf.keras.callbacks.ModelCheckpoint(输出文件名, 保存参数)

# 模型文件名
output_model_file = os.path.join(log_dir, 'fashion_mnist_model.h5')
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(output_model_file,
                                       save_best_only=True),# 只保存最好的模型
]

4、回调函数 设置

model.fit(…,callbacks = 回调函数列表)

callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir),
    tf.keras.callbacks.ModelCheckpoint(output_model_file,
                                       save_best_only=True),
    tf.keras.callbacks.EarlyStopping(monitor='val_loss',# 监控对象
                                     min_delta=1e-3,  # 最小差值
                                     patience=5)  # 耐心等待
]

history = model.fit(x_train_scaled, y_train, epochs=20,
                    validation_data=(x_valid_scaled, y_valid),
                    callbacks=callbacks)

完整代码

import os
import sys

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn
import tensorflow as tf
from sklearn.preprocessing import StandardScaler
from tensorflow import keras

print(tf.__version__)
print(sys.version_info)

for module in mpl, np, pd, sklearn, keras, tf:
    print(module.__name__, module.__version__)

# 加载数据集
(x_train_all, y_train_all), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]

print(x_valid.shape, y_valid.shape)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

print(np.max(x_train), np.min(x_train))
scaler = StandardScaler()

x_train_scaled = scaler.fit_transform(
    x_train.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
x_valid_scaled = scaler.fit_transform(
    x_valid.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
x_test_scaled = scaler.fit_transform(
    x_test.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)

print(np.max(x_train_scaled), np.min(x_train_scaled))

# tf.keras.Sequential
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=[28, 28]))
model.add(tf.keras.layers.Dense(300, activation="relu"))
model.add(tf.keras.layers.Dense(100, activation="relu"))
model.add(tf.keras.layers.Dense(10, activation="softmax"))

# relu: y = max(0, x)
# softmax x = [x1, x2, x3]
#         y = [e^x1/sum, e^x2/sum, e^x3/sum]
#         sum = e^x1 + e^x2 + e^x3

model.compile(loss="sparse_categorical_crossentropy",
              optimizer='adam',
              metrics=['accuracy'])

print(model.layers)
model.summary()

# 回调函数 Tensorboard, earlystopping, Modelcheckpoint
log_dir = os.path.join('callbacks')

if not os.path.exists(log_dir):
    os.mkdir(log_dir)

output_model_file = os.path.join(log_dir, 'fashion_mnist_model.h5')
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir),
    tf.keras.callbacks.ModelCheckpoint(output_model_file,
                                       save_best_only=True),
    tf.keras.callbacks.EarlyStopping(monitor='val_loss',# 监控对象
                                     min_delta=1e-3,  # 最小差值
                                     patience=5)  # 耐心等待
]

history = model.fit(x_train_scaled, y_train, epochs=20,
                    validation_data=(x_valid_scaled, y_valid),
                    callbacks=callbacks)

print(history.history)


def plot_learning_curves(history):
    pd.DataFrame(history).plot(figsize=(8, 5))
    plt.grid(True)
    plt.gca().set_ylim(0, 1)
    plt.show()


plot_learning_curves(history.history)

model.evaluate(x_test_scaled, y_test)

猜你喜欢

转载自blog.csdn.net/weixin_45875105/article/details/113609847