Pytorch-Lighting使用教程(MNIST为例)

2024-06-02 07:12

本文主要是介绍Pytorch-Lighting使用教程(MNIST为例),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、pytorch-lighting简介

1.1 pytorch-lighting是什么

pytorch-lighting(简称pl),基于 PyTorch 的框架。它的核心思想是,将学术代码模型定义、前向 / 反向、优化器、验证等)与工程代码for-loop,保存、tensorboard 日志、训练策略等)解耦开来,使得代码更为简洁清晰。

工程代码经常会出现在深度学习代码中,PyTorch Lightning 对这部分逻辑进行了封装,只需要在 Trainer 类中简单设置即可调用,无需重复造轮子。

1.2 pytorch-lighting优势

  • 通过抽象出样板工程代码,可以更容易地识别和理解ML代码;
  • Lightning的统一结构使得在现有项目的基础上进行构建和理解变得非常容易;
  • Lightning 自动化的代码是用经过全面测试、定期维护并遵循ML最佳实践的高质量代码构建的;

pytorch-lighting最大的好处:

(1)是摆脱了硬件依赖,不需要在程序中显式设置.cuda() 等,PyTorch Lightning 会自动将模型、张量的设备放置在合适的设备;移除.train() 等代码,这也会自动切换

(2)支持分布式训练,自动分配资源,能够很好的进行大规模的DL训练

(3)代码量较少,只需要关心关键的逻辑代码,而框架性的东西,pytorch-lighting已经帮你解决(如自动训练,自动debug)


二、基于Pytorch-Lighting框架训练MNIST模型

1、仅仅训练

下载的所有的数据集都用于训练(没有评估和测试过程,不清楚模型的好与坏)。

# 1. 导入所需的模块
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning.pytorch as pl# 2. 定义编码器和解码器
# 2.1 定义基础编码器Encoder
class Encoder(nn.Module):def __init__(self):super().__init__()self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))def forward(self, x):return self.l1(x)# 2.2 定义基础解码器Decoder
class Decoder(nn.Module):def __init__(self):super().__init__()self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))def forward(self, x):return self.l1(x)# 3. 定义LightningModule
class LitAutoEncoder(pl.LightningModule):# 3.1 加载基础模型def __init__(self, encoder, decoder):super().__init__()self.encoder = encoderself.decoder = decoder# 3.2 训练过程设置def training_step(self, batch, batch_idx):  # 每一个batch数据运算计算loss# training_step defines the train loop.x, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)loss = F.mse_loss(x_hat, x)return loss# 3.3 优化器设置def configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)return optimizer# 4. 定义训练数据
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)# 5. 实例化模型
autoencoder = LitAutoEncoder(Encoder(), Decoder())# 6. 开始训练
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

class LitAutoEncoder(pl.LightningModule):

  • 将模型定义代码写在__init__
  • 定义前向传播逻辑
  • 将优化器代码写在 configure_optimizers 钩子中
  • 训练代码写在 training_step 钩子中,可使用 self.log 随时记录变量的值,会保存在 tensorboard 中
  • 验证代码写在 validation_step 钩子中
  • 移除硬件调用.cuda() 等,PyTorch Lightning 会自动将模型、张量的设备放置在合适的设备;移除.train() 等代码,这也会自动切换
  • 根据需要,重写其他钩子函数,例如 validation_epoch_end,对 validation_step 的结果进行汇总;train_dataloader,定义训练数据的加载逻辑
  • 实例化 Lightning Module 和 Trainer 对象,传入数据集
  • 定义训练参数和回调函数,例如训练设备、数量、保存策略,Early Stop、半精度等

运行结果:

2、添加验证和测试模块

在训练之后,加入了测试和评估功能,能更好的指导模型的性能。

# 1. 导入所需的模块
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning.pytorch as plimport torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transformsfrom torch.utils.data import DataLoader# 2. 定义编码器和解码器
# 2.1 定义基础编码器Encoder
class Encoder(nn.Module):def __init__(self):super().__init__()self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))def forward(self, x):return self.l1(x)# 2.2 定义基础解码器Decoder
class Decoder(nn.Module):def __init__(self):super().__init__()self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))def forward(self, x):return self.l1(x)# 3. 定义LightningModule
class LitAutoEncoder(pl.LightningModule):# 3.1 加载基础模型def __init__(self, encoder, decoder):super().__init__()self.encoder = encoderself.decoder = decoder# 3.2 训练过程设置def training_step(self, batch, batch_idx):  # 每一个batch数据运算计算loss# training_step defines the train loop.x, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)loss = F.mse_loss(x_hat, x)return loss# 3.3 测试过程设置def test_step(self, batch, batch_idx):# this is the test loopx, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)test_loss = F.mse_loss(x_hat, x)self.log("test_loss", test_loss)# 3.4 验证过程设置def validation_step(self, batch, batch_idx):# this is the validation loopx, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)val_loss = F.mse_loss(x_hat, x)self.log("val_loss", val_loss)# 3.5 优化器设置def configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)return optimizer# 4. 定义训练数据
'''
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)
'''# 4.1 分别下载并加载训练集和测试集
transform = transforms.ToTensor()
train_set = datasets.MNIST(os.getcwd(), download=False, train=True, transform=transform)
test_set = datasets.MNIST(os.getcwd(), download=False, train=False, transform=transform)# 4.2 将训练集中的20%用于验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size# 4.3 设置种子
seed = torch.Generator().manual_seed(42)# 4.4 从训练集中随机拿到80%的测试集和20%的验证集
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)# 4.5 分别加载训练集和测试集
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)# 5. 实例化模型
autoencoder = LitAutoEncoder(Encoder(), Decoder())# 6. 实例化Trainer
trainer = pl.Trainer(max_epochs=10)# 7. 开始训练和评估
trainer.fit(autoencoder, train_loader, valid_loader)# 8.开始测试
trainer.test(model=autoencoder, dataloaders=DataLoader(test_set))

3、权重 & 超参的保存和加载

当模型正在训练时,性能会随着它继续看到更多数据而发生变化。

1)训练完成后,使用在训练过程中发现的最佳性能相对应的权重;

2)权重可以让训练在训练过程中断的情况下从原来的位置恢复。

保存权重:Lightning 会自动为你在当前工作目录下保存一个权重,其中包含上一次训练的状态。这能确保在训练中断的情况下恢复训练。

3.1 自动在当前目录下保存checkpoint

# simply by using the Trainer you get automatic checkpointing
trainer = Trainer()

3.2 指定checkpoint保存的目录

# saves checkpoints to 'some/path/' at every epoch end
trainer = Trainer(default_root_dir="some/path/")

3.3 加载checkpoint

# trainer.fit(autoencoder, train_loader, valid_loader, ckpt_path="/home/gvlib_ljh/class/Lightning_mnist/lightning_logs/version_25/checkpoints/epoch=9-step=160000.ckpt")

4、可视化

在模型开发中,我们跟踪感兴趣的值,例如validation_loss,以可视化模型的学习过程。模型开发就像驾驶一辆没有窗户的汽车,图表和日志提供了了解汽车行驶方向的窗口。借助 Lightning,可以可视化任何您能想到的东西:数字、文本、图像、音频。

要跟踪指标,只需使用 LightningModule 内可用的 self.log 方法。

class LitModel(pl.LightningModule):def training_step(self, batch, batch_idx):value = ...self.log("some_value", value)

同时记录多个指标:

values = {"loss": loss, "acc": acc, "metric_n": metric_n}  # add more items if needed
self.log_dict(values)

4.1 命令行查看

要在命令行进度栏中查看指标,请将 prog_bar 参数设置为 True。

self.log(..., prog_bar=True)

4.2 浏览器查看

默认情况下,Lightning 使用 Tensorboard(如果可用)和一个简单的 CSV 记录器

在命令行中输入(注意:一定是lightning_logs所在的目录):

tensorboard --logdir=lightning_logs/

Tensorboard界面:

Tensorboard输出分析:

完整的代码:

# 1. 导入所需的模块
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning.pytorch as plimport torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transformsfrom torch.utils.data import DataLoaderfrom pytorch_lightning.loggers import TensorBoardLogger# 设置浮点矩阵乘法精度为 'medium'
torch.set_float32_matmul_precision('medium')# 2. 定义编码器和解码器
# 2.1 定义基础编码器Encoder
class Encoder(nn.Module):def __init__(self):super().__init__()self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))def forward(self, x):return self.l1(x)# 2.2 定义基础解码器Decoder
class Decoder(nn.Module):def __init__(self):super().__init__()self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))def forward(self, x):return self.l1(x)# 3. 定义LightningModule
class LitAutoEncoder(pl.LightningModule):# 3.1 加载基础模型def __init__(self, encoder, decoder):super().__init__()self.encoder = encoderself.decoder = decoder# 3.2 训练过程设置def training_step(self, batch, batch_idx):  # 每一个batch数据运算计算loss# training_step defines the train loop.x, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)loss = F.mse_loss(x_hat, x)batch_idx_value = batch_idx + 1print(" ")values = {"loss": loss, "batch_idx_value": batch_idx_value}  # add more items if neededself.log_dict(values)# 在命令行界面显示log'''sync_dist=True:分布式计算,数据同步标志prog_bar=True:在控制台上显示'''self.log("train_loss", loss, sync_dist=True, prog_bar=True)return loss# 3.3 测试过程设置def test_step(self, batch, batch_idx):x, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)test_loss = F.mse_loss(x_hat, x)self.log("test_loss", test_loss, sync_dist=True, prog_bar=True)# 3.4 验证过程设置def validation_step(self, batch, batch_idx):# this is the validation loopx, y = batchx = x.view(x.size(0), -1)z = self.encoder(x)x_hat = self.decoder(z)val_loss = F.mse_loss(x_hat, x)self.log("val_loss", val_loss, sync_dist=True, prog_bar=True)# 3.5 优化器设置def configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)return optimizer# 4. 定义训练数据
'''
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)
'''# 4.1 分别下载并加载训练集和测试集
transform = transforms.ToTensor()
train_set = datasets.MNIST(os.getcwd(), download=False, train=True, transform=transform)
test_set = datasets.MNIST(os.getcwd(), download=False, train=False, transform=transform)# 4.2 将训练集中的20%用于验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size# 4.3 设置种子
seed = torch.Generator().manual_seed(42)# 4.4 从训练集中随机拿到80%的测试集和20%的验证集
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)# 4.5 分别加载训练集和测试集
train_loader = DataLoader(train_set, batch_size=256, num_workers=5)
valid_loader = DataLoader(valid_set, batch_size=128, num_workers=5)# 5. 实例化模型
autoencoder = LitAutoEncoder(Encoder(), Decoder())# 6. 实例化Trainer
trainer = pl.Trainer(max_epochs=1000)# 7. 开始训练和评估
trainer.fit(autoencoder, train_loader, valid_loader)
# 7. 从checkpoint恢复状态
# trainer.fit(autoencoder, train_loader, valid_loader, ckpt_path="/home/gvlib_ljh/class/Lightning_mnist/lightning_logs/version_25/checkpoints/epoch=9-step=160000.ckpt")# 8.开始测试
trainer.test(model=autoencoder, dataloaders=DataLoader(test_set))

参考:

https://zhuanlan.zhihu.com/p/659631467

这篇关于Pytorch-Lighting使用教程(MNIST为例)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1023290

相关文章

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

中文分词jieba库的使用与实景应用(一)

知识星球:https://articles.zsxq.com/id_fxvgc803qmr2.html 目录 一.定义: 精确模式(默认模式): 全模式: 搜索引擎模式: paddle 模式(基于深度学习的分词模式): 二 自定义词典 三.文本解析   调整词出现的频率 四. 关键词提取 A. 基于TF-IDF算法的关键词提取 B. 基于TextRank算法的关键词提取

使用SecondaryNameNode恢复NameNode的数据

1)需求: NameNode进程挂了并且存储的数据也丢失了,如何恢复NameNode 此种方式恢复的数据可能存在小部分数据的丢失。 2)故障模拟 (1)kill -9 NameNode进程 [lytfly@hadoop102 current]$ kill -9 19886 (2)删除NameNode存储的数据(/opt/module/hadoop-3.1.4/data/tmp/dfs/na

Hadoop数据压缩使用介绍

一、压缩原则 (1)运算密集型的Job,少用压缩 (2)IO密集型的Job,多用压缩 二、压缩算法比较 三、压缩位置选择 四、压缩参数配置 1)为了支持多种压缩/解压缩算法,Hadoop引入了编码/解码器 2)要在Hadoop中启用压缩,可以配置如下参数

Makefile简明使用教程

文章目录 规则makefile文件的基本语法:加在命令前的特殊符号:.PHONY伪目标: Makefilev1 直观写法v2 加上中间过程v3 伪目标v4 变量 make 选项-f-n-C Make 是一种流行的构建工具,常用于将源代码转换成可执行文件或者其他形式的输出文件(如库文件、文档等)。Make 可以自动化地执行编译、链接等一系列操作。 规则 makefile文件

使用opencv优化图片(画面变清晰)

文章目录 需求影响照片清晰度的因素 实现降噪测试代码 锐化空间锐化Unsharp Masking频率域锐化对比测试 对比度增强常用算法对比测试 需求 对图像进行优化,使其看起来更清晰,同时保持尺寸不变,通常涉及到图像处理技术如锐化、降噪、对比度增强等 影响照片清晰度的因素 影响照片清晰度的因素有很多,主要可以从以下几个方面来分析 1. 拍摄设备 相机传感器:相机传

pdfmake生成pdf的使用

实际项目中有时会有根据填写的表单数据或者其他格式的数据,将数据自动填充到pdf文件中根据固定模板生成pdf文件的需求 文章目录 利用pdfmake生成pdf文件1.下载安装pdfmake第三方包2.封装生成pdf文件的共用配置3.生成pdf文件的文件模板内容4.调用方法生成pdf 利用pdfmake生成pdf文件 1.下载安装pdfmake第三方包 npm i pdfma

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]

git使用的说明总结

Git使用说明 下载安装(下载地址) macOS: Git - Downloading macOS Windows: Git - Downloading Windows Linux/Unix: Git (git-scm.com) 创建新仓库 本地创建新仓库:创建新文件夹,进入文件夹目录,执行指令 git init ,用以创建新的git 克隆仓库 执行指令用以创建一个本地仓库的

【北交大信息所AI-Max2】使用方法

BJTU信息所集群AI_MAX2使用方法 使用的前提是预约到相应的算力卡,拥有登录权限的账号密码,一般为导师组共用一个。 有浏览器、ssh工具就可以。 1.新建集群Terminal 浏览器登陆10.126.62.75 (如果是1集群把75改成66) 交互式开发 执行器选Terminal 密码随便设一个(需记住) 工作空间:私有数据、全部文件 加速器选GeForce_RTX_2080_Ti