linear self attention 的pytorch实现 和使用

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/guotong1988/article/details/86502457
# For summarizing a set of vectors into a single vector
class LinearSelfAttn(nn.Module):
    """Self attention over a sequence:
    * o_i = softmax(Wx_i) for x_i in X.
    """
    def __init__(self, input_size):
        super(LinearSelfAttn, self).__init__()
        self.linear = nn.Linear(input_size, 1)

    def forward(self, x, x_mask):
        """
        x = [batch, len, hdim]
        x_mask = [batch, len]
        """
        x = dropout(x, p=my_dropout_p, training=self.training)

        x_flat = x.contiguous().view(-1, x.size(-1))
        scores = self.linear(x_flat).view(x.size(0), x.size(1))
        scores.data.masked_fill_(x_mask.data, -float('inf'))
        alpha = F.softmax(scores, dim=1)
        return alpha # [batch, len]
# bmm: batch matrix multiplication
# unsqueeze: add singleton dimension
# squeeze: remove singleton dimension
def weighted_avg(x, weights): 
    """ x = [batch, len, d]
        weights = [batch, len]
    """
    return weights.unsqueeze(1).bmm(x).squeeze(1)

使用:

# [batch,sentence_len,hidden_dim] -> [batch,sentence_len]
sentence_weights = linear_self_attn(sentence_hiddens, sentence_mask) 

# [batch,hidden_dim]
sentence_avg_hidden = weighted_avg(sentence_hiddens, sentence_weights)

猜你喜欢

转载自blog.csdn.net/guotong1988/article/details/86502457
今日推荐