Práctica de entrenamiento de ajuste del modelo ASR chino/inglés de NeMo

1.Instalar nemo

instalación de pip -U nemo_toolkit[todas] métricas ASR

2. Descargue el modelo ASR previamente entrenado localmente (se recomienda usar huggleface, que es mucho más rápido que el sitio web oficial de nvidia)

3. Cree el modelo ASR localmente

asr_model = nemo_asr.models.EncDecCTCModel.restore_from("stt_zh_quartznet15x5.nemo")

3. Defina train_mainfest, un archivo json que contiene la ruta del archivo de voz, la duración y el texto de voz.

{"audio_filepath": "test.wav", "duration": 8.69, "text": "Oye, anteayer, me dijiste cuál era la tasa de interés para el período 12 y el número de trabajo era 908262, cero Si paga entre 80.000 y 10.000 yuanes, el interés será de 80.000 yuanes en doce cuotas"}

4. Lea la configuración yaml del modelo.

# Utilice YAML para leer el archivo de configuración del modelo quartznet.
Intente:
    desde ruamel.yaml importe YAML
excepto ModuleNotFoundError:
    desde ruamel_yaml importe YAML
config_path ="/NeMo/examples/asr/conf/quartznet/quartznet_15x5_zh.yaml"

yaml = YAML(typ='safe')
con open(config_path) como f:
    params = yaml.load(f)
print(params['model']['train_ds']['manifest_filepath'])
print(params[' modelo']['validation_ds']['manifest_filepath'])

5. Configurar manifiesto de capacitación y verificación.

train_manifest = "train_manifest.json"
val_manifest = "train_manifest.json"

params['model']['train_ds']['manifest_filepath']=train_manifest
params['model']['validation_ds']['manifest_filepath']=val_manifest
print(params['model']['train_ds'][ 'manifest_filepath'])
print(params['model']['validation_ds']['manifest_filepath'])

asr_model.setup_training_data(train_data_config=params['model']['train_ds'])
asr_model.setup_validation_data(val_data_config=params['model']['validation_ds'])

6. Utilice el entrenamiento pytorch_lightning
import pytorch_lightning as pl 
trainer = pl.Trainer(accelerator='gpu', devices=1,max_epochs=10)
trainer.fit(asr_model)#Llame al método 'fit' para comenzar a entrenar 

7. Guarde el modelo entrenado.

asr_model.save_to('my_stt_zh_quartznet15x5.nemo')

8. Ver los resultados después del entrenamiento.

my_asr_model = nemo_asr.models.EncDecCTCModel.restore_from("my_stt_zh_quartznet15x5.nemo")
queries=my_asr_model.transcribe(['test1.wav'])
print(consultas)

#['Oye, anteayer me dijiste cuál es la tasa de interés para 12 cuotas. Si tu número de trabajo es 9082602, si es 0.810.000, el interés será de 80 en 12 cuotas.']

9. Calcular la tasa de error de palabras

de ASR_metrics import utils as metrics
s1 = "Oye, anteayer, ayer me dijiste cuál es la tasa de interés para el período de 12 períodos. Si el número de empleado es 908262, si es 0.810,000, el interés será 80 en 12 periodos. "#Especifica la respuesta correcta
s2 = " ".join(queries)#Resultados de reconocimiento
print("Tasa de error de palabras: {}".format(metrics.calculate_cer(s1,s2)))#Calcular la palabra tasa de error cer print
("Tasa de precisión:{}".format(1-metrics.calculate_cer(s1,s2)))#Calcular precisión exactitud

#Tasa de error de palabras: 0,041666666666666664

#Tasa de precisión: 0,95833333333333334

10. Agregar salida de signo de puntuación

desde zhpr.predict importe DocumentDataset, merge_stride, decode_pred
desde transformadores importe AutoModelForTokenClassification, AutoTokenizer
desde torch.utils.data importe DataLoader

def predict_step(lote,modelo,tokenizador):
        lote_out = []
        lote_input_ids = lote

        codificaciones = {'input_ids': lote_input_ids}
        salida = modelo (** codificaciones)

        predicted_token_class_id_batch = salida['logits'].argmax(-1)
        para predicted_token_class_ids, input_ids en zip(predicted_token_class_id_batch, batch_input_ids):
            out=[]
            tokens = tokenizer.convert_ids_to_tokens(input_ids)
            
            # calcula el inicio del pad en input_ids
            # y también trunca el predecir
            # print(tokenizer.decode(batch_input_ids))
            input_ids = input_ids.tolist()
            intente:
                input_id_pad_start = input_ids.index(tokenizer.pad_token_id)
            excepto:
                input_id_pad_start = len(input_ids)
            input_ids = input_ids[:input_id_pad_start]
            tokens = tokens[:input_id_pad_start]
    
            # predicted_token_class_ids
            predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids]
            predicted_tokens_classes = predicted_tokens_classes[:input_id_pad_start]

            para token,ner en zip(tokens,predicted_tokens_classes):
                out.append((token,ner))
            lote_out.append(out)
        devolver lote_out

if __name__ == "__main__":
    tamaño_ventana = 256
    paso = 200
    texto = consultas[0]
    conjunto de datos = DocumentDataset(texto,tamaño_ventana=tamaño_ventana,paso=paso)
    cargador de datos = Cargador de datos(conjunto de datos=conjunto de datos,shuffle=False,tamaño_batch=5)

    model_name = 'zh-wiki-punctuation-restore'
    model = AutoModelForTokenClassification.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model_pred_out = []
    para lote en el cargador de datos:
        lote_out = predict_step(batch,model,tokenizer)
        para fuera en lote_out:
            model_pred_out.append(out)
        
    merge_pred_result = merge_stride(model_pred_out,step)
    merge_pred_result_deocde = decode_pred(merge_pred_result)
    merge_pred_result_deocde = ''.join (merge_pred_result_deocde)
    print(merge_pred_result_deocde)
#Oye, me lo dijiste anteayer. Ayer me dijeron cuál era el tipo de interés a doce plazos. Si el número de trabajo es 19082602, si es 0,810 el interés será de 80 en 12 cuotas.

Supongo que te gusta

Origin blog.csdn.net/wxl781227/article/details/132254944
Recomendado
Clasificación