Ray Tune相关API介绍

1. 注册可训练的函数或类

ray.tune.register_trainable(name, trainable)

参数:

  • name (str) - 注册的方法或函数名。

  • trainable (obj) - 函数或tune.Trainable类。函数必须采用(config, status_reporter)作为参数,并且在注册的过程中自动转换为类。

 

2. 构造experiment对象

ray.tune.Experiment(name, run, stop=None, config=None, trial_resources=None, repeat=1, local_dir=None, upload_dir='', checkpoint_freq=0, max_failures=3)

参数:

  • name (str) – 名字。

  • run (str) – 要训练的算法或模型。 这可以指内置算法的名称(例如RLLib的DQN或PPO),或者在tune注册表中注册的用户定义的可训练函数或类。

  • stop (dict) - 停止标准。 值可以是TrainingResult中的任何字段,以先到达者为准。 默认为空字典。

  • config (dict) – 特定于算法的配置(例如env,hyperparams)。 默认为空字典。

  • trial_resources (dict) – 每次试验分配的机器资源,例如: {"cpu":64,"gpu":8}。 请注意,除非您在此处指定GPU,否则不会分配GPU。 默认为1个CPU和0个GPU。

  • repeat (int) – 重复每次试验的次数。 默认为1。

  • local_dir (str) – 将训练结果保存到的本地目录。 默认为〜/ ray_results。

  • upload_dir (str) – 同步训练结果的可选URI地址(例如s3:// bucket)。

  • checkpoint_freq (int) – 设置检查点间的训练迭代次数。 值0(默认值)禁用设置检查点。

  • max_failures (int) – 设置尝试从最后一个检查点恢复试验的最多次数。 仅在启用了检查点时适用。 默认为3。

 

3. 运行实验程序

ray.tune.run_experiments(experiments, scheduler=None, with_server=False, server_port=4321, verbose=True, queue_trials=False)

参数:

  • experiments (Experiment | list | dict) - 要运行的实验。

  • scheduler (TrialScheduler) - 用于执行实验的调度程序。在FIFO(默认),MedianStopping,AsyncHyperBand,HyperBand或HyperOpt中进行选择。

  • with_server (bool) - 启动后台Tune服务器。 使用客户端API需要。

  • server_port (int) - 启动TuneServer的端口号。

  • verbose (bool) - 每次试验应打印多少输出。

  • queue_trials (bool) - 当群集当前没有足够的资源来启动试验时,是否对试验进行排队。 在自动扩展群集上运行时,应将其设置为True以启用自动向上扩展。

返回值:

  • trial对象列表,保存每个已执行试验的数据。

4. 调度程序HyperOptScheduler

ray.tune.hpo_scheduler.HyperOptScheduler(max_concurrent=None, reward_attr='episode_reward_mean')

参数:

  • max_concurrent (int | None)– 最大并发试验次数。如果为None,则仅在资源可用时才会对试验进行排队。
  • reward_attr (str) – TrainingResult目标值属性。 这是指一个递增的值,在与HyperOpt交互时在内部被否定,以便HyperOpt可以“最大化”该值。

 

5. 调度程序AsyncHyperBandScheduler

ray.tune.async_hyperband.AsyncHyperBandScheduler(time_attr='training_iteration', reward_attr='episode_reward_mean', max_t=100, grace_period=10, reduction_factor=3, brackets=3)

参数:

  • time_attr (str) – 用于比较时间的TrainingResult 。请注意,你可以传递非时间性的东西,例如training_iteration作为进度的度量,唯一的要求是属性应该单调增加。

  • reward_attr (str) – TrainingResult目标值属性。 与time_attr一样,这可以指任何客观值。 终止进程时将使用此属性。

  • max_t (float) – 每次试验的最大时间单位。 经过max_t时间单位(由time_attr确定)后,试验将停止。

  • grace_period (float) – 终止试验的条件。单位与time_attr命名的属性相同。

  • reduction_factor (float) – 用于设置减半的速率和数量。

  • brackets (int) – bracket的数量。 每个bracket具有不同的减半率,由减少系数指定。

 

6. 调度程序HyperBandScheduler

ray.tune.hyperband.HyperBandScheduler(time_attr='training_iteration', reward_attr='episode_reward_mean', max_t=81)

参数:

  • time_attr (str) – 用于比较时间的TrainingResult 。请注意,你可以传递非时间性的东西,例如training_iteration作为进度的度量,唯一的要求是属性应该单调增加。
  • reward_attr (str) – TrainingResult目标值属性。 与time_attr一样,这可以指任何客观值。 终止进程时将使用此属性。

  • max_t (int) – 每次试验的最大时间单位。 经过max_t时间单位(由time_attr确定)后,试验将停止。该调度器将在时间结束后终止试验。

 

7. 调度程序MedianStoppingRule

ray.tune.median_stopping_rule.MedianStoppingRule(time_attr='time_total_s', reward_attr='episode_reward_mean', grace_period=60.0, min_samples_required=3, hard_stop=True, verbose=True)

参数:

  • time_attr (str) – 用于比较时间的TrainingResult 。请注意,你可以传递非时间性的东西,例如training_iteration作为进度的度量,唯一的要求是属性应该单调增加。
  • reward_attr (str) – TrainingResult目标值属性。 与time_attr一样,这可以指任何客观值。

  • grace_period (float) – 终止试验的条件。单位与time_attr命名的属性相同。

  • min_samples_required (int) – 最小样本计算中位数。

  • hard_stop (bool) – 如果为False,则暂停试验而不是停止试验。 当所有其他试验完成后,暂停试验将恢复并允许以FIFO运行。

  • verbose (bool) – 如果为True,则每次试验报告时都会输出中位数和最佳 结果。 默认为True。

 

8. 可训练模型,函数等的抽象类

ray.tune.trainable.Trainable(config=None, logger_creator=None)

参数:

  • config(obj) - 试验的超参数配置。
  • logdir(str) - 放置训练输出的目录。

重载方法:

  • _train() - 重载该方法实现train():在trainable上调用train()将执行一次训练的逻辑迭代。
  • _save(checkpoint_dir) - 重载该方法实现save():调用save()将可训练的训练状态保存到磁盘,并且restore(path)应该将训练状态恢复到给定状态。

  • _restore(checkpoint_path) - 重载该方法实现restore()。

  • _setup() - 重载该方法实现自定义初始化。

  • _stop() - 重载该方法实现清理和关闭程序。

注:

1)通常,在继承Trainable时,你只需要在这里实现_train,_save和_restore。

2)如果你不需要checkpoint/restore,那么你也可以通过提供my_train(config,reporter)函数并调用以下内容来实现,而不是实现此类。

      register_trainable(“my_func”,train)

     注册它以便与Tune一起使用。该功能将自动转换为该接口(无检查点功能)。

 

9. 实现客户端与正在进行的Tune实验进行交互,需要服务器已开始运行

ray.tune.web_server.TuneClient(tune_address)

方法:

  • get_all_trials() - 返回所有试验的列表(trial_id,config,status)。
  • get_trial(trial_id) - 返回查询试验的最后结果。

  • add_trial(name, trial_spec) - 给相应名字的试验添加配置。

  • stop_trial(trial_id) - 关闭相应id的试验。

 

10. PopulationBasedTraining(PBT)算法

ray.tune.pbt.PopulationBasedTraining(time_attr='time_total_s', reward_attr='episode_reward_mean', perturbation_interval=60.0, hyperparam_mutations={}, resample_probability=0.25, custom_explore_fn=None)

参数:

  • time_attr (str) – 用于比较时间的TrainingResult 。请注意,你可以传递非时间性的东西,例如training_iteration作为进度的度量,唯一的要求是属性应该单调增加。
  • reward_attr (str) – TrainingResult目标值属性。 与time_attr一样,这可以指任何客观值。 终止进程时将使用此属性。

  • perturbation_interval (float) – 模型将在time_attr的这个时间间隔内考虑扰动。 请注意,扰动会导致检查点开销,因此你不应将此设置为过于频繁。

  • hyperparam_mutations (dict) – Hyperparams变异。 格式如下:对于每个键,可以提供列表或函数。 列表指定一组允许的分类值。 函数指定连续参数的分布。 你必须至少指定hyperparam_mutations或custom_explore_fn中的一个。

  • resample_probability (float) – 应用hyperparam_mutations时从原始分布重新采样的概率。 如果不重新采样,如果连续,则值将被因子1.2或0.8扰动,或者如果离散则将值更改为相邻值。

  • custom_explore_fn (func) – 你还可以指定自定义探索功能。 在应用来自hyperparam_mutations的内置扰动后,此函数被调用为f(config),并应根据需要返回更新的配置。 你必须至少指定hyperparam_mutations或custom_explore_fn中的一个。

 

 

 

猜你喜欢

转载自blog.csdn.net/u011254180/article/details/81175151