[NLP] LLM---<训练中文LLama2(五)>对SFT后的LLama2进行DPO训练

2023-11-07 05:20

本文主要是介绍[NLP] LLM---<训练中文LLama2(五)>对SFT后的LLama2进行DPO训练,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

当前关于LLM的共识

大型语言模型(LLM)使 NLP 中微调模型的过程变得更加复杂。最初,当 ChatGPT 等模型首次出现时,最主要的方法是先训练奖励模型,然后优化 LLM 策略。从人类反馈中强化学习(RLHF)极大地推动了NLP的发展,并将NLP中许多长期面临的挑战抛在了一边。基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback,RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步,它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。

然而,它也给 NLP 引入了一些 RL 相关的复杂性: 既要构建一个好的奖励函数,并训练一个模型用以估计每个状态的价值 (value); 又要注意最终生成的 LLM 不能与原始模型相差太远,如果太远的话会使得模型容易产生乱码而非有意义的文本。该过程非常复杂,涉及到许多复杂的组件,而这些组件本身在训练过程中又是动态变化的,因此把它们料理好并不容易。

现在主流的LLM,比如chatglm、chinese-alpaca,主要进行了三步操作:

Step1:知识学习,CLM,大规模语料库上的预训练,本步的模型拥有续写的功能

Step2:知识表达,指令微调,在指令数据上进行微调,本步骤可以使用Lora等节省显存的方式,本模型可以听懂人类指令并进行回答的功能

Step3:偏好学习,RLHF或本文所提的DPO,可以让模型的输出更符合人类偏好,通俗说就是同样一句话,得调教的让模型输出人类喜欢的表达方式,好比高情商的人说话让人舒服

第二步,还是多多少少学习了一点知识,第三步则几乎不学知识,只学表达方式了。

RLHF太耗时耗力了,得提前训练好RewardModel,然后PPO阶段,得加载4个模型,2个推理,2个训练,实在是太不友好了。

下图是SFT+RLHF的过程,对应上文的Step2和Step3,主要包括指令微调模型、训练奖励模型和PPO优化。

现在大多数目前开源的LLM模型都只做了前2步:预训练和指令微调。

而其中原因就是第3步人类反馈强化学习(RLHF)实现起来很困难:

1.需要人类反馈数据(很难收集)
2.奖励模型训练(很难训练)
3. PPO强化学习微调(不仅很耗资源,而且也很难训练)

但是能不能不要最后一步呢,一般来说还是有RLHF比较好,有主要有以下几个原因:

  1. 提高安全性和可控性;
  2. 改进交互性;
  3. 克服数据集偏差;
  4. 提供个性化体验;
  5. 符合道德规范;
  6. 持续优化和改进。

RLHF使得ChatGPT这样的大型对话模型既具备强大能力,又能够接受人类价值观的指导,生成更智能、安全、有益的对话回复。这是未来可信赖和可解释AI的重要发展方向。

所以这一步还是非常重要。那如何解决人类反馈强化学习(RLHF)训练这个难题呢?

DPO (Differentiable Policy Optimization) 算法

Rafailov、Sharma、Mitchell 等人最近发表了一篇论文 Direct Preference Optimization,论文提出将现有方法使用的基于强化学习的目标转换为可以通过简单的二元交叉熵损失直接优化的目标,这一做法大大简化了 LLM 的提纯过程。

DPO 是为实现对 LLM 的精确控制而引入的一种方法。从人类反馈强化学习(RLHF)的基础是训练奖励模型,然后使用近端策略优化(PPO)使语言模型的输出与人类的偏好相一致。这种方法虽然有效,但既复杂又不稳定。DPO 将受限奖励最大化问题视为人类偏好数据的分类问题。这种方法稳定、高效、计算量小。它无需进行奖励模型拟合、大量采样和超参数调整。

DPO(Direct Preference Optimization)是一种直接偏好优化算法,它与PPO(Proximal Policy Optimization)优化的目标相同。主要思路是:

1.定义policy模型(策略模型)和reference模型(参考模型),Policy模型是需要训练的对话生成模型,reference模型是给定的预训练模型或人工构建的模型。

2.对于给定prompt,计算两模型对正样本和负样本的概率,正样本是人类选择的回复,负样本是被拒绝的回复。

3.通过两个模型概率的差值构建DPO损失函数,惩罚policy模型对正样本概率的下降和负样本概率的上升。通过最小化DPO损失进行模型训练。

相比之下DPO就很友好,只需要加载2个模型,其中一个推理,另外一个训练,直接在偏好数据上进行训练即可:

DPO 拒绝有害问题 实战部分

数据集

数据集其实就是标准的RLHF奖励模型的训练集,下载地址在这

Anthropic/hh-rlhf · Datasets at Hugging Face

dikw/hh_rlhf_cn · Datasets at Hugging Face

其样式就是:一个context,一个选择的正样本,一个拒绝的负样本。希望这些样本能够让LLM 尽可能生成用户选择的无害回复,而不要生成有害的回复。

微调代码
下方这段代码实现了基于DPO (Differentiable Policy Optimization) 的对话模型微调。主要步骤包括:

  1. 加载预训练语言模型(这里使用llama-2-7b)并准备量化训练,采用int4量化的+少量lora 参数。
  2. 定义参考模型(int4量化的模型),也使用同样的预训练模型。
  3. 加载Helpful/Harmless数据集,并转换成所需格式。
  4. 定义DPO训练参数,包括batch size,学习率等。
  5. 定义DPO训练器,传入policy模型,参考模型,训练参数等。
  6. 进行DPO微调训练。
  7. 保存微调后的模型,只保存量lora 参数。

关键点:

1. 使用DPO损失函数实现安全性约束的模型训练。不需要额外在训练一个奖励模型。
2. 这也导致整个训练过程只需要策略模型和参考模型 2个LLM模型,不需要额外的显存去加载奖励模型。
3. 整个训练过程策略模型和参考模型可以进行4int的模型量化 + 少量的lora 参数

综上,这段代码对预训练语言模型进行DPO微调,以实现安全可控的对话生成

#!/usr/bin/env python
# coding: utf-8from typing import Dictimport torch
from datasets import Dataset, load_dataset
from trl import DPOTrainer
import bitsandbytes as bnbfrom transformers import TrainingArguments
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from peft import (LoraConfig,get_peft_model,prepare_model_for_kbit_training
)output_dir1 = "./dpo_output_dir1"
output_dir2 = "./dpo_output_dir2"base_model = "/home/work/llama-2-7b"###准备训练数据
dataset = load_dataset("json", data_files="./dpo_dataset/harmless_base_cn_train.jsonl")
train_val = dataset["train"].train_test_split(test_size=2000, shuffle=True, seed=42
)
train_data = train_val["train"]
val_data = train_val["test"]def extract_anthropic_prompt(prompt_and_response):final = ""for sample in prompt_and_response:final += sample["role"] + "\n" + sample["text"]final += "\n"return finaldef get_hh(dataset, split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.The dataset is converted to a dictionary with the following structure:{'prompt': List[str],'chosen': List[str],'rejected': List[str],}Prompts should be structured as follows:\n\nHuman: <prompt>\n\nAssistant:Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:."""dataset = datasetif sanity_check:dataset = dataset.select(range(min(len(dataset), 1000)))def split_prompt_and_responses(sample) -> Dict[str, str]:prompt = extract_anthropic_prompt(sample["context"])return {"prompt": prompt,"chosen": sample["chosen"]["role"] + "\n" + sample["chosen"]["text"],"rejected": sample["rejected"]["role"] + "\n" + sample["rejected"]["text"],}return dataset.map(split_prompt_and_responses)train_dataset = get_hh(train_data, "train", sanity_check=True)
eval_dataset = get_hh(val_data, "test", sanity_check=True)def find_all_linear_names(model):# cls = bnb.nn.Linear8bitLtcls = bnb.nn.Linear4bitlora_module_names = set()for name, module in model.named_modules():if isinstance(module, cls):names = name.split('.')lora_module_names.add(names[0] if len(names) == 1 else names[-1])if 'lm_head' in lora_module_names:  # needed for 16-bitlora_module_names.remove('lm_head')return list(lora_module_names)def print_trainable_parameters(model):"""Prints the number of trainable parameters in the model."""trainable_params = 0all_param = 0for _, param in model.named_parameters():all_param += param.numel()if param.requires_grad:trainable_params += param.numel()print(f"trainable params: {trainable_params} || all params: {all_param} || trainables%: {100 * trainable_params / all_param}")tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"  # Fix weird overflow issue with fp16 trainingbnb_4bit_compute_dtype = "float16"
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_4bit_quant_type = "nf4"
use_nested_quant = Falsebnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type=bnb_4bit_quant_type,bnb_4bit_compute_dtype=compute_dtype,bnb_4bit_use_double_quant=use_nested_quant,
)model = AutoModelForCausalLM.from_pretrained(base_model,trust_remote_code=True,quantization_config=bnb_config,device_map="auto")
model.config.use_cache = False
model = prepare_model_for_kbit_training(model)modules = find_all_linear_names(model)
config = LoraConfig(r=8,lora_alpha=16,lora_dropout=0.05,bias="none",target_modules=modules,task_type="CAUSAL_LM",
)model = get_peft_model(model, config)
print_trainable_parameters(model)###定义参考模型
model_ref = AutoModelForCausalLM.from_pretrained(base_model,trust_remote_code=True,quantization_config=bnb_config,device_map="auto")
###定义dpo训练参数
training_args = TrainingArguments(per_device_train_batch_size=1,max_steps=100,remove_unused_columns=False,gradient_accumulation_steps=2,learning_rate=3e-4,evaluation_strategy="steps",output_dir="./test",
)###定义dpo训练器
dpo_trainer = DPOTrainer(model,model_ref,args=training_args,beta=0.1,train_dataset=train_dataset,eval_dataset=eval_dataset,tokenizer=tokenizer,
)
###训练
dpo_trainer.train()
###模型保存
dpo_trainer.save_model(output_dir1)dpo_trainer.model.save_pretrained(output_dir2)
tokenizer.save_pretrained(output_dir2)

训练过程

其中看出加载了2遍int4量化的模型到显存中,需要训练的策略模型只有一部分lora参数,而参考模型就是原始模型本身.

模型保存

保存下来的参数也就是lora参数,这部分lora 参数就学会了如何拒绝回答有害问题。

至此,我们就学会了如何利用使用DPO +Qlora 实现在完成RLHF的实战。

使用场景

核心原则:偏好数据集中的good/bad response都是和SFT model的训练数据同分布的,也可以说模型是可以生成good/bad response的。

场景1

已有一个SFT model,为了让它更好,对它的output进行偏好标注,然后使用DPO进行训练,这是最正常的使用场景,但是偏好数据集确实避免不了的

场景2

场景1的改进版本,偏好标注不由人来做,而是让gpt4或者一个reward model来标注好坏,至于reward model怎么来,就各凭本事吧

场景3

没有SFT model只有偏好数据集,那就先在偏好数据即中的进行训练,然后在进行DPO的训练。先SFT就是为了符合上文的核心原则

OpenAI独家绝技RLHF也被开源超越啦?!DPO让小白轻松玩转RLHF![已开源] - 知乎 (zhihu.com)

RLHF中的「RL」是必需的吗?有人用二进制交叉熵直接微调LLM,效果更好 - 知乎 (zhihu.com)

直接偏好优化:你的语言模型其实是一个奖励模型 - 知乎 (zhihu.com)

消费级显卡搞定RLHF——DPO算法+QLora微调LLM拒绝有害问题回答实战 - 知乎 (zhihu.com)

使用 DPO 微调 Llama 2 - 知乎 (zhihu.com)

DPO(Direct Preference Optimization):LLM的直接偏好优化 - 知乎 (zhihu.com)

DPO: Direct Preference Optimization 论文解读及代码实践 - 知乎 (zhihu.com)GitHub - mzbac/llama2-fine-tune: Scripts for fine-tuning Llama2 via SFT and DPO.

DPO——RLHF 的替代之《Direct Preference Optimization: Your Language Model is Secretly a Reward Model》论文阅读 - 知乎 (zhihu.com)

这篇关于[NLP] LLM---<训练中文LLama2(五)>对SFT后的LLama2进行DPO训练的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/361503

相关文章

中文分词jieba库的使用与实景应用(一)

知识星球:https://articles.zsxq.com/id_fxvgc803qmr2.html 目录 一.定义: 精确模式(默认模式): 全模式: 搜索引擎模式: paddle 模式(基于深度学习的分词模式): 二 自定义词典 三.文本解析   调整词出现的频率 四. 关键词提取 A. 基于TF-IDF算法的关键词提取 B. 基于TextRank算法的关键词提取

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

业务中14个需要进行A/B测试的时刻[信息图]

在本指南中,我们将全面了解有关 A/B测试 的所有内容。 我们将介绍不同类型的A/B测试,如何有效地规划和启动测试,如何评估测试是否成功,您应该关注哪些指标,多年来我们发现的常见错误等等。 什么是A/B测试? A/B测试(有时称为“分割测试”)是一种实验类型,其中您创建两种或多种内容变体——如登录页面、电子邮件或广告——并将它们显示给不同的受众群体,以查看哪一种效果最好。 本质上,A/B测

vscode中文乱码问题,注释,终端,调试乱码一劳永逸版

忘记咋回事突然出现了乱码问题,很多方法都试了,注释乱码解决了,终端又乱码,调试窗口也乱码,最后经过本人不懈努力,终于全部解决了,现在分享给大家我的方法。 乱码的原因是各个地方用的编码格式不统一,所以把他们设成统一的utf8. 1.电脑的编码格式 开始-设置-时间和语言-语言和区域 管理语言设置-更改系统区域设置-勾选Bata版:使用utf8-确定-然后按指示重启 2.vscode

MiniGPT-3D, 首个高效的3D点云大语言模型,仅需一张RTX3090显卡,训练一天时间,已开源

项目主页:https://tangyuan96.github.io/minigpt_3d_project_page/ 代码:https://github.com/TangYuan96/MiniGPT-3D 论文:https://arxiv.org/pdf/2405.01413 MiniGPT-3D在多个任务上取得了SoTA,被ACM MM2024接收,只拥有47.8M的可训练参数,在一张RTX

遮罩,在指定元素上进行遮罩

废话不多说,直接上代码: ps:依赖 jquer.js 1.首先,定义一个 Overlay.js  代码如下: /*遮罩 Overlay js 对象*/function Overlay(options){//{targetId:'',viewHtml:'',viewWidth:'',viewHeight:''}try{this.state=false;//遮罩状态 true 激活,f

利用matlab bar函数绘制较为复杂的柱状图,并在图中进行适当标注

示例代码和结果如下:小疑问:如何自动选择合适的坐标位置对柱状图的数值大小进行标注?😂 clear; close all;x = 1:3;aa=[28.6321521955954 26.2453660695847 21.69102348512086.93747104431360 6.25442246899816 3.342835958564245.51365061796319 4.87

Spark MLlib模型训练—聚类算法 PIC(Power Iteration Clustering)

Spark MLlib模型训练—聚类算法 PIC(Power Iteration Clustering) Power Iteration Clustering (PIC) 是一种基于图的聚类算法,用于在大规模数据集上进行高效的社区检测。PIC 算法的核心思想是通过迭代图的幂运算来发现数据中的潜在簇。该算法适用于处理大规模图数据,特别是在社交网络分析、推荐系统和生物信息学等领域具有广泛应用。Spa

SigLIP——采用sigmoid损失的图文预训练方式

SigLIP——采用sigmoid损失的图文预训练方式 FesianXu 20240825 at Wechat Search Team 前言 CLIP中的infoNCE损失是一种对比性损失,在SigLIP这个工作中,作者提出采用非对比性的sigmoid损失,能够更高效地进行图文预训练,本文进行介绍。如有谬误请见谅并联系指出,本文遵守CC 4.0 BY-SA版权协议,转载请联系作者并注

解决Office Word不能切换中文输入

我们在使用WORD的时可能会经常碰到WORD中无法输入中文的情况。因为,虽然我们安装了搜狗输入法,但是到我们在WORD中使用搜狗的输入法的切换中英文的按键的时候会发现根本没有效果,无法将输入法切换成中文的。下面我就介绍一下如何在WORD中把搜狗输入法切换到中文。