Use PyTorch para implementar o modelo Skip-gram em Word2Vec

Word2VecDatasetPrimeiro, um conjunto de dados customizado usando a classe é criado para gerar dados de treinamento. Em seguida, o modelo Skip-gram é definido e treinado usando a função de perda de entropia cruzada e o otimizador Adam.

Em cada época de treinamento, o carregador de dados é percorrido, propagação direta, perda de computação, propagação reversa e atualizações de peso para cada lote. Finalmente, o vetor de palavras treinado é obtido e pode ser usado word_vectorpara obter a representação do vetor de palavras de uma palavra específica.

Certifique-se de instalar o PyTorch antes de executar, você pode pip install torchinstalá-lo usando Observe que o código será executado na GPU, se disponível. Se você não possui uma GPU, .to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))remova a seção e execute-a na CPU.

A seguir está um exemplo de código para implementar o modelo Skip-gram usando PyTorch:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Hyperparameters
embedding_dim = 100
window_size = 2
learning_rate = 0.001
epochs = 100
batch_size = 32

# Example corpus
corpus = [['I', 'enjoy', 'playing', 'football', 'with', 'my', 'friends'],
          ['We', 'like', 'to', 'play', 'tennis', 'on', 'weekends'],
          ['She', 'is', 'a', 'good', 'dancer']]

# Create vocabulary
vocab = list(set([word for sentence in corpus for word in sentence]))
vocab_size = len(vocab)
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}

# Generate training data
class Word2VecDataset(Dataset):
    def __init__(self, corpus, word2idx):
        self.data = []
        for sentence in corpus:
            word_indices = [word2idx[word] for word in sentence]
            for center_word_idx, center_word in enumerate(word_indices):
                for context_word_idx in range(max(0, center_word_idx - window_size), min(center_word_idx + window_size + 1, len(word_indices))):
                    if context_word_idx != center_word_idx:
                        context_word = word_indices[context_word_idx]
                        self.data.append((center_word, context_word))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

dataset = Word2VecDataset(corpus, word2idx)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define Skip-gram model
class SkipGramModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SkipGramModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)
        
    def forward(self, center_word):
        embedded = self.embedding(center_word)
        output = self.linear(embedded)
        return output

model = SkipGramModel(vocab_size, embedding_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training
for epoch in range(epochs):
    running_loss = 0.0
    for i, (center_word, context_word) in enumerate(dataloader):
        optimizer.zero_grad()
        
        center_word = center_word.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        context_word = context_word.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        
        output = model(center_word)
        loss = criterion(output, context_word)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    average_loss = running_loss / len(dataloader)
    print(f'Epoch {epoch+1}/{epochs}, Loss: {average_loss:.4f}')

# Get trained word embeddings
trained_embeddings = model.embedding.weight.data.numpy()

# Example usage - Getting word vector for a word
word = 'football'
word_vector = trained_embeddings[word2idx[word]]
print(f"Word vector for '{word}': {word_vector}")

Os resultados da execução são os seguintes:

Época 1/100, Perda: 3,1324
Época 2/100, Perda: 3,0791
Época 3/100, Perda: 2,9902
Época 4/100, Perda: 2,9392
Época 5/100, Perda: 2,8870
Época 6/100, Perda: 2,8166
Época 7 /100, Perda: 2.7615
Epoch 8/100, Perda: 2.7017
Epoch 9/100, Perda: 2.6500
Epoch 10/100, Perda: 2.5993
Epoch 11/100, Perda: 2.5496
Epoch 12/100, Perda: 2.5013
Epoch 13/100 , Perda: 2,4621
Época 14/100, Perda: 2,4079
Época 15/100, Perda: 2,3660
Época 16/100, Perda: 2,3229 Época 17/100, Perda
: 2,2795
Época 18/100, Perda: 2,2398
Época 19/100, Perda : 2,1998
Época 20/100, Perda: 2,1582
Época 21/100, Perda: 2,1278
Época 22/100, Perda: 2,1023
Época 23/100, Perda: 2,0569
Época 24/100, Perda: 2,0245
Época 25/100, Perda: 1,9936 Época 26/100
, Perda: 1,9639 Época
27/100, Perda: 1,9344
Época 28 /100, Perda: 1,9137
Época 29/100, Perda: 1,8888
Época 30/100, Perda: 1,8586 Época 31/100, Perda: 1,8352
Época 32/100
, Perda: 1,8200
Época 33/100, Perda: 1,7815
Época 34/1 00 , Perda: 1,7685
Época 35/100, Perda: 1,7531
Época 36/100, Perda: 1,7209 Época
37/100, Perda: 1,7049 Época
38/100, Perda: 1,6881
Época 39/100, Perda: 1,6775
Época 40/100, Perda : 1,6517
Época 41/100, Perda: 1,6390
Época 42/100, Perda: 1,6238
Época 43/100, Perda: 1,6077
Época 44/100, Perda: 1,5939
Época 45/100, Perda: 1,5745
Época 46/100, Perda: 1,5703 Época 47/100
, Perda: 1,5574
Época 48/100, Perda: 1,5458
Época 49 /100, Perda: 1,5308
Época 50/100, Perda: 1,5215
Época 51/100, Perda: 1,5122 Época
52/100, Perda: 1,4988 Época
53/100, Perda: 1,4958
Época 54/100, Perda: 1,4773
Época 55/1 00 , Perda: 1,4746
Época 56/100, Perda: 1,4618
Época 57/100, Perda: 1,4560
Época 58/100, Perda: 1,4506
Época 59/100, Perda: 1,4380
Época 60/100, Perda: 1,4266
Época 61/100, Perda : 1,4257
Época 62/100, Perda: 1,4148
Época 63/100, Perda: 1,4090
Época 64/100, Perda: 1,4070
Época 65/100, Perda: 1,3940
Época 66/100, Perda: 1,3890 Época 67/100
, Perda: 1,3846 Época 68/100
, Perda: 1,3813
Época 69/100, Perda: 1,3738
Época 70 /100, Perda: 1,3717
Época 71/100, Perda: 1,3681
Época 72/100, Perda: 1,3594 Época
73/100, Perda: 1,3593 Época 74/100, Perda
: 1,3504
Época 75/100, Perda: 1,3447
Época 76/1 00 , Perda: 1,3439
Época 77/100, Perda: 1,3397
Época 78/100, Perda: 1,3315
Época 79/100, Perda: 1,3260
Época 80/100, Perda: 1,3253
Época 81/100, Perda: 1,3229
Época 82/100, Perda : 1,3215
Época 83/100, Perda: 1,3148
Época 84/100, Perda: 1,3160
Época 85/100, Perda: 1,3072
Época 86/100, Perda: 1,3105
Época 87/100, Perda: 1,3104
Época 88/100, Perda: 1,3018
Época 89/100, Perda: 1,2912 Época
90/100, Perda: 1,2950
Época 91 /100, Perda: 1,2938
Época 92/100, Perda: 1,2951
Época 93/100, Perda: 1,2859
Época 94/100, Perda: 1,2902 Época 95/100
, Perda: 1,2840
Época 96/100, Perda: 1,2748
Época 97/1 00 , Perda: 1,2840
Época 98/100, Perda: 1,2763
Época 99/100, Perda: 1,2772
Época 100/100, Perda: 1,2746


Vetor de palavras para 'futebol':

[-1.2727762   0.8401019  -0.5115612   2.0667355   1.1854529  -0.7444803
 -1.9658612  -1.0488677   0.98938674 -1.1675086   1.582392    1.7414839
 -0.4892138  -1.2149098   0.15343344 -1.8318586   0.41794038  0.25481498
  0.6008032  -0.23904797  0.80143225 -1.0495795  -1.0174142  -0.01827855
  2.7477944  -0.9574399   1.025569    2.4843202  -0.2796719  -0.4390253
 -1.4423424  -1.8073392   0.1897556   0.90259725  2.7565296  -0.28331178
 -1.8443514   0.77545553 -1.0289538   0.71483964  1.1801128  -0.22635305
  0.5960759   0.6690206  -1.9100318   1.2388043  -0.68522704  0.92120373
  1.0252377  -1.4376261  -0.6595934   0.31699112  0.6751458   0.99656415
  0.40565705 -1.0904227  -0.3513346  -0.66078615  1.1834346  -1.0899751
 -1.4925232 -0.30818892 1.4249563 0.06006899 -3.2386255 0.96192694
 -1.1045157 0.5540482 -1.5388466 -0.8721646 1.1221852 1.6488599
  0.44869688 1.1519432 -1.4588032 -0.04230021 -0.33113605 1.1316347
 -0.7425484 -0.11400439 0.37237874 -0.34573358 0.4140474 -0.04413145
  0.6157635 -1.0094129 -1.2208599 -0.7154122 0.9412035 0.9452426
 -0.0973389 -0.23566085 0,34300375 -0,95858365 0,8764276
 -0,5669889 -1,933235 0,22371146 1,6641699 1,3258857]

Guess you like

Origin blog.csdn.net/Metal1/article/details/132886936