在文本预处理一节,介绍了如何利用torchtext读取tsv格式的文本数据。对于分类问题,这是足够的。但是在处理如NER和机器翻译等问题时,我们构造的输入通常就不是(类别,序列)这样的结构了,而是(序列,序列)。另一方面,在搭建混合网络时,有时我们希望能够给模型多个输入(例如cnn-bilstm-crf中,既需要字符又需要单词输入),这超过了tsv所能。因此要另辟蹊径。
尽管Torchtext封装了一个SequenceTaggingDataset类用于构造NER数据,但是在实际使用中发现,十分不方便生成batch。
json格式是采取字典的方式存储数据,这带来了很大的灵活性。但是关于其完整使用的相关文章非常有限,官方文档也未给出详细案例(不得不吐槽一下torchtext的说明文档)。
使用
下面以NER任务为例, 进行json使用说明,以抛砖引玉。事实上,这种方法可以无缝地拓展到其他的任务上去,如文本分类,机器翻译等。
对于NER任务,标签(target)和输入(source)都是同样长度的序列。例如:
source:
人 民 网 1 月 1 日 讯 据 《 纽 约 时 报 》 报 道 , 美 国 华 尔 街 股 市 在 2 0 1 3 年 的 最 后 一 天 继 续 上 涨 , 和 全 球 股 市 一 样 , 都 以 最 高 纪 录 或 接 近 最 高 纪 录 结 束 本 年 的 交 易 。
target:
O O O B_T I_T I_T I_T O O O B_LOC I_LOC O O O O O O B_LOC I_LOC I_LOC I_LOC I_LOC O O O B_T I_T I_T I_T I_T O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O
文本预处理
一般来说,为了节省内存,会先将其做个字符-ID映射(这里只是举例说明,和上面不是同一段文本):
source:
5627 5580 5550 5636 4509 5192 5466 5463 5624 5520 4871 5637 5607 5411 5313 5251 5528 5628 5580 5612 5292 5636 5626 5637 4810
target:
9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 1
构造Json格式的文件
因为我们最终是通过torchtext来生成batch,所以首先要把数据存储为torchtext能够读取的json格式。值得注意的是,torchtext能够读取的json文件和我们一般意义上的json文件格式是不同的(这也是比较坑的地方),我们需要把上面的数据处理成如下格式:
{"source": "10 111 2 3", "target": "1 1 2 2"}
{"source": "10 111 2 3", "target": "1 1 2 2"}
{"source": "10 111 2 3", "target": "1 1 2 2"}
{"source": "10 111 2 3", "target": "1 1 2 2"}
{"source": "10 111 2 3", "target": "1 1 2 2"}
可以看到,里面的内容和通常的Json并无区别,每个字段采用字典的格式存储。不同的是,多个json序列中间是以换行符隔开的,而且最外面没有列表。
那么怎么构造这样的数据呢?
我们知道python中的json模块是用来处理json文件的,但是对所有序列进行json.dump()后的结果并非我们想要的(序列之间不是以换行符隔开的),几经尝试,找到如下的处理方式,仅做参考,可能还有更好的处理方法:
with open(config.TRAIN_FILE, 'w') as fw:
for sent, label in train:
sent = ' '.join([str(w) for w in sent])
label = ' '.join([str(l) for l in label])
df = {
"source": sent, "target": label}
encode_json = json.dumps(df)
# 一行一行写入,并且采用print到文件的方式
print(encode_json, file=fw)
这里采用的是一行一行进行dumps, 然后print到文件,就能得到我们想要的格式了。
接下来就是使用torchtext读取了,这个和之前处理tsv文件并无太大差异。
Torchtext读取
这里和之前处理tsv类似,不多赘述,只是将几个不同的点提出来说一下。
(1)Field的定义,这里source和target都是序列,因此两个字段的定义方式基本相同
(2)传入TabularDataset的fields和tsv的定义有所不同,这里定义成字典-元组格式
(3) TabularDataset的format要指定成json格式
(4)pad_token根据需要,一般来说source使用0作为padding, target使用-1进行padding
def create_dataset(self):
SOURCE = Field(sequential=True, tokenize=x_tokenize,
use_vocab=False, batch_first=True,
fix_length=self.fix_length, # 如需静态padding,则设置fix_length, 但要注意要大于文本最大长度
eos_token=None, init_token=None,
include_lengths=True, pad_token=0)
TARGET = Field(sequential=True, tokenize=x_tokenize,
use_vocab=False, batch_first=True,
fix_length=self.fix_length, # 如需静态padding,则设置fix_length, 但要注意要大于文本最大长度
eos_token=None, init_token=None,
include_lengths=False, pad_token=-1)
fields = {
'source': ('source', SOURCE), 'target': ('target', TARGET)}
train, valid = TabularDataset.splits(
path=config.ROOT_DIR,
train=self.train_path, validation=self.valid_path,
format="json",
skip_header=False,
fields=fields)
return train, valid
def get_iterator(self, train, valid):
train_iter = BucketIterator(train,
batch_size=self.batch_size,
device = torch.device("cpu"), # cpu by -1, gpu by 0
sort_key=lambda x: len(x.source), # field sorted by len
sort_within_batch=True,
repeat=False)
val_iter = BucketIterator(valid,
batch_size=self.batch_size,
device=torch.device("cpu"), # cpu by -1, gpu by 0
sort_key=lambda x: len(x.source), # field sorted by len
sort_within_batch=True,
repeat=False)
return train_iter, val_iter