本文主要是介绍论文笔记《Robust Federated Learning with Noisy Labels》,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
读论文:Robust Federated Learning with Noisy Labels
- 应用背景(问题与挑战)
- 相关工作
- Federated learning
- Learning on noisy data
- 解决方案的局限性(motivation)
- 方案介绍
- Problem definition and notations
- Local updates
- local clean set
- naive average features
- local class-wise centroids
- 类特征损失函数
- Global-guided pseudo-labeling
- 总损失值
- Global updates
- Weight aggregation
- Global centroid aggregation
- Experiments
- 总结
论文信息:
作者:Seunghan Yang, Hyoungseob Park, Junyoung Byun, Changick Kim
单位:KAIST, South Korea
期刊:arXiv:2012.01700v1 [cs.LG] 3 Dec 2020
引用信息:Yang S, Park H, Byun J, et al. Robust Federated Learning with Noisy Labels[J]. arXiv preprint arXiv:2012.01700, 2020.
论文链接:https://arxiv.org/abs/2012.01700
应用背景(问题与挑战)
为解决传统数据集中式学习容易造成数据隐私泄露的问题,联邦学习技术应运而出。联邦学习(FL)允许每个参与方本地进行模型训练,然后协作构建出一个共同的深度学习模型,而不需要参与方提交训练数据给服务器。通常来说,联邦学习首先在每个交互轮次中,由服务器广播统一的初始模型,然后服务器选择部分client用他们本地数据参与训练,这些被选中的client将训练后的模型提交给服务器,服务器以一定的聚合算法对这些模型进行聚合。然后重复迭代上述过程,直至全局模型收敛。最早提出的FL算法如FedAvg,利用本地数据集数量作为聚合的权值。这之后,不少优秀的研究将FL应用到实际需求中。如处理非独立同分布的工作[文献]、噪声通信聚合研究[文献]、 domain adaptation(猜测是设备选择算法)[文献]、公平资源分配[文献]和 continual learning[文献]。
然而,这些研究通常没有考虑到实际应用中存在噪声标签的问题。(在原文作者看来)出于隐私保护的考量,联邦学习系统中用户数据的标签应该由机器生成标签。而这些由机器进行标注的标签通常是会存在问题的。在集中式机器学习中,已有各种算法提出来解决该问题。如取样可靠数据最小化噪声标签的影响[文献]、更新标签[文献]、根据匹配的原型估计标签[文献]等。
但这些工作在联邦学习环境中通常会造成性能下降的问题。由于噪声数据分布差异,局部模型(local model)会形成不一致的决策边界,他们的模型参数(weight)方向差异大,会造成聚合困难,如图1 所示。
相关工作
Federated learning
联邦学习有如下的限制:
- statistical challenges (non-i.i.d. data);
- lower network bandwidth;
- inconsistent accuracy across devices;
- noisy communication.
为降低这些限制对FL的影响,FedProx,FedMA, research about the convergence of FedAvg专注于non-iid数据中模型的收敛性。DGC,signSGD,STC利用梯度压缩技术专注于减小系统的带宽限制。对于联邦学习的公平性和通信噪音干扰也分别有研究。但这些研究通常假定本地数据是干净的,这在实际情况中并不总是存在,因为本地数据是由客户端收集并进行标注的。
Learning on noisy data
集中式学习有着许多关于噪声标签的研究。如基于噪声清洗的方法,co-teaching,标签校正方法,通过网络表示能力来纠正错误标签。亦或是在自学习中基于比较样本的特征和其他类别特征来确定样本标签,或是在元学习中优化不易过拟合的参数、采用噪声容忍模型等。但这些方法都是基于集中式学习的,在噪声标签联邦学习环境下,不同参与方有不同的噪声分布,这会导致本地模型的决策边界不一致,甚至有可能局部模型的权重严重发散,全局模型无法收敛。
解决方案的局限性(motivation)
1、中心式学习会造成隐私泄露问题
2、由于不同client噪声标签分布不一致,传统FL算法效率降低,甚至无法收敛。如图2A所示
3、集中式的噪声标签鲁棒算法在FL环境中效率降低
本文提出的解决方法:
1、FL架构
2、设置类特征中心辅助判断模型边界,如图2B所示
方案介绍
Problem definition and notations
符号 | 符号说明 | 备注 |
---|---|---|
centroids | (特征)中心? | global centroids and local centroids |
D k = { ( x k i , y k i ) } i = 1 n k {D}_{k}=\left \{ ({{x}_{k}^i},{{y}_{k}^i}) \right\}_{i=1}^{n_k} Dk={(xki,yki)}i=1nk | 第k个客户端的本地训练数据,包括图片和对应的标签 | – |
f G c {f}_{G}^{c} fGc | 服务器下发的第c类别中心 | local class-wise centroids |
f k c {f}_{k}^{c} fkc | 客户端上的第c类别中心(具体如下所示) | – |
local centroids | 本地(局部)中心是在本地数据集上,使用来自 | Local centroids are the average feature vectors from the global average pooling layer in each local dataset |
y k i {y}_{k}^{i} yki | 由softmax层提取到的真实标签对应的one-hot特征向量,差不多就是标签的意思 | one-hot vector of the ground truth label and |
a pseudo-label extracted by the softmax layer | ||
y ^ k i {\hat y}_{k}^{i} y^ki | 由softmax层提取到的,经由服务器下发的分类器对本地数据样本 i i i进行判别得到的伪标签 | one-hot向量介绍 |
F k {F}_{k} Fk | 第 k k k个客户端的特征提取器 | |
C k {C}_{k} Ck | 第 k k k个客户端的分类器 |
Local updates
在每轮本地训练开始前,被选中参加FL的客户端接收服务器发送的全局模型参数和全局类中心,以如下损失函数进行训练:
L c k = m k l c e ( C k ( F k ( x k ) ) , y k ) + ( 1 − m k ) l c e ( C k ( F k ( x k ) ) , y ^ k ) {L}_{c}^k = {m}_{k}{l}_{ce}({C}_{k}({F}_{k}({x}_{k})),{y}_{k})+(1-{m}_{k}){l}_{ce}({C}_{k}({F}_{k}({x}_{k})),{\hat y}_{k}) Lck=mklce(Ck(Fk(xk)),yk)+(1−mk)lce(Ck(Fk(xk)),y^k)
符号 | 符号说明 | 备注 |
---|---|---|
m k {m}_{k} mk | {0,1}掩码,用于控制当前计算损失值的样本是否为可信样本(The complementary use of feature similarity-based and ground truth labels can help to find accurate confident samples.) | m k i = 1 ( y ~ k i = y k i ) m_{k}^{i}=\mathbb{1}\left(\tilde{y}_{k}^{i}=y_{k}^{i}\right) mki=1(y~ki=yki) |
y ~ k i \tilde{y}_{k}^{i} y~ki | the feature similarity-based labels 。对于第 i i i个样本(假设属于y类),计算本地特征提取器提取改样本的特征与y类特征中心(local class-wise centroid,下文会介绍)的相似度。也就是说,这个标签记录的是该样本的特征最接近某一类别 | y ~ k i = argmax y sim ( f k y , F k ( x k i ) ) \tilde{y}_{k}^{i}=\operatorname{argmax}_{y} \operatorname{sim}\left(\mathbf{f}_{k}^{y}, F_{k}\left(x_{k}^{i}\right)\right) y~ki=argmaxysim(fky,Fk(xki)) |
y ^ k \hat{y}_{k} y^k | 伪类标签,一种标签校正方法的结果,是由服务器下发的全局 F G {F}_{G} FG和 C G {C}_{G} CG计算得到的样本标签 | y ^ k = C G ( F G ( x k ) ) \hat{y}_{k}=C_{G}\left(F_{G}\left(\mathbf{x}_{\mathbf{k}}\right)\right) y^k=CG(FG(xk)) |
l c e ( ⋅ ) l_{ce}(\cdot) lce(⋅) | 交叉熵损失函数 | 计算本地参与训练样本的损失值,结果为一批数的平均值或者交叉熵值的向量 |
那么,这个损失函数计算的是样本标签修正过的损失值。当 m k = 1 {m}_{k}=1 mk=1时, L c k {L}_{c}^k Lck基于原本标签计算;当 m k = 0 {m}_{k}=0 mk=0时, L c k {L}_{c}^k Lck基于修正过的标签计算。这样做的优势在于样本标签更接近真实值,得到的损失值也能更有效的对反向传播进行反馈。并且增加的通信开销较模型参数来说是十分小的(如果模型参数不采用梯度压缩方式传播),因为多传输的是类别的特征中心。
local clean set
考虑到如果直接用所有的local samples来计算local centroids,那么存在noisy label的samples会引入负面效果,因此这里采用基于loss的local centroid生成方式(loss小则说明是clean label,依旧是非常常见的假设)。,这里参照博主AgentDS的文章,里面讲的很详细,这里简单复述一下。
co-teaching理论在我看来是源于深度学习中的记忆模式(A Closer Look at Memorization in Deep Networks)。深度学习首先会学习简单模式(simple patterns),然后才对模型的噪声数据进行拟合。应用该理论,co-teaching中在前期的保留更多的训练样本数量,然后随着训练进行,逐渐以减小 R ( t ) R(t) R(t)比例保留损失值更小的这些比例的样本,从理论上达到在网络进行记忆噪声数据之前丢弃这些噪声样本。
本文通过逐渐减小的 R ( t ) R(t) R(t)比例选择损失值较小的样本参与训练( we refine the dataset D k D_k Dk by selecting R ( t ) R(t) R(t) percentage of small-loss instances on each client as follows):
D ^ k = argmin D k ′ : ∣ D k ′ ∣ ≥ R ( t ) ∣ D k ∣ l c e ( D k ′ ) \hat{D}_{k}=\operatorname{argmin}_{D_{k}^{\prime}:\left|D_{k}^{\prime}\right| \geq R(t)\left|D_{k}\right|} l_{c e}\left(D_{k}^{\prime}\right) D^k=argminDk′:∣Dk′∣≥R(t)∣Dk∣lce(Dk′)
符号 | 符号说明 | 备注 |
---|---|---|
D ^ k \hat{D}_{k} D^k | 近似认为由损失值较小的样本集合构成,比较干净的数据集 | |
D k ′ D_{k}^{\prime} Dk′ | 经过选择的数据集 | |
R ( t ) R(t) R(t) | 控制每轮应该选择多少比例的小损失值样本数量, R ( t ) = 1 − min { t T τ , τ } R(t)=1-\min \left\{\frac{t}{T} \tau, \tau\right\} R(t)=1−min{Ttτ,τ} | We set T and 小 T \mathcal{T} T to 10 and ϵ \epsilon ϵ in our experiments。公式来源于co-teaching,这里的含义为随着t的增加,R(t)在减小 |
1.1 \text {1.1} 1.1 | 集合里的样本数量 | 绝对值符号(狗头) |
naive average features
利用筛选过后的干净数据集 D ^ k \hat{D}_{k} D^k,第 k k k个客户端根据小损失样本,计算每个类的原始平均特征 f ^ k c \hat {f}_{k}^c f^kc:
f ^ k c = 1 n ~ k c ∑ x k i ∈ D ^ k F k ( x k i ) 1 ( y k i = c ) \hat{{f}}_{k}^{c}=\frac{1}{\tilde{n}_{k}^{c}} \sum_{x_{k}^{i} \in \hat{D}_{k}} F_{k}\left(x_{k}^{i}\right) {1}\left(y_{k}^{i}=c\right) f^kc=n~kc1xki∈D^k∑Fk(xki)1(yki=c)
符号 | 符号说明 | 备注 |
---|---|---|
n ~ k c \tilde{{n}}_{k}^{c} n~kc | D ^ k \hat{D}_{k} D^k中标签为 c c c的样本数量 | |
1 ( ⋅ ) {1}(\cdot) 1(⋅) | 指示函数,判断里面内容是否相等,相等返回1,不等返回0 | 这里是在筛选所有样本里为 c c c类的样本 |
在筛选样本后,对这些样本用本地特征提取器 F k F_{k} Fk进行特征提取,并平均得到原始平均特征 f ^ k c \hat {f}_{k}^c f^kc。
local class-wise centroids
由于联邦学习中客户端持有non-iid数据,数据分布不一致,每个客户端的类原始特征中心可能会有许多差异,如果直接平均聚合可能会产生不必要的偏差(这里理解为直接平均可能偏离最接近真实类特征向量的特征中心,为了让这些类中心特征的聚合更贴近全局收敛的特征中心将采用加权聚合的方式)。
f k c = ( 1 − s i m ( f G c , f ^ k c ) 2 ) f G c + s i m ( f G c , f ^ k c ) 2 f ^ k c {f}_{k}^{c}=\left(1-{sim}\left({f}_{G}^{c}, \hat{{f}}_{k}^{c}\right)^{2}\right){f}_{G}^{c}+{sim}\left({f}_{G}^{c}, \hat{{f}}_{k}^{c}\right)^{2} \hat{{f}}_{k}^{c} fkc=(1−sim(fGc,f^kc)2)fGc+sim(fGc,f^kc)2f^kc
符号 | 符号说明 | 备注 |
---|---|---|
s i m ( ⋅ , ⋅ ) {sim}(\cdot, \cdot) sim(⋅,⋅) | 相似度函数,文章采用余弦相似度 | |
f G c f_{G}^c fGc | 由服务器下发得到,是聚合后的全局平均特征中心 |
这样,可以看出当 f ^ k c \hat {f}_{k}^c f^kc和 f G c f_{G}^c fGc相似度较高时,余弦值较大, f ^ k c \hat {f}_{k}^c f^kc占比更大,更大程度保留了 f ^ k c \hat {f}_{k}^c f^kc的类特征。这样做的优点在于深度学习首先学习sample pattern,在训练初期学习较多样本的sample pattern,聚合得到的全局特征中心受到较低的噪声样本的干扰;然后用这个全局特征中心来约束局部特征中心,这种基于相似性的更新可以在一定程度上减小噪声数据的干扰,即使在训练后期对样本进行记忆的时候。(和动量更新的思路有点像,都是利用已有知识来指导更新) 考虑到本文non-iid的设置,如果直接使用全局的类中心特征作为本地的类中心特征,会使得类中心特征与本地数据集分布不一致。当两个中心较为相近则赋予本地特征中心更高的权重以贴近本地数据集的分布;当两个中心相似度较小时,则赋予全局中心更高的权重以指导本地类特征中心的更新。
类特征损失函数
L c e n k = ∑ i = 1 n k m k i ∥ F k ( x k i ) − f k y k i ∥ 2 2 L_{c e n}^{k}=\sum_{i=1}^{n_{k}} m_{k}^{i}\left\|F_{k}\left(x_{k}^{i}\right)-{f}_{k}^{y_{k}^{i}}\right\|_{2}^{2} Lcenk=i=1∑nkmki∥∥∥Fk(xki)−fkyki∥∥∥22
符号 | 符号说明 | 备注 |
---|---|---|
m k i m_{k}^{i} mki | 用于判断该样本是否为可信样本,如果是则计算进损失值 | m k i = 1 ( y ~ k i = y k i ) m_{k}^{i}=\mathbb{1}\left(\tilde{y}_{k}^{i}=y_{k}^{i}\right) mki=1(y~ki=yki) |
f k y k i {f}_{k}^{y_{k}^{i}} fkyki | y k i y_{k}^{i} yki为样本 i i i的标签,即为某一类别 c c c。 | – |
这里计算的是所有参与训练的样本的,当当前训练的样本在 m k i m_{k}^{i} mki判断下为可信样本时, m k i = 1 m_{k}^{i}=1 mki=1,才将二范数平方值加入损失值。也就是说,这个损失函数是约束经过特征提取器提取的样本特征与类中心特征的差异。因为这个类中心特征 f k y k i {f}_{k}^{y_{k}^{i}} fkyki是由本地和全局中心特征加权得到的,因此可以将其用来约束损失值,进而使模型学的参数可以更接近于提取到类中心特征,减小模型参数差异。那为什么要减小模型参数差异呢?(自己的理解)由于DNN存在记忆模式,全局首先习得简单模式,使本地模型参数效果更接近全局效果,一定程度上保证了模型参数不会远离收敛点。也就是文章说的 exploit these local centroids to reduce weight diver-gence of clients’ models。
Global-guided pseudo-labeling
用户在接受全局模型后,用全局的特征提取器和全局分类器来预测本地样本的标签值,作为经由全局模型指导产生的伪标签。
y ^ k = C G ( F G ( x k ) ) \hat{{y}}_{k}=C_{G}\left(F_{G}\left({x}_{{k}}\right)\right) y^k=CG(FG(xk))
采用这个伪标签可以配合公式一计算损失值,如果当前样本不为可信样本时, m k = 0 {m}_{k}=0 mk=0,采用全局模型预测的标签计算损失值。
L c k = m k l c e ( C k ( F k ( x k ) ) , y k ) + ( 1 − m k ) l c e ( C k ( F k ( x k ) ) , y ^ k ) {L}_{c}^k = {m}_{k}{l}_{ce}({C}_{k}({F}_{k}({x}_{k})),{y}_{k})+(1-{m}_{k}){l}_{ce}({C}_{k}({F}_{k}({x}_{k})),{\hat y}_{k}) Lck=mklce(Ck(Fk(xk)),yk)+(1−mk)lce(Ck(Fk(xk)),y^k)
总损失值
L total k = L c k + λ cen L cen k + λ e L e k L_{\text {total }}^{k}=L_{c}^{k}+\lambda_{\text {cen }} L_{\text {cen }}^{k}+\lambda_{e} L_{e}^{k} Ltotal k=Lck+λcen Lcen k+λeLek
L e k = − ∑ i p i log p i L_{e}^k=-\sum_{i}{p}^{i} \log {p}^{i} Lek=−∑ipilogpi为预测结果的熵正则化(the entropy regularization of prediction results), p k i p_{k}^i pki是softmax输出结果 C k ( F k ( x k ) ) {C}_{k}({F}_{k}({x}_{k})) Ck(Fk(xk)),也就是最大可能标签的概率值。 L e k L_{e}^k Lek是一个网络前向计算的结果得到的损失值。 L c k {L}_{c}^k Lck是根据可信标签与全局模型预测的伪标签得到的损失值, L cen k L_{\text {cen }}^{k} Lcen k则是关于中心特征的损失值。
那么,local update algorithm如下:
Global updates
结合FedAvg算法来进行聚合,值得注意的是,由于对局部特征中心进行加权求和,可以有效地减小噪声类的干扰(since it performs a class-wise summation of local centroids, it is less affected by different noise ratios in classes)。
Weight aggregation
典型的用数据数量作为加权聚合的权值,这里的 θ \theta θ为模型参数,用于用户预测类的。
θ G = ∑ k ∈ K n k n θ L , k \theta_{G}=\sum_{k \in K} \frac{n_{k}}{n} \theta_{L, k} θG=k∈K∑nnkθL,k
Global centroid aggregation
考虑到不同客户端有不同的噪声分布,如果只是简单的平均可能会使得收敛效果降低,因此使用基于类特征相似度的加权聚合。
f G c = 1 ∑ k ∈ K w k c ∑ k ∈ K w k c f k c {f}_{G}^{c}=\frac{1}{\sum_{k \in K} w_{k}^{c}} \sum_{k \in K} w_{k}^{c} {f}_{k}^{c} fGc=∑k∈Kwkc1k∈K∑wkcfkc 其中 w k c = s i m ( f ^ G c , f k c ) w_{k}^c=sim(\hat f_G^c,f_k^c) wkc=sim(f^Gc,fkc),为存储的全局 c c c类特征中心 f ^ G c \hat f_G^c f^Gc(可能是上一轮的聚合值)与局部 c c c类的特征中心 f k c f_k^c fkc的相似度。并将所有用户的权值 w k c w_k^c wkc累加起来作为分母,保证特征中心的收敛性。如果两个特征中心比较相近,那么 s i m ( f ^ G c , f k c ) sim(\hat f_G^c,f_k^c) sim(f^Gc,fkc)值大,占有的权值大。那为什么采用这样的加权方式呢,为什么相似度高的局部特征中心占比就要大呢?一方面,早期的 f ^ G c \hat f_G^c f^Gc由simple pattern构成,比较能代表真正的中心特征“收敛中心方向”,通过控制每一轮local centroid的加权聚合来提升global centriod的准确性,如果在一开始接近global centriod的local centriod权值更大,激励这个收敛中心,抑制其他中心。另一方面,类似于真相发现理论(truth discovery algorithm),当全局处在高质量IID数据时,产生的梯度在收敛方向是比较一致的。对比过来,如果clean sample较多时,特征中心在收敛方向也是较为一致的,也即特征中心向量在方向是较为一致的。通过扩大他们的权重,可以保证特征中心的收敛性。
Experiments
待续。。。
总结
本文介绍了联邦学习中处理噪声标签的处理方法。构造了一个由样本标签比较+特征中心+前向计算组合而成的损失函数。基于相似度的可信样本判断,设置局部和全局类特征中心用以校正样本标签。待续。
这篇关于论文笔记《Robust Federated Learning with Noisy Labels》的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!