记录 之 tensorflow函数:tf.data.Dataset.from_tensor_slices

tf.data.Dataset.from_tensor_slices(),是常见的数据处理函数,它的作用是将给定的元组(turple)、列表(list)、张量(tensor)等特征进行特征切片。切片的范围是从最外层维度开始的。
更具体的,假设我们有一组特征集合(features),以及这组数据集合所对应的标签集合(labels),那么我们如何将每个数据与其对应的标签进行组合,构成一个个完整训练数据集合([feature_1, label_1],[feature_2, label_2],........).讲道理,tf.data.Dataset.from_tensor_slices函数就是完成这个需求。

例:

import tensorflow as tf

a = tf.random_uniform((4,3))
b = tf.random_uniform((4,1))
data1 = tf.data.Dataset.from_tensor_slices((a,b))
data2 = tf.data.Dataset.from_tensor_slices(a)
print(data1)
print(data2)

>>> <DatasetV1Adapter shapes: ((3,), (1,)), types: (tf.float32, tf.float32)>

>>> <DatasetV1Adapter shapes: (3,), types: tf.float32>

我们可以看到返回值是一个DatasetV1Adapter,这是一个数据迭代器。

猜你喜欢

转载自blog.csdn.net/qq_41368074/article/details/110006719
今日推荐