FATE —— 二.2.6 Homo-NN使用FATE接口Trainer

2023-12-19 11:10

本文主要是介绍FATE —— 二.2.6 Homo-NN使用FATE接口Trainer,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言

在本教程中,我们将演示如何使用培训师用户界面返回格式化的预测结果,评估模型的性能,保存模型,并在仪表板上显示损失曲线和性能分数。这些接口允许您的培训师与FATE框架集成,使其更易于使用。

由于官方网站的示例代码有一定的错误,所以在此进行声明,改正后的_proximal_term如下所示:

def _proximal_term(self, model_a, model_b):diff_ = 0for p1, p2 in zip(model_a.parameters(), model_b.parameters()):diff_ += ((p1-p2.detach())**2).sum()return diff_

在本教程中,我们将继续开发我们的玩具FedProx训练器。

FedProx的玩具实现

在上一个教程中,我们通过演示FedProx算法的玩具实现提供了一个具体的示例。在FedProx中,训练过程与标准FedAVG算法略有不同,因为在计算损失时,需要从当前模型和全局模型计算近端项。代码在这里:

from pipeline.component.nn import save_to_fate
%%save_to_fate trainer fedprox.py
import copy
import torch as t
from federatedml.nn.homo.trainer.trainer_base import TrainerBase
from torch.utils.data import DataLoader
# We need to use aggregator client&server class for federation
from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer
# We use LOGGER to output logs
from federatedml.util import LOGGERclass ToyFedProxTrainer(TrainerBase):def __init__(self, epochs, batch_size, u):super(ToyFedProxTrainer, self).__init__()# trainer parametersself.epochs = epochsself.batch_size = batch_sizeself.u = u# Given two model, we compute the proximal termdef _proximal_term(self, model_a, model_b):diff_ = 0for p1, p2 in zip(model_a.parameters(), model_b.parameters()):diff_ += ((p1-p2.detach())**2).sum()return diff_# implement the train function, this function will be called by client side# contains the local training process and the federation partdef train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):sample_num = len(train_set)aggregator = Noneif self.fed_mode:aggregator = SecureAggregatorClient(True, aggregate_weight=sample_num, communicate_match_suffix='fedprox')  # initialize aggregator# set dataloaderdl = DataLoader(train_set, batch_size=self.batch_size, num_workers=4)for epoch in range(self.epochs):# the local training processLOGGER.debug('running epoch {}'.format(epoch))global_model = copy.deepcopy(self.model)loss_sum = 0# batch training processfor batch_data, label in dl:optimizer.zero_grad()pred = self.model(batch_data)loss_term_a = loss(pred, label)loss_term_b = self._proximal_term(self.model, global_model)loss_ = loss_term_a + (self.u/2) * loss_term_bloss_.backward()loss_sum += float(loss_.detach().numpy())optimizer.step()# pring lossLOGGER.debug('epoch loss is {}'.format(loss_sum))# the aggregation processif aggregator is not None:self.model = aggregator.model_aggregation(self.model)converge_status = aggregator.loss_aggregation(loss_sum)# implement the aggregation function, this function will be called by the sever sidedef server_aggregate_procedure(self, extra_data={}):# initialize aggregatorif self.fed_mode:aggregator = SecureAggregatorServer(communicate_match_suffix='fedprox')# the aggregation process is simple: every epoch the server aggregate model and loss oncefor i in range(self.epochs):aggregator.model_aggregation()merge_loss, _ = aggregator.loss_aggregation()
用户界面

现在我们向您介绍TrainerBase类提供的用户界面,我们将使用这些功能来改进我们的培训师。

格式预测结果

此函数将组织预测结果并返回StdReturnFormat对象,该对象将包装结果。您可以在预测函数的末尾使用此函数返回FATE框架可以解析并显示在命运板上的标准格式。这种标准化格式还允许下游组件(例如评估组件)使用预测结果。

此函数接受四个参数:

  • sample_ids:示例ID列表

  • predict_result:预测得分的张量

  • true_label:真标签张量

  • task_type:正在执行的任务类型。默认值为“auto”,它将自动推断任务类型。其他选项包括“二进制”、“多”和“回归”。目前,FATE仪表板仅支持显示二进制/多分类和回归任务。如果选择“自动”,则将自动推断任务类型。

稍后我们将在FedProx培训器中实现预测。
import torch as t 
from typing import Listdef format_predict_result(self, sample_ids: List, predict_result: t.Tensor,true_label: t.Tensor, task_type: str = None):...
callback_metric和callback_loss

顾名思义,这两个功能使您能够保存数据点,并在命盘上显示自定义评估指标和损失曲线。

使用回调度量函数时,需要提供度量名称、浮点值,并指定度量类型('train'或'validate')和历元索引。使用回调损失函数时,需要提供浮点损失值和历元索引。您的数据将显示在命盘上。

def callback_metric(self, metric_name: str, value: float, metric_type='train', epoch_idx=0):...def callback_loss(self, loss: float, epoch_idx: int):...
总结

此函数允许您在字典中保存训练过程的摘要,例如丢失历史和最佳时期。任务完成后,您可以从管道中检索此摘要。

def summary(self, summary_dict: dict):...
保存和检查点

您可以使用“save”保存模型,并使用“checkpoint”功能设置模型检查点。需要注意的是:

  • “save”仅将模型存储在内存中,因此您保存的模型将是上次使用“save”功能保存的模型。

  • “checkpoint”直接将模型保存到磁盘。

  • “save”只能在客户端(在“train”函数中)调用,而“checkpoint”应该在客户端和服务器端(在“rain”和“server_aggregate_proccure”函数中调用)调用,以确保检查点机制正常工作。

函数中的“extra_data”参数允许您在字典中保存其他数据。这在热启动模型时非常有用,因为您可以使用“train”和“server_aggregate_proccure”函数中的“extra_data”参数检索保存的数据。

def save(self,model=None,epoch_idx=-1,optimizer=None,converge_status=False,loss_history=None,best_epoch=-1,extra_data={}): ...def checkpoint(self,epoch_idx,model=None,optimizer=None,converge_status=False,loss_history=None,best_epoch=-1,extra_data={}): ...
评价

此界面允许您通过自动计算各种性能指标来评估模型。计算的指标取决于数据集和任务的类型

  • 二进制分类:“AUC”和“ks”

  • 多级分类:“准确度”、“精度”和“召回”

  • 回归:“rmse”和“mae”

您可以在参数中指定数据集的类型(“训练”或“验证”)和任务类型(“二元”、“多元”或“回归”)。如果未指定任务类型,将自动从您的分数和标签中推断出任务类型。

def evaluation(self, sample_ids: list, pred_scores: t.Tensor, label: t.Tensor, dataset_type='train',epoch_idx=0, task_type=None):...

改进的FedProx训练器

在本节中,我们将使用前面介绍的接口来改进我们的FedProx Trainer,使其成为一个更全面的培训工具。我们:

  • 我们实现了predict函数,并返回格式化的结果

  • 添加求值函数

  • 培训结束时保存模型

  • 回调损失以保存损失曲线

  • 我们计算准确度分数,然后使用回调度量显示。

from pipeline.component.nn import save_to_fate
%%save_to_fate trainer fedprox_v2.py
import copy
import torch as t
from federatedml.nn.homo.trainer.trainer_base import TrainerBase
from federatedml.nn.dataset.base import Dataset
from torch.utils.data import DataLoader
# We need to use aggregator client&server class for federation
from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer
# We use LOGGER to output logs
from federatedml.util import LOGGERclass ToyFedProxTrainer(TrainerBase):def __init__(self, epochs, batch_size, u):super(ToyFedProxTrainer, self).__init__()# trainer parametersself.epochs = epochsself.batch_size = batch_sizeself.u = u# Given two model, we compute the proximal termdef _proximal_term(self, model_a, model_b):diff_ = 0for p1, p2 in zip(model_a.parameters(), model_b.parameters()):diff_ += ((p1-p2.detach())**2).sum()return diff_# implement the train function, this function will be called by client side# contains the local training process and the federation partdef train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):sample_num = len(train_set)aggregator = Noneif self.fed_mode:aggregator = SecureAggregatorClient(True, aggregate_weight=sample_num, communicate_match_suffix='fedprox')  # initialize aggregator# set dataloaderdl = DataLoader(train_set, batch_size=self.batch_size, num_workers=4)loss_history = []for epoch in range(self.epochs):# the local training processLOGGER.debug('running epoch {}'.format(epoch))global_model = copy.deepcopy(self.model)loss_sum = 0# batch training processfor batch_data, label in dl:optimizer.zero_grad()pred = self.model(batch_data)loss_term_a = loss(pred, label)loss_term_b = self._proximal_term(self.model, global_model)loss_ = loss_term_a + (self.u/2) * loss_term_bLOGGER.debug('loss is {} loss a is {} loss b is {}'.format(loss_, loss_term_a, loss_term_b))loss_.backward()loss_sum += float(loss_.detach().numpy())optimizer.step()# print lossLOGGER.debug('epoch loss is {}'.format(loss_sum))loss_history.append(loss_sum)# we callback loss hereself.callback_loss(loss_sum, epoch)# we evaluate out model heresample_ids, preds, labels = self._predict(train_set)self.evaluation(sample_ids, preds, labels, 'train', task_type='binary', epoch_idx=epoch)# we manually compute accuracy:acc = ((preds > 0.5 + 0) == labels).sum() / len(labels)acc = float(acc.detach().numpy())self.callback_metric('my_accuracy', acc, epoch_idx=epoch)# the aggregation processif aggregator is not None:self.model = aggregator.model_aggregation(self.model)converge_status = aggregator.loss_aggregation(loss_sum)# We will save model at the end of the trainingself.save(self.model, epoch, optimizer)# We will save model summaryself.summary({'loss_history': loss_history})# implement the aggregation function, this function will be called by the sever sidedef server_aggregate_procedure(self, extra_data={}):# initialize aggregatorif self.fed_mode:aggregator = SecureAggregatorServer(communicate_match_suffix='fedprox')# the aggregation process is simple: every epoch the server aggregate model and loss oncefor i in range(self.epochs):aggregator.model_aggregation()merge_loss, _ = aggregator.loss_aggregation()def _predict(self, dataset: Dataset):len_ = len(dataset)dl = DataLoader(dataset, batch_size=len_)preds, labels = None, Nonefor data, l in dl:preds = self.model(data)labels = lsample_ids = dataset.get_sample_ids()return sample_ids, preds, labels# We implement the predict function heredef predict(self, dataset):sample_ids, preds, labels = self._predict(dataset)return self.format_predict_result(sample_ids, preds, labels, 'binary')

提交Pipeline

在这里,我们提交了一个新的Pipeline来测试我们的新trainer

# torch
import torch as t
from torch import nn
from pipeline import fate_torch_hook
fate_torch_hook(t)
# pipeline
from pipeline.component.homo_nn import HomoNN, TrainerParam  # HomoNN Component, TrainerParam for setting trainer parameter
from pipeline.backend.pipeline import PipeLine  # pipeline class
from pipeline.component import Reader, DataTransform, Evaluation # Data I/O and Evaluation
from pipeline.interface import Data  # Data Interaces for defining data flow# create a pipeline to submitting the job
guest = 9999
host = 10000
arbiter = 10000
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host, arbiter=arbiter)# read uploaded dataset
train_data_0 = {"name": "breast_homo_guest", "namespace": "experiment"}
train_data_1 = {"name": "breast_homo_host", "namespace": "experiment"}
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=train_data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=train_data_1)# The transform component converts the uploaded data to the DATE standard format
data_transform_0 = DataTransform(name='data_transform_0')
data_transform_0.get_party_instance(role='guest', party_id=guest).component_param(with_label=True, output_format="dense")
data_transform_0.get_party_instance(role='host', party_id=host).component_param(with_label=True, output_format="dense")"""
Define Pytorch model/ optimizer and loss
"""
model = nn.Sequential(nn.Linear(30, 1),nn.Sigmoid()
)
loss = nn.BCELoss()
optimizer = t.optim.Adam(model.parameters(), lr=0.01)"""
Create Homo-NN Component
"""
nn_component = HomoNN(name='nn_0',model=model, # set modelloss=loss, # set lossoptimizer=optimizer, # set optimizer# Here we use fedavg trainer# TrainerParam passes parameters to fedavg_trainer, see below for details about Trainertrainer=TrainerParam(trainer_name='fedprox_v2', epochs=3, batch_size=128, u=0.5),torch_seed=100 # random seed)# define work flow
pipeline.add_component(reader_0)
pipeline.add_component(data_transform_0, data=Data(data=reader_0.output.data))
pipeline.add_component(nn_component, data=Data(train_data=data_transform_0.output.data))
pipeline.compile()
pipeline.fit()

这篇关于FATE —— 二.2.6 Homo-NN使用FATE接口Trainer的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/511949

相关文章

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

Pandas使用AdaBoost进行分类的实现

《Pandas使用AdaBoost进行分类的实现》Pandas和AdaBoost分类算法,可以高效地进行数据预处理和分类任务,本文主要介绍了Pandas使用AdaBoost进行分类的实现,具有一定的参... 目录什么是 AdaBoost?使用 AdaBoost 的步骤安装必要的库步骤一:数据准备步骤二:模型

使用Pandas进行均值填充的实现

《使用Pandas进行均值填充的实现》缺失数据(NaN值)是一个常见的问题,我们可以通过多种方法来处理缺失数据,其中一种常用的方法是均值填充,本文主要介绍了使用Pandas进行均值填充的实现,感兴趣的... 目录什么是均值填充?为什么选择均值填充?均值填充的步骤实际代码示例总结在数据分析和处理过程中,缺失数

如何使用 Python 读取 Excel 数据

《如何使用Python读取Excel数据》:本文主要介绍使用Python读取Excel数据的详细教程,通过pandas和openpyxl,你可以轻松读取Excel文件,并进行各种数据处理操... 目录使用 python 读取 Excel 数据的详细教程1. 安装必要的依赖2. 读取 Excel 文件3. 读

解决Maven项目idea找不到本地仓库jar包问题以及使用mvn install:install-file

《解决Maven项目idea找不到本地仓库jar包问题以及使用mvninstall:install-file》:本文主要介绍解决Maven项目idea找不到本地仓库jar包问题以及使用mvnin... 目录Maven项目idea找不到本地仓库jar包以及使用mvn install:install-file基

Python使用getopt处理命令行参数示例解析(最佳实践)

《Python使用getopt处理命令行参数示例解析(最佳实践)》getopt模块是Python标准库中一个简单但强大的命令行参数处理工具,它特别适合那些需要快速实现基本命令行参数解析的场景,或者需要... 目录为什么需要处理命令行参数?getopt模块基础实际应用示例与其他参数处理方式的比较常见问http

C 语言中enum枚举的定义和使用小结

《C语言中enum枚举的定义和使用小结》在C语言里,enum(枚举)是一种用户自定义的数据类型,它能够让你创建一组具名的整数常量,下面我会从定义、使用、特性等方面详细介绍enum,感兴趣的朋友一起看... 目录1、引言2、基本定义3、定义枚举变量4、自定义枚举常量的值5、枚举与switch语句结合使用6、枚

使用Python从PPT文档中提取图片和图片信息(如坐标、宽度和高度等)

《使用Python从PPT文档中提取图片和图片信息(如坐标、宽度和高度等)》PPT是一种高效的信息展示工具,广泛应用于教育、商务和设计等多个领域,PPT文档中常常包含丰富的图片内容,这些图片不仅提升了... 目录一、引言二、环境与工具三、python 提取PPT背景图片3.1 提取幻灯片背景图片3.2 提取

usb接口驱动异常问题常用解决方案

《usb接口驱动异常问题常用解决方案》当遇到USB接口驱动异常时,可以通过多种方法来解决,其中主要就包括重装USB控制器、禁用USB选择性暂停设置、更新或安装新的主板驱动等... usb接口驱动异常怎么办,USB接口驱动异常是常见问题,通常由驱动损坏、系统更新冲突、硬件故障或电源管理设置导致。以下是常用解决

使用Python实现图像LBP特征提取的操作方法

《使用Python实现图像LBP特征提取的操作方法》LBP特征叫做局部二值模式,常用于纹理特征提取,并在纹理分类中具有较强的区分能力,本文给大家介绍了如何使用Python实现图像LBP特征提取的操作方... 目录一、LBP特征介绍二、LBP特征描述三、一些改进版本的LBP1.圆形LBP算子2.旋转不变的LB