pytorch和torchtext的版本要对应起来:
https://blog.csdn.net/qq_42363032/article/details/126929995?spm=1001.2014.3001.5501
glove预训练的emb向量下载:
https://blog.csdn.net/weixin_44912159/article/details/105538891
先下载词典golve,可以选择的维度有50,100,200,300。我这里选的300:
# cache_dir是保存golve词典的缓存路径
cache_dir = './data/'
# dim是embedding的维度
glove = vocab.GloVe(name='840B', dim=300, cache=cache_dir)
# 可以打印下来看一下维度
print(glove.vectors.size())
# torch.Size([2196017, 300])
nn.embedding的时候预加载即可:
self.glove_emb = nn.Embedding.from_pretrained(glove, freeze=False)
但是实际使用的时候会和torchtext一起,更方便,类似于:
word_field = Field(tokenize='spacy', lower=True, include_lengths=True, fix_length=tcfig.pad_size, tokenizer_language='en_core_web_sm')
label_field = LabelField(dtype=torch.long)
def sst_word_char(path, word_field, label_field, batch_size, device, word_emb_file, cache_dir):
fields = {
'text': ('text_word', word_field),
'label': ('label', label_field)
}
# splits方法可以为多个数据集直接创建Dataset,以json的形式创建]
# THUCNews
train, dev = TabularDataset.splits(path=path, train='train_demo.jsonl', validation='val_demo.jsonl', format='json', skip_header=True, fields=fields)
# 建立词汇表
word_vectors = vocab.Vectors(word_emb_file, cache_dir)
word_field.build_vocab(train, dev, max_size=25000, vectors=word_vectors, unk_init=torch.Tensor.normal_)
label_field.build_vocab(train, dev)
train_iter, dev_iter = BucketIterator.splits(
(train, dev), batch_sizes=(batch_size, len(dev)), sort_key=lambda x: len(x.text_word),
sort_within_batch=True, repeat=False, shuffle=True, device=device
)
return train_iter, dev_iter
train_loader, val_loader = sst_word_char(
tcfig.data_dir, word_field, label_field, tcfig.batch_size, device,
tcfig.glove_word_file, tcfig.cache_path)
word_embeddings = word_field.vocab.vectors
print('---', word_embeddings.shape)
# ----------
for batch, batch_data in enumerate(train_loader):
text_word, y = batch_data.text_word, batch_data.label
pred_probs = net(text_word)
https://blog.csdn.net/cskywit/article/details/93407830
https://blog.csdn.net/sinat_26917383/article/details/83029140