30天干掉tensorflow2.0-day25 训练模型的3种方法

训练模型的3种方法

模型的训练主要有内置fit方法、内置tran_on_batch方法、自定义训练循环。

注:fit_generator方法在tf.keras中不推荐使用,其功能已经被fit包含。

import numpy as np 
import pandas as pd 
import tensorflow as tf
from tensorflow.keras import * 

#打印时间分割线
@tf.function
def printbar():
    ts = tf.timestamp()
    today_ts = ts%(24*60*60)

    hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)
    minite = tf.cast((today_ts%3600)//60,tf.int32)
    second = tf.cast(tf.floor(today_ts%60),tf.int32)
    
    def timeformat(m):
        if tf.strings.length(tf.strings.format("{}",m))==1:
            return(tf.strings.format("0{}",m))
        else:
            return(tf.strings.format("{}",m))
    
    timestring = tf.strings.join([timeformat(hour),timeformat(minite),
                timeformat(second)],separator = ":")
    tf.print("=========="*8,end = "")
    tf.print(timestring)
MAX_LEN = 300
BATCH_SIZE = 32
(x_train,y_train),(x_test,y_test) = datasets.reuters.load_data()
x_train = preprocessing.sequence.pad_sequences(x_train,maxlen=MAX_LEN)
x_test = preprocessing.sequence.pad_sequences(x_test,maxlen=MAX_LEN)

MAX_WORDS = x_train.max()+1
CAT_NUM = y_train.max()+1

ds_train = tf.data.Dataset.from_tensor_slices((x_train,y_train)) \
          .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
          .prefetch(tf.data.experimental.AUTOTUNE).cache()
   
ds_test = tf.data.Dataset.from_tensor_slices((x_test,y_test)) \
          .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
          .prefetch(tf.data.experimental.AUTOTUNE).cache()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/reuters.npz
2113536/2110848 [==============================] - 4s 2us/step

一,内置fit方法

该方法功能非常强大, 支持对numpy array, tf.data.Dataset以及 Python generator数据进行训练。

并且可以通过设置回调函数实现对训练过程的复杂控制逻辑。

tf.keras.backend.clear_session()
def create_model():
    
    model = models.Sequential()
    model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
    model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
    model.add(layers.MaxPool1D(2))
    model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
    model.add(layers.MaxPool1D(2))
    model.add(layers.Flatten())
    model.add(layers.Dense(CAT_NUM,activation = "softmax"))
    return(model)

def compile_model(model):
    model.compile(optimizer=optimizers.Nadam(),
                loss=losses.SparseCategoricalCrossentropy(),
                metrics=[metrics.SparseCategoricalAccuracy(),metrics.SparseTopKCategoricalAccuracy(5)]) 
    return(model)
 
model = create_model()
model.summary()
model = compile_model(model)
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, 300, 7)            216874    
_________________________________________________________________
conv1d (Conv1D)              (None, 296, 64)           2304      
_________________________________________________________________
max_pooling1d (MaxPooling1D) (None, 148, 64)           0         
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 146, 32)           6176      
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 73, 32)            0         
_________________________________________________________________
flatten (Flatten)            (None, 2336)              0         
_________________________________________________________________
dense (Dense)                (None, 46)                107502    
=================================================================
Total params: 332,856
Trainable params: 332,856
Non-trainable params: 0
_________________________________________________________________
history = model.fit(ds_train,validation_data = ds_test,epochs = 10)
Train for 281 steps, validate for 71 steps
Epoch 1/10
281/281 [==============================] - 13s 46ms/step - loss: 2.0064 - sparse_categorical_accuracy: 0.4593 - sparse_top_k_categorical_accuracy: 0.7438 - val_loss: 1.6618 - val_sparse_categorical_accuracy: 0.5654 - val_sparse_top_k_categorical_accuracy: 0.7627
Epoch 2/10
281/281 [==============================] - 13s 46ms/step - loss: 1.4743 - sparse_categorical_accuracy: 0.6166 - sparse_top_k_categorical_accuracy: 0.7966 - val_loss: 1.5357 - val_sparse_categorical_accuracy: 0.5975 - val_sparse_top_k_categorical_accuracy: 0.7934
Epoch 3/10
281/281 [==============================] - 12s 42ms/step - loss: 1.2023 - sparse_categorical_accuracy: 0.6879 - sparse_top_k_categorical_accuracy: 0.8485 - val_loss: 1.5544 - val_sparse_categorical_accuracy: 0.6233 - val_sparse_top_k_categorical_accuracy: 0.8037
Epoch 4/10
281/281 [==============================] - 12s 41ms/step - loss: 0.9329 - sparse_categorical_accuracy: 0.7573 - sparse_top_k_categorical_accuracy: 0.9061 - val_loss: 1.7027 - val_sparse_categorical_accuracy: 0.6149 - val_sparse_top_k_categorical_accuracy: 0.8041
Epoch 5/10
281/281 [==============================] - 12s 43ms/step - loss: 0.6946 - sparse_categorical_accuracy: 0.8221 - sparse_top_k_categorical_accuracy: 0.9442 - val_loss: 1.9096 - val_sparse_categorical_accuracy: 0.6064 - val_sparse_top_k_categorical_accuracy: 0.8019
Epoch 6/10
281/281 [==============================] - 12s 42ms/step - loss: 0.5219 - sparse_categorical_accuracy: 0.8690 - sparse_top_k_categorical_accuracy: 0.9686 - val_loss: 2.1816 - val_sparse_categorical_accuracy: 0.6006 - val_sparse_top_k_categorical_accuracy: 0.7956
Epoch 7/10
281/281 [==============================] - 12s 42ms/step - loss: 0.4114 - sparse_categorical_accuracy: 0.8999 - sparse_top_k_categorical_accuracy: 0.9810 - val_loss: 2.4422 - val_sparse_categorical_accuracy: 0.5988 - val_sparse_top_k_categorical_accuracy: 0.7956
Epoch 8/10
281/281 [==============================] - 11s 39ms/step - loss: 0.3419 - sparse_categorical_accuracy: 0.9197 - sparse_top_k_categorical_accuracy: 0.9863 - val_loss: 2.6622 - val_sparse_categorical_accuracy: 0.6037 - val_sparse_top_k_categorical_accuracy: 0.7970
Epoch 9/10
281/281 [==============================] - 11s 39ms/step - loss: 0.2969 - sparse_categorical_accuracy: 0.9293 - sparse_top_k_categorical_accuracy: 0.9900 - val_loss: 2.8685 - val_sparse_categorical_accuracy: 0.6051 - val_sparse_top_k_categorical_accuracy: 0.8014
Epoch 10/10
281/281 [==============================] - 12s 41ms/step - loss: 0.2654 - sparse_categorical_accuracy: 0.9354 - sparse_top_k_categorical_accuracy: 0.9919 - val_loss: 3.0531 - val_sparse_categorical_accuracy: 0.6109 - val_sparse_top_k_categorical_accuracy: 0.8023

二,内置train_on_batch方法

该内置方法相比较fit方法更加灵活,可以不通过回调函数而直接在批次层次上更加精细地控制训练的过程。

tf.keras.backend.clear_session()

def create_model():
    model = models.Sequential()

    model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
    model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
    model.add(layers.MaxPool1D(2))
    model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
    model.add(layers.MaxPool1D(2))
    model.add(layers.Flatten())
    model.add(layers.Dense(CAT_NUM,activation = "softmax"))
    return(model)

def compile_model(model):
    model.compile(optimizer=optimizers.Nadam(),
                loss=losses.SparseCategoricalCrossentropy(),
                metrics=[metrics.SparseCategoricalAccuracy(),metrics.SparseTopKCategoricalAccuracy(5)]) 
    return(model)
 
model = create_model()
model.summary()
model = compile_model(model)
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, 300, 7)            216874    
_________________________________________________________________
conv1d (Conv1D)              (None, 296, 64)           2304      
_________________________________________________________________
max_pooling1d (MaxPooling1D) (None, 148, 64)           0         
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 146, 32)           6176      
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 73, 32)            0         
_________________________________________________________________
flatten (Flatten)            (None, 2336)              0         
_________________________________________________________________
dense (Dense)                (None, 46)                107502    
=================================================================
Total params: 332,856
Trainable params: 332,856
Non-trainable params: 0
_________________________________________________________________
def train_model(model,ds_train,ds_valid,epoches):

    for epoch in tf.range(1,epoches+1):
        model.reset_metrics()
        
        # 在后期降低学习率
        if epoch == 5:
            model.optimizer.lr.assign(model.optimizer.lr/2.0)
            tf.print("Lowering optimizer Learning Rate...\n\n")
        
        for x, y in ds_train:
            train_result = model.train_on_batch(x, y)

        for x, y in ds_valid:
            valid_result = model.test_on_batch(x, y,reset_metrics=False)
            
        if epoch%1 ==0:
            printbar()
            tf.print("epoch = ",epoch)
            print("train:",dict(zip(model.metrics_names,train_result)))
            print("valid:",dict(zip(model.metrics_names,valid_result)))
            print("")
train_model(model,ds_train,ds_test,10)
================================================================================09:18:28
epoch =  1
train: {'loss': 1.8016982, 'sparse_categorical_accuracy': 0.54545456, 'sparse_top_k_categorical_accuracy': 0.6363636}
valid: {'loss': 1.6246095, 'sparse_categorical_accuracy': 0.5765806, 'sparse_top_k_categorical_accuracy': 0.76268923}

================================================================================09:18:38
epoch =  2
train: {'loss': 1.7089436, 'sparse_categorical_accuracy': 0.59090906, 'sparse_top_k_categorical_accuracy': 0.6363636}
valid: {'loss': 1.610321, 'sparse_categorical_accuracy': 0.6046305, 'sparse_top_k_categorical_accuracy': 0.7934105}

================================================================================09:18:49
epoch =  3
train: {'loss': 1.5006503, 'sparse_categorical_accuracy': 0.54545456, 'sparse_top_k_categorical_accuracy': 0.8181818}
valid: {'loss': 1.9378943, 'sparse_categorical_accuracy': 0.6251113, 'sparse_top_k_categorical_accuracy': 0.80142474}

================================================================================09:18:59
epoch =  4
train: {'loss': 1.2124836, 'sparse_categorical_accuracy': 0.6363636, 'sparse_top_k_categorical_accuracy': 0.8181818}
valid: {'loss': 2.4395564, 'sparse_categorical_accuracy': 0.6246661, 'sparse_top_k_categorical_accuracy': 0.8009795}

Lowering optimizer Learning Rate...


================================================================================09:19:08
epoch =  5
train: {'loss': 0.8375286, 'sparse_categorical_accuracy': 0.6818182, 'sparse_top_k_categorical_accuracy': 0.95454544}
valid: {'loss': 2.9165814, 'sparse_categorical_accuracy': 0.63178986, 'sparse_top_k_categorical_accuracy': 0.7987533}

================================================================================09:19:17
epoch =  6
train: {'loss': 0.66361845, 'sparse_categorical_accuracy': 0.6818182, 'sparse_top_k_categorical_accuracy': 1.0}
valid: {'loss': 3.168786, 'sparse_categorical_accuracy': 0.6246661, 'sparse_top_k_categorical_accuracy': 0.7956367}

================================================================================09:19:26
epoch =  7
train: {'loss': 0.50838065, 'sparse_categorical_accuracy': 0.77272725, 'sparse_top_k_categorical_accuracy': 1.0}
valid: {'loss': 3.3748772, 'sparse_categorical_accuracy': 0.626447, 'sparse_top_k_categorical_accuracy': 0.7987533}

================================================================================09:19:35
epoch =  8
train: {'loss': 0.40919036, 'sparse_categorical_accuracy': 0.77272725, 'sparse_top_k_categorical_accuracy': 1.0}
valid: {'loss': 3.5492792, 'sparse_categorical_accuracy': 0.62422085, 'sparse_top_k_categorical_accuracy': 0.8009795}

================================================================================09:19:44
epoch =  9
train: {'loss': 0.35043362, 'sparse_categorical_accuracy': 0.8636364, 'sparse_top_k_categorical_accuracy': 1.0}
valid: {'loss': 3.69528, 'sparse_categorical_accuracy': 0.6202137, 'sparse_top_k_categorical_accuracy': 0.8032057}

================================================================================09:19:54
epoch =  10
train: {'loss': 0.30596277, 'sparse_categorical_accuracy': 0.8636364, 'sparse_top_k_categorical_accuracy': 1.0}
valid: {'loss': 3.799032, 'sparse_categorical_accuracy': 0.6121995, 'sparse_top_k_categorical_accuracy': 0.8036509}

三,自定义训练循环

自定义训练循环无需编译模型,直接利用优化器根据损失函数反向传播迭代参数,拥有最高的灵活性。

tf.keras.backend.clear_session()

def create_model():
    
    model = models.Sequential()

    model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
    model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
    model.add(layers.MaxPool1D(2))
    model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
    model.add(layers.MaxPool1D(2))
    model.add(layers.Flatten())
    model.add(layers.Dense(CAT_NUM,activation = "softmax"))
    return(model)

model = create_model()
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, 300, 7)            216874    
_________________________________________________________________
conv1d (Conv1D)              (None, 296, 64)           2304      
_________________________________________________________________
max_pooling1d (MaxPooling1D) (None, 148, 64)           0         
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 146, 32)           6176      
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 73, 32)            0         
_________________________________________________________________
flatten (Flatten)            (None, 2336)              0         
_________________________________________________________________
dense (Dense)                (None, 46)                107502    
=================================================================
Total params: 332,856
Trainable params: 332,856
Non-trainable params: 0
_________________________________________________________________
optimizer = optimizers.Nadam()
loss_func = losses.SparseCategoricalCrossentropy()

train_loss = metrics.Mean(name='train_loss')
train_metric = metrics.SparseCategoricalAccuracy(name='train_accuracy')

valid_loss = metrics.Mean(name='valid_loss')
valid_metric = metrics.SparseCategoricalAccuracy(name='valid_accuracy')

@tf.function
def train_step(model, features, labels):
    with tf.GradientTape() as tape:
        predictions = model(features,training = True)
        loss = loss_func(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss.update_state(loss)
    train_metric.update_state(labels, predictions)
    

@tf.function
def valid_step(model, features, labels):
    predictions = model(features)
    batch_loss = loss_func(labels, predictions)
    valid_loss.update_state(batch_loss)
    valid_metric.update_state(labels, predictions)
    

def train_model(model,ds_train,ds_valid,epochs):
    for epoch in tf.range(1,epochs+1):
        
        for features, labels in ds_train:
            train_step(model,features,labels)

        for features, labels in ds_valid:
            valid_step(model,features,labels)

        logs = 'Epoch={},Loss:{},Accuracy:{},Valid Loss:{},Valid Accuracy:{}'
        
        if epoch%1 ==0:
            printbar()
            tf.print(tf.strings.format(logs,
            (epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result())))
            tf.print("")
            
        train_loss.reset_states()
        valid_loss.reset_states()
        train_metric.reset_states()
        valid_metric.reset_states()

train_model(model,ds_train,ds_test,10)
================================================================================09:20:04
Epoch=1,Loss:2.03290606,Accuracy:0.454353154,Valid Loss:1.69233954,Valid Accuracy:0.558325887

================================================================================09:20:14
Epoch=2,Loss:1.49616826,Accuracy:0.613226473,Valid Loss:1.52313662,Valid Accuracy:0.604185224

================================================================================09:20:22
Epoch=3,Loss:1.22066784,Accuracy:0.680806041,Valid Loss:1.51799047,Valid Accuracy:0.627337515

================================================================================09:20:33
Epoch=4,Loss:0.945678711,Accuracy:0.749944329,Valid Loss:1.65234017,Valid Accuracy:0.627337515

================================================================================09:20:44
Epoch=5,Loss:0.678333282,Accuracy:0.822533965,Valid Loss:1.8622793,Valid Accuracy:0.621549428

================================================================================09:20:55
Epoch=6,Loss:0.483631164,Accuracy:0.882208884,Valid Loss:2.06073833,Valid Accuracy:0.623775601

================================================================================09:21:04
Epoch=7,Loss:0.371374488,Accuracy:0.912714303,Valid Loss:2.21256471,Valid Accuracy:0.629118443

================================================================================09:21:13
Epoch=8,Loss:0.305030555,Accuracy:0.927410364,Valid Loss:2.36870408,Valid Accuracy:0.63223511

================================================================================09:21:22
Epoch=9,Loss:0.262721539,Accuracy:0.936317086,Valid Loss:2.50547385,Valid Accuracy:0.630454123

================================================================================09:21:32
Epoch=10,Loss:0.234934,Accuracy:0.941104412,Valid Loss:2.62294984,Valid Accuracy:0.626892269
原创文章 58 获赞 7 访问量 6187

猜你喜欢

转载自blog.csdn.net/Elenstone/article/details/105805104