TensorFlow 的tf.one_hot函数用法 TensorFlow 的tf.one_hot函数用法

TensorFlow 的tf.one_hot函数用法

tf.one_hot 函数
     
      
      

 
  
  
  1. one_hot(
  2. indices,
  3. depth,
  4. on_value= None,
  5. off_value= None,
  6. axis= None,
  7. dtype= None,
  8. name= None
  9. )

indices: 代表了on_value所在的索引,其他位置值为off_value。类型为tensor,其尺寸与depth共同决定输出tensor的尺寸。

depth:编码深度。

on_value & off_value为编码开闭值,缺省分别为1和0,indices指定的索引处为on_value值;

axis:编码的轴,分情况可取-1、0或-1、0、1,默认为-1

dtype:默认为 on_value 或 off_value的类型,若未提供on_value或off_value,则默认为tf.float32类型。

返回一个 one-hot 张量。 

索引中由索引表示的位置取值 on_value,而所有其他位置都取值 off_value。 

on_value 和 off_value必须具有匹配的数据类型。如果还提供了 dtype,则它们必须与 dtype 指定的数据类型相同。

如果未提供 on_value,则默认值将为 1,其类型为 dtype。 

如果未提供 off_value,则默认值为 0,其类型为 dtype。

 

假设如下:


 
  
  
  1. indices = [ 0, 2, - 1, 1]
  2. depth = 3
  3. on_value = 5.0
  4. off_value = 0.0
  5. axis = - 1

那么输出为 [4 x 3]:


 
  
  
  1. output =
  2. [ 5.0 0.0 0.0] // one_hot(0)
  3. [ 0.0 0.0 5.0] // one_hot(2)
  4. [ 0.0 0.0 0.0] // one_hot(-1)
  5. [ 0.0 5.0 0.0] // one_hot(1)
tf.one_hot 函数
     
  
  

Guess you like

Origin blog.csdn.net/qq_40837542/article/details/103934376