【深度学习实战(24)】如何实现“断点续训”?

2024-04-25 03:04

本文主要是介绍【深度学习实战(24)】如何实现“断点续训”?,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、什么是断点续训:

中断的地方,继续训练。与加载预训练权重有什么区别呢?区别在于优化器参数和学习率变了。

二、如何实现“断点续训”

我们需要使用checkpoint方法保存,模型权重,优化器权重,训练轮数。
保存模型,优化器权重可以理解,保存训练轮数是为了获得中断时的学习率。
由于在中断的时候,我们保存了中断时的模型权重,优化器权重,训练轮数,所以再次训练,加载这些参数,便可以继续训练。
实现流程:
(1)断点训练开关设置

# -------------------#
#   断点续训
# -------------------#
resume = True
resume_weights = os.path.join(save_dir, name_last_weights)

(2)使用checkpoint方式保模型权重,优化器权重,训练轮数

# -----------------------------------------------#
#   保存最后一轮模型权重,优化器权重,训练轮数
# -----------------------------------------------#
last_ckpt = {'epoch': epoch, 'model': save_state_dict, 'optimizer': optimizer.state_dict(), 'loss': val_loss}
torch.save(last_ckpt, os.path.join(save_dir, name_last_weights))

(3)模型权重,训练轮数加载

Init_Eoch = ...
model = YourModel()
# -------------------#
#   断点续训
# -------------------#
if resume:if args.resume_weights != '':Init_Epoch = torch.load(args.resume_weights, map_location=device)['epoch']model.load_state_dict(torch.load(args.resume_weights, map_location=device)['model'])

(4)优化器权重加载

optimizer = optim.AdamW(model.parameters(), lr=0.0001)
# -------------------#
#   断点续训
# -------------------#
if resume:if args.resume_weights != '':optimizer.load_state_dict(torch.load(args.resume_weights, map_location=device)['optimizer'])

三、完整“断点续训”框架

# -------------------#
#   断点续训
# -------------------#
resume = True
resume_weights = os.path.join(save_dir, name_last_weights)Init_Eoch = ...
model = YourModel()
# -------------------#
#   断点续训
# -------------------#
if resume:if args.resume_weights != '':Init_Epoch = torch.load(args.resume_weights, map_location=device)['epoch']model.load_state_dict(torch.load(args.resume_weights, map_location=device)['model'])optimizer = optim.AdamW(model.parameters(), lr=0.0001)
# -------------------#
#   断点续训
# -------------------#
if resume:if args.resume_weights != '':optimizer.load_state_dict(torch.load(args.resume_weights, map_location=device)['optimizer'])# -----------------------------------------------#
#   保存最后一轮模型权重,优化器权重,训练轮数
# -----------------------------------------------#
last_ckpt = {'epoch': epoch, 'model': save_state_dict, 'optimizer': optimizer.state_dict(), 'loss': val_loss}
torch.save(last_ckpt, os.path.join(save_dir, name_last_weights))

四、实际应用

从第50轮开始训练,训练到第103轮,中断训练。
loss变化:
在这里插入图片描述

检测变化:
在这里插入图片描述

从第104轮继续训练,训练到第162轮,中断训练。
loss变化:
在这里插入图片描述

检测变化:
在这里插入图片描述

从第163轮继续训练,训练到第320轮,中断训练。
loss变化:
在这里插入图片描述

检测变化:
在这里插入图片描述

从第321轮继续训练,训练到第1000轮,中断训练。
loss变化:
在这里插入图片描述

检测变化:
在这里插入图片描述

这篇关于【深度学习实战(24)】如何实现“断点续训”?的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/933548

相关文章

Python实现对阿里云OSS对象存储的操作详解

《Python实现对阿里云OSS对象存储的操作详解》这篇文章主要为大家详细介绍了Python实现对阿里云OSS对象存储的操作相关知识,包括连接,上传,下载,列举等功能,感兴趣的小伙伴可以了解下... 目录一、直接使用代码二、详细使用1. 环境准备2. 初始化配置3. bucket配置创建4. 文件上传到os

深度解析Java DTO(最新推荐)

《深度解析JavaDTO(最新推荐)》DTO(DataTransferObject)是一种用于在不同层(如Controller层、Service层)之间传输数据的对象设计模式,其核心目的是封装数据,... 目录一、什么是DTO?DTO的核心特点:二、为什么需要DTO?(对比Entity)三、实际应用场景解析

关于集合与数组转换实现方法

《关于集合与数组转换实现方法》:本文主要介绍关于集合与数组转换实现方法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、Arrays.asList()1.1、方法作用1.2、内部实现1.3、修改元素的影响1.4、注意事项2、list.toArray()2.1、方

从原理到实战深入理解Java 断言assert

《从原理到实战深入理解Java断言assert》本文深入解析Java断言机制,涵盖语法、工作原理、启用方式及与异常的区别,推荐用于开发阶段的条件检查与状态验证,并强调生产环境应使用参数验证工具类替代... 目录深入理解 Java 断言(assert):从原理到实战引言:为什么需要断言?一、断言基础1.1 语

深度解析Java项目中包和包之间的联系

《深度解析Java项目中包和包之间的联系》文章浏览阅读850次,点赞13次,收藏8次。本文详细介绍了Java分层架构中的几个关键包:DTO、Controller、Service和Mapper。_jav... 目录前言一、各大包1.DTO1.1、DTO的核心用途1.2. DTO与实体类(Entity)的区别1

使用Python实现可恢复式多线程下载器

《使用Python实现可恢复式多线程下载器》在数字时代,大文件下载已成为日常操作,本文将手把手教你用Python打造专业级下载器,实现断点续传,多线程加速,速度限制等功能,感兴趣的小伙伴可以了解下... 目录一、智能续传:从崩溃边缘抢救进度二、多线程加速:榨干网络带宽三、速度控制:做网络的好邻居四、终端交互

java实现docker镜像上传到harbor仓库的方式

《java实现docker镜像上传到harbor仓库的方式》:本文主要介绍java实现docker镜像上传到harbor仓库的方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 前 言2. 编写工具类2.1 引入依赖包2.2 使用当前服务器的docker环境推送镜像2.2

C++20管道运算符的实现示例

《C++20管道运算符的实现示例》本文简要介绍C++20管道运算符的使用与实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录标准库的管道运算符使用自己实现类似的管道运算符我们不打算介绍太多,因为它实际属于c++20最为重要的

Java easyExcel实现导入多sheet的Excel

《JavaeasyExcel实现导入多sheet的Excel》这篇文章主要为大家详细介绍了如何使用JavaeasyExcel实现导入多sheet的Excel,文中的示例代码讲解详细,感兴趣的小伙伴可... 目录1.官网2.Excel样式3.代码1.官网easyExcel官网2.Excel样式3.代码

Java MQTT实战应用

《JavaMQTT实战应用》本文详解MQTT协议,涵盖其发布/订阅机制、低功耗高效特性、三种服务质量等级(QoS0/1/2),以及客户端、代理、主题的核心概念,最后提供Linux部署教程、Sprin... 目录一、MQTT协议二、MQTT优点三、三种服务质量等级四、客户端、代理、主题1. 客户端(Clien