TensorFlow 的tf.one_hot函数用法
发布于2018-08-20 09:42:03
tf.one_hot 函数
-
one_hot(
-
indices,
-
depth,
-
on_value=
None,
-
off_value=
None,
-
axis=
None,
-
dtype=
None,
-
name=
None
-
)
indices: 代表了on_value所在的索引,其他位置值为off_value。类型为tensor,其尺寸与depth共同决定输出tensor的尺寸。
depth:编码深度。
on_value & off_value为编码开闭值,缺省分别为1和0,indices指定的索引处为on_value值;
axis:编码的轴,分情况可取-1、0或-1、0、1,默认为-1
dtype:默认为 on_value 或 off_value的类型,若未提供on_value或off_value,则默认为tf.float32类型。
返回一个 one-hot 张量。
索引中由索引表示的位置取值 on_value,而所有其他位置都取值 off_value。
on_value 和 off_value必须具有匹配的数据类型。如果还提供了 dtype,则它们必须与 dtype 指定的数据类型相同。
如果未提供 on_value,则默认值将为 1,其类型为 dtype。
如果未提供 off_value,则默认值为 0,其类型为 dtype。
假设如下:
-
indices = [
0,
2, -
1,
1]
-
depth =
3
-
on_value =
5.0
-
off_value =
0.0
-
axis = -
1
那么输出为 [4 x 3]:
-
output =
-
[
5.0 0.0 0.0]
// one_hot(0)
-
[
0.0 0.0 5.0]
// one_hot(2)
-
[
0.0 0.0 0.0]
// one_hot(-1)
-
[
0.0 5.0 0.0]
// one_hot(1)
tf.one_hot 函数