FATE —— 二.2.5 Homo-NN定制Trainer以控制训练过程

2023-12-19 11:10

本文主要是介绍FATE —— 二.2.5 Homo-NN定制Trainer以控制训练过程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言

在本教程中,您将学习如何创建和定制您自己的Trainer,以控制培训过程、进行预测并汇总结果以满足您的特定需求。我们将首先向您介绍需要实现的TrainerBase类的接口。然后,我们将提供FedProx算法的工具示例(请注意,这只是一个工具示例,不应在生产中使用),以帮助您更好地理解教练定制的概念。

Trainer基本类

基础

TrainerBase Class是FATE中所有Homo NN培训师的基地。要创建自定义训练器,您需要将位于federatedml.hhomo.trainer_base中的TrainerBase类进行子类化。您必须实现两个必需的函数:

  • “train()”函数:该函数接受五个参数:训练数据集实例(必须是数据集的子类)、验证数据集实例、带有初始化训练参数的优化器实例、损失函数和可能包含预热启动任务的附加数据的额外数据字典。在此函数中,您可以定义Homo NN任务的客户端训练和联合过程。

  • “server_aggregate_procedere()”函数:此函数接受一个参数,一个额外的数据字典,可能包含热启动任务的其他数据。它由服务器调用,您可以在其中定义聚合过程。

还有一个可选的“predict()”函数:它接受一个参数,一个数据集,并允许您定义培训师如何进行预测。如果您想使用FATE框架,您需要确保返回数据的格式正确,以便FATE能够正确显示(我们将在后面的教程中介绍)。"

在Homo NN客户端组件中,“set_model()”函数用于将初始化的模型设置为训练器。开发培训师时,可以使用“set_model()”设置模型,然后在培训师中使用“self.model”访问模型。

这里显示这些接口的源代码:

class TrainerBase(object):def __init__(self, **kwargs):...self._model = None......@propertydef model(self):if not hasattr(self, '_model'):raise AttributeError('model variable is not initialized, remember to call'' super(your_class, self).__init__()')if self._model is None:raise AttributeError('model is not set, use set_model() function to set training model')return self._model@model.setterdef model(self, val):self._model = val@abc.abstractmethoddef train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):"""train_set:数据集实例,必须是数据集子类(federatedml.nn.Dataset.base)的实例,例如,TableData()(来自federatedml.nn.dataset.table)validate_set:数据集实例,但可选的必须是数据集子类的实例(federatedml.nn.dataset.base),例如TableData()(来自federateddl.nn.datadataset.table)优化器:pytorch优化器类实例,例如,t.optim.Adam()、t.optim.SGD()loss:pytorch loss类,例如,nn.BECLoss(),nn.CrossEntropyLoss()"""pass@abc.abstractmethoddef predict(self, dataset):pass@abc.abstractmethoddef server_aggregate_procedure(self, extra_data={}):pass
Fed模式/本地模式

培训师有一个属性“self.fed_mode”,在运行联合任务时设置为True。您可以使用此变量来确定培训师是在联合模式下运行还是在本地调试模式下运行。如果要在本地测试培训器,可以使用“local_mode()”函数将“self.fed_mode”设置为False。

示例:开发工具FedProx

为了帮助您理解如何实现这些函数,我们将通过演示FedProx算法的玩具实现来提供一个具体的示例如该网址。在FedProx中,训练过程与标准FedAVG算法略有不同,因为在计算损失时,需要从当前模型和全局模型计算近端项。我们将带您一步一步地阅读带有注释的代码。

工具FedProx

这是训练器的代码,保存在federatedml.nn.homo.trainer模块中。此培训器实现两个功能:train和server_aggregate_proccure。这些功能可以完成简单的培训任务。该代码包含注释以提供更多详细信息。

from pipeline.component.nn import save_to_fate
%%save_to_fate trainer fedprox.py
import copy
from federatedml.nn.homo.trainer.trainer_base import TrainerBase
from torch.utils.data import DataLoader
# We need to use aggregator client&server class for federation
from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer
# We use LOGGER to output logs
from federatedml.util import LOGGERclass ToyFedProxTrainer(TrainerBase):def __init__(self, epochs, batch_size, u):super(ToyFedProxTrainer, self).__init__()# trainer parametersself.epochs = epochsself.batch_size = batch_sizeself.u = u# Given two model, we compute the proximal termdef _proximal_term(self, model_a, model_b):diff_ = 0for p1, p2 in zip(model_a.parameters(), model_b.parameters()):diff_ += t.sqrt((p1-p2.detach())**2).sum()return diff_# implement the train function, this function will be called by client side# contains the local training process and the federation partdef train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):sample_num = len(train_set)aggregator = Noneif self.fed_mode:aggregator = SecureAggregatorClient(True, aggregate_weight=sample_num, communicate_match_suffix='fedprox')  # initialize aggregator# set dataloaderdl = DataLoader(train_set, batch_size=self.batch_size, num_workers=4)for epoch in range(self.epochs):# the local training processLOGGER.debug('running epoch {}'.format(epoch))global_model = copy.deepcopy(self.model)loss_sum = 0# batch training processfor batch_data, label in dl:optimizer.zero_grad()pred = self.model(batch_data)loss_term_a = loss(pred, label)loss_term_b = self._proximal_term(self.model, global_model)loss_ = loss_term_a + (self.u/2) * loss_term_bloss_.backward()loss_sum += float(loss_.detach().numpy())optimizer.step()# print lossLOGGER.debug('epoch loss is {}'.format(loss_sum))# the aggregation processif aggregator is not None:self.model = aggregator.model_aggregation(self.model)converge_status = aggregator.loss_aggregation(loss_sum)# implement the aggregation function, this function will be called by the sever sidedef server_aggregate_procedure(self, extra_data={}):# initialize aggregatorif self.fed_mode:aggregator = SecureAggregatorServer(communicate_match_suffix='fedprox')# the aggregation process is simple: every epoch the server aggregate model and loss oncefor i in range(self.epochs):aggregator.model_aggregation()merge_loss, _ = aggregator.loss_aggregation()
本地测试

我们可以使用local_mode()在本地测试新的FedProx训练器。

import torch as t
from federatedml.nn.dataset.table import TableDatasetmodel = t.nn.Sequential(t.nn.Linear(30, 1),t.nn.Sigmoid()
)ds = TableDataset()
ds.load('../examples/data/breast_homo_guest.csv')  # 根据自己得文件地址进行调整trainer = ToyFedProxTrainer(10, 128, u=0.1)
trainer.set_model(model)
opt = t.optim.Adam(model.parameters(), lr=0.01)
loss = t.nn.BCELoss()  
# 由于这里要求输入值(不是分类)的范围要在(0,1)之间,否则会报错。但是模型中的Sigmoid函数已经对其进行了处理。所以,笔者在这里并没有看清楚其损失函数得出错位置,于是将其BCELoss损失函数替换为了MSELosstrainer.local_mode()
trainer.train(ds, None, opt, loss)
这里在进行训练时,产生了报错。经过笔者debug后发现在经过现行层Linear(30, 1)后,输出为nan。如果有小伙伴知道如何解决,望告知。

笔者在官网提出该问题后,官方团队给出答复:

def _proximal_term(self, model_a, model_b):diff_ = 0for p1, p2 in zip(model_a.parameters(), model_b.parameters()):diff_ += ((p1-p2.detach())**2).sum()return diff_

可以工作!然后,我们将提交一个联合任务,看看我们的培训师是否工作正常。

提交新任务以测试ToyFedProx
# torch
import torch as t
from torch import nn
from pipeline import fate_torch_hook
fate_torch_hook(t)
# pipeline
from pipeline.component.homo_nn import HomoNN, TrainerParam  # HomoNN Component, TrainerParam for setting trainer parameter
from pipeline.backend.pipeline import PipeLine  # pipeline class
from pipeline.component import Reader, DataTransform, Evaluation # Data I/O and Evaluation
from pipeline.interface import Data  # Data Interaces for defining data flow# create a pipeline to submitting the job
guest = 9999
host = 10000
arbiter = 10000
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host, arbiter=arbiter)# read uploaded dataset
train_data_0 = {"name": "breast_homo_guest", "namespace": "experiment"}
train_data_1 = {"name": "breast_homo_host", "namespace": "experiment"}
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=train_data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=train_data_1)# The transform component converts the uploaded data to the DATE standard format
data_transform_0 = DataTransform(name='data_transform_0')
data_transform_0.get_party_instance(role='guest', party_id=guest).component_param(with_label=True, output_format="dense")
data_transform_0.get_party_instance(role='host', party_id=host).component_param(with_label=True, output_format="dense")"""
Define Pytorch model/ optimizer and loss
"""
model = nn.Sequential(nn.Linear(30, 1),nn.Sigmoid()
)
loss = nn.BCELoss()
optimizer = t.optim.Adam(model.parameters(), lr=0.01)"""
Create Homo-NN Component
"""
nn_component = HomoNN(name='nn_0',model=model, # set modelloss=loss, # set lossoptimizer=optimizer, # set optimizer# Here we use fedavg trainer# TrainerParam passes parameters to fedavg_trainer, see below for details about Trainertrainer=TrainerParam(trainer_name='fedprox', epochs=3, batch_size=128, u=0.5),torch_seed=100 # random seed)# define work flow
pipeline.add_component(reader_0)
pipeline.add_component(data_transform_0, data=Data(data=reader_0.output.data))
pipeline.add_component(nn_component, data=Data(train_data=data_transform_0.output.data))
pipeline.compile()
pipeline.fit()

这篇关于FATE —— 二.2.5 Homo-NN定制Trainer以控制训练过程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/511950

相关文章

Spring Security 基于表达式的权限控制

前言 spring security 3.0已经可以使用spring el表达式来控制授权,允许在表达式中使用复杂的布尔逻辑来控制访问的权限。 常见的表达式 Spring Security可用表达式对象的基类是SecurityExpressionRoot。 表达式描述hasRole([role])用户拥有制定的角色时返回true (Spring security默认会带有ROLE_前缀),去

浅析Spring Security认证过程

类图 为了方便理解Spring Security认证流程,特意画了如下的类图,包含相关的核心认证类 概述 核心验证器 AuthenticationManager 该对象提供了认证方法的入口,接收一个Authentiaton对象作为参数; public interface AuthenticationManager {Authentication authenticate(Authenti

作业提交过程之HDFSMapReduce

作业提交全过程详解 (1)作业提交 第1步:Client调用job.waitForCompletion方法,向整个集群提交MapReduce作业。 第2步:Client向RM申请一个作业id。 第3步:RM给Client返回该job资源的提交路径和作业id。 第4步:Client提交jar包、切片信息和配置文件到指定的资源提交路径。 第5步:Client提交完资源后,向RM申请运行MrAp

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss

MOLE 2.5 分析分子通道和孔隙

软件介绍 生物大分子通道和孔隙在生物学中发挥着重要作用,例如在分子识别和酶底物特异性方面。 我们介绍了一种名为 MOLE 2.5 的高级软件工具,该工具旨在分析分子通道和孔隙。 与其他可用软件工具的基准测试表明,MOLE 2.5 相比更快、更强大、功能更丰富。作为一项新功能,MOLE 2.5 可以估算已识别通道的物理化学性质。 软件下载 https://pan.quark.cn/s/57

MiniGPT-3D, 首个高效的3D点云大语言模型,仅需一张RTX3090显卡,训练一天时间,已开源

项目主页:https://tangyuan96.github.io/minigpt_3d_project_page/ 代码:https://github.com/TangYuan96/MiniGPT-3D 论文:https://arxiv.org/pdf/2405.01413 MiniGPT-3D在多个任务上取得了SoTA,被ACM MM2024接收,只拥有47.8M的可训练参数,在一张RTX

Solr 使用Facet分组过程中与分词的矛盾解决办法

对于一般查询而言  ,  分词和存储都是必要的  .  比如  CPU  类型  ”Intel  酷睿  2  双核  P7570”,  拆分成  ”Intel”,”  酷睿  ”,”P7570”  这样一些关键字并分别索引  ,  可能提供更好的搜索体验  .  但是如果将  CPU  作为 Facet  字段  ,  最好不进行分词  .  这样就造成了矛盾  ,  解决方法

Spark MLlib模型训练—聚类算法 PIC(Power Iteration Clustering)

Spark MLlib模型训练—聚类算法 PIC(Power Iteration Clustering) Power Iteration Clustering (PIC) 是一种基于图的聚类算法,用于在大规模数据集上进行高效的社区检测。PIC 算法的核心思想是通过迭代图的幂运算来发现数据中的潜在簇。该算法适用于处理大规模图数据,特别是在社交网络分析、推荐系统和生物信息学等领域具有广泛应用。Spa

控制反转 的种类

之前对控制反转的定义和解释都不是很清晰。最近翻书发现在《Pro Spring 5》(免费电子版在文章最后)有一段非常不错的解释。记录一下,有道翻译贴出来方便查看。如有请直接跳过中文,看后面的原文。 控制反转的类型 控制反转的类型您可能想知道为什么有两种类型的IoC,以及为什么这些类型被进一步划分为不同的实现。这个问题似乎没有明确的答案;当然,不同的类型提供了一定程度的灵活性,但

Python:豆瓣电影商业数据分析-爬取全数据【附带爬虫豆瓣,数据处理过程,数据分析,可视化,以及完整PPT报告】

**爬取豆瓣电影信息,分析近年电影行业的发展情况** 本文是完整的数据分析展现,代码有完整版,包含豆瓣电影爬取的具体方式【附带爬虫豆瓣,数据处理过程,数据分析,可视化,以及完整PPT报告】   最近MBA在学习《商业数据分析》,大实训作业给了数据要进行数据分析,所以先拿豆瓣电影练练手,网络上爬取豆瓣电影TOP250较多,但对于豆瓣电影全数据的爬取教程很少,所以我自己做一版。 目