基于PyTorch的「Keras」:除了核心逻辑通通都封装

2024-01-26 09:20

本文主要是介绍基于PyTorch的「Keras」:除了核心逻辑通通都封装,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

(给Python开发者加星标,提升Python技能

转自:机器之心

Keras 和 PyTorch 都是对初学者最友好的深度学习框架,它们用起来就像描述架构的简单语言一样,告诉框架哪一层该用什么就行。很多研究者和开发者都在考虑到底哪一个框架更好,但目前两个框架都非常流行,它们都各有优势。最近,Facebook 研究员 William Falcon 为 PyTorch 披上了一件 Keras 的外衣,他表明用这样的框架做研究简直不要太爽。


PyTorch Lightning 地址:https://github.com/williamFalcon/pytorch-lightning


640?wx_fmt=png


看起来像 Keras 的 PyTorch


Keras 本身的目的就是对深度学习框架(TensorFlow、Theano)进行了进一步的 API 封装。作为 TensorFlow 的高度封装,Keras 的抽象层次非常高,很多 API 细节都隐藏了起来。虽然 PyTorch 同样使用动态计算图,也方便快捷,但总体上 Keras 隐藏的细节更多一些。


反观 PyTorch,它提供一个相对较低级别的实验环境,使用户可以更加自由地编写自定义层、查看数值优化任务等等。例如在 PyTorch 1.0 中,编译工具 torch.jit 就包含一种名为 Torch Script 的语言,它是 Python 的子语言,开发者使用它能进一步对模型进行优化。


用 PyTorch 写模型,除了数据加载和模型定义部分外,整个训练和验证的逻辑、配置都需要我们手动完成,这些步骤都较为繁琐。甚至可以说,研究者需要耗费相当多的精力处理这一部分的代码,还要祈祷不出 Bug。但是对于大多数研究实验来说,训练和验证的循环体差不多都是一样的,实现的功能也相当一致,所以为什么不将这些通用的东西都打包在一起,这样训练不就简单了么?


William Falcon 正是这样想的,他将 PyTorch 开发中的各种通用配置全都包装起来,我们只需要写核心逻辑就行。通过 PyTorch Lightning,PyTorch 就类似于 Keras,它能以更高级的形式快速搭建模型。


项目作者是谁


要完成这样的工作,工作量肯定是非常大的,因为从超参搜索、模型 Debug、分布式训练、训练和验证的循环逻辑到模型日志的打印,都需要写一套通用的方案,确保各种任务都能用得上。所以 Facebook 的这位小哥哥 William Falcon 还是很厉害的。


640?wx_fmt=png


640?wx_fmt=png


他是一位 NYU 和 Facebook 的开发者。目前在 NYU 攻读 PhD。从 GitHub 的活动来看,小哥是一位比较活跃的开发者。


640?wx_fmt=png


这是一位披着 Keras 外衣的 PyTorch


Lightning 是 PyTorch 非常轻量级的包装,研究者只需要写最核心的训练和验证逻辑,其它过程都会自动完成。因此这就有点类似 Keras 那种高级包装,它隐藏了绝大多数细节,只保留了最通俗易懂的接口。Lightning 能确保自动完成部分的正确性,对于核心训练逻辑的提炼非常有优势。


那么我们为什么要用 Lightning?


当我们开始构建新项目,最后你希望做的可能就是记录训练循环、多集群训练、float16 精度、提前终止、模型加载/保存等等。这一系列过程可能需要花很多精力来解决各式各样、千奇百怪的 Bug,因此很难把精力都放在研究的核心逻辑上。


通过使用 Lightning,这些部分都能保证是 Work 的,因此能抽出精力关注我们要研究的东西:数据、训练、验证逻辑。此外,我们完全不需要担心使用多 GPU 加速会很难,因为 Lightning 会把这些东西都做好。


所以 Lightning 都能帮我们干什么?


下图展示了构建一个机器学习模型都会经历哪些过程,很多时候最困难的还不是写模型,是各种配置与预处理过程。如下蓝色的部分需要用 LightningModule 定义,而灰色部分 Lightning 可以自动完成。我们需要做的,差不多也就加载数据、定义模型、确定训练和验证过程。


640?wx_fmt=png


下面的伪代码展示了大致需要定义的几大模块,它们再加上模型架构定义就能成为完整的模型。


 
# what to do in the training loop	
def training_step(self, data_batch, batch_nb):	
# what to do in the validation loop	
def validation_step(self, data_batch, batch_nb):	# how to aggregate validation_step outputs	
def validation_end(self, outputs):	# and your dataloaders	
def tng_dataloader():	
def val_dataloader():	
def test_dataloader():


除了需要定义的模块外,以下步骤均可通过 Lightning 自动完成。当然,每个模块可以单独进行配置。


640?wx_fmt=png


640?wx_fmt=png


Lightning 怎么用


Lightning 的使用也非常简单,只需要两步就能完成:定义 LightningModel;拟合训练器。


以经典的 MNIST 图像识别为例,如下展示了 LightningModel 的示例。我们可以照常导入 PyTorch 模块,但这次不是继承 nn.Module,而是继承 LightningModel。然后我们只需要照常写 PyTorch 就行了,该调用函数还是继续调用。这里看上去似乎没什么不同,但注意方法名都是确定的,这样才能利用 Lightning 的后续过程。


 
import os	
import torch	
from torch.nn import functional as F	
from torch.utils.data import DataLoader	
from torchvision.datasets import MNIST	
import torchvision.transforms as transforms	import pytorch_lightning as ptl	class CoolModel(ptl.LightningModule):	def __init__(self):	super(CoolModel, self).__init__()	# not the best model...	self.l1 = torch.nn.Linear(28 * 28, 10)	def forward(self, x):	return torch.relu(self.l1(x.view(x.size(0), -1)))	def my_loss(self, y_hat, y):	return F.cross_entropy(y_hat, y)	def training_step(self, batch, batch_nb):	x, y = batch	y_hat = self.forward(x)	return {'loss': self.my_loss(y_hat, y)}	def validation_step(self, batch, batch_nb):	x, y = batch	y_hat = self.forward(x)	return {'val_loss': self.my_loss(y_hat, y)}	def validation_end(self, outputs):	avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()	return {'avg_val_loss': avg_loss}	def configure_optimizers(self):	return [torch.optim.Adam(self.parameters(), lr=0.02)]	@ptl.data_loader	def tng_dataloader(self):	return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)	@ptl.data_loader	def val_dataloader(self):	return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)	@ptl.data_loader	def test_dataloader(self):	return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)


随后,第二步即拟合训练器。这就比较类似 Keras 这类高级包装了,它将训练配置细节、循环体、以及日志输出等更加具体的信息全都隐藏了,一个 fit() 方法就能自动搞定一切。


这相比以前写 PyTorch 更加便捷精炼一些,而且分布式训练也非常容易,只要给出设备 id 就行了。


 
from pytorch_lightning import Trainer	
from test_tube import Experiment	model = CoolModel()	
exp = Experiment(save_dir=os.getcwd())	# train on cpu using only 10% of the data (for demo purposes)	
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1)	# train on 4 gpus	
# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=[0, 1, 2, 3])	
# train on 32 gpus across 4 nodes (make sure to submit appropriate SLURM job)	
# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=[0, 1, 2, 3, 4, 5, 6, 7], nb_gpu_nodes=4)	
# train (1 epoch only here for demo)	
trainer.fit(model)	
# view tensorflow logs print(f'View tensorboard logs by running\ntensorboard --logdir {os.getcwd()}')	
print('and going to http://localhost:6006 on your browser')


其他特性


Pytorch-Lightning 还可以和 TensorBoard 无缝对接。


640?wx_fmt=png


只需要定义运行的路径:



 
from test_tube import Experiment	
from pytorch-lightning import Trainer	exp = Experiment(save_dir = 『/some/path』)	
trainer = Trainer(experiment = exp)


将 TensorBoard 连接到路径上即可:


 
tensorboard -logdir /some/path


8月13日晚,腾讯将在澳门 IJCAI 2019 大会期间举办腾讯学术工业交流会(TAIC),诚邀AI从业者前来参加,共同探讨 AI 的应用与未来发展。点击「阅读原文」了解详情并参与报名。



推荐阅读

(点击标题可跳转阅读)

PyTorch 实战:计算 Wasserstein 距离

PyTorch 1.0 正式版发布了!

GitHub 热门项目:PyTorch 资源大全


觉得本文对你有帮助?请分享给更多人

关注「Python开发者」加星标,提升Python技能

640?wx_fmt=png

好文章,我在看❤️

这篇关于基于PyTorch的「Keras」:除了核心逻辑通通都封装的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/646329

相关文章

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

JavaSE——封装、继承和多态

1. 封装 1.1 概念      面向对象程序三大特性:封装、继承、多态 。而类和对象阶段,主要研究的就是封装特性。何为封装呢?简单来说就是套壳屏蔽细节 。     比如:对于电脑这样一个复杂的设备,提供给用户的就只是:开关机、通过键盘输入,显示器, USB 插孔等,让用户来和计算机进行交互,完成日常事务。但实际上:电脑真正工作的却是CPU 、显卡、内存等一些硬件元件。

PostgreSQL核心功能特性与使用领域及场景分析

PostgreSQL有什么优点? 开源和免费 PostgreSQL是一个开源的数据库管理系统,可以免费使用和修改。这降低了企业的成本,并为开发者提供了一个活跃的社区和丰富的资源。 高度兼容 PostgreSQL支持多种操作系统(如Linux、Windows、macOS等)和编程语言(如C、C++、Java、Python、Ruby等),并提供了多种接口(如JDBC、ODBC、ADO.NET等

哈希表的封装和位图

文章目录 2 封装2.1 基础框架2.2 迭代器(1)2.3 迭代器(2) 3. 位图3.1 问题引入3.2 左移和右移?3.3 位图的实现3.4 位图的题目3.5 位图的应用 2 封装 2.1 基础框架 文章 有了前面map和set封装的经验,容易写出下面的代码 // UnorderedSet.h#pragma once#include "HashTable.h"

深入解析秒杀业务中的核心问题 —— 从并发控制到事务管理

深入解析秒杀业务中的核心问题 —— 从并发控制到事务管理 秒杀系统是应对高并发、高压力下的典型业务场景,涉及到并发控制、库存管理、事务管理等多个关键技术点。本文将深入剖析秒杀商品业务中常见的几个核心问题,包括 AOP 事务管理、同步锁机制、乐观锁、CAS 操作,以及用户限购策略。通过这些技术的结合,确保秒杀系统在高并发场景下的稳定性和一致性。 1. AOP 代理对象与事务管理 在秒杀商品

封装MySQL操作时Where条件语句的组织

在对数据库进行封装的过程中,条件语句应该是相对难以处理的,毕竟条件语句太过于多样性。 条件语句大致分为以下几种: 1、单一条件,比如:where id = 1; 2、多个条件,相互间关系统一。比如:where id > 10 and age > 20 and score < 60; 3、多个条件,相互间关系不统一。比如:where (id > 10 OR age > 20) AND sco

逻辑表达式,最小项

目录 得到此图的逻辑电路 1.画出它的真值表 2.根据真值表写出逻辑式 3.画逻辑图 逻辑函数的表示 逻辑表达式 最小项 定义 基本性质 最小项编号 最小项表达式   得到此图的逻辑电路 1.画出它的真值表 这是同或的逻辑式。 2.根据真值表写出逻辑式   3.画逻辑图   有两种画法,1是根据运算优先级非>与>或得到,第二种是采

UMI复现代码运行逻辑全流程(一)——eval_real.py(尚在更新)

一、文件夹功能解析 全文件夹如下 其中,核心文件作用为: diffusion_policy:扩散策略核心文件夹,包含了众多模型及基础库 example:标定及配置文件 scripts/scripts_real:测试脚本文件,区别在于前者倾向于单体运行,后者为整体运行 scripts_slam_pipeline:orb_slam3运行全部文件 umi:核心交互文件夹,作用在于构建真

文章解读与仿真程序复现思路——电力自动化设备EI\CSCD\北大核心《考虑燃料电池和电解槽虚拟惯量支撑的电力系统优化调度方法》

本专栏栏目提供文章与程序复现思路,具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源程序擅长文章解读,论文与完整源程序,等方面的知识,电网论文源程序关注python