pytorch 实现Transformer encoder

import torch
from torch import nn
import torch.nn.functional as F
import math

class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.q = nn.Linear(embed_dim, head_dim)
        self.k = nn.Linear(embed_dim, head_dim)
        self.v = nn.Linear(embed_dim, head_dim)
        self.dropout = nn.Dropout(0.2)

    def forward(self, query, key, value, mask=None):
        query, key, value = self.q(query), self.k(key), self.v(value)
        scores = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(query.size(-1))
        if mask is not None:
            mask = mask.unsqueeze(dim=1).repeat(1, mask.size(1), 1)  # [batch_size, seq_len, seq_len]
            assert scores.size() == mask.size()
            scores = scores.masked_fill(mask == 0, -float("inf"))
        weights = self.dropout(F.softmax(scores, dim=-1))
        return torch.bmm(weights, value)

class SequentialMultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        head_dim = hidden_size // num_heads
        self.heads = nn.ModuleList(
            [AttentionHead(hidden_size, head_dim) for _ in range(num_heads)]
        )
        self.output_linear = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(0.2)

    def forward(self, query, key, value, attn_mask=None, query_mask=None, key_mask=None):
        if query_mask is not None and key_mask is not None:
            attn_mask = torch.bmm(query_mask.unsqueeze(-1), key_mask.unsqueeze(1))
        x = torch.cat([h(query, key, value, attn_mask) for h in self.heads], dim=-1)
        x = self.dropout(self.output_linear(x))
        return x



class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k, dropout=.1):
        super(ScaledDotProductAttention, self).__init__()
        self.scale_factor = math.sqrt(d_k)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, attn_mask=None):
        # q: [b_size x n_heads x len_q x d_k]
        # k: [b_size x n_heads x len_k x d_k]
        # v: [b_size x n_heads x len_v x d_v] note: (len_k == len_v)

        # attn: [b_size x n_heads x len_q x len_k]
        scores = torch.matmul(q, k.transpose(-1, -2)) / self.scale_factor
        if attn_mask is not None:
            assert attn_mask.size() == scores.size()
            scores.masked_fill_(attn_mask == 0, -1e9)
        attn = self.dropout(F.softmax(scores, dim=-1))

        # outputs: [b_size x n_heads x len_q x d_v]
        context = torch.matmul(attn, v)
        return context, attn


class _MultiHeadAttention(nn.Module):
    def __init__(self, d_k, d_v, d_model, n_heads, dropout):
        super(_MultiHeadAttention, self).__init__()

        self.d_k = d_k
        self.d_v = d_v
        self.d_model = d_model
        self.n_heads = n_heads

        self.w_q = nn.Linear(d_model, d_k * n_heads)
        self.w_k = nn.Linear(d_model, d_k * n_heads)
        self.w_v = nn.Linear(d_model, d_v * n_heads)

        self.attention = ScaledDotProductAttention(d_k, dropout)

    def forward(self, q, k, v, attn_mask):
        # q: [b_size x len_q x d_model]
        # k: [b_size x len_k x d_model]
        # v: [b_size x len_k x d_model]
        b_size = q.size(0)

        # q_s: [b_size x n_heads x len_q x d_k]
        # k_s: [b_size x n_heads x len_k x d_k]
        # v_s: [b_size x n_heads x len_k x d_v]
        q_s = self.w_q(q).view(b_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        k_s = self.w_k(k).view(b_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v_s = self.w_v(v).view(b_size, -1, self.n_heads, self.d_v).transpose(1, 2)

        if attn_mask is not None:  # attn_mask: [b_size x len_q]
            attn_mask = attn_mask[:, None, None,:].repeat(1, self.n_heads, attn_mask.size(1), 1)  # [b, n_heads, len_q, len_k]
        # context: [b_size x n_heads x len_q x d_v],
        # attn: [b_size x n_heads x len_q x len_k]
        context, attn = self.attention(q_s, k_s, v_s, attn_mask=attn_mask)
        # context: [b_size x len_q x n_heads * d_v]
        context = context.transpose(1, 2).contiguous().view(b_size, -1, self.n_heads * self.d_v)

        # return the context and attention weights
        return context, attn



class MultiHeadAttention(nn.Module):
    def __init__(self, d_k, d_v, d_model, n_heads, dropout):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.multihead_attn = _MultiHeadAttention(d_k, d_v, d_model, n_heads, dropout)
        self.proj = nn.Linear(n_heads * d_v, d_model)
        self.dropout = nn.Dropout(dropout)
        # self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, q, k, v, attn_mask):
        # q: [b_size x len_q x d_model]
        # k: [b_size x len_k x d_model]
        # v: [b_size x len_v x d_model] note (len_k == len_v)
        # context: a tensor of shape [b_size x len_q x n_heads * d_v]
        context, attn = self.multihead_attn(q, k, v, attn_mask=attn_mask)

        # project back to the residual size, outputs: [b_size x len_q x d_model]
        output = self.dropout(self.proj(context))
        return output


class FeedForward(nn.Module):
    def __init__(self, hidden_size, intermediate_size, dropout):
        super().__init__()
        self.linear_1 = nn.Linear(hidden_size, intermediate_size)
        self.linear_2 = nn.Linear(intermediate_size, hidden_size)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.gelu(x)
        x = self.linear_2(x)
        x = self.dropout(x)
        return x

class TransformerEncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
        self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
		
		# 比较慢的
        # self.attention = SequentialMultiHeadAttention(
        #     hidden_size=config.hidden_size,
        #     num_heads=config.num_heads
        # )
		
		# 并行计算的
        d_k = config.hidden_size // config.num_heads
        self.attention = MultiHeadAttention(
            d_k=d_k,
            d_v=d_k,
            d_model=config.hidden_size,
            n_heads=config.num_heads,
            dropout=config.dropout
        )
        self.feed_forward = FeedForward(
            config.hidden_size,
            config.intermediate_size,
            config.dropout
        )

    def forward(self, x, mask=None):
        # Apply attention with a skip connection
        x = x + self.attention(x, x, x, attn_mask=mask)
        x = self.layer_norm_1(x)
        # Apply feed-forward layer with a skip connection
        x = x + self.feed_forward(x)
        x = self.layer_norm_2(x)
        return x

class Embeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(0.2)

    def forward(self, input_ids):
        # Create position IDs for input sequence
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device).unsqueeze(0)
        # Create token and position embeddings
        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        # Combine token and position embeddings
        embeddings = token_embeddings + position_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layers = nn.ModuleList(
            [TransformerEncoderLayer(config) for _ in range(config.num_hidden_layers)]
        )

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return x


class TransModel(nn.Module):
    def __init__(self, config):
        super(TransModel, self).__init__()

        self.encoder = TransformerEncoder(config)
        self.embedding = Embeddings(config)

        self.tanh1 = nn.Tanh()
        self.w1 = nn.Parameter(torch.randn(config.hidden_size))

        self.tanh2 = nn.Tanh()
        self.w2 = nn.Parameter(torch.randn(config.hidden_size))

        if config.use_know:
            self.classifier = nn.Sequential(
                nn.Linear(config.hidden_size * 2, config.hidden_size2),
                nn.ReLU(),
                nn.Linear(config.hidden_size2, config.num_classes)
            )
        else:
            self.classifier = nn.Sequential(
                nn.Linear(config.hidden_size , config.hidden_size2),
                nn.ReLU(),
                nn.Linear(config.hidden_size2, config.num_classes)
            )

        self.loss_fn = nn.CrossEntropyLoss()

        self.config = config


    def get_features(self, input_ids, masks):
        if len(input_ids.size()) == 3:
            input_ids = input_ids.view(-1, input_ids.size(-1))
        emb = self.embedding(input_ids)  # [batch_size, seq_len, embeding]
        hidden_vec = self.encoder(emb, masks)  # [batch_size, seq_len, hidden_size * num_direction]
        hidden_vec = self.tanh1(hidden_vec)  # [batch_size, seq_len, hidden_size * num_direction]
        alpha = F.softmax(torch.matmul(hidden_vec, self.w1), dim=-1).unsqueeze(-1)  # [batch_size, seq_len, 1]
        out = hidden_vec * alpha  # [batch, seq_len, 256]
        out = torch.sum(out, dim=-2)  # [batch, 256]
        return out


    def forward(self, input_ids, masks, know_input_ids, know_masks, labels=None):
        sent_feature = self.get_features(input_ids, masks)  # [batch, 256]

        if self.config.use_know:
            know_feature = self.get_features(know_input_ids, know_masks)  # [batch*3, 256]
            know_feature = know_feature.view(input_ids.size(0), -1, know_feature.size(-1))  # [batch, 3, 256]

            alpha = F.softmax(self.tanh2(torch.matmul(know_feature, self.w2)), dim=-1).unsqueeze(dim=-1)  # [batch, 3, 1]
            know_feature = (know_feature * alpha).sum(dim=-2)

            # know_feature = know_feature.sum(dim=-2)

            out = self.classifier(torch.cat([sent_feature, know_feature], dim=-1))
        else:
            out = self.classifier(sent_feature)


        if labels is not None:
            loss = self.loss_fn(out, labels)
            return_tuple = (loss, out)
        else:
            return_tuple = (out, )
        return return_tuple



if __name__ == '__main__':
    from transformers import AutoConfig
    from transformers import AutoTokenizer

    model_ckpt = "bert-base-uncased"
    tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
    config = AutoConfig.from_pretrained(model_ckpt)

    text = "time flies like an arrow"
    inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)

    encoder = TransformerEncoder(config)
    print(encoder(inputs.input_ids).size())

猜你喜欢

转载自blog.csdn.net/mch2869253130/article/details/128837464
今日推荐