Semantic similarity matching (two)-ESIM model

Semantic similarity matching (1)-DSSM model

 

I. Overview

Paper source: TACL 2017

Paper link: Enhanced LSTM for Natural Language Inference

An enhanced version of LSTM designed for natural language inference

advantage:

Finely designed sequential inference structure.
Consider local inference and global inference.

Inter -sentence attention mechanism (intra-sentence attention) to realize local inference and further realize global inference

 

2. Model principle

The author mentioned that the syntactic LSTM tree can be used for processing, or BiLSTM can be used for processing. Here I only introduce the BiLSMT method. If you are interested in the content of the LSTM tree, you can read the paper yourself.

As shown in the figure above, the model is mainly divided into three parts: Input Encoding, Local Inference Modeling, Inference Composition

 

2.1  Input Encoding

First, the input is the embedding of two queries directly, and then the BiLSTM is connected.

Attach code

def forward(self, *input):
   # batch_size * seq_len
    sent1, sent2 = input[0], input[1]
    mask1, mask2 = sent1.eq(0), sent2.eq(0)

   # embeds: batch_size * seq_len => batch_size * seq_len * embeds_dim
    x1 = self.bn_embeds(self.embeds(sent1).transpose(1, 2).contiguous()).transpose(1, 2)
    x2 = self.bn_embeds(self.embeds(sent2).transpose(1, 2).contiguous()).transpose(1, 2)

   # batch_size * seq_len * embeds_dim => batch_size * seq_len * hidden_size
    o1, _ = self.lstm1(x1)
    o2, _ = self.lstm1(x2)    

This piece is relatively simple, and the corresponding code is relatively easy, so I won’t explain too much

 

2.2. Local Inference Modeling

First, calculate the two sentences word similarity between, to obtain the similarity matrix 

Align

The back-end information is spliced ​​and aligned, where the vectors before and after the alignment are spliced, and the difference and dot product of the vectors before and after the alignment are used to obtain the difference.

The principle is over, it feels a bit convoluted, let's go to the code

def soft_align_attention(self, x1, x2, mask1, mask2):
    '''
     x1: batch_size * seq_len * hidden_size
     x2: batch_size * seq_len * hidden_size
    '''
    # attention: batch_size * seq_len * seq_len
     attention = torch.matmul(x1, x2.transpose(1, 2))
     mask1 = mask1.float().masked_fill_(mask1, float('-inf'))
     mask2 = mask2.float().masked_fill_(mask2, float('-inf'))

    # weight: batch_size * seq_len * seq_len
     weight1 = F.softmax(attention + mask2.unsqueeze(1), dim=-1)
     x1_align = torch.matmul(weight1, x2)
     weight2 = F.softmax(attention.transpose(1, 2) + mask1.unsqueeze(1), dim=-1)
     x2_align = torch.matmul(weight2, x1)
   
    # x_align: batch_size * seq_len * hidden_size
     return x1_align, x2_align    

def submul(self, x1, x2):
    mul = x1 * x2
    sub = x1 - x2
    return torch.cat([sub, mul], -1)    

def forward(self, *input):
    ···
    
    # Attention
    # output: batch_size * seq_len * hidden_size
    q1_align, q2_align = self.soft_align_attention(o1, o2, mask1, mask2)

    # Enhancement of local inference information
    # batch_size * seq_len * (8 * hidden_size)
    q1_combined = torch.cat([o1, q1_align, self.submul(o1, q1_align)], -1)
    q2_combined = torch.cat([o2, q2_align, self.submul(o2, q2_align)], -1)

The other logic is easy to understand. Here I will focus on soft_align_attention. I translate it as soft attention alignment. I don’t know if it’s right.

Through code interpretation, attention is like the similarity matrix obtained by multiplying as mentioned above, and then weight1 is obtained by splicing mask2 with attention, and the result of multiplying with x2 is the part of x2 that is related to x1; the align of x2 is the same. .

For specific x1 x2 mask1 mask2, o1, o2, please refer to the code in the first part above.

 

2.3 Inference Composition

This step goes directly to the two-way LSTM

And through average pooling and max pooling operations

Let's talk about the splicing of these results

Finally, a fully connected layer is added, tanh used by the fully connected layer activation function, and finally the final result is normalized by softmax.

def apply_multiple(self, x):
    # input: batch_size * seq_len * (2 * hidden_size)
    p1 = F.avg_pool1d(x.transpose(1, 2), x.size(1)).squeeze(-1)
    p2 = F.max_pool1d(x.transpose(1, 2), x.size(1)).squeeze(-1)
    # output: batch_size * (4 * hidden_size)
    return torch.cat([p1, p2], 1)

def forward(self, *input):
    ...
    
    # inference composition
    # batch_size * seq_len * (2 * hidden_size)
    q1_compose, _ = self.lstm2(q1_combined)
    q2_compose, _ = self.lstm2(q2_combined)

    # Aggregate
    # input: batch_size * seq_len * (2 * hidden_size)
    # output: batch_size * (4 * hidden_size)
    q1_rep = self.apply_multiple(q1_compose)
    q2_rep = self.apply_multiple(q2_compose)

    # Classifier
    x = torch.cat([q1_rep, q2_rep], -1)
    sim = self.fc(x)
    return sim

 

Code address: https://github.com/pengshuang/Text-Similarity   (refer to other big guys, here includes ESIM, SiaGRU, ABCNN, BiMPM four text similarity models)

Three, thinking

The author uses soft attention alignment, and can learn the similarity influence factor between two queries through the DL model, so as to achieve better results.

Here is a mention. In an article I saw before, the association classification of brand words and attribute words is similar to this method, but also through various mutual attentions to achieve better results.

 This kind of algorithm improves the model's ability to learn the relationship between the two through the intermediate interaction logic.

 


4. References:

Short text matching tool-ESIM  https://zhuanlan.zhihu.com/p/47580077 

https://blog.csdn.net/qq_36733823/article/details/101907000   (College Computer Competition with code)

https://blog.csdn.net/pengmingpengming/article/details/88534968  Semantic matching models based on deep learning DSSM, ESIM, BIMPM, ABCNN

 

Personal understanding is relatively shallow, if you have any questions, please point out.

In the process of sorting out some related documents and papers were referenced. If there is any infringement, it is not my intention, please contact me to modify or indicate the source, thank you!

 

Guess you like

Origin blog.csdn.net/katrina1rani/article/details/110135791