LSTM 预测各国人均GPD(Pytorch)

以下是训练过程:

from pandas_datareader.pandas_datareader import wb
import torch.nn
import torch
import torch.optim
import csv
from IPython.display import display
import pandas as pd 
import numpy
import matplotlib.pyplot as plt

class Net(torch.nn.Module):
    def __init__(self,input_size,hidden_size):
        super(Net,self).__init__()
        self.rnn = torch.nn.LSTM(input_size,hidden_size)
        self.fc = torch.nn.Linear(hidden_size,1)
 
    def forward(self, x):
        x = x[:,:,None]
        x, _ = self.rnn(x)
        x = self.fc(x)
        x = x[:,:,0]
        return x
 
countries=['BR', 'CA', 'CN', 'FR', 'DE', 'IN', 'IL','JP', 'SA', 'GB', 'US']
dat = w

猜你喜欢

转载自blog.csdn.net/tony2278/article/details/105248260
今日推荐