前言
正文
1. 词频统计
- 使用的是jieba分词,如果是基于字的词向量直接split()就行
def get_word_freq(file_path): ''' 统计文件出现的词频 Args: file_path: train、val、test文件所在目录 Returns: token_counter: [dict]分词后词频统计结果 ''' token_counter = Counter() for file_name in ['train.json', 'val.json', 'test.json']: path = os.path.join(file_path, file_name) with codecs.open(path, 'r', 'utf-8') as infs: for inf in infs: inf = json.loads(inf.strip()) for token in jieba.lcut(inf['question']): token_counter[token] += 1 print("*** {} words in total ***".format(len(token_counter))) return token_counter
2. 抽取词向量
- 最终返回分词映射到id的字典、词嵌入矩阵
- 未知(unk)和补全(pad)字符的index分别为0和1,词向量用全0表示
def get_embedding(embed_path, token_counter, freq_threshold, embed_dim): ''' 读取词向量 Args: embed_path : embedding文件路径 token_counter : [dict]分词后词频统计结果 freq_threshold: [int]词频最低阈值,低于此阈值的词不会进行词向量抽取 embed_dim : [int]词向量维度 Returns: token2id : [dict]词转id的字典 embed_mat: [ListOfList]嵌入矩阵 ''' embed_dict = {} filtered_elements = [k for k, v in token_counter.items() if v >= freq_threshold] with codecs.open(embed_path, 'r', 'utf-8') as infs: for inf in infs: inf = inf.strip() inf_list = inf.split() token = ''.join(inf_list[0:-embed_dim]) if token in token_counter and token_counter[token] >= freq_threshold: embed_dict[token] = list(map(float, inf_list[-embed_dim:])) print("{} / {} tokens have corresponding embedding vector".format( len(embed_dict), len(filtered_elements))) unk = "<unk>" pad = "<pad>" # enumerate(iterable, start=0),start代指起始idx(不影响token输出) token2id = {token: idx for idx, token in enumerate(embed_dict.keys(), 2)} token2id[unk] = 0 token2id[pad] = 1 embed_dict[unk] = [0. for _ in range(embed_dim)] embed_dict[pad] = [0. for _ in range(embed_dim)] id2embed = {idx: embed_dict[token] for token, idx in token2id.items()} embed_mat = [id2embed[idx] for idx in range(len(id2embed))] return token2id, embed_mat
3. 调用
-
如果没有缓存,则调用之前两个函数,根据训练、验证、测试集进行词向量抽取
-
如果有缓存,读取缓存
-
关于缓存读写
def load_word_embedding(embed_path, token2id_cache='data/token2id_cache.pkl', embed_mat_cache='data/embed_mat_cache.pkl'): ''' 读取token2id和embed_mat Args: embed_path : 词向量文件地址 token2id_cache : token2id缓存地址 embed_mat_cache: 嵌入矩阵缓存地址 Returns: token2id : [dict]词转id的字典 embed_mat: [ListOfList]嵌入矩阵 ''' if os.path.exists(token2id_cache) and os.path.exists(embed_mat_cache): print('*** load token2id and id2embed from cache ***') # 读cache with open(token2id_cache, 'rb') as inf: token2id = pickle.load(inf) with open(embed_mat_cache, 'rb') as inf: embed_mat = pickle.load(inf) else: print('*** generating token2id and id2embed from datasets ***') token_counter = get_word_freq('data/') token2id, embed_mat = get_embedding( embed_path, token_counter, freq_threshold=2, embed_dim=300) # 写cache with open(token2id_cache, 'wb') as outf: pickle.dump(token2id, outf) with open(embed_mat_cache, 'wb') as outf: pickle.dump(embed_mat, outf) return token2id, embed_mat