TensorBoard 标量:记录Keras训练中的指标

机器学习中我们需要理解关键指标,例如在训练过程中损失函数是如何变化的。我们可以根据这些指标来判断是否过拟合,或者是否进行了长时间的不必要训练。我们通过不同训练过程中这些指标的对比来不断完善模型。

TensorBoard的Scalar Dashboard,通过简单的API就可以看到这些标量。这个教程提供了非常基础的样例帮助我们学习如何通过TensorBoard的API来提高Keras模型。我们将学习如何用keras TensorBoard的返回和TensorFlow的SummaryAPI来看到这些标量。

样例

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
from packaging import version

import tensorflow as tf
from tensorflow import keras

import numpy as np
import os

print("TensorFlow version: ", tf.__version__)
assert version.parse(tf.__version__).release[0] >= 2, "this notebook requires Tensorflow 2.0 or above"

data_size = 1000
#训练集80%
train_pct = 0.8

train_size = int(data_size * train_pct)

#创建-1到1的随机数据
x = np.linspace(-1,1,data_size)
np.random.shuffle(x)

#生成输出数据y = 0.5x+2+noise
y = 0.5*x +2 +np.random.normal(0,0.05,(data_size,))

#划分成测试集和训练集
x_train, y_train = x[:train_size], y[:train_size]
x_test, y_test = x[train_size:],y[train_size:]

#定义,训练和评估模型
#训练时记录损失函数数据的方法如下:
#1.创建Keras TensorBoard callback  2.给定一个日志路径 3.通过TensorBoard callback来实现keras的Model.fit()
logdir = os.path.join("logs", datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)

model = keras.models.Sequential([keras.layers.Dense(16, input_dim=1),keras.layers.Dense(1),
])

model.compile(loss = 'mse',
optimizer = keras.optimizers.SGD(lr=0.2))

print("Training ... With default parameters, this takes less than 10 seconds")


training_history = model.fit(
    x_train, # input
    y_train, # output
    batch_size=train_size,
    verbose=0, # Suppress chatty output; use Tensorboard instead
    epochs=100,
    validation_data=(x_test, y_test),
    callbacks=[tensorboard_callback],
)
print(model.predict([60,25,2]))
print("Average test loss: ", np.average(training_history.history['loss']))
(tensorflow) C:\Users\70908\Anaconda3\envs\tensorflow\Scripts>tensorboard --logdir=F://tensorflow//TESTCode//logs

 注意:logs所在的路径为绝对路径

发布了93 篇原创文章 · 获赞 2 · 访问量 3049

猜你喜欢

转载自blog.csdn.net/qq_40041064/article/details/104728124
今日推荐