一组向量 和 一组向量 的attention计算,pytorch实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/guotong1988/article/details/86503378
class GetAttentionHiddens(nn.Module):
    def __init__(self, input_size, attention_hidden_size, similarity_attention = False):
        super(GetAttentionHiddens, self).__init__()
        self.scoring = AttentionScore(input_size, attention_hidden_size, similarity_score=similarity_attention)

    def forward(self, x1, x2, x2_mask, x3=None, scores=None, return_scores=False, drop_diagonal=False):
        """
        Using x1, x2 to calculate attention score, but x1 will take back info from x3.
        If x3 is not specified, x1 will attend on x2.

        x1: [batch, len1, x1_input_size]
        x2: [batch, len2, x2_input_size]
        x2_mask: [batch, len2]

        x3: [batch, len2, x3_input_size]
        """
        if x3 is None:
            x3 = x2

        if scores is None:
            scores = self.scoring(x1, x2)

        # Mask padding
        x2_mask = x2_mask.unsqueeze(1).expand_as(scores)
        scores.data.masked_fill_(x2_mask.data, -float('inf'))
        if drop_diagonal:
            assert(scores.size(1) == scores.size(2))
            diag_mask = torch.diag(scores.data.new(scores.size(1)).zero_() + 1).byte().unsqueeze(0).expand_as(scores)
            scores.data.masked_fill_(diag_mask, -float('inf'))

        # Normalize with softmax
        alpha = F.softmax(scores, dim=2)

        # Take weighted average
        matched_seq = alpha.bmm(x3)
        if return_scores:
            return matched_seq, scores
        else:
            return matched_seq # [batch, len1, x1_input_size]
class AttentionScore(nn.Module):
    """
    sij = Relu(Wx1)DRelu(Wx2)
    """
    def __init__(self, input_size, attention_hidden_size, similarity_score = False):
        super(AttentionScore, self).__init__()
        self.linear = nn.Linear(input_size, attention_hidden_size, bias=False)

        if similarity_score:
            self.linear_final = Parameter(torch.ones(1, 1, 1) / (attention_hidden_size ** 0.5), requires_grad = False)
        else:
            self.linear_final = Parameter(torch.ones(1, 1, attention_hidden_size), requires_grad = True)

    def forward(self, x1, x2):
        """
        x1: [batch, len1, input_size]
        x2: [batch, len2, input_size]
        scores: [batch, len1, len2] 
        <the scores are not masked>
        """
        x1 = dropout(x1, p=my_dropout_p, training=self.training)
        x2 = dropout(x2, p=my_dropout_p, training=self.training)

        x1_rep = self.linear(x1.contiguous().view(-1, x1.size(-1))).view(x1.size(0), x1.size(1), -1)
        x2_rep = self.linear(x2.contiguous().view(-1, x2.size(-1))).view(x2.size(0), x2.size(1), -1)

        x1_rep = F.relu(x1_rep)
        x2_rep = F.relu(x2_rep)
        final_v = self.linear_final.expand_as(x2_rep)

        x2_rep_v = final_v * x2_rep
        scores = x1_rep.bmm(x2_rep_v.transpose(1, 2))
        return scores # [batch, len1, len2]

猜你喜欢

转载自blog.csdn.net/guotong1988/article/details/86503378
今日推荐