tf.nn.embedding_lookup
tf.nn.embedding_lookup(
params,
ids,
partition_strategy='mod',
name=None,
validate_indices=True,
max_norm=None
)
Args:
- params: is a single tensor
, representing a large embedding table; or shape
consists of p identical tensors list
, except that the first dimension represents shared embedding tensors.
- ids: Tensor of type int32
or int64
, representing the ids that need to find the embedding table in params.
- partition_strategy: Currently “div” and “mod” are supported. Default is “mod”.
- name: A name for the operation (optional).
- validate_indices: DEPRECATED.
- max_norm: If provided, the value of embedding will be L2 normalized to the value of max_norm.
Code:
import tensorflow as tf
import numpy as np
embedding = tf.random_uniform(minval=0,maxval=10,shape=(50,50))
ids = [2,1,0]
embed = np.array(np.arange(1000)).reshape(100,10)
some_embedding = tf.nn.embedding_lookup(embedding,ids=ids)
some_embed = tf.nn.embedding_lookup(embed,ids=ids)
with tf.Session() as sess:
print(sess.run(some_embedding))
print(sess.run(some_embed))
- result:
[[1.9174099 1.4985561 8.763132 5.1513696 3.0472064 1.3083363
8.763195 9.585258 6.5263486 4.5616875 1.7619014 3.474176
6.629846 7.9499936 3.3529973 3.8903248 3.5784101 2.235669
5.382924 3.5264575 6.6558695 5.08219 8.104395 4.7079873
6.942602 8.672055 9.369327 2.9697776 2.1853507 7.4739685
7.0480814 6.7367516 2.4502242 3.1823123 1.7625916 9.54441
5.154147 0.1355648 7.0668087 0.90946436 1.8672597 6.3007355
0.8664024 3.5557115 2.82979 1.6570354 5.1281223 4.6919146
2.8726697 4.4800186 ]
[7.004551 0.11087418 4.191972 1.8102491 5.845166 5.9069967
5.808556 7.354327 5.4576874 5.7963667 6.093507 9.707213
4.706266 0.89920044 7.6548386 4.213927 3.8301349 9.722362
3.239541 2.0103395 6.5536427 0.5413759 6.9744706 4.4435463
7.05449 8.396385 9.050558 8.112217 9.726019 3.3593094
1.8417215 5.391085 0.22029877 3.4840071 8.741225 1.8167889
2.7843773 7.633685 5.02874 4.2803707 8.34441 2.6564145
8.721051 6.1455345 0.05278707 5.921961 0.7731664 4.822198
0.674752 7.203877 ]
[4.561104 4.972656 3.8639069 8.641497 5.19176 6.2906036
8.288484 1.6396677 4.090029 0.7623017 8.548621 8.682202
6.331664 5.1811004 4.39334 9.6758795 8.779367 6.086396
1.9097066 3.6502922 4.1653624 9.572655 1.2448967 2.2894263
1.0316873 9.54537 4.5649695 8.553878 7.0591154 6.1101866
3.672725 3.2488763 8.497752 9.083318 5.6434383 0.73266864
5.2025137 6.840452 7.8819666 7.0142117 5.364611 7.310052
9.392939 7.888708 5.781497 1.2782669 5.573305 1.9477713
2.99865 2.6828694 ]]
[[20 21 22 23 24 25 26 27 28 29]
[10 11 12 13 14 15 16 17 18 19]
[ 0 1 2 3 4 5 6 7 8 9]]