ChatGLM多轮对话微调-多轮对话训练数据的自动生成(标注)

        通常使用大模型进行业务数据微调的时候,需要对历史对话数据进行细粒度的整理,比如:1-3轮对话数据的微调,以便模型能够学会多轮对话。以ChatGLM为例,微调对话任务的时候,微调会导致模型的理解能力别削弱(无法理解相似语义的输入),即当输入数据prompt的分布与训练数据分布不一致时,模型不会按照训练集的response进行输出,而是使用模型原有的能力进行输出,模型输出结果出现不可控的情况。这个时候需要对输入的数据进行数据增强,数据的方法很多,但个人认为对于样本比较少的对话,最有效的方式应该是人工进行标注,即人工写出输入数据prompt的各种可能的语义相似的样本来(根据对数据增强方式的理解,如:释义、采样和加噪),有人说数据增强的方式怎么做也无法与人工标注的效果相比,只适合于写论文,这里不做评价和扩展。仅针对多轮对话进行1-3轮的对话数据自动标注说明。

       假定历史对话的格式为:

#test.txt
坐席:Y0
客户:X0
坐席:Y1
客户:X1
.....
坐席:Yn
客户:Xn

        说明:1轮指的是n=0时,坐席和客户说的话作为输入,3轮指的是n=2时,坐席和客户说的话作为输入。

        最多3轮,个人认为1-3轮的叠加能解决大部分场景的多轮对话的问题。

1.读取历史对话文本test.txt

import pandas as pd
data =[]
file_name = 'test'
with open(f'{file_name}.txt') as f:
    data = f.readlines()
print(data)

2.自动生成1-3轮对话标注

#to ChatGLM格式
lines=[]
prompt=''
for i,row in enumerate(data):
    if i>0 and i%2==0:
        temps = data[i-3:i-1]
        if len(temps) == 2:
            history = [[temps[0],temps[1]]]
        else:
            history = [['','']]
        lines.append({"prompt":prompt.replace('\n',''),"response":row.replace('\n',''),"history":history})
        prompt = row
    else:
        prompt = row
    if i==len(data)-1:
        prompt = ''
prompt=''
for i,row in enumerate(data):
    if i>0 and i%2==0:
        temps = data[i-5:i-1]
        if len(temps) == 4:
            history = [[temps[0],temps[1]],[temps[2],temps[3]]]
        else:
            history = [['','']]
        lines.append({"prompt":prompt.replace('\n',''),"response":row.replace('\n',''),"history":history})
        prompt = row
    else:
        prompt = row
    if i==len(data)-1:
        prompt = ''
prompt=''
for i,row in enumerate(data):
    if i>0 and i%2==0:
        temps = data[i-7:i-1]
        if len(temps) == 6:
            history = [[temps[0],temps[1]],[temps[2],temps[3]],[temps[4],temps[5]]]
        else:
            history = [['','']]
        lines.append({"prompt":prompt.replace('\n',''),"response":row.replace('\n',''),"history":history})
        prompt = row
    else:
        prompt = row
    if i==len(data)-1:
        prompt = ''
print(lines)

3.显示自动标注的结果

df = pd.DataFrame(lines)
df

4.保存生成的标注数据

import json
with open ('train.json','w') as f:
    json.dump(lines,f)

猜你喜欢

转载自blog.csdn.net/wxl781227/article/details/131005577
今日推荐