nnUNet 更改学习率和衰减优化器的方法

2023-11-06 11:36

本文主要是介绍nnUNet 更改学习率和衰减优化器的方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

此为记录贴,逻辑混乱 仅供参考:
勿喷
nnUNet默认的学习率衰减方法为线性衰减,优化器为SGD,在.\nnUNet\nnunetv2\training\nnUNetTrainer\nnUNetTrainer.py文件中nnUNetTrainer基类中定义 如下:

    def configure_optimizers(self):optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,momentum=0.99, nesterov=True)lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)return optimizer, lr_scheduler

为了改变优化器和学习率衰减方法:
我们可以继承nnUNetTrainer类重写一个 nnUNetTrainerCosAnneal类,当然nnUnet已经贴心的为我们写好了 在.\nnUNet\nnunetv2\training\nnUNetTrainer\variants\optimizer\nnUNetTrainerAdam
原始代码如下:

import torch
from torch.optim import Adam, AdamWfrom nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainerclass nnUNetTrainerAdam(nnUNetTrainer):def configure_optimizers(self):optimizer = AdamW(self.network.parameters(),lr=self.initial_lr,weight_decay=self.weight_decay,amsgrad=True)# optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,#                             momentum=0.99, nesterov=True)lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)return optimizer, lr_scheduler

如果按照上一篇博客的方法直接更改训练方法为nnUNetTrainerAdam的话,会弹出如下警告:

 UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1
.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first 
value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-ratewarnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`.

警告已经说的很明白了,就不翻译了,为了避免不能在训练的时候调整学习率,我们需要去改变lr_scheduler.step()optimizer.step() 调用顺序。就要在重写on_train_epoch_starttrain_step函数
下列文件可以作为参考:
要修改优化器也可以直接在
optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True)
更改即可

from torch.optim.lr_scheduler import CosineAnnealingLR
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import *class nnUNetTrainerCosAnneal(nnUNetTrainer):def configure_optimizers(self):optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,momentum=0.99, nesterov=True)lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs,eta_min=1e-4)return optimizer, lr_schedulerdef on_train_epoch_start(self):self.network.train()# self.lr_scheduler.step() #don't need call lr_scheduler.step() in this functionself.print_to_log_file('')self.print_to_log_file(f'Epoch {self.current_epoch}')self.print_to_log_file(f"Current learning rate: {np.round(self.optimizer.param_groups[0]['lr'], decimals=5)}")# lrs are the same for all workers so we don't need to gather them in case of DDP trainingself.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch)def train_step(self, batch: dict) -> dict:data = batch['data']target = batch['target']data = data.to(self.device, non_blocking=True)if isinstance(target, list):target = [i.to(self.device, non_blocking=True) for i in target]else:target = target.to(self.device, non_blocking=True)self.optimizer.zero_grad(set_to_none=True)# Autocast is a little bitch.# If the device_type is 'cpu' then it's slow as heck and needs to be disabled.# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)# So autocast will only be active if we have a cuda device.with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():output = self.network(data)# del datal = self.loss(output, target)if self.grad_scaler is not None:self.grad_scaler.scale(l).backward()self.grad_scaler.unscale_(self.optimizer)torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)self.grad_scaler.step(self.optimizer)self.grad_scaler.update()else:l.backward()torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)self.optimizer.step()self.lr_scheduler.step()## add lr_scheduler.step() after optimizer.step()return {'loss': l.detach().cpu().numpy()}

要使用这个类进行训练,运行以下命令即可:

nnUNetV2_train 002 2d 0 -tr nnUNetTrainerCosAnneal

记录完毕,继续炼丹

这篇关于nnUNet 更改学习率和衰减优化器的方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/356375

相关文章

C++变换迭代器使用方法小结

《C++变换迭代器使用方法小结》本文主要介绍了C++变换迭代器使用方法小结,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录1、源码2、代码解析代码解析:transform_iterator1. transform_iterat

C++中std::distance使用方法示例

《C++中std::distance使用方法示例》std::distance是C++标准库中的一个函数,用于计算两个迭代器之间的距离,本文主要介绍了C++中std::distance使用方法示例,具... 目录语法使用方式解释示例输出:其他说明:总结std::distance&n编程bsp;是 C++ 标准

SpringBoot3实现Gzip压缩优化的技术指南

《SpringBoot3实现Gzip压缩优化的技术指南》随着Web应用的用户量和数据量增加,网络带宽和页面加载速度逐渐成为瓶颈,为了减少数据传输量,提高用户体验,我们可以使用Gzip压缩HTTP响应,... 目录1、简述2、配置2.1 添加依赖2.2 配置 Gzip 压缩3、服务端应用4、前端应用4.1 N

Linux换行符的使用方法详解

《Linux换行符的使用方法详解》本文介绍了Linux中常用的换行符LF及其在文件中的表示,展示了如何使用sed命令替换换行符,并列举了与换行符处理相关的Linux命令,通过代码讲解的非常详细,需要的... 目录简介检测文件中的换行符使用 cat -A 查看换行符使用 od -c 检查字符换行符格式转换将

SpringBoot实现数据库读写分离的3种方法小结

《SpringBoot实现数据库读写分离的3种方法小结》为了提高系统的读写性能和可用性,读写分离是一种经典的数据库架构模式,在SpringBoot应用中,有多种方式可以实现数据库读写分离,本文将介绍三... 目录一、数据库读写分离概述二、方案一:基于AbstractRoutingDataSource实现动态

Java中的String.valueOf()和toString()方法区别小结

《Java中的String.valueOf()和toString()方法区别小结》字符串操作是开发者日常编程任务中不可或缺的一部分,转换为字符串是一种常见需求,其中最常见的就是String.value... 目录String.valueOf()方法方法定义方法实现使用示例使用场景toString()方法方法

Java中List的contains()方法的使用小结

《Java中List的contains()方法的使用小结》List的contains()方法用于检查列表中是否包含指定的元素,借助equals()方法进行判断,下面就来介绍Java中List的c... 目录详细展开1. 方法签名2. 工作原理3. 使用示例4. 注意事项总结结论:List 的 contain

Spring Boot + MyBatis Plus 高效开发实战从入门到进阶优化(推荐)

《SpringBoot+MyBatisPlus高效开发实战从入门到进阶优化(推荐)》本文将详细介绍SpringBoot+MyBatisPlus的完整开发流程,并深入剖析分页查询、批量操作、动... 目录Spring Boot + MyBATis Plus 高效开发实战:从入门到进阶优化1. MyBatis

MyBatis 动态 SQL 优化之标签的实战与技巧(常见用法)

《MyBatis动态SQL优化之标签的实战与技巧(常见用法)》本文通过详细的示例和实际应用场景,介绍了如何有效利用这些标签来优化MyBatis配置,提升开发效率,确保SQL的高效执行和安全性,感... 目录动态SQL详解一、动态SQL的核心概念1.1 什么是动态SQL?1.2 动态SQL的优点1.3 动态S

macOS无效Launchpad图标轻松删除的4 种实用方法

《macOS无效Launchpad图标轻松删除的4种实用方法》mac中不在appstore上下载的应用经常在删除后它的图标还残留在launchpad中,并且长按图标也不会出现删除符号,下面解决这个问... 在 MACOS 上,Launchpad(也就是「启动台」)是一个便捷的 App 启动工具。但有时候,应