使用pysyft发送模型给带数据集的远端WebsocketServerWorker作联合训练

本文主要是介绍使用pysyft发送模型给带数据集的远端WebsocketServerWorker作联合训练,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

WebsocketServerWorker端代码:start_worker.py

import argparseimport torch as th
from syft.workers.websocket_server import WebsocketServerWorkerimport syft as sy# Arguments
parser = argparse.ArgumentParser(description="Run websocket server worker.")
parser.add_argument("--port", "-p", type=int, help="port number of the websocket server worker, e.g. --port 8777"
)
parser.add_argument("--host", type=str, default="localhost", help="host for the connection")
parser.add_argument("--id", type=str, help="name (id) of the websocket server worker, e.g. --id alice"
)
parser.add_argument("--verbose","-v",action="store_true",help="if set, websocket server worker will be started in verbose mode",
)def main(**kwargs):  # pragma: no cover"""Helper function for spinning up a websocket participant."""# Create websocket workerworker = WebsocketServerWorker(**kwargs)# Setup toy data (xor example)data = th.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)target = th.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)# Create a dataset using the toy datadataset = sy.BaseDataset(data, target)# Tell the worker about the datasetworker.add_dataset(dataset, key="xor")# Start workerworker.start()return workerif __name__ == "__main__":hook = sy.TorchHook(th)args = parser.parse_args()kwargs = {"id": args.id,"host": args.host,"port": args.port,"hook": hook,"verbose": args.verbose,}main(**kwargs)

启动worker

  python start_worker.py --host 172.16.5.45 --port 8777 --id alice

客户端代码:

import inspect
import start_workerprint(inspect.getsource(start_worker.main))# Dependencies
import torch as th
import torch.nn.functional as F
from torch import nnuse_cuda = th.cuda.is_available()
th.manual_seed(1)
device = th.device("cuda" if use_cuda else "cpu")import syft as sy
from syft import workershook = sy.TorchHook(th)  # hook torch as always :)class Net(th.nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(2, 20)self.fc2 = nn.Linear(20, 10)self.fc3 = nn.Linear(10, 1)def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# Instantiate the model
model = Net()# The data itself doesn't matter as long as the shape is right
mock_data = th.zeros(1, 2)# Create a jit version of the model
traced_model = th.jit.trace(model, mock_data)type(traced_model)# Loss function
@th.jit.script
def loss_fn(target, pred):return ((target.view(pred.shape).float() - pred.float()) ** 2).mean()type(loss_fn)optimizer = "SGD"batch_size = 4
optimizer_args = {"lr" : 0.1, "weight_decay" : 0.01}
epochs = 1
max_nr_batches = -1  # not used in this example
shuffle = Truetrain_config = sy.TrainConfig(model=traced_model,loss_fn=loss_fn,optimizer=optimizer,batch_size=batch_size,optimizer_args=optimizer_args,epochs=epochs,shuffle=shuffle)kwargs_websocket = {"host": "172.16.5.45", "hook": hook, "verbose": False}
alice = workers.websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)# Send train config
train_config.send(alice)# Setup toy data (xor example)
data = th.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)
target = th.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)print("\nEvaluation before training")
pred = model(data)
loss = loss_fn(target=target, pred=pred)
print("Loss: {}".format(loss))
print("Target: {}".format(target))
print("Pred: {}".format(pred))for epoch in range(10):loss = alice.fit(dataset_key="xor")  # ask alice to train using "xor" datasetprint("-" * 50)print("Iteration %s: alice's loss: %s" % (epoch, loss))new_model = train_config.model_ptr.get()print("\nEvaluation after training:")
pred = new_model(data)
loss = loss_fn(target=target, pred=pred)
print("Loss: {}".format(loss))
print("Target: {}".format(target))
print("Pred: {}".format(pred))

运行:

python worker-client.py 

输出结果:

Evaluation before training
Loss: 0.4933376908302307
Target: tensor([[1.],[1.],[0.],[0.]])
Pred: tensor([[ 0.1258],[-0.0994],[ 0.0033],[ 0.0210]], grad_fn=<AddmmBackward>)
--------------------------------------------------
Iteration 0: alice's loss: tensor(0.4933, requires_grad=True)
--------------------------------------------------
Iteration 1: alice's loss: tensor(0.3484, requires_grad=True)
--------------------------------------------------
Iteration 2: alice's loss: tensor(0.2858, requires_grad=True)
--------------------------------------------------
Iteration 3: alice's loss: tensor(0.2626, requires_grad=True)
--------------------------------------------------
Iteration 4: alice's loss: tensor(0.2529, requires_grad=True)
--------------------------------------------------
Iteration 5: alice's loss: tensor(0.2474, requires_grad=True)
--------------------------------------------------
Iteration 6: alice's loss: tensor(0.2441, requires_grad=True)
--------------------------------------------------
Iteration 7: alice's loss: tensor(0.2412, requires_grad=True)
--------------------------------------------------
Iteration 8: alice's loss: tensor(0.2388, requires_grad=True)
--------------------------------------------------
Iteration 9: alice's loss: tensor(0.2368, requires_grad=True)Evaluation after training:
Loss: 0.23491761088371277
Target: tensor([[1.],[1.],[0.],[0.]])
Pred: tensor([[0.6553],[0.3781],[0.4834],[0.4477]], grad_fn=<DifferentiableGraphBackward>)

这篇关于使用pysyft发送模型给带数据集的远端WebsocketServerWorker作联合训练的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/240449

相关文章

一文详解如何使用Java获取PDF页面信息

《一文详解如何使用Java获取PDF页面信息》了解PDF页面属性是我们在处理文档、内容提取、打印设置或页面重组等任务时不可或缺的一环,下面我们就来看看如何使用Java语言获取这些信息吧... 目录引言一、安装和引入PDF处理库引入依赖二、获取 PDF 页数三、获取页面尺寸(宽高)四、获取页面旋转角度五、判断

MyBatis-Plus通用中等、大量数据分批查询和处理方法

《MyBatis-Plus通用中等、大量数据分批查询和处理方法》文章介绍MyBatis-Plus分页查询处理,通过函数式接口与Lambda表达式实现通用逻辑,方法抽象但功能强大,建议扩展分批处理及流式... 目录函数式接口获取分页数据接口数据处理接口通用逻辑工具类使用方法简单查询自定义查询方法总结函数式接口

C++中assign函数的使用

《C++中assign函数的使用》在C++标准模板库中,std::list等容器都提供了assign成员函数,它比操作符更灵活,支持多种初始化方式,下面就来介绍一下assign的用法,具有一定的参考价... 目录​1.assign的基本功能​​语法​2. 具体用法示例​​​(1) 填充n个相同值​​(2)

Spring StateMachine实现状态机使用示例详解

《SpringStateMachine实现状态机使用示例详解》本文介绍SpringStateMachine实现状态机的步骤,包括依赖导入、枚举定义、状态转移规则配置、上下文管理及服务调用示例,重点解... 目录什么是状态机使用示例什么是状态机状态机是计算机科学中的​​核心建模工具​​,用于描述对象在其生命

使用Python删除Excel中的行列和单元格示例详解

《使用Python删除Excel中的行列和单元格示例详解》在处理Excel数据时,删除不需要的行、列或单元格是一项常见且必要的操作,本文将使用Python脚本实现对Excel表格的高效自动化处理,感兴... 目录开发环境准备使用 python 删除 Excphpel 表格中的行删除特定行删除空白行删除含指定

深入理解Go语言中二维切片的使用

《深入理解Go语言中二维切片的使用》本文深入讲解了Go语言中二维切片的概念与应用,用于表示矩阵、表格等二维数据结构,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起学习学习吧... 目录引言二维切片的基本概念定义创建二维切片二维切片的操作访问元素修改元素遍历二维切片二维切片的动态调整追加行动态

prometheus如何使用pushgateway监控网路丢包

《prometheus如何使用pushgateway监控网路丢包》:本文主要介绍prometheus如何使用pushgateway监控网路丢包问题,具有很好的参考价值,希望对大家有所帮助,如有错误... 目录监控网路丢包脚本数据图表总结监控网路丢包脚本[root@gtcq-gt-monitor-prome

Python通用唯一标识符模块uuid使用案例详解

《Python通用唯一标识符模块uuid使用案例详解》Pythonuuid模块用于生成128位全局唯一标识符,支持UUID1-5版本,适用于分布式系统、数据库主键等场景,需注意隐私、碰撞概率及存储优... 目录简介核心功能1. UUID版本2. UUID属性3. 命名空间使用场景1. 生成唯一标识符2. 数

SpringBoot中如何使用Assert进行断言校验

《SpringBoot中如何使用Assert进行断言校验》Java提供了内置的assert机制,而Spring框架也提供了更强大的Assert工具类来帮助开发者进行参数校验和状态检查,下... 目录前言一、Java 原生assert简介1.1 使用方式1.2 示例代码1.3 优缺点分析二、Spring Fr

Python办公自动化实战之打造智能邮件发送工具

《Python办公自动化实战之打造智能邮件发送工具》在数字化办公场景中,邮件自动化是提升工作效率的关键技能,本文将演示如何使用Python的smtplib和email库构建一个支持图文混排,多附件,多... 目录前言一、基础配置:搭建邮件发送框架1.1 邮箱服务准备1.2 核心库导入1.3 基础发送函数二、