Compared with the above. . Use the board to visualize the training process. After the training, the log will be generated under the log file.
Enter the command in the terminal
tensorboard --logdir ./
turn on
# -*- encoding: utf-8 -*-
"""
@File : train.py
@Time : 2021-03-07 16:24
@Author : XD
@Email : [email protected]
@Software: PyCharm
"""
import os
import torch
import torch.nn as nn
import torchvision
import tensorboardX
from vggnet import VGGNet
from load_cifar10 import train_loader
from load_cifar10 import test_loader
#判断是否有gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#遍历200次
epoch_num = 2
#学习率
lr = 0.01
#batch_size = 128
batch_size = 128
net = VGGNet().to(device)
#loss多分类问题,交叉熵来定义
loss_func = nn.CrossEntropyLoss()
#定义优化器
optimizer = torch.optim.Adam(net.parameters(),lr = lr)
#optimizer = torch.optim.SGD(net.parameters(),lr = lr,
# monmentum = 0.9,weight_decat = 5e-4)
#质数衰减学习率
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size = 1,gamma = 0.9)
if not os.path.exists("log"):
os.mkdir("log")
writer = tensorboardX.SummaryWriter("log")
step_n = 0
for epoch in range(epoch_num):
print(" epoch is: ", epoch)
net.train() #train BN dropout
for i, data in enumerate(train_loader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
loss = loss_func(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("step:",i,"loss is:",loss.item())
_,pred = torch.max(outputs.data,dim = 1)
correct = pred.eq(labels.data).cpu().sum()
# print(" epoch is: ", epoch)
# print("step:",i,"loss is:",loss.item(),
# "mini-batch correct is:",100.0 * correct / batch_size)
# print("lr is:", optimizer.state_dict()["param_groups"][0]["lr"])
#x = torch.tensor([1.0])
#x.item()
# 1.0
writer.add_scalar("train loss:",loss.item(),global_step = step_n)
writer.add_scalar("train correct",
100.0 * correct.item(),global_step = step_n)
im = torchvision.utils.make_grid(inputs)
writer.add_image("train im",im,global_step = step_n)
step_n += 1
if not os.path.exists("models"):
os.mkdir("models")
torch.save(net.state_dict(),"models\{}.pth".format(epoch + 1))
scheduler.step()
#编写一个测试脚本
sum_loss = 0
sum_correct = 0
for i, data in enumerate(test_loader):
net.eval()
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
loss = loss_func(outputs, labels)
_, pred = torch.max(outputs.data, dim=1)
correct = pred.eq(labels.data).cpu().sum()
sum_loss += loss.item()
sum_correct += correct.item()
writer.add_scalar("test loss:", loss.item(),global_step = step_n)
writer.add_scalar("test correct:",
100.0 * correct.item() / batch_size,global_step = step_n)
writer.add_image("test im",im,global_step = step_n)
test_loss = sum_loss * 1.0 / len(test_loader)
test_correct = sum_correct * 100.0 / len(test_loader) / batch_size
print("epoch is:", epoch + 1, "loss is:", test_loss,
"test correct is:", test_correct)
writer.close()