语义相似度匹配(二)—— ESIM模型

语义相似度匹配(一)—— DSSM模型

一、概述

论文来源:TACL 2017

论文链接:Enhanced LSTM for Natural Language Inference

一种专为自然语言推断而生的加强版 LSTM

优点:

精细的设计序列式的推断结构。
考虑局部推断和全局推断。

句子间的注意力机制(intra-sentence attention),来实现局部的推断,进一步实现全局的推断

 

二、 模型原理

作者提到,可以采用句法的LSTM树来处理,也可以用BiLSTM处理, 这里我只介绍BiLSMT的方法,LSTM树的内容有兴趣可以自己阅读论文。

如上图,模型主要分三部分:Input Encoding, Local Inference Modeling, Inference Composition

扫描二维码关注公众号,回复: 12364393 查看本文章

2.1  Input Encoding

首先, 输入是直接采用两个query的embedding,接BiLSTM得到。

附上代码

def forward(self, *input):
   # batch_size * seq_len
    sent1, sent2 = input[0], input[1]
    mask1, mask2 = sent1.eq(0), sent2.eq(0)

   # embeds: batch_size * seq_len => batch_size * seq_len * embeds_dim
    x1 = self.bn_embeds(self.embeds(sent1).transpose(1, 2).contiguous()).transpose(1, 2)
    x2 = self.bn_embeds(self.embeds(sent2).transpose(1, 2).contiguous()).transpose(1, 2)

   # batch_size * seq_len * embeds_dim => batch_size * seq_len * hidden_size
    o1, _ = self.lstm1(x1)
    o2, _ = self.lstm1(x2)    

这块比较简单,对应代码也比较容易,就不做过多解释了

2.2. Local Inference Modeling

首先,计算两个句子 word 之间的相似度,得到相似度矩阵 

对齐

拼接对齐后端信息,这里拼接了对齐前后的向量,对齐前后向量的差和点积,获取差异性。

原理说完了,感觉有一点绕,还是上代码吧

def soft_align_attention(self, x1, x2, mask1, mask2):
    '''
     x1: batch_size * seq_len * hidden_size
     x2: batch_size * seq_len * hidden_size
    '''
    # attention: batch_size * seq_len * seq_len
     attention = torch.matmul(x1, x2.transpose(1, 2))
     mask1 = mask1.float().masked_fill_(mask1, float('-inf'))
     mask2 = mask2.float().masked_fill_(mask2, float('-inf'))

    # weight: batch_size * seq_len * seq_len
     weight1 = F.softmax(attention + mask2.unsqueeze(1), dim=-1)
     x1_align = torch.matmul(weight1, x2)
     weight2 = F.softmax(attention.transpose(1, 2) + mask1.unsqueeze(1), dim=-1)
     x2_align = torch.matmul(weight2, x1)
   
    # x_align: batch_size * seq_len * hidden_size
     return x1_align, x2_align    

def submul(self, x1, x2):
    mul = x1 * x2
    sub = x1 - x2
    return torch.cat([sub, mul], -1)    

def forward(self, *input):
    ···
    
    # Attention
    # output: batch_size * seq_len * hidden_size
    q1_align, q2_align = self.soft_align_attention(o1, o2, mask1, mask2)

    # Enhancement of local inference information
    # batch_size * seq_len * (8 * hidden_size)
    q1_combined = torch.cat([o1, q1_align, self.submul(o1, q1_align)], -1)
    q2_combined = torch.cat([o2, q2_align, self.submul(o2, q2_align)], -1)

其他的逻辑还是好理解的,这里重点说一下soft_align_attention 这块,我翻译为注意力软对齐,不知道对不对。

通过代码解读,attention 像是前面说的相乘得到的相似矩阵,然后 weight1,是通过attention拼接mask2得到,跟x2相乘得到的结果即为x2中跟x1相关联的部分;x2 的align同理。

具体x1 x2 mask1 mask2, o1, o2,可以参见前面第一部分的代码。

2.3 Inference Composition

这一步直接过双向LSTM

并通过average pooling和max pooling操作

再讲这些结果拼接

最后加一个全连接层,全连接层激活函数采用的tanh,最后通过softmax归一化得到最终结果。

def apply_multiple(self, x):
    # input: batch_size * seq_len * (2 * hidden_size)
    p1 = F.avg_pool1d(x.transpose(1, 2), x.size(1)).squeeze(-1)
    p2 = F.max_pool1d(x.transpose(1, 2), x.size(1)).squeeze(-1)
    # output: batch_size * (4 * hidden_size)
    return torch.cat([p1, p2], 1)

def forward(self, *input):
    ...
    
    # inference composition
    # batch_size * seq_len * (2 * hidden_size)
    q1_compose, _ = self.lstm2(q1_combined)
    q2_compose, _ = self.lstm2(q2_combined)

    # Aggregate
    # input: batch_size * seq_len * (2 * hidden_size)
    # output: batch_size * (4 * hidden_size)
    q1_rep = self.apply_multiple(q1_compose)
    q2_rep = self.apply_multiple(q2_compose)

    # Classifier
    x = torch.cat([q1_rep, q2_rep], -1)
    sim = self.fc(x)
    return sim

代码地址:https://github.com/pengshuang/Text-Similarity  (参考其它大佬的,这里包含 ESIM、SiaGRU、ABCNN、BiMPM四种文本相似模型)

三、思考

作者采用了注意力软对齐,能通过DL模型学到两个query之间的相似度影响因子,从而取得更好的效果。

这里提一句,之前见到的一篇文章,品牌词和属性词关联分类,跟这个方法异曲同工,也是通过各种相互之间注意力,从而达到较好的效果。

 类似这种算法 通过中间的交互逻辑,提高了模型对于两者关系的学习能力。


四、 参考文献:

短文本匹配的利器-ESIM https://zhuanlan.zhihu.com/p/47580077 

https://blog.csdn.net/qq_36733823/article/details/101907000  (高校计算机大赛 附代码)

https://blog.csdn.net/pengmingpengming/article/details/88534968 基于深度学习的语义匹配若干模型DSSM,ESIM, BIMPM, ABCNN

个人理解,较为浅薄,如有问题,请指出。

整理过程中参考了一些相关文献和论文,如有侵权,非我本意,请联系我进行修改或注明出处,谢谢!

猜你喜欢

转载自blog.csdn.net/katrina1rani/article/details/110135791