本文主要是介绍PyTorchLightning集成SwanLab进行训练监控与可视化,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
文档:https://docs.swanlab.cn/zh/guide_cloud/integration/integration-pytorch-lightning.html
PyTorch Lightning是一个开源的机器学习库,它建立在 PyTorch 之上,旨在帮助研究人员和开发者更加方便地进行深度学习模型的研发。Lightning 的设计理念是将模型训练中的繁琐代码(如设备管理、分布式训练等)与研究代码(模型架构、数据处理等)分离,从而使研究人员可以专注于研究本身,而不是底层的工程细节。
你可以使用PyTorch Lightning快速进行模型训练,同时使用SwanLab进行实验跟踪与可视化。
1. 引入SwanLabLogger
from swanlab.integration.pytorch_lightning import SwanLabLogger
SwanLabLogger是适配于PyTorch Lightning的日志记录类。
SwanLabLogger可以定义的参数有:
- project、experiment_name、description等与
swanlab.init
效果一致的参数
2. 传入Trainer
import pytorch_lightning as pl...# 实例化SwanLabLogger
swanlab_logger = SwanLabLogger(project="lightning-visualization")trainer = pl.Trainer(...# 传入callbacks参数logger=swanlab_logger,
)trainer.fit(...)
3. 完整案例代码
from swanlab.integration.pytorch_lightning import SwanLabLoggerimport importlib.util
import osimport pytorch_lightning as pl
from torch import nn, optim, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))# define the LightningModule
class LitAutoEncoder(pl.LightningModule):def __init__(self, encoder, decoder):super().__init__()self.encoder = encoderself.decoder = decoderdef training_step(self, batch, batch_idx):# training_step defines the train loop.# it is independent of forwardx, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)loss = nn.functional.mse_loss(x_hat, x)# Logging to TensorBoard (if installed) by defaultself.log("train_loss", loss)return lossdef test_step(self, batch, batch_idx):# test_step defines the test loop.# it is independent of forwardx, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)loss = nn.functional.mse_loss(x_hat, x)# Logging to TensorBoard (if installed) by defaultself.log("test_loss", loss)return lossdef configure_optimizers(self):optimizer = optim.Adam(self.parameters(), lr=1e-3)return optimizer# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)# setup data
dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
train_dataset, val_dataset = utils.data.random_split(dataset, [55000, 5000])
test_dataset = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor())train_loader = utils.data.DataLoader(train_dataset)
val_loader = utils.data.DataLoader(val_dataset)
test_loader = utils.data.DataLoader(test_dataset)swanlab_logger = SwanLabLogger(project="swanlab_example",experiment_name="example_experiment",
)trainer = pl.Trainer(limit_train_batches=100, max_epochs=5, logger=swanlab_logger)trainer.fit(model=autoencoder, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(dataloaders=test_loader)
这篇关于PyTorchLightning集成SwanLab进行训练监控与可视化的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!