pytorch实现逻辑回归模型训练

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
# from torch.autograd import Variable

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batchsz = 50
input_size = 784
num_classes = 50
learning_rate = 1e-3
num_epochs = 10

train_dataset = datasets.MNIST(root="./data", train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batchsz, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batchsz, shuffle=False)

class LogisticRegression(nn.Module):
    def __init__(self, input_size, num_classes):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_size, num_classes)
    
    def forward(self, x):
        out = self.linear(x)
        return out

# net = LogisticRegression(5,1)
# print(net)

model = LogisticRegression(input_size, num_classes)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    for i, (images,labels) in enumerate(train_loader):
        # import pdb;pdb.set_trace()

        images = images.view(-1, 28*28).to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print("Epoch: [%d/%d],step: [%d/%d], loss: %.4f" 
                % (epoch+1, num_epochs, i+1, len(train_dataset) // batchsz, loss.item()))

with torch.no_grad():
    for images, labels in test_loader:
        correct = 0
        total = 0
        # import pdb;pdb.set_trace()
        images = images.reshape(-1, 28*28).to(device)

        labels = labels.to(device)
        outputs = model(images).to(device)
        _,predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print("Accuracy of the model on the 10000 test images: %d %% " % (100*correct/total))

    torch.save(model.state_dict(), 'model.pkl')

运行结果:(准确率80%,有点低)

Epoch: [1/10],step: [100/1200], loss: 3.5799
Epoch: [1/10],step: [200/1200], loss: 3.2070
Epoch: [1/10],step: [300/1200], loss: 2.8602
Epoch: [1/10],step: [400/1200], loss: 2.7669
Epoch: [1/10],step: [500/1200], loss: 2.5550
Epoch: [1/10],step: [600/1200], loss: 2.3238
Epoch: [1/10],step: [700/1200], loss: 2.2219
Epoch: [1/10],step: [800/1200], loss: 2.1080
Epoch: [1/10],step: [900/1200], loss: 1.9478
Epoch: [1/10],step: [1000/1200], loss: 2.0522
Epoch: [1/10],step: [1100/1200], loss: 1.6376
Epoch: [1/10],step: [1200/1200], loss: 1.7999
Epoch: [2/10],step: [100/1200], loss: 1.6434
Epoch: [2/10],step: [200/1200], loss: 1.5681
Epoch: [2/10],step: [300/1200], loss: 1.6594
Epoch: [2/10],step: [400/1200], loss: 1.3429
Epoch: [2/10],step: [500/1200], loss: 1.3682
Epoch: [2/10],step: [600/1200], loss: 1.4606
Epoch: [2/10],step: [700/1200], loss: 1.4246
Epoch: [2/10],step: [800/1200], loss: 1.2950
Epoch: [2/10],step: [900/1200], loss: 1.4897
Epoch: [2/10],step: [1000/1200], loss: 1.2748
Epoch: [2/10],step: [1100/1200], loss: 1.2010
Epoch: [2/10],step: [1200/1200], loss: 1.1625
Epoch: [3/10],step: [100/1200], loss: 1.4096
Epoch: [3/10],step: [200/1200], loss: 1.1130
Epoch: [3/10],step: [300/1200], loss: 1.2882
Epoch: [3/10],step: [400/1200], loss: 1.0916
Epoch: [3/10],step: [500/1200], loss: 1.3179
Epoch: [3/10],step: [600/1200], loss: 1.1008
Epoch: [3/10],step: [700/1200], loss: 1.0496
Epoch: [3/10],step: [800/1200], loss: 0.9719
Epoch: [3/10],step: [900/1200], loss: 1.0363
Epoch: [3/10],step: [1000/1200], loss: 0.9731
Epoch: [3/10],step: [1100/1200], loss: 0.8112
Epoch: [3/10],step: [1200/1200], loss: 1.0275
Epoch: [4/10],step: [100/1200], loss: 0.9957
Epoch: [4/10],step: [200/1200], loss: 1.0084
Epoch: [4/10],step: [300/1200], loss: 0.9670
Epoch: [4/10],step: [400/1200], loss: 0.9439
Epoch: [4/10],step: [500/1200], loss: 0.8658
Epoch: [4/10],step: [600/1200], loss: 0.8131
Epoch: [4/10],step: [700/1200], loss: 1.0619
Epoch: [4/10],step: [800/1200], loss: 0.9512
Epoch: [4/10],step: [900/1200], loss: 1.0149
Epoch: [4/10],step: [1000/1200], loss: 1.0696
Epoch: [4/10],step: [1100/1200], loss: 0.8596
Epoch: [4/10],step: [1200/1200], loss: 0.8374
Epoch: [5/10],step: [100/1200], loss: 0.7280
Epoch: [5/10],step: [200/1200], loss: 0.8602
Epoch: [5/10],step: [300/1200], loss: 0.7759
Epoch: [5/10],step: [400/1200], loss: 0.7611
Epoch: [5/10],step: [500/1200], loss: 0.8259
Epoch: [5/10],step: [600/1200], loss: 0.8372
Epoch: [5/10],step: [700/1200], loss: 0.6777
Epoch: [5/10],step: [800/1200], loss: 0.8727
Epoch: [5/10],step: [900/1200], loss: 0.7577
Epoch: [5/10],step: [1000/1200], loss: 0.8768
Epoch: [5/10],step: [1100/1200], loss: 0.8232
Epoch: [5/10],step: [1200/1200], loss: 0.7646
Epoch: [6/10],step: [100/1200], loss: 0.7075
Epoch: [6/10],step: [200/1200], loss: 0.6172
Epoch: [6/10],step: [300/1200], loss: 0.7680
Epoch: [6/10],step: [400/1200], loss: 0.7379
Epoch: [6/10],step: [500/1200], loss: 0.6340
Epoch: [6/10],step: [600/1200], loss: 0.8410
Epoch: [6/10],step: [700/1200], loss: 0.5893
Epoch: [6/10],step: [800/1200], loss: 0.6192
Epoch: [6/10],step: [900/1200], loss: 0.7843
Epoch: [6/10],step: [1000/1200], loss: 0.7142
Epoch: [6/10],step: [1100/1200], loss: 0.7099
Epoch: [6/10],step: [1200/1200], loss: 0.8738
Epoch: [7/10],step: [100/1200], loss: 0.6682
Epoch: [7/10],step: [200/1200], loss: 0.7589
Epoch: [7/10],step: [300/1200], loss: 0.7899
Epoch: [7/10],step: [400/1200], loss: 0.7006
Epoch: [7/10],step: [500/1200], loss: 0.7020
Epoch: [7/10],step: [600/1200], loss: 0.7768
Epoch: [7/10],step: [700/1200], loss: 0.6600
Epoch: [7/10],step: [800/1200], loss: 0.4579
Epoch: [7/10],step: [900/1200], loss: 0.6737
Epoch: [7/10],step: [1000/1200], loss: 0.5730
Epoch: [7/10],step: [1100/1200], loss: 0.6239
Epoch: [7/10],step: [1200/1200], loss: 0.6538
Epoch: [8/10],step: [100/1200], loss: 0.7026
Epoch: [8/10],step: [200/1200], loss: 0.5370
Epoch: [8/10],step: [300/1200], loss: 0.6488
Epoch: [8/10],step: [400/1200], loss: 0.5992
Epoch: [8/10],step: [500/1200], loss: 0.5131
Epoch: [8/10],step: [600/1200], loss: 0.5598
Epoch: [8/10],step: [700/1200], loss: 0.5424
Epoch: [8/10],step: [800/1200], loss: 0.6852
Epoch: [8/10],step: [900/1200], loss: 0.7855
Epoch: [8/10],step: [1000/1200], loss: 0.8309
Epoch: [8/10],step: [1100/1200], loss: 0.6523
Epoch: [8/10],step: [1200/1200], loss: 0.4958
Epoch: [9/10],step: [100/1200], loss: 0.6622
Epoch: [9/10],step: [200/1200], loss: 0.5233
Epoch: [9/10],step: [300/1200], loss: 0.6910
Epoch: [9/10],step: [400/1200], loss: 0.6326
Epoch: [9/10],step: [500/1200], loss: 0.4408
Epoch: [9/10],step: [600/1200], loss: 0.5026
Epoch: [9/10],step: [800/1200], loss: 0.3960
Epoch: [9/10],step: [900/1200], loss: 0.6953
Epoch: [9/10],step: [1000/1200], loss: 0.4921
Epoch: [9/10],step: [1100/1200], loss: 0.7392
Epoch: [9/10],step: [1200/1200], loss: 0.5771
Epoch: [10/10],step: [100/1200], loss: 0.8520
Epoch: [10/10],step: [200/1200], loss: 0.7730
Epoch: [10/10],step: [300/1200], loss: 0.6758
Epoch: [10/10],step: [400/1200], loss: 0.4674
Epoch: [10/10],step: [500/1200], loss: 0.4745
Epoch: [10/10],step: [600/1200], loss: 0.7266
Epoch: [10/10],step: [700/1200], loss: 0.6446
Epoch: [10/10],step: [800/1200], loss: 0.4252
Epoch: [10/10],step: [900/1200], loss: 0.5491
Epoch: [10/10],step: [1000/1200], loss: 0.7494
Epoch: [10/10],step: [1100/1200], loss: 0.5062
Epoch: [10/10],step: [1200/1200], loss: 0.5986
Accuracy of the model on the 10000 test images: 80 %

猜你喜欢

转载自blog.csdn.net/qq_37369201/article/details/109472106