Word2VecDataset
Primeiro, um conjunto de dados customizado usando a classe é criado para gerar dados de treinamento. Em seguida, o modelo Skip-gram é definido e treinado usando a função de perda de entropia cruzada e o otimizador Adam.
Em cada época de treinamento, o carregador de dados é percorrido, propagação direta, perda de computação, propagação reversa e atualizações de peso para cada lote. Finalmente, o vetor de palavras treinado é obtido e pode ser usado word_vector
para obter a representação do vetor de palavras de uma palavra específica.
Certifique-se de instalar o PyTorch antes de executar, você pode pip install torch
instalá-lo usando Observe que o código será executado na GPU, se disponível. Se você não possui uma GPU, .to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
remova a seção e execute-a na CPU.
A seguir está um exemplo de código para implementar o modelo Skip-gram usando PyTorch:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
# Hyperparameters
embedding_dim = 100
window_size = 2
learning_rate = 0.001
epochs = 100
batch_size = 32
# Example corpus
corpus = [['I', 'enjoy', 'playing', 'football', 'with', 'my', 'friends'],
['We', 'like', 'to', 'play', 'tennis', 'on', 'weekends'],
['She', 'is', 'a', 'good', 'dancer']]
# Create vocabulary
vocab = list(set([word for sentence in corpus for word in sentence]))
vocab_size = len(vocab)
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}
# Generate training data
class Word2VecDataset(Dataset):
def __init__(self, corpus, word2idx):
self.data = []
for sentence in corpus:
word_indices = [word2idx[word] for word in sentence]
for center_word_idx, center_word in enumerate(word_indices):
for context_word_idx in range(max(0, center_word_idx - window_size), min(center_word_idx + window_size + 1, len(word_indices))):
if context_word_idx != center_word_idx:
context_word = word_indices[context_word_idx]
self.data.append((center_word, context_word))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
dataset = Word2VecDataset(corpus, word2idx)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Define Skip-gram model
class SkipGramModel(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super(SkipGramModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.linear = nn.Linear(embedding_dim, vocab_size)
def forward(self, center_word):
embedded = self.embedding(center_word)
output = self.linear(embedded)
return output
model = SkipGramModel(vocab_size, embedding_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Training
for epoch in range(epochs):
running_loss = 0.0
for i, (center_word, context_word) in enumerate(dataloader):
optimizer.zero_grad()
center_word = center_word.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
context_word = context_word.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
output = model(center_word)
loss = criterion(output, context_word)
loss.backward()
optimizer.step()
running_loss += loss.item()
average_loss = running_loss / len(dataloader)
print(f'Epoch {epoch+1}/{epochs}, Loss: {average_loss:.4f}')
# Get trained word embeddings
trained_embeddings = model.embedding.weight.data.numpy()
# Example usage - Getting word vector for a word
word = 'football'
word_vector = trained_embeddings[word2idx[word]]
print(f"Word vector for '{word}': {word_vector}")
Os resultados da execução são os seguintes:
Época 1/100, Perda: 3,1324
Época 2/100, Perda: 3,0791
Época 3/100, Perda: 2,9902
Época 4/100, Perda: 2,9392
Época 5/100, Perda: 2,8870
Época 6/100, Perda: 2,8166
Época 7 /100, Perda: 2.7615
Epoch 8/100, Perda: 2.7017
Epoch 9/100, Perda: 2.6500
Epoch 10/100, Perda: 2.5993
Epoch 11/100, Perda: 2.5496
Epoch 12/100, Perda: 2.5013
Epoch 13/100 , Perda: 2,4621
Época 14/100, Perda: 2,4079
Época 15/100, Perda: 2,3660
Época 16/100, Perda: 2,3229 Época 17/100, Perda
: 2,2795
Época 18/100, Perda: 2,2398
Época 19/100, Perda : 2,1998
Época 20/100, Perda: 2,1582
Época 21/100, Perda: 2,1278
Época 22/100, Perda: 2,1023
Época 23/100, Perda: 2,0569
Época 24/100, Perda: 2,0245
Época 25/100, Perda: 1,9936 Época 26/100
, Perda: 1,9639 Época
27/100, Perda: 1,9344
Época 28 /100, Perda: 1,9137
Época 29/100, Perda: 1,8888
Época 30/100, Perda: 1,8586 Época 31/100, Perda: 1,8352
Época 32/100
, Perda: 1,8200
Época 33/100, Perda: 1,7815
Época 34/1 00 , Perda: 1,7685
Época 35/100, Perda: 1,7531
Época 36/100, Perda: 1,7209 Época
37/100, Perda: 1,7049 Época
38/100, Perda: 1,6881
Época 39/100, Perda: 1,6775
Época 40/100, Perda : 1,6517
Época 41/100, Perda: 1,6390
Época 42/100, Perda: 1,6238
Época 43/100, Perda: 1,6077
Época 44/100, Perda: 1,5939
Época 45/100, Perda: 1,5745
Época 46/100, Perda: 1,5703 Época 47/100
, Perda: 1,5574
Época 48/100, Perda: 1,5458
Época 49 /100, Perda: 1,5308
Época 50/100, Perda: 1,5215
Época 51/100, Perda: 1,5122 Época
52/100, Perda: 1,4988 Época
53/100, Perda: 1,4958
Época 54/100, Perda: 1,4773
Época 55/1 00 , Perda: 1,4746
Época 56/100, Perda: 1,4618
Época 57/100, Perda: 1,4560
Época 58/100, Perda: 1,4506
Época 59/100, Perda: 1,4380
Época 60/100, Perda: 1,4266
Época 61/100, Perda : 1,4257
Época 62/100, Perda: 1,4148
Época 63/100, Perda: 1,4090
Época 64/100, Perda: 1,4070
Época 65/100, Perda: 1,3940
Época 66/100, Perda: 1,3890 Época 67/100
, Perda: 1,3846 Época 68/100
, Perda: 1,3813
Época 69/100, Perda: 1,3738
Época 70 /100, Perda: 1,3717
Época 71/100, Perda: 1,3681
Época 72/100, Perda: 1,3594 Época
73/100, Perda: 1,3593 Época 74/100, Perda
: 1,3504
Época 75/100, Perda: 1,3447
Época 76/1 00 , Perda: 1,3439
Época 77/100, Perda: 1,3397
Época 78/100, Perda: 1,3315
Época 79/100, Perda: 1,3260
Época 80/100, Perda: 1,3253
Época 81/100, Perda: 1,3229
Época 82/100, Perda : 1,3215
Época 83/100, Perda: 1,3148
Época 84/100, Perda: 1,3160
Época 85/100, Perda: 1,3072
Época 86/100, Perda: 1,3105
Época 87/100, Perda: 1,3104
Época 88/100, Perda: 1,3018
Época 89/100, Perda: 1,2912 Época
90/100, Perda: 1,2950
Época 91 /100, Perda: 1,2938
Época 92/100, Perda: 1,2951
Época 93/100, Perda: 1,2859
Época 94/100, Perda: 1,2902 Época 95/100
, Perda: 1,2840
Época 96/100, Perda: 1,2748
Época 97/1 00 , Perda: 1,2840
Época 98/100, Perda: 1,2763
Época 99/100, Perda: 1,2772
Época 100/100, Perda: 1,2746
Vetor de palavras para 'futebol':
[-1.2727762 0.8401019 -0.5115612 2.0667355 1.1854529 -0.7444803
-1.9658612 -1.0488677 0.98938674 -1.1675086 1.582392 1.7414839
-0.4892138 -1.2149098 0.15343344 -1.8318586 0.41794038 0.25481498
0.6008032 -0.23904797 0.80143225 -1.0495795 -1.0174142 -0.01827855
2.7477944 -0.9574399 1.025569 2.4843202 -0.2796719 -0.4390253
-1.4423424 -1.8073392 0.1897556 0.90259725 2.7565296 -0.28331178
-1.8443514 0.77545553 -1.0289538 0.71483964 1.1801128 -0.22635305
0.5960759 0.6690206 -1.9100318 1.2388043 -0.68522704 0.92120373
1.0252377 -1.4376261 -0.6595934 0.31699112 0.6751458 0.99656415
0.40565705 -1.0904227 -0.3513346 -0.66078615 1.1834346 -1.0899751
-1.4925232 -0.30818892 1.4249563 0.06006899 -3.2386255 0.96192694
-1.1045157 0.5540482 -1.5388466 -0.8721646 1.1221852 1.6488599
0.44869688 1.1519432 -1.4588032 -0.04230021 -0.33113605 1.1316347
-0.7425484 -0.11400439 0.37237874 -0.34573358 0.4140474 -0.04413145
0.6157635 -1.0094129 -1.2208599 -0.7154122 0.9412035 0.9452426
-0.0973389 -0.23566085 0,34300375 -0,95858365 0,8764276
-0,5669889 -1,933235 0,22371146 1,6641699 1,3258857]