ICLR2023 | PromptPG: Cuando el aprendizaje por refuerzo se encuentra con modelos lingüísticos a gran escala

Fuente | El corazón de la máquina

Ingrese al grupo NLP —> únase al grupo de intercambio NLP (comentario nips/emnlp/nlpcc ingresa al grupo de contribución correspondiente)

El método PromptPG supera a la mejor línea de base (Few-shot CoT GPT-3) en un 5,31 % en la precisión de las respuestas a las preguntas.

El razonamiento matemático es una habilidad central de la inteligencia humana, pero para las máquinas, el pensamiento abstracto y el razonamiento lógico siguen siendo un gran desafío. Los modelos de lenguaje preentrenado a gran escala, como GPT-3 y GPT-4, han logrado un progreso notable en el razonamiento matemático basado en texto (como los problemas matemáticos verbales). Sin embargo, actualmente no está claro si estos modelos pueden manejar problemas más complejos que involucran información heterogénea como datos tabulares.

Para llenar este vacío, los investigadores de UCLA y el Instituto Allen de Inteligencia Artificial (AI2) lanzaron Tabular Math Word Problems (TabMWP), un conjunto de datos de 38,431 problemas de dominio abierto que requieren texto y razonamiento matemático en datos tabulares para obtener la respuesta correcta. respuesta. Cada pregunta en TabMWP está asociada con un contexto, que puede contener imágenes, texto o tablas en formato estructurado.

Los investigadores evaluaron diferentes modelos preentrenados, incluido el GPT-3 de pocos disparos en TabMWP. Como ha descubierto la investigación existente, GPT-3 de pocos disparos depende en gran medida de la selección de ejemplos en contexto, lo que conduce a su rendimiento bastante inestable en el caso de ejemplos seleccionados al azar. Esta inestabilidad es aún más grave cuando se trata de problemas de inferencia complejos como TabMWP.

Para resolver este problema, el autor propone el método PromptPG, que convierte la selección de ejemplos en un problema de bandido contextual en el aprendizaje por refuerzo, y utiliza Policy Gradient para entrenar una red de políticas para aprender a elegir el óptimo en una pequeña cantidad de datos de entrenamiento.-ejemplo de contexto.

Los resultados experimentales muestran que su método PromptPG propuesto supera a la mejor línea de base (Few-shot CoT GPT-3) en un 5,31 % en la precisión de las preguntas y respuestas, y su método reduce significativamente las predicciones en relación con los ejemplos en contexto seleccionados al azar. La varianza de , que mejora la estabilidad de este método.

bba93adda18d082e6d58065cbe2b4287.png

Enlace de papel:

https://arxiv.org/abs/2209.14610

Enlace de código:

https://github.com/lupantech/PromptPG

Página de inicio del proyecto:

https://promptpg.github.io/

Visualización de datos:

https://promptpg.github.io/explorar

5e55911699719e9203fc35243d89b980.png

Conjunto de datos TabMWP

A continuación se muestran dos ejemplos del conjunto de datos TabMWP. Una de ellas es una pregunta de texto libre (free-text) con respuesta numérica, y la otra es una pregunta de opción múltiple (multi-choice) con respuesta de texto. Como puede ver, cada problema proporciona una solución que implica un razonamiento paso a paso. Para resolver los problemas en TabMWP, el sistema debe tener la capacidad de búsqueda de tablas y razonamiento matemático de varios pasos.

Tome el ejemplo de la figura a continuación como ejemplo, para responder "cuánto gastará (si Tracy compra tres tipos de pan)", necesitamos averiguar los precios correspondientes de los tres tipos de pan en la tabla, y luego calcule el precio de compra de cada tipo de pan y súmelos para obtener el costo final.

c39bdbbc61bbc2bb76eeb2999e154c9d.png

Como muestran las estadísticas de la siguiente tabla, el conjunto de datos TabMWP contiene 38 431 problemas matemáticos tabulares. El 74,7% de las preguntas eran preguntas de texto libre y el 25,3% eran preguntas de opción múltiple. TabMWP tiene un total de 28 876 preguntas distintas, 6153 respuestas distintas y 35 442 soluciones distintas, lo que indica su rica diversidad en la distribución de preguntas. La longitud media de las preguntas fue de 22,1 palabras y la longitud media de las respuestas fue de 49,5 palabras, lo que demuestra la riqueza léxica de TabMWP.

Una característica distintiva de TabMWP es que cada pregunta viene con un contexto de tabla, sin el cual no se puede resolver la pregunta. TabMWP tiene un total de 37.644 tablas diferentes con un promedio de 5,9 filas y 2,2 columnas, 12,9 celdas y un máximo de 54 celdas. Estas estadísticas muestran que las tablas en TabMWP también son ricas en diversidad. 

61bbcb078487c9b8972a228f8c15e428.png

El conjunto de datos TabMWP tiene dos tipos de preguntas diferentes y cinco tipos de respuestas diferentes: 

9a184074cd3d5e40f822bf9caf364edb.png

Cada pregunta en TabMWP tiene un contexto tabular, que se representa en tres formatos: imagen, texto semiestructurado y estructurado. Esto abre la posibilidad de desarrollar diferentes tipos de modelos de inferencia. 

c66fe3ed2eb46c0ab2a0c0b22b5fda34.png

En comparación con los conjuntos de datos existentes, TabMWP requiere comprensión tabular y razonamiento matemático para responder preguntas. Además, cada pregunta de TabMWP tiene un proceso de razonamiento detallado de varios pasos, que tiene ventajas obvias en el tamaño del conjunto de datos, el tipo de formulario, el tipo de pregunta y el tipo de respuesta. Hasta donde sabemos, TabMWP es el primer conjunto de datos para el razonamiento matemático en escenarios tabulares de dominio abierto. 

c545a9fdc7d9ccdf836be33887f44691.png

f76c710c8238dabc12e7c1f53f3da068.png

Método PromptPG

Teniendo en cuenta el éxito de los modelos preentrenados a gran escala, como GPT-3, en la resolución de problemas matemáticos verbales, los autores primero establecieron un punto de referencia en TabMWP utilizando Few-shot GPT-3. Seleccionan aleatoriamente algunos ejemplos contextuales del conjunto de entrenamiento, así como también muestras de prueba para formar indicaciones, lo que hace que GPT-3 prediga la respuesta.

Sin embargo, estudios recientes han demostrado que este tipo de aprendizaje basado en una selección aleatoria puede ser muy inestable con diferentes opciones de ejemplos contextuales. La selección aleatoria puede ser menos efectiva cuando se trata de problemas de razonamiento complejos como TabMWP, ya que los problemas involucran tablas de diferentes tipos y formatos.

Para resolver este problema, el autor propone un método mejorado: el aprendizaje rápido a través de Policy Gradient, aprendiendo a seleccionar ejemplos contextuales a partir de una pequeña cantidad de datos de entrenamiento, llamado PromptPG .

Como se muestra en la Figura 2, la red de políticas aprende a encontrar el mejor ejemplo en contexto del grupo de candidatos (ejemplos de candidatos) y su objetivo de optimización es maximizar la predicción de un ejemplo de entrenamiento dado (ejemplo de entrenamiento) cuando interactúa con el GPT -3 ambiente premio. La red de políticas para el ejemplo elegido es un modelo de lenguaje BERT basado en parámetros fijos y una red neuronal de una sola capa con parámetros que se pueden aprender. Después de completar el aprendizaje de optimización, PromptPG puede seleccionar dinámicamente diferentes ejemplos óptimos de ejemplos candidatos para diferentes preguntas de prueba, a fin de maximizar el rendimiento de razonamiento de GPT-3. 

7e42343ad005fda2c5cec3ed6a156005.png

El siguiente es el algoritmo de aprendizaje de PromptPG. 

e830fccee09cfb4720942726a94f34a9.png

5c38649107d3e579aba78fde998dc93e.png


Experimento y Análisis

f44b5ac94af9b95114aef01c1e2058d0.png

Pre-entrenamiento y puesta a punto

La Tabla 3 compara los resultados de PromptPG y diferentes puntos de referencia en el conjunto de datos de TabMWP. Se puede ver que TAPEX funciona mejor que UnifiedQA bajo la premisa de cantidades de parámetros similares debido al entrenamiento previo en datos tabulares. Tanto para TAPEX como para UnifiedQA, aumentar la cantidad de parámetros en el modelo puede mejorar la precisión de la predicción. Además, ajustar el modelo en TabMWP también puede mejorar en gran medida la precisión de la predicción.

modelo de lenguaje a gran escala

GPT-3 puede lograr una precisión similar a la de los modelos UnifiedQA y TAPEX ajustados sin ningún ajuste fino (Zero-shot GPT-3). Si el modelo GPT-3 de pocos disparos selecciona aleatoriamente dos ejemplos en contexto como sugerencias de GPT-3, puede mejorar aún más en un 0,17 % en comparación con el GPT-3 de disparo cero. Al permitir que Few-shot GPT-3 generara un paso intermedio de varios pasos (Few-shot-CoT GPT-3) antes de generar la respuesta final, los investigadores pudieron obtener el mejor modelo de referencia, que logró una tasa de precisión del 62,92 %. .

PromptPG

A diferencia de la selección aleatoria de ejemplos en contexto, el PromptPG propuesto en este documento entrena una red de políticas a través de Policy Gradient para seleccionar ejemplos en contexto más apropiados Ha logrado el resultado de predicción más alto (68.23%) en TabMWP, y su predicción promedio la precisión supera al mejor modelo de referencia (Few-shot-CoT GPT-3) 5,31 %. Vale la pena señalar que PromptPG muestra su superioridad en la precisión de la predicción para casi todos los tipos de preguntas, tipos de respuestas y dificultad de las preguntas. Sin embargo, PromptPG está lejos del rendimiento humano del 90,22 %, y todavía hay mucho margen de mejora.

experimento de ablación

23c785b3cf86443b2348c40a34821d7f.png

La Tabla 4 muestra que todos los elementos de entrada de TabMWP (texto de la pregunta, información del formulario, información de la opción) son fundamentales para responder la pregunta correctamente. Zero-shot GPT-3 logra su precisión de predicción promedio relativamente más alta (59,50 %) solo con todos los elementos de la pregunta como entrada.

Diferentes opciones de muestra

7f8ac9139e92701d58830f5883fd00f5.png

Como experimento comparativo, los investigadores también compararon otros métodos con diferentes selecciones de ejemplos. Como se muestra en la Tabla 5, elegir el mismo tipo de pregunta o tipo de respuesta que la pregunta de la prueba puede ayudar al modelo a encontrar ejemplos más relevantes y mejorar la precisión de la respuesta. Elegir los ejemplos más complejos no mejora constantemente la precisión de las respuestas. La selección fija de los dos mejores ejemplos entre los ejemplos candidatos mejora ligeramente la precisión y reduce la varianza. Seleccionar los ejemplos que son semánticamente más cercanos a la pregunta de la prueba logra la precisión más cercana al método PromptPG. En general, PromptPG demuestra plenamente sus ventajas para mejorar la precisión del pronóstico y reducir la variación del pronóstico.

La siguiente figura muestra ejemplos de selecciones de PromptPG y las predicciones finales. Se puede ver que el método PromptPG puede seleccionar ejemplos con una capacidad matemática similar a los elementos de prueba, mejorando así el rendimiento de inferencia de GPT-3 de pocos disparos.

90a2f3d1f37d27468712ca5822630ddc.png

Ejemplos de éxito predictivo

A continuación se muestra la respuesta correcta de PromptPG a una pregunta de texto libre. El problema pide sumar y dividir ocho números en una tabla para encontrar el promedio. 

aa582a625c94ad3ad8bcbf3c2dc84347.png

En el siguiente ejemplo, se le pide al modelo que comprenda un informe de impuestos y calcule el salario después de las deducciones de impuestos.

89deca04ccc20a3dbf2869fc38cdad25.png

A continuación se muestran las predicciones correctas de PromptPG para preguntas de opción múltiple. La tabla dada tiene un total de 9 filas y 6 columnas. El modelo ubica con éxito la celda objetivo en la tabla y realiza un razonamiento de varios pasos para predecir la respuesta correcta. 

2e51cf99aad274814782ec67ab0e3164.png

En el siguiente ejemplo, el modelo necesita comparar el presupuesto y los costos totales para verificar que Ariana tenga suficiente dinero. 

ec5c7d16d10e4ef3ae4cc455c4143687.png

Ejemplos de errores de predicción

Lo siguiente demuestra las predicciones incorrectas de PromptPG para preguntas de texto libre. El modelo recupera el precio incorrecto del cuarzo rosa, calculando mal el costo total de los tres artículos.

89873d0a4a6d0d29ae543c41809f590a.png

En el siguiente ejemplo, la pregunta proporciona una tabla abstracta de tallo y hojas. El modelo no pudo comprender esta tabla específica del dominio y carecía de un razonamiento lógico avanzado para obtener respuestas incorrectas.

914a7109c05e7f5d6c4c9f8a75311952.png

El siguiente ejemplo muestra que los modelos existentes no parecen tener la capacidad de ordenar números.

1eeb8596c9143cb0b6bf187b666e20a6.png

En el siguiente ejemplo, la hora que es exactamente la misma que la hora actual mencionada en la pregunta no aparece en la tabla, por lo que el modelo no puede ubicar con precisión la hora de salida de la siguiente estación. 

59dd1d0ef3c20dda5117901cb17adc77.png

En el siguiente ejemplo, el modelo tiene dificultades para realizar con precisión operaciones aritméticas en una larga lista de números. 

7f7613fbee5221e4d9843021252827c5.png

cb1d64e0d9dc147d74a16be6c6ba00b9.png


Conclusión y perspectiva

Los autores presentan TabMWP, el primer conjunto de datos a gran escala para la resolución de problemas matemáticos en contextos tabulares. TabMWP contiene 38 431 preguntas de dominio abierto, incluidos dos tipos de preguntas y cinco tipos de respuestas, y cada pregunta está marcada con un proceso de solución de varios pasos.

Utilizando métodos de QA y TableQA de última generación, los autores llevan a cabo experimentos completos en TabMWP en configuraciones de preentrenamiento y ajuste fino, así como la evaluación con el modelo de lenguaje preentrenado a gran escala GPT-3. El autor propone además un nuevo método de aprendizaje por refuerzo PromptPG, que utiliza el aprendizaje de gradiente de políticas para seleccionar la instancia óptima de los datos de entrenamiento para impulsar el modelo GPT-3. Los resultados experimentales muestran que PromptPG supera significativamente las líneas de base existentes y reduce la inestabilidad del rendimiento en las predicciones en comparación con la selección aleatoria.

exterior_predeterminado.png

referencias

exterior_predeterminado.png

[1] Pan Lu, Liang Qiu, Wenhao Yu, Sean Welleck y Kai-Wei Chang. Una encuesta sobre el aprendizaje profundo para el razonamiento matemático. arXiv preprint arXiv:2212.10535, 2022b.

[2] Gabriel Barth-Maron, Matthew W Hoffman, David Budden, Will Dabney, Dan Horgan, Dhruva Tb, Alistair Muldal, Nicolas Heess y Timothy Lillicrap. Gradientes de política deterministas distribucionales distribuidos. preimpresión de arXiv arXiv:1804.08617, 2018.

[2] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Los modelos de lenguaje son aprendices de pocas oportunidades. Avances en los sistemas de procesamiento de información neuronal (NeurIPS), 33:1877–1901, 2020

[3] Jacob Devlin, Ming-Wei Chang, Kenton Lee y Kristina Toutanova. Bert: Pre-entrenamiento de transformadores bidireccionales profundos para la comprensión del lenguaje. preimpresión de arXiv arXiv:1810.04805, 2018.

[4] Daniel Khashabi, Sewon Min, Tushar Khot, Ashish Sabharwal, Oyvind Tafjord, Peter Clark y Hannaneh Hajishirzi. Unifiedqa: cruzando los límites del formato con un solo sistema qa. En Hallazgos de la Asociación de Lingüística Computacional (EMNLP), págs. 1896–1907, 2020.

[5] Takeshi Kojima, Shixiang Shane Gu, Machel Reid, Yutaka Matsuo y Yusuke Iwasawa. Los grandes modelos lingüísticos son razonadores de tiro cero. arXiv preprint arXiv:2205.11916, 2022.

[6] Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen y Jian-Guang Lou. Tapex: Table pre-training via learning a neural sql executor. En Conferencia Internacional sobre Representaciones de Aprendizaje (ICLR), 2022b.

[7] Jason Wei, Xuezhi Wang, Dale Schuurmans, Maarten Bosma, Ed Chi, Quoc Le y Denny Zhou. La cadena de impulsos de pensamiento provoca el razonamiento en grandes modelos de lenguaje. preimpresión de arXiv arXiv:2201.11903, 2022.


Ingrese al grupo NLP —> únase al grupo de intercambio NLP (comentario nips/emnlp/nlpcc ingresa al grupo de contribución correspondiente)

Continúe publicando la información más reciente, como la interpretación del procesamiento del lenguaje natural NLP, documentos diarios de alta calidad, información relevante de primera mano, posiciones de algoritmos de IA, etc.

Únete al planeta, obtendrás:

1.  Actualice de 3 a 5 lecturas de velocidad en papel más recientes y de la más alta calidad todos los días . En unos segundos , puede comprender el contenido general del artículo, incluido un resumen de una oración del artículo, el contenido general, la dirección de la investigación y la descarga del pdf.

2.  Los últimos materiales de aprendizaje introductorio y avanzado . Incluyendo aprendizaje automático, aprendizaje profundo, PNL y otros campos.

3.  La subdivisión específica de las direcciones de NLP incluye, entre otros : análisis de sentimientos, extracción de relaciones, gráfico de conocimiento, análisis de sintaxis, análisis semántico, traducción automática, diálogo humano-computadora, generación de texto, reconocimiento de entidad nombrada, resolución de referencia, modelo de lenguaje grande , aprendizaje de muestra cero, aprendizaje de muestra pequeña, generación de código, multimodalidad, destilación de conocimiento, compresión de modelos, AIGC, PyTorch, TensorFlow, etc.

4.  Información de contratación diaria de 1 a 3 para puestos de AI como PNL, búsqueda, promoción y promoción, y CV . Se pueden organizar entrevistas simuladas.

85b98052efe32f18c5d4cbe6f331bddf.png

Supongo que te gusta

Origin blog.csdn.net/qq_27590277/article/details/130097131
Recomendado
Clasificación