使用pysyft发送模型给带数据集的远端WebsocketServerWorker作联合训练

本文主要是介绍使用pysyft发送模型给带数据集的远端WebsocketServerWorker作联合训练,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

WebsocketServerWorker端代码:start_worker.py

import argparseimport torch as th
from syft.workers.websocket_server import WebsocketServerWorkerimport syft as sy# Arguments
parser = argparse.ArgumentParser(description="Run websocket server worker.")
parser.add_argument("--port", "-p", type=int, help="port number of the websocket server worker, e.g. --port 8777"
)
parser.add_argument("--host", type=str, default="localhost", help="host for the connection")
parser.add_argument("--id", type=str, help="name (id) of the websocket server worker, e.g. --id alice"
)
parser.add_argument("--verbose","-v",action="store_true",help="if set, websocket server worker will be started in verbose mode",
)def main(**kwargs):  # pragma: no cover"""Helper function for spinning up a websocket participant."""# Create websocket workerworker = WebsocketServerWorker(**kwargs)# Setup toy data (xor example)data = th.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)target = th.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)# Create a dataset using the toy datadataset = sy.BaseDataset(data, target)# Tell the worker about the datasetworker.add_dataset(dataset, key="xor")# Start workerworker.start()return workerif __name__ == "__main__":hook = sy.TorchHook(th)args = parser.parse_args()kwargs = {"id": args.id,"host": args.host,"port": args.port,"hook": hook,"verbose": args.verbose,}main(**kwargs)

启动worker

  python start_worker.py --host 172.16.5.45 --port 8777 --id alice

客户端代码:

import inspect
import start_workerprint(inspect.getsource(start_worker.main))# Dependencies
import torch as th
import torch.nn.functional as F
from torch import nnuse_cuda = th.cuda.is_available()
th.manual_seed(1)
device = th.device("cuda" if use_cuda else "cpu")import syft as sy
from syft import workershook = sy.TorchHook(th)  # hook torch as always :)class Net(th.nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(2, 20)self.fc2 = nn.Linear(20, 10)self.fc3 = nn.Linear(10, 1)def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# Instantiate the model
model = Net()# The data itself doesn't matter as long as the shape is right
mock_data = th.zeros(1, 2)# Create a jit version of the model
traced_model = th.jit.trace(model, mock_data)type(traced_model)# Loss function
@th.jit.script
def loss_fn(target, pred):return ((target.view(pred.shape).float() - pred.float()) ** 2).mean()type(loss_fn)optimizer = "SGD"batch_size = 4
optimizer_args = {"lr" : 0.1, "weight_decay" : 0.01}
epochs = 1
max_nr_batches = -1  # not used in this example
shuffle = Truetrain_config = sy.TrainConfig(model=traced_model,loss_fn=loss_fn,optimizer=optimizer,batch_size=batch_size,optimizer_args=optimizer_args,epochs=epochs,shuffle=shuffle)kwargs_websocket = {"host": "172.16.5.45", "hook": hook, "verbose": False}
alice = workers.websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)# Send train config
train_config.send(alice)# Setup toy data (xor example)
data = th.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)
target = th.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)print("\nEvaluation before training")
pred = model(data)
loss = loss_fn(target=target, pred=pred)
print("Loss: {}".format(loss))
print("Target: {}".format(target))
print("Pred: {}".format(pred))for epoch in range(10):loss = alice.fit(dataset_key="xor")  # ask alice to train using "xor" datasetprint("-" * 50)print("Iteration %s: alice's loss: %s" % (epoch, loss))new_model = train_config.model_ptr.get()print("\nEvaluation after training:")
pred = new_model(data)
loss = loss_fn(target=target, pred=pred)
print("Loss: {}".format(loss))
print("Target: {}".format(target))
print("Pred: {}".format(pred))

运行:

python worker-client.py 

输出结果:

Evaluation before training
Loss: 0.4933376908302307
Target: tensor([[1.],[1.],[0.],[0.]])
Pred: tensor([[ 0.1258],[-0.0994],[ 0.0033],[ 0.0210]], grad_fn=<AddmmBackward>)
--------------------------------------------------
Iteration 0: alice's loss: tensor(0.4933, requires_grad=True)
--------------------------------------------------
Iteration 1: alice's loss: tensor(0.3484, requires_grad=True)
--------------------------------------------------
Iteration 2: alice's loss: tensor(0.2858, requires_grad=True)
--------------------------------------------------
Iteration 3: alice's loss: tensor(0.2626, requires_grad=True)
--------------------------------------------------
Iteration 4: alice's loss: tensor(0.2529, requires_grad=True)
--------------------------------------------------
Iteration 5: alice's loss: tensor(0.2474, requires_grad=True)
--------------------------------------------------
Iteration 6: alice's loss: tensor(0.2441, requires_grad=True)
--------------------------------------------------
Iteration 7: alice's loss: tensor(0.2412, requires_grad=True)
--------------------------------------------------
Iteration 8: alice's loss: tensor(0.2388, requires_grad=True)
--------------------------------------------------
Iteration 9: alice's loss: tensor(0.2368, requires_grad=True)Evaluation after training:
Loss: 0.23491761088371277
Target: tensor([[1.],[1.],[0.],[0.]])
Pred: tensor([[0.6553],[0.3781],[0.4834],[0.4477]], grad_fn=<DifferentiableGraphBackward>)

这篇关于使用pysyft发送模型给带数据集的远端WebsocketServerWorker作联合训练的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/240449

相关文章

Java使用ANTLR4对Lua脚本语法校验详解

《Java使用ANTLR4对Lua脚本语法校验详解》ANTLR是一个强大的解析器生成器,用于读取、处理、执行或翻译结构化文本或二进制文件,下面就跟随小编一起看看Java如何使用ANTLR4对Lua脚本... 目录什么是ANTLR?第一个例子ANTLR4 的工作流程Lua脚本语法校验准备一个Lua Gramm

Java Optional的使用技巧与最佳实践

《JavaOptional的使用技巧与最佳实践》在Java中,Optional是用于优雅处理null的容器类,其核心目标是显式提醒开发者处理空值场景,避免NullPointerExce... 目录一、Optional 的核心用途二、使用技巧与最佳实践三、常见误区与反模式四、替代方案与扩展五、总结在 Java

使用Java将DOCX文档解析为Markdown文档的代码实现

《使用Java将DOCX文档解析为Markdown文档的代码实现》在现代文档处理中,Markdown(MD)因其简洁的语法和良好的可读性,逐渐成为开发者、技术写作者和内容创作者的首选格式,然而,许多文... 目录引言1. 工具和库介绍2. 安装依赖库3. 使用Apache POI解析DOCX文档4. 将解析

Qt中QUndoView控件的具体使用

《Qt中QUndoView控件的具体使用》QUndoView是Qt框架中用于可视化显示QUndoStack内容的控件,本文主要介绍了Qt中QUndoView控件的具体使用,具有一定的参考价值,感兴趣的... 目录引言一、QUndoView 的用途二、工作原理三、 如何与 QUnDOStack 配合使用四、自

C++使用printf语句实现进制转换的示例代码

《C++使用printf语句实现进制转换的示例代码》在C语言中,printf函数可以直接实现部分进制转换功能,通过格式说明符(formatspecifier)快速输出不同进制的数值,下面给大家分享C+... 目录一、printf 原生支持的进制转换1. 十进制、八进制、十六进制转换2. 显示进制前缀3. 指

使用Python构建一个Hexo博客发布工具

《使用Python构建一个Hexo博客发布工具》虽然Hexo的命令行工具非常强大,但对于日常的博客撰写和发布过程,我总觉得缺少一个直观的图形界面来简化操作,下面我们就来看看如何使用Python构建一个... 目录引言Hexo博客系统简介设计需求技术选择代码实现主框架界面设计核心功能实现1. 发布文章2. 加

SpringBoot集成Milvus实现数据增删改查功能

《SpringBoot集成Milvus实现数据增删改查功能》milvus支持的语言比较多,支持python,Java,Go,node等开发语言,本文主要介绍如何使用Java语言,采用springboo... 目录1、Milvus基本概念2、添加maven依赖3、配置yml文件4、创建MilvusClient

shell编程之函数与数组的使用详解

《shell编程之函数与数组的使用详解》:本文主要介绍shell编程之函数与数组的使用,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录shell函数函数的用法俩个数求和系统资源监控并报警函数函数变量的作用范围函数的参数递归函数shell数组获取数组的长度读取某下的

使用Python开发一个带EPUB转换功能的Markdown编辑器

《使用Python开发一个带EPUB转换功能的Markdown编辑器》Markdown因其简单易用和强大的格式支持,成为了写作者、开发者及内容创作者的首选格式,本文将通过Python开发一个Markd... 目录应用概览代码结构与核心组件1. 初始化与布局 (__init__)2. 工具栏 (setup_t

SpringValidation数据校验之约束注解与分组校验方式

《SpringValidation数据校验之约束注解与分组校验方式》本文将深入探讨SpringValidation的核心功能,帮助开发者掌握约束注解的使用技巧和分组校验的高级应用,从而构建更加健壮和可... 目录引言一、Spring Validation基础架构1.1 jsR-380标准与Spring整合1