import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
# from torch.autograd import Variable
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batchsz = 50
input_size = 784
num_classes = 50
learning_rate = 1e-3
num_epochs = 10
train_dataset = datasets.MNIST(root="./data", train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batchsz, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batchsz, shuffle=False)
class LogisticRegression(nn.Module):
def __init__(self, input_size, num_classes):
super(LogisticRegression, self).__init__()
self.linear = nn.Linear(input_size, num_classes)
def forward(self, x):
out = self.linear(x)
return out
# net = LogisticRegression(5,1)
# print(net)
model = LogisticRegression(input_size, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
for i, (images,labels) in enumerate(train_loader):
# import pdb;pdb.set_trace()
images = images.view(-1, 28*28).to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print("Epoch: [%d/%d],step: [%d/%d], loss: %.4f"
% (epoch+1, num_epochs, i+1, len(train_dataset) // batchsz, loss.item()))
with torch.no_grad():
for images, labels in test_loader:
correct = 0
total = 0
# import pdb;pdb.set_trace()
images = images.reshape(-1, 28*28).to(device)
labels = labels.to(device)
outputs = model(images).to(device)
_,predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print("Accuracy of the model on the 10000 test images: %d %% " % (100*correct/total))
torch.save(model.state_dict(), 'model.pkl')
运行结果:(准确率80%,有点低)
Epoch: [1/10],step: [100/1200], loss: 3.5799
Epoch: [1/10],step: [200/1200], loss: 3.2070
Epoch: [1/10],step: [300/1200], loss: 2.8602
Epoch: [1/10],step: [400/1200], loss: 2.7669
Epoch: [1/10],step: [500/1200], loss: 2.5550
Epoch: [1/10],step: [600/1200], loss: 2.3238
Epoch: [1/10],step: [700/1200], loss: 2.2219
Epoch: [1/10],step: [800/1200], loss: 2.1080
Epoch: [1/10],step: [900/1200], loss: 1.9478
Epoch: [1/10],step: [1000/1200], loss: 2.0522
Epoch: [1/10],step: [1100/1200], loss: 1.6376
Epoch: [1/10],step: [1200/1200], loss: 1.7999
Epoch: [2/10],step: [100/1200], loss: 1.6434
Epoch: [2/10],step: [200/1200], loss: 1.5681
Epoch: [2/10],step: [300/1200], loss: 1.6594
Epoch: [2/10],step: [400/1200], loss: 1.3429
Epoch: [2/10],step: [500/1200], loss: 1.3682
Epoch: [2/10],step: [600/1200], loss: 1.4606
Epoch: [2/10],step: [700/1200], loss: 1.4246
Epoch: [2/10],step: [800/1200], loss: 1.2950
Epoch: [2/10],step: [900/1200], loss: 1.4897
Epoch: [2/10],step: [1000/1200], loss: 1.2748
Epoch: [2/10],step: [1100/1200], loss: 1.2010
Epoch: [2/10],step: [1200/1200], loss: 1.1625
Epoch: [3/10],step: [100/1200], loss: 1.4096
Epoch: [3/10],step: [200/1200], loss: 1.1130
Epoch: [3/10],step: [300/1200], loss: 1.2882
Epoch: [3/10],step: [400/1200], loss: 1.0916
Epoch: [3/10],step: [500/1200], loss: 1.3179
Epoch: [3/10],step: [600/1200], loss: 1.1008
Epoch: [3/10],step: [700/1200], loss: 1.0496
Epoch: [3/10],step: [800/1200], loss: 0.9719
Epoch: [3/10],step: [900/1200], loss: 1.0363
Epoch: [3/10],step: [1000/1200], loss: 0.9731
Epoch: [3/10],step: [1100/1200], loss: 0.8112
Epoch: [3/10],step: [1200/1200], loss: 1.0275
Epoch: [4/10],step: [100/1200], loss: 0.9957
Epoch: [4/10],step: [200/1200], loss: 1.0084
Epoch: [4/10],step: [300/1200], loss: 0.9670
Epoch: [4/10],step: [400/1200], loss: 0.9439
Epoch: [4/10],step: [500/1200], loss: 0.8658
Epoch: [4/10],step: [600/1200], loss: 0.8131
Epoch: [4/10],step: [700/1200], loss: 1.0619
Epoch: [4/10],step: [800/1200], loss: 0.9512
Epoch: [4/10],step: [900/1200], loss: 1.0149
Epoch: [4/10],step: [1000/1200], loss: 1.0696
Epoch: [4/10],step: [1100/1200], loss: 0.8596
Epoch: [4/10],step: [1200/1200], loss: 0.8374
Epoch: [5/10],step: [100/1200], loss: 0.7280
Epoch: [5/10],step: [200/1200], loss: 0.8602
Epoch: [5/10],step: [300/1200], loss: 0.7759
Epoch: [5/10],step: [400/1200], loss: 0.7611
Epoch: [5/10],step: [500/1200], loss: 0.8259
Epoch: [5/10],step: [600/1200], loss: 0.8372
Epoch: [5/10],step: [700/1200], loss: 0.6777
Epoch: [5/10],step: [800/1200], loss: 0.8727
Epoch: [5/10],step: [900/1200], loss: 0.7577
Epoch: [5/10],step: [1000/1200], loss: 0.8768
Epoch: [5/10],step: [1100/1200], loss: 0.8232
Epoch: [5/10],step: [1200/1200], loss: 0.7646
Epoch: [6/10],step: [100/1200], loss: 0.7075
Epoch: [6/10],step: [200/1200], loss: 0.6172
Epoch: [6/10],step: [300/1200], loss: 0.7680
Epoch: [6/10],step: [400/1200], loss: 0.7379
Epoch: [6/10],step: [500/1200], loss: 0.6340
Epoch: [6/10],step: [600/1200], loss: 0.8410
Epoch: [6/10],step: [700/1200], loss: 0.5893
Epoch: [6/10],step: [800/1200], loss: 0.6192
Epoch: [6/10],step: [900/1200], loss: 0.7843
Epoch: [6/10],step: [1000/1200], loss: 0.7142
Epoch: [6/10],step: [1100/1200], loss: 0.7099
Epoch: [6/10],step: [1200/1200], loss: 0.8738
Epoch: [7/10],step: [100/1200], loss: 0.6682
Epoch: [7/10],step: [200/1200], loss: 0.7589
Epoch: [7/10],step: [300/1200], loss: 0.7899
Epoch: [7/10],step: [400/1200], loss: 0.7006
Epoch: [7/10],step: [500/1200], loss: 0.7020
Epoch: [7/10],step: [600/1200], loss: 0.7768
Epoch: [7/10],step: [700/1200], loss: 0.6600
Epoch: [7/10],step: [800/1200], loss: 0.4579
Epoch: [7/10],step: [900/1200], loss: 0.6737
Epoch: [7/10],step: [1000/1200], loss: 0.5730
Epoch: [7/10],step: [1100/1200], loss: 0.6239
Epoch: [7/10],step: [1200/1200], loss: 0.6538
Epoch: [8/10],step: [100/1200], loss: 0.7026
Epoch: [8/10],step: [200/1200], loss: 0.5370
Epoch: [8/10],step: [300/1200], loss: 0.6488
Epoch: [8/10],step: [400/1200], loss: 0.5992
Epoch: [8/10],step: [500/1200], loss: 0.5131
Epoch: [8/10],step: [600/1200], loss: 0.5598
Epoch: [8/10],step: [700/1200], loss: 0.5424
Epoch: [8/10],step: [800/1200], loss: 0.6852
Epoch: [8/10],step: [900/1200], loss: 0.7855
Epoch: [8/10],step: [1000/1200], loss: 0.8309
Epoch: [8/10],step: [1100/1200], loss: 0.6523
Epoch: [8/10],step: [1200/1200], loss: 0.4958
Epoch: [9/10],step: [100/1200], loss: 0.6622
Epoch: [9/10],step: [200/1200], loss: 0.5233
Epoch: [9/10],step: [300/1200], loss: 0.6910
Epoch: [9/10],step: [400/1200], loss: 0.6326
Epoch: [9/10],step: [500/1200], loss: 0.4408
Epoch: [9/10],step: [600/1200], loss: 0.5026
Epoch: [9/10],step: [800/1200], loss: 0.3960
Epoch: [9/10],step: [900/1200], loss: 0.6953
Epoch: [9/10],step: [1000/1200], loss: 0.4921
Epoch: [9/10],step: [1100/1200], loss: 0.7392
Epoch: [9/10],step: [1200/1200], loss: 0.5771
Epoch: [10/10],step: [100/1200], loss: 0.8520
Epoch: [10/10],step: [200/1200], loss: 0.7730
Epoch: [10/10],step: [300/1200], loss: 0.6758
Epoch: [10/10],step: [400/1200], loss: 0.4674
Epoch: [10/10],step: [500/1200], loss: 0.4745
Epoch: [10/10],step: [600/1200], loss: 0.7266
Epoch: [10/10],step: [700/1200], loss: 0.6446
Epoch: [10/10],step: [800/1200], loss: 0.4252
Epoch: [10/10],step: [900/1200], loss: 0.5491
Epoch: [10/10],step: [1000/1200], loss: 0.7494
Epoch: [10/10],step: [1100/1200], loss: 0.5062
Epoch: [10/10],step: [1200/1200], loss: 0.5986
Accuracy of the model on the 10000 test images: 80 %