NeMo中文/英文ASR模型微调训练实践

1.安装nemo

pip install -U nemo_toolkit[all] ASR-metrics

2.下载ASR预训练模型到本地(建议使用huggleface,比nvidia官网快很多)

3.从本地创建ASR模型

asr_model = nemo_asr.models.EncDecCTCModel.restore_from("stt_zh_quartznet15x5.nemo")

3.定义train_mainfest,包含语音文件路径、时长和语音文本的json文件

{"audio_filepath": "test.wav", "duration": 8.69, "text": "诶前天跟我说昨天跟我说十二期利率是多少工号幺九零八二六十二期的话零点八一万的话分十二期利息八十嘛"}

4.读取模型的yaml配置

# 使用YAML读取quartznet模型配置文件
try:
    from ruamel.yaml import YAML
except ModuleNotFoundError:
    from ruamel_yaml import YAML
config_path ="/NeMo/examples/asr/conf/quartznet/quartznet_15x5_zh.yaml"

yaml = YAML(typ='safe')
with open(config_path) as f:
    params = yaml.load(f)
print(params['model']['train_ds']['manifest_filepath'])
print(params['model']['validation_ds']['manifest_filepath'])

5.设置训练及验证manifest

train_manifest = "train_manifest.json"
val_manifest = "train_manifest.json"

params['model']['train_ds']['manifest_filepath']=train_manifest
params['model']['validation_ds']['manifest_filepath']=val_manifest
print(params['model']['train_ds']['manifest_filepath'])
print(params['model']['validation_ds']['manifest_filepath'])

asr_model.setup_training_data(train_data_config=params['model']['train_ds'])
asr_model.setup_validation_data(val_data_config=params['model']['validation_ds'])

6.使用pytorch_lightning训练
import pytorch_lightning as pl 
trainer = pl.Trainer(accelerator='gpu', devices=1,max_epochs=10)
trainer.fit(asr_model)#调用‘fit’方法开始训练 

7.保存训练好的模型

asr_model.save_to('my_stt_zh_quartznet15x5.nemo')

8.看看训练后的效果

my_asr_model = nemo_asr.models.EncDecCTCModel.restore_from("my_stt_zh_quartznet15x5.nemo")
queries=my_asr_model.transcribe(['test1.wav'])
print(queries)

#['诶前天跟我说的昨天跟我说十二期利率是多少工号幺九零八二六零十二期的话零点八一万的话分十二期利息八十嘛']

9.计算字错率

from ASR_metrics import utils as metrics
s1 = "诶前天跟我说昨天跟我说十二期利率是多少工号幺九零八二六十二期的话零点八一万的话分十二期利息八十嘛"#指定正确答案
s2 = " ".join(queries)#识别结果
print("字错率:{}".format(metrics.calculate_cer(s1,s2)))#计算字错率cer
print("准确率:{}".format(1-metrics.calculate_cer(s1,s2)))#计算准确率accuracy

#字错率:0.041666666666666664

#准确率:0.9583333333333334

10.增加标点符号输出

from zhpr.predict import DocumentDataset,merge_stride,decode_pred
from transformers import AutoModelForTokenClassification,AutoTokenizer
from torch.utils.data import DataLoader

def predict_step(batch,model,tokenizer):
        batch_out = []
        batch_input_ids = batch

        encodings = {'input_ids': batch_input_ids}
        output = model(**encodings)

        predicted_token_class_id_batch = output['logits'].argmax(-1)
        for predicted_token_class_ids, input_ids in zip(predicted_token_class_id_batch, batch_input_ids):
            out=[]
            tokens = tokenizer.convert_ids_to_tokens(input_ids)
            
            # compute the pad start in input_ids
            # and also truncate the predict
            # print(tokenizer.decode(batch_input_ids))
            input_ids = input_ids.tolist()
            try:
                input_id_pad_start = input_ids.index(tokenizer.pad_token_id)
            except:
                input_id_pad_start = len(input_ids)
            input_ids = input_ids[:input_id_pad_start]
            tokens = tokens[:input_id_pad_start]
    
            # predicted_token_class_ids
            predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids]
            predicted_tokens_classes = predicted_tokens_classes[:input_id_pad_start]

            for token,ner in zip(tokens,predicted_tokens_classes):
                out.append((token,ner))
            batch_out.append(out)
        return batch_out

if __name__ == "__main__":
    window_size = 256
    step = 200
    text = queries[0]
    dataset = DocumentDataset(text,window_size=window_size,step=step)
    dataloader = DataLoader(dataset=dataset,shuffle=False,batch_size=5)

    model_name = 'zh-wiki-punctuation-restore'
    model = AutoModelForTokenClassification.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model_pred_out = []
    for batch in dataloader:
        batch_out = predict_step(batch,model,tokenizer)
        for out in batch_out:
            model_pred_out.append(out)
        
    merge_pred_result = merge_stride(model_pred_out,step)
    merge_pred_result_deocde = decode_pred(merge_pred_result)
    merge_pred_result_deocde = ''.join(merge_pred_result_deocde)
    print(merge_pred_result_deocde)
#诶前天跟我说的。昨天跟我说十二期利率是多少。工号幺九零八二六零十二期的话,零点八一万的话,分十二期利息八十嘛。

猜你喜欢

转载自blog.csdn.net/wxl781227/article/details/132254944
今日推荐