话不多说,由于直接下载glue data需要科学上网,尤其是MRPC的数据需要下载dev_ids文件,但是scripts里面的url是有问题的,笔者也是用了各种方法之后才把数据下载并用hub上的代码process好:
glue data百度云(包含原始data以及process后的train/dev/test.tsv文件):
提取码:nfim
但是————我特喵的这tsv文件格式明显有问题啊!
按道理,他的tsv文件读出来每一行应该都有固定个数的cell,分别有id,seq1,seq2,label等信息,这样很简单的一行行读取处理数据,但是有几个数据的tsv文件格式并不是这样规整的。
就像上面这种,一个cell里面藏了好几条数据,这是因为每一行的分割‘\n’被当作str对象,而没有被当作文件流读取的分割对象,所以好几条id,seq都挤到一行的某个seq字段去了
不过这种情况毕竟是少数,本身数据量打,有几千几万条,也不差这么几十条训练验证数据吧,所以一开始我写代码偷了个懒:
error_cnt=0
with open("./train.tsv","r",encoding="utf-8") as train:
t=csv.reader(train,delimiter='\t')
for i,line in enumerate(t):
print(list(line))
try:
assert len(line) == 4
except:
print("error line,skip")
error_cnt += 1
lb=line[1]
if lb not in labels:
labels.append(lb)
label = labels.index(lb)
instance = dict()
instance["index"],instance['seq1'],instance['seq2'],instance["label"]=i,line[3],None,label
train_file.append(instance)
眼不见为净…但是后来突然发现,这拟码是glue board啊,train/dev少了不要紧,test data少了咋submit到leader board啊…
所以还是写了一下处理的代码:
def break_up(error_str:str)->list:
ans = []
lines=error_str.strip().split("\n")
for line in lines:
res=line.split("\t")
ans += res
return ans
def fix_glue_error(data:list,stand:int):
ans = []
for element in data:
res=break_up(element)
ans += res
assert len(ans) % stand == 0
new_data = []
for i in range(len(ans)//stand):
new_element=ans[i*stand:(i+1)*stand]
new_data.append(new_element)
return new_data
其实仔细一分析,处理这种格式问题很简单,不就是"\n","\t"没被reader认出来吗,那就手动split就行了。上面的代码就是把data
这个list传入(也就是原先read出来的每一行line
),然后给他全部split(’\n’)+split("\t"),也就是全部打散。然后依据实现已知的stand
(每一行理应有的字段数),给他重组成新的line
with open("./train.tsv","r",encoding="utf-8") as train:
t=csv.reader(train,delimiter='\t')
for i,line in enumerate(t):
if i==0:
print("skip first line")
continue
# print(list(line))
# try:
# assert len(line) == 4
# except:
# # print("error_line:",line)
new_lines=fix_glue_error(line,4)
for i,new_line in enumerate(new_lines):
lb=new_line[3]
if lb not in labels:
labels.append(lb)
label = labels.index(lb)
instance = dict()
instance["index"],instance['seq1'],instance['seq2'],instance["label"]=int(new_line[0]),new_line[1],new_line[2],label
train_file.append(instance)
print(len(train_file))
然后可以校对一下自己处理的glue data的instance数目是否正确(尤其是test):
GLUE基准数据集介绍
这位老兄应该也是自己proces的,因为我看他有几个数据的indtance数目是不对的,仅作参考吧,纠结的可以看原论文校对,或者提交leader board(数量不对它会有提示)