Tensorflow - tf.nn.embedding_lookup使用

  • 原型:tf.nn.embedding_lookup(params, ids, partition_strategy='mod', name=None, validate_indices=True, max_norm=None)
  • 实际上tf.nn.embedding_lookup的作用就是找到要寻找的embedding data中的对应的行下的vector。
  • 简单地通过代码了解一下,ids为1行或者多行都囊括:
    # -*- coding= utf-8 -*-
    import tensorflow as tf
    import numpy as np
    
    a = [[0.1, 0.2, 0.3], [1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3]]
    a = np.asarray(a)
    idx1 = tf.Variable([0, 2, 3, 1], tf.int32)
    idx2 = tf.Variable([[0, 2, 3, 1], [4, 0, 2, 2]], tf.int32)
    out1 = tf.nn.embedding_lookup(a, idx1)
    out2 = tf.nn.embedding_lookup(a, idx2)
    init = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init)
        print sess.run(out1)
        print out1
        print '=================='
        print sess.run(out2)
        print out2
  • 输出:
    [[ 0.1  0.2  0.3]
     [ 2.1  2.2  2.3]
     [ 3.1  3.2  3.3]
     [ 1.1  1.2  1.3]]
    Tensor("embedding_lookup:0", shape=(4, 3), dtype=float64)
    ==================
    [[[ 0.1  0.2  0.3]
      [ 2.1  2.2  2.3]
      [ 3.1  3.2  3.3]
      [ 1.1  1.2  1.3]]
    
     [[ 4.1  4.2  4.3]
      [ 0.1  0.2  0.3]
      [ 2.1  2.2  2.3]
      [ 2.1  2.2  2.3]]]
    Tensor("embedding_lookup_1:0", shape=(2, 4, 3), dtype=float64)
  • 维度讨论:在要寻找的embedding数据中下找对应的index下的vector进行拼接。永远是ids部分的维度+embedding部分的除了第一维后的维度拼接。很明显,我们也可以得到,ids里面值是必须要小于等于embedding最大维度减一的。

参考文章https://www.jianshu.com/p/ad88a0afa98f

猜你喜欢

转载自www.cnblogs.com/Jesee/p/11445560.html