Interpretation of the paper: Contrastive Learning Reduces Hallucination in Conversations

Interpretation of the paper: Contrastive Learning Reduces Hallucination in Conversations

image.png

Github:https://github.com/sunnweiwei/MixCL

1. Motivation

  • A large amount of knowledge (e.g., common sense, facts, etc.) is crucial for open-domain dialogue systems, and to inject knowledge, a retrieval link is usually involved. Nowadays, the big language language model can be used as a high-quality dialogue robot to generate more valuable information replies;
  • However, large models often have the problem of hallucinations, which generate plausible information that is actually irrelevant or wrong in context;
  • By randomly sampling 200 samples from Wizard-of-Wikipedia, and letting BART generate corresponding responses, according to the 200 responses obtained, three experts were invited to mark from the perspectives of internal hallucinations and external hallucinations. The results showed that more than 50% of the responses were hallucinations.

image.png
The specific proportion is as follows:
image.png

  • There are many kinds of problems that create this hallucination, such as inconsistencies in the goals of the training phase and the testing phase. The likelihood estimation is maximized during training, which leads to the generation according to this mode during inference.
  • Previous work to solve the illusion is usually the way to inject external knowledge bases, such as retrieval (retrieve) and post-processing (post-editing).

2. Method

problem definition

Given a question or context xxx , a corresponding retrieved knowledgeK \mathcal{K}K , the goal is to generate responses yybased on context and knowledgey .
There are currently two modes of dialogue, as shown in the figure below:
image.png

  • KB mode: Retrieve the knowledge base according to the dialogue context, obtain the retrieved document and generate a reply based on the context;
  • LM mode: Today's language model paradigm, that is, let the language model be pre-trained on the knowledge base first, and then directly answer;

This article focuses on the LM mode
(1) Pre-training: using BART as a language model:
image.png
(2) SFT (Fine-tuning): using MLE targets for autoregressive training on dialogue datasets:
image.png
however, MLE loss encourages models to blindly imitate training The data does not lead to model hallucinations , which rely too much on the previous tokens, which can easily lead to error propagation.

Studies have found that models trained using standard MLE may over-rely on previously predicted labels, exacerbating error propagation (Wang and Sennrich 2020). As a result, during the inference phase, as the generated sequence grows and errors accumulate along the sequence, the model tends to amplify errors and hallucinate content.
Studies have found that models trained with standard MLE may over-rely on previously predicted tokens, exacerbating error propagation (Wang and Sennrich 2020). As a result, during the inference stage, as the generated sequence grows, the errors accumulate along the sequence, and the model tends to amplify errors and generate hallucinating contents.

MixCL

This paper proposes MixCL, a training strategy based on mixed contrastive learning to reduce model hallucinations.
The method is shown in the figure below:
image.png
it mainly includes two core steps: Negative Sampling and Mixed Contrastive Learning

Negative Sampling

z + z^{+} z+ represents the correct knowledge or text fragment, representing positive, which passes a functionQP os ( x ) Q_{Pos}(x)QPos( x ) to achieve positive acquisition. The input to this function is the original textxxx , output the correct knowledge fragments, which can be manual annotations or heuristic rules.
z − z^-z−Indicates negative, that is, non-factual or with the inputxxx has irrelevant knowledge fragments. In this paper, two methods are designed to obtainz − z^-z method:
(1) Retrieval formula: use TF-IDF retriever, given input textxxx and a knowledge baseK \mathcal{K}K , output a set ofz − z^-z . Due to the use of TF-IDF, there is some confusion between the sampled fragment and the input text, but it is still negative;
image.png
(2) Model generation: A bootstrapping strategy is proposed to obtain negative
image.png
fragments generated by using the NLI tool to constrain the modelDoes not contain correct knowledge.
Based on the above two methods, the negative sampling function is finally constructed:
image.png

Mixed Contrastive Learning

First, the loss design of comparative learning is as follows:
image.png
lll Display cross-entropy loss,MMM is the number of negative samples.
In training in BERT or GPT mode, usuallylll is either token-based loss or sentence-based loss. However, the illusion produced by the model is usually a text interval (span), so this paper proposes span-based contrastive learning.
(1) Extracting intervals
First, extract intervals from positive and negative texts respectively.
Considering that hallucinations have internal hallucinations and external hallucinations, two span extraction strategies are designed.

  • Internal hallucinations: Usually there is confusion at the entity level, so NER can be used to extract entities such as person and time;
  • External hallucinations: irrelevant text appears in the text, so constituency parsing is used to extract sentence components, such as noun, particle, etc.

(2) Build a Mixing example
refer to Mix-up and other work, and mix-up a positive sample and a negative sample: z ~ = M ix ( z + , z − ) \tilde{z}=Mix(z^+, z ^-)z~=Mix(z+,z ).
The specific operation is as follows:

  • Given a positive sample z + z^+z+ and negative samplesz − z^-z
  • Randomly sample a previously drawn interval from the positive sample;
  • Then randomly sample a previously drawn interval from the negative sample;
  • Replace the interval in the negative sample with the interval of the positive sample to get z ~ \tilde{z}z~;
  • Define a ϕ \phiϕ sequence whose length isz ~ \tilde{z}z~ , each element of the sequence is 0 or 1, where 0 means corresponding toz ~ \tilde{z}z~ The token of position comes fromz − z^-z , 1 means the correspondingz ~ \tilde{z}z~ The token of position comes fromz + z^+z+

In fact, 0/1 represents the mixed sequence z ~ \tilde{z}z~ The corresponding token is a negative sample/positive sample.

(3)Contrastive Loss

For the entire dataset, given an input xxx , first obtain a corresponding positive samplez + z^+z+ , then sample to obtainMMM negative sampleszi − z_i^-zi.
All input xxThe total loss corresponding to x is defined as follows:
image.png

For a certain positive sample z + z^+z+ and negative sampleszi − z_i^-zipair, its loss is defined as follows:
image.png
where z ~ i = M ix ( z + , zi − ) \tilde{z}_i=Mix(z^+, z_i^-)z~i=Mix(z+,zi) ∣ z ~ i ∣ |\tilde{z}_i| z~i indicates the number of tokens in this sequence,ϕ ij \phi_{ij}ϕijView z ~ i \tilde{z}_iz~ijj ofWhether j tokens are positive.
It can be seen that the loss is still based on the token-based Causal Languege Modeling goal, but the difference is that some of the corresponding tokens come from positive and some are negative, and the negative token can be considered as the hallucination part of the simulation during the training process.

  • During training, if ϕ ij = 1 \phi_{ij}=1ϕij=1 , indicating that the current token is positive, you only need to maximize the probability of the token being predicted;
  • If ϕ ij = 0 \phi_{ij}=0ϕij=0 , indicating that the current token is negative, and it is necessary to minimize the probability of this token being predicted.

The final total training loss is:
image.png
At initialization, α 1 = 0.4 \alpha_1=0.4a1=0.4α 2 = 0.3 \alpha_2=0.3a2=0.3α 3 = 0.3 \alpha_3=0.3a3=0.3
Then these parameters are changed linearly, and finallyα 1 = 0.5 \alpha_1=0.5a1=0.5α 2 = 0.5 \alpha_2=0.5a2=0.5α 3 = 0 \alpha_3=0a3=0 .
The reason why at the beginningα 3 > 0 \alpha_3>0a3>0 , the purpose is to prevent the model from catastrophic forgetting.

3. Experiment

data set

Wizard-of-Wikipedia(WoW)

Evaluation index

F1、ROUGE-L、BLEU(2/4)、MT、Knowledge-F1、Entity-F1、Acc。

F1 (Dinan et al. 2019) calculates the unigram F1 between the generated text and the ground-truth text. For ROUGE (Lin 2004) we use ROUGE-L (RL for short) following previous work. BLEU (Papineni et al. 2002) we use BLEU-2 and BLEU-4 (or B2 and B4 for short) and use the implementation in the NLTK Toolkit. MT (Meteor) (Denkowski and Lavie 2014) is based on the harmonic mean of unigram precision and recall. Knowledge-F1 (Dinan et al. 2019) (or KF1 for short) calculates the F1 between the generated response and the ground-truth knowledge sentence, which indicates the informativeness of a response. Acc measures the knowledge selection accuracy. As we skip the knowledge selection step, we select knowledge by matching the generated response with each knowledge candidate in WoW using the F1 score. Entity-F1 (or EF1 for short) identifies entities in text using Spacy, deletes the non-entity words, and calculates the F1 score between the modified generated text and the ground- truth response. EF1 eliminates the impact of the stop-word and focuses on the accuracy of entities.

The implementation reference of these evaluation indicators: https://github.com/sunnweiwei/MixCL/blob/main/utils/evaluation.py
In addition, three new labelers are invited to label 100 samples in the test samples, from four Aspects are scored.

Informativeness(0、1、2分), which measures whether the response is knowledge-inclusive; Relevancy(0、1、2分), which measures whether the response’s content is relevant to the dialogue; Factuality(0或1分), which measures whether the information in the response is factually correct; and Humanlikeness(0、1、2分), which measures whether the response is human-like in its fluency and naturalness.

Experiment Details

The backbone chooses BART-Large (400M), and the knowledge base is Wikipedia

Experimental results

(1) Automatic evaluation
image.png
It can be seen that the effect on various indicators has been significantly improved.
(2) Manual evaluation
image.png
The MixCL proposed by Restraint is also the highest in terms of manual scoring, and some indicators are also close to the scoring of human responses.
(3) Ablation experiment
image.png
Three losses and two sampling functions were used for model training. It is found that if one part is missing, the effect will decrease. However, the decline in the index is not obvious
(4) Validity verification
image.png
The horizontal axis represents the waiting time for the model to generate results, and the vertical axis represents the F1 value.
It can be seen that our method obtains the best F1 value with the least latency (waiting time), indicating that the overall performance is very good.
(5) Case Study
image.png

Guess you like

Origin blog.csdn.net/qq_36426650/article/details/132001440