ChatGLM的Trainer模块解析

2024-02-21 04:36
文章标签 模块 解析 chatglm trainer

本文主要是介绍ChatGLM的Trainer模块解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Trainer类

为什么这样比较好

Trainer是一个简单但功能完备的PyTorch训练和评估循环,针对🤗 Transformers进行了优化。

参数:

  • model: [PreTrainedModel]或torch.nn.Module类型,可选。用于训练、评估或预测的模型。如果未提供,则必须传入model_init
  • args: [TrainingArguments]类型,可选。用于调整训练的参数。如果未提供,则默认为一个基本的[TrainingArguments]实例,其中output_dir设置为当前目录下的名为tmp_trainer的目录。
  • data_collator: DataCollator类型,可选。用于从train_dataseteval_dataset的元素列表中形成批次的函数。如果未提供tokenizer,则默认为[default_data_collator];否则,默认为[DataCollatorWithPadding]的实例。
  • train_dataset: torch.utils.data.Datasettorch.utils.data.IterableDataset类型,可选。用于训练的数据集。如果是[~datasets.Dataset]类型,则会自动删除模型的forward()方法不接受的列。
  • eval_dataset: [torch.utils.data.Dataset]类型或字典类型,可选。用于评估的数据集。如果是[~datasets.Dataset]类型,则会自动删除模型的forward()方法不接受的列。如果是字典类型,则会对每个数据集进行评估,并在度量名称前添加字典键。
  • tokenizer: [PreTrainedTokenizerBase]类型,可选。用于预处理数据的分词器。如果提供,将在批处理输入时自动对输入进行最大长度填充,并将其与模型一起保存,以便在重新运行中断的训练或重用微调模型时更容易。
  • model_init: Callable[[], PreTrainedModel]类型,可选。用于实例化要使用的模型的函数。如果提供,每次调用[~Trainer.train]都会从此函数给出的模型的新实例开始。

重要属性:

  • model: 始终指向核心模型。如果使用transformers模型,它将是[PreTrainedModel]的子类。
  • model_wrapped: 始终指向最外层的模型,如果有一个或多个其他模块包装了原始模型。这是应该用于前向传递的模型。例如,在DeepSpeed下,内部模型被包装在DeepSpeed中,然后再次包装在torch.nn.DistributedDataParallel中。如果内部模型没有被包装,则self.model_wrappedself.model相同。
  • is_model_parallel: 模型是否已切换到模型并行模式(与数据并行不同,这意味着一些模型层在不同的GPU上分割)。
  • place_model_on_device: 是否自动将模型放置在设备上。如果使用模型并行或deepspeed,或者如果默认的TrainingArguments.place_model_on_device被覆盖为返回False,则将其设置为False
  • is_in_train: 模型当前是否正在运行train(例如,在train期间调用evaluate时)。

初始化部分

def _init_ 初始化函数

data_collator 这个属性的作用

329行 设置了随机化的种子,这个种子在哪里发挥作用?种子在采样器中发挥作用

参数里有一个full_determinism(完全决定)什么是完全决定?

deepspeed是什么?Deepspeed是一个用于深度学习训练的优化工具和库。它旨在提高训练速度和效率,并减少对昂贵硬件资源的需求。Deepspeed通过使用模型并行化、梯度累积、动态精度缩放等技术,可以在大规模模型和分布式训练中提供显著的性能改进。args.deepspeed是一个变量,它可能是用来指定是否启用Deepspeed的选项。通过检查args.deepspeed的值,你可以确定是否使用Deepspeed来优化你的深度学习训练过程。

hp_name是干什么的

有一个memory_tracker 记忆跟踪?

什么是当前的log_level 级别?

MODEL_MAPPING_NAMES这个里面存储着什么

sharded_ddp是共享GPU设置,不能和deepspeed和fsdp一起用

backward_prefetch是干什么的?

forword_prefetch是干什么的?

limit_all_gathers是干什么的?

回调函数

def add_callback() 增加回调函数,不是很理解这里的回调函数

这段代码是一个类方法add_callback的实现,它用于向当前的[~transformer.TrainerCallback]列表中添加一个回调函数。

回调函数是一段可以在训练过程的特定点执行的代码,用于执行如保存模型检查点、调整学习率或简单地打印出进度信息等任务。代码524行 在这行特定的代码中,添加的回调函数是PrinterCallbackDEFAULT_PROGRESS_CALLBACK,具体取决于self.args.disable_tqdm的值。

def pop_callback(self, callback): 弹出一个回调函数并返回

def remove_callback(self, callback): 移除一个回调函数

模型移动到设备

def _move_model_to_device(self, model, device):

这是一个名为_move_model_to_device的方法的实现。这个方法的目的是将模型移动到指定的设备上。它接受两个参数:modeldevice

在代码中,model = model.to(device)这一行将模型移动到指定的设备上。to()是一个PyTorch中的方法,用于将对象移动到指定的设备上。在这种情况下,model对象被移动到device设备上。

接下来的条件语句检查self.args.parallel_mode是否等于ParallelMode.TPU,并且model对象是否具有名为tie_weights的属性。如果条件成立,那么model.tie_weights()将被调用。这个方法可能是在特定的情况下用于绑定模型权重的。

总的来说,这个方法的作用是将给定的模型移动到指定的设备上,并在特定情况下重新绑定模型的权重。

签名列

设置签名列,移除签名列?不太懂

def _set_signature_columns_if_needed(self):

inspect.signature(self.model.forward)这行代码使用Python的内置inspect模块来获取模型的forward方法的签名。函数的签名是对函数参数的描述。总的来说,这个方法将_signature_columns设置为模型的forward方法的参数名称的列表,再加上"label"、"label_ids"和self.label_names的元素,同时移除任何重复的元素。

def _remove_unused_columns 移除未使用的列

def _get_collator_with_removed_columns()

data_collator是一个在PyTorch的数据加载过程中使用的概念,它是一个可调用的对象,用于将一批数据样本组合成一个批次的数据。

在PyTorch中,当你创建一个DataLoader对象时,你可以传递一个data_collator作为参数。DataLoader会使用这个data_collator来将从数据集中抽取的样本组合成一个批次的数据。

data_collator通常需要处理两个主要任务:

  1. 将数据样本的形状对齐,例如,通过填充短的序列使得所有的序列长度相同。
  2. 将数据样本堆叠(stack)成一个批次的数据。

例如,在处理自然语言处理(NLP)任务时,data_collator可能需要将不同长度的句子填充到相同的长度,然后将它们堆叠成一个批次的数据。总的来说,data_collator是一个在PyTorch的数据加载过程中使用的工具,用于将一批数据样本组合成一个批次的数据。

训练数据采样 和 迭代器 评估迭代器 和采样器

def _get_train_sampler 不太理解采样

def get_train_dataloader 训练迭代器

def _get_eval_sampler

def get_eval_dataloader

get_test_dataloader

在PyTorch中,训练采样器(Sampler)和训练迭代器(Iterator)是两个不同的概念,它们在数据加载和处理过程中起着不同的作用。

  1. 训练采样器(Sampler):Sampler负责定义如何从数据集中抽取样本。例如,RandomSampler会在每个epoch开始时随机打乱数据集的顺序,SequentialSampler则会按照数据集的原始顺序抽取样本。Sampler的主要作用是定义数据抽取的顺序,这对于训练模型来说非常重要,因为不同的抽样策略可能会导致模型的训练效果有所不同。
  2. 训练迭代器(Iterator):Iterator负责从数据集中实际抽取样本,并将它们组织成一个个的batch。在PyTorch中,这通常是通过DataLoader类来实现的。DataLoader接受一个数据集和一个Sampler作为输入,然后生成一个迭代器,这个迭代器可以在训练循环中使用,用于按照Sampler定义的顺序抽取数据并组织成batch。

创建优化器和策略

def create_optimizer_and_scheduler

def create_optimizer

def get_optimizer_cls_and_kwargs

def create_scheduler

hp search 超参数搜索

def num_examples

def _hp_search_setup

def _report_to_hp_search

hp_search可能是指超参数搜索(Hyperparameter Search),这是一种在机器学习中常用的技术,用于自动寻找模型的最优超参数。

超参数是在开始学习过程之前设置的参数,不能由训练过程学习。例如,学习率、批次大小、训练轮数、隐藏层的数量等都是超参数。超参数的选择可以极大地影响模型的性能。

超参数搜索的目标是找到一组超参数,使得在验证集上的性能最优。常见的超参数搜索方法包括网格搜索(Grid Search)、随机搜索(Random Search)和贝叶斯优化(Bayesian Optimization)等。

在某些机器学习库中,可能会提供hp_search或类似的工具来帮助进行超参数搜索。

保存模型

def _tune_save_checkpoint

def call_model_init

不知道在干什么

def torch_jit_model_eval

“JIT” 是 “Just-In-Time” 编译的缩写,这是一种在运行时动态将代码编译为机器代码的技术,以提高代码的执行效率。这种技术在许多编程语言和系统中都有应用,包括 Java、.NET 和 Python 的某些实现。

在 PyTorch 中,JIT 通过 TorchScript 提供。TorchScript 是 PyTorch 的一个子集,它可以将 PyTorch 程序转换为一个中间表示(Intermediate Representation,IR),然后对这个 IR 进行优化,并最终将其编译为机器代码。这使得 PyTorch 模型可以在没有 Python 解释器的环境中运行,例如在非 Python 环境或在移动设备上。

torch.jit 是 PyTorch 的一个模块,提供了将 PyTorch 模型转换为 TorchScript 的工具。例如,torch.jit.trace 函数可以通过运行模型一次并记录其操作来生成一个 TorchScript 模型,torch.jit.script 函数则可以将模型直接转换为 TorchScript,这对于包含控制流(如 if 和 for)的模型来说非常有用。

def ipex_optimize_model

PEX是为了在Intel硬件(如Intel Xeon CPU)上优化PyTorch性能而开发的。它可以提供自动混合精度训练(Automatic Mixed Precision,AMP)、通道最后的内存格式(Channel Last memory format)、JIT编译等功能,以提高PyTorch在Intel硬件上的运行效率。

def _wrap_model 这里是包装模型

这部分代码是一个Python类中的一个方法,名为_wrap_model。它的作用是对给定的模型进行包装和配置,以便在训练或推理过程中使用。

具体来说,这个方法执行以下操作:

  1. 检查是否需要使用Torch的编译功能,如果需要,则使用指定的后端和模式对模型进行编译。
  2. 检查是否需要使用Intel Extension for PyTorch (IPEX),如果需要,则对模型进行优化,选择数据类型为torch.bfloat16或torch.float32。
  3. 检查是否启用了SageMaker的多进程训练模式,如果是,则根据指定的参数创建一个分布式模型。
  4. 检查是否启用了DeepSpeed训练模式,如果是,则返回已经初始化的DeepSpeed模型。
  5. 检查模型是否已经被包装过,如果是,则直接返回模型。
  6. 检查是否需要使用Apex库进行混合精度训练(仅适用于torch < 1.6版本),如果是,则使用Apex库对模型和优化器进行初始化。
  7. 检查是否需要进行多GPU训练,如果是,则使用nn.DataParallel对模型进行包装。
  8. 检查是否需要在评估模式下使用Torch的即时编译模式,如果是,则对模型进行即时编译。
  9. 检查是否需要进行分布式训练,如果是,则根据指定的参数使用ShardedDDP或FullyShardedDDP对模型进行包装。
  10. 检查是否需要使用PyTorch FSDP(Fully Sharded Data Parallel)进行分布式训练,如果是,则使用FSDP对模型进行包装。
  11. 检查是否启用了SageMaker的分布式训练模式,如果是,则使用nn.parallel.DistributedDataParallel对模型进行包装。
  12. 检查是否在本地使用多GPU训练,如果是,则使用nn.parallel.DistributedDataParallel对模型进行包装。
  13. 返回经过包装和配置后的模型。

这个方法的目的是根据给定的训练模式和配置对模型进行适当的包装和优化,以便在训练或推理过程中获得更好的性能和效果。

def train

初始化训练循环

def _inner_training_loop

def _get_output_dir()

def _load_from_checkpoint

def _load_best_model

def _issue_warnings_after_load

def _maybe_log_save_evaluate

def _load_rng_state

def _save_checkpoint

def _load_optimizer_and_scheduler

def hyperparameter_search

def log

def _prepare_input

def _prepare_inputs

def compute_loss_context_manager

def autocast_smart_context_manager

def training_step

def compute_loss

def is_local_process_zero

def is_world_process_zero(self)

def save_model

def _save_tpu

def _save

def store_flos

def _sorted_checkpoints

def _rotate_checkpoints

def evaluate

def predict

def evaluation_loop

def _nested_gather

def _pad_across_processes

def prediction_step

def floating_point_ops

def init_git_repo

def create_model_card

def _push_from_checkpoint

def _push_from_checkpoint

def GLUJ

这篇关于ChatGLM的Trainer模块解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/730583

相关文章

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

python: 多模块(.py)中全局变量的导入

文章目录 global关键字可变类型和不可变类型数据的内存地址单模块(单个py文件)的全局变量示例总结 多模块(多个py文件)的全局变量from x import x导入全局变量示例 import x导入全局变量示例 总结 global关键字 global 的作用范围是模块(.py)级别: 当你在一个模块(文件)中使用 global 声明变量时,这个变量只在该模块的全局命名空

深入探索协同过滤:从原理到推荐模块案例

文章目录 前言一、协同过滤1. 基于用户的协同过滤(UserCF)2. 基于物品的协同过滤(ItemCF)3. 相似度计算方法 二、相似度计算方法1. 欧氏距离2. 皮尔逊相关系数3. 杰卡德相似系数4. 余弦相似度 三、推荐模块案例1.基于文章的协同过滤推荐功能2.基于用户的协同过滤推荐功能 前言     在信息过载的时代,推荐系统成为连接用户与内容的桥梁。本文聚焦于

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

OWASP十大安全漏洞解析

OWASP(开放式Web应用程序安全项目)发布的“十大安全漏洞”列表是Web应用程序安全领域的权威指南,它总结了Web应用程序中最常见、最危险的安全隐患。以下是对OWASP十大安全漏洞的详细解析: 1. 注入漏洞(Injection) 描述:攻击者通过在应用程序的输入数据中插入恶意代码,从而控制应用程序的行为。常见的注入类型包括SQL注入、OS命令注入、LDAP注入等。 影响:可能导致数据泄

从状态管理到性能优化:全面解析 Android Compose

文章目录 引言一、Android Compose基本概念1.1 什么是Android Compose?1.2 Compose的优势1.3 如何在项目中使用Compose 二、Compose中的状态管理2.1 状态管理的重要性2.2 Compose中的状态和数据流2.3 使用State和MutableState处理状态2.4 通过ViewModel进行状态管理 三、Compose中的列表和滚动

Spring 源码解读:自定义实现Bean定义的注册与解析

引言 在Spring框架中,Bean的注册与解析是整个依赖注入流程的核心步骤。通过Bean定义,Spring容器知道如何创建、配置和管理每个Bean实例。本篇文章将通过实现一个简化版的Bean定义注册与解析机制,帮助你理解Spring框架背后的设计逻辑。我们还将对比Spring中的BeanDefinition和BeanDefinitionRegistry,以全面掌握Bean注册和解析的核心原理。

CSP 2023 提高级第一轮 CSP-S 2023初试题 完善程序第二题解析 未完

一、题目阅读 (最大值之和)给定整数序列 a0,⋯,an−1,求该序列所有非空连续子序列的最大值之和。上述参数满足 1≤n≤105 和 1≤ai≤108。 一个序列的非空连续子序列可以用两个下标 ll 和 rr(其中0≤l≤r<n0≤l≤r<n)表示,对应的序列为 al,al+1,⋯,ar​。两个非空连续子序列不同,当且仅当下标不同。 例如,当原序列为 [1,2,1,2] 时,要计算子序列 [

多线程解析报表

假如有这样一个需求,当我们需要解析一个Excel里多个sheet的数据时,可以考虑使用多线程,每个线程解析一个sheet里的数据,等到所有的sheet都解析完之后,程序需要提示解析完成。 Way1 join import java.time.LocalTime;public class Main {public static void main(String[] args) thro

ZooKeeper 中的 Curator 框架解析

Apache ZooKeeper 是一个为分布式应用提供一致性服务的软件。它提供了诸如配置管理、分布式同步、组服务等功能。在使用 ZooKeeper 时,Curator 是一个非常流行的客户端库,它简化了 ZooKeeper 的使用,提供了高级的抽象和丰富的工具。本文将详细介绍 Curator 框架,包括它的设计哲学、核心组件以及如何使用 Curator 来简化 ZooKeeper 的操作。 1