详解FedAvg:联邦学习的开山之作

2024-06-08 17:44

本文主要是介绍详解FedAvg:联邦学习的开山之作,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

FedAvg:2017年 开山之作

论文地址:https://proceedings.mlr.press/v54/mcmahan17a/mcmahan17a.pdf
源码地址:https://github.com/shaoxiongji/federated-learning
针对的问题:移动设备中有大量的数据,但显然我们不能收集这些数据到云端以进行集中训练,所以引入了一种分布式的机器学习方法,即联邦学习Federal Learning。在FL中,server将全局模型下放给各client,client利用本地的数据去训练模型,并将训练后的权重上传到server,从而实现全局模型的更新。
论文贡献

  1. 提出了联邦学习这个研究方向,简单来说就是从分散的存储于各设备的数据中训练模型;
  2. 提出了FedAvg算法;
  3. 通过实验验证了FedAvg的可靠性;

总结一下就是,本文提出了FedAvg算法,这种算法融合了client上的局部随机梯度下降和server上的模型平均。作者用该算法做了不少实验,结果表明FedAvg对于unbalanced且non-iid的数据有很好的鲁棒性,并且使得在非数据中心存储的数据上进行深度网络训练所需的通信轮次减少了好几个数量级。
算法介绍

  1. 联邦随机梯度下降算法FedSGD

设定固定的学习率η,对K个client的数据计算损失梯度:
g k = ▽ F k ( w t ) g_k=\bigtriangledown F_k(w_t) gk=Fk(wt)
server将聚合每个服务器计算的梯度,以此来更新模型参数:
w t + 1 ← w t − η ∑ k = 1 K n k n g k = w t − η ▽ f ( w t ) w_{t+1}\leftarrow w_t-\eta\sum\limits_{k=1}^K\frac{n_k}{n}g_k=w_t-\eta\bigtriangledown f(w_t) wt+1wtηk=1Knnkgk=wtηf(wt)

  1. 联邦平均算法FedAvg:

在client进行局部模型的更新:
w t + 1 k ← w t − η g k w_{t+1}^k\leftarrow w_t-\eta g_k wt+1kwtηgk
server对每个client更新后的权重进行加权平均:
w t + 1 ← ∑ k = 1 K n k n w t + 1 k w_{t+1}\leftarrow \sum_{k=1}^K \frac{n_k}{n}w_{t+1}^k wt+1k=1Knnkwt+1k
注意,在这里每个client可以在本地独立地多次更新本地权重,然后将更好的权重参数发给server进行加权平均。这样做的好处是不用每更新一次就去聚合,这大大减少了通信量。
FedAvg的计算量与3个参数有关:

  • C:每轮训练选择client的比例,每一轮通信时只选择C*K个client;(K为client总数)
  • E:每个client更新本地权重时,在本地数据集上训练E轮;
  • B:client更新权重时,每次梯度下降所使用的数据量,即本地数据集的batch size;

对于一个拥有 n k n_k nk个数据样本的client,每轮通信本地参数的更新次数为:
u k = E × n k B u_k=E\times\frac{n_k}{B} uk=E×Bnk
所以我们可知,FedSGD只是FedAvg的一个特例,即当参数 E = 1 , B = ∞ E=1,B=\infty E=1B=时,FedAvg等价于FedSGD。注: B = ∞ B=\infty B=意味着batch size大小就是本地数据集大小。
下面为FedAvg的算法流程图:
FedAvg算法流程图
实验设计与实现
Q1:在训练伊始,需不需要对模型进行统一初始化?
image.png
可见,采用不同的初始化参数进行模型平均,模型性能比两个父模型都差(左图);而统一初始化后,对模型的平均可以显著减少整个训练集的loss,模型性能优于两个父模型(右图)。
该结论是实现FL的重要支持,在每一轮通信时,server有必要发布全局模型,使各client采用相同的参数在本地数据集上进行训练,可以有效减少loss。
Q2:数据集怎么设置?
原文中主要研究了MNIST数据集和一个莎士比亚作品集构建的数据集,但我们在这里主要关注MNIST数据集和Cifar-10数据集,这两个数据集也是以后FL领域工作最常用的。
在模型选择方面,作者选择了多层感知机MLP和卷积神经网络CNN。
在数据集划分方面,作者假设有100个client,对于MNIST数据集,进行了iid和non-iid两种划分:

  • MNIST-iid:数据随机打乱分给100个client,每个client得到600个样例;
  • MNIST-non-iid:按数字label将数据集划分为200个大小为300的碎片,每个client两个碎片,意味着每个client至多只能获得两种label的样例;

对于Cifar-10数据集,做了iid划分。
Q3:实验咋做的?
作者指出,相比于传统模式下训练模型时计算开销为主通信开销较小的情况,在FL中,通信开销才是大头,因此减少通信开销才是我们需要关注的,作者提出可以通过加大计算以减少训练模型所需的通信轮数。作者提出主要有两种方法:提高并行度、增加每个client的计算量
而FedAvg的计算量在前面我们也给出过,再来看一下:
u k = E × n k B u_k=E\times\frac{n_k}{B} uk=E×Bnk
提高并行度:固定参数E,对C和B进行讨论。注:此处C=0时,算法也会选择一个client参与,详见上面的算法流程图。
2NN测试集acc 97%,CNN测试集acc 99%所需的通信轮数

  • B = ∞ B=\infty B=时,增加client的比例C,效果提升的优势较小;
  • B = 10 B=10 B=10时,效果显著改善了,特别是在non-iid情况下;
  • B = 10 , C ≥ 10 B=10,C\geq10 B=10,C10时,收敛速度明显改进,当client到一定数量后,收敛速度增加也不明显了;

增加每个client的计算量:根据公式,可以通过增加E或者减小B实现。
对测试集到达期望acc所需的通信轮数

  • 每个通信轮次内增加更多的本地SGD可以显著降低通信成本;
  • 对于unbalanced-non-iid的莎士比亚数据集减少的通信轮数更多,推测可能某些client有相对较大的本地数据集,这种情况下增加了本地训练的价值;

Q4:FedAvg VS FedSGD?
image.png
蓝色实现即为FedSGD。由图可知,FedAvg相比FedSGD不仅降低通信轮数,还具有更高的测试精度。推测是平均模型产生了类似Dropout的正则化效益。
Q5:加大每个client的计算量会不会导致过拟合?
image.png
加大每个client的计算量(主要体现在加大E),确实可能导致训练损失停滞或发散。所以在实际应用时,在训练后期减少各client的E,或者在loss有震荡的苗头时即刻停止,这样做有助于收敛。
Q6:在Cifar-10数据集上的表现如何?
如下图所示:
image.png
image.png
针对第一张图的一点吐槽,你去拿分布式深度学习去pk单机上的深度学习,去比通信轮数,这不是太不公平了。。。
总结展望
作者证明了FL在实践中是可行的,能够用相对较少的通信轮数训练出高质量的模型。并且提出未来的一个方向就是通过差分隐私、安全多方技术等隐私保护技术去组合FL以提供隐私保护。

这篇关于详解FedAvg:联邦学习的开山之作的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1042867

相关文章

详解C#如何提取PDF文档中的图片

《详解C#如何提取PDF文档中的图片》提取图片可以将这些图像资源进行单独保存,方便后续在不同的项目中使用,下面我们就来看看如何使用C#通过代码从PDF文档中提取图片吧... 当 PDF 文件中包含有价值的图片,如艺术画作、设计素材、报告图表等,提取图片可以将这些图像资源进行单独保存,方便后续在不同的项目中使

Android中Dialog的使用详解

《Android中Dialog的使用详解》Dialog(对话框)是Android中常用的UI组件,用于临时显示重要信息或获取用户输入,本文给大家介绍Android中Dialog的使用,感兴趣的朋友一起... 目录android中Dialog的使用详解1. 基本Dialog类型1.1 AlertDialog(

C#数据结构之字符串(string)详解

《C#数据结构之字符串(string)详解》:本文主要介绍C#数据结构之字符串(string),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录转义字符序列字符串的创建字符串的声明null字符串与空字符串重复单字符字符串的构造字符串的属性和常用方法属性常用方法总结摘

Java中StopWatch的使用示例详解

《Java中StopWatch的使用示例详解》stopWatch是org.springframework.util包下的一个工具类,使用它可直观的输出代码执行耗时,以及执行时间百分比,这篇文章主要介绍... 目录stopWatch 是org.springframework.util 包下的一个工具类,使用它

Java进行文件格式校验的方案详解

《Java进行文件格式校验的方案详解》这篇文章主要为大家详细介绍了Java中进行文件格式校验的相关方案,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、背景异常现象原因排查用户的无心之过二、解决方案Magandroidic Number判断主流检测库对比Tika的使用区分zip

Java实现时间与字符串互相转换详解

《Java实现时间与字符串互相转换详解》这篇文章主要为大家详细介绍了Java中实现时间与字符串互相转换的相关方法,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、日期格式化为字符串(一)使用预定义格式(二)自定义格式二、字符串解析为日期(一)解析ISO格式字符串(二)解析自定义

springboot security快速使用示例详解

《springbootsecurity快速使用示例详解》:本文主要介绍springbootsecurity快速使用示例,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝... 目录创www.chinasem.cn建spring boot项目生成脚手架配置依赖接口示例代码项目结构启用s

Python中随机休眠技术原理与应用详解

《Python中随机休眠技术原理与应用详解》在编程中,让程序暂停执行特定时间是常见需求,当需要引入不确定性时,随机休眠就成为关键技巧,下面我们就来看看Python中随机休眠技术的具体实现与应用吧... 目录引言一、实现原理与基础方法1.1 核心函数解析1.2 基础实现模板1.3 整数版实现二、典型应用场景2

一文详解SpringBoot响应压缩功能的配置与优化

《一文详解SpringBoot响应压缩功能的配置与优化》SpringBoot的响应压缩功能基于智能协商机制,需同时满足很多条件,本文主要为大家详细介绍了SpringBoot响应压缩功能的配置与优化,需... 目录一、核心工作机制1.1 自动协商触发条件1.2 压缩处理流程二、配置方案详解2.1 基础YAML

Python实现无痛修改第三方库源码的方法详解

《Python实现无痛修改第三方库源码的方法详解》很多时候,我们下载的第三方库是不会有需求不满足的情况,但也有极少的情况,第三方库没有兼顾到需求,本文将介绍几个修改源码的操作,大家可以根据需求进行选择... 目录需求不符合模拟示例 1. 修改源文件2. 继承修改3. 猴子补丁4. 追踪局部变量需求不符合很