XGB-3: 模型IO

2024-02-05 01:20
文章标签 模型 io xgb

本文主要是介绍XGB-3: 模型IO,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在XGBoost 1.0.0中,引入了对使用JSON保存/加载XGBoost模型和相关超参数的支持,旨在用一个可以轻松重用的开放格式取代旧的二进制内部格式。后来在XGBoost 1.6.0中,还添加了对通用二进制JSON的额外支持,作为更高效的模型IO的优化。它们具有相同的文档结构,但具有不同的表示形式,但都统称为JSON格式。本教程旨在分享一些关于XGBoost中使用的JSON序列化方法的基本见解。除非明确说明,以下各节假定正在使用2个输出格式之一,可以通过在保存/加载模型时提供带有.json(或二进制JSON的.ubj)文件扩展名的文件名来启用这两种格式:booster.save_model('model.json')

在开始之前,需要说明的是,XGBoost是一个以树模型为重点的梯度提升库,这意味着在XGBoost内部有两个明显的部分:

  • 由树组成的模型
  • 用于构建模型的超参数和配置

如果是专注于深度学习领域,那么应该清楚由固定张量操作的权重组成的神经网络结构与用于训练它们的优化器(例如RMSprop)之间存在差异。

因此,当调用 booster.save_model(在R中是 xgb.save)时,XGBoost会保存树、一些模型参数(例如在训练树中的输入列数)以及目标函数,这些组合在一起代表了XGBoost中的“模型”概念。至于为什么将目标函数保存为模型的一部分,原因是目标函数控制全局偏差的转换(在XGBoost中称为base_score)。用户可以与他人共享此模型,用于预测、评估或使用不同的超参数集继续训练等

有些情况下,需要保存的不仅仅是模型本身。例如,在分布式训练中,XGBoost执行检查点操作。或者由于某些原因,分布式计算框架决定将模型从一个工作节点复制到另一个工作节点,并在那里继续训练。在这种情况下,序列化输出需要包含足够的信息,以便在不需要用户再次提供任何参数的情况下继续以前的训练。将这种情景视为内存快照( memory snapshot或基于内存的序列化方法),并将其与普通的模型IO操作区分开来。目前,内存快照用于以下情况:

  • Python:使用内置的pickle模块对Booster对象进行pickle
  • R:使用内置函数saveRDS或save对xgb.Booster对象进行持久化
  • JVM:使用内置函数saveModelBooster对象进行序列化

注意:

旧的二进制格式不能区分模型和原始内存序列化格式的差异,它是一切的混合体。JVM包有其自己的基于内存的序列化方法。

为了启用模型 IO 的 JSON 格式支持(仅保存树和目标),请在文件名中使用 .json.ubj 作为文件扩展名,后者是通用二进制 JSON 的扩展名。

  • Python
bst.save_model('model_file_name.json')
  • R
xgb.save(bst, 'model_file_name.json')
  • Scala
val format = "json"  // or val format = "ubj"
model.write.option("format", format).save("model_directory_path")

注意:

仅从由 XGBoost 生成的 JSON 文件加载模型。尝试加载由外部来源生成的 JSON 文件可能导致未定义的行为和崩溃。

关于模型和内存快照的向后兼容性说明

保证模型的向后兼容性,但不保证内存快照的向后兼容性。

模型(树和目标)使用稳定的表示,因此在较早版本的 XGBoost 中生成的模型可以在较新版本的 XGBoost 中访问。如果希望将模型存储或存档以供长期存储,请使用 save_model(Python)和 xgb.save(R)。

另一方面,内存快照(序列化)捕获了 XGBoost 内部的许多内容,其格式不稳定且可能经常更改。因此,内存快照仅适用于检查点,可以持久保存训练配置的完整快照,以便可以从可能的故障中强大地恢复并恢复训练过程。加载由较早版本的 XGBoost 生成的内存快照可能会导致错误或未定义的行为。如果使用 pickle.dump(Python)或 saveRDS(R)持久保存模型,则该模型可能无法在较新版本的 XGBoost 中访问。

自定义目标和度量标准

XGBoost支持用户提供的自定义目标和度量标准函数作为扩展。这些函数不会保存在模型文件中,因为它们是与语言相关的特性。在Python中,用户可以使用pickle将这些函数包含在保存的二进制文件中。其中一个缺点是,pickle输出不是稳定的序列化格式,在不同的Python版本和XGBoost版本上都无法使用,更不用说在不同的语言环境中了。解决此限制的另一种方法是在加载模型后再次提供这些函数。如果定制的函数很有用,请考虑创建一个PR(Pull Request)在XGBoost内部实现它,这样就可以在不同的语言绑定中使用定制的函数。

加载来自不同版本XGBoost的pickled文件

如前所述,pickle模型既不具备可移植性,也不稳定,但在某些情况下,pickled模型是有价值的。将其在将来恢复的一种方法是使用特定版本的Python和XGBoost将其加载回来,然后通过调用save_model导出模型。

可以使用类似的过程来恢复保存在旧RDS文件中的模型。在R中,可以使用remotes包安装旧版本的XGBoost:

library(remotes)
remotes::install_version("xgboost", "0.90.0.1")   # 安装版本0.90.0.1

安装所需的版本后,可以使用readRDS加载RDS文件并恢复xgb.Booster对象。然后,调用xgb.save以使用稳定表示导出模型,就能够在最新版本的XGBoost中使用该模型。

  • Python
import xgboost as xgbbst = xgb.Booster({'nthread': 4}) 
bst.load_model('model_file_name.json')  # load xgb modelpreds = bst.predict(xgb.DMatrix(X_test))  # predict if x_test is not DMatrix format
print(preds)

保存和加载内部参数配置

XGBoost的C APIPython APIR API支持直接将内部配置保存和加载为JSON字符串。在Python包中:

bst = xgboost.train(...)
config = bst.save_config()
print(config)

或在R中:

config <- xgb.config(bst)
print(config)

将打印出类似以下的内容(由于太长,以下内容不是实际输出,仅用于演示):

{"Learner": {"generic_parameter": {"device": "cuda:0","gpu_page_size": "0","n_jobs": "0","random_state": "0","seed": "0","seed_per_iteration": "0"},"gradient_booster": {"gbtree_train_param": {"num_parallel_tree": "1","process_type": "default","tree_method": "hist","updater": "grow_gpu_hist","updater_seq": "grow_gpu_hist"},"name": "gbtree","updater": {"grow_gpu_hist": {"gpu_hist_train_param": {"debug_synchronize": "0",},"train_param": {"alpha": "0","cache_opt": "1","colsample_bylevel": "1","colsample_bynode": "1","colsample_bytree": "1","default_direction": "learn",..."subsample": "1"}}}},"learner_train_param": {"booster": "gbtree","disable_default_eval_metric": "0","objective": "reg:squarederror"},"metrics": [],"objective": {"name": "reg:squarederror","reg_loss_param": {"scale_pos_weight": "1"}}},"version": [1, 0, 0]
}

可以将其加载回由相同版本的XGBoost生成的模型,方法是:

bst.load_config(config)

保存模型和转储模型之间的区别

XGBoost在Booster对象中有一个名为dump_model的函数,它以可读的格式(如txtjsondot(graphviz))导出模型。它的主要用途是进行模型解释或可视化,不应该加载回XGBoost。JSON版本具有模式Schema 。

保存模型(Save Model): 通过save_model函数,XGBoost将整个模型以二进制格式保存到文件中。这包括模型的树结构、超参数和目标函数等。保存的模型文件可以用于在不同的XGBoost版本之间共享、加载和继续训练。

  • Python
booster.save_model('model.bin')
  • R
xgb.save(booster, 'model.bin')

转储模型(Dump Model): 通过dump_model函数,XGBoost将模型导出为可读的文本、JSON或Graphviz DOT格式,以便进行模型解释、可视化或分析。这是为了方便用户查看模型的结构和特性,而不是用于加载回XGBoost进行进一步的训练或预测。

  • Python
booster.dump_model('model.txt')
  • R
xgb.dump(booster, 'model.txt')

Json Schema

JSON格式的另一个重要特点是有一个详细记录的模式(schema),基于这个模式,用户可以轻松地重用XGBoost输出的模型。以下是输出模型的JSON模式(不是序列化,如上所述将不是稳定的)。有关解析XGBoost树模型的示例,请参见/demo/json-model。请注意“dart” booster 中使用的“weight_drop”字段。XGBoost不直接对树叶进行缩放,而是将权重保存为一个单独的数组

{"$schema": "http://json-schema.org/draft-07/schema#","definitions": {"gbtree": {"type": "object","properties": {"name": {"const": "gbtree"},"model": {"type": "object","properties": {"gbtree_model_param": {"$ref": "#/definitions/gbtree_model_param"},"trees": {"type": "array","items": {"type": "object","properties": {"tree_param": {"$ref": "#/definitions/tree_param"},"id": {"type": "integer"},"loss_changes": {"type": "array","items": {"type": "number"}},"sum_hessian": {"type": "array","items": {"type": "number"}},"base_weights": {"type": "array","items": {"type": "number"}},"left_children": {"type": "array","items": {"type": "integer"}},"right_children": {"type": "array","items": {"type": "integer"}},"parents": {"type": "array","items": {"type": "integer"}},"split_indices": {"type": "array","items": {"type": "integer"}},"split_conditions": {"type": "array","items": {"type": "number"}},"split_type": {"type": "array","items": {"type": "integer"}},"default_left": {"type": "array","items": {"type": "integer"}},"categories": {"type": "array","items": {"type": "integer"}},"categories_nodes": {"type": "array","items": {"type": "integer"}},"categories_segments": {"type": "array","items": {"type": "integer"}},"categories_sizes": {"type": "array","items": {"type": "integer"}}},"required": ["tree_param","loss_changes","sum_hessian","base_weights","left_children","right_children","parents","split_indices","split_conditions","default_left","categories","categories_nodes","categories_segments","categories_sizes"]}},"tree_info": {"type": "array","items": {"type": "integer"}}},"required": ["gbtree_model_param","trees","tree_info"]}},"required": ["name","model"]},"gbtree_model_param": {"type": "object","properties": {"num_trees": {"type": "string"},"num_parallel_tree": {"type": "string"}},"required": ["num_trees","num_parallel_tree"]},"tree_param": {"type": "object","properties": {"num_nodes": {"type": "string"},"size_leaf_vector": {"type": "string"},"num_feature": {"type": "string"}},"required": ["num_nodes","num_feature","size_leaf_vector"]},"reg_loss_param": {"type": "object","properties": {"scale_pos_weight": {"type": "string"}}},"pseudo_huber_param": {"type": "object","properties": {"huber_slope": {"type": "string"}}},"aft_loss_param": {"type": "object","properties": {"aft_loss_distribution": {"type": "string"},"aft_loss_distribution_scale": {"type": "string"}}},"softmax_multiclass_param": {"type": "object","properties": {"num_class": { "type": "string" }}},"lambda_rank_param": {"type": "object","properties": {"num_pairsample": { "type": "string" },"fix_list_weight": { "type": "string" }}},"lambdarank_param": {"type": "object","properties": {"lambdarank_num_pair_per_sample": { "type": "string" },"lambdarank_pair_method": { "type": "string" },"lambdarank_unbiased": {"type": "string" },"lambdarank_bias_norm": {"type": "string" },"ndcg_exp_gain": {"type": "string"}}}},"type": "object","properties": {"version": {"type": "array","items": [{"type": "number","minimum": 1},{"type": "number","minimum": 0},{"type": "number","minimum": 0}],"minItems": 3,"maxItems": 3},"learner": {"type": "object","properties": {"feature_names": {"type": "array","items": {"type": "string"}},"feature_types": {"type": "array","items": {"type": "string"}},"gradient_booster": {"oneOf": [{"$ref": "#/definitions/gbtree"},{"type": "object","properties": {"name": { "const": "gblinear" },"model": {"type": "object","properties": {"weights": {"type": "array","items": {"type": "number"}}}}}},{"type": "object","properties": {"name": { "const": "dart" },"gbtree": {"$ref": "#/definitions/gbtree"},"weight_drop": {"type": "array","items": {"type": "number"}}},"required": ["name","gbtree","weight_drop"]}]},"objective": {"oneOf": [{"type": "object","properties": {"name": { "const": "reg:squarederror" },"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}},"required": ["name","reg_loss_param"]},{"type": "object","properties": {"name": { "const": "reg:pseudohubererror" },"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}},"required": ["name","reg_loss_param"]},{"type": "object","properties": {"name": { "const": "reg:squaredlogerror" },"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}},"required": ["name","reg_loss_param"]},{"type": "object","properties": {"name": { "const": "reg:linear" },"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}},"required": ["name","reg_loss_param"]},{"type": "object","properties": {"name": { "const": "reg:logistic" },"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}},"required": ["name","reg_loss_param"]},{"type": "object","properties": {"name": { "const": "binary:logistic" },"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}},"required": ["name","reg_loss_param"]},{"type": "object","properties": {"name": { "const": "binary:logitraw" },"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}},"required": ["name","reg_loss_param"]},{"type": "object","properties": {"name": { "const": "count:poisson" },"poisson_regression_param": {"type": "object","properties": {"max_delta_step": { "type": "string" }}}},"required": ["name","poisson_regression_param"]},{"type": "object","properties": {"name": { "const": "reg:tweedie" },"tweedie_regression_param": {"type": "object","properties": {"tweedie_variance_power": { "type": "string" }}}},"required": ["name","tweedie_regression_param"]},{"properties": {"name": {"const": "reg:absoluteerror"}},"type": "object"},{"properties": {"name": {"const": "reg:quantileerror"},"quantile_loss_param": {"type": "object","properties": {"quantle_alpha": {"type": "array"}}}},"type": "object"},{"type": "object","properties": {"name": { "const": "survival:cox" }},"required": [ "name" ]},{"type": "object","properties": {"name": { "const": "reg:gamma" }},"required": [ "name" ]},{"type": "object","properties": {"name": { "const": "multi:softprob" },"softmax_multiclass_param": { "$ref": "#/definitions/softmax_multiclass_param"}},"required": ["name","softmax_multiclass_param"]},{"type": "object","properties": {"name": { "const": "multi:softmax" },"softmax_multiclass_param": { "$ref": "#/definitions/softmax_multiclass_param"}},"required": ["name","softmax_multiclass_param"]},{"type": "object","properties": {"name": { "const": "rank:pairwise" },"lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}},"required": ["name","lambdarank_param"]},{"type": "object","properties": {"name": { "const": "rank:ndcg" },"lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}},"required": ["name","lambdarank_param"]},{"type": "object","properties": {"name": { "const": "rank:map" },"lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}},"required": ["name","lambda_rank_param"]},{"type": "object","properties": {"name": {"const": "survival:aft"},"aft_loss_param": { "$ref": "#/definitions/aft_loss_param"}}},{"type": "object","properties": {"name": {"const": "binary:hinge"}}}]},"learner_model_param": {"type": "object","properties": {"base_score": { "type": "string" },"num_class": { "type": "string" },"num_feature": { "type": "string" },"num_target": { "type": "string" }}}},"required": ["gradient_booster","objective"]}},"required": ["version","learner"]
}

这篇关于XGB-3: 模型IO的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/679372

相关文章

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了

透彻!驯服大型语言模型(LLMs)的五种方法,及具体方法选择思路

引言 随着时间的发展,大型语言模型不再停留在演示阶段而是逐步面向生产系统的应用,随着人们期望的不断增加,目标也发生了巨大的变化。在短短的几个月的时间里,人们对大模型的认识已经从对其zero-shot能力感到惊讶,转变为考虑改进模型质量、提高模型可用性。 「大语言模型(LLMs)其实就是利用高容量的模型架构(例如Transformer)对海量的、多种多样的数据分布进行建模得到,它包含了大量的先验

图神经网络模型介绍(1)

我们将图神经网络分为基于谱域的模型和基于空域的模型,并按照发展顺序详解每个类别中的重要模型。 1.1基于谱域的图神经网络         谱域上的图卷积在图学习迈向深度学习的发展历程中起到了关键的作用。本节主要介绍三个具有代表性的谱域图神经网络:谱图卷积网络、切比雪夫网络和图卷积网络。 (1)谱图卷积网络 卷积定理:函数卷积的傅里叶变换是函数傅里叶变换的乘积,即F{f*g}

秋招最新大模型算法面试,熬夜都要肝完它

💥大家在面试大模型LLM这个板块的时候,不知道面试完会不会复盘、总结,做笔记的习惯,这份大模型算法岗面试八股笔记也帮助不少人拿到过offer ✨对于面试大模型算法工程师会有一定的帮助,都附有完整答案,熬夜也要看完,祝大家一臂之力 这份《大模型算法工程师面试题》已经上传CSDN,还有完整版的大模型 AI 学习资料,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

AI Toolkit + H100 GPU,一小时内微调最新热门文生图模型 FLUX

上个月,FLUX 席卷了互联网,这并非没有原因。他们声称优于 DALLE 3、Ideogram 和 Stable Diffusion 3 等模型,而这一点已被证明是有依据的。随着越来越多的流行图像生成工具(如 Stable Diffusion Web UI Forge 和 ComyUI)开始支持这些模型,FLUX 在 Stable Diffusion 领域的扩展将会持续下去。 自 FLU

SWAP作物生长模型安装教程、数据制备、敏感性分析、气候变化影响、R模型敏感性分析与贝叶斯优化、Fortran源代码分析、气候数据降尺度与变化影响分析

查看原文>>>全流程SWAP农业模型数据制备、敏感性分析及气候变化影响实践技术应用 SWAP模型是由荷兰瓦赫宁根大学开发的先进农作物模型,它综合考虑了土壤-水分-大气以及植被间的相互作用;是一种描述作物生长过程的一种机理性作物生长模型。它不但运用Richard方程,使其能够精确的模拟土壤中水分的运动,而且耦合了WOFOST作物模型使作物的生长描述更为科学。 本文让更多的科研人员和农业工作者

线性因子模型 - 独立分量分析(ICA)篇

序言 线性因子模型是数据分析与机器学习中的一类重要模型,它们通过引入潜变量( latent variables \text{latent variables} latent variables)来更好地表征数据。其中,独立分量分析( ICA \text{ICA} ICA)作为线性因子模型的一种,以其独特的视角和广泛的应用领域而备受关注。 ICA \text{ICA} ICA旨在将观察到的复杂信号