使用pysyft发送模型给带数据集的远端WebsocketServerWorker作联合训练

本文主要是介绍使用pysyft发送模型给带数据集的远端WebsocketServerWorker作联合训练,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

WebsocketServerWorker端代码:start_worker.py

import argparseimport torch as th
from syft.workers.websocket_server import WebsocketServerWorkerimport syft as sy# Arguments
parser = argparse.ArgumentParser(description="Run websocket server worker.")
parser.add_argument("--port", "-p", type=int, help="port number of the websocket server worker, e.g. --port 8777"
)
parser.add_argument("--host", type=str, default="localhost", help="host for the connection")
parser.add_argument("--id", type=str, help="name (id) of the websocket server worker, e.g. --id alice"
)
parser.add_argument("--verbose","-v",action="store_true",help="if set, websocket server worker will be started in verbose mode",
)def main(**kwargs):  # pragma: no cover"""Helper function for spinning up a websocket participant."""# Create websocket workerworker = WebsocketServerWorker(**kwargs)# Setup toy data (xor example)data = th.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)target = th.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)# Create a dataset using the toy datadataset = sy.BaseDataset(data, target)# Tell the worker about the datasetworker.add_dataset(dataset, key="xor")# Start workerworker.start()return workerif __name__ == "__main__":hook = sy.TorchHook(th)args = parser.parse_args()kwargs = {"id": args.id,"host": args.host,"port": args.port,"hook": hook,"verbose": args.verbose,}main(**kwargs)

启动worker

  python start_worker.py --host 172.16.5.45 --port 8777 --id alice

客户端代码:

import inspect
import start_workerprint(inspect.getsource(start_worker.main))# Dependencies
import torch as th
import torch.nn.functional as F
from torch import nnuse_cuda = th.cuda.is_available()
th.manual_seed(1)
device = th.device("cuda" if use_cuda else "cpu")import syft as sy
from syft import workershook = sy.TorchHook(th)  # hook torch as always :)class Net(th.nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(2, 20)self.fc2 = nn.Linear(20, 10)self.fc3 = nn.Linear(10, 1)def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# Instantiate the model
model = Net()# The data itself doesn't matter as long as the shape is right
mock_data = th.zeros(1, 2)# Create a jit version of the model
traced_model = th.jit.trace(model, mock_data)type(traced_model)# Loss function
@th.jit.script
def loss_fn(target, pred):return ((target.view(pred.shape).float() - pred.float()) ** 2).mean()type(loss_fn)optimizer = "SGD"batch_size = 4
optimizer_args = {"lr" : 0.1, "weight_decay" : 0.01}
epochs = 1
max_nr_batches = -1  # not used in this example
shuffle = Truetrain_config = sy.TrainConfig(model=traced_model,loss_fn=loss_fn,optimizer=optimizer,batch_size=batch_size,optimizer_args=optimizer_args,epochs=epochs,shuffle=shuffle)kwargs_websocket = {"host": "172.16.5.45", "hook": hook, "verbose": False}
alice = workers.websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)# Send train config
train_config.send(alice)# Setup toy data (xor example)
data = th.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)
target = th.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)print("\nEvaluation before training")
pred = model(data)
loss = loss_fn(target=target, pred=pred)
print("Loss: {}".format(loss))
print("Target: {}".format(target))
print("Pred: {}".format(pred))for epoch in range(10):loss = alice.fit(dataset_key="xor")  # ask alice to train using "xor" datasetprint("-" * 50)print("Iteration %s: alice's loss: %s" % (epoch, loss))new_model = train_config.model_ptr.get()print("\nEvaluation after training:")
pred = new_model(data)
loss = loss_fn(target=target, pred=pred)
print("Loss: {}".format(loss))
print("Target: {}".format(target))
print("Pred: {}".format(pred))

运行:

python worker-client.py 

输出结果:

Evaluation before training
Loss: 0.4933376908302307
Target: tensor([[1.],[1.],[0.],[0.]])
Pred: tensor([[ 0.1258],[-0.0994],[ 0.0033],[ 0.0210]], grad_fn=<AddmmBackward>)
--------------------------------------------------
Iteration 0: alice's loss: tensor(0.4933, requires_grad=True)
--------------------------------------------------
Iteration 1: alice's loss: tensor(0.3484, requires_grad=True)
--------------------------------------------------
Iteration 2: alice's loss: tensor(0.2858, requires_grad=True)
--------------------------------------------------
Iteration 3: alice's loss: tensor(0.2626, requires_grad=True)
--------------------------------------------------
Iteration 4: alice's loss: tensor(0.2529, requires_grad=True)
--------------------------------------------------
Iteration 5: alice's loss: tensor(0.2474, requires_grad=True)
--------------------------------------------------
Iteration 6: alice's loss: tensor(0.2441, requires_grad=True)
--------------------------------------------------
Iteration 7: alice's loss: tensor(0.2412, requires_grad=True)
--------------------------------------------------
Iteration 8: alice's loss: tensor(0.2388, requires_grad=True)
--------------------------------------------------
Iteration 9: alice's loss: tensor(0.2368, requires_grad=True)Evaluation after training:
Loss: 0.23491761088371277
Target: tensor([[1.],[1.],[0.],[0.]])
Pred: tensor([[0.6553],[0.3781],[0.4834],[0.4477]], grad_fn=<DifferentiableGraphBackward>)

这篇关于使用pysyft发送模型给带数据集的远端WebsocketServerWorker作联合训练的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/240449

相关文章

Python数据验证神器Pydantic库的使用和实践中的避坑指南

《Python数据验证神器Pydantic库的使用和实践中的避坑指南》Pydantic是一个用于数据验证和设置的库,可以显著简化API接口开发,文章通过一个实际案例,展示了Pydantic如何在生产环... 目录1️⃣ 崩溃时刻:当你的API接口又双叒崩了!2️⃣ 神兵天降:3行代码解决验证难题3️⃣ 深度

Linux内核定时器使用及说明

《Linux内核定时器使用及说明》文章详细介绍了Linux内核定时器的特性、核心数据结构、时间相关转换函数以及操作API,通过示例展示了如何编写和使用定时器,包括按键消抖的应用... 目录1.linux内核定时器特征2.Linux内核定时器核心数据结构3.Linux内核时间相关转换函数4.Linux内核定时

python中的flask_sqlalchemy的使用及示例详解

《python中的flask_sqlalchemy的使用及示例详解》文章主要介绍了在使用SQLAlchemy创建模型实例时,通过元类动态创建实例的方式,并说明了如何在实例化时执行__init__方法,... 目录@orm.reconstructorSQLAlchemy的回滚关联其他模型数据库基本操作将数据添

Spring配置扩展之JavaConfig的使用小结

《Spring配置扩展之JavaConfig的使用小结》JavaConfig是Spring框架中基于纯Java代码的配置方式,用于替代传统的XML配置,通过注解(如@Bean)定义Spring容器的组... 目录JavaConfig 的概念什么是JavaConfig?为什么使用 JavaConfig?Jav

MySQL快速复制一张表的四种核心方法(包括表结构和数据)

《MySQL快速复制一张表的四种核心方法(包括表结构和数据)》本文详细介绍了四种复制MySQL表(结构+数据)的方法,并对每种方法进行了对比分析,适用于不同场景和数据量的复制需求,特别是针对超大表(1... 目录一、mysql 复制表(结构+数据)的 4 种核心方法(面试结构化回答)方法 1:CREATE

详解C++ 存储二进制数据容器的几种方法

《详解C++存储二进制数据容器的几种方法》本文主要介绍了详解C++存储二进制数据容器,包括std::vector、std::array、std::string、std::bitset和std::ve... 目录1.std::vector<uint8_t>(最常用)特点:适用场景:示例:2.std::arra

Java使用Spire.Doc for Java实现Word自动化插入图片

《Java使用Spire.DocforJava实现Word自动化插入图片》在日常工作中,Word文档是不可或缺的工具,而图片作为信息传达的重要载体,其在文档中的插入与布局显得尤为关键,下面我们就来... 目录1. Spire.Doc for Java库介绍与安装2. 使用特定的环绕方式插入图片3. 在指定位

Springboot3 ResponseEntity 完全使用案例

《Springboot3ResponseEntity完全使用案例》ResponseEntity是SpringBoot中控制HTTP响应的核心工具——它能让你精准定义响应状态码、响应头、响应体,相比... 目录Spring Boot 3 ResponseEntity 完全使用教程前置准备1. 项目基础依赖(M

Java使用Spire.Barcode for Java实现条形码生成与识别

《Java使用Spire.BarcodeforJava实现条形码生成与识别》在现代商业和技术领域,条形码无处不在,本教程将引导您深入了解如何在您的Java项目中利用Spire.Barcodefor... 目录1. Spire.Barcode for Java 简介与环境配置2. 使用 Spire.Barco

Android使用java实现网络连通性检查详解

《Android使用java实现网络连通性检查详解》这篇文章主要为大家详细介绍了Android使用java实现网络连通性检查的相关知识,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录NetCheck.Java(可直接拷贝)使用示例(Activity/Fragment 内)权限要求