tensorflow的one_hot()函数

       在机器学习中进行分类任务时,数据的label一般是一维的,无论是几分类,比如2分类,数据的label一般就是0和1的组合;3分类的时候就是0,1,2的组合等。以此类推,若是100分类,则需要100个变量标识各个label,为了简化计算和减少参数,一般的会将一维的label扩展为n维的label,比如2分类的问题用[1,0]表示label为0的类别,[0,1]表示label为1的类别;若是3分类的,则用

[1,0,0]

[0,1,0]

[0,0,1]分别表示label为0,1,2的类别。

而将一维标签转为n维标签,有多种算法可以实现,如果是使用google的深度学习框架tensorflow,则可以直接用函数one_hot()实现。

def one_hot(indices,
            depth,
            on_value=None,
            off_value=None,
            axis=None,
            dtype=None,
            name=None):
  """Returns a one-hot tensor.

在tensorflow中one_hot()定义如上,一共有七个参数,定义分别如下:

Args:
indices: A Tensor of indices.
depth: A scalar defining the depth of the one hot dimension.
on_value: A scalar defining the value to fill in output when indices[j] = i. (default: 1)
off_value: A scalar defining the value to fill in output when indices[j] != i. (default: 0)
axis: The axis to fill (default: -1, a new inner-most axis).
dtype: The data type of the output tensor.
name: A name for the operation (optional).

在七个参数中,indices就是需要进行one_hot编码的tensor,depth就是上面说到的n分类。而on_value和off_value分别为当前位置是某个分类时用哪个数字表示和不是某个分类时用哪个数字表示,默认分别用1和0表示。2分类和3分类的热编码程序如下:

import tensorflow as tf

a = [1,0,0,1,1]
b = tf.one_hot(a,2)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    c = sess.run(b)
    print(c)

output:
[[0. 1.]
 [1. 0.]
 [1. 0.]
 [0. 1.]
 [0. 1.]]
import tensorflow as tf

a = [1,0,2,1,2,0]
b = tf.one_hot(a,3)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    c = sess.run(b)
    print(c)

output:
[[0. 1. 0.]
 [1. 0. 0.]
 [0. 0. 1.]
 [0. 1. 0.]
 [0. 0. 1.]
 [1. 0. 0.]]

下图是将参数on_value和off_value分别赋值为10和5后的运行结果:

import tensorflow as tf

a = [1,0,2,1,2,0]
b = tf.one_hot(a,3,10,5)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    c = sess.run(b)
    print(c)

output:
[[ 5 10  5]
 [10  5  5]
 [ 5  5 10]
 [ 5 10  5]
 [ 5  5 10]
 [10  5  5]]

最后的三个参数中axis是用来定义从哪个维度进行one_hot编码的,默认是-1,dtype则是定义输出的tensor的数据类型,若是on_value、off_value、dtype都为赋值,则默认为dtype = tf.float32;参数name则是tensor的name。

import tensorflow as tf

a = [1,0,2,1,2,0]
b = tf.one_hot(a,3,10,5,0)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    c = sess.run(b)
    print(c)

output:
[[ 5 10  5  5  5 10]
 [10  5  5 10  5  5]
 [ 5  5 10  5 10  5]]

注意:axis的参数已赋值为0 ,即从默认值-1改为0

附one_hot()函数定义地址:https://tensorflow.google.cn/api_docs/python/tf/one_hot

猜你喜欢

转载自blog.csdn.net/Asunqingwen/article/details/82778592