最简单的模型是tf.keras.Sequential
模型,可以理解为神经网络层的简单堆叠,可以完成一些简单的分类任务,但是却无法表示任意模型。通常将使用函数式API
来构建更加复杂的模型。本文在这里只简单总结序列模型的使用。
1. 导入相关依赖
import tensorflow as tf
2. 定义序列模型
model = tf.keras.Sequential()
3. 添加神经网络层
model.add(tf.keras.layers.Dense(units=32, input_shape=(16,), activation="relu"))
model.add(tf.keras.layers.Dense(units=32, input_shape=(16,),
activation=tf.keras.activations.relu))
model.add(tf.keras.layers.Dense(units=32, input_shape=(16,), activation="relu",
kernel_regularizer=tf.keras.regularizers.l2(l=0.01)))
input arrays of shape (, 16)
and output arrays of shape (, 32)
Dense
层,表示全连接层。其中的参数解释如下:
Arguments
:
units
: 正整数,表示输出的维度。
activation
: 激活函数,默认情况下,系统不会应用任何激活函数,即保持原样。
use_bias
: Boolean
, 指定该层是否使用偏置向量。
kernel_initializer
: 权重矩阵的初始化方式指定,默认是"glorot_uniform
" 初始化器。
bias_initializer
: 偏置向量的初始化方式指定,默认是"zeros
" 初始化器。
kernel_regularizer
: 权重矩阵的正则化方案,默认是None
,即没有正则化。
bias_regularizer
: 偏置向量的正则化方案,默认是None
,即没有正则化。
4. 配置该模型的学习流程
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.categorical_crossentropy,
metrics=[tf.keras.metrics.categorical_accuracy])
model.compile(
optimizer='rmsprop',
loss=None,
metrics=None,
loss_weights=None,
)
optimizer
: 优化器的字符串,或者tf.keras.optimizers
优化器实例对象。
loss
: 优化函数的字符串,或者是tf.losses.Loss
实例对象。
metrics
: 评估模型训练和测试的度量,可以是一个字符串,也可以是字符串列表,如:metrics=['accuracy', 'mse']
,同样也可以是一个对应的实例对象。
5. 开始训练
可以输入的数据类型:
- Numpy
- BatchDataset
5.1 Numpy
model.fit(numpy_train_x, numpy_train_y, epochs=10, batch_size=100,
validation_data=(numpy_val_x, numpy_val_y))
numpy_*
也就是numpy.ndarray
格式的数据
5.2 tf.data格式
dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y)) # 特征和标签的配对
dataset = dataset.batch(32) # 划分每32个为一个batch
dataset = dataset.repeat() # 重复数据集,无数次
val_dataset = tf.data.Dataset.from_tensor_slices((val_x, val_y))
val_dataset = val_dataset.batch(32)
val_dataset = val_dataset.repeat()
model.fit(dataset, epochs=10, steps_per_epoch=30,
validation_data=val_dataset, validation_steps=3)
输入特征与对应标签的配对,使用tf.data.Dataset.from_tensor_slices((输入特征, 标签))
,Numpy
和Tensor
格式的都可以使用该语句读入数据。
dataset.repeat()
表示重复数据集,当repeat()
参数为空时,意思是重复无数遍,永远不会有读取不到数据batch
的情况。因为我们通常需要在数据集上跑多轮。
6. 评估
model.evaluate(numpy_test_x, test_y, batch_size=32)
# 或者
model.evaluate(batch_test_data, steps=30) # test_data.batch(32).repeat()
7. 预测
result = model.predict(batch_test_data, batch_size=32)