PyTorchLightning集成SwanLab进行训练监控与可视化

2024-05-24 16:52

本文主要是介绍PyTorchLightning集成SwanLab进行训练监控与可视化,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文档:https://docs.swanlab.cn/zh/guide_cloud/integration/integration-pytorch-lightning.html

PyTorch Lightning是一个开源的机器学习库,它建立在 PyTorch 之上,旨在帮助研究人员和开发者更加方便地进行深度学习模型的研发。Lightning 的设计理念是将模型训练中的繁琐代码(如设备管理、分布式训练等)与研究代码(模型架构、数据处理等)分离,从而使研究人员可以专注于研究本身,而不是底层的工程细节。

你可以使用PyTorch Lightning快速进行模型训练,同时使用SwanLab进行实验跟踪与可视化。

1. 引入SwanLabLogger

from swanlab.integration.pytorch_lightning import SwanLabLogger

SwanLabLogger是适配于PyTorch Lightning的日志记录类。

SwanLabLogger可以定义的参数有:

  • project、experiment_name、description等与swanlab.init效果一致的参数

2. 传入Trainer

import pytorch_lightning as pl...# 实例化SwanLabLogger
swanlab_logger = SwanLabLogger(project="lightning-visualization")trainer = pl.Trainer(...# 传入callbacks参数logger=swanlab_logger,
)trainer.fit(...)

3. 完整案例代码

from swanlab.integration.pytorch_lightning import SwanLabLoggerimport importlib.util
import osimport pytorch_lightning as pl
from torch import nn, optim, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))# define the LightningModule
class LitAutoEncoder(pl.LightningModule):def __init__(self, encoder, decoder):super().__init__()self.encoder = encoderself.decoder = decoderdef training_step(self, batch, batch_idx):# training_step defines the train loop.# it is independent of forwardx, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)loss = nn.functional.mse_loss(x_hat, x)# Logging to TensorBoard (if installed) by defaultself.log("train_loss", loss)return lossdef test_step(self, batch, batch_idx):# test_step defines the test loop.# it is independent of forwardx, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)loss = nn.functional.mse_loss(x_hat, x)# Logging to TensorBoard (if installed) by defaultself.log("test_loss", loss)return lossdef configure_optimizers(self):optimizer = optim.Adam(self.parameters(), lr=1e-3)return optimizer# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)# setup data
dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
train_dataset, val_dataset = utils.data.random_split(dataset, [55000, 5000])
test_dataset = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor())train_loader = utils.data.DataLoader(train_dataset)
val_loader = utils.data.DataLoader(val_dataset)
test_loader = utils.data.DataLoader(test_dataset)swanlab_logger = SwanLabLogger(project="swanlab_example",experiment_name="example_experiment",
)trainer = pl.Trainer(limit_train_batches=100, max_epochs=5, logger=swanlab_logger)trainer.fit(model=autoencoder, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(dataloaders=test_loader)

这篇关于PyTorchLightning集成SwanLab进行训练监控与可视化的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/998980

相关文章

使用Python实现矢量路径的压缩、解压与可视化

《使用Python实现矢量路径的压缩、解压与可视化》在图形设计和Web开发中,矢量路径数据的高效存储与传输至关重要,本文将通过一个Python示例,展示如何将复杂的矢量路径命令序列压缩为JSON格式,... 目录引言核心功能概述1. 路径命令解析2. 路径数据压缩3. 路径数据解压4. 可视化代码实现详解1

Python 交互式可视化的利器Bokeh的使用

《Python交互式可视化的利器Bokeh的使用》Bokeh是一个专注于Web端交互式数据可视化的Python库,本文主要介绍了Python交互式可视化的利器Bokeh的使用,具有一定的参考价值,感... 目录1. Bokeh 简介1.1 为什么选择 Bokeh1.2 安装与环境配置2. Bokeh 基础2

Pandas使用AdaBoost进行分类的实现

《Pandas使用AdaBoost进行分类的实现》Pandas和AdaBoost分类算法,可以高效地进行数据预处理和分类任务,本文主要介绍了Pandas使用AdaBoost进行分类的实现,具有一定的参... 目录什么是 AdaBoost?使用 AdaBoost 的步骤安装必要的库步骤一:数据准备步骤二:模型

使用Pandas进行均值填充的实现

《使用Pandas进行均值填充的实现》缺失数据(NaN值)是一个常见的问题,我们可以通过多种方法来处理缺失数据,其中一种常用的方法是均值填充,本文主要介绍了使用Pandas进行均值填充的实现,感兴趣的... 目录什么是均值填充?为什么选择均值填充?均值填充的步骤实际代码示例总结在数据分析和处理过程中,缺失数

Spring Boot 集成 Quartz并使用Cron 表达式实现定时任务

《SpringBoot集成Quartz并使用Cron表达式实现定时任务》本篇文章介绍了如何在SpringBoot中集成Quartz进行定时任务调度,并通过Cron表达式控制任务... 目录前言1. 添加 Quartz 依赖2. 创建 Quartz 任务3. 配置 Quartz 任务调度4. 启动 Sprin

QT进行CSV文件初始化与读写操作

《QT进行CSV文件初始化与读写操作》这篇文章主要为大家详细介绍了在QT环境中如何进行CSV文件的初始化、写入和读取操作,本文为大家整理了相关的操作的多种方法,希望对大家有所帮助... 目录前言一、CSV文件初始化二、CSV写入三、CSV读取四、QT 逐行读取csv文件五、Qt如何将数据保存成CSV文件前言

SpringBoot集成Milvus实现数据增删改查功能

《SpringBoot集成Milvus实现数据增删改查功能》milvus支持的语言比较多,支持python,Java,Go,node等开发语言,本文主要介绍如何使用Java语言,采用springboo... 目录1、Milvus基本概念2、添加maven依赖3、配置yml文件4、创建MilvusClient

通过Spring层面进行事务回滚的实现

《通过Spring层面进行事务回滚的实现》本文主要介绍了通过Spring层面进行事务回滚的实现,包括声明式事务和编程式事务,具有一定的参考价值,感兴趣的可以了解一下... 目录声明式事务回滚:1. 基础注解配置2. 指定回滚异常类型3. ​不回滚特殊场景编程式事务回滚:1. ​使用 TransactionT

Java中使用Hutool进行AES加密解密的方法举例

《Java中使用Hutool进行AES加密解密的方法举例》AES是一种对称加密,所谓对称加密就是加密与解密使用的秘钥是一个,下面:本文主要介绍Java中使用Hutool进行AES加密解密的相关资料... 目录前言一、Hutool简介与引入1.1 Hutool简介1.2 引入Hutool二、AES加密解密基础

SpringSecurity6.0 如何通过JWTtoken进行认证授权

《SpringSecurity6.0如何通过JWTtoken进行认证授权》:本文主要介绍SpringSecurity6.0通过JWTtoken进行认证授权的过程,本文给大家介绍的非常详细,感兴趣... 目录项目依赖认证UserDetailService生成JWT token权限控制小结之前写过一个文章,从S