pytorch シーケンスからシーケンスへの例

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 构建序列到序列模型
class Seq2Seq(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Seq2Seq, self).__init__()
        self.hidden_size = hidden_size
        self.encoder = nn.GRU(input_size, hidden_size)
        self.decoder = nn.GRU(hidden_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size, output_size)

    def forward(self, input_seq, target_seq=None, teacher_forcing_ratio=0.5):
        input_length = input_seq.size(0)
        batch_size = input_seq.size(1)
        output_seq = torch.zeros(MAX_LENGTH, batch_size, output_size)  # 定义最大输出序列长度
        
        encoder_outputs, hidden = self.encoder(input_seq)
        
        decoder_input = torch.zeros(1, batch_size, output_size)  # 起始符号
        decoder_hidden = hidden
        
        for t in range(MAX_LENGTH):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            output = self.output_layer(decoder_output)
            output_seq[t] = output
            
            top1 = output.argmax(2)  # 使用预测作为下一个输入
            decoder_input = top1.view(1, batch_size, -1)
        
        return output_seq

おすすめ

転載: blog.csdn.net/jacke121/article/details/135053396