超参数自动优化方法PBT(Population Based Training)

我们知道,机器学习模型的效果好坏很大程度上取决于超参的选取。人肉调参需要依赖经验与直觉,且花费大量精力。PBT(Population based training)是DeepMind在论文《Population Based Training of Neural Networks》中提出的一种异步的自动超参数调节优化方法。以往的自动调节超参方法可分为两类:parallel search和sequential optimization。前者并行执行很多不同超参的优化任务,优点是可以并行利用计算资源更快找到最优解;后者需要利用之前的信息来进行下一步的超参优化,因此只能串行执行,但一般能得到更好的解。PBT完美地结合两种方法,兼具两者优点。它被应用于一些领域取得了不错的效果。如DeepMind的论文《Human-level performance in first-person multiplayer games with population-based deep reinforcement learning》将之用于第一人称多人游戏使AI达到人类水平。还有今年UC Berkeley的论文《Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules》中用PBT来自动学习data augmentation策略,在几个benchmark上达到了不错的精度。另外,最近自动驾驶公司Waymo也称将PBT应用于识别任务,与手工调参相比可以提高精度和加快训练速度。

PBT开局与parallel search类似,会并行训练一批随机初始化的模型。过程中它会周期性地将表现好的模型替换表现不好的模型(exploitation),同时再加上随机扰动(主要是为了exploration)。PBT与其它方法的一个重要不同是它在训练的过程中对超参进行调节,因此可以更快地发现超参和优异的schedule。论文《Population Based Training of Neural Networks》中的示意图非常清楚地示意了整个过程,及与其它方法的区别:
在这里插入图片描述

PBT是一种很通用的方法,可以用于很多场景,其一般套路如下:

  1. Step:对模型训练一步。至于一步是一次iteration还是一个epoch还是其它可以根据需要指定。
  2. Eval:在验证集上做评估。
  3. Ready: 选取群体中的一个模型来进行下面的exploit和explore操作(即perturbation)。这个模型一般是上次做过该操作后经过指定的时间(或迭代次数等)。
  4. Exploit: 将那些经过评估比较烂的模型用那些比较牛叉的模型替代。
  5. Explore: 对上一步产生的复制体模型加随机扰动,如加上随机值或重采样。

Ray中实现了PBT算法。Ray中关于PBT有三个example:一个是learning rate搜索pbt_example.py,另一个是强化学习算法PPO的超参数搜索pbt_ppo_example.py。还有一个是pbt_tune_cifar10_with_keras.py。我们来看下最简单的pbt_example.py。其中的PBTBenchmarkExample类继承自Trainable类,它是一个toy的模拟环境,假设在模型训练过程中最优的learning rate是变化的,是accuracy的函数。目标是找到learning rate的schedule。它的核心函数是_train(),这里会模拟最优的learning rate。

然后看主函数,首先通过ray.init()初始化ray,然后创建PopulationBasedTraining对象,接着通过run()函数开始超参搜索过程。

    pbt = PopulationBasedTraining(
        time_attr="training_iteration",
        metric="mean_accuracy",
        mode="max",
        perturbation_interval=20,
        hyperparam_mutations={
            # distribution for resampling
            "lr": lambda: random.uniform(0.0001, 0.02),
            # allow perturbations within this set of categorical values
            "some_other_factor": [1, 2],
        })
        
    run(
        PBTBenchmarkExample,
        name="pbt_test",
        scheduler=pbt,
        reuse_actors=True,
        verbose=False,
        **{
            "stop": {
                "training_iteration": 2000,
            },
            "num_samples": 4,
            "config": {
                "lr": 0.0001,
                # note: this parameter is perturbed but has no effect on
                # the model training in this example
                "some_other_factor": 1,
            },
        })    

先看第一步,PopulationBasedTraining的实现在python/ray/tune/schedulers/pbt.py中。它继承自FIFOScheduler类。构造函数中几个主要参数:

  • time_attr: 用于定义训练时长的测度,要求单调递增,比如training_iteration
  • metric: 训练结果衡量目标。
  • mode: 上面metric属性是越高越好,还是越低越好。
  • perturbation_interval: 模型会以time_attr为间隔来进行perturbation。
  • hyperparam_mutations: 需要变异的超参。它是一个dict,对于每个key对应list或者function。如果没设这个,就需要在custom_explore_fn中指定。
  • quantile_fraction: 决定按多大比例将表现好的头部模型克隆到尾部模型。
  • resample_probability: 当对超参进行exploration时从原分布中重新采样的概率,否则会根据现有的值调整。
  • custom_explore_fn: 自定义的exploration函数。

第二步中run()函数实现在ray/python/ray/tune/tune.py中:

def run(run_or_experiment, name=None, ...):
    trial_executor = traial_executor or RayTrialExecutor(...)
    experiment = run_or_experiment
    if not isinstance(run_or_experiment, Experiment):
    	if not isinstance(run_or_experiment, Experiment):
    	experiment = Experiment(...)
    ...
    runner = TrialRunner(
        search_alg=search_alg or BasicVariantGenerator(),
        scheduler=scheduler or FIFOScheduler(),
        local_checkpoint_dir=experiment.checkpoint_dir,
        remote_checkpoint_dir=experiment.remote_checkpoint_dir,
        sync_to_cloud=sync_to_cloud,
        checkpoint_period=global_checkpoint_period,
        resume=resume,
        launch_web_server=with_server,
        server_port=server_port,
        verbose=bool(verbose > 1),
        trial_executor=trial_executor)
        
    runner.add_experiment(experiment)
    ...
    while not runner.is_finished():
       runner.step()
       ...
       
	wait_for_sync()
	...
	return ExperimentAnalysis(runner.checkpoint_file, trials=trials)

第一个参数run_or_experiment是要训练的目标任务,参数scheduler就是上面创建的PopulationBasedTraining,负责超参搜索时的调度。

其中几个关键类关系如下图:
在这里插入图片描述
SearchAlgorithm的实现类BasicVariantGenerator会根据给定的Experiment产生参数变体。每个待训练的参数变体会创建相应的Trial对象。Trial有PENDING, RUNNING, PAUSED, TERMINATED, ERROR几种状态。它会开始于PENDING状态,开始训练后转为RUNNING状态,出错了就到ERROR状态,成功的话就是TERMINATED状态。训练中还可能被TrialScheduler暂停(转入PAUSED状态)并释放资源。

TrialRunner是最核心的数据结构,它管理一系列的Trial对象,并且执行一个事件循环,将这些任务通过TrialExecutor的实现类RayTrialExecutor提交到Ray cluster运行。RayTrialExecutor会负责资源的管理。这里通过Ray分布执行的主要是Trainable的实现类(上例中就是PBTBenchmarkExample)中的_train()函数。RayTrialExecutor对象中的_running维护了正在运行的Trial。在循环中,TrialRunner会通过TrialScheduler的实现类PopulationBasedTraining来进行调度。它的choose_trial_to_run()函数从trial_runner的queue中拿出状态为PENDING或者PAUSED的trial,并且选取离上次做perturbation最久的一个保证尽可能公平。

run函数主要做以下几步:

扫描二维码关注公众号,回复: 8899694 查看本文章
  1. 创建RayTrailExecutor对象(如果没有传入trial_executor的话)。
  2. 如果目标任务不是以Experiment对象形式给出,会按照给定的其它参数构建Experiment对象。
  3. 创建TrialRunner对象,它基于Ray来调度事件循环。
    1. 创建搜索算法对象(如果没给),默认为BasicVariantGenerator(实现在basic_variant.py)。它主要用于产生新的参数变体。
    2. 创建执行实验的调度器(如果没给),默认为FIFOScheduler。上例中给定了PopulationBasedTraining,所以这里就不需要创建了。
    3. 创建TrialRunner对象(实现在trial_runner.py)。并上面创建的Experiment对象通过add_experiment()函数加到TrialRunner对象中。
  4. 进入主循环,通过TrialRunneris_finished()函数判断是否结束。如果没有,就调用TrialRunnerstep()函数执行一步。step()函数的主要工作下面再细说。
  5. 收尾工作。如通过wait_for_sync()函数同步远端目标,记录没有正常结束的trial,返回分析信息。

其中比较关键的是step()函数,其主要流程如下:
在这里插入图片描述

当一个Trial训练结束返回结果时,TrialRunner会调用PopulationBasedTrainingon_trial_result()函数。这里就是PBT的精华了。结合文章开关的PBT一般套路,主要步骤如下:

  1. 如果离上次pertubation的时间还没到指定间隔,则返回让该Trial继续训练。
  2. 调用_quantiles()函数按设定的比例__quantile_fraction得到所有Trial中表现好的头部和表现不好的尾部。
  3. 如果当前trial是比较牛的那一批,那赶紧存成checkpoint,等着被其它trial克隆学习。
  4. 如果很不幸地,当前trial属于比较差的那一批,那就从牛的那批中随机挑一个(为trial_to_clone),然后调用_exploit()函数。该函数会调用explore()函数对trial_to_clone进行扰动,然后将它的参数设置和checkpoint设置到当前trial。这样,当前trial就“洗心革面”,重新出发了。
  5. 如果TrialRunner中有PENDING和PAUSED状态的trial,则请求暂停当前trial,让出资源。否则的话就继续训练着。

最后,总结下主要模块间的大体流程:
在这里插入图片描述

发布了211 篇原创文章 · 获赞 438 · 访问量 148万+

猜你喜欢

转载自blog.csdn.net/ariesjzj/article/details/100047416
今日推荐