使用pysyft发送模型给带数据集的远端WebsocketServerWorker作联合训练

本文主要是介绍使用pysyft发送模型给带数据集的远端WebsocketServerWorker作联合训练,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

WebsocketServerWorker端代码:start_worker.py

import argparseimport torch as th
from syft.workers.websocket_server import WebsocketServerWorkerimport syft as sy# Arguments
parser = argparse.ArgumentParser(description="Run websocket server worker.")
parser.add_argument("--port", "-p", type=int, help="port number of the websocket server worker, e.g. --port 8777"
)
parser.add_argument("--host", type=str, default="localhost", help="host for the connection")
parser.add_argument("--id", type=str, help="name (id) of the websocket server worker, e.g. --id alice"
)
parser.add_argument("--verbose","-v",action="store_true",help="if set, websocket server worker will be started in verbose mode",
)def main(**kwargs):  # pragma: no cover"""Helper function for spinning up a websocket participant."""# Create websocket workerworker = WebsocketServerWorker(**kwargs)# Setup toy data (xor example)data = th.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)target = th.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)# Create a dataset using the toy datadataset = sy.BaseDataset(data, target)# Tell the worker about the datasetworker.add_dataset(dataset, key="xor")# Start workerworker.start()return workerif __name__ == "__main__":hook = sy.TorchHook(th)args = parser.parse_args()kwargs = {"id": args.id,"host": args.host,"port": args.port,"hook": hook,"verbose": args.verbose,}main(**kwargs)

启动worker

  python start_worker.py --host 172.16.5.45 --port 8777 --id alice

客户端代码:

import inspect
import start_workerprint(inspect.getsource(start_worker.main))# Dependencies
import torch as th
import torch.nn.functional as F
from torch import nnuse_cuda = th.cuda.is_available()
th.manual_seed(1)
device = th.device("cuda" if use_cuda else "cpu")import syft as sy
from syft import workershook = sy.TorchHook(th)  # hook torch as always :)class Net(th.nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(2, 20)self.fc2 = nn.Linear(20, 10)self.fc3 = nn.Linear(10, 1)def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# Instantiate the model
model = Net()# The data itself doesn't matter as long as the shape is right
mock_data = th.zeros(1, 2)# Create a jit version of the model
traced_model = th.jit.trace(model, mock_data)type(traced_model)# Loss function
@th.jit.script
def loss_fn(target, pred):return ((target.view(pred.shape).float() - pred.float()) ** 2).mean()type(loss_fn)optimizer = "SGD"batch_size = 4
optimizer_args = {"lr" : 0.1, "weight_decay" : 0.01}
epochs = 1
max_nr_batches = -1  # not used in this example
shuffle = Truetrain_config = sy.TrainConfig(model=traced_model,loss_fn=loss_fn,optimizer=optimizer,batch_size=batch_size,optimizer_args=optimizer_args,epochs=epochs,shuffle=shuffle)kwargs_websocket = {"host": "172.16.5.45", "hook": hook, "verbose": False}
alice = workers.websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)# Send train config
train_config.send(alice)# Setup toy data (xor example)
data = th.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)
target = th.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)print("\nEvaluation before training")
pred = model(data)
loss = loss_fn(target=target, pred=pred)
print("Loss: {}".format(loss))
print("Target: {}".format(target))
print("Pred: {}".format(pred))for epoch in range(10):loss = alice.fit(dataset_key="xor")  # ask alice to train using "xor" datasetprint("-" * 50)print("Iteration %s: alice's loss: %s" % (epoch, loss))new_model = train_config.model_ptr.get()print("\nEvaluation after training:")
pred = new_model(data)
loss = loss_fn(target=target, pred=pred)
print("Loss: {}".format(loss))
print("Target: {}".format(target))
print("Pred: {}".format(pred))

运行:

python worker-client.py 

输出结果:

Evaluation before training
Loss: 0.4933376908302307
Target: tensor([[1.],[1.],[0.],[0.]])
Pred: tensor([[ 0.1258],[-0.0994],[ 0.0033],[ 0.0210]], grad_fn=<AddmmBackward>)
--------------------------------------------------
Iteration 0: alice's loss: tensor(0.4933, requires_grad=True)
--------------------------------------------------
Iteration 1: alice's loss: tensor(0.3484, requires_grad=True)
--------------------------------------------------
Iteration 2: alice's loss: tensor(0.2858, requires_grad=True)
--------------------------------------------------
Iteration 3: alice's loss: tensor(0.2626, requires_grad=True)
--------------------------------------------------
Iteration 4: alice's loss: tensor(0.2529, requires_grad=True)
--------------------------------------------------
Iteration 5: alice's loss: tensor(0.2474, requires_grad=True)
--------------------------------------------------
Iteration 6: alice's loss: tensor(0.2441, requires_grad=True)
--------------------------------------------------
Iteration 7: alice's loss: tensor(0.2412, requires_grad=True)
--------------------------------------------------
Iteration 8: alice's loss: tensor(0.2388, requires_grad=True)
--------------------------------------------------
Iteration 9: alice's loss: tensor(0.2368, requires_grad=True)Evaluation after training:
Loss: 0.23491761088371277
Target: tensor([[1.],[1.],[0.],[0.]])
Pred: tensor([[0.6553],[0.3781],[0.4834],[0.4477]], grad_fn=<DifferentiableGraphBackward>)

这篇关于使用pysyft发送模型给带数据集的远端WebsocketServerWorker作联合训练的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/240449

相关文章

从零教你安装pytorch并在pycharm中使用

《从零教你安装pytorch并在pycharm中使用》本文详细介绍了如何使用Anaconda包管理工具创建虚拟环境,并安装CUDA加速平台和PyTorch库,同时在PyCharm中配置和使用PyTor... 目录背景介绍安装Anaconda安装CUDA安装pytorch报错解决——fbgemm.dll连接p

SpringBoot快速接入OpenAI大模型的方法(JDK8)

《SpringBoot快速接入OpenAI大模型的方法(JDK8)》本文介绍了如何使用AI4J快速接入OpenAI大模型,并展示了如何实现流式与非流式的输出,以及对函数调用的使用,AI4J支持JDK8... 目录使用AI4J快速接入OpenAI大模型介绍AI4J-github快速使用创建SpringBoot

Vue项目的甘特图组件之dhtmlx-gantt使用教程和实现效果展示(推荐)

《Vue项目的甘特图组件之dhtmlx-gantt使用教程和实现效果展示(推荐)》文章介绍了如何使用dhtmlx-gantt组件来实现公司的甘特图需求,并提供了一个简单的Vue组件示例,文章还分享了一... 目录一、首先 npm 安装插件二、创建一个vue组件三、业务页面内 引用自定义组件:四、dhtmlx

使用Python创建一个能够筛选文件的PDF合并工具

《使用Python创建一个能够筛选文件的PDF合并工具》这篇文章主要为大家详细介绍了如何使用Python创建一个能够筛选文件的PDF合并工具,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下... 目录背景主要功能全部代码代码解析1. 初始化 wx.Frame 窗口2. 创建工具栏3. 创建布局和界面控件4

一文详解如何在Python中使用Requests库

《一文详解如何在Python中使用Requests库》:本文主要介绍如何在Python中使用Requests库的相关资料,Requests库是Python中常用的第三方库,用于简化HTTP请求的发... 目录前言1. 安装Requests库2. 发起GET请求3. 发送带有查询参数的GET请求4. 发起PO

Java中的Cursor使用详解

《Java中的Cursor使用详解》本文介绍了Java中的Cursor接口及其在大数据集处理中的优势,包括逐行读取、分页处理、流控制、动态改变查询、并发控制和减少网络流量等,感兴趣的朋友一起看看吧... 最近看代码,有一段代码涉及到Cursor,感觉写法挺有意思的。注意是Cursor,而不是Consumer

javaScript在表单提交时获取表单数据的示例代码

《javaScript在表单提交时获取表单数据的示例代码》本文介绍了五种在JavaScript中获取表单数据的方法:使用FormData对象、手动提取表单数据、使用querySelector获取单个字... 方法 1:使用 FormData 对象FormData 是一个方便的内置对象,用于获取表单中的键值

Node.js net模块的使用示例

《Node.jsnet模块的使用示例》本文主要介绍了Node.jsnet模块的使用示例,net模块支持TCP通信,处理TCP连接和数据传输,具有一定的参考价值,感兴趣的可以了解一下... 目录简介引入 net 模块核心概念TCP (传输控制协议)Socket服务器TCP 服务器创建基本服务器服务器配置选项服

如何使用CSS3实现波浪式图片墙

《如何使用CSS3实现波浪式图片墙》:本文主要介绍了如何使用CSS3的transform属性和动画技巧实现波浪式图片墙,通过设置图片的垂直偏移量,并使用动画使其周期性地改变位置,可以创建出动态且具有波浪效果的图片墙,同时,还强调了响应式设计的重要性,以确保图片墙在不同设备上都能良好显示,详细内容请阅读本文,希望能对你有所帮助...

Rust中的注释使用解读

《Rust中的注释使用解读》本文介绍了Rust中的行注释、块注释和文档注释的使用方法,通过示例展示了如何在实际代码中应用这些注释,以提高代码的可读性和可维护性... 目录Rust 中的注释使用指南1. 行注释示例:行注释2. 块注释示例:块注释3. 文档注释示例:文档注释4. 综合示例总结Rust 中的注释