如何优雅地使用大型词嵌入?


下载地址

在开始之前,请下载文档。
百度网盘链接(337.97MB):https://pan.baidu.com/s/1WZEGJeHBzmqs_tVFV-zBBA 提取码:8dqm


特定词语解释

低内存机器: 这里指的是内存小于32GB的计算机。

大型词嵌入: 这里指的是词嵌入文件大于15G的词嵌入文件。

腾讯词嵌入: 这里指的是腾讯发布的 AILab ChineseEmbedding,其下载地址为(解压后:15.5GB):https://ai.tencent.com/ailab/nlp/data/Tencent_AILab_ChineseEmbedding.tar.gz,其下载页面为:https://ai.tencent.com/ailab/nlp/embedding.html


解决了什么问题?

  本文解决了大型词嵌入在低性能、低内存机器上的资源耗尽的问题(ResourceExhaustedError)。

  以腾讯词嵌入为例,腾讯中文的词向量映射集在解压后有15.5G,共计有8,824,330条字词短语,内存较小的计算机显然不能直接加载,故为满足小内存、低性能的计算机的需要,特建立对词嵌入的映射关系文件,映射后只有313MB,满足了此类计算机的需求。


原理是什么?

  这里是典型的以时间换空间的方式解决在使用腾讯词嵌入的时候内存资源耗尽的问题。在词嵌入与程序之间建立一个中间的映射文件,程序通过映射文件读取词嵌入的内容,映射文件格式如下:

  程序通过词汇可以访问到对应的文件指针的起始位置以及读取长度,然后程序就可以直接访问磁盘中的对应的数据了。


怎么用?

1. 安装linecache

$ pip install linecache

2. 创建文件夹

  创建名字叫做“utils”的文件夹,里面放入“ShowProcess .py”,这个可以在文章末尾复制代码,可以直接从文中的百度网盘链接中下载。

  创建名字叫做“embeddings”的文件夹,里面放入解压好了的“Tencent_AILab_ChineseEmbedding.txt”文件,以及“ReadEmbeddings .py”,映射文件“embeddings_map_index.txt”也将在这里生成。

  目录结构如图:

2. 加载模块

from embeddings.ReadEmbeddings import ReadEmbeddings
EMB = ReadEmbeddings()

3. (可选)关键参数设置

   指定词嵌入文件位置(默认使用腾讯词嵌入)。

EMB.emb_file="要指定的词嵌入文件的位置,默认设置:embeddings/Tencent_AILab_ChineseEmbedding.txt"

   指定词嵌入词条数量(默认使用腾讯词嵌入的词条数量)。

EMB.max_count = 8824330

  如果有需要指定生成的映射文件的位置,可以在这里指定。

EMB.map_file="要指定的生成映射文件的位置,默认设置:embeddings/embeddings_map_index.txt"

4. (首次使用)创建映射文件

  这个过程需要一个小时左右,可以选择自己生成,也可以选择博主生成好了的文件。

EMB.creat_map_file()

5. 加载词嵌入映射文件

  生成完之后就可将映射文件加载进内存了,你可自行查看映射列表的内容。

map_list = []
map_list = EMB.load_map_in_memery()

6. 单个查询

  这里有提供单个词组查询的功能:

word = '你好'
value = []
value = EMB.find_by_word(map_list, word)

7. 批量查询

  这里有提供批量查询的功能:

query_list = ['这个','世界','需要','更多的','英雄']
return_dict = {}
return_dict = EMB.find_by_list(map_list, query_list)

8. 释放内存

  当不再需要映射文件时,立即释放内存。

EMB.clear_cache()

高级用法

指定编码

  在初始化的时候就可指定全局编码,在读取词嵌入以及创建映射文件的时候可以使用统一的编码。

EMB.encoding = '指定文件编码,默认设置:utf8'

启用日志

  单独查询与批量查询都具备写入日志的功能(仅记录查询失败日志)。

#--------------------------------方法一--------------------------------

log_obj = open("log_file.log", "a+",encoding=EMB.encoding) #或者自行指定encoding
# 单独查询
word = "你好"
value = EMB.find_by_word(map_list, word, f_log_obj=log_obj)

# 批量查询
query_list = ['这个','世界','需要','更多的','英雄']
return_dict = {}
return_dict = EMB.find_by_list(map_list, query_list, f_log_obj=log_obj)

# 当不再需要记录查询日志时
log_obj.close()

#--------------------------------方法二--------------------------------
with open("log_file.log", "a+", encoding=EMB.encoding) as log_obj:
    ...

启用元素删除功能

  元素删除功能在每个词汇仅查询一次的条件下才能启用,此功能在需要查询的词汇量特别大的时候会显著提升查询效率,有效减少查询时间,举个栗子:


# 比如在Keras中,你已经到了使用Tokenizer的这一步:
...

tokenizer = Tokenizer()
tokenizer.fit_on_texts(dataset_all.tolist())
vocab = tokenizer.word_index
# 比如: vocab["你好"] == 10
# 那么“你好”的排名就是10
# 比如你发现len(vocab) == 1000000
# 你就需要创建一个(1,000,001, 200)的矩阵,200这个参数是依据你词嵌入的维度来决定的。
# 你可以翻转key与word构成vocab_resverse,然后使用下面的语句

matrix[10] = EMB.find_by_word(map_list, vocab_resverse[10], is_del_element=True)

...

  由于Python的特性,此功能会影响到外部变量,受影响的外部变量:

map_list = EMB.load_map_in_memery()
A = map_list
B = A

# 这里A、B、map_list均会受到元素删除功能的影响
word = "你好"
value = EMB.find_by_word(map_list, word, is_del_element=True)
# 这里“你好”一词已经从“map_list”中移除,无法再被查询到。

#虽然已经明确说了每个词汇仅查询一次的条件,但是还是给出解决办法:
#--------------------------------方法一--------------------------------
import copy
map_list = EMB.load_map_in_memery()
A = copy.deepcopy(map_list)

# 这里仅A会受到元素删除功能的影响
word = "你好"
value = EMB.find_by_word(A, word, is_del_element=True)
# 这里“你好”一词已经从“A”中移除,无法再被A查询到,但是map_list不受影响。

#--------------------------------方法二--------------------------------
# 重新加载词嵌入映射文件
map_list = EMB.load_map_in_memery()

代码部分

ShowProcess . py

# -*- coding: UTF-8 -*-

class ShowProcess():
    """
    显示处理进度的类
    调用该类相关函数即可实现处理进度的显示
    """
    i = 0 # 当前的处理进度
    max_steps = 100 # 总共需要处理的次数
    max_arrow = 100 #进度条的长度
    infoDone = 'done'

    # 初始化函数,需要知道总共的处理次数
    def __init__(self, max_steps=max_steps, infoDone = 'Done'):
        self.max_steps = max_steps
        self.i = 0
        self.infoDone = infoDone

    # 显示函数,根据当前的处理进度i显示进度
    # 效果为[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>]100.00% 999|999
    def show_process(self, i=None):
        if i is not None:
            self.i = i
        else:
            self.i += 1
        num_arrow = int(self.i * self.max_arrow / self.max_steps) #计算显示多少个'>'
        num_line = self.max_arrow - num_arrow #计算显示多少个'-'
        percent = self.i * 100.0 / self.max_steps #计算完成进度,格式为xx.xx%
        process_bar = '[{curr_process}{incomplete}] {curr_percent:.2f}%\t{i}|{max_steps}\r'.format(
                    curr_process = '>' * num_arrow, 
                    incomplete = '-' * num_line,
                    curr_percent = percent,
                    i = i,
                    max_steps = self.max_steps)
        sys.stdout.write(process_bar) #这两句打印字符到终端
        sys.stdout.flush()
        if self.i >= self.max_steps:
            self.close()

    def close(self):
        print('')
        print(self.infoDone)
        self.i = 0

ReadEmbeddings . py

# -*- coding: UTF-8 -*-

import linecache
import copy

from utils.ShowProcess import ShowProcess

class ReadEmbeddings():
    
    def __init__(self, emb_file='embeddings/Tencent_AILab_ChineseEmbedding.txt',map_file='embeddings/embeddings_map_index.txt', \
               encoding='utf8'):
        self.map_file = map_file
        self.emb_file = emb_file
        self.encoding = encoding
        self.counter = 0
        self.max_count = 8824330
        self.process=ShowProcess(max_steps=self.max_count)
        
    def creat_map_file(self):
        '''
        当映射表损坏或丢失时,可用此函数来创建映射表,映射表为313MB(320,873KB),大约需要1h+,
        【警告】:本函数的所有代码在修改前请三思。
        '''
        emb_file_name = self.emb_file
        map_file_name = self.map_file
        start_seek = 12
        length = 0
        self.process.max_steps=self.max_count
        i = 1
        with open(emb_file_name, 'rb') as emb_file:
            next(emb_file) # 跳过第一行的无用数据。
            with open(map_file_name, 'a', encoding='utf8') as map_file:
                for e in emb_file:
                    length = len(e)
                    curr_e = str(e, encoding='utf8').split()[0]
                    curr_m = '{word} {index} {start_seek} {length}\n'.format(word=curr_e, index=i, start_seek=start_seek, length=length)
                    map_file.write(curr_m)
                    start_seek = start_seek + length
                    i += 1
                    self.process.show_process(i)   
    
    def load_map_in_memery(self):
        '''
        因为词嵌入文件过大,所以我们只需要加载它的映射就可以了
        
        返回:
            map_list -- 映射列表
        '''
        # 将映射表加载进内存
        tmp_map_list = linecache.getlines(self.map_file)
        map_list = []
        for m in tmp_map_list:
            curr_m = m.split()
            curr_list_word = [curr_m[0]]
            curr_list_value = list(map(int, curr_m[1:]))
            map_list.append(curr_list_word + curr_list_value)
        return map_list
    
    def clear_cache(self):
        '''
        当不再需要映射列表的时候,调用此函数可以立即清除缓存,释放空间。
        '''
        linecache.clearcache()
        
    def find_by_list(self, map_list, query_list ,f_log_obj=None):
        '''
        批量查询它的词嵌入权值
        参数:
            map_list -- 映射列表,可通过load_map_in_memery()函数得到。
            query_list -- 要查询的词汇的列表
            f_log_obj -- 使用open语句打开的文件对象
        返回:
            return_dict -- 字典类型,键为词汇,值为对应的词嵌入权值。
                example:{'我们':[0.238955, -0.192848, ... , 0.137744],'...':[...],...}
            
        '''
        is_log = False
        if f_log_obj:
            is_log = True
        
        query_list2 = copy.deepcopy(query_list)
        if len(query_list2) == 0:
            if is_log:
                is_log = False
                f_log_obj.write("query_list is empty!\n")
            return -1
        
        return_dict = {}
        with open(self.emb_file, 'rb') as emb:
            for m in map_list:
                for q in query_list2:
                    if q in m:
                        emb.seek(m[2])
                        value = list(map(float, emb.read(m[3]).split()[1:]))
                        return_dict[str(q)] = value
                        query_list2.remove(q)
        if len(query_list2) >= 1:
            for q in query_list2:
                #print("Waring: " + q + " not in the embeddings.")
                if is_log:
                    f_log_obj.write("未找到:{word} \n".format(word=str(q)))
                return_dict[str(q)] = [0.0]*200
                query_list2.remove(q)
                
        return return_dict
    
    def find_by_word(self, map_list, word, is_del_element=False, f_log_obj=None):
        '''
        查询它的词嵌入权值
        参数:
            map_list -- 映射列表,可通过load_map_in_memery()函数得到。
            word -- 要查询的词汇
            f_log_obj -- 使用open语句打开的文件对象
        返回:
            value -- 200维词汇向量
            
        '''
        is_log = False
        if f_log_obj:
            is_log = True
        value = []
        with open(self.emb_file, 'rb') as emb:
            for m in map_list:
                if word in m:
                    emb.seek(m[2])
                    value = list(map(float, emb.read(m[3]).split()[1:]))
                    if is_del_element:
                        map_list.remove(m)
                        f_log_obj.write("删除:{m} \n".format(m=str(m)))
                    return value
        if is_log:
            f_log_obj.write("未找到:{word} \n".format(word=word))
            return [0.0]*200
发布了57 篇原创文章 · 获赞 1690 · 访问量 76万+

猜你喜欢

转载自blog.csdn.net/u013733326/article/details/94050224