Pytorch 训练一些要点记录

我们用Pytroch 训练神经网络模型时,如果数据量稍大,或者网络过深,会造成训练速度过慢,这时,我们就需要保存中间结果,然后下次的时候,从中间结果恢复参数进行继续训练。

记录一下代码的主要写法:

import torch

#导入 torch模块

import argparse

# 一个命令行参数处理很好用的Python 库

parse = argparse.ArgumentParser(description='Pytorch CIFAR10 Training')

# 一般这一行添加代码的作用

parse = argparse.ArgumentParser('--lr' ,default = 0.1, type=float, help='learning rate')

# 定义学习率

parse = argparse.ArgumentParser('--resume' , '-r' , action='store_true',  help='resume from checkpoint')

# 定义是否从检查点恢复模型

use_cuda = torch.cuda.is_available()

bset_acc = 0 # best test accuracy

start_epoch = 0 # start from epoch 0 or last checkpoint epoch

# Data

print('==>Preparing data...')





transform_train = transforms.Compose([

    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),

])

# 定义训练集的数据增强


transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 定义训练集的数据增强

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8)
# 用模块载入训练数据集

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)

# 用模块载入测试数据集

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 定义 类别信息

if args.resume:

    checkpoint = torch.load('./checkpoint/ckpt.t7'

    net = checkpoint ['net']

    best_acc = checkpoint['acc']

    start_epoch = checkpoint['epoch']

if use_cuda:






猜你喜欢

转载自blog.csdn.net/ewqapple/article/details/81063963