Post-Training有多重要?一文带你了解全部细节

2024-09-08 00:52

本文主要是介绍Post-Training有多重要?一文带你了解全部细节,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. 简介

  • 随着LLM学界和工业界日新月异的发展,不仅预训练所用的算力和数据正在疯狂内卷,后训练(post-training)的对齐和微调方法也在不断更新。
  • InstructGPT、WebGPT等较早发布的模型使用标准RLHF方法,其中的数据管理风格和规模似乎已经过时。
  • 近来,Meta、谷歌和英伟达等AI巨头纷纷发布开源模型,附带发布详尽的论文或报告,包括Llama 3.1、Nemotron 340B、Gemma 2,以及Apple Intellegence的基础模型报告。

2. post-training方法

2.1 llama3.1

在这里插入图片描述

2.1.1 Post-training Data

  • Human Preference Data
    在这里插入图片描述

    • 每个prompt由两个不同的模型构造两个response,并且每条数据标注有四个等级(特别好,好,稍好,略微好)。注意这里不同的模型可以是不同的混合比例数据,对齐策略和擅长不同能力的模型。
    • 每条reference data有三个不同的等级(edited > chosen > rejected),其中edited是标注者人工修改的。
    • 具体的人类偏好数据见上表,其中,General English包含多种子类,比如说知识问答,精确的指令遵循等。不涉及具体的能力。
    • reward model训练会使用所有迭代轮次的偏好数据,并且仅使用标注为(特别好,好)的数据。
    • 每轮DPO训练仅使用最后一批的偏好数据,且使用标注为(特别好,好)的数据。
  • SFT Data
    在这里插入图片描述

    • sft数据组成:有拒绝回复的人工标注集合,某种能力的合成数据及人工标注数据。
    • RS(Rejection sampling):每个prompt用当前最好的模型得到10-30个response,并用reward模型选择最好的一个回复,最后,在最后一轮迭代中,引入系统提示来引导模型符合语气,风格和格式。
    • 详细的sft数据组成见上表,在每一轮post-training中,都会仔细的调整数据比例,在众多的benchmark得到最佳效果。最终的数据训练epoch会多次在高质量数据源上采样(有点类似于Does your data spark joy? Performance gains from domain upsampling at the end of training这篇工作)
  • Data Processing and Quality Control
    由于大部分的post-training数据都是模型生成的,所以需要细致的清洗和质量控制。

    • 数据清洗:通过规则的方式过滤一些数据,比如说过度使用表情符号,感叹号,或者是 “I’m sorry” or “I apologize”等句式的比例。
    • 数据剪切:
      • Topic classification: 训练分类器(eg.Llama 3 8B)将所有数据分为两类,Coarsely-grained buckets (“mathematical reasoning”) and fine-grained buckets (“geometry and trigonometry”)。
      • Quality scoring: 用RM和llama3对数据进行打分。其中,RM的前1/4分认为是高质量数据;llama3会对一般数据有三个维度打分(accuracy, instruction following, and tone/presentation),代码数据有两个维度打分(bug identification and user intention)。最终发现,RM和llama3打分有很大的不重合部分,并采用RM或llama3认为高质量的数据。
      • Difficulty scoring:通过Instag score[2]和llama score[3]来判断。Instag score主要是通过llama3 70B判断sft prompt的意图,意图越多prompt越难。llama score是通过llama3用三分制衡量prompt的难度。
      • Semantic deduplication: 使用RoBERTa模型对所有数据进行聚类,并对每个类簇下数据使用quality score × difficulty score来进行排序,从高分到低分开始选择,并过滤掉cosine similarity大于某个阈值的数据。

2.1.2 Modeling

  • chat dialog format
    设计了一个新的多消息聊天协议,它使用了各种特殊的头和终止令牌。头token用于指示会话中每个消息的源和目的地。同样,终止token表示何时在人类和AI之间交替发言
  • Reward Modeling
    • loss:在这里去除掉了 m ( r ) m(r) m(r)
      在这里插入图片描述
    • 数据:每条prompt有2或3个response,并将所有response随机打乱后和prompt拼接起来作为一条训练数据。
  • Supervised Finetuning
    • 使用rejection-sampled data and other data sources数据,在每条数据的target token上计算cross entropy loss。
    • LR: 10 e − 5 10e^{-5} 10e5
  • Direct Preference Optimization
    • 超参:learning rate of 10−5 and set the β hyper-parameter to be 0.1
    • Masking out formatting tokens: mask掉一些特殊符号,比如header和termination tokens。We observe that having these tokens contribute to the loss may lead to undesired model behaviors such as tail repetition or abruptly generating termination tokens。
    • Regularization with NLL loss:add an additional negative log-likelihood (NLL) loss term with a scaling
      coefficient of 0.2 on the chosen sequence[4]
  • Model Averaging
    在RM, SFT, or DPO阶段,averger不同版本数据和超参得到的模型。
  • Iterative Rounds
    重复上面的stage 6轮,每轮的reference annotations and SFT data, sampling synthetic data都是从最新的模型得到的。

2.1.3 能力建设

  1. code
  • Expert training:在pretrain过后,在1T token训练数据中continuing pretrain(数据中85%+为代码数据)。用于post-training阶段搜集数据。
    • 用领域数据continuing pretrain能提升该领域内的效果[5]。
    • 在pretrain后期,最后几千个steps,将context length扩充到16K tokens,并用高质量代码数据训练。
    • 使用post-training中用到的对齐方法对齐expert model。
  • Synthetic data generation:code generation的问题有: difficulty in following instructions, code syntax errors, incorrect code generation, and difficulty in fixing bug。通过人工标注数据和生成数据都是可以很好的解决这些问题。模型构造数据成本更低,不受标注人员专业水平限制。总共构造了2.7M数据用在SFT阶段
    • Synthetic data generation: execution feedback:用更好的模型生成的数据微调较差的模型是有效果的,但用llama3 405B模型生成的数据微调自己效果甚微,甚至会有反向的作用。因此采用execution feedback的方式解决这些问题,下面介绍了生成代码数据的一些进程。
      • Problem description generation:生成一系列包含各个主题的代码问题描述(为了保证多样性,会随机选择一些代码段用来帮助模型生成代码问题描述)。
      • Solution generation: 在指定代码语言的前提下,让llama3生成各个代码问题的解答。注意:增加一些通用的代码规则以及在注释中解释其思考过程对于提升代码效果是有帮助的。
      • Correctness analysis: 1)通过parser或linter判断生成代码是否有语法错误,格式问题等。2)让模型生成单元测试并执行。
      • Error feedback and iterative self-correction: 当有任何的代码错误,就让模型根据反馈自我修正。修正prompt包括了初始问题,错误的代码, parser/linter/teste中得到的反馈。只有通过所有检查的数据采用用于SFT,最终有20%的错误数据自我修正。
      • Fine-tuning and iterative improvement: finetune过程分为多个阶段,每个阶段都基于上个阶段。在经过每个阶段优化后,模型在代码生成上都会有提升。
    • Synthetic data generation: programming language translation:为了提升一些不常见的语言的效果(Typescript/PHP),希望可以将一些常见的语言(C++/java)翻译到不常见的语言[6]。
      在这里插入图片描述
    • Synthetic data generation: backtranslation
      • Generate: 让llama3生成代表目标能力的数据(例如,给代码段添加注释或解释某个代码段。)
      • Backtranslate:让模型反向翻译代码,例如从代码段的注释到代码,从某段代码的解释生成该代码。
      • Filter:用初始的code作为reference,让模型判断生成代码的质量(eg,有多少信心生成的代码跟初始差不多)。在SFT阶段,用自我验证分数高的数据进行训练。
  • System prompt steering during rejection sampling:在rejection sampling中使用系统prompt提升生成代码的质量。在这里插入图片描述
  • Filtering training data with execution and model-as-judge signals:用于rejection-sampled中的数据,会有一些语法错误的数据,为了解决这个问题,使用“model-as-judge”方法,在代码准确性和格式两个方法进行0/1评分,最终选择达到2分的数据作为训练样本。当然,这样处理有个问题,会大量过滤一些困难的prompt,所以一些人工修改一些困难样本的response,让其在模型评估下达到比较好的得分。
  1. Multilinguality
  • Expert training:在pretrain后,用90%占比的多语言数据continuing pretrain。然后进行post-training。然后利用这个export model搜集更高质量的非英语数据,一直到pretrain完成。
  • Multilingual data collection:multilingual SFT数据2.4%人工标注,44.2%是其他NLP任务,18.8%是rejection sampled data及34.6% translated reasoning data。
    • Human annotations: 从语言学家和说母语者那搜集高质量的标注数据。
    • Data from other NLP tasks:
    • Rejection sampled data:通过人工标注的prompt生成rejection sampled data。
      • Generation:post-training前几轮中,temperature范围在[0.2, 1],并随机选择作为超参进行生产数据。在最后一轮中,temperature为0.6。
      • Selection:首先判断语言是否一致,然后通过reward model进行筛选数据。
    • Translated data:翻译synthetic quantitative reasoning data 数据到其他语言用于训练,能提升MGSM效果。
  1. Math and Reasoning
  • 难点:
    • Lack of prompts: 随着问题复杂度的提供,用于SFT的prompt数据减少。
    • Lack of ground truth chain of thought: 缺少正确的CoT数据。
    • Incorrect intermediate steps: 用模型生成中间步骤,但是中间步骤可能是错误的。
    • Teaching models to use external tools: 增加模型使用额外工具的能力,比如说代码解释器。
    • Discrepancy between training and inference: 训练和推理时,使用方式不一致。
  • 解决方法:
    • Addressing the lack of prompts: 将预训练数据中数学文本转化为qa形式,并用于模型SFT。另外,构造了数学能力的分类系统,并要求人提供相关的prompt/questions。
    • Augmenting training data with step-wise reasoning traces: 使用llama3生成一系列prompt的解题步骤,针对每个prompt,llama3会生成多个答案,然后用正确答案过滤。同时,让llama3对答案中的每个步骤自我矫正。
    • Filtering incorrect reasoning traces: 训练针对结果和step的reward model,用于过滤生成的答案。针对难度大的prompt,会使用蒙特卡洛方法+reward model去得到合理的推理路径。
    • Interleaving code and text reasoning: 提示llama3通过文本推理和相关python代码的组合来解决推理问题。代码执行结果作为反馈信息用于过滤推理错误的case。
    • Learning from feedback and mistakes: 为了模拟人类反馈,采用 生成错误答案,并prompting llama3模型去进行错误纠正从而得到正确的答案。
  1. Long Context
  • SFT and synthetic data generation:通过合成数据的方式得到长文本数据。例如,使用早期版本的llama3生成多轮qa,长文本摘要,代码推理等。
    • Question answering: 预训练数据会划分为8k的chunk,让早期的llama3模型基于随机选择的chunk生成QA pairs。在训练阶段,整个文档会作为context。
    • Summarization: 针对长文本,会分成多个8k的chunk,然后使用最好的模型对每个chunk进行总结,然后将生成的多个总结进行总结。训练的时候,会输入整个文档,让模型进行总结。同时也会基于总结的摘要生成QA对,让模型基于整个文档对question进行回答。
    • Long context code reasoning: 解析python文件,辨别import语句及依赖。然后,选择至少被其他五个文件依赖的文件,随后,从代码库中删除其中一个文件。让模型判断哪些文件依赖缺失的文件并生成相关的缺失代码。

随后,将这些生成的数据按照sequence length分类(16K,32K,64K和128K)。实验表明,将占比0.1%的合成长文本数据与最初的短文本数据混合会提升short-context和long-context的benchmark。

  • DPO:SFT模型在长文本任务上有很好的效果的话,在DPO阶段,用短文本训练数据并不会有负面影响。主要原因可能是因为DPO训练step要小于SFT。
  1. Tool Use
  • 相关的工具有:
    • Search engine
    • python interpreter
    • Mathematical computational engine
  • 目标:通过训练,模型能够通过使用这些工具解决用户的问题,包括一些多轮的对话。如果用户问题需要多个tool调用,模型可以写一个多步的plan,在sequence中调用工具,并在调用工具后进行推理。同时,也会增强模型的zero-shot能力。例如,输入长下文,没有见过的tool,工具说明和query,让模型能够生成正确的工具调用。
  • Implementation:zero-shot工具主要转化为python函数实现,模型只需要function signature和docstring作为context去生成对应的调用。另外,会将所有的函数定义和调用转化为json格式。
  • Data collection
    • 数据标注针对同一个context,给出更好的回复。如果两条回复都不好,可以编辑其中一条。然后被选择或被编辑的回复就会被加入到context中,然后继续对话。
    • 没有使用rejection samping,因为没有观察到收益。
  • Tool datasets
    • Single-step tool use:在人工构造的prompt基础上,通过few-shot的方式提升模型使用工具的准确率,在生成合适的工具调用后,直接执行并将结果加到对应prompt中。最后,模型基于这些得到最终的结果。
    • Multi-step tool use: 首先,让模型生成至少包含两个工具调用的prompt=。随后,few-shot prompt llama3生成包含交叉工具调用和推理的解题结果[7]。
      在这里插入图片描述
    • File uploads: prompt主要包含能基于所提供的文件,要求总结文件的内容,查找和修复错误,优化一段代码,执行数据分析或可视化。
      在这里插入图片描述
  • Zero-shot tool use data:通过在大量多样性的合成数据<unctions definitions, user query, corresponding
    call>上微调来优化llama3的zero-shot tool use abilities。
    • Single, nested, and parallel function calling:
    • Multi-turn function calling: 通过多个agent完成。
  1. Factuality
  • 在预训练数据中抽取数据片段;
  • 根据这些片段,让llama3模型生成基于事实的问题;
  • llama3生成这些问题的回复;
  • 对这些回复进行打分;
  • 对这些问题的信息度进行打分;
  • 在多轮回复中的信息不一致和不正确的回答进行拒绝。
  1. Steerability:增强通过系统指令的可操控性。例如,保证回复长度,格式,语调及人设的一致。
  • Data collection:让标注人物设计不同的系统指令,然后标注人员参与到与模型的对话中,并评估模型执行指令的效果。
  • Modeling:利用搜集到的数据训练reward model,rejection sampling,SFT和DPO。

2.2 Nemotron340B[8]

2.2.1 简介

  • 英伟达开源了Nemotron-4-340B等一系列模型,包括Nemotron-4-340B-Base,Nemotron-4-340B-Instruct和Nemotron-4-340B-Reward模型。预训练的高质量数据总共有9 trillion tokens,8T用于pretrain,1T用于continued pretrain,具体的细节见[9]。对齐阶段使用的数据98%都是合成的。文章里面详细的解释了一些预训练和对齐细节,有兴趣的同学,可以仔细看看。
  • 模型结构是decode only,细节包括使用RoPE,squared ReLU,没有bias,dropout rate为0,不固定input-output embedding,使用GQA。总共参数有340B,embedding parameters有9.4B,非embedding parameters有331.6B。参数细节见下表:
    在这里插入图片描述* 训练细节
    • 使用768个DGX H100节点,每个节点包含8个H100-80G。
    • 并行策略:8 TP,12 PP 和数据并行
      在这里插入图片描述
      这里面MFU表示gpu的利用率,理论峰值是100%。
  • 将预训练两个阶段,第一阶段是正常的pretrain,第二阶段是continued pretrain。在第二阶段有两个设计:
      1. 数据包含主要的continue pretrain数据和第一阶段中的高质量数据(增加高质量数据的采样比重),另外还增加了少量对齐阶段的QA格式的样本数据,同时在训练中提高模型准确率比较低的领域数据比重。
      1. learning rate schedule:优先更陡的衰减策略,而不是学习率大小。

2.2.2 Alignment

  • Reward Model
    • 数据:搜集了10K人类偏好数据,名字为HelpSteer2[10],搜集方法跟HelpSteer类似[11]。
    • 模型:multi-attribute regression reward models,能够更好的预测细粒度的reward,分辨两个相似回复的细微差别。构建方式是在Nemotron-4-340B-Base模型基础上,将最后一层softmax替换为reward head。reward head是一个linear projection,将最后一层hidden states映射到HelpSteer属性(Helpfulness,Correctness, Coherence, Complexity, Verbosity)的五维向量。在推理阶段,这五种属性值加权和后得到一个最终的reward。
    • 结果:在RewardBench上达到最高准确率。
      在这里插入图片描述
  • Alignment Data
    整个对齐阶段,只用了20K人工标注数据,其中10K用于supervised,finetuning,10K Helpsteer2数据用于reward model训练和preference finetuning,剩下超过98%的数据都是合成的。下面重点介绍下合成数据流程。
    • Prompt preparation
      • 现存prompt数据:LMSYS-Chat-1M prompts
      • 指令多样性:任务多样性(写作,Open QA,closed QA等),话题多样性(日常生活,人文学等),指令多样性(json输出,Yes-or-No回答等)
      • 模型:Mixtral-8x7B-Instruct-v0.1,生成open QA,writing,closed QA和math&coding任务prompt。
      • 方法:针对每个prompt任务,会使用不同的话题或关键词作为种子来生成[12, 13]。同时也会生成不同指令的prompt,比如说“输出格式为json”。另外,为了增加多轮能力,也会生成两轮的prompt。
        • Synthetic single-turn prompts
          在这里插入图片描述
          • topic搜集:首先,让模型生成不同的大话题,然后基于生成的大话题继续生成相关的子话题,最后加上人工搜集的话题,总共有3k个话题。
          • open QA:根据提供的话题,生成相关的prompt,为了让prompt更详细的,会让模型对生成的prompt进行修改。
          • writing:让模型根据不同的话题生成不同的文章,跟open QA一样,会对生成的prompt进行二次修改。
          • closed QA:基于C4数据集,让模型基于任意一个文档生成对应的指令,然后用固定的模版将文档和生成的指令拼接起来。
          • math&coding:在数学和代码中搜集多样的关键词集合,然后针对数学和python代码生成大类和子类,还会判断Wikipedia实体是否跟math&code有关,也会在预训练数据中搜集python或math关键词,随后加上人工搜集的,总共有12K python相关的关键词和17K数学相关的关键词。最后让模型基本搜集到的关键词生成问题。
        • Synthetic instruction-following prompts:随机选择一些合成prompt,对于每个合成prompt,会让模型随机生成一个指令,例如,你的回复必须包含三段。最后,将prompt和指令通过固定的模版合成在一起。
        • Synthetic two-turn prompts:具体形式为 “User: XXX; Assistant: XXX; User: XXX,第一轮prompt来自ShareGPT,然后生成对应的回复,最后让模型生成第二轮的问题。
        • Real-world LMSYS prompts:为了更好的模拟真实世界的请求,使用LMSYS-Chat-1M数据集[14]。然后,会将数据分成两份,一份用于SFT,一份用于preference learning。在SFT中,会过滤掉数据集中的不安全prompt,但在preference learning中保留这些,让模型能区分安全或不安全的回复。
    • Synthetic Dialogue Generation:单轮对话数据,很简单的让模型基于query生成回复。多轮对话数据,让模型不断的扮演user和assistant生成多轮对话。为了让对话多样性更好,会会通过指令的方式,设定user不同的人设信息。
    • Synthetic Preference Data Generation:合成偏好数据prompt来自 synthetic single-turn prompts, instruction-following prompts, two-turn prompts, 以及来自MATH,GSM8K,LMSYS和ShareGPT等真实数据集,然后会让不同的模型生成回复。如果prompt有ground truth的话,可以通过ground truth判断正确的回复,构造<prompt, chosen, rejected>三元组样本。如果没有的话,可以通过LLM和Reward model来判断。使用LLM来判断回复好坏,为了排除位置的影响,会讲候选的response不同的排列方式过LLM,如果两次答案一样,则认为这条样本是好的样本。但是实验结果表明,使用Reward model是更好的方式。
    • Iterative Weak-to-Strong Alignment
      在这里插入图片描述
        1. 将Mixtral-8x7B-Instruct-v0.1作为初始对齐模型,生成合成数据。
        1. 让第一步得到的合成数据,训练Nemotron-4-340B-Interm-1-Base,得到Nemotron-4-340B-Interm-1-Instruct模型,因为Nemotron-4-340B-Interm-1-Base要比Mixtral-8x7B-Instruct-v0.1更好,所以训练得到的Nemotron-4-340B-Interm-1-Instruct模型会比Mixtral-8x7B-Instruct-v0.1更好。
        1. 用Nemotron-4-340B-Interm-1-Instruct模型生成新的合成数据,同时训练Nemotron-4-340B-Interm-1-Base,得到模型Nemotron-4-340B-Interm-2-Chat。
    • Additional Data Sources: CantTalkAboutThis,Open-Platypus,PRM800K,FinQA, Glaive AI等。
  • Alignment Algorithms[15]:
    • Staged Supervised Fine-tuning:实验表明同时训练多种能力,多种能力之间会有冲突的现象,尤其是code能力。为了更好的增强模型的多种能力,这里采用了两阶段训练。
      • Code SFT:
        • 数据:通过self instruction[16]和wizard coder mutations[17]方式,构造大量的合成代码数据。通过这种方式构造了800K的数据。模型基于这些数据训练一个epoch,LR为 3 e − 7 3e{-7} 3e7,batch size为128。
      • General SFT:用多种任务组成的总共200K数据继续训练,为了防止灾难性遗忘,微调数据中有2%包含了代码数据。最后,在batch size为128,LR在 [ 1 e − 7 , 5 e − 7 ] [1e-7,5e-7] [1e7,5e7]之间,总共训练了3个epoch。在两个阶段训练中,只计算response的loss。
    • Preference Fine-tuning:这个阶段包括了多次迭代,每次迭代算法涉及DPO和RPO。
      • Direct Preference Optimization (DPO):在训练足够长时,发现模型会过拟合,在一些任务上指标会下降。所以在chosen response上增加了SFT loss。模型训练参数:learning rate within [3e-8, 3e-7], kl regularization coefficient in the DPO loss within [3e-4, 3e-3], and the weight of the SFT loss within [1e-5, 1e-3]
      • Reward-aware Preference Optimization (RPO):实验中发现,有些rejected response仅比chosen response差一点,有些则差很多,DPO算法无法衡量出这种差距,所有这里提出了RPO算法。
        在这里插入图片描述
        • 其中 r ( x , y c ) , r ( x , y L ) r(x, y_{c}),r(x, y_L) r(x,yc)r(xyL)分别表示reward model在chosen和rejected resposne的reward score。 D [ a ∣ ∣ B ] : = σ ( b ) l o g σ ( b ) σ ( a ) + ( 1 − σ ( b ) ) l o g 1 − σ ( b ) 1 − σ ( a ) D[a||B]:=\sigma(b)log\frac{\sigma(b)}{\sigma(a)}+(1-\sigma(b))log\frac{1-\sigma(b)}{1-\sigma(a)} D[a∣∣B]:=σ(b)logσ(a)σ(b)+(1σ(b))log1σ(a)1σ(b)
        • 在训练中,会使用DPO得到的checkpoint初始化RPO模型,训练数据为300K,额外加上sft loss,训练三个epoch。
        • 除此之外,现在还有不少的类DPO方法,比如说DNO,cDPO,IPO,DistillDPO和BRAINn等算法。

2.3 Gemma2[18]

2.3.1 简介

模型中使用了interleaving local-global attentions[19]以及group-query attention,另外,在训练2B和7B模型时使用了知识蒸馏的方法。

2.3.2 模型

  • 结构:decoder-only transformer architecture,具体参数如下:
    在这里插入图片描述
    • Local Sliding Window[19] and Global Attention:local sliding window窗口大小为4096,global attention span设置为8192tokens。local和global交替使用。
    • Logits soft-capping:将每个attention layer和最有一层的logits保证在[-soft_cap, +soft_cap]之间。在self-attention layer,soft_cap大小设置为50,在final layer,soft_cap大小设置为30。
      在这里插入图片描述
    • Post-norm and pre-norm with RMSNorm:使用RMSNorm对每个transformer sub-layer,attention layer及feedforward layer的input和output进行normalization。
    • Grouped-Query Attention:num_groups=2,这个参数能同时保证提升推理速度和保障下游任务的效果。

2.3.3 Pre-Trainiing

  • Training Data:总共有13 trilliion token,其中27B模型用13T数据训练,9B模型有8T数据训练,2B模型用2T数据训练。训练数据包括网络文档,code及科学文章。数据混合比例参考Gemini1.0[20]。
    • Tokenizer:使用SentencePiece,词表大小为256k。
    • Filtering:参考Gemma1。
  • Knowledge Distillation
    在这里插入图片描述
    其中, P s P_s Ps表示student的概率。

2.3.4 Post-Training

  • 数据:在Gemma1.1数据基础上进行了扩充,使用了LMSYS-chat-1M的prompt,没有用到回复。
  • SFT:在真实和合成的prompt上进行微调,每个prompt的回复都是由teacher模型生成。
  • RLHF:算法跟Gemma1.1相似,但是在reward model上有轻微改动。reward model会更大,在多轮对话能力上也更强。
  • Model merging:在将不同参数的模型进行average[21]
  • Data filtering:在合成数据后,会过滤掉包含个人信息,不安全,有毒的数据,不一致和重复的数据。包含语境归因,规避,拒绝最小化幻觉(?)改进模型的效果(这块没懂)。
  • Formatting:包含一些特定的格式内容
    在这里插入图片描述
    在这里插入图片描述

3. 对比

对比三个模型的不同后处理,可以发现有共同点和不同点:

  • 共同点:
    • 在post-training阶段都使用了合成数据,合成数据的质量可能高于人类数据,特别是对于具有挑战性的任务;
    • RLHF可以比指令微调扩展到更大规模,目前仅gemma用到了RLHF,在llama和Nemotron340B中都使用了类DPO算法;
    • 需要多轮训练和生成才能得到最佳模型,基本都是先sft,后面加上DPO或RLHF;
    • 在post-training中,数据质量和多样性很重要。因此,数据过滤是训练中最重要的部分。
  • 不同点:
    • 只有Gemma2在post-training阶段用到了RLHF,llama3.1,Nemotron340B在模型训练阶段都使用类DPO算法;
    • llama3.1会对不同的能力进行详细的优化,包括code,Multilinguality等;
    • llama3.1和Gemma2都用到model average这个策略;
    • Nemotron340B对DPO算法进行了改进,llama3则使用RS,Gemma2使用RLHF;
    • Nemotron340B在post-training阶段,为了缓解不同能力之间的冲突,将sft阶段分为两个部分,分别是code sft和general sft;
    • gemma2在模型结构有一些改进,可以自己研究下;

4. 总结

本文主要介绍了llama3.1, Nemotron340B,gemma2三个模型的一些post-training细节。涉及数据搜集和处理,模型训练,训练算法细节等方面。当然,由于篇幅很长,肯定会有一些错误的地方,欢迎大家指正。
在这里插入图片描述

5. 参考文献

  • [1] https://www.interconnects.ai/p/frontier-model-post-training
  • [2] Instag: Instruction tagging for analyzing supervised fine-tuning of large language models
  • [3] What makes good data for alignment? a comprehensive study of automatic data selection in instruction tuning,
  • [4] Iterative reasoning preference optimization
  • [5] Don’t stop pretraining: Adapt language models to domains and tasks
  • [6] reaking language barriers in multilingual mathematical reasoning: Insights and observations
  • [7] React: Synergizing reasoning and acting in language models
  • [8] Nemotron-4 340B Technical Report
  • [9] Nemotron-4 15b technical report
  • [10] Helpsteer2: Open-source dataset for training top-performing reward models
  • [11] Helpsteer: Multi-attribute helpfulness dataset for steerlm
  • [12] Enhancing chat language models by scaling high-quality instructional conversations
  • [13] Camel: Communicative agents for ”mind” exploration of large language model society
  • [14] Lmsys-chat-1m: A large-scale real-world llm conversation dataset
  • [15] Training language models to follow instructions with human feedback
  • [16] Self-instruct: Aligning language models with self-generated instructions
  • [17] Wizardcoder: Empowering code large language models with evol-instruct
  • [18] Gemma 2: Improving Open Language Models at a Practical Size
  • [19] Long-former: The long-document transformer.
  • [20] Gemini: A family of highly capable multimodal models
  • [21] Warp: On the benefits of weight averaged rewarded policies

这篇关于Post-Training有多重要?一文带你了解全部细节的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1146655

相关文章

C#使用HttpClient进行Post请求出现超时问题的解决及优化

《C#使用HttpClient进行Post请求出现超时问题的解决及优化》最近我的控制台程序发现有时候总是出现请求超时等问题,通常好几分钟最多只有3-4个请求,在使用apipost发现并发10个5分钟也... 目录优化结论单例HttpClient连接池耗尽和并发并发异步最终优化后优化结论我直接上优化结论吧,

一文带你理解Python中import机制与importlib的妙用

《一文带你理解Python中import机制与importlib的妙用》在Python编程的世界里,import语句是开发者最常用的工具之一,它就像一把钥匙,打开了通往各种功能和库的大门,下面就跟随小... 目录一、python import机制概述1.1 import语句的基本用法1.2 模块缓存机制1.

Linux Mint Xia 22.1重磅发布: 重要更新一览

《LinuxMintXia22.1重磅发布:重要更新一览》Beta版LinuxMint“Xia”22.1发布,新版本基于Ubuntu24.04,内核版本为Linux6.8,这... linux Mint 22.1「Xia」正式发布啦!这次更新带来了诸多优化和改进,进一步巩固了 Mint 在 Linux 桌面

如何评价Ubuntu 24.04 LTS? Ubuntu 24.04 LTS新功能亮点和重要变化

《如何评价Ubuntu24.04LTS?Ubuntu24.04LTS新功能亮点和重要变化》Ubuntu24.04LTS即将发布,带来一系列提升用户体验的显著功能,本文深入探讨了该版本的亮... Ubuntu 24.04 LTS,代号 Noble NumBAT,正式发布下载!如果你在使用 Ubuntu 23.

一文带你搞懂Nginx中的配置文件

《一文带你搞懂Nginx中的配置文件》Nginx(发音为“engine-x”)是一款高性能的Web服务器、反向代理服务器和负载均衡器,广泛应用于全球各类网站和应用中,下面就跟随小编一起来了解下如何... 目录摘要一、Nginx 配置文件结构概述二、全局配置(Global Configuration)1. w

SpringBoot中Get请求和POST请求接收参数示例详解

《SpringBoot中Get请求和POST请求接收参数示例详解》文章详细介绍了SpringBoot中Get请求和POST请求的参数接收方式,包括方法形参接收参数、实体类接收参数、HttpServle... 目录1、Get请求1.1 方法形参接收参数 这种方式一般适用参数比较少的情况,并且前后端参数名称必须

Python开发围棋游戏的实例代码(实现全部功能)

《Python开发围棋游戏的实例代码(实现全部功能)》围棋是一种古老而复杂的策略棋类游戏,起源于中国,已有超过2500年的历史,本文介绍了如何用Python开发一个简单的围棋游戏,实例代码涵盖了游戏的... 目录1. 围棋游戏概述1.1 游戏规则1.2 游戏设计思路2. 环境准备3. 创建棋盘3.1 棋盘类

关于数据埋点,你需要了解这些基本知识

产品汪每天都在和数据打交道,你知道数据来自哪里吗? 移动app端内的用户行为数据大多来自埋点,了解一些埋点知识,能和数据分析师、技术侃大山,参与到前期的数据采集,更重要是让最终的埋点数据能为我所用,否则可怜巴巴等上几个月是常有的事。   埋点类型 根据埋点方式,可以区分为: 手动埋点半自动埋点全自动埋点 秉承“任何事物都有两面性”的道理:自动程度高的,能解决通用统计,便于统一化管理,但个性化定

2014 Multi-University Training Contest 8小记

1002 计算几何 最大的速度才可能拥有无限的面积。 最大的速度的点 求凸包, 凸包上的点( 注意不是端点 ) 才拥有无限的面积 注意 :  凸包上如果有重点则不满足。 另外最大的速度为0也不行的。 int cmp(double x){if(fabs(x) < 1e-8) return 0 ;if(x > 0) return 1 ;return -1 ;}struct poin

2014 Multi-University Training Contest 7小记

1003   数学 , 先暴力再解方程。 在b进制下是个2 , 3 位数的 大概是10000进制以上 。这部分解方程 2-10000 直接暴力 typedef long long LL ;LL n ;int ok(int b){LL m = n ;int c ;while(m){c = m % b ;if(c == 3 || c == 4 || c == 5 ||