DeepSpeed-Chat RLHF 阶段代码解读(0) —— 原始 PPO 代码解读

2024-03-03 12:44

本文主要是介绍DeepSpeed-Chat RLHF 阶段代码解读(0) —— 原始 PPO 代码解读,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

为了理解 DeepSpeed-Chat RLHF 的 RLHF 全部过程,这个系列会分三篇文章分别介绍:
原始 PPO 代码解读RLHF 奖励函数代码解读RLHF PPO 代码解读
这是系列的第一篇文章,我们来一步一步的看 PPO 算法的代码实现,对于 PPO 算法原理不太了解的同学,可以参考之前的文章:
深度强化学习(DRL)算法 2 —— PPO 之 Clipped Surrogate Objective 篇
深度强化学习(DRL)算法 2 —— PPO 之 GAE 篇

Clipped Surrogate 函数实现

# code from cleanrl: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py
for start in range(0, args.batch_size, args.minibatch_size):end = start + args.minibatch_sizemb_inds = b_inds[start:end]_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])logratio = newlogprob - b_logprobs[mb_inds]ratio = logratio.exp()mb_advantages = b_advantages[mb_inds]if args.norm_adv:mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)# Policy losspg_loss1 = -mb_advantages * ratiopg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)pg_loss = torch.max(pg_loss1, pg_loss2).mean()

Clipped Surrogate 函数的实现很简单,这里不再赘述,理解算法原理,代码自然而然就可以看懂,核心是 get_action_and_value 函数的理解。

def get_action_and_value(self, x, action=None):logits = self.actor(x)# probs 相当于计算 softmaxprobs = Categorical(logits=logits)if action is None:action = probs.sample()# probs.log_prob(action) 计算的是 p(a|s) 的 log 形式,方便计算 Clipped Surrogate 函数里的 ratioreturn action, probs.log_prob(action), probs.entropy(), self.critic(x) 

GAE 实现

直接来看 gae 可能比较抽象,我们先来看蒙特卡洛方法实现的优势估计,对蒙特卡洛方法不熟悉的同学,可以参考之前的文章。
深度强化学习(DRL)算法 附录 3 —— 蒙特卡洛方法(MC)和时序差分(TD)
两种方法都采用了反向迭代(因为反向迭代更好计算)的方式来实现优势估计。

# code from cleanrl: https://github.com/vwxyzjn/cleanrl/commit/b7088a41e5e6f0f5f6940fd29054a35118083b28
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):if t == args.num_steps - 1:nextnonterminal = 1.0 - next_donenext_return = last_valueelse:nextnonterminal = 1.0 - dones[t+1]next_return = returns[t+1]returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
advantages = returns - values

上面的代码做了什么事情呢,last_value 对应最后的 step(对应 step t) 产生的期望回报,如果 step t-1 整个流程没有结束,那么 t-1 时刻的期望回报就是 reward(t-1) + args.gamma * nextnonterminal * next_return,这样一步一步往后推,就可以计算每一个 step 的期望回报,从而得到每一步的优势,还没理解的话,看下面每个时间步的拆解。关于 last_value 的使用,这里由于没有后续的回报可以累积,所以直接使用 last_value 作为最后一个时间步的回报。关于下面为啥用 return[t-1] 替换原始公式的 value[t-1],这样计算的话就相当于蒙特卡洛方法的优势估计,如果next_return = returns[t+1] 改成 next_value = values[t+1] 就相当于 TD(1) 的优势估计。

# t
return(t) = v(t)
# t - 1
return(t-1) = reward(t-1) + gamma * return(t) = reward(t-1) + gamma * return(t)
# t - 2
return(t-2) = reward(t-2) + gamma * return(t-1) = reward(t-2) + gamma * (reward(t-1) + gamma * return(t))
......
# 我们可以看到一步一步往前推,最后就得到蒙特卡洛方法的优势估计

理解了上面讲的蒙特卡洛方法实现的优势估计,再来看 gae 的实现,我们可以看到代码实现上十分的相似,只是多了 delta 的计算,这里的 delta 对应的就是之前 PPO GAE 篇里介绍的 delta。

# code from cleanrl: https://github.com/vwxyzjn/cleanrl/commit/b7088a41e5e6f0f5f6940fd29054a35118083b28
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):if t == args.num_steps - 1:nextnonterminal = 1.0 - next_donenextvalues = last_valueelse:nextnonterminal = 1.0 - dones[t+1]nextvalues = values[t+1]delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values

这里通过反向迭代的方式计算 GAE advantage,可能理解上比较抽象,举个例子,就很好理解了:

# advantage(t)
adv[t] = lastgaelam = rewards[t] + gamma * values[t+1] - values[t]
# t-1
adv[t-1] = lastgaelam = rewards[t-1] + gamma * values[t] - values[t-1] + gamma * lambda * lastgaelam
# t-2
adv[t-2] = lastgaelam = rewards[t-2] + gamma * values[t-1] - values[t-2] + gamma * lambda * lastgaelam
...

可以看到,逐项展开,每一时刻的 GAE Advantage 和 PPO GAE 篇里介绍的公式是一模一样的,这里 GAE 就是一种数学公式表达,核心思想是 n step 的优势估计的加权平均,通过数学技巧恰好是上面的形式。

参考

  1. The 37 Implementation Details of Proximal Policy Optimization · The ICLR Blog Track (iclr-blog-track.github.io)
  2. HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION

这篇关于DeepSpeed-Chat RLHF 阶段代码解读(0) —— 原始 PPO 代码解读的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/769650

相关文章

利用Python调试串口的示例代码

《利用Python调试串口的示例代码》在嵌入式开发、物联网设备调试过程中,串口通信是最基础的调试手段本文将带你用Python+ttkbootstrap打造一款高颜值、多功能的串口调试助手,需要的可以了... 目录概述:为什么需要专业的串口调试工具项目架构设计1.1 技术栈选型1.2 关键类说明1.3 线程模

Python Transformers库(NLP处理库)案例代码讲解

《PythonTransformers库(NLP处理库)案例代码讲解》本文介绍transformers库的全面讲解,包含基础知识、高级用法、案例代码及学习路径,内容经过组织,适合不同阶段的学习者,对... 目录一、基础知识1. Transformers 库简介2. 安装与环境配置3. 快速上手示例二、核心模

Java的栈与队列实现代码解析

《Java的栈与队列实现代码解析》栈是常见的线性数据结构,栈的特点是以先进后出的形式,后进先出,先进后出,分为栈底和栈顶,栈应用于内存的分配,表达式求值,存储临时的数据和方法的调用等,本文给大家介绍J... 目录栈的概念(Stack)栈的实现代码队列(Queue)模拟实现队列(双链表实现)循环队列(循环数组

Mysql用户授权(GRANT)语法及示例解读

《Mysql用户授权(GRANT)语法及示例解读》:本文主要介绍Mysql用户授权(GRANT)语法及示例,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录mysql用户授权(GRANT)语法授予用户权限语法GRANT语句中的<权限类型>的使用WITH GRANT

使用Java将DOCX文档解析为Markdown文档的代码实现

《使用Java将DOCX文档解析为Markdown文档的代码实现》在现代文档处理中,Markdown(MD)因其简洁的语法和良好的可读性,逐渐成为开发者、技术写作者和内容创作者的首选格式,然而,许多文... 目录引言1. 工具和库介绍2. 安装依赖库3. 使用Apache POI解析DOCX文档4. 将解析

C++使用printf语句实现进制转换的示例代码

《C++使用printf语句实现进制转换的示例代码》在C语言中,printf函数可以直接实现部分进制转换功能,通过格式说明符(formatspecifier)快速输出不同进制的数值,下面给大家分享C+... 目录一、printf 原生支持的进制转换1. 十进制、八进制、十六进制转换2. 显示进制前缀3. 指

python3 gunicorn配置文件的用法解读

《python3gunicorn配置文件的用法解读》:本文主要介绍python3gunicorn配置文件的使用,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录python3 gunicorn配置文件配置文件服务启动、重启、关闭启动重启关闭总结python3 gun

关于pandas的read_csv方法使用解读

《关于pandas的read_csv方法使用解读》:本文主要介绍关于pandas的read_csv方法使用,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录pandas的read_csv方法解读read_csv中的参数基本参数通用解析参数空值处理相关参数时间处理相关

使用Python实现全能手机虚拟键盘的示例代码

《使用Python实现全能手机虚拟键盘的示例代码》在数字化办公时代,你是否遇到过这样的场景:会议室投影电脑突然键盘失灵、躺在沙发上想远程控制书房电脑、或者需要给长辈远程协助操作?今天我要分享的Pyth... 目录一、项目概述:不止于键盘的远程控制方案1.1 创新价值1.2 技术栈全景二、需求实现步骤一、需求

Java中Date、LocalDate、LocalDateTime、LocalTime、时间戳之间的相互转换代码

《Java中Date、LocalDate、LocalDateTime、LocalTime、时间戳之间的相互转换代码》:本文主要介绍Java中日期时间转换的多种方法,包括将Date转换为LocalD... 目录一、Date转LocalDateTime二、Date转LocalDate三、LocalDateTim