tensorflow API : tf.nn.embedding_lookup

tf.nn.embedding_lookup

tf.nn.embedding_lookup(
    params,
    ids,
    partition_strategy='mod',
    name=None,
    validate_indices=True,
    max_norm=None
)

Args:

- params: 是一个单独的tensor,表示一个大的embedding表;或者p个相同shape的tensors组成的list,除了第一维代表共享的embedding tensors。

- ids: 类型为int32或者int64的Tensor ,代表了需要查找params中的embedding表的ids。

- partition_strategy: Currently “div” and “mod” are supported. Default is “mod”.

- name: A name for the operation (optional).

- validate_indices: DEPRECATED.

- max_norm: 如果提供,embedding的值会L2正则化到max_norm的值.

代码:

import tensorflow as tf
import numpy as np


embedding = tf.random_uniform(minval=0,maxval=10,shape=(50,50))
ids = [2,1,0]
embed = np.array(np.arange(1000)).reshape(100,10)

some_embedding = tf.nn.embedding_lookup(embedding,ids=ids)
some_embed = tf.nn.embedding_lookup(embed,ids=ids)
with tf.Session() as sess:
    print(sess.run(some_embedding))
    print(sess.run(some_embed))

- 结果:

[[1.9174099  1.4985561  8.763132   5.1513696  3.0472064  1.3083363
  8.763195   9.585258   6.5263486  4.5616875  1.7619014  3.474176
  6.629846   7.9499936  3.3529973  3.8903248  3.5784101  2.235669
  5.382924   3.5264575  6.6558695  5.08219    8.104395   4.7079873
  6.942602   8.672055   9.369327   2.9697776  2.1853507  7.4739685
  7.0480814  6.7367516  2.4502242  3.1823123  1.7625916  9.54441
  5.154147   0.1355648  7.0668087  0.90946436 1.8672597  6.3007355
  0.8664024  3.5557115  2.82979    1.6570354  5.1281223  4.6919146
  2.8726697  4.4800186 ]
 [7.004551   0.11087418 4.191972   1.8102491  5.845166   5.9069967
  5.808556   7.354327   5.4576874  5.7963667  6.093507   9.707213
  4.706266   0.89920044 7.6548386  4.213927   3.8301349  9.722362
  3.239541   2.0103395  6.5536427  0.5413759  6.9744706  4.4435463
  7.05449    8.396385   9.050558   8.112217   9.726019   3.3593094
  1.8417215  5.391085   0.22029877 3.4840071  8.741225   1.8167889
  2.7843773  7.633685   5.02874    4.2803707  8.34441    2.6564145
  8.721051   6.1455345  0.05278707 5.921961   0.7731664  4.822198
  0.674752   7.203877  ]
 [4.561104   4.972656   3.8639069  8.641497   5.19176    6.2906036
  8.288484   1.6396677  4.090029   0.7623017  8.548621   8.682202
  6.331664   5.1811004  4.39334    9.6758795  8.779367   6.086396
  1.9097066  3.6502922  4.1653624  9.572655   1.2448967  2.2894263
  1.0316873  9.54537    4.5649695  8.553878   7.0591154  6.1101866
  3.672725   3.2488763  8.497752   9.083318   5.6434383  0.73266864
  5.2025137  6.840452   7.8819666  7.0142117  5.364611   7.310052
  9.392939   7.888708   5.781497   1.2782669  5.573305   1.9477713
  2.99865    2.6828694 ]]
[[20 21 22 23 24 25 26 27 28 29]
 [10 11 12 13 14 15 16 17 18 19]
 [ 0  1  2  3  4  5  6  7  8  9]]

猜你喜欢

转载自blog.csdn.net/nockinonheavensdoor/article/details/80224890