Torchtext读取JSON数据

在文本预处理一节,介绍了如何利用torchtext读取tsv格式的文本数据。对于分类问题,这是足够的。但是在处理如NER和机器翻译等问题时,我们构造的输入通常就不是(类别,序列)这样的结构了,而是(序列,序列)。另一方面,在搭建混合网络时,有时我们希望能够给模型多个输入(例如cnn-bilstm-crf中,既需要字符又需要单词输入),这超过了tsv所能。因此要另辟蹊径。

尽管Torchtext封装了一个SequenceTaggingDataset类用于构造NER数据,但是在实际使用中发现,十分不方便生成batch。

json格式是采取字典的方式存储数据,这带来了很大的灵活性。但是关于其完整使用的相关文章非常有限,官方文档也未给出详细案例(不得不吐槽一下torchtext的说明文档)。

使用

下面以NER任务为例, 进行json使用说明,以抛砖引玉。事实上,这种方法可以无缝地拓展到其他的任务上去,如文本分类,机器翻译等。

对于NER任务,标签(target)和输入(source)都是同样长度的序列。例如:

source: 
人 民 网 1 月 1 日 讯 据 《 纽 约 时 报 》 报 道 , 美 国 华 尔 街 股 市 在 2 0 1 3 年 的 最 后 一 天 继 续 上 涨 , 和 全 球 股 市 一 样 , 都 以 最 高 纪 录 或 接 近 最 高 纪 录 结 束 本 年 的 交 易 。
target:
O O O B_T I_T I_T I_T O O O B_LOC I_LOC O O O O O O B_LOC I_LOC I_LOC I_LOC I_LOC O O O B_T I_T I_T I_T I_T O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O

文本预处理

一般来说,为了节省内存,会先将其做个字符-ID映射(这里只是举例说明,和上面不是同一段文本):

source:
5627 5580 5550 5636 4509 5192 5466 5463 5624 5520 4871 5637 5607 5411 5313 5251 5528 5628 5580 5612 5292 5636 5626 5637 4810
target:
9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 1

构造Json格式的文件

因为我们最终是通过torchtext来生成batch,所以首先要把数据存储为torchtext能够读取的json格式。值得注意的是,torchtext能够读取的json文件和我们一般意义上的json文件格式是不同的(这也是比较坑的地方),我们需要把上面的数据处理成如下格式:

{"source": "10 111 2 3", "target": "1 1 2 2"}
{"source": "10 111 2 3", "target": "1 1 2 2"}
{"source": "10 111 2 3", "target": "1 1 2 2"}
{"source": "10 111 2 3", "target": "1 1 2 2"}
{"source": "10 111 2 3", "target": "1 1 2 2"}

可以看到,里面的内容和通常的Json并无区别,每个字段采用字典的格式存储。不同的是,多个json序列中间是以换行符隔开的,而且最外面没有列表。

那么怎么构造这样的数据呢?

我们知道python中的json模块是用来处理json文件的,但是对所有序列进行json.dump()后的结果并非我们想要的(序列之间不是以换行符隔开的),几经尝试,找到如下的处理方式,仅做参考,可能还有更好的处理方法:

with open(config.TRAIN_FILE, 'w') as fw:
    for sent, label in train:
        sent = ' '.join([str(w) for w in sent])
        label = ' '.join([str(l) for l in label])
        df = {
    
    "source": sent, "target": label}
        encode_json = json.dumps(df)
        # 一行一行写入,并且采用print到文件的方式
        print(encode_json, file=fw)

这里采用的是一行一行进行dumps, 然后print到文件,就能得到我们想要的格式了。

接下来就是使用torchtext读取了,这个和之前处理tsv文件并无太大差异。

Torchtext读取

这里和之前处理tsv类似,不多赘述,只是将几个不同的点提出来说一下。

(1)Field的定义,这里source和target都是序列,因此两个字段的定义方式基本相同

(2)传入TabularDataset的fields和tsv的定义有所不同,这里定义成字典-元组格式

(3) TabularDataset的format要指定成json格式

(4)pad_token根据需要,一般来说source使用0作为padding, target使用-1进行padding

def create_dataset(self):
    SOURCE = Field(sequential=True, tokenize=x_tokenize,
                 use_vocab=False, batch_first=True,
                 fix_length=self.fix_length,   #  如需静态padding,则设置fix_length, 但要注意要大于文本最大长度
                 eos_token=None, init_token=None,
                 include_lengths=True, pad_token=0)

    TARGET = Field(sequential=True, tokenize=x_tokenize,
                 use_vocab=False, batch_first=True,
                 fix_length=self.fix_length,   #  如需静态padding,则设置fix_length, 但要注意要大于文本最大长度
                 eos_token=None, init_token=None,
                 include_lengths=False, pad_token=-1)

    fields = {
    
    'source': ('source', SOURCE), 'target': ('target', TARGET)}

    train, valid = TabularDataset.splits(
        path=config.ROOT_DIR,
        train=self.train_path, validation=self.valid_path,
        format="json",
        skip_header=False,
        fields=fields)
    return train, valid


def get_iterator(self, train, valid):
    train_iter = BucketIterator(train,
                                batch_size=self.batch_size,
                                device = torch.device("cpu"),  # cpu by -1, gpu by 0
                                sort_key=lambda x: len(x.source), # field sorted by len
                                sort_within_batch=True,
                                repeat=False)
    val_iter = BucketIterator(valid,
                                batch_size=self.batch_size,
                                device=torch.device("cpu"),  # cpu by -1, gpu by 0
                                sort_key=lambda x: len(x.source),  # field sorted by len
                                sort_within_batch=True,
                                repeat=False)
    return train_iter, val_iter

猜你喜欢

转载自blog.csdn.net/weixin_43896398/article/details/85559172
今日推荐