【ShuQiHere】微调与训练恢复:理解 `load_weights` 和 `save_model` 的实用方法

本文主要是介绍【ShuQiHere】微调与训练恢复:理解 `load_weights` 和 `save_model` 的实用方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

【ShuQiHere】

在深度学习的世界中,训练一个模型不仅需要时间,还需要大量的计算资源。比如,你已经花了几天时间训练一个模型,但突然间,电脑崩溃了,你的所有进度都丢失了。这种情况就像是在一场马拉松比赛的最后一公里摔倒,让人沮丧至极。那么,有没有什么方法可以避免这种悲剧呢?今天,我们就来聊聊如何通过保存和加载模型的权重来应对这些挑战,确保你在深度学习的旅程中不会白费功夫。

模型保存和加载的背景

训练一个深度学习模型就像建造一座摩天大楼。你需要从基础开始,一层层地搭建,最终完成一个复杂的系统。然而,建造过程中难免会遇到意外,比如断电、系统崩溃,甚至是代码错误。这些意外可能让你的努力前功尽弃。如果你不想每次意外发生后都从头开始,那么模型保存和加载就显得尤为重要。

在 TensorFlow 中,我们有两种主要的保存和加载方法:保存整个模型保存模型的权重。理解它们的区别和用法,就像学会了在建造摩天大楼时如何保存施工进度,确保即使遭遇突发事件,你的建筑工程也能顺利继续。

微调模型:从预训练到自定义任务

我们都知道,训练一个深度学习模型需要大量的数据和计算资源。幸运的是,深度学习社区里有很多预训练的模型,这些模型已经在大规模数据集上进行了训练。通过微调(Fine-Tuning),你可以利用这些预训练模型,在它们的基础上进行训练,快速适应新的任务。

场景:训练猫狗识别模型

比如,你想训练一个模型来区分猫和狗。如果从零开始训练,不仅费时费力,而且可能效果不佳。但如果你有一个在 ImageNet 上预训练的模型,就可以大大减少训练时间。你只需加载预训练的权重,并在猫狗数据集上微调模型,这样不仅节省了时间,还能获得更好的效果。

微调代码示例
import tensorflow as tf
from tensorflow.keras import layers, models# 创建基础模型结构
base_model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(64, 64, 3)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(128, (3, 3), activation='relu'),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10, activation='softmax')  # 假设有10个类别
])# 加载预训练权重
base_model.load_weights('pretrained_weights.h5')# 冻结部分层
for layer in base_model.layers[:-1]:layer.trainable = False# 编译并微调模型
base_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])# 微调训练
base_model.fit(new_data, new_labels, epochs=10)

解释

  • 加载预训练权重:通过 load_weights,你可以将已有的知识应用到新的任务中,避免从零开始训练模型。
  • 冻结层:冻结部分层的目的是保留预训练模型中已经学到的通用特征,仅微调特定的几层以适应新任务。这就像是在一个已经建好的摩天大楼里装修几层,改造成你需要的样子,而不是重新建造整栋大楼。

通过这种方式,你可以在保持预训练模型中有用特征的同时,快速适应新的任务或数据集。对于刚入门的小白来说,这是一种高效且实用的策略。

应对训练中断:如何保存和恢复模型

在训练模型的过程中,意外总是难以避免。系统崩溃、断电、内存不足等问题可能随时出现,这就像是在建造摩天大楼时,突然遇到大风暴,导致施工中断。那么,如何在中断后快速恢复呢?定期保存模型的权重是一个明智的选择。

保存权重

通过 TensorFlow 的 ModelCheckpoint 回调函数,你可以定期保存模型的权重,确保即使训练中断,也能从上次保存的进度继续。

import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint# 设置检查点保存路径
checkpoint_path = "training_checkpoints/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)# 创建一个 ModelCheckpoint 回调函数
cp_callback = ModelCheckpoint(filepath=checkpoint_path,save_weights_only=True,verbose=1,save_freq=600)  # 每处理 600 个输入样本保存一次# 训练模型,同时保存检查点
model.fit(train_data, train_labels, epochs=10, callbacks=[cp_callback])

解释

  • 定期保存ModelCheckpoint 可以帮助你定期保存模型的权重,就像是在建造摩天大楼时,每隔一段时间都拍张照片保存进度。这样,即使突然下雨或者停电,你也可以在天气恢复后继续施工。
恢复训练

当训练中断时,你可以从最新的检查点恢复训练,而不必从头开始。这不仅节省了时间,也让你不会因为突发事件而感到沮丧。

# 查找最新的检查点
latest = tf.train.latest_checkpoint(checkpoint_dir)# 如果有检查点,加载权重
if latest:model.load_weights(latest)print(f"Loaded weights from checkpoint: {latest}")
else:print("No checkpoint found. Starting from scratch.")# 继续训练
model.fit(train_data, train_labels, epochs=10, initial_epoch=int(latest.split('-')[-1].split('.')[0]) if latest else 0,callbacks=[cp_callback])

解释

  • 从检查点恢复:通过 tf.train.latest_checkpoint 函数,你可以找到最近的检查点,并通过 load_weights 恢复模型的状态,从中断的地方继续训练。这就像是你在大风暴后回到工地,拿出之前保存的施工进度照片,继续建造摩天大楼。
保存与加载完整模型:从开发到生产的无缝衔接

在训练完成后,我们不仅需要保存权重,还可能需要保存整个模型。这样做的目的是为了方便模型的部署和迁移。通过保存整个模型,你可以在不同的环境中无缝地加载和使用它。

保存完整模型
# 保存整个模型
model.save('my_full_model.h5')
加载完整模型
from tensorflow.keras.models import load_model# 加载完整模型
model = load_model('my_full_model.h5')# 继续训练或进行推理
predictions = model.predict(test_data)

解释

  • 保存整个模型:使用 save_model 可以将模型的结构、权重以及优化器状态一并保存,确保模型在不同环境中的一致性。
  • 加载整个模型load_model 允许你在任何支持 TensorFlow/Keras 的环境中重新加载并使用这个模型,无论是继续训练还是部署到生产环境。
超大型语言模型的微调

当我们面对像 LLaMA3 这样超大型的语言模型时,微调过程就更加复杂。由于这些模型的参数量巨大,通常我们不会直接微调所有参数,而是使用**参数高效微调(PEFT)**技术,如 LoRA 或 Adapter。这些技术允许我们通过调整少量参数来适应新的任务,既降低了计算资源的需求,又保证了模型的性能。

LoRA 微调示例
from peft import LoraConfig, get_peft_model# 配置LoRA
lora_config = LoraConfig(r=4,  # 低秩矩阵的秩lora_alpha=16,  # LoRA的缩放因子target_modules=["q_proj", "v_proj"],  # 在注意力层中应用LoRAlora_dropout=0.1,  # LoRA的dropout率
)# 将LoRA应用到模型
model = get_peft_model(model, lora_config)# 开始训练
trainer.train()

解释

  • LoRA 技术:LoRA 通过在模型的特定矩阵上应用低秩分解,实现参数高效微调。这就像是在摩天大楼的某些关键部位加固,从而确保建筑在更高负

载下依然稳固。

最佳实践与总结
  1. 定期保存:在训练过程中,使用 ModelCheckpoint 定期保存权重,防止因中断而丢失进度。
  2. 选择合适的保存方法:在开发过程中,可以使用 save_weights 进行频繁保存;在部署前,使用 save_model 保存整个模型。
  3. 恢复训练:通过 load_weights,你可以轻松恢复训练进度,并在中断后继续模型的优化。
  4. 参数高效微调:对于超大型模型,使用 LoRA 或 Adapter 等技术进行参数高效微调,可以大幅降低资源需求。

通过掌握这些技术,你不仅可以确保模型训练的稳健性,还能有效应对实际开发中的各种挑战。无论是微调预训练模型,还是处理不可预知的中断,load_weightssave_model 都将成为你深度学习开发中的利器。最佳实践与总结


这篇关于【ShuQiHere】微调与训练恢复:理解 `load_weights` 和 `save_model` 的实用方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1121003

相关文章

Python实现图片分割的多种方法总结

《Python实现图片分割的多种方法总结》图片分割是图像处理中的一个重要任务,它的目标是将图像划分为多个区域或者对象,本文为大家整理了一些常用的分割方法,大家可以根据需求自行选择... 目录1. 基于传统图像处理的分割方法(1) 使用固定阈值分割图片(2) 自适应阈值分割(3) 使用图像边缘检测分割(4)

Java中Switch Case多个条件处理方法举例

《Java中SwitchCase多个条件处理方法举例》Java中switch语句用于根据变量值执行不同代码块,适用于多个条件的处理,:本文主要介绍Java中SwitchCase多个条件处理的相... 目录前言基本语法处理多个条件示例1:合并相同代码的多个case示例2:通过字符串合并多个case进阶用法使用

Python中__init__方法使用的深度解析

《Python中__init__方法使用的深度解析》在Python的面向对象编程(OOP)体系中,__init__方法如同建造房屋时的奠基仪式——它定义了对象诞生时的初始状态,下面我们就来深入了解下_... 目录一、__init__的基因图谱二、初始化过程的魔法时刻继承链中的初始化顺序self参数的奥秘默认

html5的响应式布局的方法示例详解

《html5的响应式布局的方法示例详解》:本文主要介绍了HTML5中使用媒体查询和Flexbox进行响应式布局的方法,简要介绍了CSSGrid布局的基础知识和如何实现自动换行的网格布局,详细内容请阅读本文,希望能对你有所帮助... 一 使用媒体查询响应式布局        使用的参数@media这是常用的

Spring 基于XML配置 bean管理 Bean-IOC的方法

《Spring基于XML配置bean管理Bean-IOC的方法》:本文主要介绍Spring基于XML配置bean管理Bean-IOC的方法,本文给大家介绍的非常详细,对大家的学习或工作具有一... 目录一. spring学习的核心内容二. 基于 XML 配置 bean1. 通过类型来获取 bean2. 通过

基于Python实现读取嵌套压缩包下文件的方法

《基于Python实现读取嵌套压缩包下文件的方法》工作中遇到的问题,需要用Python实现嵌套压缩包下文件读取,本文给大家介绍了详细的解决方法,并有相关的代码示例供大家参考,需要的朋友可以参考下... 目录思路完整代码代码优化思路打开外层zip压缩包并遍历文件:使用with zipfile.ZipFil

Python处理函数调用超时的四种方法

《Python处理函数调用超时的四种方法》在实际开发过程中,我们可能会遇到一些场景,需要对函数的执行时间进行限制,例如,当一个函数执行时间过长时,可能会导致程序卡顿、资源占用过高,因此,在某些情况下,... 目录前言func-timeout1. 安装 func-timeout2. 基本用法自定义进程subp

Python列表去重的4种核心方法与实战指南详解

《Python列表去重的4种核心方法与实战指南详解》在Python开发中,处理列表数据时经常需要去除重复元素,本文将详细介绍4种最实用的列表去重方法,有需要的小伙伴可以根据自己的需要进行选择... 目录方法1:集合(set)去重法(最快速)方法2:顺序遍历法(保持顺序)方法3:副本删除法(原地修改)方法4:

Python中判断对象是否为空的方法

《Python中判断对象是否为空的方法》在Python开发中,判断对象是否为“空”是高频操作,但看似简单的需求却暗藏玄机,从None到空容器,从零值到自定义对象的“假值”状态,不同场景下的“空”需要精... 目录一、python中的“空”值体系二、精准判定方法对比三、常见误区解析四、进阶处理技巧五、性能优化

C++中初始化二维数组的几种常见方法

《C++中初始化二维数组的几种常见方法》本文详细介绍了在C++中初始化二维数组的不同方式,包括静态初始化、循环、全部为零、部分初始化、std::array和std::vector,以及std::vec... 目录1. 静态初始化2. 使用循环初始化3. 全部初始化为零4. 部分初始化5. 使用 std::a