0.说明
本系列笔记用于记录NeuralNLP-NeuralClassifier源码精读,此篇笔记是对get_data_loader()
函数的精读记录
附上train()
函数
1.def train(conf):
2. logger = util.Logger(conf)
3. if not os.path.exists(conf.checkpoint_dir): # 用来保存模型
4. os.makedirs(conf.checkpoint_dir)
5. model_name = conf.model_name # FastText
6. dataset_name = "ClassificationDataset"
7. collate_name = "FastTextCollator" if model_name == "FastText" \
else "ClassificationCollator"
8. train_data_loader, validate_data_loader, test_data_loader = \
get_data_loader(dataset_name, collate_name, conf) # 数据预处理,获取DataLoader类对象
# 是一个ClassificationDataset对象,只执行了__init__函数,加载了{key: index}和{index: key}
# 有__getitem__函数,可以用[]调用
# {key: index}和{index: key}两种字典不为空,调用__getitem__函数时返回空
9. empty_dataset = globals()[dataset_name](conf, [])
10. model = get_classification_model(model_name, empty_dataset, conf) # 设置模型
11. loss_fn = globals()["ClassificationLoss"](
label_size=len(empty_dataset.label_map), loss_type=conf.train.loss_type) # 设置损失函数 BCEWITHLOGITS
12. optimizer = get_optimizer(conf, model) # 设置优化器ADAM
13. evaluator = cEvaluator(conf.eval.dir) # 设置计算准确率的各项指标
14. trainer = globals()["ClassificationTrainer"](
empty_dataset.label_map, logger, evaluator, conf, loss_fn) # 有准确率和损失函数
15. best_epoch = -1
16. best_performance = 0
17. model_file_prefix = conf.checkpoint_dir + "/" + model_name
18. for epoch in range(conf.train.start_epoch,
conf.train.start_epoch + conf.train.num_epochs): # 迭代训练
19. start_time = time.time()
20. trainer.train(train_data_loader, model, optimizer, "Train", epoch)
21. trainer.eval(train_data_loader, model, optimizer, "Train", epoch)
22. performance = trainer.eval(
validate_data_loader, model, optimizer, "Validate", epoch) # 计算准确率的各项指标,返回fscore_list
23. trainer.eval(test_data_loader, model, optimizer, "test", epoch)
24. if performance > best_performance: # record the best model
25. best_epoch = epoch
26. best_performance = performance
27. save_checkpoint({
'epoch': epoch,
'model_name': model_name,
'state_dict': model.state_dict(),
'best_performance': best_performance,
'optimizer': optimizer.state_dict(),
}, model_file_prefix)
28. time_used = time.time() - start_time
29. logger.info("Epoch %d cost time: %d second" % (epoch, time_used))
# best model on validateion set
30. best_epoch_file_name = model_file_prefix + "_" + str(best_epoch)
31. best_file_name = model_file_prefix + "_best"
32. shutil.copyfile(best_epoch_file_name, best_file_name)
33. load_checkpoint(model_file_prefix + "_" + str(best_epoch), conf, model, optimizer)
34. trainer.eval(test_data_loader, model, optimizer, "Best test", best_epoch)
3.1 get_data_loader()
函数
train()
函数第8行,使用get_data_loader()
函数加载训练集、验证集和测试集的数据,用来PyTorch的数据读取,PyTorch的数据读取主要包含Dataset、DataLoader、DataLoaderIter三个类,这三者大致是一个依次封装的关系: DataLoaderIter封装DataLoader,DataLoader又封装Dataset。本项目自定义了Dataset类。
进函数细看
def get_data_loader(dataset_name, collate_name, conf):
"""Get data loader: Train, Validate, Test
:param dataset_name: "ClassificationDataset"
:param collate_name: "FastTextCollator"
"""
train_dataset = globals()[dataset_name](
conf, conf.data.train_json_files, generate_dict=True)
collate_fn = globals()[collate_name](conf, len(train_dataset.label_map))
train_data_loader = DataLoader(
train_dataset, batch_size=conf.train.batch_size, shuffle=True,
num_workers=conf.data.num_worker, collate_fn=collate_fn,
pin_memory=True)
... # 对验证集、测试集的处理同训练集,只是传入DataLoader的参数不同,collate_fn用的同一个
return train_data_loader, validate_data_loader, test_data_loader
概括来讲,函数主要有三个部分:train_dataset
返回一个batch的所有数据,collate_fn
将所有数据封装到一起,后面就可以用data_loader
得到每一个batch的数据
globals()[dataset_name]()
这是一种动态调用类的方式,方括号里传入字符串形式的类名,后面圆括号中是该类构造函数的参数,globals()
函数会以字典的形式返回该类的全部全局变量,后面就可以用a.b
的形式返回实例a的全局变量b,代码执行的时候会调用该类的构造函数,完成对象的初始化。
3.11 ClassificationDataset
类
这里,碰壁的地方,没见过这个用法,代入传入的参数
train_dataset = globals()["ClassificationDataset"](
conf, conf.data.train_json_files, generate_dict=True)
generate_dict=True
表明需要重新建立token等特征与下标index的对应关系,项目中将特征与index的对应关系存储在.dict文件中,相同数据在执行完第一遍后,可将该参数设置为generate_dict=False
,这样就不会再重新处理一遍token等特征,会节省一些运行时间。
train_dataset
中的每条数据最终会被处理成以下格式:数字就是token对应的下标
{
'doc_label': [23, 4],
'doc_token': [14878, 2193, ...],
'doc_char': [3488, 254, 41, ...], # 不限制 token 长度
'doc_char_in_token': [[2415, 3488], [254, 41], ...], # 限制token的最长长度(4)
'doc_token_ngram': [0],
'doc_keyword': [0],
'doc_topic': [0]
}
__init__()
通过上面的动态类,调用自定义的ClassificationDataset
类的构造函数,该构造函数只是继承其父类DatasetBase
类的构造函数,父类继承自抽象类torch.utils.data.Dataset
,最终得到token、char等特征与index的对应关系,在该构造函数中
-
self._init_dict()
函数初始化统计数据,在父类DatasetBase
中,该函数只有1行代码raise NotImplementedError
,只抛出了一个未实现的异常,这表明该函数是预留函数,必须在子函数中实现,否则将会报错,所以需要去子类ClassificationDataset
中找该函数的具体实现:
函数中定义了很多列表、字典等,从第一行就可以看出,模型使用的信息包括文档标签label,分词token,字符char,token_ngram,关键词keyword,主题词topic六类数据,后面保存了一些配置文件中的配置信息,接着定义统计数据所使用的字典。self.label_map
等统计所有{特征:出现次数}
,self.label_count_list
等保存最终需要的的特征,self.id_to_label_map
保存{index: label}
的对应关系。 -
for i, json_file in enumerate(json_files)
:用这个for循环读取每个文件的每一条数据,并统计数据集大小
tell()
函数返回文件的当前位置 -
generate_dict=True
:如果前面动态调用类里传入的generate_dict
参数为True,则开始下面的步骤:
(1)统计每个特征的出现次数,根据data.generate_dict_using_json_files
等配置信息决定参与训练的数据只有训练集还是也包括验证集和测试集。通过内置函数_insert_vocab
逐条处理所有数据特征,该函数的具体处理函数self._insert_vocab()
的实现也在其子类ClassificationDataset
中,self._insert_vocab()
又调用了父类的一些函数来进行最终的计数操作。在对token进行处理的同时,一方面会同时统计char,另一方面若配置文件中feature.token_ngram>1
则会同时统计token_ngram:假设feature.token_ngram=3
,token为['感谢', '司机', '大哥', '行程', '安排']
,则最终得到token_ngram_map = ['感谢司机', '司机大哥', '大哥行程', '行程安排', '感谢司机大哥', '司机大哥行程', '大哥行程安排']
(2)若有预训练好的embedding文件,则加载到相应的字典里
(3)self._shrink_dict()
:处理self.label_map
中的数据,将其中数据根据特征出现次数逆序,只留下数量大于设置的最小值feature.min_token_count
且出现次数在topfeature.max_token_dict_size
的特征,保存在self.label_count_list
中
(4)self._save_dict()
:将self.label_count_list
中的数据写入到.dict文件中,同时构建self.id_to_label_map[index] = label
字典 -
self._load_dict()
:加载.dict文件中的数据。最终得到两组对应字典:self.label_map = {key: index}
和self.id_to_label_map = {index: key}
,这里的key指的就是前面六类数据。
这里用到了self.VOCAB_UNKNOWN=1
,self.VOCAB_PADDING=0
,self.VOCAB_PADDING_LEARNABLE=2
三种标记词,self.VOCAB_UNKNOWN=1
用于标记在训练数据中未出现过的未登陆词;self.VOCAB_PADDING=0
用于对长度不足的数据进行填充,因为我们在训练时要保证每一条数据的维度是相同的,填充为0则表明该位是不可学习的,因为在反向传播时求得的导数始终为0;self.VOCAB_PADDING_LEARNABLE=2
也是用来填充,不同点是该位是可学习的。
动态调用的ClassificationDataset
类的构造函数执行完毕
__getitem__()
所有自定义的Dataset都需要继承并且实现抽象类torch.utils.data.Dataset
的__getitem__()
方法,来定义如何取数据,实现__getitem__()
方法意味着,它的实例对象(假设为P)可以以P[index]的形式取值。__getitem__()
方法的具体实现如下:
def __getitem__(self, idx):
if idx >= self.sample_size: # self.sample_size = 373883
raise IndexError
index = self.sample_index[idx]
with open(self.files[index[0]]) as fin:
fin.seek(index[1])
json_str = fin.readline()
return self._get_vocab_id_list(json.loads(json_str))
其中,self.sample_index = [[], []]
是一个存放数据的列表,该列表中的元素又是一个个列表,元素列表是二维的,第一个元素记录当前数据属于第几个文件(配置文件中data.train_json_files
配置项是列表类型的,意味着可以使用多个文件作为训练集),第二个元素记录当前数据在文件中的起始位置。
当调用该函数时,会打开指定的文件,定位到所需数据的起始位置,再返回self._get_vocab_id_list()
函数处理过的数据格式。
父类DatasetBase
类中未对数据做任何处理,直接原样返回。
子类ClassificationDataset
类使用__init__()
函数中加载的self.label_map = {key: index}
对应关系字典将原始数据中的token等特征转换为其对应的下标,具体实现如下:
def _get_vocab_id_list(self, json_obj):
"""Use dict to convert all vocabs to ids
"""
doc_labels = json_obj[self.DOC_LABEL]
doc_tokens = \
json_obj[self.DOC_TOKEN][0:self.config.feature.max_token_len]
doc_keywords = json_obj[self.DOC_KEYWORD]
doc_topics = json_obj[self.DOC_TOPIC]
token_ids, char_ids, char_in_token_ids, token_ngram_ids = \
self._token_to_id(doc_tokens, self.token_map, self.char_map,
self.config.feature.token_ngram,
self.token_ngram_map,
self.config.feature.max_char_len,
self.config.feature.max_char_len_per_token) # 将具体的token用其对应的index代替
return {self.DOC_LABEL: self._label_to_id(doc_labels, self.label_map) if self.model_mode != ModeType.PREDICT else [0],
self.DOC_TOKEN: token_ids,
self.DOC_CHAR: char_ids,
self.DOC_CHAR_IN_TOKEN: char_in_token_ids,
self.DOC_TOKEN_NGRAM: token_ngram_ids,
self.DOC_KEYWORD: self._vocab_to_id(doc_keywords, self.keyword_map),
self.DOC_TOPIC: self._vocab_to_id(doc_topics, self.topic_map)}
从token中可以提取出char,也可以拼接出token_ngram,对于未出现过的char、token会用未登陆词标志self.VOCAB_UNKNOWN
代替,对于空字符串或ngram<=1
的情况,则会用不可学习填充标志self.VOCAB_PADDING
代替,最终返回对应的id列表
(这里有行代码没看懂,char_id_list.extend(char_id[0:max_char_sequence_length]),max_char_sequence_length对应配置文件中的feature.max_char_len,char_len的长度难道不是1?如果你知道,欢迎评论区交流讨论,谢谢)
假设P为ClassificationDataset
类的实例,当你调用P[index]时得到的数据格式为
{
'doc_label': [23, 4],
'doc_token': [14878, 2193, ...],
'doc_char': [3488, 254, 41, ...], # 不限制 token 长度
'doc_char_in_token': [[2415, 3488], [254, 41], ...], # 限制token的最长长度(4)
'doc_token_ngram': [0],
'doc_keyword': [0],
'doc_topic': [0]
}
至此,ClassificationDataset
类完成了它的使命。
3.12 FastTextCollator
类
该类的作用是打包batch,将每一个batch的所有数据整合到一个大的列表中,再将所有列表转换成torch.tensor
类型的数据,构建成字典返回。
代入传入的参数得
collate_fn = globals()["FastTextCollator"](conf, len(train_dataset.label_map))
__init__()
通过上面的动态类,调用FastTextCollator
类的构造函数,该类只重写了__call__()
函数,构造函数需要继续调用其父类ClassificationCollator
类的,
ClassificationCollator
类的self._init_()
函数
加载一些配置文件中的配置项__call__()
函数:重写__call__()
函数意味着之后可以将其实例当成函数来调用。参数为batch
,传进来的是批量的数据,每一条数据的格式与上面调用P[index]时得到的数据格式相同,整个batch的数据都会加入到同一个大的列表中。
首先,对batch中的每一条数据依次执行:
(1)添加label:如果是单标签训练,则所有数据的标签构成一个列表;如果是多标签训练,则每条数据的标签构成一个小列表,这些小列表再构成一个大的列表;
(2)添加token:_append_vocab()
函数是定义在__call__()
函数内部的嵌套函数,该函数只考虑数据中的所有非未登陆词,记录进相应的列表中;
(3)同理添加token_ngram、keywords、topics。
接着,将label的格式转换成torch.tensor
,对多标签训练数据,需要通过函数self._get_multi_hot_label()
将每条数据的label转换成one-hot
的形式,函数不是很好理解,可以通过注释中的例子理解:
def _get_multi_hot_label(self, doc_labels):
"""For multi-label classification 转成one-hot形式
Generate multi-hot for input labels
e.g. input: [[0,1], [2]]
output: [[1,1,0], [0,0,1]]
"""
batch_size = len(doc_labels) # 2
max_label_num = max([len(x) for x in doc_labels]) # 2
doc_labels_extend = \
[[doc_labels[i][0] for x in range(max_label_num)] for i in range(batch_size)] # [[0, 0], [2, 2]],将所有子列表转换成同一纬度
for i in range(0, batch_size):
doc_labels_extend[i][0: len(doc_labels[i])] = doc_labels[i] # [[0, 1], [2, 2]],保证每条数据中前面的数据与输入相同
y = torch.Tensor(doc_labels_extend).long() # [[0, 1], [2, 2]],转换成torch
# scatter_函数说明:
# 第一个1代表沿着y轴按行看(如果是0则沿着x轴按列看
# y是每一行待填充数据的位置
# 最后一个1代表填充的数据是1,也可以写成torch.Tensor的形式填充不同的值
y_onehot = torch.zeros(batch_size, self.label_size).scatter_(1, y, 1) # [[1., 1., 0.], [0., 0., 1.]]
return y_onehot
最后,构建成字典返回,字典中的每个value值都转换成torch.tensor
类型。
至此,FastTextCollator
类完成了它的使命。
3.13 实例化DataLoader
对象
DataLoader
是PyTorch自带的类,传入相应参数即可实例化一个DataLoader
对象。
参数data.num_worker>1
即为使用多线程读取数据。
在该类的__init__()
函数中定义了一堆成员变量,通过__iter__()
函数传进_MultiProcessingDataLoaderIter()
类中,然后就可以通过_MultiProcessingDataLoaderIter()
类中的__next__()
函数读取一个个batch了
同样的函数处理训练集、验证集、测试集,最终返回train_data_loader, validate_data_loader, test_data_loader