分布式强化学习(Distributed RL)入门

参考视频:周博磊强化学习纲要

阅读本文需要强化学习基础,可以阅读我以前的文章:强化学习纲要(周博磊课程)强化学习实践教学

分布式系统

一般情况下我们做的论文课题都是小规模的,使用的都是一个相对较小的数据库,因此使用单机系统基本可以完成任务。但现实生活中的数据往往是巨量的,我们需要一个完整的分布式系统来处理这种大规模的数据。算法和结果只是冰山一角,只有拥有一个好的系统和框架作为支撑,才能得到好的算法和实验结果。

分布式系统需要满足:

  • 一致性:确保多节点的协调运作并且结果和单机运行的结果一致。
  • 容错性:当分布式环境工作时,其中某一个节点出现错误(机器宕机等),任务能够分配到其他机器,工作也能正常。
  • 交流:分布式系统需要I\O和分布式文件系统的知识。

在这里插入图片描述

分布式系统中存在分布式学习的模型和数据,因此有两种并行。一种是不同的机器做一个网络不同部分的计算。第二种是一套机器拥有一个单独的模型拷贝,但是分配到的数据是不同的,计算结果汇总。

在这里插入图片描述

一台电脑可以拥有多块显卡,因此每一块显卡都可以负责模型的一部分,每台电脑都可以放置一个模型,然后把数据分配到每台电脑的每一块显卡上,这就是上面两种并行方法的综合运用。


使用算法和模型的时候需要在机器之间传输信息,怎么实施信息之间的交互是需要解决的问题。一种更新参数的方法是使用Parameter Server。我们可以用另外一台机器接收各个机器传回来的模型参数,然后给这些参数取平均得到更新值,然后把更新后的参数返回给各个机器,使得每一台机器都可以保持相同的参数进行分布式计算。这里我们也可以在各个机器中计算梯度后只把梯度传递回主机,主机中乘上一个学习率即可。

在这里插入图片描述


对于模型的更新有两种常见的方法:同步更新和异步更新。

在这里插入图片描述


上面的方法需要一个主机Parameter Sever,如果主机出现错误,整个训练就会失败。因此我们可以不用主机,这也叫做分散异步随机梯度下降,机器中间点对点传输梯度来更新参数。

在这里插入图片描述


分布式优化:asych SGD,不加lock,这样会导致一个进程读取的参数被另一个进程抢先更新的情况。但Hogwild给出证明两种算法在一定情况下结果趋紧一致。asych SGD是并行系统中广泛应用的设计。

在这里插入图片描述

import torch.multiprocessing as mp
from model import MyModel

def train(model):
    # Construct data_loader,optimizer,etc.
    for data,labels in data_loader:
        optimizer.zero_frad()
        loss_fn(model(data),labels).backward()
        opeimizer.step()      # update the shared parameters
        
 if __name__ == "__main__":
    num_processes = 4
    model = MyModel()
    # NOTE: this is required for the "fork" method to work
    model.share_memory() 
    processes = []
    for rank in range(num_processes):
        p = mp.Process(target-train,args=(model,))
        p.start()
    for p in processes:
        p.join()

MapReduce:分布式学习的开山鼻祖算法

在这里插入图片描述


DisBelief

在这里插入图片描述


AlexNet

在这里插入图片描述


分布式强化学习系统

现在我们考虑强化学习系统的哪些部分可以实行分布式设计。

在这里插入图片描述

我们可以从以上三点出发:考虑增加更多的环境,让更多的智能体同时工作,并且让多个episode同时执行。

在分布式强化学习之中,我们需要多个环境,多个智能体进行交互。因此在不同的机器中都可以有一个环境和一个智能体,得到多个轨迹,然后传回给learner,learner对参数进行更新,然后把更新后的参数传回给不同机器中的智能体。

经典算法

分布式强化学习系统的进展:

在这里插入图片描述


GORILA系统是在DQN的基础上进行分布式加速:

在这里插入图片描述


A3C

A3C相对于GORILA取缔了reply memory,每个线程都保留了自己的actor,因此每个线程的轨迹都是具有多样化的,可以直接采样进行学习。

在这里插入图片描述

部分代码:

processes = []
for rank in range(args.processes):
    p = mp.Process(target=train, args=(shared_model, shared_optimizer, rank, args, info))
    p.start()
    processes.append(p)
for p in processes:
    p.join()

A2C:基于A3C的改进

在这里插入图片描述

在这里插入图片描述


Apex-X

在这里插入图片描述


IMPALA

在这里插入图片描述


RLlib

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述


Evolution Strategies

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/tianjuewudi/article/details/120697733