tf.nn.embedding_lookup()的用法

举例如下:
 
 
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf 
input_data=np.array([[1,0],[0,1]])
embedding=np.array([[1,2,3],[4,5,6]])
result = tf.nn.embedding_lookup(embedding, input_data)
with tf.Session() as sess:
    init = tf.initialize_all_variables()
    sess.run(init)
    print(sess.run(result))

输出
[[[4 5 6]
  [1 2 3]]

 [[1 2 3]
  [4 5 6]]]

可以看出:将输入数据input_data在embedding矩阵中查找,得出新矩阵inputs,查找规则为
首先看矩阵 input_data是多少行的, 若是m=2行矩阵,查找结果矩阵inputs就是m=2个(维度的)
再看input_data中每一行,比如第一行为[1,0],意味着查找嵌入矩阵embedding中的第一行和第0行元素,存入结果矩阵results的第一个矩阵中

猜你喜欢

转载自blog.csdn.net/weixin_38145317/article/details/79410395