DeepSpeed-Chat RLHF 阶段代码解读(0) —— 原始 PPO 代码解读

2024-03-03 12:44

本文主要是介绍DeepSpeed-Chat RLHF 阶段代码解读(0) —— 原始 PPO 代码解读,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

为了理解 DeepSpeed-Chat RLHF 的 RLHF 全部过程,这个系列会分三篇文章分别介绍:
原始 PPO 代码解读RLHF 奖励函数代码解读RLHF PPO 代码解读
这是系列的第一篇文章,我们来一步一步的看 PPO 算法的代码实现,对于 PPO 算法原理不太了解的同学,可以参考之前的文章:
深度强化学习(DRL)算法 2 —— PPO 之 Clipped Surrogate Objective 篇
深度强化学习(DRL)算法 2 —— PPO 之 GAE 篇

Clipped Surrogate 函数实现

# code from cleanrl: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py
for start in range(0, args.batch_size, args.minibatch_size):end = start + args.minibatch_sizemb_inds = b_inds[start:end]_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])logratio = newlogprob - b_logprobs[mb_inds]ratio = logratio.exp()mb_advantages = b_advantages[mb_inds]if args.norm_adv:mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)# Policy losspg_loss1 = -mb_advantages * ratiopg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)pg_loss = torch.max(pg_loss1, pg_loss2).mean()

Clipped Surrogate 函数的实现很简单,这里不再赘述,理解算法原理,代码自然而然就可以看懂,核心是 get_action_and_value 函数的理解。

def get_action_and_value(self, x, action=None):logits = self.actor(x)# probs 相当于计算 softmaxprobs = Categorical(logits=logits)if action is None:action = probs.sample()# probs.log_prob(action) 计算的是 p(a|s) 的 log 形式,方便计算 Clipped Surrogate 函数里的 ratioreturn action, probs.log_prob(action), probs.entropy(), self.critic(x) 

GAE 实现

直接来看 gae 可能比较抽象,我们先来看蒙特卡洛方法实现的优势估计,对蒙特卡洛方法不熟悉的同学,可以参考之前的文章。
深度强化学习(DRL)算法 附录 3 —— 蒙特卡洛方法(MC)和时序差分(TD)
两种方法都采用了反向迭代(因为反向迭代更好计算)的方式来实现优势估计。

# code from cleanrl: https://github.com/vwxyzjn/cleanrl/commit/b7088a41e5e6f0f5f6940fd29054a35118083b28
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):if t == args.num_steps - 1:nextnonterminal = 1.0 - next_donenext_return = last_valueelse:nextnonterminal = 1.0 - dones[t+1]next_return = returns[t+1]returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
advantages = returns - values

上面的代码做了什么事情呢,last_value 对应最后的 step(对应 step t) 产生的期望回报,如果 step t-1 整个流程没有结束,那么 t-1 时刻的期望回报就是 reward(t-1) + args.gamma * nextnonterminal * next_return,这样一步一步往后推,就可以计算每一个 step 的期望回报,从而得到每一步的优势,还没理解的话,看下面每个时间步的拆解。关于 last_value 的使用,这里由于没有后续的回报可以累积,所以直接使用 last_value 作为最后一个时间步的回报。关于下面为啥用 return[t-1] 替换原始公式的 value[t-1],这样计算的话就相当于蒙特卡洛方法的优势估计,如果next_return = returns[t+1] 改成 next_value = values[t+1] 就相当于 TD(1) 的优势估计。

# t
return(t) = v(t)
# t - 1
return(t-1) = reward(t-1) + gamma * return(t) = reward(t-1) + gamma * return(t)
# t - 2
return(t-2) = reward(t-2) + gamma * return(t-1) = reward(t-2) + gamma * (reward(t-1) + gamma * return(t))
......
# 我们可以看到一步一步往前推,最后就得到蒙特卡洛方法的优势估计

理解了上面讲的蒙特卡洛方法实现的优势估计,再来看 gae 的实现,我们可以看到代码实现上十分的相似,只是多了 delta 的计算,这里的 delta 对应的就是之前 PPO GAE 篇里介绍的 delta。

# code from cleanrl: https://github.com/vwxyzjn/cleanrl/commit/b7088a41e5e6f0f5f6940fd29054a35118083b28
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):if t == args.num_steps - 1:nextnonterminal = 1.0 - next_donenextvalues = last_valueelse:nextnonterminal = 1.0 - dones[t+1]nextvalues = values[t+1]delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values

这里通过反向迭代的方式计算 GAE advantage,可能理解上比较抽象,举个例子,就很好理解了:

# advantage(t)
adv[t] = lastgaelam = rewards[t] + gamma * values[t+1] - values[t]
# t-1
adv[t-1] = lastgaelam = rewards[t-1] + gamma * values[t] - values[t-1] + gamma * lambda * lastgaelam
# t-2
adv[t-2] = lastgaelam = rewards[t-2] + gamma * values[t-1] - values[t-2] + gamma * lambda * lastgaelam
...

可以看到,逐项展开,每一时刻的 GAE Advantage 和 PPO GAE 篇里介绍的公式是一模一样的,这里 GAE 就是一种数学公式表达,核心思想是 n step 的优势估计的加权平均,通过数学技巧恰好是上面的形式。

参考

  1. The 37 Implementation Details of Proximal Policy Optimization · The ICLR Blog Track (iclr-blog-track.github.io)
  2. HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION

这篇关于DeepSpeed-Chat RLHF 阶段代码解读(0) —— 原始 PPO 代码解读的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/769650

相关文章

Java调用DeepSeek API的最佳实践及详细代码示例

《Java调用DeepSeekAPI的最佳实践及详细代码示例》:本文主要介绍如何使用Java调用DeepSeekAPI,包括获取API密钥、添加HTTP客户端依赖、创建HTTP请求、处理响应、... 目录1. 获取API密钥2. 添加HTTP客户端依赖3. 创建HTTP请求4. 处理响应5. 错误处理6.

使用 sql-research-assistant进行 SQL 数据库研究的实战指南(代码实现演示)

《使用sql-research-assistant进行SQL数据库研究的实战指南(代码实现演示)》本文介绍了sql-research-assistant工具,该工具基于LangChain框架,集... 目录技术背景介绍核心原理解析代码实现演示安装和配置项目集成LangSmith 配置(可选)启动服务应用场景

Python中顺序结构和循环结构示例代码

《Python中顺序结构和循环结构示例代码》:本文主要介绍Python中的条件语句和循环语句,条件语句用于根据条件执行不同的代码块,循环语句用于重复执行一段代码,文章还详细说明了range函数的使... 目录一、条件语句(1)条件语句的定义(2)条件语句的语法(a)单分支 if(b)双分支 if-else(

MySQL数据库函数之JSON_EXTRACT示例代码

《MySQL数据库函数之JSON_EXTRACT示例代码》:本文主要介绍MySQL数据库函数之JSON_EXTRACT的相关资料,JSON_EXTRACT()函数用于从JSON文档中提取值,支持对... 目录前言基本语法路径表达式示例示例 1: 提取简单值示例 2: 提取嵌套值示例 3: 提取数组中的值注意

CSS3中使用flex和grid实现等高元素布局的示例代码

《CSS3中使用flex和grid实现等高元素布局的示例代码》:本文主要介绍了使用CSS3中的Flexbox和Grid布局实现等高元素布局的方法,通过简单的两列实现、每行放置3列以及全部代码的展示,展示了这两种布局方式的实现细节和效果,详细内容请阅读本文,希望能对你有所帮助... 过往的实现方法是使用浮动加

JAVA调用Deepseek的api完成基本对话简单代码示例

《JAVA调用Deepseek的api完成基本对话简单代码示例》:本文主要介绍JAVA调用Deepseek的api完成基本对话的相关资料,文中详细讲解了如何获取DeepSeekAPI密钥、添加H... 获取API密钥首先,从DeepSeek平台获取API密钥,用于身份验证。添加HTTP客户端依赖使用Jav

Java实现状态模式的示例代码

《Java实现状态模式的示例代码》状态模式是一种行为型设计模式,允许对象根据其内部状态改变行为,本文主要介绍了Java实现状态模式的示例代码,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来... 目录一、简介1、定义2、状态模式的结构二、Java实现案例1、电灯开关状态案例2、番茄工作法状态案例

nginx-rtmp-module模块实现视频点播的示例代码

《nginx-rtmp-module模块实现视频点播的示例代码》本文主要介绍了nginx-rtmp-module模块实现视频点播,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习... 目录预置条件Nginx点播基本配置点播远程文件指定多个播放位置参考预置条件配置点播服务器 192.

MySQL中的MVCC底层原理解读

《MySQL中的MVCC底层原理解读》本文详细介绍了MySQL中的多版本并发控制(MVCC)机制,包括版本链、ReadView以及在不同事务隔离级别下MVCC的工作原理,通过一个具体的示例演示了在可重... 目录简介ReadView版本链演示过程总结简介MVCC(Multi-Version Concurr

关于Gateway路由匹配规则解读

《关于Gateway路由匹配规则解读》本文详细介绍了SpringCloudGateway的路由匹配规则,包括基本概念、常用属性、实际应用以及注意事项,路由匹配规则决定了请求如何被转发到目标服务,是Ga... 目录Gateway路由匹配规则一、基本概念二、常用属性三、实际应用四、注意事项总结Gateway路由