Premiers pas avec pytorch pour apprendre les réseaux de neurones convolutifs CNN

Premiers pas avec pytorch pour apprendre les réseaux de neurones convolutifs CNN

Insérez la description de l'image ici

résultat de l'opération

 [1,  300] loss: 1.008
 [1,  600] loss: 0.289
 [1,  900] loss: 0.203
Accuracy on test set: 95 % 
 [2,  300] loss: 0.163
 [2,  600] loss: 0.135
 [2,  900] loss: 0.123
Accuracy on test set: 96 % 
 [3,  300] loss: 0.109
 [3,  600] loss: 0.107
 [3,  900] loss: 0.099
Accuracy on test set: 97 % 
 [4,  300] loss: 0.092
 [4,  600] loss: 0.085
 [4,  900] loss: 0.084
Accuracy on test set: 97 % 
 [5,  300] loss: 0.078
 [5,  600] loss: 0.077
 [5,  900] loss: 0.074
Accuracy on test set: 98 % 
 [6,  300] loss: 0.069
 [6,  600] loss: 0.070
 [6,  900] loss: 0.068
Accuracy on test set: 98 % 
 [7,  300] loss: 0.065
 [7,  600] loss: 0.061
 [7,  900] loss: 0.061
Accuracy on test set: 98 % 
 [8,  300] loss: 0.060
 [8,  600] loss: 0.056
 [8,  900] loss: 0.059
Accuracy on test set: 98 % 
 [9,  300] loss: 0.051
 [9,  600] loss: 0.052
 [9,  900] loss: 0.057
Accuracy on test set: 98 % 
 [10,  300] loss: 0.052
 [10,  600] loss: 0.049
 [10,  900] loss: 0.052
Accuracy on test set: 98 % 

Process finished with exit code 0

import  torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim


#step1 准备数据集

batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.137,),(0.3081,))
])

train_dataset = datasets.MNIST(root='../dataset/mnist',
                               train=True,
                               download=True,
                               transform=transform)

train_loder = DataLoader(train_dataset,
                         shuffle=True,
                         batch_size=batch_size)

test_dataset = datasets.MNIST(root='../dataset/mnist',
                               train=False,
                               download=True,
                               transform=transform)

test_loder = DataLoader(test_dataset,
                        shuffle=False,
                        batch_size=batch_size)

#step2 搭建网络
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(1,10,kernel_size=5)   #卷积层
        self.conv2 = torch.nn.Conv2d(10,20,kernel_size=5)
        self.pooling = torch.nn.MaxPool2d(2)   #池化
        self.fc = torch.nn.Linear(320,10)      #full connecting 全连接层

    def forward(self,x):
        batch_size = x.size(0)
        x = F.relu(self.pooling(self.conv1(x)))   # 卷积层 -> 池化 -> relu 层
        x = F.relu(self.pooling(self.conv2(x)))
        x = x.view(batch_size,-1)
        x = self.fc(x)
        return x

model = Net()

#使用GPU进行加速
device = torch.device("cuda:0" if torch.cuda.is_available() else"cpu")
model.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)


#step3 训练
def train(epoch):
    running_loss = 0.0
    for batch_idx,data in enumerate(train_loder,0):
        inputs,target = data
        inputs,target = inputs.to(device),target.to(device)
        optimizer.zero_grad()    #梯度清零

        #forward + backward + update
        outputs = model(inputs)
        loss = criterion(outputs,target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print(' [%d,%5d] loss: %.3f' % (epoch + 1,batch_idx + 1, running_loss / 300))
            running_loss = 0.0

def test():
    correct = 0
    total = 0
    with torch.no_grad():      #不计算梯度
        for data in test_loder:
            inputs,target = data
            inputs, target = inputs.to(device), target.to(device)
            outputs = model(inputs)
            _,predicted = torch.max(outputs.data,dim=1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    print('Accuracy on test set: %d %% '%(100 * correct / total))

if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

Je suppose que tu aimes

Origine blog.csdn.net/weixin_41281151/article/details/108556059
conseillé
Classement