PyTorch代码学习-ImageNET训练

PyTorch代码学习-ImageNET训练

文章说明:本人学习pytorch/examples/ImageNET/main()理解(待续)

# -*- coding: utf-8 -*-
import argparse  # 命令行解释器相关程序,命令行解释器
import os        # 操作系统文件相关
import shutil    # 文件高级操作
import time      # 调用时间模块

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn        # gpu 使用
import torch.distributed as dist            # 分布式(pytorch 0.2)
import torch.optim                          # 优化器
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

# name中若为小写且不以‘——’开头,则对其进行升序排列
model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))                
    # callable功能为判断返回对象是否可调用(即某种功能)。

# 创建argparse.ArgumentParser对象
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# 添加命令行元素
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--world-size', default=1, type=int,
                    help='number of distributed processes')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
                    help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='gloo', type=str,
                    help='distributed backend')

# 定义参数
best_prec1 = 0

# 定义主函数main()
def main():
    global args, best_prec1
    # 使用函数parse_args()进行参数解析,输入默认是sys.argv[1:],
    # 返回值是一个包含命令参数的Namespace,所有参数以属性的形式存在,比如args.myoption。
    args = parser.parse_args()

########## 使用多播地址进行初始化
    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size)

##### step1: create model and set GPU 
    # 导入pretrained model 或者创建model
    if args.pretrained:
        # format 格式化表达字符串,上述默认arch为resnet18
        print("=> using pre-trained model '{}'".format(args.arch))      
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()
    # 分布式运行,可实现在多块GPU上运行
    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            # 批处理,多GPU默认用dataparallel使用在多块gpu上
            model.features = torch.nn.DataParallel(model.features)           
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        # Wrap model in DistributedDataParallel (CUDA only for the moment)
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)


##### step2: define loss function (criterion) and optimizer
    # 使用交叉熵损失函数
    criterion = nn.CrossEntropyLoss().cuda()                            
    # optimizer 使用 SGD + momentum
    # 动量,默认设置为0.9
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                # 权值衰减,默认为1e-4                 
                                weight_decay=args.weight_decay)         


   # 恢复模型(详见模型存取与恢复)
####step3:optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):                                 # 判断返回的是不是文件
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)                        # load 一个save的对象
            args.start_epoch = checkpoint['epoch']                      # default = 90
            best_prec1 = checkpoint['best_prec1']                       # best_prec1 = 0
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])          # load_state_dict:恢复模型
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

##### step4: Data loading code base of dataset(have downloaded) and normalize
    # 从 train、val文件中导入数据
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    # 数据预处理:normalize: - mean / std
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],       
                                     std=[0.229, 0.224, 0.225])

    # ImageFolder 一个通用的数据加载器
    train_dataset = datasets.ImageFolder(
        traindir,
        # 对数据进行预处理
        transforms.Compose([                      # 将几个transforms 组合在一起
            transforms.RandomSizedCrop(224),      # 随机切再resize成给定的size大小
            transforms.RandomHorizontalFlip(),    # 概率为0.5,随机水平翻转。
            transforms.ToTensor(),                # 把一个取值范围是[0,255]或者shape为(H,W,C)的numpy.ndarray,
                                                  # 转换成形状为[C,H,W],取值范围是[0,1.0]的torch.FloadTensor
            normalize,
        ]))

#######
    if args.distributed:
        # Use a DistributedSampler to restrict each process to a distinct subset of the dataset.
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None
######

    # train 数据下载及预处理
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([ 
            # 重新改变大小为`size`,若:height>width`,则:(size*height/width, size)
            transforms.Scale(256),
            # 将给定的数据进行中心切割,得到给定的size。
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)         # default workers = 4

##### step5: 验证函数
    if args.evaluate:
        validate(val_loader, model, criterion)             # 自定义的validate函数,见下
        return

##### step6:开始训练模型
    for epoch in range(args.start_epoch, args.epochs):
        # Use .set_epoch() method to reshuffle the dataset partition at every iteration
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)      # adjust_learning_rate 自定义的函数,见下

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, is_best)


# 定义相关函数
# def train 函数
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        target = target.cuda(async=True)
        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        # compute output
        output = model(input_var)
        # criterion 为定义过的损失函数
        loss = criterion(output, target_var)        

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # 每十步输出一次
        if i % args.print_freq == 0:     # default=10
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1, top5=top5))


def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()


    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        target = target.cuda(async=True)
        # 这是一种用来包裹张量并记录应用的操作
        """
        Attributes:
        data: 任意类型的封装好的张量。
        grad: 保存与data类型和位置相匹配的梯度,此属性难以分配并且不能重新分配。
        requires_grad: 标记变量是否已经由一个需要调用到此变量的子图创建的bool值。只能在叶子变量上进行修改。
        volatile: 标记变量是否能在推理模式下应用(如不保存历史记录)的bool值。只能在叶变量上更改。
        is_leaf: 标记变量是否是图叶子(如由用户创建的变量)的bool值.
        grad_fn: Gradient function graph trace.

        Parameters:
        data (any tensor class): 要包装的张量.
        requires_grad (bool): bool型的标记值. **Keyword only.**
        volatile (bool): bool型的标记值. **Keyword only.**
        """
        input_var = torch.autograd.Variable(input, volatile=True)
        target_var = torch.autograd.Variable(target, volatile=True)

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   i, len(val_loader), batch_time=batch_time, loss=losses,
                   top1=top1, top5=top5))

    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))

    return top1.avg

# 保存当前节点
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

# 计算并存储参数当前值或平均值
class AverageMeter(object):
    # Computes and stores the average and current value
    """
       batch_time = AverageMeter()
       即 self = batch_time
       则 batch_time 具有__init__,reset,update三个属性,
       直接使用batch_time.update()调用
       功能为:batch_time.update(time.time() - end)
               仅一个参数,则直接保存参数值
        对应定义:def update(self, val, n=1)
        losses.update(loss.data[0], input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))
        这些有两个参数则求参数val的均值,保存在avg中##不确定##

    """
    def __init__(self):
        self.reset()       # __init__():reset parameters

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


# 更新 learning_rate :每30步,学习率降至前的10分之1
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))            # args.lr = 0.1 , 即每30步,lr = lr /10
    for param_group in optimizer.param_groups:       # 将更新的lr 送入优化器 optimizer 中,进行下一次优化
        param_group['lr'] = lr

# 计算准确度
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k
    prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
    """
    maxk = max(topk)
    # size函数:总元素的个数
    batch_size = target.size(0)

    # topk函数选取output前k大个数
    _, pred = output.topk(maxk, 1, True, True)
    ##########不了解t()
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


if __name__ == '__main__':
    main()

   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389

在模型完成训练后,我们需要将训练好的模型保存为一个文件供测试使用,或者因为一些原因我们需要继续之前的状态训练之前保存的模型,那么如何在PyTorch中保存和恢复模型呢?

参考PyTorch官方的这份repo,我们知道有两种方法可以实现我们想要的效果。

方法一(推荐):

第一种方法也是官方推荐的方法,只保存和恢复模型中的参数。

保存

恢复

使用这种方法,我们需要自己导入模型的结构信息。

方法二:

使用这种方法,将会保存模型的参数和结构信息。

保存

恢复

一个相对完整的例子

saving

loading

获取模型中某些层的参数

对于恢复的模型,如果我们想查看某些层的参数,可以:

# 打印网络的结构 print(model)
1
2
3
4
5
6
7
8
9
10
# 定义一个网络
from collections import OrderedDict
model = nn . Sequential ( OrderedDict ( [
                   ( ‘conv1’ , nn . Conv2d ( 1 , 20 , 5 ) ) ,
                   ( ‘relu1’ , nn . ReLU ( ) ) ,
                   ( ‘conv2’ , nn . Conv2d ( 20 , 64 , 5 ) ) ,
                   ( ‘relu2’ , nn . ReLU ( ) )
                 ] ) )
# 打印网络的结构
print ( model )

Out:

如果我们想获取conv1的weight和bias:

文章来源:http://www.aiboy.pub/2017/06/05/How_To_Save_And_Restore_Model/

本站QQ群(242251466)和微信讨论群,欢迎加入:

#group { display: block; margin: 0 auto; }



(adsbygoogle = window.adsbygoogle || []).push({});




猜你喜欢

转载自blog.csdn.net/Jason_mmt/article/details/82687613
今日推荐