RLHF介绍及实践测试

2023-12-22 07:36
文章标签 实践 介绍 测试 rlhf

本文主要是介绍RLHF介绍及实践测试,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

介绍

RLHF(Reinforcement Learning Hyperparameter Optimization Framework)是一种用于强化学习模型的超参数优化框架。它结合了强化学习中的经典方法贝叶斯优化技术能够更高效地找到最佳超参数组合。下面是强化学习微调的完整 RLHF 流程:

  • RLHF-Stage1 是 supervised-fintuning,即使用上文提到的数据集进行模型微调,目的是将大模型能力往垂直领域迁移;
  • RLHF-Stage2 训练奖励模型,它通过对于同一个 prompt 的不同输出进行人工排序,得到对应分数,监督训练奖励模型,目的是训练一个自动评估函数
  • RLHF-Stage3 使用了强化学习算法训练优化LM,目前多个组织找到的可行方案是使用策略梯度强化学习 (Policy Gradient RL) 算法、近端策略优化 (Proximal Policy Optimization,PPO) 微调初始 LM 的部分或全部参数。

ps: 与lora微调的区别是:RLHF多了强化学习的过程,lora微调相当于RLHF-Stage1的SFT

参考学习资料:如何看待Geoffrey Hinton对RLHF的看法? - 知乎【科普向】Chat GPT背后的技术:什么是RLHF(人类反馈强化学习)? - 哔哩哔哩

框架

  • DeepspeedChat:暂不支持LLama、chatglm,IDEA的微调https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat
  • Trlx:GitHub - CarperAI/trlx: A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
  • ColossalAI-Chat:暂不支持chatglm,IDEA的微调https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat

三个框架对比介绍:

RLHF几大常用框架实践对比(trlx、deepspeedchat、colossalaichat) - 知乎

实践

本次实践采用ColossalAI框架分步训练(暂不支持TP策略,支持DP策略)

官方训练介绍:https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat#rlhf-training-stage3---training-model-with-reinforcement-learning-by-human-feedback

conda环境:conda activate coati

RLHF Training Stage1 - Supervised instructs tuning

数据准备:https://huggingface.co/datasets/yizhongw/self_instruct/viewer/super_natural_instructions/train

train_sft.sh:执行监督训练shell脚本

CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 train_sft.py \--pretrain "/data/jupyter/LLM/models/llama-7b-hf/" \  #微调训练底模--model 'llama' \--strategy colossalai_zero2 \ #微调策略方法--log_interval 10 \--save_path  /data/jupyter/your_production/ColossalAI/applications/Chat/models/sft-7b \ #保存路径--dataset "yizhongw/self_instruct" \ #huggingface数据集--batch_size 1 \--accumulation_steps 8 \--lr 2e-5 \--max_datasets_size 512 \--max_epochs 1 \--lora_rank 1

ps:

  • 更多参数说明参考

train_sft.py

  • 训练方法:执行

./train_sft.sh

  •  该步训练的坑较少,只要显存足够,一般不会遇到问题。

RLHF Training Stage2 - Training reward model

数据准备:https://huggingface.co/datasets/Anthropic/hh-rlhf/viewer/Anthropic--hh-rlhf/train?row=1

train_rm.sh:执行奖励函数训练脚本

torchrun --standalone --nproc_per_node=1 train_reward_model.py \--pretrain  "/data/jupyter/your_prodcution/ColossalAI/applications/Chat/models/sft-7b" \ #这里是第一步训练保存的模型路径--model 'llama' \--strategy colossalai_gemini \ #训练策略,这里只能该策略,其他策略实测单张3090 24G显存不足--loss_fn 'log_exp'\--save_path /data/jupyter/your_prodcution/ColossalAI/applications/Chat/models/rmstatic.pt \ #保存模型路径,这里仅为模型权重--dataset 'Anthropic/hh-rlhf'\ #huggingface数据集--lora_rank 1 \--batch_size 1 \--max_len 128 

ps:

  • 更多参数说明参考

train_reward_model.py

  • pretrain的模型是第一步训练保存的模型
  • strategy只能执行colossalai_gemini,其他会显存不足
  • max_len设置为128、256可以跑通,但512会出现显存不足

RLHF Training Stage2 - Training reward model

数据准备:

使用generate_prompt_dataset.py对目标数据生成prompt数据(instructions)https://github.com/XueFuzhao/InstructionWild/tree/main/data#instructwild-data

使用步骤一的pretrain dataset(including the instruction and corresponding response)https://huggingface.co/datasets/yizhongw/self_instruct/viewer/super_natural_instructions/train

train_prompts.sh:执行LM微调训练脚本

torchrun --standalone --nproc_per_node=2 train_prompts.py \--pretrain "/data/jupyter/your_production/ColossalAI/applications/Chat/models/sft-7b" \--model 'llama' \--strategy colossalai_gemini \--prompt_dataset /data/jupyter/LLM/datasets/InstructionWild/data1 \--pretrain_dataset /data/jupyter/LLM/datasets/self_instruct \--rm_pretrain /your/pretrain/rm/definition \--rm_path /data/jupyter/your_production/ColossalAI/applications/Chat/models/rmstatic.pt

ps:

  • 因显存不足,该过程暂无法跑通,底层代码多处封装cuda使用,较难使用仅cpu运行
  • rm_pretrain本意应为训练第二步保存的模型结构,但第二步训练保存的是pt文件,无保存模型结构(colossalai_gemini无法执行save_pretrained,原作者也没有这样保存,colossalai_zero2策略可以,但显存不足),所以在第三步作者是分两步完成模型加载

state_dict = torch.load(args.rm_path, map_location='cpu') reward_model = LlamaRM(pretrained=args.rm_pretrain) reward_model.load_state_dict(state_dict)

  • 这里存在有问题:第二步RM保存pt文件是有两层lora训练的,LlamaRM是无lora的,导致加载直接报错,修改为:

reward_model = LlamaRM(pretrained=pretrain, lora_rank=lora_rank)

  • critic加载第二步RM保存pt文件,存在问题,LlamaCritic是三层lora,pt是二层lora导致报错:

_IncompatibleKeys(missing_keys=['value_head.lora_A', 'value_head.lora_B'], unexpected_keys=[])

修改critic.load_state_dict(state_dict, strict=False)可解决;

  • critic的lora加载顺序可能有问题:先加载value_head后convert_to_lora,导致value_head不可训练,该层参数随机化;

self.model = model self.value_head = value_head self.use_action_mask = use_action_mask self.convert_to_lora()

这篇关于RLHF介绍及实践测试的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/523107

相关文章

Oracle查询优化之高效实现仅查询前10条记录的方法与实践

《Oracle查询优化之高效实现仅查询前10条记录的方法与实践》:本文主要介绍Oracle查询优化之高效实现仅查询前10条记录的相关资料,包括使用ROWNUM、ROW_NUMBER()函数、FET... 目录1. 使用 ROWNUM 查询2. 使用 ROW_NUMBER() 函数3. 使用 FETCH FI

Python进阶之Excel基本操作介绍

《Python进阶之Excel基本操作介绍》在现实中,很多工作都需要与数据打交道,Excel作为常用的数据处理工具,一直备受人们的青睐,本文主要为大家介绍了一些Python中Excel的基本操作,希望... 目录概述写入使用 xlwt使用 XlsxWriter读取修改概述在现实中,很多工作都需要与数据打交

在C#中获取端口号与系统信息的高效实践

《在C#中获取端口号与系统信息的高效实践》在现代软件开发中,尤其是系统管理、运维、监控和性能优化等场景中,了解计算机硬件和网络的状态至关重要,C#作为一种广泛应用的编程语言,提供了丰富的API来帮助开... 目录引言1. 获取端口号信息1.1 获取活动的 TCP 和 UDP 连接说明:应用场景:2. 获取硬

Java内存泄漏问题的排查、优化与最佳实践

《Java内存泄漏问题的排查、优化与最佳实践》在Java开发中,内存泄漏是一个常见且令人头疼的问题,内存泄漏指的是程序在运行过程中,已经不再使用的对象没有被及时释放,从而导致内存占用不断增加,最终... 目录引言1. 什么是内存泄漏?常见的内存泄漏情况2. 如何排查 Java 中的内存泄漏?2.1 使用 J

java脚本使用不同版本jdk的说明介绍

《java脚本使用不同版本jdk的说明介绍》本文介绍了在Java中执行JavaScript脚本的几种方式,包括使用ScriptEngine、Nashorn和GraalVM,ScriptEngine适用... 目录Java脚本使用不同版本jdk的说明1.使用ScriptEngine执行javascript2.

Python实现NLP的完整流程介绍

《Python实现NLP的完整流程介绍》这篇文章主要为大家详细介绍了Python实现NLP的完整流程,文中的示例代码讲解详细,具有一定的借鉴价值,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 编程安装和导入必要的库2. 文本数据准备3. 文本预处理3.1 小写化3.2 分词(Tokenizatio

Linux中Curl参数详解实践应用

《Linux中Curl参数详解实践应用》在现代网络开发和运维工作中,curl命令是一个不可或缺的工具,它是一个利用URL语法在命令行下工作的文件传输工具,支持多种协议,如HTTP、HTTPS、FTP等... 目录引言一、基础请求参数1. -X 或 --request2. -d 或 --data3. -H 或

Docker集成CI/CD的项目实践

《Docker集成CI/CD的项目实践》本文主要介绍了Docker集成CI/CD的项目实践,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学... 目录一、引言1.1 什么是 CI/CD?1.2 docker 在 CI/CD 中的作用二、Docke

如何测试计算机的内存是否存在问题? 判断电脑内存故障的多种方法

《如何测试计算机的内存是否存在问题?判断电脑内存故障的多种方法》内存是电脑中非常重要的组件之一,如果内存出现故障,可能会导致电脑出现各种问题,如蓝屏、死机、程序崩溃等,如何判断内存是否出现故障呢?下... 如果你的电脑是崩溃、冻结还是不稳定,那么它的内存可能有问题。要进行检查,你可以使用Windows 11

基于MySQL Binlog的Elasticsearch数据同步实践

一、为什么要做 随着马蜂窝的逐渐发展,我们的业务数据越来越多,单纯使用 MySQL 已经不能满足我们的数据查询需求,例如对于商品、订单等数据的多维度检索。 使用 Elasticsearch 存储业务数据可以很好的解决我们业务中的搜索需求。而数据进行异构存储后,随之而来的就是数据同步的问题。 二、现有方法及问题 对于数据同步,我们目前的解决方案是建立数据中间表。把需要检索的业务数据,统一放到一张M