Transformer输入部分结构原理分析

Transformer网络架构

在这里插入图片描述

输入部分

源:源文本嵌入层及其位置编码器 目标:目标文本嵌入层及其位置编码器

文本嵌入层的作用:目的是将文本word2id的数字转变为以向量的方式表示
位置编码器的作用:因为在Transformer的编码器结构中, 并没有针对词汇位置信息的处理,因此需要在Embedding层后加入位置编码器,将词汇位置不同可能会产生不同语义的信息加入到词嵌入张量中, 以弥补位置信息的缺失

文本嵌入层与位置编码器代码实现

import torch
import torch.nn as nn
from torch.autograd import Variable
import math
class embedded(nn.Module):
	def __init__(self,vocab,dim_model):
		super().__init__()
		self.dim_model=dim_model
		self.embedding=nn.Embedding(vocab,dim_model)
	def forward(self,input):
		embedded=self.embedding(input)
		return embedded*math.sqrt(self.dim_model)

class positionalEncoding(nn.Module):
	def __init__(self,d_model,dropout,max_len=5000):
		super().__init__()
		self.dropout=nn.Dropout(p=dropout)
	
		self.pe=torch.zeros(max_len,d_model)

		self.position=torch.arange(0,max_len).unsqueeze(1)
	
		div_term=torch.exp(torch.arange(0,d_model,2)*-(math.log(10000.0)/d_model))

		self.pe[:,0::2]=torch.sin(self.position*div_term)
		self.pe[:,1::2]=torch.cos(self.position*div_term)
	
		self.pe=self.pe.unsqueeze(0)
		self.register_buffer('ppe',self.pe)

	
	def forward(self,x):
		x=x+Variable(self.pe[:,:x.size(1)],requires_grad=False)		
		
	
		return self.dropout(x)


if __name__=='__main__':
	dim_model=512
	vocab=10
	input=torch.tensor([[1,2,3,4],[5,6,7,8]])
	emb=embedded(vocab,dim_model)
	emp=emb(input)
	
	d_model=512
	dropout=0.2
	max_len=50
	x=emp
	pe=positionalEncoding(d_model,dropout,max_len)
	output_pe=pe(x)
	print(output_pe)

运行结果

tensor([[[-16.4910,  -6.9946,   0.0000,  ...,  -1.4522,   5.3102, -49.9791],
         [  2.5286, -21.2322, -31.5919,  ...,  34.9202, -25.4638,  36.9452],
         [-11.0071, -47.8129, -39.6638,  ...,  -3.3422, -14.5521, -39.2135],
         [ -0.0000,   3.5995,   6.9863,  ...,  -3.6490,  38.5564,   2.1162]],

        [[  6.8392, -18.2051,   0.0000,  ...,  -2.7158,  21.9266,  -0.0000],
         [ 36.3135, -25.6620, -16.9607,  ...,  24.3223,  -0.0000, -11.4633],
         [-25.7412,   0.0000, -12.9964,  ...,   1.3707,   7.1693,   2.4895],
         [ -0.0000,  21.8448,  -3.3618,  ...,   2.7337,  15.8136,  -8.6898]]],
       grad_fn=<MulBackward0>)

发布了66 篇原创文章 · 获赞 1 · 访问量 7004

猜你喜欢

转载自blog.csdn.net/qq_41128383/article/details/105673141
今日推荐