import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
batch_size =64
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))])#把[]中的操作整成一个pipline,均值和标准差
train_dataset = datasets.MNIST(root='./dataset/mnist/',
train=True,
download=True,
transform=transform)
train_loader = DataLoader(train_dataset,
shuffle=True,
batch_size=batch_size)
test_dataset = datasets.MNIST(root='./dataset/mnist/',
train=False,
download=True,
transform=transform)
test_loader = DataLoader(test_dataset,
shuffle=False,
batch_size=batch_size)classNet(torch.nn.Module):def__init__(self):super(Net, self).__init__()
self.l1 = torch.nn.Linear(784,512)
self.l2 = torch.nn.Linear(512,256)
self.l3 = torch.nn.Linear(256,128)
self.l4 = torch.nn.Linear(128,64)
self.l5 = torch.nn.Linear(64,10)defforward(self, x):
x = x.view(-1,784)
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
x = F.relu(self.l3(x))
x = F.relu(self.l4(x))return self.l5(x)
model = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)deftrain(epoch):
running_loss =0.0for batch_idx, data inenumerate(train_loader,0):
inputs, target = data
optimizer.zero_grad()# forward + backward + update
outputs = model(inputs)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()if batch_idx %300==299:print('[%d, %5d] loss: %.3f'%(epoch +1, batch_idx +1, running_loss /300))
running_loss =0.0deftest():
correct =0
total =0with torch.no_grad():for data in test_loader:
images, labels = data
outputs = model(images)
_, predicted = torch.max(outputs.data, dim=1)
total += labels.size(0)
correct +=(predicted == labels).sum().item()print('Accuracy on test set: %d %%'%(100* correct / total))if __name__ =='__main__':for epoch inrange(10):
train(epoch)
test()
D:\ANACONDA\envs\pytorch_gpu\python.exe E:/Python面试准备/python基础/练习/xd.py
[1,300] loss:2.173[1,600] loss:0.805[1,900] loss:0.419
Accuracy on test set:88%[2,300] loss:0.325[2,600] loss:0.270[2,900] loss:0.221
Accuracy on test set:93%[3,300] loss:0.188[3,600] loss:0.173[3,900] loss:0.155
Accuracy on test set:95%[4,300] loss:0.135[4,600] loss:0.127[4,900] loss:0.115
Accuracy on test set:96%[5,300] loss:0.101[5,600] loss:0.100[5,900] loss:0.091
Accuracy on test set:96%[6,300] loss:0.080[6,600] loss:0.077[6,900] loss:0.077
Accuracy on test set:96%[7,300] loss:0.060[7,600] loss:0.066[7,900] loss:0.066
Accuracy on test set:97%[8,300] loss:0.047[8,600] loss:0.055[8,900] loss:0.053
Accuracy on test set:97%[9,300] loss:0.040[9,600] loss:0.042[9,900] loss:0.043
Accuracy on test set:97%[10,300] loss:0.032[10,600] loss:0.033[10,900] loss:0.036
Accuracy on test set:97%
Process finished with exit code 0