transformer结构输出部分

在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F

class generator(nn.Module):
        def __init__(self,d_model,vocab):
                super().__init__()

                self.linear=nn.Linear(d_model,vocab)

        def forward(self,input)
                return F.log_foftmax(self.linear(input),dim=-1) 

gen = Generator(d_model, vocab_size)
gen_result = gen(x)
发布了66 篇原创文章 · 获赞 1 · 访问量 7001

猜你喜欢

转载自blog.csdn.net/qq_41128383/article/details/105731814
今日推荐