【ShuQiHere】微调与训练恢复:理解 `load_weights` 和 `save_model` 的实用方法

本文主要是介绍【ShuQiHere】微调与训练恢复:理解 `load_weights` 和 `save_model` 的实用方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

【ShuQiHere】

在深度学习的世界中,训练一个模型不仅需要时间,还需要大量的计算资源。比如,你已经花了几天时间训练一个模型,但突然间,电脑崩溃了,你的所有进度都丢失了。这种情况就像是在一场马拉松比赛的最后一公里摔倒,让人沮丧至极。那么,有没有什么方法可以避免这种悲剧呢?今天,我们就来聊聊如何通过保存和加载模型的权重来应对这些挑战,确保你在深度学习的旅程中不会白费功夫。

模型保存和加载的背景

训练一个深度学习模型就像建造一座摩天大楼。你需要从基础开始,一层层地搭建,最终完成一个复杂的系统。然而,建造过程中难免会遇到意外,比如断电、系统崩溃,甚至是代码错误。这些意外可能让你的努力前功尽弃。如果你不想每次意外发生后都从头开始,那么模型保存和加载就显得尤为重要。

在 TensorFlow 中,我们有两种主要的保存和加载方法:保存整个模型保存模型的权重。理解它们的区别和用法,就像学会了在建造摩天大楼时如何保存施工进度,确保即使遭遇突发事件,你的建筑工程也能顺利继续。

微调模型:从预训练到自定义任务

我们都知道,训练一个深度学习模型需要大量的数据和计算资源。幸运的是,深度学习社区里有很多预训练的模型,这些模型已经在大规模数据集上进行了训练。通过微调(Fine-Tuning),你可以利用这些预训练模型,在它们的基础上进行训练,快速适应新的任务。

场景:训练猫狗识别模型

比如,你想训练一个模型来区分猫和狗。如果从零开始训练,不仅费时费力,而且可能效果不佳。但如果你有一个在 ImageNet 上预训练的模型,就可以大大减少训练时间。你只需加载预训练的权重,并在猫狗数据集上微调模型,这样不仅节省了时间,还能获得更好的效果。

微调代码示例
import tensorflow as tf
from tensorflow.keras import layers, models# 创建基础模型结构
base_model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(64, 64, 3)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(128, (3, 3), activation='relu'),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10, activation='softmax')  # 假设有10个类别
])# 加载预训练权重
base_model.load_weights('pretrained_weights.h5')# 冻结部分层
for layer in base_model.layers[:-1]:layer.trainable = False# 编译并微调模型
base_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])# 微调训练
base_model.fit(new_data, new_labels, epochs=10)

解释

  • 加载预训练权重:通过 load_weights,你可以将已有的知识应用到新的任务中,避免从零开始训练模型。
  • 冻结层:冻结部分层的目的是保留预训练模型中已经学到的通用特征,仅微调特定的几层以适应新任务。这就像是在一个已经建好的摩天大楼里装修几层,改造成你需要的样子,而不是重新建造整栋大楼。

通过这种方式,你可以在保持预训练模型中有用特征的同时,快速适应新的任务或数据集。对于刚入门的小白来说,这是一种高效且实用的策略。

应对训练中断:如何保存和恢复模型

在训练模型的过程中,意外总是难以避免。系统崩溃、断电、内存不足等问题可能随时出现,这就像是在建造摩天大楼时,突然遇到大风暴,导致施工中断。那么,如何在中断后快速恢复呢?定期保存模型的权重是一个明智的选择。

保存权重

通过 TensorFlow 的 ModelCheckpoint 回调函数,你可以定期保存模型的权重,确保即使训练中断,也能从上次保存的进度继续。

import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint# 设置检查点保存路径
checkpoint_path = "training_checkpoints/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)# 创建一个 ModelCheckpoint 回调函数
cp_callback = ModelCheckpoint(filepath=checkpoint_path,save_weights_only=True,verbose=1,save_freq=600)  # 每处理 600 个输入样本保存一次# 训练模型,同时保存检查点
model.fit(train_data, train_labels, epochs=10, callbacks=[cp_callback])

解释

  • 定期保存ModelCheckpoint 可以帮助你定期保存模型的权重,就像是在建造摩天大楼时,每隔一段时间都拍张照片保存进度。这样,即使突然下雨或者停电,你也可以在天气恢复后继续施工。
恢复训练

当训练中断时,你可以从最新的检查点恢复训练,而不必从头开始。这不仅节省了时间,也让你不会因为突发事件而感到沮丧。

# 查找最新的检查点
latest = tf.train.latest_checkpoint(checkpoint_dir)# 如果有检查点,加载权重
if latest:model.load_weights(latest)print(f"Loaded weights from checkpoint: {latest}")
else:print("No checkpoint found. Starting from scratch.")# 继续训练
model.fit(train_data, train_labels, epochs=10, initial_epoch=int(latest.split('-')[-1].split('.')[0]) if latest else 0,callbacks=[cp_callback])

解释

  • 从检查点恢复:通过 tf.train.latest_checkpoint 函数,你可以找到最近的检查点,并通过 load_weights 恢复模型的状态,从中断的地方继续训练。这就像是你在大风暴后回到工地,拿出之前保存的施工进度照片,继续建造摩天大楼。
保存与加载完整模型:从开发到生产的无缝衔接

在训练完成后,我们不仅需要保存权重,还可能需要保存整个模型。这样做的目的是为了方便模型的部署和迁移。通过保存整个模型,你可以在不同的环境中无缝地加载和使用它。

保存完整模型
# 保存整个模型
model.save('my_full_model.h5')
加载完整模型
from tensorflow.keras.models import load_model# 加载完整模型
model = load_model('my_full_model.h5')# 继续训练或进行推理
predictions = model.predict(test_data)

解释

  • 保存整个模型:使用 save_model 可以将模型的结构、权重以及优化器状态一并保存,确保模型在不同环境中的一致性。
  • 加载整个模型load_model 允许你在任何支持 TensorFlow/Keras 的环境中重新加载并使用这个模型,无论是继续训练还是部署到生产环境。
超大型语言模型的微调

当我们面对像 LLaMA3 这样超大型的语言模型时,微调过程就更加复杂。由于这些模型的参数量巨大,通常我们不会直接微调所有参数,而是使用**参数高效微调(PEFT)**技术,如 LoRA 或 Adapter。这些技术允许我们通过调整少量参数来适应新的任务,既降低了计算资源的需求,又保证了模型的性能。

LoRA 微调示例
from peft import LoraConfig, get_peft_model# 配置LoRA
lora_config = LoraConfig(r=4,  # 低秩矩阵的秩lora_alpha=16,  # LoRA的缩放因子target_modules=["q_proj", "v_proj"],  # 在注意力层中应用LoRAlora_dropout=0.1,  # LoRA的dropout率
)# 将LoRA应用到模型
model = get_peft_model(model, lora_config)# 开始训练
trainer.train()

解释

  • LoRA 技术:LoRA 通过在模型的特定矩阵上应用低秩分解,实现参数高效微调。这就像是在摩天大楼的某些关键部位加固,从而确保建筑在更高负

载下依然稳固。

最佳实践与总结
  1. 定期保存:在训练过程中,使用 ModelCheckpoint 定期保存权重,防止因中断而丢失进度。
  2. 选择合适的保存方法:在开发过程中,可以使用 save_weights 进行频繁保存;在部署前,使用 save_model 保存整个模型。
  3. 恢复训练:通过 load_weights,你可以轻松恢复训练进度,并在中断后继续模型的优化。
  4. 参数高效微调:对于超大型模型,使用 LoRA 或 Adapter 等技术进行参数高效微调,可以大幅降低资源需求。

通过掌握这些技术,你不仅可以确保模型训练的稳健性,还能有效应对实际开发中的各种挑战。无论是微调预训练模型,还是处理不可预知的中断,load_weightssave_model 都将成为你深度学习开发中的利器。最佳实践与总结


这篇关于【ShuQiHere】微调与训练恢复:理解 `load_weights` 和 `save_model` 的实用方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1121003

相关文章

使用SecondaryNameNode恢复NameNode的数据

1)需求: NameNode进程挂了并且存储的数据也丢失了,如何恢复NameNode 此种方式恢复的数据可能存在小部分数据的丢失。 2)故障模拟 (1)kill -9 NameNode进程 [lytfly@hadoop102 current]$ kill -9 19886 (2)删除NameNode存储的数据(/opt/module/hadoop-3.1.4/data/tmp/dfs/na

认识、理解、分类——acm之搜索

普通搜索方法有两种:1、广度优先搜索;2、深度优先搜索; 更多搜索方法: 3、双向广度优先搜索; 4、启发式搜索(包括A*算法等); 搜索通常会用到的知识点:状态压缩(位压缩,利用hash思想压缩)。

电脑桌面文件删除了怎么找回来?别急,快速恢复攻略在此

在日常使用电脑的过程中,我们经常会遇到这样的情况:一不小心,桌面上的某个重要文件被删除了。这时,大多数人可能会感到惊慌失措,不知所措。 其实,不必过于担心,因为有很多方法可以帮助我们找回被删除的桌面文件。下面,就让我们一起来了解一下这些恢复桌面文件的方法吧。 一、使用撤销操作 如果我们刚刚删除了桌面上的文件,并且还没有进行其他操作,那么可以尝试使用撤销操作来恢复文件。在键盘上同时按下“C

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

浅谈主机加固,六种有效的主机加固方法

在数字化时代,数据的价值不言而喻,但随之而来的安全威胁也日益严峻。从勒索病毒到内部泄露,企业的数据安全面临着前所未有的挑战。为了应对这些挑战,一种全新的主机加固解决方案应运而生。 MCK主机加固解决方案,采用先进的安全容器中间件技术,构建起一套内核级的纵深立体防护体系。这一体系突破了传统安全防护的局限,即使在管理员权限被恶意利用的情况下,也能确保服务器的安全稳定运行。 普适主机加固措施:

webm怎么转换成mp4?这几种方法超多人在用!

webm怎么转换成mp4?WebM作为一种新兴的视频编码格式,近年来逐渐进入大众视野,其背后承载着诸多优势,但同时也伴随着不容忽视的局限性,首要挑战在于其兼容性边界,尽管WebM已广泛适应于众多网站与软件平台,但在特定应用环境或老旧设备上,其兼容难题依旧凸显,为用户体验带来不便,再者,WebM格式的非普适性也体现在编辑流程上,由于它并非行业内的通用标准,编辑过程中可能会遭遇格式不兼容的障碍,导致操

透彻!驯服大型语言模型(LLMs)的五种方法,及具体方法选择思路

引言 随着时间的发展,大型语言模型不再停留在演示阶段而是逐步面向生产系统的应用,随着人们期望的不断增加,目标也发生了巨大的变化。在短短的几个月的时间里,人们对大模型的认识已经从对其zero-shot能力感到惊讶,转变为考虑改进模型质量、提高模型可用性。 「大语言模型(LLMs)其实就是利用高容量的模型架构(例如Transformer)对海量的、多种多样的数据分布进行建模得到,它包含了大量的先验

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

【北交大信息所AI-Max2】使用方法

BJTU信息所集群AI_MAX2使用方法 使用的前提是预约到相应的算力卡,拥有登录权限的账号密码,一般为导师组共用一个。 有浏览器、ssh工具就可以。 1.新建集群Terminal 浏览器登陆10.126.62.75 (如果是1集群把75改成66) 交互式开发 执行器选Terminal 密码随便设一个(需记住) 工作空间:私有数据、全部文件 加速器选GeForce_RTX_2080_Ti

AI Toolkit + H100 GPU,一小时内微调最新热门文生图模型 FLUX

上个月,FLUX 席卷了互联网,这并非没有原因。他们声称优于 DALLE 3、Ideogram 和 Stable Diffusion 3 等模型,而这一点已被证明是有依据的。随着越来越多的流行图像生成工具(如 Stable Diffusion Web UI Forge 和 ComyUI)开始支持这些模型,FLUX 在 Stable Diffusion 领域的扩展将会持续下去。 自 FLU