联邦学习FedAvg-基于去中心化数据的深度网络高效通信学习

        随着计算机算力的提升,机器学习作为海量数据的分析处理技术,已经广泛服务于人类社会。 然而,机器学习技术的发展过程中面临两大挑战:一是数据安全难以得到保障,隐私泄露问题亟待解决;二是网络安全隔离和行业隐私,不同行业部门之间存在数据壁垒,导致数据形成“孤岛”无法安全共享,而仅凭各部门独立数据训练的机器学习模型性能无法达到全局最优化。为解决上述问题,谷歌提出了联邦学习(FL,federated learning)技术。

        本文主要对联邦学习的开山之作《Communication-Efficient Learning of Deep Networks from Decentralized Data》 进行重点内容的解读与整理总结。

论文链接:Communication-Efficient Learning of Deep Networks from Decentralized Data

源码实现:https://gitcode.net/mirrors/WHDY/fedavg?utm_source=csdn_github_accelerator 

目录

摘要

1. 介绍

1.1 问题来源

1.2 本文贡献

1.3 联邦学习特性

1.4 联邦优化

1.5 相关工作

1.6 联邦学习框架图

2. 算法介绍

2.1 联邦随机梯度下降(FedSGD)

2.2 联邦平均算法(FedAvg)

3. 实验设计与实现

3.1 模型初始化

3.2 数据集的设置

3.2.1 MNIST数据集

3.2.2 莎士比亚作品集

3.3 实验优化

3.3.1 增加并行性

3.3.2 增加客户端计算量

 3.4 探究客户端数据集的过度优化

3.5 CIFAR实验

3.6 大规模LSTM实验

4. 总结展望

 摘要

现代移动设备拥有大量的适合模型学习的数据,基于这些数据训练得到的模型可以极大地提升用户体验。例如,语言模型能提升语音设别的准确率和文本输入的效率,图像模型能自动筛选好的照片。然而,移动设备拥有的丰富的数据经常具有关于用户的敏感的隐私信息且多个移动设备所存储的数据总量很大,这样一来,不适合将各个移动设备的数据上传到数据中心,然后使用传统的方法进行模型训练。作者提出了一个替代方法,这种方法可以基于分布在各个设备上的数据(无需上传到数据中心),然后通过局部计算的更新值进行聚合来学习到一个共享模型。作者定义这种非中心化方法为“联邦学习”。作者针对深度网络的联邦学习任务提出了一种实用方法,这种方法在学习过程中多次对模型进行平均。同时,作者使用了五种不同的模型和四个数据集对这种方法进行了实验验证。实验结果表明,这种方法面对不平衡以及非独立同分布的数据,具有较好的鲁棒性。在这种方法中,通信所产生的资源开销是主要的瓶颈,实验结果表明,与同步随机梯度下降相比,该方法的通信轮次减少了10-100倍。

1 介绍

1.1 问题来源

        移动设备中有大量数据适合机器学习任务,利用这些数据反过来可以改善用户体验。例如图像识别模型可以帮助用户挑选好的照片。但是这些数据具有高度私密性,并且数据量大,所以我们不可能把这些数据拿到云端服务器进行集中训练。论文提出了一种分布式机器学习方法称为联邦学习(Federal Learning),在该框架中,服务器将全局模型下发给客户,客户端利用本地数据集进行训练,并将训练后的权重上传到服务器,从而实现全局模型的更新。

1.2 本文贡献

  • 提出了从分散的存储于各个移动设备的数据中训练模型是一个重要的研究方向
  • 提出了一个简单实用的算法来解决这种在非中心化设置下的学习问题
  • 做了大量实验来评估所提算法

        具体来说,本文介绍了“联邦平均”算法,这种算法融合了客户端上的局部随机梯度下降计算与服务器上的模型平均。作者使用该算法进行了大量实验,结果表明了这种算法对于不平衡且非独立同分布的数据具有很好的鲁棒性,并且使得在非中心存储的数据上进行深度网络训练所需的通信轮次减少了几个数量级。

1.3 联邦学习特性

  • 从多个移动设备中存储的真实数据中进行模型训练比从存储在数据中心的数据中进行模型训练更具优势
  • 由于数据具有隐私,且多个移动设备所存储的数据总量很大,因此不适合将其上传至数据中心再进行模型训练
  • 对于监督学习任务,数据中的标签信息可以从用户与应用程序的交互中推断出来

1.4 联邦优化

        传统分布式学习关注点在于如何将一个大型神经网络训练分布式进行,数据仍然可能是在几个大的训练中心存储。而联邦学习更关注数据本身,利用联邦学习保证了数据不出本地,并根据数据的特点,对学习模型进行改进。相比于典型的分布式优化问题,联邦优化具有几个关键特性:

  • Non-IID:数据的特征和分布在不同参与方间存在差异
  • Unbalanced:一些用户会更多地使用服务或应用程序,导致本地训练数据量存在差
  • Massively distributed:参与优化的用户数>>平均每个用户的数据量
  • Limited communication:无法保证客户端和服务器端的高效通信

 本文重点关注优化任务中非独立同分布和不平衡问题,以及通信受限的临界属性。

注:独立同分布假设(IID)

        非凸神经网络的目标函数:

对于一个机器学习的问题来说,有,即用模型参数w预测实例的损失。

        设有K个client,第k个client的数据点为P_{k},对应的数据集数量为n_{k}=\left | P_{k} \right |上式可写为:

P_{k}上的数据集是随机均匀采样的,称IID设置,此时有:

不成立则称Non-IID。 

1.5 相关工作

        相关工作中,2010年通过迭代平均本地训练的模型来对感知机进行分布式训练,2015年研究了语音识别深度神经网络的分布式训练,在2015论文里研究了使用“软”平均的异步训练方法。这些工作都考虑的是数据中心化背景下的分布式训练,没有考虑具有数据不平衡且非独立同分布特点的联邦学习任务。但是它们提供了一种思路,即通过迭代平均本地训练模型的算法来解决联邦学习的问题。与本文的研究动机相似在这篇论文中讨论了保护设备中的用户数据的隐私的优点。而在这篇论文中,作者关注于训练深度网络,强调隐私的重要性以及通过在每一轮通信中仅共享一部分参数,进而降低通信开销;但是,他们也没有考虑数据的不平衡以及非独立同分布性,并且他们的研究工作缺乏实验评估。

1.6 联邦学习框架图

2 算法介绍

2.1 联邦随机梯度下降(FedSGD)

设置固定的学习率η,对K个客户端的数据计算其损失梯度:

中心服务器聚合每个客户端计算的梯度,以此来更新模型参数:

其中,

2.2 联邦平均算法(FedAvg)

在客户端进行局部模型的更新:

中心服务器对每个客户端更新后的参数进行加权平均:

每个客户端可以独立地更新模型参数多次,然后再将更新好的参数发送给中心服务器进行加权平均:

FedAvg的计算量与三个参数有关:

  • C:每轮训练选择客户端的比例
  • E:每个客户端更新参数的循环次数所设计的一个因子
  • B:客户端更新参数时,每次梯度下降所使用的数据量

对于一个拥有n_{k}个数据样本的客户端,每轮本地参数更新的次数为:

注:FedSGD只是FedAvg的一个特例,即当参数E=1,B=∞时,FedAvg等价于FedSGD。
 
FedSGD和FedAvg的关系示意图:

3 实验设计与实现

3.1 模型初始化

实验设置
  • 数据集:MNIST中600个无重复的独立同分布样本
  • E=20; C=1; B=50; 中心服务器聚合一次
  • 不同模型使用不同/相同的初始化模型,并通过θ对两模型参数进行加权求和
       

研究模型平均对模型效果的影响:

        这里有两种情况,一种是不同模型使用不同的初始化模型;一种是不同模型使用相同的初始化模型。并且可以通过参数控制权重比进行模型的加权求和。

        可看到,采用不同的初始化参数进行模型平均后,平均模型的效果变差,模型性能比两个父模型都差;采用相同的初始化参数进行模型平均后,对模型的平均可以显著的减少整个训练集的损失,模型性能优于两个父模型。

        该结论是用于实现联邦学习的重要支撑,在每一轮训练时,server发布全局模型,使各个client采用相同的参数模型进行训练,可以有效的减少训练集的损失。

3.2 数据集的设置

        初步研究包括两个数据集三个模型族,前两个模型用于识别MNIST数据集,后一个用于实现莎士比亚作品集单词预测。

3.2.1 MNIST数据集

2NN:拥有两个隐藏层,每层200个神经元的多层感知机模型,ReLu激活;

CNN:两个卷积核大小为5X5的卷积层(分别是32通道和64通道,每层后都有一个2X2的最大池化层);

IDD:数据随机打乱分给100个客户端,每个客户端600个样例;

Non-IDD:按数字标签将数据集划分为200个大小为300的碎片,每个客户端两个碎片;

  • 3.2.2 莎士比亚作品集

LSTM:将输入字符嵌入到一个已学习的8维空间中,然后通过两个LSTM层处理嵌入的字符,每层256个节点,最后,第二个LSTM层的输出被发送到每一个字符有一个节点的softmax输出层,使用unroll的80个字符长度进行训练;

Unbalanced-Non-IID:每个角色形成一个客户端,共1146个客户端;

Balanced-IID:直接将数据集划分给1146个客户端;

3.3 实验优化

        在数据中心存储的优化中,通信开销相对较小,计算开销占主导地位。而在联邦优化中,任何一个单一设备所具有的数据量较少,且现代移动设备有相对快的处理器所以这里更关注通信开销因此,我们想要使用额外的计算来减少训练模型所需通信的轮次主要有两个方法,分别是提高并行度以及增加每个客户端的计算量。

3.3.1 增加并行性

固定参数E,对C和B进行讨论。

  •  当B=∞时,增加客户端比例,效果提升的优势较小;
  • 当B=10时,有显著改善,特别是在Non-IID情况下;
  • 在B=10,当C≥0.1时,收敛速度有明显改进,当用户达到一定数量时,收敛增加的速度不再明显。

3.3.2 增加客户端计算量

对于增加每个客户端的计算量,可以通过减小B或者增加E来实现。

  • 每轮增加更多的本地SGD更新可以显著降低通信成本;
  • 对于Unbalanced-Non-IDD的莎士比亚数据减少通信轮数倍数更多,推测可能某些客户端有相对较大的本地数据集,使得增加本地训练更有价值;

 将上述实验结果用折线图的形式展示,这里蓝色线表示的是联邦随机梯度下降的结果:

  • FedAvg相比FedSGD不仅降低通信轮数,还具有更高的测试精度。推测是平均模型产生了类似Dropout的正则化效益; 

 3.4 探究客户端数据集的过度优化

        在E=5以及E=25的设置下,对于大的本地更新次数而言,联邦平均的训练损失会停滞或发散;因此在实际应用时,对于一些模型,在训练后期减少本地训练周期将有助于收敛。 

3.5 CIFAR实验

在CTFAR数据集上进行实验,模型是TensorFlow教程中的模型包括两个卷积层,两个全连接层和一个线性传输层,大约10^6个参数。下表给出了baselineSGD、FedSGD和FedAvg达到三种不同精度目标的通信轮数。

不同学习率下FedSGD和FedAvg的曲线:

3.6 大规模LSTM实验

 为了证明我们的方法对于解决实际问题的有效性,我们进行了一项大规模单词预测任务。

训练集包含来自大型社交网络的100万个公共帖子。我们根据作者对帖子进行分组,总共有超过50个客户端。我们将每个客户的数据集限制为最多5000个单词。模型是一个256节点的LSTM,其词汇量为10000个单词。每个单词的输入和输出嵌入为192维,并与模型共同训练;总共有4950544个参数,使用10个字符的unroll。

对于联邦平均和联邦随机梯度下降的最佳学习率曲线:

  • 相同准确率的情况下,FedAvg的通信轮数更少;测试精度方差更小;
  • E=1比E=5的表现效果更好; 

4 总结展望

         我们的实验表明,联邦学习可以在实践中实现,因为它可以使用相对较少的几轮通信来训练高质量的模型,这一点在各种模型体系结构上得到了证明:一个多层感知器、两个不同的卷积NNs、一个两层LSTM和一个大规模LSTM。虽然联邦学习提供了许多实用的隐私保护,但是通过差分隐私、安全多方计算提供了可以提供更有力的保障,或者他们的组合是未来工作的一个有趣方向。请注意,这两类技术最自然地应用于像FedAvg这样的同步算法。

参考文章:

https://blog.csdn.net/qq_41605740/article/details/124584939?spm=1001.2014.3001.5506

https://blog.csdn.net/weixin_45662974/article/details/119464191?spm=1001.2014.3001.5506 

https://zhuanlan.zhihu.com/p/515756280 

猜你喜欢

转载自blog.csdn.net/SmartLab307/article/details/132583309