Detailed pytorch embedding layer (from the actual principle)

Nlp to do a lot of time to use the embedded layer, pytorch comes this layer

What is embedding layer

This is a write better

I use the most popular language to tell you
in nlp years, embedding layer is the list of words [ 'you', 'good', 'it']
encoded into

 ‘你’ --------------[0.2,0.1]
 ‘好’ --------------[0.3,0.2]
 ‘吗’ --------------[0.6,0.5]

Vector way

Why embedding

It is written in detail

I summarized in one sentence:
Because the one-hot encoding represents a waste of memory, and we are all children from poor families.

pytorch how to use inside

Official Documents / English

Class definition

Here Insert Picture Description

parameter

Here that several important parameters:

  1. num_embeddings: embedding layer dictionary size (number of words in the wordbook)
  2. embedding_dim: The size of each output vector
    Here Insert Picture Description

Explanation

The following points should be noted:

  1. Same as the default output of the first dimension

Here Insert Picture Description

Examples

>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])


>>> # example with padding_idx
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0,2,0,5]])
>>> embedding(input)
tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.1535, -2.0309,  0.9315],
         [ 0.0000,  0.0000,  0.0000],
         [-0.1655,  0.9897,  0.0635]]])

Their small example easier to understand

a = torch.LongTensor([0])
embedding = torch.nn.Embedding(2, 5)
b = embedding(a)
b
Out[29]: tensor([[1.7931, 0.5004, 0.3444, 0.7140, 0.3001]], grad_fn=<EmbeddingBackward>)

Published 81 original articles · won praise 357 · views 30000 +

Guess you like

Origin blog.csdn.net/weixin_43914889/article/details/104699657