6-7Pytorch builds cifar10 training script (on)

Insert picture description here

Need to explain the code in detail~~

  1. import torch.nn.functional as F, contains all the functions in the torch.nn library, and contains a large number of loss and activation functions
# -*- encoding: utf-8 -*-
"""
@File    : train.py
@Time    : 2021-03-07 16:24
@Author  : XD
@Email   : [email protected]
@Software: PyCharm
"""
import os

import torch
import torch.nn as nn
import torchvision

from vggnet import VGGNet

from load_cifar10 import train_loader
from load_cifar10 import test_loader


#判断是否有gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#遍历200次
epoch_num = 1

#学习率
lr = 0.01

#batch_size = 128
batch_size = 128

net = VGGNet().to(device)

#loss多分类问题,交叉熵来定义
loss_func = nn.CrossEntropyLoss()

#定义优化器
optimizer = torch.optim.Adam(net.parameters(),lr = lr)
#optimizer = torch.optim.SGD(net.parameters(),lr = lr,
#                            monmentum = 0.9,weight_decat = 5e-4)

#质数衰减学习率
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size = 5,gamma = 0.9)

step_n = 0
for epoch in range(epoch_num):
    print(" epoch is ", epoch)
    net.train() #train BN dropout

    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = net(inputs)
        loss = loss_func(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print("step",i,"loss is:",loss.item())

        _,pred = torch.max(outputs.data,dim = 1)

        correct = pred.eq(labels.data).cpu().sum()

        print("step",i,"loss is:",loss.item(),
              "mini-batch correct is:",100.0 * correct / batch_size)
    if not os.path.exists("models"):
        os.mkdir("models")
    torch.save(net.state_dict(),"models\{}.pth".format(epoch + 1))

    scheduler.step()

    print("lr is",optimizer.state_dict()["param_groups"][0]["lr"])
num_of_train: 50000
num_of_test: 10000
 epoch is  0
step 0 loss is: 2.487577199935913
step 0 loss is: 2.487577199935913 mini-batch correct is: tensor(11.7188)
step 1 loss is: 11.613815307617188
step 1 loss is: 11.613815307617188 mini-batch correct is: tensor(7.8125)
step 2 loss is: 15.141844749450684

Guess you like

Origin blog.csdn.net/weixin_46815330/article/details/114500237