使用模型平均下的深度网络的联邦学习

本文出自Google的Federated Learning of Deep Networks using Model Averaging,主要介绍使用模型平均方法的联邦式学习。

引言

丰富的数据通常是隐私敏感性的,数量大或者两者都有,这将会妨碍使用传统方法登录到数据中心并在此进行训练。我们提倡将训练数据分布在移动设备上,并通过聚合本地计算的更新来学习一个共享模型。这种方法被称为联邦学习。我们提出了一种实用的深度网络联邦学习方法,证明该方法对自然生成的不平衡和无IID(独立且恒等分布)数据分布具有鲁棒性。这种方法运行在相对较少的通信环境(联邦学习的主要约束)中训练高质量的模型。关键点是:尽管我们优化了非凸性损失函数,对来自多个客户机的更新进行参数平均会产生较好的结果。

一、简介

  1. 存在许多用于分布式优化的算法,但这些算法通常具有通信需求且只被一个数据中心网络结构所满足。这些算法的理论合理性和实际性能在很大程度上取决于假定数据在计算节点上是独立且恒等分布的。综上所述,这些需求相当于假设完整的训练集由建模器控制并存储在一个集中的位置。
  2. 我们研究了一种学习技术,它允许用户在不需要集中存储数据的情况下,从这些丰富的数据中去获得共享模型的好处。这种方法还允许我们利用网络边缘可用的廉价计算来扩展学习任务。我们将这种方法称为联邦学习,因为学习任务是由一个中央服务器协调的松散的参与设备联合(客户机)来解决的。每个客户端都有一个本地的训练数据集且从未上传到服务器。相反,每个客户端计算对由中央服务器维护的当前全局模型的更新,并且只有这个更新可以被交流。由于这些更新是特定于改进当前的模型,所以一旦它们被应用了就没有理由去存储它们。
  3. 我们引入了联邦平均算法,它将每个客户机上本地SGD训练与中央服务器执行模型平均的通信循环结合起来。实验证明:它对于不平衡和无IID的分布式数据具有鲁棒性,并且可以将训练深度网络所需要的通信周期减少一到两个数量级。

二、联邦式学习

  1. 联邦学习适合任务的特点:(1)对来自移动设备的真实数据的训练比对数据中心通常可用的代理数据的训练具有明显优势; (2)这种数据是隐私敏感性的或者规模较大的,因此它不适合将其记录到数据中心来进行模型训练;(3)对于监督型任务,数据集上的标签可以从用户与他们设备的交互过程中自然推理出来。
  2. 应用实例:(1)图像分类,例如预测哪些照片最有可能在未来被查看或分享多次;(2)语言模型,可以通过提高解码、下个单词预测甚至预测整个回复来提高触摸屏键盘上的语音识别和文本输入。
  3. 这些训练样本的分布与容易获取的代理数据集有很大的不同:聊天和短信中使用的语言与标准语料库有着很大的不同,人们在手机上拍摄的照片很可能与网络上的普通照片有所不同。另外,这些问题的标签集都是直接可用的:输入的文本是学习语言模型的自标签,照片标签是可以通过用户与照片应用程序的交互过程来定义(哪些照片被删除、分享或查看)。
  4. 联邦学习的隐私性:(1)我们必须考虑到一个攻击者可能通过检查模型参数学习到什么内容,这些参数被优化过程中参与到的客户机所共享。对于真正隐私敏感性的任务,差异化隐私技术可以提供严格的最坏情况下的隐私保证,即使对方有任意的边缘信息。(2)下一个问题是对方可以通过访问到单个客户的更新信息来学习到什么内容。如果一方信任中心服务器,那么加密和其他标准化安全协议对于这种类型的攻击是一个基础的防护。一个强有力的保证可以通过强制执行本地差异化隐私来实现,在这种情况下,我们不是向最终的模型添加噪音干扰,而是干扰每个更新,从而阻止中央服务器对客户机进行任何确定的推断。也可以使用安全多方计算来对多个客户机更新执行聚合,允许使用更少的随机噪声来实现本地的差异化隐私。
  5. 大型数据的优势:在数据中心训练每个客户机所需的网络流量只是每个客户机本地数据集的大小,必须要被传输一次。对于联邦学习,每个客户机的流量是每轮的通信量乘以更新的规模。如果更新规模相对于所需的训练数据数量要小,则后一种数量要小得多。

三、联邦式优化

  1. 联邦优化与典型的分布式优化问题有几个关键的区别:(1)Non-IID:给定客户端的训练数据集通常为基于特定用户对移动设备的使用情况,因此,任何特定用户的本地数据集都不能代表总体分布。(2)不平衡性:一些用户会频繁地使用产生训练集的服务或应用程序,导致一些客户端有着大量的本地训练集,而另外一些用户则只有很少或没有数据。(3)大规模分布式:在实际场景中,我们期望参与优化的客户机数量要远大于每个客户机的平均样例数量。
  2. 我们假设同步更新方案在几轮通信过程中进行,这里有一组固定的K个客户端,每个客户端都有一个固定的本地数据集。在每轮的开始,客户端的一个随机分数C被选择好,然后服务器将当前全局算法状态(当前的模型参数)发送到每个客户端。每个客户端执行基于全局状态和本地数据集的本地计算,然后发送更新到服务器。服务器应用这些更新到它的全局状态,并重复此过程。
  3. 对于一个机器学习问题,我们通常令: f i ( w ) = ϱ ( x i , y i ; w ) f_i(w)=\varrho(x_i,y_i;w) ,这是模型参数w对示例 ( x i , y i ) (x_i,y_i) 的预测损失。我们假定这里有K个客户端用于分享数据,用 P k P_k 表示客户端k上数据点的索引值,令 n k = P k n_k=|P_k| 。于是我们可以使:
               f ( w ) = k = 1 K n k n F k ( w ) w h e r e F k ( w ) = 1 n k i ϵ P k f i ( w ) . f(w)=\sum_{k=1}^{K}\frac{n_k}{n}F_k(w) \quad where \quad F_k(w)=\frac{1}{n_k}\sum_{i\epsilon P_k}f_i(w).
  4. 在联邦优化中,通信代价占据主导地位。此外,我们希望每个客户端每天只参与一个小数量的更新循环。另一方面,因为任何单个设备上的数据集都小于总数据集的大小,且现代智能手机有着相对较快的处理器(包括GPU),与许多型号的通信成本相比,计算基本是免费的。因此,我们的目标是使用额外的计算,目的是减少训练一个模型所需要的通信循环次数。这里有两种我们添加计算的方式:(1)提高并行性:使用更多的客户端在每个通信循环中独立工作;(2)增加每个客户端的计算量。
  5. 我们考虑的(参数化)算法集的一个端点是简单的一次平均,每个客户端都为模型求解来使本地数据损失函数最小化(可能为正则化),然后这些模型被平均得到最终的全局模型。这种方法在独立且恒等分布形式的数据凸情况下进行了广泛研究,在最坏的情况下,生成的全局模型并不比在单个客户端上训练的模型要好。

四、联邦平均算法

  1. SGD可以很自然地应用于联邦优化问题,每轮通信只执行一个小批量梯度计算(比如在随机选择的客户端上)。这种方法计算效率高,但是需要大量的训练才能产生好的模型。这种方式的计算量由三个关键参数控制:C(每轮执行计算的客户端设备的比例);E(每个客户机在每轮上通过其本地数据集执行的训练次数);B(用于客户机更新的批处理大小),其中将 B设为无穷代表将整个本地数据集看作单个批处理量。

  2. 我们令B=无穷,E=1来生成一种可变小批量规模的SGD形式,这个算法每轮选择C的客户端比例,并计算这些客户端所拥有的所有数据的损失梯度,因此C=1相当于全批次(非随机)的梯度下降。我们将这种算法称为联邦式SGD,而批次选择机制不同于通过均匀地随机选择单个样例来选择批次,其批梯度g仍满足于 E [ g ] = f ( w ) E[g]=\triangledown f(w)

  3. 带有一个固定的学习率 η \eta 的分布式梯度下降的典型实现有着每个客户端k计算 g ( k ) = F k ( w t ) g(k)=\triangledown F_k(w_t) ,即当前模型 w t w_t 的本地数据集上平均梯度,然后中央服务器将这些梯度集合起来并应用于更新中:
                   w t + 1 w t η k = 1 K n k n g k w_{t+1} \leftarrow w_t-\eta \sum_{k=1}K\frac{n_k}{n}g_k
    因为 η k = 1 K n k n g k = f ( w t ) \eta \sum_{k=1}K\frac{n_k}{n}g_k=\triangledown f(w_t) ,所以上述表达式也可以表示为:
               k , w t + 1 k w t η g k a n d w t + 1 k = 1 K n k n w t + 1 k . \forall k,w^k_{t+1}\leftarrow w_t-\eta g_k \quad and \quad w_{t+1} \leftarrow \sum_{k=1}{K}\frac{n_k}{n}w_{t+1}^k.

  4. 每个客户端都使用本地数据在当前模型上进行一步梯度下降,然后服务器对得到的模型进行加权平均。对于一个带有 n k n_k 本地样例的客户端,每轮本地更新的数量由 u k = E n k B u_k=E\frac{n_k}{B} 来给出。

  5. 最近的工作表明在实践中,充分参数化的神经网络的损失曲面表现得很好,特别是不像以前所认为的那么容易出现糟糕的局部极小值。当我们从相同的随机初始化开始两个模型,然后再一次独立训练每一个不同的子数据集,我们发现普通的参数平均工作得很好:两个模型的平均值,在完整MNIST训练集上获得的损失函数值明显要低于单独在两个小数据集上进行训练所获得的最好模型。

  6. 实验算法如下所示:
    算法代码

五、实验结果

  1. 对于每个任务,我们选择一个适度大小的代理数据集,这样我们就可以彻底地研究FedAvg算法的超参数。虽然每次单独训练运行相对要小,但我们为这些实验训练了2000多个单独的模型。
  2. 第一个任务是MNIST数字识别任务,它有两种模型构建方式:(1)一个简单双隐层模型,每层有200个单元使用ReLu激活,我们将其定义为MNIST 2NN;(2)一个CNN有着两个5x5卷积层(第一个有着32个通道,第二个有着64个通道,每层后跟着2x2的最大池化),一个全连接层有着512个单元和ReLu激活,和一个最终的softmax输出层。
  3. 为了研究联邦优化,我们也需要明确数据如何在客户端上分布。我们研究了两种分割MNIST数据的方式:IID(数据被打乱,然后被分到100个客户机,每个客户机接收到600个样本);Non-IID(根据数字标签将数据排序,将其分成大小为300的200个碎片化数据,然后指定100个客户机中每个有2个碎片)。因此,让我们探索一下我们的算法对盖度non-IID数据的破坏程度。
  4. 第二个任务是语言模型,为了研究联邦优化,我们建立了一个数据集,它来自于莎士比亚全集。这个数据集是明显不平衡的,一些角色只有几句台词,而一些则有着大量台词。使用相同的训练/测试分割,我们也可以形成一个平衡的IID版本的数据集。
  5. 在这个数据集上我们训练了一个堆叠的字符级别的LSTM语言模型,在读取一行中的每个字符后,可以预测下一个字符。该模型以一系列字符作为输入,并将每个字符嵌入到一个学习的8维空间中。嵌入的字符通过2LSTM层进行处理,每个LSTM层有着256个节点。最终第二个LSTM层的输出被发送到每个字符有一个节点的softmax输出层。
  6. 提高并行性:我们首先设置C来测试实验效果,它控制了多客户端的并行度。为了计算达到目标测试准确率所需要的通信轮数,我们为每个参数设置的组合构造了一个学习曲线,使曲线可以单调性改进,然后计算曲线达到目标值的通信轮数。基于实验中所展示的结果,在剩下的实验中,我们固定C=0.1,在计算效率和收敛速度中取得了很好的平衡,对于固定B增加C并没有一个很好的效果,而将B=无穷和B=10的轮数相比起来,可以发现有一个明显的加速。下表展示了对于MNIST模型C值变化时的影响,每个表实体给出了2NN达到测试精度97%和CNN达到测试精度99%时所需要的通信轮数。下图则展示了对于MNIST模型测试精度与通信轮数的变化曲线图,我们将C设置为0.1。
    提高并行性
    提高并行性
  7. 增加每个客户机的计算量:我们将C固定为0.1,每轮为每个客户机增加更多的计算,减少B,增加E或两者都增加。每个客户机每轮更新的预期数量由 给出,只要B足够大,就能够充分利用客户端硬件上可用的并行性,降低它的计算时间基本上没有成本,这是第一个调优的参数。当我们将在完全不同的数字对上训练的模型参数平均时,平均提供了任何优势。因此,我们认为这为这个方案的鲁棒性提供了有力证据。下表展示了在MNIST模型上达到目标精度所需要的通信轮数的加速,下图展示了Shakespeare LSTM模型的学习曲线。可以看到每轮添加更多的本地SGD更新然后再进行模型平均可以产生一个很好的加速效果。
    增加计算量
    增加计算量
  8. 我们推测,除了降低通信成本外,模型平均策略还产生了与dropout类型的正则化好处。我们主要关注泛化性能,但是FedAvg也能有效地优化训练损失,甚至超过了测试集精度停滞不前的程度。下图展示了FedAvg对于优化训练损失的变化影响关系。
    优化训练损失
  9. 当前模型参数仅通过初始化影响每个客户端更新中执行的优化。当E趋近于无穷时,至少对于一个凸问题最终的初始条件应当是不相关的,无论初始化如何都会达到全局最小值。而对于一个非凸问题,只要初始化在同一个区域里,算法就会收敛到相同的局部最小值。实验结果表明:对于一些模型特别是在收敛阶段的后期,每轮本地计算量的衰减(减小E或者增大B)可能是有用的,同样衰减学习率也是有用的。对于E值较大的情况,收敛速度下降的幅度并不大。下图上展示了在 Shakespeare LSTM问题上初始化训练中E值的影响,下则展示了MNIST CNN的实验影响。
    训练次数变化

六、结论和未来工作

我们的实验表明,联邦学习具有重要的前景,因为高质量的模型可以使用相对较少的通信轮数进行训练。下一步的一个重要步骤是,在更大的数据集上对所提出的方法进行进一步的经验评估,这些数据真正地捕捉到了现实世界问题的大规模分布式本质。为了保持算法探索范围是可控的,我们限制自己以朴素SGD为基础。也可以研究我们的方法与其他优化算法,如AdaGrad和ADAM,以及模型结构的变化可以帮助优化,如dropout和批量规范化,这些都是未来工作的一个研究方向。

发布了31 篇原创文章 · 获赞 40 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/librahfacebook/article/details/90245669