lstm 加多头注意力MultiHeadAttention

1210 MultiHeadAttention放在后面效果更好。

1211: MultiHeadAttention头和尾各加一个,效果又有改善

分类的激活函数relu6比gelu效果更好。

import os
import time
from typing import Dict, Iterable, Optional
from torch import nn, Tensor
from torch.nn import Linear

import torch.nn.functional as F

import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional

class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

    def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None, ):
        q = self.query(x)

        if kv_cache is None or xa is None or self.key

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/134658358