torch加载glove预训练的embedding

pytorch和torchtext的版本要对应起来:

https://blog.csdn.net/qq_42363032/article/details/126929995?spm=1001.2014.3001.5501

glove预训练的emb向量下载:

https://blog.csdn.net/weixin_44912159/article/details/105538891

先下载词典golve,可以选择的维度有50,100,200,300。我这里选的300:

在这里插入图片描述

# cache_dir是保存golve词典的缓存路径
cache_dir = './data/'

# dim是embedding的维度
glove = vocab.GloVe(name='840B', dim=300, cache=cache_dir)

# 可以打印下来看一下维度
print(glove.vectors.size())
# torch.Size([2196017, 300])

nn.embedding的时候预加载即可:

self.glove_emb = nn.Embedding.from_pretrained(glove, freeze=False)

但是实际使用的时候会和torchtext一起,更方便,类似于:

word_field = Field(tokenize='spacy', lower=True, include_lengths=True, fix_length=tcfig.pad_size, tokenizer_language='en_core_web_sm')
label_field = LabelField(dtype=torch.long)

def sst_word_char(path, word_field, label_field, batch_size, device, word_emb_file, cache_dir):

    fields = {
    
    
        'text': ('text_word', word_field),
        'label': ('label', label_field)
    }

    # splits方法可以为多个数据集直接创建Dataset,以json的形式创建]
    # THUCNews
    train, dev = TabularDataset.splits(path=path, train='train_demo.jsonl', validation='val_demo.jsonl', format='json', skip_header=True, fields=fields)

    # 建立词汇表
    word_vectors = vocab.Vectors(word_emb_file, cache_dir)

    word_field.build_vocab(train, dev, max_size=25000, vectors=word_vectors, unk_init=torch.Tensor.normal_)

    label_field.build_vocab(train, dev)

    train_iter, dev_iter = BucketIterator.splits(
        (train, dev), batch_sizes=(batch_size, len(dev)), sort_key=lambda x: len(x.text_word),
        sort_within_batch=True, repeat=False, shuffle=True, device=device
    )

    return train_iter, dev_iter

train_loader, val_loader = sst_word_char(
    tcfig.data_dir, word_field, label_field, tcfig.batch_size, device,
    tcfig.glove_word_file, tcfig.cache_path)

word_embeddings = word_field.vocab.vectors
print('---', word_embeddings.shape)

# ----------

for batch, batch_data in enumerate(train_loader):
	text_word, y = batch_data.text_word, batch_data.label
    pred_probs = net(text_word)

https://blog.csdn.net/cskywit/article/details/93407830

https://blog.csdn.net/sinat_26917383/article/details/83029140

猜你喜欢

转载自blog.csdn.net/qq_42363032/article/details/126954614
今日推荐