Pytorch Ignite 使用方法

2024-01-16 15:48
文章标签 使用 方法 pytorch ignite

本文主要是介绍Pytorch Ignite 使用方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Pytorch Ignite 使用方法

下载 pip install ignite
官方网址:https://pytorch.org/ignite/concepts.html

Engine

该框架的本质是class Engine,它是一种抽象形式,它在提供的数据上循环给定的次数,执行处理函数并返回结果:

  while epoch < max_epochs:# run an epoch on datadata_iter = iter(data)while True:try:batch = next(data_iter)output = process_function(batch)iter_counter += 1except StopIteration:data_iter = iter(data)if iter_counter == epoch_length:break

因此,模型训练器只是一个引擎,它在训练数据集上循环多次并更新模型参数。同样,可以使用在验证数据集上运行一次并计算指标的引擎来完成模型评估。

例如,用于监督任务的模型训练器:

def train_step(trainer, batch):model.train()optimizer.zero_grad()x, y = prepare_batch(batch)y_pred = model(x)loss = loss_fn(y_pred, y)loss.backward()optimizer.step()return loss.item()trainer = Engine(train_step)
trainer.run(data, max_epochs=100)

训练步骤的输出类型(即在上面的示例中loss.item())不受限制。训练步骤功能可以返回用户想要的一切。输出设置为,trainer.state.output并且可以进一步用于任何类型的处理。

默认情况下,epoch_length 长度由len(data)定义。但是,用户也可以手动将epoch_length 长度定义为要循环的多次迭代。这样,输入数据可以是迭代器。

trainer.run(data, max_epochs=100, epoch_length=200)

如果data是长度未知的有限数据迭代器(对于用户),epoch_length则可以省略参数,并且在耗尽数据迭代器时将自动确定参数。

任何复杂度的训练逻辑都可以使用train_step方法进行编码,并且可以使用此方法来构造训练器。

函数batch中的train_step参数是用户定义的,可以包含单个迭代所需的任何数据。

# 定义模型参数
model_1 = ...
model_2 = ...
# 定义优化器
optimizer_1 = ...
optimizer_2 = ...
# 
criterion_1 = ...
criterion_2 = ...
# ...def train_step(trainer, batch):data_1 = batch["data_1"]data_2 = batch["data_2"]# ...model_1.train()optimizer_1.zero_grad()loss_1 = forward_pass(data_1, model_1, criterion_1)loss_1.backward()optimizer_1.step()# ...model_2.train()optimizer_2.zero_grad()loss_2 = forward_pass(data_2, model_2, criterion_2)loss_2.backward()optimizer_2.step()# ...# User can return any type of structure.# 用户可以返回任何类型的数据结构return {"loss_1": loss_1,"loss_2": loss_2,# ...}trainer = Engine(train_step)
trainer.run(data, max_epochs=100)

事件和处理程序

为了提高Engine灵活性,引入了事件系统,该系统促进了运行的每个步骤之间的交互:

  • Engine启动/完成
  • epoch开始/完成
  • 批处理迭代已开始/已完成

有关事件的完整列表,请参见Events

用户可以执行自定义代码作为事件处理程序。处理程序可以是任何函数:例如lambda,简单函数,类方法等。第一个参数可以选择是engine,但不是必须的。

让我们更详细地考虑当run()被调用时发生的情况:

fire_event(Events.STARTED)
while epoch < max_epochs:fire_event(Events.EPOCH_STARTED)# run once on datafor batch in data:fire_event(Events.ITERATION_STARTED)output = process_function(batch)fire_event(Events.ITERATION_COMPLETED)fire_event(Events.EPOCH_COMPLETED)
fire_event(Events.COMPLETED)

首先,将触发*“Engine已启动”事件并执行其所有事件处理程序(我们将在下一段中看到如何添加事件处理程序)。接下来,在启动循环并发生“epoch 开始”*事件时,等等。每次触发事件时,都会执行附加的处理程序。

使用方法add_event_handler()on()装饰器附加事件处理程序很简单:

trainer = Engine(update_model)trainer.add_event_handler(Events.STARTED, lambda _: print("Start training"))
# or
@trainer.on(Events.STARTED)
def on_training_started(engine):print("Another message of start training")
# or even simpler, use only what you need !
@trainer.on(Events.STARTED)
def on_training_started():print("Another message of start training")# attach handler with args, kwargs
mydata = [1, 2, 3, 4]def on_training_ended(data):print(f"Training is ended. mydata={data}")trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)

可以通过remove_event_handler()或通过RemovableEventHandle 返回的引用来分离事件处理程序add_event_handler()。这可用于将已配置的引擎重用于多个循环:

model = ...
train_loader, validation_loader, test_loader = ...trainer = create_supervised_trainer(model, optimizer, loss)
evaluator = create_supervised_evaluator(model, metrics={"acc": Accuracy()})def log_metrics(engine, title):print(f"Epoch: {trainer.state.epoch} - {title} accuracy: {engine.state.metrics['acc']:.2f}")@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(trainer):with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "train"):evaluator.run(train_loader)with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "validation"):evaluator.run(validation_loader)with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "test"):evaluator.run(test_loader)trainer.run(train_loader, max_epochs=100)

还可以将事件处理程序配置为以用户模式调用:每第n个事件一次,或使用自定义事件过滤功能:

model = ...
train_loader, validation_loader, test_loader = ...trainer = create_supervised_trainer(model, optimizer, loss)@trainer.on(Events.ITERATION_COMPLETED(every=50))
def log_training_loss_every_50_iterations():print(f"{trainer.state.epoch} / {trainer.state.max_epochs} : {trainer.state.iteration} - loss: {trainer.state.output:.2f}")@trainer.on(Events.EPOCH_STARTED(once=25))
def do_something_once_on_25_epoch():# do somethingdef custom_event_filter(engine, event):if event in [1, 2, 5, 10, 50, 100]:return Truereturn False@engine.on(Events.ITERATION_STARTED(event_filter=custom_event_filter))
def call_on_special_event(engine):# do something on 1, 2, 5, 10, 50, 100 iterationstrainer.run(train_loader, max_epochs=100)

自定义事件

用户还可以定义自定义事件。用户定义的事件应继承于引擎EventEnumregister_events()在引擎中注册。

from ignite.engine import EventEnumclass CustomEvents(EventEnum):"""Custom events defined by user"""CUSTOM_STARTED = 'custom_started'CUSTOM_COMPLETED = 'custom_completed'engine.register_events(*CustomEvents)

这些事件可用于附加任何处理程序,并使用触发fire_event()

@engine.on(CustomEvents.CUSTOM_STARTED)
def call_on_custom_event(engine):# do something@engine.on(Events.STARTED)
def fire_custom_events(engine):engine.fire_event(CustomEvents.CUSTOM_STARTED)

时间线和事件

在事件下方,一些典型的处理程序显示在时间轴上,以进行训练循环,并在每个时期后进行评估:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-U3DJJvxX-1618109731932)(https://pytorch.org/ignite/_images/timeline_and_events.png)]

状态

Engine引入了一个状态来存储process_function,当前时期,迭代和其他有用信息的输出。每个都Engine包含一个State,其中包括以下内容:

  • engine.state.seed:要在每个数据“ epoch”处设置的种子。
  • engine.state.epoch:引擎已完成的纪元数。初始化为0,第一个时期为1。
  • engine.state.iteration:引擎已完成的迭代次数。初始化为0,第一次迭代为1。
  • engine.state.max_epochs:要运行的时期数。初始化为1。
  • engine.state.output:为定义的process_function的输出Engine。见下文。
  • 等等

其他属性可以在的文档中找到State

在下面的代码中,engine.state.output将存储批次损失。此输出用于打印每次迭代的损耗。

def update(engine, batch):x, y = batchy_pred = model(inputs)loss = loss_fn(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()return loss.item()def on_iteration_completed(engine):iteration = engine.state.iterationepoch = engine.state.epochloss = engine.state.outputprint(f"Epoch: {epoch}, Iteration: {iteration}, Loss: {loss}")trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed)

由于对process_function的输出没有限制,因此Ignite为其和提供了output_transform参数 。参数output_transform是用于将engine.state.output转换为预期用途的函数。在下面,我们将看到不同类型的engine.state.output以及如何对其进行转换。metrics``handlers

在下面的代码中,engine.state.output将已处理批次的loss,y_pred,y。如果要附加Accuracy到引擎,则需要output_transform来从engine.state.output获取y_pred和y 。让我们看看如何做到这一点:

def update(engine, batch):x, y = batchy_pred = model(inputs)loss = loss_fn(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()return loss.item(), y_pred, ytrainer = Engine(update)@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):epoch = engine.state.epochloss = engine.state.output[0]print (f'Epoch {epoch}: train_loss = {loss}')accuracy = Accuracy(output_transform=lambda x: [x[1], x[2]])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)

与上面类似,但是这次process_function的输出是处理后的批次的字典 loss,y_pred,y,这是用户可以使用output_transform从engine.state.output获取y_pred和y的方式。见下文:

def update(engine, batch):x, y = batchy_pred = model(inputs)loss = loss_fn(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()return {'loss': loss.item(),'y_pred': y_pred,'y': y}trainer = Engine(update)@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):epoch = engine.state.epochloss = engine.state.output['loss']print (f'Epoch {epoch}: train_loss = {loss}')accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)

指标

库提供了各种机器学习任务的现成指标列表。支持两种计算指标的方式:1)在线 和 2)存储整个输出历史记录。

指标可以附加到Engine

from ignite.metrics import Accuracyaccuracy = Accuracy()accuracy.attach(evaluator, "accuracy")state = evaluator.run(validation_data)print("Result:", state.metrics)
# > {"accuracy": 0.12345}

或可用作独立对象:

from ignite.metrics import Accuracyaccuracy = Accuracy()accuracy.reset()for y_pred, y in get_prediction_target():accuracy.update((y_pred, y))print("Result:", accuracy.compute())

或可用作独立对象:

from ignite.metrics import Accuracyaccuracy = Accuracy()accuracy.reset()for y_pred, y in get_prediction_target():accuracy.update((y_pred, y))print("Result:", accuracy.compute())

完整的指标列表和API可以在ignite.metrics模块中找到。

这篇关于Pytorch Ignite 使用方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/613089

相关文章

中文分词jieba库的使用与实景应用(一)

知识星球:https://articles.zsxq.com/id_fxvgc803qmr2.html 目录 一.定义: 精确模式(默认模式): 全模式: 搜索引擎模式: paddle 模式(基于深度学习的分词模式): 二 自定义词典 三.文本解析   调整词出现的频率 四. 关键词提取 A. 基于TF-IDF算法的关键词提取 B. 基于TextRank算法的关键词提取

使用SecondaryNameNode恢复NameNode的数据

1)需求: NameNode进程挂了并且存储的数据也丢失了,如何恢复NameNode 此种方式恢复的数据可能存在小部分数据的丢失。 2)故障模拟 (1)kill -9 NameNode进程 [lytfly@hadoop102 current]$ kill -9 19886 (2)删除NameNode存储的数据(/opt/module/hadoop-3.1.4/data/tmp/dfs/na

Hadoop数据压缩使用介绍

一、压缩原则 (1)运算密集型的Job,少用压缩 (2)IO密集型的Job,多用压缩 二、压缩算法比较 三、压缩位置选择 四、压缩参数配置 1)为了支持多种压缩/解压缩算法,Hadoop引入了编码/解码器 2)要在Hadoop中启用压缩,可以配置如下参数

Makefile简明使用教程

文章目录 规则makefile文件的基本语法:加在命令前的特殊符号:.PHONY伪目标: Makefilev1 直观写法v2 加上中间过程v3 伪目标v4 变量 make 选项-f-n-C Make 是一种流行的构建工具,常用于将源代码转换成可执行文件或者其他形式的输出文件(如库文件、文档等)。Make 可以自动化地执行编译、链接等一系列操作。 规则 makefile文件

使用opencv优化图片(画面变清晰)

文章目录 需求影响照片清晰度的因素 实现降噪测试代码 锐化空间锐化Unsharp Masking频率域锐化对比测试 对比度增强常用算法对比测试 需求 对图像进行优化,使其看起来更清晰,同时保持尺寸不变,通常涉及到图像处理技术如锐化、降噪、对比度增强等 影响照片清晰度的因素 影响照片清晰度的因素有很多,主要可以从以下几个方面来分析 1. 拍摄设备 相机传感器:相机传

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

pdfmake生成pdf的使用

实际项目中有时会有根据填写的表单数据或者其他格式的数据,将数据自动填充到pdf文件中根据固定模板生成pdf文件的需求 文章目录 利用pdfmake生成pdf文件1.下载安装pdfmake第三方包2.封装生成pdf文件的共用配置3.生成pdf文件的文件模板内容4.调用方法生成pdf 利用pdfmake生成pdf文件 1.下载安装pdfmake第三方包 npm i pdfma

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]

浅谈主机加固,六种有效的主机加固方法

在数字化时代,数据的价值不言而喻,但随之而来的安全威胁也日益严峻。从勒索病毒到内部泄露,企业的数据安全面临着前所未有的挑战。为了应对这些挑战,一种全新的主机加固解决方案应运而生。 MCK主机加固解决方案,采用先进的安全容器中间件技术,构建起一套内核级的纵深立体防护体系。这一体系突破了传统安全防护的局限,即使在管理员权限被恶意利用的情况下,也能确保服务器的安全稳定运行。 普适主机加固措施:

webm怎么转换成mp4?这几种方法超多人在用!

webm怎么转换成mp4?WebM作为一种新兴的视频编码格式,近年来逐渐进入大众视野,其背后承载着诸多优势,但同时也伴随着不容忽视的局限性,首要挑战在于其兼容性边界,尽管WebM已广泛适应于众多网站与软件平台,但在特定应用环境或老旧设备上,其兼容难题依旧凸显,为用户体验带来不便,再者,WebM格式的非普适性也体现在编辑流程上,由于它并非行业内的通用标准,编辑过程中可能会遭遇格式不兼容的障碍,导致操