RLHF学习

2024-01-27 00:04
文章标签 学习 rlhf

本文主要是介绍RLHF学习,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

整体流程

三个步骤分解:

  1. 预训练一个语言模型 (LM) ;
  2. 聚合问答数据并训练一个奖励模型 (Reward Model,RM) ;
  3. 用强化学习 (RL) 方式微调 LM。

在这里插入图片描述

在这里插入图片描述

RW

RM 的训练是 RLHF 区别于旧范式的开端。这一模型接收一系列文本并返回一个标量奖励,数值上对应人的偏好。我们可以用端到端的方式用 LM 建模,或者用模块化的系统建模 (比如对输出进行排名,再将排名转换为奖励) 。这一奖励数值将对后续无缝接入现有的 RL 算法至关重要。

  • 关于模型选择方面:
    RM 可以是另一个经过微调的 LM,也可以是根据偏好数据从头开始训练的 LM。例如 Anthropic 提出了一种特殊的预训练方式,即用偏好模型预训练 (Preference Model Pretraining,PMP) 来替换一般预训练后的微调过程。因为前者被认为对样本数据的利用率更高。但对于哪种 RM 更好尚无定论。

  • 过程:
    在这里插入图片描述

  • Bradley-Terry(BT)模型是一个常见选择(在可以获得多个排序答案的情况下,Plackett-Luce 是更一般的排序模型)

  • **排序损失:**在最后一层 transformer 层后添加一个线性层以获得奖励值的标量预测。为了确保奖励函数具有较低的方差,之前的工作会对奖励进行归一化
    在这里插入图片描述

RW代码

from dataclasses import dataclass, field
from typing import Optionalimport tyro
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfigfrom trl import RewardConfig, RewardTrainer, is_xpu_availabletqdm.pandas()@dataclass
class ScriptArguments:model_name: str = "facebook/opt-350m""""the model name"""dataset_name: str = "Anthropic/hh-rlhf""""the dataset name"""dataset_text_field: str = "text""""the text field of the dataset"""eval_split: str = "none""""the dataset split to evaluate on; default to 'none' (no evaluation)"""load_in_8bit: bool = False"""load the model in 8 bits precision"""load_in_4bit: bool = False"""load the model in 4 bits precision"""trust_remote_code: bool = True"""Enable `trust_remote_code`"""reward_config: RewardConfig = field(default_factory=lambda: RewardConfig(output_dir="output",per_device_train_batch_size=64,num_train_epochs=1,gradient_accumulation_steps=16,gradient_checkpointing=True,gradient_checkpointing_kwargs={"use_reentrant": False},learning_rate=1.41e-5,report_to="tensorboard",remove_unused_columns=False,optim="adamw_torch",logging_steps=500,evaluation_strategy="no",max_length=512,))use_peft: bool = False"""whether to use peft"""peft_config: Optional[LoraConfig] = field(default_factory=lambda: LoraConfig(r=16,lora_alpha=16,bias="none",task_type="SEQ_CLS",modules_to_save=["scores"],),)args = tyro.cli(ScriptArguments)
args.reward_config.evaluation_strategy = "steps" if args.eval_split != "none" else "no"# Step 1: Load the model
if args.load_in_8bit and args.load_in_4bit:raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif args.load_in_8bit or args.load_in_4bit:quantization_config = BitsAndBytesConfig(load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit)# Copy the model to each devicedevice_map = ({"": f"xpu:{Accelerator().local_process_index}"}if is_xpu_available()else {"": Accelerator().local_process_index})
else:device_map = Nonequantization_config = Nonemodel = AutoModelForSequenceClassification.from_pretrained(args.model_name,quantization_config=quantization_config,device_map=device_map,trust_remote_code=args.trust_remote_code,num_labels=1,
)# Step 2: Load the dataset and pre-process it
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
train_dataset = load_dataset(args.dataset_name, split="train")# Tokenize chosen/rejected pairs of inputs
# Adapt this section to your needs for custom datasets
def preprocess_function(examples):new_examples = {"input_ids_chosen": [],"attention_mask_chosen": [],"input_ids_rejected": [],"attention_mask_rejected": [],}for chosen, rejected in zip(examples["chosen"], examples["rejected"]):tokenized_chosen = tokenizer(chosen)tokenized_rejected = tokenizer(rejected)new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])return new_examples# Preprocess the dataset and filter out examples that are longer than args.max_length
train_dataset = train_dataset.map(preprocess_function,batched=True,num_proc=4,
)
train_dataset = train_dataset.filter(lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_lengthand len(x["input_ids_rejected"]) <= args.reward_config.max_length
)if args.eval_split == "none":eval_dataset = None
else:eval_dataset = load_dataset(args.dataset_name, split=args.eval_split)eval_dataset = eval_dataset.map(preprocess_function,batched=True,num_proc=4,)eval_dataset = eval_dataset.filter(lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_lengthand len(x["input_ids_rejected"]) <= args.reward_config.max_length)# Step 4: Define the LoraConfig
if args.use_peft:peft_config = args.peft_config
else:peft_config = None# Step 5: Define the Trainer
trainer = RewardTrainer(model=model,tokenizer=tokenizer,args=args.reward_config,train_dataset=train_dataset,eval_dataset=eval_dataset,peft_config=peft_config,
)trainer.train()

RLHF

  • 动手学强化学习: https://hrl.boyuai.com/chapter/2/actor-critic%E7%AE%97%E6%B3%95

让我们首先将微调任务表述为 RL 问题。

  • 首先,该 策略 (policy) 是一个接受提示并返回一系列文本 (或文本的概率分布) 的 LM。
  • 这个策略的 行动空间 (action space) 是 LM 的词表对应的所有词元 (一般在 50k 数量级)
  • 观察空间 (observation space) 是可能的输入词元序列,也比较大 (词汇量 ^ 输入标记的数量) 。
  • 奖励函数 是偏好模型和策略转变约束 (Policy shift constraint) 的结合。
    在这里插入图片描述

在这里插入图片描述

  • KL散度这一项被用于惩罚 RL 策略在每个训练批次中生成大幅偏离初始模型,以确保模型输出合理连贯的文本。如果去掉这一惩罚项可能导致模型在优化中生成乱码文本来愚弄奖励模型提供高奖励值。

可视化进度条的一种方法:

with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:for i_episode in range(episode):if (i_episode + 1) % 10 == 0:pbar.set_postfix({'episode':'%d' % (num_episodes / 10 * i + i_episode + 1),'return':'%.3f' % np.mean(return_list[-10:])})pbar.update(1)

策略梯度

  • 基于值函数的方法主要是学习值函数,然后根据值函数导出一个策略,学习过程中并不存在一个显式的策略;而基于策略的方法则是直接显式地学习一个目标策略。

AC算法

  • 基于值函数的方法只学习一个价值函数,而基于策略的方法只学习一个策略函数

  • Actor-Critic 算法本质上是基于策略的算法,因为这一系列算法的目标都是优化一个带参数的策略,只是会额外学习价值函数,从而帮助策略函数更好地学习。

  • Actor-Critic 算法估计一个动作价值函数 Q Q Q,代替蒙特卡洛采样得到的回报,这便是 Q ( s , a ) Q(s,a) Q(s,a)。这个时候,我们可以把状态价值函数 V V V作为基线,从 Q Q Q函数减去这个 V V V函数则得到了函数 A A A,我们称之为优势函数(advantage function)

Actor-Critic 分为两个部分:Actor(策略网络)和 Critic(价值网络)

  • Actor 要做的是与环境交互,并在 Critic 价值函数的指导下用策略梯度学习一个更好的策略。
  • Critic 要做的是通过 Actor 与环境交互收集的数据学习一个价值函数,这个价值函数会用于判断在当前状态什么动作是好的,什么动作不是好的,进而帮助 Actor 进行策略更新。

在这里插入图片描述
在这里插入图片描述

PPO

PPO惩罚

PPO-惩罚(PPO-Penalty)用拉格朗日乘数法直接将 KL 散度的限制放进了目标函数中,这就变成了一个无约束的优化问题,在迭代的过程中不断更新 KL 散度前的系数。
在这里插入图片描述

PPO截断

在这里插入图片描述

  • 对于连续动作,让策略网络输出连续动作高斯分布(Gaussian distribution)的均值和标准差。后续的连续动作则在该高斯分布中采样得到。
PPO的训练中存在的问题

PPO会找捷径,只要有机会,PPO 算法就会利用这些缺陷。

  1. 显然,当从概率低于 SFT 模型的策略中采样令牌时,这将导致负 KL 惩罚。但平均而言,它将是正的,否则您将无法从策略中正确采样。使用 KL 惩罚项是为了推动模型的输出保持接近基本策略的输出。一般来说,KL 散度衡量两个分布之间的距离,并且始终为正值。

  2. 某些生成策略可以强制生成某些token或抑制某些token。例如,当批量生成完成的序列时,会进行填充;当设置最小长度时,EOS 令牌会被抑制。该模型可以为那些导致负 KL 的标记分配非常高或非常低的概率。当 PPO 算法针对奖励进行优化时,它会追逐这些负面惩罚,从而导致不稳定。

    • 生成响应时需要小心,我们建议在采用更复杂的生成方法之前始终先使用简单的采样策略。
  3. 损失偶尔会出现峰值,这可能会导致进一步的不稳定。

  4. 字符串的重复会导致奖励的突然增加。

DPO

与以往的 RLHF 方法(先学习一个奖励函数,然后通过强化学习优化)不同,我们的方法跳过了奖励建模步骤,直接使用偏好数据优化语言模型。

  • 我们的核心观点是利用从奖励函数到最优策略的解析映射,将对奖励函数的损失转化为对策略的损失。这种变量转换的方法使我们能够跳过显式的奖励建模步骤,同时仍然在现有的人类偏好模型(如 Bradley-Terry 模型)下进行优化。实质上,策略网络既代表语言模型,又代表奖励。

在这里插入图片描述
在这里插入图片描述

RLHF开源工具

  • TRL
  • RL4LM

TRL实践

  • demo.py
# 0. imports
import torch
from transformers import GPT2Tokenizerfrom trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token# 2. initialize trainer
ppo_config = {"batch_size": 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)# 3. encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device)# 4. generate model response
generation_kwargs = {"min_length": -1,"top_k": 0.0,"top_p": 1.0,"do_sample": True,"pad_token_id": tokenizer.eos_token_id,"max_new_tokens": 20,
}
response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs)
response_txt = tokenizer.decode(response_tensor[0])# 5. define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0, device=model.pretrained_model.device)]# 6. train model with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

这篇关于RLHF学习的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/648457

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识

线性代数|机器学习-P36在图中找聚类

文章目录 1. 常见图结构2. 谱聚类 感觉后面几节课的内容跨越太大,需要补充太多的知识点,教授讲得内容跨越较大,一般一节课的内容是书本上的一章节内容,所以看视频比较吃力,需要先预习课本内容后才能够很好的理解教授讲解的知识点。 1. 常见图结构 假设我们有如下图结构: Adjacency Matrix:行和列表示的是节点的位置,A[i,j]表示的第 i 个节点和第 j 个

Node.js学习记录(二)

目录 一、express 1、初识express 2、安装express 3、创建并启动web服务器 4、监听 GET&POST 请求、响应内容给客户端 5、获取URL中携带的查询参数 6、获取URL中动态参数 7、静态资源托管 二、工具nodemon 三、express路由 1、express中路由 2、路由的匹配 3、路由模块化 4、路由模块添加前缀 四、中间件