联邦学习【01】杨强第三章横向联邦学习复现

2024-06-02 19:20

本文主要是介绍联邦学习【01】杨强第三章横向联邦学习复现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

环境:有无gpu均可
anaconda环境:conda install pytorch1.13.1 torchvision0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
具体安装命令结合自己的gpu的cuda版本
在这里插入图片描述
项目代码放在
https://gitee.com/ihan1001gitee
https://github.com/ihan1001github
项目架构
在这里插入图片描述

conf.json文件,存放训练参数

{"model_name" : "resnet18","no_models" : 10,"type" : "cifar","global_epochs" :20,"local_epochs" : 3,"k" : 10,"batch_size" : 32,"lr" : 0.001,"momentum" : 0.0001,"lambda" : 0.1
}

在这里插入图片描述

model_name:模型名称
no_models:客户端总数量
type:数据集信息
global_epochs:全局迭代次数,即服务端与客户端的通信迭代次数
local_epochs:本地模型训练迭代次数
k:每一轮迭代时,服务端会从所有客户端中挑选k个客户端参与训练。
batch_size:本地训练每一轮的样本数
lr,momentum,lambda:本地训练的超参数设置

lr(learning rate)表示模型在每一次迭代时更新参数的步长大小。通常需要根据数据集和模型的具体情况来选择一个合适的学习率,如果学习率过高,则容易导致参数更新过快而无法收敛;如果学习率过低,则会使得模型收敛速度变慢甚至无法收敛。例如,可以使用optim.SGD类中的lr参数来设置学习率。momentum是一种加速梯度下降的方法,它可以在更新参数的过程中积累之前的梯度信息,并将其作为当前梯度的一部分。这样可以使得梯度下降更加平滑,避免震荡或者陷入局部极小值。通常可以将momentum设置为0.9左右,具体取值也需要根据数据集和模型来进行调整。例如,可以使用optim.SGD类中的momentum参数来设置动量。lambda是一种正则化方法,用于避免模型过拟合。正则化可以通过对模型的损失函数添加一个惩罚项来实现。通常有两种常见的正则化方法:L1正则化和L2正则化。L1正则化会使得模型的权重向量变得更加稀疏,可以通过torch.nn.L1Loss或者optim.L1Loss来实现;L2正则化则会使得模型的权重向量变得更加平滑,可以通过torch.nn.MSELoss或者optim.MSELoss来实现。参数lambda用于控制正则化项的大小,通常需要根据数据集和模型的具体情况来进行调整。

dataset.py

'''
-*- coding: utf-8 -*-
@Time    : 2023/11/23 17:35
@Author  : ihan
@File    : FederateAI
@Software: PyCharm
'''import torch
from torchvision import datasets, transforms#获取数据集
#调用语句 train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"])
def get_dataset(dir, name):if name == 'mnist':train_dataset = datasets.MNIST(dir, train=True, download=True, transform=transforms.ToTensor())print(train_dataset)print("--------------------")'''dir 数据路径  train=True 代表训练数据集   Flase代表测试数据集download  是否从互联网下载数据集并存放在dir中transform=transforms.ToTensor() 表示将图像转换为张量形式'''eval_dataset = datasets.MNIST(dir, train=False, transform=transforms.ToTensor())print(eval_dataset)print("--------------------")elif name == 'cifar':'''transforms.Compose 是将多个transform组合起来使用(由transform构成的列表)'''transform_train = transforms.Compose([#transforms.RandomCrop(32, padding=4),#transforms.RandomCrop: 切割中心点的位置随机选取#随机裁剪原始图像,裁剪的大小为32x32像素,padding参数表示在图像的四周填充4个像素,以避免裁剪后图像边缘信息丢失。transforms.RandomHorizontalFlip(),#以0.5的概率对图像进行随机水平翻转,增加数据集的多样性。transforms.ToTensor(),#将图像转换为张量形式,便于后续处理。#transforms.Normalize: 给定均值:(R,G,B) 方差:(R,G,B),将会把Tensor正则化transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),#对图像进行标准化操作,减去均值并除以标准差。这里使用的均值和标准差是CIFAR-10数据集的全局均值和标准差,可以使模型更容易收敛。])print(transform_train)print("--------------------")transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])print(transform_test)print("--------------------")train_dataset = datasets.CIFAR10(dir, train=True, download=True,transform=transform_train)eval_dataset = datasets.CIFAR10(dir, train=False, transform=transform_test)return train_dataset, eval_dataset
# if __name__ == '__main__':
#     train_datasets, eval_datasets = get_dataset("./data/", "cifar")
#     # print(train_datasets)
#     # print(eval_datasets)

在这里插入图片描述
server.py

'''
-*- coding: utf-8 -*-
@Time    : 2023/11/23 17:37
@Author  : ihan
@File    : FederateAI
@Software: PyCharm
'''import models, torchclass Server(object):'''定义构造函数。在构造函数中,服务端的工作包括:第一,将配置信息拷贝到服务端中;第二,按照配置中的模型信息获取模型,这里我们使用torchvision的models模块内置的ResNet - 18模型。'''def __init__(self, conf, eval_dataset):# 导入配置文件self.conf = conf# 根据配置获取模型文件self.global_model = models.get_model(self.conf["model_name"])# 生成一个测试集合加载器                                                           # 设置单个批次大小     # 打乱数据集self.eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=self.conf["batch_size"], shuffle=True)'''定义模型聚合函数。前面我们提到服务端的主要功能是进行模型的聚合,因此定义构造函数后,我们需要在类中定义模型聚合函数,通过接收客户端上传的模型,使用聚合函数更新全局模型。聚合方案有很多种,本节我们采用经典的FedAvg算法。'''# 全局聚合模型# weight_accumulator 存储了每一个客户端的上传参数变化值/差值def model_aggregate(self, weight_accumulator):# 遍历服务器的全局模型for name, data in self.global_model.state_dict().items():# 更新每一层乘上学习率update_per_layer = weight_accumulator[name] * self.conf["lambda"]# 因为update_per_layer的type是floatTensor,所以将起转换为模型的LongTensor(有一定的精度损失)if data.type() != update_per_layer.type():data.add_(update_per_layer.to(torch.int64))else:data.add_(update_per_layer)'''定义模型评估函数。对当前的全局模型,利用评估数据评估当前的全局模型性能。通常情况下,服务端的评估函数主要对当前聚合后的全局模型进行分析,用于判断当前的模型训练是需要进行下一轮迭代、还是提前终止,或者模型是否出现发散退化的现象。根据不同的结果,服务端可以采取不同的措施策略。'''def model_eval(self):self.global_model.eval()total_loss = 0.0correct = 0dataset_size = 0for batch_id, batch in enumerate(self.eval_loader):data, target = batchdataset_size += data.size()[0]if torch.cuda.is_available():data = data.cuda()target = target.cuda()output = self.global_model(data)total_loss += torch.nn.functional.cross_entropy(output, target,reduction='sum').item()  # sum up batch losspred = output.data.max(1)[1]  # get the index of the max log-probabilitycorrect += pred.eq(target.data.view_as(pred)).cpu().sum().item()acc = 100.0 * (float(correct) / float(dataset_size))total_l = total_loss / dataset_sizereturn acc, total_l

在这里插入图片描述client.py

'''
-*- coding: utf-8 -*-
@Time    : 2023/11/23 17:33
@Author  : ihan
@File    : FederateAI
@Software: PyCharm
''''''
横向联邦学习的客户端主要功能是接收服务端的下发指令和全局模型,利用本地数据进行局部模型训练。'''import models, torch, copyclass Client(object):'''定义构造函数。在客户端构造函数中,客户端的主要工作包括:首先,将配置信息拷贝到客户端中;然后,按照配置中的模型信息获取模型,通常由服务端将模型参数传递给客户端,客户端将该全局模型覆盖掉本地模型;最后,配置本地训练数据,在本案例中,我们通过torchvision 的datasets 模块获取cifar10数据集后按客户端ID切分,不同的客户端拥有不同的子数据集,相互之间没有交集。'''def __init__(self, conf, model, train_dataset, id=-1):self.conf = confself.local_model = models.get_model(self.conf["model_name"])self.client_id = idself.train_dataset = train_datasetall_range = list(range(len(self.train_dataset)))data_len = int(len(self.train_dataset) / self.conf['no_models'])train_indices = all_range[id * data_len: (id + 1) * data_len]self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=conf["batch_size"],sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices))'''定义模型本地训练函数。本例是一个图像分类的例子,因此,我们使用交叉熵作为本地模型的损失函数,利用梯度下降来求解并更新参数值'''def local_train(self, model):for name, param in model.state_dict().items():self.local_model.state_dict()[name].copy_(param.clone())# print(id(model))optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'],momentum=self.conf['momentum'])# print(id(self.local_model))self.local_model.train()for e in range(self.conf["local_epochs"]):for batch_id, batch in enumerate(self.train_loader):data, target = batchif torch.cuda.is_available():data = data.cuda()target = target.cuda()optimizer.zero_grad()output = self.local_model(data)loss = torch.nn.functional.cross_entropy(output, target)loss.backward()optimizer.step()print("Epoch %d done." % e)diff = dict()for name, data in self.local_model.state_dict().items():diff[name] = (data - model.state_dict()[name])# print(diff[name])return diff

main.py

'''
-*- coding: utf-8 -*-
@Time    : 2023/11/23 17:33
@Author  : ihan
@File    : FederateAI
@Software: PyCharm
'''
import argparse, json
import datetime
import os
import logging
import torch, random
import ihanfrom server import *
from client import *
import models, datasetsif __name__ == '__main__':parser = argparse.ArgumentParser(description='Federated Learning')parser.add_argument('-c', '--conf', dest='conf')args = parser.parse_args()#读取配置文件信息。with open(args.conf, 'r') as f:conf = json.load(f)ihan.saveIn(conf["model_name"],conf["no_models"],conf["type"],conf["global_epochs"],conf["local_epochs"],conf["k"])#获取数据集train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"])#分别定义一个服务端对象和多个客户端对象,用来模拟横向联邦训练场景。server = Server(conf, eval_datasets)clients = []for c in range(conf["no_models"]):clients.append(Client(conf, server.global_model, train_datasets, c))print("\n\n")#每一轮的迭代,服务端会从当前的客户端集合中随机挑选一部分参与本轮迭代训练,#被选中的客户端调用本地训练接口local_train进行本地训练,#最后服务端调用模型聚合函数model_aggregate来更新全局模型for e in range(conf["global_epochs"]):candidates = random.sample(clients, conf["k"])weight_accumulator = {}for name, params in server.global_model.state_dict().items():weight_accumulator[name] = torch.zeros_like(params)for c in candidates:diff = c.local_train(server.global_model)for name, params in server.global_model.state_dict().items():weight_accumulator[name].add_(diff[name])server.model_aggregate(weight_accumulator)acc, loss = server.model_eval()print("Epoch %d, acc: %f, loss: %f\n" % (e, acc, loss))ihan.savePara(e,acc,loss)

ihan.py 用于输出实验结果,写入txt

'''
-*- coding: utf-8 -*-
@Time    : 2023/12/25 22:53
@Author  : ihan
@File    : FederateAI
@Software: PyCharm
'''from datetime import datetime
import time		# python里的日期模块def savePara(e,acc,loss):date = time.strftime('%Y-%m-%d %H:%M:%S').split()		# 按空格分开,这里宫格是我自己添加的print(date)     # ['2021-07-22', '16:11:00']hour_minute = date[1].split(':')print(hour_minute)      # 时,分,秒filepath = "/home/"+ date[0]+'.txt'output = "%s:Epoch %d  training accuracy :  %g, train Loss : %f " % (datetime.now(), e, acc, loss)with open(filepath,"a+") as f:f.write(output+'\n')f.close
def saveIn(model_name,no_models,date_type,global_epochs,local_epochs,current_clients):date = time.strftime('%Y-%m-%d %H:%M:%S').split()		# 按空格分开,这里宫格是我自己添加的print(date)     # ['2021-07-22', '16:11:00']hour_minute = date[1].split(':')print(hour_minute)      # 时,分,秒filepath = "/home/"+ date[0]+'.txt'output="%s:model_name %s,no_models %d,date_type %s,global_epochs %d,local_epochs %d,current_clients %d start"\% (datetime.now(), model_name,no_models,date_type,global_epochs,local_epochs,current_clients)with open(filepath, "a+") as f:f.write(output + '\n')f.close

在这里插入图片描述
关键代码解释
更新每一层乘上lambda=1/k

update_per_layer = weight_accumulator[name] * self.conf["lambda"]
data.add_(update_per_layer.to(torch.int64))

在这里插入图片描述weight_accumulator[name]代表
在这里插入图片描述

计算weight_accumulator[name]

for name, params in server.global_model.state_dict().items():weight_accumulator[name].add_(diff[name])   

需要先计算diff[name]

for name, data in self.local_model.state_dict().items():diff[name] = (data - model.state_dict()[name])

这段代码是在遍历self.local_model的状态字典,并计算本地模型参数与全局模型参数之间的差异,将差异存储在diff字典中。

具体来说,对于本地模型中的每个参数,通过访问全局模型的状态字典model.state_dict()来获取对应参数的数值,然后计算本地模型参数与全局模型参数之间的差异,并将这个差异存储在diff字典中,以参数的名称作为键。

这个操作的目的可能是为了计算本地模型相对于全局模型的参数更新量。在一些分布式或异步更新的训练算法中,本地模型通常是在单个机器上训练得到的,而全局模型则是由多个机器上的模型进行聚合得到的。为了将本地模型的参数更新同步到全局模型,需要计算本地模型参数与全局模型参数之间的差异,并将这个差异用于更新全局模型的参数。

完成

server.model_aggregate(weight_accumulator)

在这里插入图片描述

这篇关于联邦学习【01】杨强第三章横向联邦学习复现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1024862

相关文章

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]

hdu 2602 and poj 3624(01背包)

01背包的模板题。 hdu2602代码: #include<stdio.h>#include<string.h>const int MaxN = 1001;int max(int a, int b){return a > b ? a : b;}int w[MaxN];int v[MaxN];int dp[MaxN];int main(){int T;int N, V;s

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识