用MNIST数据集实现K-折交叉验证

基本思路
K-折交叉验证的方法是将数据集分成k个互斥的子集(一般是均分),然后将每个子集分别做一次验证集,其余K-1组子集作为训练集。在每次训练完的模型后进行验证来对模型性能进行估计。

应用的函数

  1. tf.range()

tf.range(start, limit, delta=1, dtype=None, name='range'

功能是创建一个开始于 start 并且将以 delta 为增量扩展到limit-1 的数字序列。
例如:

a = tf.range(0, 10)
tf.Tensor([0 1 2 3 4 5 6 7 8 9], shape=(10,), dtype=int32)
  1. tf.random.shuffle()
tf.random_shuffle(
    value,
    seed=None,
    name=None
)

功能:对张量value的第一维度进行打乱。
例如:

a = tf.random.shuffle(a)
tf.Tensor([4 3 7 5 9 8 6 1 0 2], shape=(10,), dtype=int32)
  1. tf.gather()
    功能:用一个索引数组将张量中对应索引的向量提取出来。
    例如:
index = tf.range(0, 2)  # [0, 1]
x = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
print(tf.gather(x, index))

输出为:
[[1 2 3]
 [4 5 6]]

使用上面的函数就可以实现对数据集元素的随机打乱并划分。

最终代码
下面以10-折交叉验证为例:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
from tensorflow.keras import datasets, layers, Sequential, optimizers

# load data
(x, y), (x_test, y_test) = datasets.mnist.load_data()

# build datasets
def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = tf.reshape(x, [-1, 28 * 28])
    y = tf.cast(y, dtype=tf.int64)
    y = tf.one_hot(y, depth=10)
    return x, y


print('datasets:', x.shape, y.shape, x_test.shape, y_test.shape)

index = tf.range(0, 60000)
index = tf.random.shuffle(index)
x_train, y_train = tf.gather(x, index[:54000]), tf.gather(y, index[:54000])  # 60000 * 9/10
x_val, y_val = tf.gather(x, index[-6000:]), tf.gather(y, index[-6000:])

# print the shapes of training dataset and validation dataset
print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)

batchsz = 128
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.batch(batchsz).shuffle(54000).map(preprocess)

db_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
db_val = db_val.batch(batchsz).shuffle(6000).map(preprocess)

db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.batch(batchsz).map(preprocess)

# print a sample in training dataset
sample = next(iter(db_train))
print('sample shape:', sample[0].shape, sample[1].shape)

# build network
network = Sequential([
    layers.Dense(256, activation='relu'),  # [b, 784] => [b, 256]
    layers.Dense(128, activation='relu'),  # [b, 256] => [b, 128]
    layers.Dense(64, activation='relu'),  # [b, 128] => [b, 64]
    layers.Dense(32, activation='relu'),  # [b, 64] => [b, 32]
    layers.Dense(10, )  # [b, 32] => [b, 10]
])

network.build(input_shape=[None, 28 * 28])
network.summary()

network.compile(optimizer=optimizers.Adam(lr=1e-3),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
network.fit(db_train, epochs=10, validation_data=db_val, validation_freq=1)

# print text accuracy
print('test accuracy:')
network.evaluate(db_test)




猜你喜欢

转载自blog.csdn.net/coolyuan/article/details/104276183