数据集对象的建立
tf.data.Dataset
由一系列的可迭代访问的元素(element)组成,每个元素包含一个或多个张量。比如说,对于一个由图像组成的数据集,每个元素可以是一个形状为 长×宽×通道数
的图片张量,也可以是由图片张量和图片标签张量组成的元组(Tuple)
最基础的建立 tf.data.Dataset
的方法是使用 tf.data.Dataset.from_tensor_slices()
,适用于数据量较小(能够整个装进内存)的情况
import tensorflow as tf
import numpy as np
X = tf.constant([2013, 2014, 2015, 2016, 2017])
Y = tf.constant([12000, 14000, 15000, 16500, 17500])
# 也可以使用NumPy数组,效果相同
# X = np.array([2013, 2014, 2015, 2016, 2017])
# Y = np.array([12000, 14000, 15000, 16500, 17500])
dataset = tf.data.Dataset.from_tensor_slices((X, Y))
for x, y in dataset:
print(x.numpy(), y.numpy())
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
(train_data, train_label), (_, _) = tf.keras.datasets.mnist.load_data()
train_data = np.expand_dims(train_data.astype(np.float32) / 255.0, axis=-1) # [60000, 28, 28, 1]
mnist_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))
for image, label in mnist_dataset:
plt.title(label.numpy())
plt.imshow(image.numpy()[:, :, 0])
plt.show()
数据集对象的预处理
tf.data.Dataset
类为我们提供了多种数据集预处理方法。最常用的如:
-
Dataset.map(f)
:对数据集中的每个元素应用函数f
,得到一个新的数据集(这部分往往结合tf.io
进行读写和解码文件,tf.image
进行图像处理); -
Dataset.shuffle(buffer_size)
:将数据集打乱(设定一个固定大小的缓冲区(Buffer),取出前buffer_size
个元素放入,并从缓冲区中随机采样,采样后的数据用后续数据替换); -
Dataset.batch(batch_size)
:将数据集分成批次,即对每batch_size
个元素,使用tf.stack()
在第 0 维合并,成为一个元素;
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
(train_data, train_label), (_, _) = tf.keras.datasets.mnist.load_data()
train_data = np.expand_dims(train_data.astype(np.float32) / 255.0, axis=-1) # [60000, 28, 28, 1]
mnist_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))
def rot90(image, label):
image = tf.image.rot90(image)
return image, label
mnist_dataset = mnist_dataset.map(rot90)
# for image, label in mnist_dataset:
# plt.title(label.numpy())
# plt.imshow(image.numpy()[:, :, 0])
# plt.show()
mnist_dataset = mnist_dataset.shuffle(buffer_size=10000).batch(4)
for images, labels in mnist_dataset:
fig, axs = plt.subplots(1, 4)
for i in range(4):
axs[i].set_title(labels.numpy()[i])
axs[i].imshow(images.numpy()[i, :, :, 0])
plt.show()
使用 tf.data
的并行化策略提高训练流程效率
tf.data
的数据集对象为我们提供了 Dataset.prefetch()
方法,使得我们可以让数据集对象 Dataset
在训练时预取出若干个元素,使得在 GPU 训练的同时 CPU 可以准备数据,从而提升训练流程的效率
mnist_dataset = mnist_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
与此类似, Dataset.map()
也可以利用多 GPU 资源,并行化地对数据项进行变换,从而提高效率
mnist_dataset = mnist_dataset.map(map_func=rot90, num_parallel_calls=2)
数据集元素的获取与使用
dataset = tf.data.Dataset.from_tensor_slices((A, B, C, ...))
for a, b, c, ... in dataset:
# 对张量a, b, c等进行操作,例如送入模型进行训练
import tensorflow as tf
import os
import matplotlib.pyplot as plt
num_epochs = 1
batch_size = 4
learning_rate = 0.001
data_dir = './catsdogs'
train_cats_dir = data_dir + '/train/Cat/'
train_dogs_dir = data_dir + '/train/Dog/'
test_cats_dir = data_dir + '/test/Cat/'
test_dogs_dir = data_dir + '/test/Dog/'
def _decode_and_resize(filename, label):
image_string = tf.io.read_file(filename) # 读取原始文件
image_decoded = tf.image.decode_image(image_string,expand_animations = False) # 解码JPEG图片
# image_decoded = tf.image.decode_jpeg(image_string)
image_resized = tf.image.resize(image_decoded, [256, 256]) / 255.0
return image_resized, label
if __name__ == '__main__':
# 构建训练数据集
train_cat_filenames = tf.constant([train_cats_dir + filename for filename in os.listdir(train_cats_dir)])
train_dog_filenames = tf.constant([train_dogs_dir + filename for filename in os.listdir(train_dogs_dir)])
train_filenames = tf.concat([train_cat_filenames, train_dog_filenames], axis=-1)
hh=train_cat_filenames.shape
train_labels = tf.concat([
tf.zeros(train_cat_filenames.shape, dtype=tf.int32),
tf.ones(train_dog_filenames.shape, dtype=tf.int32)],
axis=-1)
train_dataset = tf.data.Dataset.from_tensor_slices((train_filenames, train_labels))
# for num,(image,labels) in enumerate(train_dataset):
# print(num,image,labels)
# image,labels=_decode_and_resize(image,labels)
# print(image.shape)
# if True:
# train_dataset = train_dataset.shuffle(buffer_size=1000)
# for images, labels in train_dataset:
# images,labels=_decode_and_resize(images,labels)
# plt.title(labels.numpy())
# plt.imshow(images.numpy())
# plt.show()
train_dataset = train_dataset.map(
map_func=_decode_and_resize,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
# 取出前buffer_size个数据放入buffer,并从其中随机采样,采样后的数据用后续数据替换
train_dataset = train_dataset.shuffle(buffer_size=1000)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(256, 256, 3)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(32, 5, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(2, activation='softmax')
])
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=[tf.keras.metrics.sparse_categorical_accuracy]
)
model.fit(train_dataset, epochs=num_epochs)
# 构建测试数据集
test_cat_filenames = tf.constant([test_cats_dir + filename for filename in os.listdir(test_cats_dir)])
test_dog_filenames = tf.constant([test_dogs_dir + filename for filename in os.listdir(test_dogs_dir)])
test_filenames = tf.concat([test_cat_filenames, test_dog_filenames], axis=-1)
test_labels = tf.concat([
tf.zeros(test_cat_filenames.shape, dtype=tf.int32),
tf.ones(test_dog_filenames.shape, dtype=tf.int32)],
axis=-1)
test_dataset = tf.data.Dataset.from_tensor_slices((test_filenames, test_labels))
test_dataset = test_dataset.map(_decode_and_resize)
test_dataset = test_dataset.batch(batch_size)
print(model.metrics_names)
print(model.evaluate(test_dataset))
tensorflow读取图片
import tensorflow as tf
import os
import matplotlib.pyplot as plt
def decode_and_resize(filename):
image_string = tf.io.read_file(filename) # 读取原始文件
image_decoded = tf.image.decode_jpeg(image_string) # 解码JPEG图片
image_resized = tf.image.resize(image_decoded, [256, 256]) / 255.0
return image_resized
image=decode_and_resize('./catsdogs/test/Cat/85.jpg')
plt.imshow(image.numpy())
plt.show()
print(image)
参考文献: