Pytorch Ignite 使用方法

2024-01-16 15:48
文章标签 使用 方法 pytorch ignite

本文主要是介绍Pytorch Ignite 使用方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Pytorch Ignite 使用方法

下载 pip install ignite
官方网址:https://pytorch.org/ignite/concepts.html

Engine

该框架的本质是class Engine,它是一种抽象形式,它在提供的数据上循环给定的次数,执行处理函数并返回结果:

  while epoch < max_epochs:# run an epoch on datadata_iter = iter(data)while True:try:batch = next(data_iter)output = process_function(batch)iter_counter += 1except StopIteration:data_iter = iter(data)if iter_counter == epoch_length:break

因此,模型训练器只是一个引擎,它在训练数据集上循环多次并更新模型参数。同样,可以使用在验证数据集上运行一次并计算指标的引擎来完成模型评估。

例如,用于监督任务的模型训练器:

def train_step(trainer, batch):model.train()optimizer.zero_grad()x, y = prepare_batch(batch)y_pred = model(x)loss = loss_fn(y_pred, y)loss.backward()optimizer.step()return loss.item()trainer = Engine(train_step)
trainer.run(data, max_epochs=100)

训练步骤的输出类型(即在上面的示例中loss.item())不受限制。训练步骤功能可以返回用户想要的一切。输出设置为,trainer.state.output并且可以进一步用于任何类型的处理。

默认情况下,epoch_length 长度由len(data)定义。但是,用户也可以手动将epoch_length 长度定义为要循环的多次迭代。这样,输入数据可以是迭代器。

trainer.run(data, max_epochs=100, epoch_length=200)

如果data是长度未知的有限数据迭代器(对于用户),epoch_length则可以省略参数,并且在耗尽数据迭代器时将自动确定参数。

任何复杂度的训练逻辑都可以使用train_step方法进行编码,并且可以使用此方法来构造训练器。

函数batch中的train_step参数是用户定义的,可以包含单个迭代所需的任何数据。

# 定义模型参数
model_1 = ...
model_2 = ...
# 定义优化器
optimizer_1 = ...
optimizer_2 = ...
# 
criterion_1 = ...
criterion_2 = ...
# ...def train_step(trainer, batch):data_1 = batch["data_1"]data_2 = batch["data_2"]# ...model_1.train()optimizer_1.zero_grad()loss_1 = forward_pass(data_1, model_1, criterion_1)loss_1.backward()optimizer_1.step()# ...model_2.train()optimizer_2.zero_grad()loss_2 = forward_pass(data_2, model_2, criterion_2)loss_2.backward()optimizer_2.step()# ...# User can return any type of structure.# 用户可以返回任何类型的数据结构return {"loss_1": loss_1,"loss_2": loss_2,# ...}trainer = Engine(train_step)
trainer.run(data, max_epochs=100)

事件和处理程序

为了提高Engine灵活性,引入了事件系统,该系统促进了运行的每个步骤之间的交互:

  • Engine启动/完成
  • epoch开始/完成
  • 批处理迭代已开始/已完成

有关事件的完整列表,请参见Events

用户可以执行自定义代码作为事件处理程序。处理程序可以是任何函数:例如lambda,简单函数,类方法等。第一个参数可以选择是engine,但不是必须的。

让我们更详细地考虑当run()被调用时发生的情况:

fire_event(Events.STARTED)
while epoch < max_epochs:fire_event(Events.EPOCH_STARTED)# run once on datafor batch in data:fire_event(Events.ITERATION_STARTED)output = process_function(batch)fire_event(Events.ITERATION_COMPLETED)fire_event(Events.EPOCH_COMPLETED)
fire_event(Events.COMPLETED)

首先,将触发*“Engine已启动”事件并执行其所有事件处理程序(我们将在下一段中看到如何添加事件处理程序)。接下来,在启动循环并发生“epoch 开始”*事件时,等等。每次触发事件时,都会执行附加的处理程序。

使用方法add_event_handler()on()装饰器附加事件处理程序很简单:

trainer = Engine(update_model)trainer.add_event_handler(Events.STARTED, lambda _: print("Start training"))
# or
@trainer.on(Events.STARTED)
def on_training_started(engine):print("Another message of start training")
# or even simpler, use only what you need !
@trainer.on(Events.STARTED)
def on_training_started():print("Another message of start training")# attach handler with args, kwargs
mydata = [1, 2, 3, 4]def on_training_ended(data):print(f"Training is ended. mydata={data}")trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)

可以通过remove_event_handler()或通过RemovableEventHandle 返回的引用来分离事件处理程序add_event_handler()。这可用于将已配置的引擎重用于多个循环:

model = ...
train_loader, validation_loader, test_loader = ...trainer = create_supervised_trainer(model, optimizer, loss)
evaluator = create_supervised_evaluator(model, metrics={"acc": Accuracy()})def log_metrics(engine, title):print(f"Epoch: {trainer.state.epoch} - {title} accuracy: {engine.state.metrics['acc']:.2f}")@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(trainer):with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "train"):evaluator.run(train_loader)with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "validation"):evaluator.run(validation_loader)with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "test"):evaluator.run(test_loader)trainer.run(train_loader, max_epochs=100)

还可以将事件处理程序配置为以用户模式调用:每第n个事件一次,或使用自定义事件过滤功能:

model = ...
train_loader, validation_loader, test_loader = ...trainer = create_supervised_trainer(model, optimizer, loss)@trainer.on(Events.ITERATION_COMPLETED(every=50))
def log_training_loss_every_50_iterations():print(f"{trainer.state.epoch} / {trainer.state.max_epochs} : {trainer.state.iteration} - loss: {trainer.state.output:.2f}")@trainer.on(Events.EPOCH_STARTED(once=25))
def do_something_once_on_25_epoch():# do somethingdef custom_event_filter(engine, event):if event in [1, 2, 5, 10, 50, 100]:return Truereturn False@engine.on(Events.ITERATION_STARTED(event_filter=custom_event_filter))
def call_on_special_event(engine):# do something on 1, 2, 5, 10, 50, 100 iterationstrainer.run(train_loader, max_epochs=100)

自定义事件

用户还可以定义自定义事件。用户定义的事件应继承于引擎EventEnumregister_events()在引擎中注册。

from ignite.engine import EventEnumclass CustomEvents(EventEnum):"""Custom events defined by user"""CUSTOM_STARTED = 'custom_started'CUSTOM_COMPLETED = 'custom_completed'engine.register_events(*CustomEvents)

这些事件可用于附加任何处理程序,并使用触发fire_event()

@engine.on(CustomEvents.CUSTOM_STARTED)
def call_on_custom_event(engine):# do something@engine.on(Events.STARTED)
def fire_custom_events(engine):engine.fire_event(CustomEvents.CUSTOM_STARTED)

时间线和事件

在事件下方,一些典型的处理程序显示在时间轴上,以进行训练循环,并在每个时期后进行评估:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-U3DJJvxX-1618109731932)(https://pytorch.org/ignite/_images/timeline_and_events.png)]

状态

Engine引入了一个状态来存储process_function,当前时期,迭代和其他有用信息的输出。每个都Engine包含一个State,其中包括以下内容:

  • engine.state.seed:要在每个数据“ epoch”处设置的种子。
  • engine.state.epoch:引擎已完成的纪元数。初始化为0,第一个时期为1。
  • engine.state.iteration:引擎已完成的迭代次数。初始化为0,第一次迭代为1。
  • engine.state.max_epochs:要运行的时期数。初始化为1。
  • engine.state.output:为定义的process_function的输出Engine。见下文。
  • 等等

其他属性可以在的文档中找到State

在下面的代码中,engine.state.output将存储批次损失。此输出用于打印每次迭代的损耗。

def update(engine, batch):x, y = batchy_pred = model(inputs)loss = loss_fn(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()return loss.item()def on_iteration_completed(engine):iteration = engine.state.iterationepoch = engine.state.epochloss = engine.state.outputprint(f"Epoch: {epoch}, Iteration: {iteration}, Loss: {loss}")trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed)

由于对process_function的输出没有限制,因此Ignite为其和提供了output_transform参数 。参数output_transform是用于将engine.state.output转换为预期用途的函数。在下面,我们将看到不同类型的engine.state.output以及如何对其进行转换。metrics``handlers

在下面的代码中,engine.state.output将已处理批次的loss,y_pred,y。如果要附加Accuracy到引擎,则需要output_transform来从engine.state.output获取y_pred和y 。让我们看看如何做到这一点:

def update(engine, batch):x, y = batchy_pred = model(inputs)loss = loss_fn(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()return loss.item(), y_pred, ytrainer = Engine(update)@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):epoch = engine.state.epochloss = engine.state.output[0]print (f'Epoch {epoch}: train_loss = {loss}')accuracy = Accuracy(output_transform=lambda x: [x[1], x[2]])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)

与上面类似,但是这次process_function的输出是处理后的批次的字典 loss,y_pred,y,这是用户可以使用output_transform从engine.state.output获取y_pred和y的方式。见下文:

def update(engine, batch):x, y = batchy_pred = model(inputs)loss = loss_fn(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()return {'loss': loss.item(),'y_pred': y_pred,'y': y}trainer = Engine(update)@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):epoch = engine.state.epochloss = engine.state.output['loss']print (f'Epoch {epoch}: train_loss = {loss}')accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)

指标

库提供了各种机器学习任务的现成指标列表。支持两种计算指标的方式:1)在线 和 2)存储整个输出历史记录。

指标可以附加到Engine

from ignite.metrics import Accuracyaccuracy = Accuracy()accuracy.attach(evaluator, "accuracy")state = evaluator.run(validation_data)print("Result:", state.metrics)
# > {"accuracy": 0.12345}

或可用作独立对象:

from ignite.metrics import Accuracyaccuracy = Accuracy()accuracy.reset()for y_pred, y in get_prediction_target():accuracy.update((y_pred, y))print("Result:", accuracy.compute())

或可用作独立对象:

from ignite.metrics import Accuracyaccuracy = Accuracy()accuracy.reset()for y_pred, y in get_prediction_target():accuracy.update((y_pred, y))print("Result:", accuracy.compute())

完整的指标列表和API可以在ignite.metrics模块中找到。

这篇关于Pytorch Ignite 使用方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/613089

相关文章

使用Python合并 Excel单元格指定行列或单元格范围

《使用Python合并Excel单元格指定行列或单元格范围》合并Excel单元格是Excel数据处理和表格设计中的一项常用操作,本文将介绍如何通过Python合并Excel中的指定行列或单... 目录python Excel库安装Python合并Excel 中的指定行Python合并Excel 中的指定列P

浅析Rust多线程中如何安全的使用变量

《浅析Rust多线程中如何安全的使用变量》这篇文章主要为大家详细介绍了Rust如何在线程的闭包中安全的使用变量,包括共享变量和修改变量,文中的示例代码讲解详细,有需要的小伙伴可以参考下... 目录1. 向线程传递变量2. 多线程共享变量引用3. 多线程中修改变量4. 总结在Rust语言中,一个既引人入胜又可

四种Flutter子页面向父组件传递数据的方法介绍

《四种Flutter子页面向父组件传递数据的方法介绍》在Flutter中,如果父组件需要调用子组件的方法,可以通过常用的四种方式实现,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录方法 1:使用 GlobalKey 和 State 调用子组件方法方法 2:通过回调函数(Callb

一文详解Python中数据清洗与处理的常用方法

《一文详解Python中数据清洗与处理的常用方法》在数据处理与分析过程中,缺失值、重复值、异常值等问题是常见的挑战,本文总结了多种数据清洗与处理方法,文中的示例代码简洁易懂,有需要的小伙伴可以参考下... 目录缺失值处理重复值处理异常值处理数据类型转换文本清洗数据分组统计数据分箱数据标准化在数据处理与分析过

Java中Object类的常用方法小结

《Java中Object类的常用方法小结》JavaObject类是所有类的父类,位于java.lang包中,本文为大家整理了一些Object类的常用方法,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. public boolean equals(Object obj)2. public int ha

golang1.23版本之前 Timer Reset方法无法正确使用

《golang1.23版本之前TimerReset方法无法正确使用》在Go1.23之前,使用`time.Reset`函数时需要先调用`Stop`并明确从timer的channel中抽取出东西,以避... 目录golang1.23 之前 Reset ​到底有什么问题golang1.23 之前到底应该如何正确的

Vue项目中Element UI组件未注册的问题原因及解决方法

《Vue项目中ElementUI组件未注册的问题原因及解决方法》在Vue项目中使用ElementUI组件库时,开发者可能会遇到一些常见问题,例如组件未正确注册导致的警告或错误,本文将详细探讨这些问题... 目录引言一、问题背景1.1 错误信息分析1.2 问题原因二、解决方法2.1 全局引入 Element

Python调用另一个py文件并传递参数常见的方法及其应用场景

《Python调用另一个py文件并传递参数常见的方法及其应用场景》:本文主要介绍在Python中调用另一个py文件并传递参数的几种常见方法,包括使用import语句、exec函数、subproce... 目录前言1. 使用import语句1.1 基本用法1.2 导入特定函数1.3 处理文件路径2. 使用ex

详解Vue如何使用xlsx库导出Excel文件

《详解Vue如何使用xlsx库导出Excel文件》第三方库xlsx提供了强大的功能来处理Excel文件,它可以简化导出Excel文件这个过程,本文将为大家详细介绍一下它的具体使用,需要的小伙伴可以了解... 目录1. 安装依赖2. 创建vue组件3. 解释代码在Vue.js项目中导出Excel文件,使用第三

Linux alias的三种使用场景方式

《Linuxalias的三种使用场景方式》文章介绍了Linux中`alias`命令的三种使用场景:临时别名、用户级别别名和系统级别别名,临时别名仅在当前终端有效,用户级别别名在当前用户下所有终端有效... 目录linux alias三种使用场景一次性适用于当前用户全局生效,所有用户都可调用删除总结Linux