import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# 构建序列到序列模型
class Seq2Seq(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Seq2Seq, self).__init__()
self.hidden_size = hidden_size
self.encoder = nn.GRU(input_size, hidden_size)
self.decoder = nn.GRU(hidden_size, hidden_size)
self.output_layer = nn.Linear(hidden_size, output_size)
def forward(self, input_seq, target_seq=None, teacher_forcing_ratio=0.5):
input_length = input_seq.size(0)
batch_size = input_seq.size(1)
output_seq = torch.zeros(MAX_LENGTH, batch_size, output_size) # 定义最大输出序列长度
encoder_outputs, hidden = self.encoder(input_seq)
decoder_input = torch.zeros(1, batch_size, output_size) # 起始符号
decoder_hidden = hidden
for t in range(MAX_LENGTH):
decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
output = self.output_layer(decoder_output)
output_seq[t] = output
top1 = output.argmax(2) # 使用预测作为下一个输入
decoder_input = top1.view(1, batch_size, -1)
return output_seq
pytorch シーケンスからシーケンスへの例
おすすめ
転載: blog.csdn.net/jacke121/article/details/135053396
おすすめ
ランキング