import torch
from torch import nn
import torchvision.datasets as dsets
import torch.utils.data as Data
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
EPOCH = 5
BATCH_SIZE = 64
TIME_STEP = 28 # 考虑了多少时间点的数据
INPUT_SIZE = 28
LR = 0.001
USE_GPU = torch.cuda.is_available()
print('GPU:', USE_GPU)
# 准备数据集
train_dataset = dsets.MNIST(root='../../data_sets/mnist',
train=True,
transform=transforms.ToTensor(),
download=False)
test_dataset = dsets.MNIST(root='../../data_sets/mnist',
train=False,
transform=transforms.ToTensor(),
download=False)
train_data_loader = Data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_data_loader = Data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.LSTM(
input_size=INPUT_SIZE,
hidden_size=64,
num_layers=3,
batch_first=True, # 输入输出都是一维格式的(batch,time_step,input_size)
)
self.out = nn.Linear(64, 10)
def forward(self, x):
r_out, (h_n, h_c) = self.rnn(x, None)
out = self.out(r_out[:, -1, :])
return out
rnn = RNN()
if USE_GPU:
rnn = rnn.cuda()
print(rnn)
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
loss_numpy = np.array([])
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(test_data_loader):
b_x = b_x.view(-1, 28, 28)
if USE_GPU:
b_x = b_x.cuda()
b_y = b_y.cuda()
pre = rnn(b_x)
loss = loss_func(pre, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 20 == 0:
print("Epoch:", epoch, "Step:", step, "loss", loss.item())
loss_numpy = np.insert(loss_numpy, loss_numpy.size, loss.item())
with torch.no_grad():
total = 0
correct = 0
for x, y in train_data_loader:
x = x.view(-1, 28, 28)
if USE_GPU:
x = x.cuda()
y = y.cuda()
pre = rnn(x)
_, pre = torch.max(pre.data, 1)
total += y.size(0)
correct += (pre == y).sum().item()
print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))
# print(loss_numpy)
plt.plot(loss_numpy, 'r-', )
plt.title('loss function', fontsize='large')
plt.xlabel('step (20)')
plt.ylabel('loss')
plt.show()
准确率:Accuracy of the network on the 10000 test images: 98.67833333333333 %
loss函数图像(锯齿形):