三、自动终止训练

有时候,当模型损失函数值预期的效果时,就可以结束训练了,一方面节约时间,另一方面防止过拟合
此时,设置损失函数值小于0.4,训练停止

from tensorflow import keras
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np


class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self,epoch,logs={
    
    }):
        if(logs.get('loss')<0.4):
            print("\nLoss is low so cancelling training!")
            self.model.stop_training = True


callbacks = myCallback()
mnist = tf.keras.datasets.fashion_mnist
(training_images,training_labels),(test_images,test_labels) = mnist.load_data()
training_images_y = training_images/255.0
test_images_y = test_images/255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512,activation=tf.nn.relu),
    tf.keras.layers.Dense(10,activation=tf.nn.softmax)
])
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
model.fit(training_images_y,training_labels,epochs=5,callbacks=[callbacks])
"""
Colocations handled automatically by placer.
Epoch 1/5
60000/60000 [==============================] - 12s 194us/sample - loss: 0.4729 - acc: 0.8303
Epoch 2/5
59712/60000 [============================>.] - ETA: 0s - loss: 0.3570 - acc: 0.8698
Loss is low so cancelling training!
60000/60000 [==============================] - 11s 190us/sample - loss: 0.3570 - acc: 0.8697
"""

猜你喜欢

转载自blog.csdn.net/qq_41264055/article/details/125445199