Comenzando con pytorch para aprender las redes neuronales convolucionales de CNN

Comenzando con pytorch para aprender las redes neuronales convolucionales de CNN

Inserte la descripción de la imagen aquí

resultado de la operación

 [1,  300] loss: 1.008
 [1,  600] loss: 0.289
 [1,  900] loss: 0.203
Accuracy on test set: 95 % 
 [2,  300] loss: 0.163
 [2,  600] loss: 0.135
 [2,  900] loss: 0.123
Accuracy on test set: 96 % 
 [3,  300] loss: 0.109
 [3,  600] loss: 0.107
 [3,  900] loss: 0.099
Accuracy on test set: 97 % 
 [4,  300] loss: 0.092
 [4,  600] loss: 0.085
 [4,  900] loss: 0.084
Accuracy on test set: 97 % 
 [5,  300] loss: 0.078
 [5,  600] loss: 0.077
 [5,  900] loss: 0.074
Accuracy on test set: 98 % 
 [6,  300] loss: 0.069
 [6,  600] loss: 0.070
 [6,  900] loss: 0.068
Accuracy on test set: 98 % 
 [7,  300] loss: 0.065
 [7,  600] loss: 0.061
 [7,  900] loss: 0.061
Accuracy on test set: 98 % 
 [8,  300] loss: 0.060
 [8,  600] loss: 0.056
 [8,  900] loss: 0.059
Accuracy on test set: 98 % 
 [9,  300] loss: 0.051
 [9,  600] loss: 0.052
 [9,  900] loss: 0.057
Accuracy on test set: 98 % 
 [10,  300] loss: 0.052
 [10,  600] loss: 0.049
 [10,  900] loss: 0.052
Accuracy on test set: 98 % 

Process finished with exit code 0

import  torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim


#step1 准备数据集

batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.137,),(0.3081,))
])

train_dataset = datasets.MNIST(root='../dataset/mnist',
                               train=True,
                               download=True,
                               transform=transform)

train_loder = DataLoader(train_dataset,
                         shuffle=True,
                         batch_size=batch_size)

test_dataset = datasets.MNIST(root='../dataset/mnist',
                               train=False,
                               download=True,
                               transform=transform)

test_loder = DataLoader(test_dataset,
                        shuffle=False,
                        batch_size=batch_size)

#step2 搭建网络
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(1,10,kernel_size=5)   #卷积层
        self.conv2 = torch.nn.Conv2d(10,20,kernel_size=5)
        self.pooling = torch.nn.MaxPool2d(2)   #池化
        self.fc = torch.nn.Linear(320,10)      #full connecting 全连接层

    def forward(self,x):
        batch_size = x.size(0)
        x = F.relu(self.pooling(self.conv1(x)))   # 卷积层 -> 池化 -> relu 层
        x = F.relu(self.pooling(self.conv2(x)))
        x = x.view(batch_size,-1)
        x = self.fc(x)
        return x

model = Net()

#使用GPU进行加速
device = torch.device("cuda:0" if torch.cuda.is_available() else"cpu")
model.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)


#step3 训练
def train(epoch):
    running_loss = 0.0
    for batch_idx,data in enumerate(train_loder,0):
        inputs,target = data
        inputs,target = inputs.to(device),target.to(device)
        optimizer.zero_grad()    #梯度清零

        #forward + backward + update
        outputs = model(inputs)
        loss = criterion(outputs,target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print(' [%d,%5d] loss: %.3f' % (epoch + 1,batch_idx + 1, running_loss / 300))
            running_loss = 0.0

def test():
    correct = 0
    total = 0
    with torch.no_grad():      #不计算梯度
        for data in test_loder:
            inputs,target = data
            inputs, target = inputs.to(device), target.to(device)
            outputs = model(inputs)
            _,predicted = torch.max(outputs.data,dim=1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    print('Accuracy on test set: %d %% '%(100 * correct / total))

if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

Supongo que te gusta

Origin blog.csdn.net/weixin_41281151/article/details/108556059
Recomendado
Clasificación