6-7Pytorch builds cifar10 training script (below)

Insert picture description here
Compared with the above. . Use the board to visualize the training process. After the training, the log will be generated under the log file.
Enter the command in the terminal

tensorboard --logdir ./

turn on
Insert picture description here
Insert picture description here

# -*- encoding: utf-8 -*-
"""
@File    : train.py
@Time    : 2021-03-07 16:24
@Author  : XD
@Email   : [email protected]
@Software: PyCharm
"""
import os

import torch
import torch.nn as nn
import torchvision
import tensorboardX

from vggnet import VGGNet

from load_cifar10 import train_loader
from load_cifar10 import test_loader


#判断是否有gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#遍历200次
epoch_num = 2

#学习率
lr = 0.01

#batch_size = 128
batch_size = 128

net = VGGNet().to(device)

#loss多分类问题,交叉熵来定义
loss_func = nn.CrossEntropyLoss()

#定义优化器
optimizer = torch.optim.Adam(net.parameters(),lr = lr)
#optimizer = torch.optim.SGD(net.parameters(),lr = lr,
#                            monmentum = 0.9,weight_decat = 5e-4)

#质数衰减学习率
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size = 1,gamma = 0.9)

if not os.path.exists("log"):
    os.mkdir("log")
writer = tensorboardX.SummaryWriter("log")

step_n = 0

for epoch in range(epoch_num):
    print(" epoch is: ", epoch)
    net.train() #train BN dropout

    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = net(inputs)
        loss = loss_func(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print("step:",i,"loss is:",loss.item())

        _,pred = torch.max(outputs.data,dim = 1)

        correct = pred.eq(labels.data).cpu().sum()

        # print(" epoch is: ", epoch)
        # print("step:",i,"loss is:",loss.item(),
        #       "mini-batch correct is:",100.0 * correct / batch_size)
        # print("lr is:", optimizer.state_dict()["param_groups"][0]["lr"])
        #x = torch.tensor([1.0])
        #x.item()
        # 1.0
        writer.add_scalar("train loss:",loss.item(),global_step = step_n)
        writer.add_scalar("train correct",
                          100.0 * correct.item(),global_step = step_n)

        im = torchvision.utils.make_grid(inputs)
        writer.add_image("train im",im,global_step = step_n)

        step_n += 1
    if not os.path.exists("models"):
        os.mkdir("models")
    torch.save(net.state_dict(),"models\{}.pth".format(epoch + 1))

    scheduler.step()

    #编写一个测试脚本
    sum_loss = 0
    sum_correct = 0
    for i, data in enumerate(test_loader):
        net.eval()
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = net(inputs)
        loss = loss_func(outputs, labels)
        _, pred = torch.max(outputs.data, dim=1)
        correct = pred.eq(labels.data).cpu().sum()

        sum_loss += loss.item()
        sum_correct += correct.item()

        writer.add_scalar("test loss:", loss.item(),global_step = step_n)
        writer.add_scalar("test correct:",
                          100.0 * correct.item() / batch_size,global_step = step_n)
        writer.add_image("test im",im,global_step = step_n)
    test_loss = sum_loss * 1.0 / len(test_loader)
    test_correct = sum_correct * 100.0 / len(test_loader) / batch_size

    print("epoch is:", epoch + 1, "loss is:", test_loss,
          "test correct is:", test_correct)

writer.close()

Guess you like

Origin blog.csdn.net/weixin_46815330/article/details/114682612