本文主要是介绍Pytorch Ignite 使用方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
Pytorch Ignite 使用方法
下载 pip install ignite
官方网址:https://pytorch.org/ignite/concepts.html
Engine
该框架的本质是class Engine
,它是一种抽象形式,它在提供的数据上循环给定的次数,执行处理函数并返回结果:
while epoch < max_epochs:# run an epoch on datadata_iter = iter(data)while True:try:batch = next(data_iter)output = process_function(batch)iter_counter += 1except StopIteration:data_iter = iter(data)if iter_counter == epoch_length:break
因此,模型训练器只是一个引擎,它在训练数据集上循环多次并更新模型参数。同样,可以使用在验证数据集上运行一次并计算指标的引擎来完成模型评估。
例如,用于监督任务的模型训练器:
def train_step(trainer, batch):model.train()optimizer.zero_grad()x, y = prepare_batch(batch)y_pred = model(x)loss = loss_fn(y_pred, y)loss.backward()optimizer.step()return loss.item()trainer = Engine(train_step)
trainer.run(data, max_epochs=100)
训练步骤的输出类型(即在上面的示例中loss.item()
)不受限制。训练步骤功能可以返回用户想要的一切。输出设置为,trainer.state.output
并且可以进一步用于任何类型的处理。
默认情况下,epoch_length 长度由
len(data)
定义。但是,用户也可以手动将epoch_length 长度定义为要循环的多次迭代。这样,输入数据可以是迭代器。trainer.run(data, max_epochs=100, epoch_length=200)
如果data是长度未知的有限数据迭代器(对于用户),
epoch_length
则可以省略参数,并且在耗尽数据迭代器时将自动确定参数。
任何复杂度的训练逻辑都可以使用train_step
方法进行编码,并且可以使用此方法来构造训练器。
函数batch
中的train_step
参数是用户定义的,可以包含单个迭代所需的任何数据。
# 定义模型参数
model_1 = ...
model_2 = ...
# 定义优化器
optimizer_1 = ...
optimizer_2 = ...
#
criterion_1 = ...
criterion_2 = ...
# ...def train_step(trainer, batch):data_1 = batch["data_1"]data_2 = batch["data_2"]# ...model_1.train()optimizer_1.zero_grad()loss_1 = forward_pass(data_1, model_1, criterion_1)loss_1.backward()optimizer_1.step()# ...model_2.train()optimizer_2.zero_grad()loss_2 = forward_pass(data_2, model_2, criterion_2)loss_2.backward()optimizer_2.step()# ...# User can return any type of structure.# 用户可以返回任何类型的数据结构return {"loss_1": loss_1,"loss_2": loss_2,# ...}trainer = Engine(train_step)
trainer.run(data, max_epochs=100)
事件和处理程序
为了提高Engine
灵活性,引入了事件系统,该系统促进了运行的每个步骤之间的交互:
- Engine启动/完成
- epoch开始/完成
- 批处理迭代已开始/已完成
有关事件的完整列表,请参见Events
。
用户可以执行自定义代码作为事件处理程序。处理程序可以是任何函数:例如lambda,简单函数,类方法等。第一个参数可以选择是engine,但不是必须的。
让我们更详细地考虑当run()
被调用时发生的情况:
fire_event(Events.STARTED)
while epoch < max_epochs:fire_event(Events.EPOCH_STARTED)# run once on datafor batch in data:fire_event(Events.ITERATION_STARTED)output = process_function(batch)fire_event(Events.ITERATION_COMPLETED)fire_event(Events.EPOCH_COMPLETED)
fire_event(Events.COMPLETED)
首先,将触发*“Engine已启动”事件并执行其所有事件处理程序(我们将在下一段中看到如何添加事件处理程序)。接下来,在启动循环并发生“epoch 开始”*事件时,等等。每次触发事件时,都会执行附加的处理程序。
使用方法add_event_handler()
或 on()
装饰器附加事件处理程序很简单:
trainer = Engine(update_model)trainer.add_event_handler(Events.STARTED, lambda _: print("Start training"))
# or
@trainer.on(Events.STARTED)
def on_training_started(engine):print("Another message of start training")
# or even simpler, use only what you need !
@trainer.on(Events.STARTED)
def on_training_started():print("Another message of start training")# attach handler with args, kwargs
mydata = [1, 2, 3, 4]def on_training_ended(data):print(f"Training is ended. mydata={data}")trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)
可以通过remove_event_handler()
或通过RemovableEventHandle
返回的引用来分离事件处理程序add_event_handler()
。这可用于将已配置的引擎重用于多个循环:
model = ...
train_loader, validation_loader, test_loader = ...trainer = create_supervised_trainer(model, optimizer, loss)
evaluator = create_supervised_evaluator(model, metrics={"acc": Accuracy()})def log_metrics(engine, title):print(f"Epoch: {trainer.state.epoch} - {title} accuracy: {engine.state.metrics['acc']:.2f}")@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(trainer):with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "train"):evaluator.run(train_loader)with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "validation"):evaluator.run(validation_loader)with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "test"):evaluator.run(test_loader)trainer.run(train_loader, max_epochs=100)
还可以将事件处理程序配置为以用户模式调用:每第n个事件一次,或使用自定义事件过滤功能:
model = ...
train_loader, validation_loader, test_loader = ...trainer = create_supervised_trainer(model, optimizer, loss)@trainer.on(Events.ITERATION_COMPLETED(every=50))
def log_training_loss_every_50_iterations():print(f"{trainer.state.epoch} / {trainer.state.max_epochs} : {trainer.state.iteration} - loss: {trainer.state.output:.2f}")@trainer.on(Events.EPOCH_STARTED(once=25))
def do_something_once_on_25_epoch():# do somethingdef custom_event_filter(engine, event):if event in [1, 2, 5, 10, 50, 100]:return Truereturn False@engine.on(Events.ITERATION_STARTED(event_filter=custom_event_filter))
def call_on_special_event(engine):# do something on 1, 2, 5, 10, 50, 100 iterationstrainer.run(train_loader, max_epochs=100)
自定义事件
用户还可以定义自定义事件。用户定义的事件应继承于引擎EventEnum
并register_events()
在引擎中注册。
from ignite.engine import EventEnumclass CustomEvents(EventEnum):"""Custom events defined by user"""CUSTOM_STARTED = 'custom_started'CUSTOM_COMPLETED = 'custom_completed'engine.register_events(*CustomEvents)
这些事件可用于附加任何处理程序,并使用触发fire_event()
。
@engine.on(CustomEvents.CUSTOM_STARTED)
def call_on_custom_event(engine):# do something@engine.on(Events.STARTED)
def fire_custom_events(engine):engine.fire_event(CustomEvents.CUSTOM_STARTED)
时间线和事件
在事件下方,一些典型的处理程序显示在时间轴上,以进行训练循环,并在每个时期后进行评估:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-U3DJJvxX-1618109731932)(https://pytorch.org/ignite/_images/timeline_and_events.png)]
状态
Engine
引入了一个状态来存储process_function,当前时期,迭代和其他有用信息的输出。每个都Engine
包含一个State
,其中包括以下内容:
- engine.state.seed:要在每个数据“ epoch”处设置的种子。
- engine.state.epoch:引擎已完成的纪元数。初始化为0,第一个时期为1。
- engine.state.iteration:引擎已完成的迭代次数。初始化为0,第一次迭代为1。
- engine.state.max_epochs:要运行的时期数。初始化为1。
- engine.state.output:为定义的process_function的输出
Engine
。见下文。 - 等等
其他属性可以在的文档中找到State
。
在下面的代码中,engine.state.output将存储批次损失。此输出用于打印每次迭代的损耗。
def update(engine, batch):x, y = batchy_pred = model(inputs)loss = loss_fn(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()return loss.item()def on_iteration_completed(engine):iteration = engine.state.iterationepoch = engine.state.epochloss = engine.state.outputprint(f"Epoch: {epoch}, Iteration: {iteration}, Loss: {loss}")trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed)
由于对process_function的输出没有限制,因此Ignite为其和提供了output_transform参数 。参数output_transform是用于将engine.state.output转换为预期用途的函数。在下面,我们将看到不同类型的engine.state.output以及如何对其进行转换。metrics``handlers
在下面的代码中,engine.state.output将已处理批次的loss,y_pred,y。如果要附加Accuracy
到引擎,则需要output_transform来从engine.state.output获取y_pred和y 。让我们看看如何做到这一点:
def update(engine, batch):x, y = batchy_pred = model(inputs)loss = loss_fn(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()return loss.item(), y_pred, ytrainer = Engine(update)@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):epoch = engine.state.epochloss = engine.state.output[0]print (f'Epoch {epoch}: train_loss = {loss}')accuracy = Accuracy(output_transform=lambda x: [x[1], x[2]])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)
与上面类似,但是这次process_function的输出是处理后的批次的字典 loss,y_pred,y,这是用户可以使用output_transform从engine.state.output获取y_pred和y的方式。见下文:
def update(engine, batch):x, y = batchy_pred = model(inputs)loss = loss_fn(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()return {'loss': loss.item(),'y_pred': y_pred,'y': y}trainer = Engine(update)@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):epoch = engine.state.epochloss = engine.state.output['loss']print (f'Epoch {epoch}: train_loss = {loss}')accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)
指标
库提供了各种机器学习任务的现成指标列表。支持两种计算指标的方式:1)在线 和 2)存储整个输出历史记录。
指标可以附加到Engine
:
from ignite.metrics import Accuracyaccuracy = Accuracy()accuracy.attach(evaluator, "accuracy")state = evaluator.run(validation_data)print("Result:", state.metrics)
# > {"accuracy": 0.12345}
或可用作独立对象:
from ignite.metrics import Accuracyaccuracy = Accuracy()accuracy.reset()for y_pred, y in get_prediction_target():accuracy.update((y_pred, y))print("Result:", accuracy.compute())
或可用作独立对象:
from ignite.metrics import Accuracyaccuracy = Accuracy()accuracy.reset()for y_pred, y in get_prediction_target():accuracy.update((y_pred, y))print("Result:", accuracy.compute())
完整的指标列表和API可以在ignite.metrics模块中找到。
这篇关于Pytorch Ignite 使用方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!