NeMo Chinese/English ASR model fine-tuning training practice

1.Install nemo

pip install -U nemo_toolkit[all] ASR-metrics

2. Download the ASR pre-trained model locally (it is recommended to use huggleface, which is much faster than the nvidia official website)

3. Create ASR model locally

asr_model = nemo_asr.models.EncDecCTCModel.restore_from("stt_zh_quartznet15x5.nemo")

3. Define train_mainfest, a json file containing the voice file path, duration and voice text

{"audio_filepath": "test.wav", "duration": 8.69, "text": "Hey, the day before yesterday, you told me what the interest rate for the 12th period was, and the job number was 908262, zero If you pay 80,000 to 10,000 yuan, the interest will be 80,000 yuan in twelve installments"}

4. Read the yaml configuration of the model

# Use YAML to read the quartznet model configuration file
try:
    from ruamel.yaml import YAML
except ModuleNotFoundError:
    from ruamel_yaml import YAML
config_path ="/NeMo/examples/asr/conf/quartznet/quartznet_15x5_zh.yaml"

yaml = YAML(typ='safe')
with open(config_path) as f:
    params = yaml.load(f)
print(params['model']['train_ds']['manifest_filepath'])
print(params['model']['validation_ds']['manifest_filepath'])

5. Set up training and verification manifest

train_manifest = "train_manifest.json"
val_manifest = "train_manifest.json"

params['model']['train_ds']['manifest_filepath']=train_manifest
params['model']['validation_ds']['manifest_filepath']=val_manifest
print(params['model']['train_ds']['manifest_filepath'])
print(params['model']['validation_ds']['manifest_filepath'])

asr_model.setup_training_data(train_data_config=params['model']['train_ds'])
asr_model.setup_validation_data(val_data_config=params['model']['validation_ds'])

6. Use pytorch_lightning training
import pytorch_lightning as pl 
trainer = pl.Trainer(accelerator='gpu', devices=1,max_epochs=10)
trainer.fit(asr_model)#Call the 'fit' method to start training 

7. Save the trained model

asr_model.save_to('my_stt_zh_quartznet15x5.nemo')

8. See the results after training

my_asr_model = nemo_asr.models.EncDecCTCModel.restore_from("my_stt_zh_quartznet15x5.nemo")
queries=my_asr_model.transcribe(['test1.wav'])
print(queries)

#['Hey, you told me the day before yesterday what the interest rate is for the twelve installments. If the employee number is 9082602, if it is 0.810,000, the interest will be divided into 12 installments and the interest will be 80.']

9. Calculate word error rate

from ASR_metrics import utils as metrics
s1 = "Hey, the day before yesterday, you told me yesterday what the interest rate is for the 12-period period. If the employee number is 908262, if it is 0.810,000, the interest will be 80 in 12 periods. "#Specify the correct answer
s2 = " ".join(queries)#Recognition results
print("Word error rate:{}".format(metrics.calculate_cer(s1,s2)))#Calculate the word error rate cer print
("Accurate Rate:{}".format(1-metrics.calculate_cer(s1,s2)))#Calculate accuracy accuracy

#Word error rate:0.041666666666666664

#Accuracy rate:0.95833333333333334

10. Add punctuation mark output

from zhpr.predict import DocumentDataset,merge_stride,decode_pred
from transformers import AutoModelForTokenClassification,AutoTokenizer
from torch.utils.data import DataLoader

def predict_step(batch,model,tokenizer):
        batch_out = []
        batch_input_ids = batch

        encodings = {'input_ids': batch_input_ids}
        output = model(**encodings)

        predicted_token_class_id_batch = output['logits'].argmax(-1)
        for predicted_token_class_ids, input_ids in zip(predicted_token_class_id_batch, batch_input_ids):
            out=[]
            tokens = tokenizer.convert_ids_to_tokens(input_ids)
            
            # compute the pad start in input_ids
            # and also truncate the predict
            # print(tokenizer.decode(batch_input_ids))
            input_ids = input_ids.tolist()
            try:
                input_id_pad_start = input_ids.index(tokenizer.pad_token_id)
            except:
                input_id_pad_start = len(input_ids)
            input_ids = input_ids[:input_id_pad_start]
            tokens = tokens[:input_id_pad_start]
    
            # predicted_token_class_ids
            predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids]
            predicted_tokens_classes = predicted_tokens_classes[:input_id_pad_start]

            for token,ner in zip(tokens,predicted_tokens_classes):
                out.append((token,ner))
            batch_out.append(out)
        return batch_out

if __name__ == "__main__":
    window_size = 256
    step = 200
    text = queries[0]
    dataset = DocumentDataset(text,window_size=window_size,step=step)
    dataloader = DataLoader(dataset=dataset,shuffle=False,batch_size=5)

    model_name = 'zh-wiki-punctuation-restore'
    model = AutoModelForTokenClassification.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model_pred_out = []
    for batch in dataloader:
        batch_out = predict_step(batch,model,tokenizer)
        for out in batch_out:
            model_pred_out.append(out)
        
    merge_pred_result = merge_stride(model_pred_out,step)
    merge_pred_result_deocde = decode_pred(merge_pred_result)
    merge_pred_result_deocde = ''.join (merge_pred_result_deocde)
    print(merge_pred_result_deocde)
#Hey, you told me the day before yesterday. Yesterday I was told what the twelve-period interest rate was. If the job number is 19082602, if it is 0.810, the interest will be 80 in 12 installments.

Guess you like

Origin blog.csdn.net/wxl781227/article/details/132254944