1210 MultiHeadAttention放在后面效果更好。
1211: MultiHeadAttention头和尾各加一个,效果又有改善
分类的激活函数relu6比gelu效果更好。
import os
import time
from typing import Dict, Iterable, Optional
from torch import nn, Tensor
from torch.nn import Linear
import torch.nn.functional as F
import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None, ):
q = self.query(x)
if kv_cache is None or xa is None or self.key