Pytorch Ignite 使用方法

2024-01-16 15:48
文章标签 使用 方法 pytorch ignite

本文主要是介绍Pytorch Ignite 使用方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Pytorch Ignite 使用方法

下载 pip install ignite
官方网址:https://pytorch.org/ignite/concepts.html

Engine

该框架的本质是class Engine,它是一种抽象形式,它在提供的数据上循环给定的次数,执行处理函数并返回结果:

  while epoch < max_epochs:# run an epoch on datadata_iter = iter(data)while True:try:batch = next(data_iter)output = process_function(batch)iter_counter += 1except StopIteration:data_iter = iter(data)if iter_counter == epoch_length:break

因此,模型训练器只是一个引擎,它在训练数据集上循环多次并更新模型参数。同样,可以使用在验证数据集上运行一次并计算指标的引擎来完成模型评估。

例如,用于监督任务的模型训练器:

def train_step(trainer, batch):model.train()optimizer.zero_grad()x, y = prepare_batch(batch)y_pred = model(x)loss = loss_fn(y_pred, y)loss.backward()optimizer.step()return loss.item()trainer = Engine(train_step)
trainer.run(data, max_epochs=100)

训练步骤的输出类型(即在上面的示例中loss.item())不受限制。训练步骤功能可以返回用户想要的一切。输出设置为,trainer.state.output并且可以进一步用于任何类型的处理。

默认情况下,epoch_length 长度由len(data)定义。但是,用户也可以手动将epoch_length 长度定义为要循环的多次迭代。这样,输入数据可以是迭代器。

trainer.run(data, max_epochs=100, epoch_length=200)

如果data是长度未知的有限数据迭代器(对于用户),epoch_length则可以省略参数,并且在耗尽数据迭代器时将自动确定参数。

任何复杂度的训练逻辑都可以使用train_step方法进行编码,并且可以使用此方法来构造训练器。

函数batch中的train_step参数是用户定义的,可以包含单个迭代所需的任何数据。

# 定义模型参数
model_1 = ...
model_2 = ...
# 定义优化器
optimizer_1 = ...
optimizer_2 = ...
# 
criterion_1 = ...
criterion_2 = ...
# ...def train_step(trainer, batch):data_1 = batch["data_1"]data_2 = batch["data_2"]# ...model_1.train()optimizer_1.zero_grad()loss_1 = forward_pass(data_1, model_1, criterion_1)loss_1.backward()optimizer_1.step()# ...model_2.train()optimizer_2.zero_grad()loss_2 = forward_pass(data_2, model_2, criterion_2)loss_2.backward()optimizer_2.step()# ...# User can return any type of structure.# 用户可以返回任何类型的数据结构return {"loss_1": loss_1,"loss_2": loss_2,# ...}trainer = Engine(train_step)
trainer.run(data, max_epochs=100)

事件和处理程序

为了提高Engine灵活性,引入了事件系统,该系统促进了运行的每个步骤之间的交互:

  • Engine启动/完成
  • epoch开始/完成
  • 批处理迭代已开始/已完成

有关事件的完整列表,请参见Events

用户可以执行自定义代码作为事件处理程序。处理程序可以是任何函数:例如lambda,简单函数,类方法等。第一个参数可以选择是engine,但不是必须的。

让我们更详细地考虑当run()被调用时发生的情况:

fire_event(Events.STARTED)
while epoch < max_epochs:fire_event(Events.EPOCH_STARTED)# run once on datafor batch in data:fire_event(Events.ITERATION_STARTED)output = process_function(batch)fire_event(Events.ITERATION_COMPLETED)fire_event(Events.EPOCH_COMPLETED)
fire_event(Events.COMPLETED)

首先,将触发*“Engine已启动”事件并执行其所有事件处理程序(我们将在下一段中看到如何添加事件处理程序)。接下来,在启动循环并发生“epoch 开始”*事件时,等等。每次触发事件时,都会执行附加的处理程序。

使用方法add_event_handler()on()装饰器附加事件处理程序很简单:

trainer = Engine(update_model)trainer.add_event_handler(Events.STARTED, lambda _: print("Start training"))
# or
@trainer.on(Events.STARTED)
def on_training_started(engine):print("Another message of start training")
# or even simpler, use only what you need !
@trainer.on(Events.STARTED)
def on_training_started():print("Another message of start training")# attach handler with args, kwargs
mydata = [1, 2, 3, 4]def on_training_ended(data):print(f"Training is ended. mydata={data}")trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)

可以通过remove_event_handler()或通过RemovableEventHandle 返回的引用来分离事件处理程序add_event_handler()。这可用于将已配置的引擎重用于多个循环:

model = ...
train_loader, validation_loader, test_loader = ...trainer = create_supervised_trainer(model, optimizer, loss)
evaluator = create_supervised_evaluator(model, metrics={"acc": Accuracy()})def log_metrics(engine, title):print(f"Epoch: {trainer.state.epoch} - {title} accuracy: {engine.state.metrics['acc']:.2f}")@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(trainer):with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "train"):evaluator.run(train_loader)with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "validation"):evaluator.run(validation_loader)with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "test"):evaluator.run(test_loader)trainer.run(train_loader, max_epochs=100)

还可以将事件处理程序配置为以用户模式调用:每第n个事件一次,或使用自定义事件过滤功能:

model = ...
train_loader, validation_loader, test_loader = ...trainer = create_supervised_trainer(model, optimizer, loss)@trainer.on(Events.ITERATION_COMPLETED(every=50))
def log_training_loss_every_50_iterations():print(f"{trainer.state.epoch} / {trainer.state.max_epochs} : {trainer.state.iteration} - loss: {trainer.state.output:.2f}")@trainer.on(Events.EPOCH_STARTED(once=25))
def do_something_once_on_25_epoch():# do somethingdef custom_event_filter(engine, event):if event in [1, 2, 5, 10, 50, 100]:return Truereturn False@engine.on(Events.ITERATION_STARTED(event_filter=custom_event_filter))
def call_on_special_event(engine):# do something on 1, 2, 5, 10, 50, 100 iterationstrainer.run(train_loader, max_epochs=100)

自定义事件

用户还可以定义自定义事件。用户定义的事件应继承于引擎EventEnumregister_events()在引擎中注册。

from ignite.engine import EventEnumclass CustomEvents(EventEnum):"""Custom events defined by user"""CUSTOM_STARTED = 'custom_started'CUSTOM_COMPLETED = 'custom_completed'engine.register_events(*CustomEvents)

这些事件可用于附加任何处理程序,并使用触发fire_event()

@engine.on(CustomEvents.CUSTOM_STARTED)
def call_on_custom_event(engine):# do something@engine.on(Events.STARTED)
def fire_custom_events(engine):engine.fire_event(CustomEvents.CUSTOM_STARTED)

时间线和事件

在事件下方,一些典型的处理程序显示在时间轴上,以进行训练循环,并在每个时期后进行评估:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-U3DJJvxX-1618109731932)(https://pytorch.org/ignite/_images/timeline_and_events.png)]

状态

Engine引入了一个状态来存储process_function,当前时期,迭代和其他有用信息的输出。每个都Engine包含一个State,其中包括以下内容:

  • engine.state.seed:要在每个数据“ epoch”处设置的种子。
  • engine.state.epoch:引擎已完成的纪元数。初始化为0,第一个时期为1。
  • engine.state.iteration:引擎已完成的迭代次数。初始化为0,第一次迭代为1。
  • engine.state.max_epochs:要运行的时期数。初始化为1。
  • engine.state.output:为定义的process_function的输出Engine。见下文。
  • 等等

其他属性可以在的文档中找到State

在下面的代码中,engine.state.output将存储批次损失。此输出用于打印每次迭代的损耗。

def update(engine, batch):x, y = batchy_pred = model(inputs)loss = loss_fn(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()return loss.item()def on_iteration_completed(engine):iteration = engine.state.iterationepoch = engine.state.epochloss = engine.state.outputprint(f"Epoch: {epoch}, Iteration: {iteration}, Loss: {loss}")trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed)

由于对process_function的输出没有限制,因此Ignite为其和提供了output_transform参数 。参数output_transform是用于将engine.state.output转换为预期用途的函数。在下面,我们将看到不同类型的engine.state.output以及如何对其进行转换。metrics``handlers

在下面的代码中,engine.state.output将已处理批次的loss,y_pred,y。如果要附加Accuracy到引擎,则需要output_transform来从engine.state.output获取y_pred和y 。让我们看看如何做到这一点:

def update(engine, batch):x, y = batchy_pred = model(inputs)loss = loss_fn(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()return loss.item(), y_pred, ytrainer = Engine(update)@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):epoch = engine.state.epochloss = engine.state.output[0]print (f'Epoch {epoch}: train_loss = {loss}')accuracy = Accuracy(output_transform=lambda x: [x[1], x[2]])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)

与上面类似,但是这次process_function的输出是处理后的批次的字典 loss,y_pred,y,这是用户可以使用output_transform从engine.state.output获取y_pred和y的方式。见下文:

def update(engine, batch):x, y = batchy_pred = model(inputs)loss = loss_fn(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()return {'loss': loss.item(),'y_pred': y_pred,'y': y}trainer = Engine(update)@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):epoch = engine.state.epochloss = engine.state.output['loss']print (f'Epoch {epoch}: train_loss = {loss}')accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)

指标

库提供了各种机器学习任务的现成指标列表。支持两种计算指标的方式:1)在线 和 2)存储整个输出历史记录。

指标可以附加到Engine

from ignite.metrics import Accuracyaccuracy = Accuracy()accuracy.attach(evaluator, "accuracy")state = evaluator.run(validation_data)print("Result:", state.metrics)
# > {"accuracy": 0.12345}

或可用作独立对象:

from ignite.metrics import Accuracyaccuracy = Accuracy()accuracy.reset()for y_pred, y in get_prediction_target():accuracy.update((y_pred, y))print("Result:", accuracy.compute())

或可用作独立对象:

from ignite.metrics import Accuracyaccuracy = Accuracy()accuracy.reset()for y_pred, y in get_prediction_target():accuracy.update((y_pred, y))print("Result:", accuracy.compute())

完整的指标列表和API可以在ignite.metrics模块中找到。

这篇关于Pytorch Ignite 使用方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/613089

相关文章

Nginx安全防护的多种方法

《Nginx安全防护的多种方法》在生产环境中,需要隐藏Nginx的版本号,以避免泄漏Nginx的版本,使攻击者不能针对特定版本进行攻击,下面就来介绍一下Nginx安全防护的方法,感兴趣的可以了解一下... 目录核心安全配置1.编译安装 Nginx2.隐藏版本号3.限制危险请求方法4.请求限制(CC攻击防御)

python生成随机唯一id的几种实现方法

《python生成随机唯一id的几种实现方法》在Python中生成随机唯一ID有多种方法,根据不同的需求场景可以选择最适合的方案,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起学习学习... 目录方法 1:使用 UUID 模块(推荐)方法 2:使用 Secrets 模块(安全敏感场景)方法

一文详解如何使用Java获取PDF页面信息

《一文详解如何使用Java获取PDF页面信息》了解PDF页面属性是我们在处理文档、内容提取、打印设置或页面重组等任务时不可或缺的一环,下面我们就来看看如何使用Java语言获取这些信息吧... 目录引言一、安装和引入PDF处理库引入依赖二、获取 PDF 页数三、获取页面尺寸(宽高)四、获取页面旋转角度五、判断

MyBatis-Plus通用中等、大量数据分批查询和处理方法

《MyBatis-Plus通用中等、大量数据分批查询和处理方法》文章介绍MyBatis-Plus分页查询处理,通过函数式接口与Lambda表达式实现通用逻辑,方法抽象但功能强大,建议扩展分批处理及流式... 目录函数式接口获取分页数据接口数据处理接口通用逻辑工具类使用方法简单查询自定义查询方法总结函数式接口

C++中assign函数的使用

《C++中assign函数的使用》在C++标准模板库中,std::list等容器都提供了assign成员函数,它比操作符更灵活,支持多种初始化方式,下面就来介绍一下assign的用法,具有一定的参考价... 目录​1.assign的基本功能​​语法​2. 具体用法示例​​​(1) 填充n个相同值​​(2)

MySQL深分页进行性能优化的常见方法

《MySQL深分页进行性能优化的常见方法》在Web应用中,分页查询是数据库操作中的常见需求,然而,在面对大型数据集时,深分页(deeppagination)却成为了性能优化的一个挑战,在本文中,我们将... 目录引言:深分页,真的只是“翻页慢”那么简单吗?一、背景介绍二、深分页的性能问题三、业务场景分析四、

JAVA中安装多个JDK的方法

《JAVA中安装多个JDK的方法》文章介绍了在Windows系统上安装多个JDK版本的方法,包括下载、安装路径修改、环境变量配置(JAVA_HOME和Path),并说明如何通过调整JAVA_HOME在... 首先去oracle官网下载好两个版本不同的jdk(需要登录Oracle账号,没有可以免费注册)下载完

Spring StateMachine实现状态机使用示例详解

《SpringStateMachine实现状态机使用示例详解》本文介绍SpringStateMachine实现状态机的步骤,包括依赖导入、枚举定义、状态转移规则配置、上下文管理及服务调用示例,重点解... 目录什么是状态机使用示例什么是状态机状态机是计算机科学中的​​核心建模工具​​,用于描述对象在其生命

使用Python删除Excel中的行列和单元格示例详解

《使用Python删除Excel中的行列和单元格示例详解》在处理Excel数据时,删除不需要的行、列或单元格是一项常见且必要的操作,本文将使用Python脚本实现对Excel表格的高效自动化处理,感兴... 目录开发环境准备使用 python 删除 Excphpel 表格中的行删除特定行删除空白行删除含指定

深入理解Go语言中二维切片的使用

《深入理解Go语言中二维切片的使用》本文深入讲解了Go语言中二维切片的概念与应用,用于表示矩阵、表格等二维数据结构,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起学习学习吧... 目录引言二维切片的基本概念定义创建二维切片二维切片的操作访问元素修改元素遍历二维切片二维切片的动态调整追加行动态