Tensorflow-tf.sparse_to_dense()

对tf.sparse_to_dense()的参数concated的理解:

batch_size = tf.size(labels)

labels = tf.expand_dims(labels, 1)

indices = tf.expand_dims(tf.range(0, batch_size, 1), 1)

concated = tf.concat([indices, labels], 1)

onehot_labels = tf.sparse_to_dense(concated, tf.stack([batch_size, NUM_CLASSES]), 1.0, 0.0)

========================================

实际上在tf.sparse_to_dense()的原函数定义中,传递的这个参数是indices,即索引值。tf.stack()位置处对应的参数是输出的tensor的shape,第三第四个参数为指定的索引的值和非指定的索引的值。

所以对于one_hot编码的输出:

[[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]

 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]

 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]

 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]

 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]

 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]

 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]

 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

[0. 1. 0. 1. 0. 1. 0. 0. 0. 0.]

concated中的元素,就是上面这个array中非0元素的对应的索引,可以说是坐标

比如[0,1] 、[1,3]、[2,5]…等等

所以索引的第一个元素相当于是第几行,第二个元素为第几列。因为是one_hot编码,所以根据这个编码的特点,每一行中必须有且仅有一个元素非零,而对于列来说,未必每一列都有元素非零,也未必每一列都不止一个元素为非零。所以对于非零元素在整个array中的索引,行坐标是递增且每一行都照顾得到的。所以concated = tf.concat(tf.expand_dims(tf.range(0, batch_size, 1), 1), tf.expand_dims(labels, 1))

来得到所有非零元素在整个输出的array中的索引(坐标)

注意tf.expand_dims()函数,用于增加维度,只有增加维度后才能进行合适的tensor的tf.concat

猜你喜欢

转载自blog.csdn.net/weixin_39721347/article/details/86172611