百川2模型解读

2024-05-05 17:44
文章标签 模型 解读 百川

本文主要是介绍百川2模型解读,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

简介

Baichuan 2是多语言大模型,目前开源了70亿和130亿参数规模的模型。在公开基准如MMLU、CMMLU、GSM8K和HumanEval上的评测,Baichuan 2达到或超过了其他同类开源模型,并在医学和法律等垂直领域表现优异。此外,官方还发布所有预训练模型的checkpoints,帮助研究社区更好地理解Baichuan 2的训练过程。总结下Baichuan 2特点:

  • 多语言支持:Baichuan 2专注于训练在多种语言中表现优异的模型,包括不仅限于英文。这使得Baichuan 2在处理各种语言的任务时能够取得显著的性能提升。
  • 广泛的训练数据:Baichuan 2 是从头开始训练的,训练数据约有2.6万亿个token。相对于以往的模型,Baichuan 2 提供了更丰富的数据资源,从而能够更好地支持多语言的开发和应用。
  • 垂直领域优化:Baichuan 2不仅在通用任务上表现出色,还在特定领域(如医学和法律)的任务中展现了卓越的性能。这为特定领域的应用提供了强有力的支持。

GitHub:

https://github.com/baichuan-inc/Baichuan2

技术报告:

https://cdn.baichuan-ai.com/paper/Baichuan2-technical-report.pdf

预训练

Baichuan 2 base模型(即基座模型)和其他模型的对比评测结果如下,可以看出多数评测数据上Baichuan 2遥遥领先!

预训练数据集

在构建数据的时候,本着追求数据的全面性和代表性,从多个来源收集数据,包括一般的互联网网页,书籍,研究论文,代码库等。训练语料库的组成如Figure 1所示:

可以看出,数据类型比较广泛,Top3数据类型是科技、商业和娱乐。

数据处理:主要关注数据的数量和质量。

  • 数量:构建了一个大规模的聚类和去重系统,支持LSH(局部敏感哈希)类和embedding类形式的数据特征。该系统能够在几小时内对万亿级的数据进行聚类和去重,从而保证数据的高效利用。基于聚类技术对文档、段落和句子进行去重和评分。这些分数用于后续预训练步骤的数据抽样。不同数据处理阶段的训练数据规模如Figure 2 所示:

  • 质量:句子级别质量过滤,过滤暴力、色情、种族歧视、仇恨言论等有害内容。

模型架构

在模型架构层面,主要还是基于Transformer,但是做了如下修改:

Tokenizer

分词器Tokenizer需要平衡两个关键因素:高压缩率以实现高效推理,以及适当大小的词汇表以确保每个词嵌入被充分训练。为此,Baichuan 2的词汇表大小从 Baichuan 1的 64,000 扩展到 125,696。

在Tokenizer方面使用来自 SentencePiece 的字节对编码(BPE)。需要补充说明的是,不对输入文本使用任何规范化,也不像 Baichuan 1那样添加虚拟前缀。此外,将数值分割成单个数字以更好地编码数值数据。为了处理包含额外空格的代码数据,在Tokenizer中添加了仅包含空格的token。字符覆盖率设置为0.9999,稀有字符回退到 UTF-8 字节。将token到最大长度设置为32,以兼容较长的中文短语。Baichuan 2 Tokenizer的训练数据来自 Baichuan 2 预训练语料库,为了提高覆盖范围采样更多代码示例和学术论文数据。Table 2展示了Baichuan 2分词器与其他分词器的详细比较。

位置编码。在 Baichuan 1 的基础上,为 Baichuan 2-7B 采用 Rotary Positional Embedding(RoPE),为 Baichuan 2-13B 采用 ALiBi。ALiBi是一种较新的位置编码技术,可以改善外推性能。然而,大多数开源模型使用 RoPE 作为位置embeddings,像 Flash Attention这样的注意力机制优化方法。这是因为Flash Attention是基于乘法的,无需将 attention_mask 传递给注意力操作,所以采用RoPE更合适。尽管如此,从初步实验结果发现,位置嵌入的选择对模型性能影响不大。为了促进关于bias-based 和 multiplication-based注意力机制的进一步研究,在 Baichuan 2-7B 上应用 RoPE,在 Baichuan 2-13B 上应用 ALiBi(与Baichuan 1 保持一致)。

激活函数和归一化

使用 SwiGLU 激活函数,这是一种 switch-activated 的 GLU 变体。然而,SwiGLU有一个“双线性”层,并包含三个参数矩阵,与 原始Transformer的前馈层有两个矩阵不同,因此将隐层尺寸从 4 倍隐层尺寸减少到 8/3 隐层尺寸,并四舍五入为 128 的倍数。

对于 Baichuan 2 的注意力层,采用由 xFormers2 实现的内存高效注意力。通过利用 xFormers 的优化注意力和偏置能力,可以在降低内存开销的同时有效地结合 ALiBi 的基于偏置的位置编码。这为 Baichuan 2 的大规模训练提供了性能和效率优势。

在Transformer Block的输入应用层归一化(Layer Normalization),这对于warm-up更具鲁棒性。此外,使用 RMSNorm(均方根归一化),这种方法只计算输入特征的方差,效率更高。

优化器

选用AdamW优化器,β1 和 β2 分别设置为 0.9 和 0.95。使用 0.1 的权重衰减,并将梯度范数剪切到 0.5。模型先用 2,000 个线性缩放step进行warmed up,达到最大学习率,然后应用余弦衰减直到最小学习率。参数和学习率详情见于Table 3:

混合精度: 模型训练使用 BFloat16 混合精度,在前向和反向计算中使用BFloat16,而在优化器更新中使用Float32。与Float16相比,BFloat16 具有更好的动态范围,使其对训练大型语言模型中的大值更具鲁棒性。然而,BFloat16 的低精度在某些设置中会引发一些问题。例如,在一些 RoPE 和 ALibi 的实现中,当整数超过 256 时,torch.arange 操作会由于碰撞而失败,导致无法区分附近位置。因此,对于一些对于值敏感的操作,如位置嵌入,使用完整精度。

NormHead: 为了稳定训练并提高模型性能,对输出嵌入(也称为“head”)进行归一化。NormHead在实验中有两个优点。

  • 稳定训练。在实验中发现head的范数容易不稳定,训练过程中稀有token嵌入的范数变小,会扰乱训练动态。NormHead 可以显著稳定训练动态。
  • 降低了L2距离在计算logits时的影响。实验中发现语义信息主要通过嵌入的余弦相似性而不是 L2 距离编码。由于当前的线性分类器通过点积计算 logits,它是 L2 距离和余弦相似性的混合。NormHead 减轻了在计算 logits 时 L2 距离的干扰。

Max-z 损失: 在训练过程中,LLM 的 logits 可能变得非常大。虽然 softmax 函数对于绝对 logits 值是不可知的,因为它只依赖于它们的相对值。大的 logits 在推理过程中会带来问题,因为常见的重复惩罚实现(如 Hugging Face 实现3中的 model.generate)直接将标量应用于 logits。以这种方式收缩非常大的 logits 可以显著改变 softmax 之后的概率,使模型对重复惩罚超参数的选择敏感。受到 NormSoftmax 和 PaLM 中的辅助 z-损失的启发,添加一个max-z loss 对logit值进行归一化:

其中z是最大的logit值。这有助于稳定训练并使推理时对超参数更具鲁棒性。

Scaling Laws

随着模型大小、数据集大小和用于训练的计算浮点数的增加,模型的性能会提高。并且为了获得最佳性能,所有三个因素必须同时放大。当不受其他两个因素的制约时,模型性能与每个单独的因素都有幂律关系。当这种幂率关系出现时,可以提前对模型的性能进行预测。基于该定律可以在深度学习和大型语言模型的训练代价变得越来越昂贵的当下确保性能。

具体如何操作呢?在训练数十亿参数的大型语言模型之前,先训练一些小型模型,并为训练更大模型拟合缩放定律。训练了从 10M 到 3B 一系列模型(最终模型的 1/1000 到 1/10),并且每个模型最多训练 1 万亿个token,使用的超参数和数据集与Baichuan 2相同。根据不同模型的最终损失,可以获取从训练 flops 到目标损失的映射。为了拟合模型的缩放定律,采用了 Henighan 等人(2020)给出的公式:

其中是不可约损失,第一项是可约损失,它被表示为一个幂律缩放项。是训练 flops, 是在该 flops 中模型的最终损失。使用 SciPy4 库的 curve_fit函数来拟合参数。最终拟合的缩放曲线以及预测的 70 亿和 130 亿参数模型的最终损失如Figure 4 所示。可以看到,拟合的缩放定律准确地预测了 Baichuan 2 的最终损失。

通过这个实验,研究人员可以确定最终的模型规模,并为训练提供相应的资源配置,以保证训练的高效性和性能表现。

Infrastructure

为了实现GPU资源的高效利用,研究人员为弹性训练框架智能集群调度策略开发了一种协同设计方法。

由于 GPU 在多用户和任务之间共享,每个任务的具体行为不可预测,这通常导致集群中出现空闲的 GPU 节点。由于配置8块A800 GPUs的单个机器足以满足 Baichuan 7B 和 Baichuan 13B模型的内存需求,因此训练框架的设计主要集中在机器级弹性。机器级弹性使其能够根据集群状态动态修改任务资源,从而为智能调度算法奠定基础。

为满足机器级弹性的要求,训练框架集成了张量并行和ZeRO 驱动的数据并行。在每台机器内部设置张量并行,并使用ZeRO共享数据并行,以实现机器之间的弹性缩放。

此外,采用张量分割技术。通过分割某些计算以减少峰值内存占用,如大词汇表的交叉熵计算。这种方法使其能够在不增加额外计算和通信的情况下满足内存需求,使系统更高效。

为了在不影响模型准确性的前提下进一步加速训练,研究人员实现了混合精度训练,在这里使用 BFloat16 执行前向和反向计算,而在优化器更新时使用Float32。此外,为了有效地将训练集群扩展到数千个GPU,整合了以下技术,以避免降低通信效率:

  • 拓扑感知的分布式训练。在大规模集群中,网络连接经常跨越多层交换机。通过策略性地安排分布式训练的排名,以最大程度地减少不同交换机之间的频繁访问,从而减少延迟并提高整体训练效率。
  • ZeRO 的混合和分层分区。通过将参数分区到 GPU,ZeRO3 以增加全收集通信开销为代价,减少内存消耗。当扩展到数千个 GPU 时,这种方法会带来显著的通信瓶颈。为了解决这个问题,研究人员提出了一种混合和分层分区方案。具体来说,首先将优化器状态分区到所有 GPU, 然后自适应地决定哪些层需要激活ZeRO3,以及是否分层分区参数。

通过整合这些策略,该系统能够在 1,024 个 NVIDIA A800 GPU 上高效地训练 Baichuan 2-7B 和 Baichuan 2-13B 模型,实现超过 180 TFLOPS 的计算效率。

对齐

Baichuan 2 还引入了对齐过程,从而产生了两个Chat模型:Baichuan 2-7B-Chat 和 Baichuan 2-13B-Chat。Baichuan 2的对齐过程包括两个部分:有监督微调(SFT)和来自人类反馈的强化学习(RLHF)。

监督微调

在监督微调阶段,标注人员为各种数据源的提示(Prompt)进行注释,每个提示根据与Claude类似的关键原则,被标记为有帮助或无害。使用交叉验证到方式验证数据质量:会让一位权威的标注者检查特定标注工作组标注的批次样本的质量,拒绝任何不符合质量标准的批次数据。最终收集超过 100k 的监督微调样本,并基于这些数据训练基座模型。接下来,通过 RLHF做强化学习以进一步改进结果。整个 RLHF 过程,包括 RM 和 RL 训练,如Figure 5 所示。

Reward Model(RM)

为提示(Prompt)设计一个3层次的分类系统,包括6个主要类别、30个次要类别和超过200个三级类别。从用户的角度来看,希望分类系统能够全面覆盖所有类型的用户需求;从训练奖励模型的角度来看,每个类别中的提示应具有足够的多样性,以确保奖励模型能够很好地泛化。给定一个提示,用不同大小和阶段(SFT,PPO)的 Baichuan 2 模型生成多样化的回应。在训练RM时,只使用由 Baichuan 2 模型族生成的回应。用于训练奖励模型的损失函数与InstructGPT的损失函数一致。训练得到的奖励模型表现与LLaMA 2一致,这表明两个回应之间的分数差距越大,奖励模型的区分准确性越高,

PPO

获得奖励模型之后,使用PPO算法进一步训练语言模型。具体使用了4种模型:actor模型(负责生成回应)、reference模型(用于计算固定参数的KL惩罚)、reward模型(提供整个回应的总体奖励,固定参数)以及 critic模型(用于学习每个token的值)。

在RLHF训练过程中,critic模型在初始训练时先做20个step的warmed up。再通过标准PPO算法更新critic和actor模型。对于所有模型,使用了0.5的梯度裁剪、5e-6的恒定学习率、PPO裁剪阈值ϵ = 0.1。将KL惩罚系数β设为0.2,并随着step的增加逐渐减小到0.005。对于所有Chat模型包括Baichuan 2-7B-Chat和Baichuan 2-13B-Chat进行350次迭代。

安全

百川的研究人员认为模型的安全性改进不仅在于数据清洗或对齐阶段的约束,还在于所有训练阶段中积极获取正面知识并识别负面知识。在整个Baichuan 2训练过程基于这一理念增强了模型的安全性。

预训练阶段

在预训练阶段,主要关注数据的安全性。整个预训练数据集进行严格的数据过滤流程,从而增强安全性。官方制定了一套规则和模型,以去除有害内容,如暴力、色情、种族歧视、仇恨言论等。

此外,策划了一个中英文双语数据集,包括数百家知名网站的数百万网页,这代表了各种正面价值领域,涵盖政策、法律、弱势群体、普遍价值观、传统美德等。同时提高对该数据集的采样概率。

对齐阶段

建立了一个包含6种类型攻击和100多种细粒度安全价值类别的红队程序,由10名具有传统互联网安全经验的专家标注团队初始化安全对齐提示(Prompt)。这些初始化提示是从预训练数据集中检索相关片段,然后创建回应,最终产生了约1,000个初始化的标注提示数据。

  • 专家标注团队通过初始化的对齐模型引导了一个50人的外包标注团队,进行红蓝对抗,生成了20万个攻击提示。
  • 通过使用专门的多值监督采样方法,最大程度地利用攻击数据,以生成不同安全级别的回应。

在RL优化阶段,也将安全性作为首要考虑:

  • 在安全性强化的开始,DPO 方法有效地利用了有限数量的标注数据,以增强对特定脆弱性问题的性能。
  • 通过使用集成有益和无害目标的奖励模型,进行了PPO安全性强化训练。

总结

模型百川2
参数量7b,13b
训练token数2.6万亿
tokenizerBPE
词表大小125696
位置编码7b:RoPE ; 13b:ALiBi (影响不大)
最长上下文4096
激活函数SwiGLU
归一化Layer Normalization + RMSNorm
注意力机制xFormers2
优化器AdamW+NormHead+Max-z损失

参考:

【论文阅读】《Baichuan 2: Open Large-scale Language Models》

这篇关于百川2模型解读的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/962282

相关文章

Golang的CSP模型简介(最新推荐)

《Golang的CSP模型简介(最新推荐)》Golang采用了CSP(CommunicatingSequentialProcesses,通信顺序进程)并发模型,通过goroutine和channe... 目录前言一、介绍1. 什么是 CSP 模型2. Goroutine3. Channel4. Channe

MySQL中时区参数time_zone解读

《MySQL中时区参数time_zone解读》MySQL时区参数time_zone用于控制系统函数和字段的DEFAULTCURRENT_TIMESTAMP属性,修改时区可能会影响timestamp类型... 目录前言1.时区参数影响2.如何设置3.字段类型选择总结前言mysql 时区参数 time_zon

MySQL中的锁和MVCC机制解读

《MySQL中的锁和MVCC机制解读》MySQL事务、锁和MVCC机制是确保数据库操作原子性、一致性和隔离性的关键,事务必须遵循ACID原则,锁的类型包括表级锁、行级锁和意向锁,MVCC通过非锁定读和... 目录mysql的锁和MVCC机制事务的概念与ACID特性锁的类型及其工作机制锁的粒度与性能影响多版本

Redis过期键删除策略解读

《Redis过期键删除策略解读》Redis通过惰性删除策略和定期删除策略来管理过期键,惰性删除策略在键被访问时检查是否过期并删除,节省CPU开销但可能导致过期键滞留,定期删除策略定期扫描并删除过期键,... 目录1.Redis使用两种不同的策略来删除过期键,分别是惰性删除策略和定期删除策略1.1惰性删除策略

Redis与缓存解读

《Redis与缓存解读》文章介绍了Redis作为缓存层的优势和缺点,并分析了六种缓存更新策略,包括超时剔除、先删缓存再更新数据库、旁路缓存、先更新数据库再删缓存、先更新数据库再更新缓存、读写穿透和异步... 目录缓存缓存优缺点缓存更新策略超时剔除先删缓存再更新数据库旁路缓存(先更新数据库,再删缓存)先更新数

Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)

《Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)》:本文主要介绍Python基于火山引擎豆包大模型搭建QQ机器人详细的相关资料,包括开通模型、配置APIKEY鉴权和SD... 目录豆包大模型概述开通模型付费安装 SDK 环境配置 API KEY 鉴权Ark 模型接口Prompt

C#反射编程之GetConstructor()方法解读

《C#反射编程之GetConstructor()方法解读》C#中Type类的GetConstructor()方法用于获取指定类型的构造函数,该方法有多个重载版本,可以根据不同的参数获取不同特性的构造函... 目录C# GetConstructor()方法有4个重载以GetConstructor(Type[]

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了