word-embedding

代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)
CONTEXT_SIZE = 2
EMBEDDING_DIM = 30
EPOCH = 100
VERVOSE = 5

# 0) 文本分开
corpus_text = "This tutorial will walk you through the key ideas of deep learning programming using Pytorch." \
              " Many of the concepts (such as the computation graph abstraction and autograd) " \
              "are not unique to Pytorch and are relevant to any deep learning tool kit out there.".split(' ')


# 1) 定义模型
class CBOW(nn.Module):

    def __init__(self, vocab_size, embedding_size, context_size):
        super(CBOW, self).__init__()

        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.context_size = context_size
        self.embeddings = nn.Embedding(self.vocab_size, self.embedding_size)

        self.lin1 = nn.Linear(self.context_size * 2 * self.embedding_size, 512)
        self.lin2 = nn.Linear(512, self.vocab_size)

    def forward(self, inp):
        out = self.embeddings(inp).view(1, -1)
        out = out.view(1, -1)

        out = self.lin1(out)
        out = F.relu(out)

        out = self.lin2(out)
        out = F.log_softmax(out, dim=1)

        return out

    def get_word_vector(self, word_idx):
        word = torch.LongTensor([word_idx])
        print('self.embeddings = ', next(self.embeddings.parameters()))
        return self.embeddings(word).view(1, -1)



def train_cbow(data, unique_vocab, word_to_idx):

    cbow = CBOW(len(unique_vocab), EMBEDDING_DIM, CONTEXT_SIZE)

    nll_loss = nn.NLLLoss()
    optimizer = optim.SGD(cbow.parameters(), lr=0.001)

    print(len(data))

    for epoch in range(EPOCH):

        total_loss = 0

        for context, target in data:

            inp_var = torch.tensor([word_to_idx[word] for word in context], dtype=torch.long)

            target_var = torch.tensor([word_to_idx[target]], dtype=torch.long)

            cbow.zero_grad()
            log_prob = cbow(inp_var)
            loss = nll_loss(log_prob, target_var)
            loss.backward()
            optimizer.step()
            total_loss += loss.data

        if epoch % VERVOSE == 0:
            loss_avg = float(total_loss / len(data))
            print("{}/{} loss {:.2f}".format(epoch, EPOCH, loss_avg))
    return cbow


def main():

    # consider 2*CONTEXT_SIZE as context window where middle word as target

    # 获取data
    data = list()
    for i in range(CONTEXT_SIZE, len(corpus_text) - CONTEXT_SIZE):
        data_context = list()
        for j in range(CONTEXT_SIZE):
            data_context.append(corpus_text[i-CONTEXT_SIZE+j])

        for j in range(1, CONTEXT_SIZE+1):
            data_context.append(corpus_text[i+j])

        data_target = corpus_text[i]
        data.append((data_context, data_target))

    print("some data: ", data[:3])

    # mapping to  index
    unique_vocab = list(set(corpus_text))
    word_to_idx = {w:i for i, w in enumerate(unique_vocab)}
    print('word_to_idx: ', word_to_idx)

    # train_model
    cbow = train_cbow(data, unique_vocab, word_to_idx)




if __name__ == "__main__":

    main()

结果


some data:  [(['This', 'tutorial', 'walk', 'you'], 'will'), (['tutorial', 'will', 'you', 'through'], 'walk'), (['will', 'walk', 'through', 'the'], 'you')]
word_to_idx:  {'using': 0, 'relevant': 1, 'unique': 2, 'tutorial': 3, 'concepts': 4, 'graph': 5, 'Many': 6, 'the': 7, 'are': 8, 'deep': 9, 'key': 10, 'ideas': 11, 'there.': 12, 'out': 13, 'Pytorch': 14, 'learning': 15, 'of': 16, 'not': 17, 'autograd)': 18, 'tool': 19, 'will': 20, '(such': 21, 'you': 22, 'This': 23, 'through': 24, 'Pytorch.': 25, 'as': 26, 'kit': 27, 'any': 28, 'computation': 29, 'and': 30, 'to': 31, 'abstraction': 32, 'programming': 33, 'walk': 34}
39
0/100 loss 3.60
5/100 loss 3.14
10/100 loss 2.73
15/100 loss 2.33
20/100 loss 1.97
25/100 loss 1.65
30/100 loss 1.36
35/100 loss 1.11
40/100 loss 0.91
45/100 loss 0.74
50/100 loss 0.61
55/100 loss 0.51
60/100 loss 0.42
65/100 loss 0.36
70/100 loss 0.31
75/100 loss 0.27
80/100 loss 0.24
85/100 loss 0.21
90/100 loss 0.19
95/100 loss 0.17

Process finished with exit code 0


猜你喜欢

转载自blog.csdn.net/sinat_15355869/article/details/88077636