tensorflow2中的tf.data.Dataset.from_tensor_slices()

tf.data.Dataset.from_tensor_slices()函数的参数是tensor。
该函数的作用是接收tensor,对tensor的第一维度进行切分,并返回一个表示该tensor的切片数据集
以minist训练集为例:
x的shape为(60000,28,28),将x作为参数传递给tf.data.Dataset.from_tensor_slices(),
将返回一个含有60000个切片的数据集,每个切片为 28*28 的图像(但数据集不知道里面有多少个切片)。
代码如下:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets

(x, y),  _ = datasets.mnist.load_data()
x = tf.convert_to_tensor(x, dtype=tf.float32)/255.
print(x.shape)

train_db = tf.data.Dataset.from_tensor_slices(x)
print(train_db)

输出的结果为:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/coolyuan/article/details/104203167