举例如下:
#!/usr/bin/env python # -*- coding: utf-8 -*- import numpy as np import tensorflow as tf input_data=np.array([[1,0],[0,1]]) embedding=np.array([[1,2,3],[4,5,6]]) result = tf.nn.embedding_lookup(embedding, input_data) with tf.Session() as sess: init = tf.initialize_all_variables() sess.run(init) print(sess.run(result))
输出
[[[4 5 6] [1 2 3]] [[1 2 3] [4 5 6]]]
可以看出:将输入数据input_data在embedding矩阵中查找,得出新矩阵inputs,查找规则为
首先看矩阵
input_data是多少行的, 若是m=2行矩阵,查找结果矩阵inputs就是m=2个(维度的)
再看input_data中每一行,比如第一行为[1,0],意味着查找嵌入矩阵embedding中的第一行和第0行元素,存入结果矩阵results的第一个矩阵中