tensorflow学习笔记--embedding_lookup()用法

 

个人分类: tensorflowpython

所属专栏: Tensorflow修炼手册

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013041398/article/details/60955847

embedding_lookup( )的用法 
关于tensorflow中embedding_lookup( )的用法,在Udacity的word2vec会涉及到,本文将通俗的进行解释。

首先看一段网上的简单代码:

#!/usr/bin/env/python
# coding=utf-8
import tensorflow as tf
import numpy as np

input_ids = tf.placeholder(dtype=tf.int32, shape=[None])

embedding = tf.Variable(np.identity(5, dtype=np.int32))
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
print(embedding.eval())
print(sess.run(input_embedding, feed_dict={input_ids:[1, 2, 3, 0, 3, 2, 1]}))

代码中先使用palceholder定义了一个未知变量input_ids用于存储索引,和一个已知变量embedding,是一个5*5的对角矩阵。 
运行结果为:

embedding = [[1 0 0 0 0]
             [0 1 0 0 0]
             [0 0 1 0 0]
             [0 0 0 1 0]
             [0 0 0 0 1]]
input_embedding = [[0 1 0 0 0]
                   [0 0 1 0 0]
                   [0 0 0 1 0]
                   [1 0 0 0 0]
                   [0 0 0 1 0]
                   [0 0 1 0 0]
                   [0 1 0 0 0]]

简单的讲就是根据input_ids中的id,寻找embedding中的对应元素。比如,input_ids=[1,3,5],则找出embedding中下标为1,3,5的向量组成一个矩阵返回。

如果将input_ids改写成下面的格式:

input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
print(sess.run(input_embedding, feed_dict={input_ids:[[1, 2], [2, 1], [3, 3]]}))

输出结果就会变成如下的格式:

[[[0 1 0 0 0]
  [0 0 1 0 0]]
 [[0 0 1 0 0]
  [0 1 0 0 0]]
 [[0 0 0 1 0]
  [0 0 0 1 0]]]

对比上下两个结果不难发现,相当于在np.array中直接采用下标数组获取数据。需要注意的细节是返回的tensor的dtype和传入的被查询的tensor的dtype保持一致;和ids的dtype无关。

猜你喜欢

转载自blog.csdn.net/qq_41853536/article/details/82689068
今日推荐