easy-Fpn源码解读(二):train

easy-Fpn源码解读(二):train

train.py完整代码

import argparse
import os
import time
import uuid
from collections import deque
from typing import Optional

from tensorboardX import SummaryWriter
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader

from backbone.base import Base as BackboneBase
from config.train_config import TrainConfig as Config
from dataset.base import Base as DatasetBase
from logger import Logger as Log
from model import Model
from roi.wrapper import Wrapper as ROIWrapper


def _train(dataset_name: str, backbone_name: str, path_to_data_dir: str, path_to_checkpoints_dir: str, path_to_resuming_checkpoint: Optional[str]):
    dataset = DatasetBase.from_name(dataset_name)(path_to_data_dir, DatasetBase.Mode.TRAIN, Config.IMAGE_MIN_SIDE, Config.IMAGE_MAX_SIDE)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)

    Log.i('Found {:d} samples'.format(len(dataset)))

    backbone = BackboneBase.from_name(backbone_name)(pretrained=True)
    model = Model(backbone, dataset.num_classes(), pooling_mode=Config.POOLING_MODE,
                  anchor_ratios=Config.ANCHOR_RATIOS, anchor_scales=Config.ANCHOR_SCALES,
                  rpn_pre_nms_top_n=Config.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=Config.RPN_POST_NMS_TOP_N).cuda()
    optimizer = optim.SGD(model.parameters(), lr=Config.LEARNING_RATE,
                          momentum=Config.MOMENTUM, weight_decay=Config.WEIGHT_DECAY)
    scheduler = MultiStepLR(optimizer, milestones=Config.STEP_LR_SIZES, gamma=Config.STEP_LR_GAMMA)

    step = 0
    time_checkpoint = time.time()
    losses = deque(maxlen=100)
    summary_writer = SummaryWriter(os.path.join(path_to_checkpoints_dir, 'summaries'))
    should_stop = False

    num_steps_to_display = Config.NUM_STEPS_TO_DISPLAY
    num_steps_to_snapshot = Config.NUM_STEPS_TO_SNAPSHOT
    num_steps_to_finish = Config.NUM_STEPS_TO_FINISH

    if path_to_resuming_checkpoint is not None:
        step = model.load(path_to_resuming_checkpoint, optimizer, scheduler)
        Log.i(f'Model has been restored from file: {path_to_resuming_checkpoint}')

    Log.i('Start training')
    
    while not should_stop:
        for batch_index, (_, image_batch, _, bboxes_batch, labels_batch) in enumerate(dataloader):
            assert image_batch.shape[0] == 1, 'only batch size of 1 is supported'

            image = image_batch[0].cuda()
            bboxes = bboxes_batch[0].cuda()
            labels = labels_batch[0].cuda()

            forward_input = Model.ForwardInput.Train(image, gt_classes=labels, gt_bboxes=bboxes)
            forward_output: Model.ForwardOutput.Train = model.train().forward(forward_input)

            anchor_objectness_loss, anchor_transformer_loss, proposal_class_loss, proposal_transformer_loss = forward_output
            loss = anchor_objectness_loss + anchor_transformer_loss + proposal_class_loss + proposal_transformer_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            losses.append(loss.item())
            summary_writer.add_scalar('train/anchor_objectness_loss', anchor_objectness_loss.item(), step)
            summary_writer.add_scalar('train/anchor_transformer_loss', anchor_transformer_loss.item(), step)
            summary_writer.add_scalar('train/proposal_class_loss', proposal_class_loss.item(), step)
            summary_writer.add_scalar('train/proposal_transformer_loss', proposal_transformer_loss.item(), step)
            summary_writer.add_scalar('train/loss', loss.item(), step)
            step += 1
            
            
            if step == num_steps_to_finish:
                should_stop = True

            if step % num_steps_to_display == 0:
                elapsed_time = time.time() - time_checkpoint
                time_checkpoint = time.time()
                steps_per_sec = num_steps_to_display / elapsed_time
                samples_per_sec = dataloader.batch_size * steps_per_sec
                eta = (num_steps_to_finish - step) / steps_per_sec / 3600
                avg_loss = sum(losses) / len(losses)
                lr = scheduler.get_lr()[0]
                Log.i(f'[Step {step}] Avg. Loss = {avg_loss:.6f}, Learning Rate = {lr:.6f} ({samples_per_sec:.2f} samples/sec; ETA {eta:.1f} hrs)')
                
            if step % num_steps_to_snapshot == 0 or should_stop:
                path_to_checkpoint = model.save(path_to_checkpoints_dir, step, optimizer, scheduler)
                Log.i(f'Model has been saved to {path_to_checkpoint}')

            if should_stop:
                break

    Log.i('Done')


if __name__ == '__main__':
    def main():
        parser = argparse.ArgumentParser()
        parser.add_argument('-s', '--dataset', type=str, choices=DatasetBase.OPTIONS, required=True, help='name of dataset')
        parser.add_argument('-b', '--backbone', type=str, choices=BackboneBase.OPTIONS, required=True, help='name of backbone model')
        parser.add_argument('-d', '--data_dir', type=str, default='./data', help='path to data directory')
        parser.add_argument('-o', '--outputs_dir', type=str, default='./outputs', help='path to outputs directory')
        parser.add_argument('-r', '--resume_checkpoint', type=str, help='path to resuming checkpoint')
        parser.add_argument('--image_min_side', type=float, help='default: {:g}'.format(Config.IMAGE_MIN_SIDE))
        parser.add_argument('--image_max_side', type=float, help='default: {:g}'.format(Config.IMAGE_MAX_SIDE))
        parser.add_argument('--anchor_ratios', type=str, help='default: "{!s}"'.format(Config.ANCHOR_RATIOS))
        parser.add_argument('--anchor_scales', type=str, help='default: "{!s}"'.format(Config.ANCHOR_SCALES))
        parser.add_argument('--pooling_mode', type=str, choices=ROIWrapper.OPTIONS, help='default: {.value:s}'.format(Config.POOLING_MODE))
        parser.add_argument('--rpn_pre_nms_top_n', type=int, help='default: {:d}'.format(Config.RPN_PRE_NMS_TOP_N))
        parser.add_argument('--rpn_post_nms_top_n', type=int, help='default: {:d}'.format(Config.RPN_POST_NMS_TOP_N))
        parser.add_argument('--learning_rate', type=float, help='default: {:g}'.format(Config.LEARNING_RATE))
        parser.add_argument('--momentum', type=float, help='default: {:g}'.format(Config.MOMENTUM))
        parser.add_argument('--weight_decay', type=float, help='default: {:g}'.format(Config.WEIGHT_DECAY))
        parser.add_argument('--step_lr_sizes', type=str, help='default: {!s}'.format(Config.STEP_LR_SIZES))
        parser.add_argument('--step_lr_gamma', type=float, help='default: {:g}'.format(Config.STEP_LR_GAMMA))
        parser.add_argument('--num_steps_to_display', type=int, help='default: {:d}'.format(Config.NUM_STEPS_TO_DISPLAY))
        parser.add_argument('--num_steps_to_snapshot', type=int, help='default: {:d}'.format(Config.NUM_STEPS_TO_SNAPSHOT))
        parser.add_argument('--num_steps_to_finish', type=int, help='default: {:d}'.format(Config.NUM_STEPS_TO_FINISH))
        args = parser.parse_args()

        dataset_name = args.dataset
        backbone_name = args.backbone
        path_to_data_dir = args.data_dir
        path_to_outputs_dir = args.outputs_dir
        path_to_resuming_checkpoint = args.resume_checkpoint

        path_to_checkpoints_dir = os.path.join(path_to_outputs_dir, 'checkpoints-{:s}-{:s}-{:s}-{:s}'.format(
            time.strftime('%Y%m%d%H%M%S'), dataset_name, backbone_name, str(uuid.uuid4()).split('-')[0]))
        os.makedirs(path_to_checkpoints_dir)

        Config.setup(image_min_side=args.image_min_side, image_max_side=args.image_max_side,
                     anchor_ratios=args.anchor_ratios, anchor_scales=args.anchor_scales, pooling_mode=args.pooling_mode,
                     rpn_pre_nms_top_n=args.rpn_pre_nms_top_n, rpn_post_nms_top_n=args.rpn_post_nms_top_n,
                     learning_rate=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay,
                     step_lr_sizes=args.step_lr_sizes, step_lr_gamma=args.step_lr_gamma,
                     num_steps_to_display=args.num_steps_to_display, num_steps_to_snapshot=args.num_steps_to_snapshot, num_steps_to_finish=args.num_steps_to_finish)

        Log.initialize(os.path.join(path_to_checkpoints_dir, 'train.log'))
        Log.i('Arguments:')
        for k, v in vars(args).items():
            Log.i(f'\t{k} = {v}')
        Log.i(Config.describe())

        _train(dataset_name, backbone_name, path_to_data_dir, path_to_checkpoints_dir, path_to_resuming_checkpoint)

    main()

代码解析

先从if name == ‘main’:开始阅读:
使用parser接收输入参数:

import argparse
parser = argparse.ArgumentParser())
parser.add_argument('-s', '--dataset', type=str, choices=DatasetBase.OPTIONS, required=True, help='name of dataset')
...
args = parser.parse_args()

再构建checkpoints的目录

path_to_checkpoints_dir = os.path.join(path_to_outputs_dir, 'checkpoints-{:s}-{:s}-{:s}-{:s}'.format(
    time.strftime('%Y%m%d%H%M%S'), dataset_name, backbone_name, str(uuid.uuid4()).split('-')[0]))
os.makedirs(path_to_checkpoints_dir)

然后调用Config.setup()将接收的参数设置成超参数,Config的setup()函数是一个类方法,不用实例化就可以调用,可以改变类属性原本的值。

from config.train_config import TrainConfig as Config
Config.setup(...)

类方法设置类属性实例如下

class a(object):
    num = 3

    @classmethod
    def setNum(cls, n):
        cls.num = n

    def getNum():
        return a.num

print(a.getNum())
a.setNum(6)
print(a.getNum())

>> 3
>> 6

接着再调用_train(*)函数开始训练

def _train(dataset_name: str, backbone_name: str, path_to_data_dir: str, path_to_checkpoints_dir: str, path_to_resuming_checkpoint: Optional[str]):
    '''
    dataset_name:数据集名称
    backbone_name:提取特征所用的骨干网络
    path_to_data_dir: 数据存储位置
    path_to_checkpoints_dir: 训练出的模型的存储位置
    path_to_resuming_checkpoint:要恢复训练的模型存储位置,关于Optional的解释将在代码下面
    '''
     dataset = DatasetBase.from_name(dataset_name)(path_to_data_dir, DatasetBase.Mode.TRAIN, Config.IMAGE_MIN_SIDE, Config.IMAGE_MAX_SIDE)
    # DatasetBase是封装数据集的基类,from_name是静态方法,不依赖于类实例化,并且返回封装实际需要使用数据集的子类,并且实例化
    
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)
    # DataLoader是数据加载类,dataset是上面实例化的用于数据调用的封装子类
    
    # 如果类里定义了__len__()方法,那么可以调用len(类名)的方式返回该类里想要知道的元素个数
    Log.i('Found {:d} samples'.format(len(dataset)))

    backbone = BackboneBase.from_name(backbone_name)(pretrained=True)
    # 用法跟上面的DatasetBase一样
    
    model = Model(backbone, dataset.num_classes(), pooling_mode=Config.POOLING_MODE,
                  anchor_ratios=Config.ANCHOR_RATIOS, anchor_scales=Config.ANCHOR_SCALES,
                  rpn_pre_nms_top_n=Config.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=Config.RPN_POST_NMS_TOP_N).cuda()
    # backbone是提取特征的骨干网络
    # dataset.num_classes()是数据集的类别数
    # pooling_mode是池化模式
    # anchor_ratios是anchor长宽比例的列表
    # anchor_scales是anchor的基础大小列表
    # rpn_pre_nms_top_n是在使用nms之前按照概率自高到低取rpn_pre_nms_top_n个anchor
    # rpn_post_nms_top_n是使用nms之后按照概率自高到低取rpn_post_nms_top_n个anchor
                      
    optimizer = optim.SGD(model.parameters(), lr=Config.LEARNING_RATE,
                          momentum=Config.MOMENTUM, weight_decay=Config.WEIGHT_DECAY)
    # 优化器是使用动量的SGD算法,weight_decay是权重衰减系数,也就是正则化项系数                          
    
    scheduler = MultiStepLR(optimizer, milestones=Config.STEP_LR_SIZES, gamma=Config.STEP_LR_GAMMA)
    # 上述是学习率自适应衰减,也就是当训练epoch达到milestones值时,初始学习率乘以gamma得到新的学习率;
    
    step = 0
    time_checkpoint = time.time()
    # 获得训练时的时间
    
    losses = deque(maxlen=100)
    # 定义一个长度为100的队列
    
    summary_writer = SummaryWriter(os.path.join(path_to_checkpoints_dir, 'summaries'))
    # 存放训练的可视化过程,summary_writer是一个用于存放一次试验过程的类,也即就像定义一个文件类,参数是生成这个文件夹的地址,存放可视化过程网址的文件在该目录下
    
    should_stop = False
    # 结束训练标志,超参数设置里到80000次会变成True  

    num_steps_to_display = Config.NUM_STEPS_TO_DISPLAY
    num_steps_to_snapshot = Config.NUM_STEPS_TO_SNAPSHOT
    num_steps_to_finish = Config.NUM_STEPS_TO_FINISH

    # 如果传入了上一次训练的结果,将会从上一次训练的地方开始
    if path_to_resuming_checkpoint is not None:
        step = model.load(path_to_resuming_checkpoint, optimizer, scheduler)
        Log.i(f'Model has been restored from file: {path_to_resuming_checkpoint}')

    Log.i('Start training')

    while not should_stop:
        for batch_index, (_, image_batch, _, bboxes_batch, labels_batch) in enumerate(dataloader):
            assert image_batch.shape[0] == 1, 'only batch size of 1 is supported'
            # image_batch.shape = [1, 3, H, W],也即是BCHW:batch_size、channel、Height、Width
            
            # 因为每个批次都只有一个数据,所以下标都为0
            image = image_batch[0].cuda()
            bboxes = bboxes_batch[0].cuda()
            labels = labels_batch[0].cuda()
            
            forward_input = Model.ForwardInput.Train(image, gt_classes=labels, gt_bboxes=bboxes)
            # 在Model中定义的输入类
            
            forward_output: Model.ForwardOutput.Train = model.train().forward(forward_input)
            # model.train()是开启训练模式,调用forward后会返回输出结果类实例
            # model.eval()是开启测试模式

            anchor_objectness_loss, anchor_transformer_loss, proposal_class_loss, proposal_transformer_loss = forward_output
            loss = anchor_objectness_loss + anchor_transformer_loss + proposal_class_loss + proposal_transformer_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            losses.append(loss.item())
            # from collections import deque
            # l = deque(maxlen=10)
            # for i in range(15):
            #     l.append(i)
            # print(l)
            # >> deque([5, 6, 7, 8, 9, 10, 11, 12, 13, 14], maxlen=10)
            
            summary_writer.add_scalar('train/anchor_objectness_loss', anchor_objectness_loss.item(), step)
            summary_writer.add_scalar('train/anchor_transformer_loss', anchor_transformer_loss.item(), step)
            summary_writer.add_scalar('train/proposal_class_loss', proposal_class_loss.item(), step)
            summary_writer.add_scalar('train/proposal_transformer_loss', proposal_transformer_loss.item(), step)
            summary_writer.add_scalar('train/loss', loss.item(), step)
            # 通常,我们会保存每次训练步骤下的损失值,或每个epoch下的正确率,有时也会保存对应的学习率。
           # 保存标量值的成本是很低的,只需要log下任何你认为重要的东西就可以了。
           # 可以用这个命令writer.add_scalar('myscalar', value, iteration)来记录(log)一个标量值。
           # 需要注意的是,如果给程序输入一个PyTorch tensor是不可以的,如果x是一个torch scalar tensor,要记得用x.item()提取标量值。
           
            step += 1

            if step == num_steps_to_finish:
                should_stop = True

            if step % num_steps_to_display == 0:
                elapsed_time = time.time() - time_checkpoint 
                # 训练到可显示阶段的时间,time()生成自公元1970年开始至今的秒数
                time_checkpoint = time.time()
                steps_per_sec = num_steps_to_display / elapsed_time
                # 每一步训练所需的时间
                
                samples_per_sec = dataloader.batch_size * steps_per_sec                
                # 每一个样本所需的时间
                
                eta = (num_steps_to_finish - step) / steps_per_sec / 3600
                # 剩下的数据训练完还需要多少时间
                avg_loss = sum(losses) / len(losses)
                # 平均损失
                lr = scheduler.get_lr()[0]
                # 当前学习率
                Log.i(f'[Step {step}] Avg. Loss = {avg_loss:.6f}, Learning Rate = {lr:.6f} ({samples_per_sec:.2f} samples/sec; ETA {eta:.1f} hrs)')
            
            # 当训练步数到保存点时进行保存          
            if step % num_steps_to_snapshot == 0 or should_stop:
                path_to_checkpoint = model.save(path_to_checkpoints_dir, step, optimizer, scheduler)
                Log.i(f'Model has been saved to {path_to_checkpoint}')

            if should_stop:
                break

    Log.i('Done')

type hints 类型
    type hints 主要是要指示函数的输入和输出的数据类型,数据类型在typing 包中,基本类型有str list dict等等,
使用示例:

def hello(name: str) -> None:
    print('hello {}'.format(name))

    type hints 有很多别的类型,此处主要说Union,Optional, 因为对于python 用到的也比较多。
    Union 是当有多种可能的数据类型时使用,比如函数有可能根据不同情况有时返回str或返回list,那么就可以写成Union[list, str]。Optional 是Union的一个简化,当数据类型中有可能是None时,比如有可能是str也有可能是None,则Optional[str], 相当于Union[str, None]. 注意:与函数有默认参数None有区别!有区别!有区别!不可省略默认参数,如下示例:

原始:def func(args = None):
错:def func(args:Optional[str]) -> None:
对:def func(args:Optional[str] = None) -> None:

type hints 还可以自定义类型等等


猜你喜欢

转载自blog.csdn.net/ThunderF/article/details/104719858