tf.keras快速入门——序列模型

最简单的模型是tf.keras.Sequential模型,可以理解为神经网络层的简单堆叠,可以完成一些简单的分类任务,但是却无法表示任意模型。通常将使用函数式API来构建更加复杂的模型。本文在这里只简单总结序列模型的使用。

1. 导入相关依赖

import tensorflow as tf

2. 定义序列模型

model = tf.keras.Sequential()

3. 添加神经网络层

model.add(tf.keras.layers.Dense(units=32, input_shape=(16,), activation="relu"))
model.add(tf.keras.layers.Dense(units=32, input_shape=(16,), 
                                activation=tf.keras.activations.relu))
model.add(tf.keras.layers.Dense(units=32, input_shape=(16,), activation="relu", 
                                kernel_regularizer=tf.keras.regularizers.l2(l=0.01)))

input arrays of shape (, 16)
and output arrays of shape (
, 32)

Dense层,表示全连接层。其中的参数解释如下:
Arguments:
units: 正整数,表示输出的维度。
activation: 激活函数,默认情况下,系统不会应用任何激活函数,即保持原样。
use_bias: Boolean, 指定该层是否使用偏置向量。
kernel_initializer: 权重矩阵的初始化方式指定,默认是"glorot_uniform" 初始化器。
bias_initializer: 偏置向量的初始化方式指定,默认是"zeros" 初始化器。
kernel_regularizer: 权重矩阵的正则化方案,默认是None,即没有正则化。
bias_regularizer: 偏置向量的正则化方案,默认是None,即没有正则化。

4. 配置该模型的学习流程

model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
             loss=tf.keras.losses.categorical_crossentropy,
             metrics=[tf.keras.metrics.categorical_accuracy])
model.compile(
    optimizer='rmsprop',
    loss=None,
    metrics=None,
    loss_weights=None,
)

optimizer: 优化器的字符串,或者tf.keras.optimizers优化器实例对象。
loss: 优化函数的字符串,或者是tf.losses.Loss 实例对象。
metrics: 评估模型训练和测试的度量,可以是一个字符串,也可以是字符串列表,如:metrics=['accuracy', 'mse'],同样也可以是一个对应的实例对象。

5. 开始训练

可以输入的数据类型:

  • Numpy
  • BatchDataset

5.1 Numpy

model.fit(numpy_train_x, numpy_train_y, epochs=10, batch_size=100,
          validation_data=(numpy_val_x, numpy_val_y))

numpy_*也就是numpy.ndarray格式的数据

5.2 tf.data格式

dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y)) # 特征和标签的配对
dataset = dataset.batch(32) # 划分每32个为一个batch
dataset = dataset.repeat() # 重复数据集,无数次
val_dataset = tf.data.Dataset.from_tensor_slices((val_x, val_y))
val_dataset = val_dataset.batch(32)
val_dataset = val_dataset.repeat()

model.fit(dataset, epochs=10, steps_per_epoch=30,
          validation_data=val_dataset, validation_steps=3)

输入特征与对应标签的配对,使用tf.data.Dataset.from_tensor_slices((输入特征, 标签))NumpyTensor格式的都可以使用该语句读入数据。
dataset.repeat()表示重复数据集,当repeat()参数为空时,意思是重复无数遍,永远不会有读取不到数据batch的情况。因为我们通常需要在数据集上跑多轮。

6. 评估

model.evaluate(numpy_test_x, test_y, batch_size=32)
# 或者
model.evaluate(batch_test_data, steps=30) # test_data.batch(32).repeat()

7. 预测

result = model.predict(batch_test_data, batch_size=32)

案例

鸢尾花分类
手写数字识别

猜你喜欢

转载自blog.csdn.net/qq_26460841/article/details/113543962