ChatGLM的Trainer模块解析

2024-02-21 04:36
文章标签 模块 解析 chatglm trainer

本文主要是介绍ChatGLM的Trainer模块解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Trainer类

为什么这样比较好

Trainer是一个简单但功能完备的PyTorch训练和评估循环,针对🤗 Transformers进行了优化。

参数:

  • model: [PreTrainedModel]或torch.nn.Module类型,可选。用于训练、评估或预测的模型。如果未提供,则必须传入model_init
  • args: [TrainingArguments]类型,可选。用于调整训练的参数。如果未提供,则默认为一个基本的[TrainingArguments]实例,其中output_dir设置为当前目录下的名为tmp_trainer的目录。
  • data_collator: DataCollator类型,可选。用于从train_dataseteval_dataset的元素列表中形成批次的函数。如果未提供tokenizer,则默认为[default_data_collator];否则,默认为[DataCollatorWithPadding]的实例。
  • train_dataset: torch.utils.data.Datasettorch.utils.data.IterableDataset类型,可选。用于训练的数据集。如果是[~datasets.Dataset]类型,则会自动删除模型的forward()方法不接受的列。
  • eval_dataset: [torch.utils.data.Dataset]类型或字典类型,可选。用于评估的数据集。如果是[~datasets.Dataset]类型,则会自动删除模型的forward()方法不接受的列。如果是字典类型,则会对每个数据集进行评估,并在度量名称前添加字典键。
  • tokenizer: [PreTrainedTokenizerBase]类型,可选。用于预处理数据的分词器。如果提供,将在批处理输入时自动对输入进行最大长度填充,并将其与模型一起保存,以便在重新运行中断的训练或重用微调模型时更容易。
  • model_init: Callable[[], PreTrainedModel]类型,可选。用于实例化要使用的模型的函数。如果提供,每次调用[~Trainer.train]都会从此函数给出的模型的新实例开始。

重要属性:

  • model: 始终指向核心模型。如果使用transformers模型,它将是[PreTrainedModel]的子类。
  • model_wrapped: 始终指向最外层的模型,如果有一个或多个其他模块包装了原始模型。这是应该用于前向传递的模型。例如,在DeepSpeed下,内部模型被包装在DeepSpeed中,然后再次包装在torch.nn.DistributedDataParallel中。如果内部模型没有被包装,则self.model_wrappedself.model相同。
  • is_model_parallel: 模型是否已切换到模型并行模式(与数据并行不同,这意味着一些模型层在不同的GPU上分割)。
  • place_model_on_device: 是否自动将模型放置在设备上。如果使用模型并行或deepspeed,或者如果默认的TrainingArguments.place_model_on_device被覆盖为返回False,则将其设置为False
  • is_in_train: 模型当前是否正在运行train(例如,在train期间调用evaluate时)。

初始化部分

def _init_ 初始化函数

data_collator 这个属性的作用

329行 设置了随机化的种子,这个种子在哪里发挥作用?种子在采样器中发挥作用

参数里有一个full_determinism(完全决定)什么是完全决定?

deepspeed是什么?Deepspeed是一个用于深度学习训练的优化工具和库。它旨在提高训练速度和效率,并减少对昂贵硬件资源的需求。Deepspeed通过使用模型并行化、梯度累积、动态精度缩放等技术,可以在大规模模型和分布式训练中提供显著的性能改进。args.deepspeed是一个变量,它可能是用来指定是否启用Deepspeed的选项。通过检查args.deepspeed的值,你可以确定是否使用Deepspeed来优化你的深度学习训练过程。

hp_name是干什么的

有一个memory_tracker 记忆跟踪?

什么是当前的log_level 级别?

MODEL_MAPPING_NAMES这个里面存储着什么

sharded_ddp是共享GPU设置,不能和deepspeed和fsdp一起用

backward_prefetch是干什么的?

forword_prefetch是干什么的?

limit_all_gathers是干什么的?

回调函数

def add_callback() 增加回调函数,不是很理解这里的回调函数

这段代码是一个类方法add_callback的实现,它用于向当前的[~transformer.TrainerCallback]列表中添加一个回调函数。

回调函数是一段可以在训练过程的特定点执行的代码,用于执行如保存模型检查点、调整学习率或简单地打印出进度信息等任务。代码524行 在这行特定的代码中,添加的回调函数是PrinterCallbackDEFAULT_PROGRESS_CALLBACK,具体取决于self.args.disable_tqdm的值。

def pop_callback(self, callback): 弹出一个回调函数并返回

def remove_callback(self, callback): 移除一个回调函数

模型移动到设备

def _move_model_to_device(self, model, device):

这是一个名为_move_model_to_device的方法的实现。这个方法的目的是将模型移动到指定的设备上。它接受两个参数:modeldevice

在代码中,model = model.to(device)这一行将模型移动到指定的设备上。to()是一个PyTorch中的方法,用于将对象移动到指定的设备上。在这种情况下,model对象被移动到device设备上。

接下来的条件语句检查self.args.parallel_mode是否等于ParallelMode.TPU,并且model对象是否具有名为tie_weights的属性。如果条件成立,那么model.tie_weights()将被调用。这个方法可能是在特定的情况下用于绑定模型权重的。

总的来说,这个方法的作用是将给定的模型移动到指定的设备上,并在特定情况下重新绑定模型的权重。

签名列

设置签名列,移除签名列?不太懂

def _set_signature_columns_if_needed(self):

inspect.signature(self.model.forward)这行代码使用Python的内置inspect模块来获取模型的forward方法的签名。函数的签名是对函数参数的描述。总的来说,这个方法将_signature_columns设置为模型的forward方法的参数名称的列表,再加上"label"、"label_ids"和self.label_names的元素,同时移除任何重复的元素。

def _remove_unused_columns 移除未使用的列

def _get_collator_with_removed_columns()

data_collator是一个在PyTorch的数据加载过程中使用的概念,它是一个可调用的对象,用于将一批数据样本组合成一个批次的数据。

在PyTorch中,当你创建一个DataLoader对象时,你可以传递一个data_collator作为参数。DataLoader会使用这个data_collator来将从数据集中抽取的样本组合成一个批次的数据。

data_collator通常需要处理两个主要任务:

  1. 将数据样本的形状对齐,例如,通过填充短的序列使得所有的序列长度相同。
  2. 将数据样本堆叠(stack)成一个批次的数据。

例如,在处理自然语言处理(NLP)任务时,data_collator可能需要将不同长度的句子填充到相同的长度,然后将它们堆叠成一个批次的数据。总的来说,data_collator是一个在PyTorch的数据加载过程中使用的工具,用于将一批数据样本组合成一个批次的数据。

训练数据采样 和 迭代器 评估迭代器 和采样器

def _get_train_sampler 不太理解采样

def get_train_dataloader 训练迭代器

def _get_eval_sampler

def get_eval_dataloader

get_test_dataloader

在PyTorch中,训练采样器(Sampler)和训练迭代器(Iterator)是两个不同的概念,它们在数据加载和处理过程中起着不同的作用。

  1. 训练采样器(Sampler):Sampler负责定义如何从数据集中抽取样本。例如,RandomSampler会在每个epoch开始时随机打乱数据集的顺序,SequentialSampler则会按照数据集的原始顺序抽取样本。Sampler的主要作用是定义数据抽取的顺序,这对于训练模型来说非常重要,因为不同的抽样策略可能会导致模型的训练效果有所不同。
  2. 训练迭代器(Iterator):Iterator负责从数据集中实际抽取样本,并将它们组织成一个个的batch。在PyTorch中,这通常是通过DataLoader类来实现的。DataLoader接受一个数据集和一个Sampler作为输入,然后生成一个迭代器,这个迭代器可以在训练循环中使用,用于按照Sampler定义的顺序抽取数据并组织成batch。

创建优化器和策略

def create_optimizer_and_scheduler

def create_optimizer

def get_optimizer_cls_and_kwargs

def create_scheduler

hp search 超参数搜索

def num_examples

def _hp_search_setup

def _report_to_hp_search

hp_search可能是指超参数搜索(Hyperparameter Search),这是一种在机器学习中常用的技术,用于自动寻找模型的最优超参数。

超参数是在开始学习过程之前设置的参数,不能由训练过程学习。例如,学习率、批次大小、训练轮数、隐藏层的数量等都是超参数。超参数的选择可以极大地影响模型的性能。

超参数搜索的目标是找到一组超参数,使得在验证集上的性能最优。常见的超参数搜索方法包括网格搜索(Grid Search)、随机搜索(Random Search)和贝叶斯优化(Bayesian Optimization)等。

在某些机器学习库中,可能会提供hp_search或类似的工具来帮助进行超参数搜索。

保存模型

def _tune_save_checkpoint

def call_model_init

不知道在干什么

def torch_jit_model_eval

“JIT” 是 “Just-In-Time” 编译的缩写,这是一种在运行时动态将代码编译为机器代码的技术,以提高代码的执行效率。这种技术在许多编程语言和系统中都有应用,包括 Java、.NET 和 Python 的某些实现。

在 PyTorch 中,JIT 通过 TorchScript 提供。TorchScript 是 PyTorch 的一个子集,它可以将 PyTorch 程序转换为一个中间表示(Intermediate Representation,IR),然后对这个 IR 进行优化,并最终将其编译为机器代码。这使得 PyTorch 模型可以在没有 Python 解释器的环境中运行,例如在非 Python 环境或在移动设备上。

torch.jit 是 PyTorch 的一个模块,提供了将 PyTorch 模型转换为 TorchScript 的工具。例如,torch.jit.trace 函数可以通过运行模型一次并记录其操作来生成一个 TorchScript 模型,torch.jit.script 函数则可以将模型直接转换为 TorchScript,这对于包含控制流(如 if 和 for)的模型来说非常有用。

def ipex_optimize_model

PEX是为了在Intel硬件(如Intel Xeon CPU)上优化PyTorch性能而开发的。它可以提供自动混合精度训练(Automatic Mixed Precision,AMP)、通道最后的内存格式(Channel Last memory format)、JIT编译等功能,以提高PyTorch在Intel硬件上的运行效率。

def _wrap_model 这里是包装模型

这部分代码是一个Python类中的一个方法,名为_wrap_model。它的作用是对给定的模型进行包装和配置,以便在训练或推理过程中使用。

具体来说,这个方法执行以下操作:

  1. 检查是否需要使用Torch的编译功能,如果需要,则使用指定的后端和模式对模型进行编译。
  2. 检查是否需要使用Intel Extension for PyTorch (IPEX),如果需要,则对模型进行优化,选择数据类型为torch.bfloat16或torch.float32。
  3. 检查是否启用了SageMaker的多进程训练模式,如果是,则根据指定的参数创建一个分布式模型。
  4. 检查是否启用了DeepSpeed训练模式,如果是,则返回已经初始化的DeepSpeed模型。
  5. 检查模型是否已经被包装过,如果是,则直接返回模型。
  6. 检查是否需要使用Apex库进行混合精度训练(仅适用于torch < 1.6版本),如果是,则使用Apex库对模型和优化器进行初始化。
  7. 检查是否需要进行多GPU训练,如果是,则使用nn.DataParallel对模型进行包装。
  8. 检查是否需要在评估模式下使用Torch的即时编译模式,如果是,则对模型进行即时编译。
  9. 检查是否需要进行分布式训练,如果是,则根据指定的参数使用ShardedDDP或FullyShardedDDP对模型进行包装。
  10. 检查是否需要使用PyTorch FSDP(Fully Sharded Data Parallel)进行分布式训练,如果是,则使用FSDP对模型进行包装。
  11. 检查是否启用了SageMaker的分布式训练模式,如果是,则使用nn.parallel.DistributedDataParallel对模型进行包装。
  12. 检查是否在本地使用多GPU训练,如果是,则使用nn.parallel.DistributedDataParallel对模型进行包装。
  13. 返回经过包装和配置后的模型。

这个方法的目的是根据给定的训练模式和配置对模型进行适当的包装和优化,以便在训练或推理过程中获得更好的性能和效果。

def train

初始化训练循环

def _inner_training_loop

def _get_output_dir()

def _load_from_checkpoint

def _load_best_model

def _issue_warnings_after_load

def _maybe_log_save_evaluate

def _load_rng_state

def _save_checkpoint

def _load_optimizer_and_scheduler

def hyperparameter_search

def log

def _prepare_input

def _prepare_inputs

def compute_loss_context_manager

def autocast_smart_context_manager

def training_step

def compute_loss

def is_local_process_zero

def is_world_process_zero(self)

def save_model

def _save_tpu

def _save

def store_flos

def _sorted_checkpoints

def _rotate_checkpoints

def evaluate

def predict

def evaluation_loop

def _nested_gather

def _pad_across_processes

def prediction_step

def floating_point_ops

def init_git_repo

def create_model_card

def _push_from_checkpoint

def _push_from_checkpoint

def GLUJ

这篇关于ChatGLM的Trainer模块解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/730583

相关文章

使用Python实现批量访问URL并解析XML响应功能

《使用Python实现批量访问URL并解析XML响应功能》在现代Web开发和数据抓取中,批量访问URL并解析响应内容是一个常见的需求,本文将详细介绍如何使用Python实现批量访问URL并解析XML响... 目录引言1. 背景与需求2. 工具方法实现2.1 单URL访问与解析代码实现代码说明2.2 示例调用

SSID究竟是什么? WiFi网络名称及工作方式解析

《SSID究竟是什么?WiFi网络名称及工作方式解析》SID可以看作是无线网络的名称,类似于有线网络中的网络名称或者路由器的名称,在无线网络中,设备通过SSID来识别和连接到特定的无线网络... 当提到 Wi-Fi 网络时,就避不开「SSID」这个术语。简单来说,SSID 就是 Wi-Fi 网络的名称。比如

多模块的springboot项目发布指定模块的脚本方式

《多模块的springboot项目发布指定模块的脚本方式》该文章主要介绍了如何在多模块的SpringBoot项目中发布指定模块的脚本,作者原先的脚本会清理并编译所有模块,导致发布时间过长,通过简化脚本... 目录多模块的springboot项目发布指定模块的脚本1、不计成本地全部发布2、指定模块发布总结多模

SpringCloud配置动态更新原理解析

《SpringCloud配置动态更新原理解析》在微服务架构的浩瀚星海中,服务配置的动态更新如同魔法一般,能够让应用在不重启的情况下,实时响应配置的变更,SpringCloud作为微服务架构中的佼佼者,... 目录一、SpringBoot、Cloud配置的读取二、SpringCloud配置动态刷新三、更新@R

使用Java解析JSON数据并提取特定字段的实现步骤(以提取mailNo为例)

《使用Java解析JSON数据并提取特定字段的实现步骤(以提取mailNo为例)》在现代软件开发中,处理JSON数据是一项非常常见的任务,无论是从API接口获取数据,还是将数据存储为JSON格式,解析... 目录1. 背景介绍1.1 jsON简介1.2 实际案例2. 准备工作2.1 环境搭建2.1.1 添加

Python中构建终端应用界面利器Blessed模块的使用

《Python中构建终端应用界面利器Blessed模块的使用》Blessed库作为一个轻量级且功能强大的解决方案,开始在开发者中赢得口碑,今天,我们就一起来探索一下它是如何让终端UI开发变得轻松而高... 目录一、安装与配置:简单、快速、无障碍二、基本功能:从彩色文本到动态交互1. 显示基本内容2. 创建链

Node.js 中 http 模块的深度剖析与实战应用小结

《Node.js中http模块的深度剖析与实战应用小结》本文详细介绍了Node.js中的http模块,从创建HTTP服务器、处理请求与响应,到获取请求参数,每个环节都通过代码示例进行解析,旨在帮... 目录Node.js 中 http 模块的深度剖析与实战应用一、引言二、创建 HTTP 服务器:基石搭建(一

在C#中合并和解析相对路径方式

《在C#中合并和解析相对路径方式》Path类提供了几个用于操作文件路径的静态方法,其中包括Combine方法和GetFullPath方法,Combine方法将两个路径合并在一起,但不会解析包含相对元素... 目录C#合并和解析相对路径System.IO.Path类幸运的是总结C#合并和解析相对路径对于 C

Java解析JSON的六种方案

《Java解析JSON的六种方案》这篇文章介绍了6种JSON解析方案,包括Jackson、Gson、FastJSON、JsonPath、、手动解析,分别阐述了它们的功能特点、代码示例、高级功能、优缺点... 目录前言1. 使用 Jackson:业界标配功能特点代码示例高级功能优缺点2. 使用 Gson:轻量

Java如何接收并解析HL7协议数据

《Java如何接收并解析HL7协议数据》文章主要介绍了HL7协议及其在医疗行业中的应用,详细描述了如何配置环境、接收和解析数据,以及与前端进行交互的实现方法,文章还分享了使用7Edit工具进行调试的经... 目录一、前言二、正文1、环境配置2、数据接收:HL7Monitor3、数据解析:HL7Busines