Teacher Gavin's Transformer Live Lessons - Detailed CRF Modeling in NLP Information Extraction (2)

I. Overview

This article describes how to implement linear-chain CRF through Pytorch. CRFs generalize any undirected graph structure, such as sequence, tree, graph, etc., and use a serialized structure in the implementation, that is, the model is based on the previous state transition. Judging, such a model is called Linear-chain CRF.

2. Implementing linear-chain CRF through Pytorch

  1. Theory about CRF

Since the publication of the first paper of CRF, it has been widely used in the field of machine learning, such as computing in the field of bioinformatics, computer vision, and natural language processing. The combination of CRFs model and LSTMs model brings good results. In the DIET architecture diagram mentioned earlier, CRF is used to extract information from the output processed by the Transformer. When a layer of CRF is built on top of the BiLSTM model, the POS tagging processing for the sequence can obtain more accurate results.

The following is a sample test result:

For CRFs, the goal of dealing with a sequence classification problem is to find the probability of a label sequence y given an input sequence vector X, and the conditional probability is expressed as P(y | X).

The following symbolic representations can be defined:

Training set: input and target sequence pair {(Xi,yi)}

The ith vector input sequence:

The i-th labels target sequence:

The conditional probability can be expressed as:

The use of regularization here to model P(yk|xk) is similar to the softmax transformation widely used in neural networks. The reason for using the exp function is:

- Avoid underflow, when multiplying very small numbers you get smaller numbers, potentially causing underflow

- Avoid Non-negative output: all values ​​are greater than 0

- Avoid monotonically increasing values: values ​​will get bigger and bigger

U(x,y) is called the unary score, that is, the predicted score of y based on vector x at the kth time step. X can correspond to any content. In practice, x is usually a concatenation of surrounding elements, such as word embeddings from a sliding window. The weight of each unary factor is measured by a learnable weight in the model, which is easy to understand if you think of them as LSTM outputs.

Z(x) is usually called the partition function, and we can think of it as a factor for regularization, because we need to get the probability at the end, which is similar to the denominator of the softmax function.

Having described a general classification model so far, it is now possible to add a learnable weight to model the transfer of label yk to label yk+1. Through modeling, dependencies between consecutive labels can be created, so this is Linear-chain CRF. In order to achieve this purpose, the previous probability can be multiplied by P(yk+1|yk), and the multiplication can be transformed into the unary fraction U(x, y) and the learnable transition fraction T(y, y) by exponential operation. and.

The entire conditional probability formula is expressed as follows:

In terms of code implementation, T(y, y) can be regarded as a matrix: (nb_labels, nb_labels), where each entry is a learnable parameter, representing the state from the i-th label to the j-th label transfer. Below is a description of these variables:

-emissions or unary scores: probability score for predicting yk given input xk

-transition scores(T): Indicates the probability score that yk is followed by yk+1

-partition function(Z): Regularization factor (factor) to obtain probability distribution based on sequence

The next step is to define function Z. Since it is a regularization operation, all possible combinations based on the label set at each time step need to be considered. If there are l combinatorial computations, then the time complexity is O(|y|∧l).

We can compute efficiently using dynamic programming by exploiting circular dependencies, such algorithms are called forward algorithms or backward algorithms, depending on the order in which you iterate the sequence. This has nothing to do with the forward propagation and back propagation algorithms mentioned in Neural Networks.

Guess you like

Origin blog.csdn.net/m0_49380401/article/details/123605916
Recommended