DeepSpeed-Chat RLHF 阶段代码解读(0) —— 原始 PPO 代码解读

2024-03-03 12:44

本文主要是介绍DeepSpeed-Chat RLHF 阶段代码解读(0) —— 原始 PPO 代码解读,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

为了理解 DeepSpeed-Chat RLHF 的 RLHF 全部过程,这个系列会分三篇文章分别介绍:
原始 PPO 代码解读RLHF 奖励函数代码解读RLHF PPO 代码解读
这是系列的第一篇文章,我们来一步一步的看 PPO 算法的代码实现,对于 PPO 算法原理不太了解的同学,可以参考之前的文章:
深度强化学习(DRL)算法 2 —— PPO 之 Clipped Surrogate Objective 篇
深度强化学习(DRL)算法 2 —— PPO 之 GAE 篇

Clipped Surrogate 函数实现

# code from cleanrl: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py
for start in range(0, args.batch_size, args.minibatch_size):end = start + args.minibatch_sizemb_inds = b_inds[start:end]_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])logratio = newlogprob - b_logprobs[mb_inds]ratio = logratio.exp()mb_advantages = b_advantages[mb_inds]if args.norm_adv:mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)# Policy losspg_loss1 = -mb_advantages * ratiopg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)pg_loss = torch.max(pg_loss1, pg_loss2).mean()

Clipped Surrogate 函数的实现很简单,这里不再赘述,理解算法原理,代码自然而然就可以看懂,核心是 get_action_and_value 函数的理解。

def get_action_and_value(self, x, action=None):logits = self.actor(x)# probs 相当于计算 softmaxprobs = Categorical(logits=logits)if action is None:action = probs.sample()# probs.log_prob(action) 计算的是 p(a|s) 的 log 形式,方便计算 Clipped Surrogate 函数里的 ratioreturn action, probs.log_prob(action), probs.entropy(), self.critic(x) 

GAE 实现

直接来看 gae 可能比较抽象,我们先来看蒙特卡洛方法实现的优势估计,对蒙特卡洛方法不熟悉的同学,可以参考之前的文章。
深度强化学习(DRL)算法 附录 3 —— 蒙特卡洛方法(MC)和时序差分(TD)
两种方法都采用了反向迭代(因为反向迭代更好计算)的方式来实现优势估计。

# code from cleanrl: https://github.com/vwxyzjn/cleanrl/commit/b7088a41e5e6f0f5f6940fd29054a35118083b28
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):if t == args.num_steps - 1:nextnonterminal = 1.0 - next_donenext_return = last_valueelse:nextnonterminal = 1.0 - dones[t+1]next_return = returns[t+1]returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
advantages = returns - values

上面的代码做了什么事情呢,last_value 对应最后的 step(对应 step t) 产生的期望回报,如果 step t-1 整个流程没有结束,那么 t-1 时刻的期望回报就是 reward(t-1) + args.gamma * nextnonterminal * next_return,这样一步一步往后推,就可以计算每一个 step 的期望回报,从而得到每一步的优势,还没理解的话,看下面每个时间步的拆解。关于 last_value 的使用,这里由于没有后续的回报可以累积,所以直接使用 last_value 作为最后一个时间步的回报。关于下面为啥用 return[t-1] 替换原始公式的 value[t-1],这样计算的话就相当于蒙特卡洛方法的优势估计,如果next_return = returns[t+1] 改成 next_value = values[t+1] 就相当于 TD(1) 的优势估计。

# t
return(t) = v(t)
# t - 1
return(t-1) = reward(t-1) + gamma * return(t) = reward(t-1) + gamma * return(t)
# t - 2
return(t-2) = reward(t-2) + gamma * return(t-1) = reward(t-2) + gamma * (reward(t-1) + gamma * return(t))
......
# 我们可以看到一步一步往前推,最后就得到蒙特卡洛方法的优势估计

理解了上面讲的蒙特卡洛方法实现的优势估计,再来看 gae 的实现,我们可以看到代码实现上十分的相似,只是多了 delta 的计算,这里的 delta 对应的就是之前 PPO GAE 篇里介绍的 delta。

# code from cleanrl: https://github.com/vwxyzjn/cleanrl/commit/b7088a41e5e6f0f5f6940fd29054a35118083b28
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):if t == args.num_steps - 1:nextnonterminal = 1.0 - next_donenextvalues = last_valueelse:nextnonterminal = 1.0 - dones[t+1]nextvalues = values[t+1]delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values

这里通过反向迭代的方式计算 GAE advantage,可能理解上比较抽象,举个例子,就很好理解了:

# advantage(t)
adv[t] = lastgaelam = rewards[t] + gamma * values[t+1] - values[t]
# t-1
adv[t-1] = lastgaelam = rewards[t-1] + gamma * values[t] - values[t-1] + gamma * lambda * lastgaelam
# t-2
adv[t-2] = lastgaelam = rewards[t-2] + gamma * values[t-1] - values[t-2] + gamma * lambda * lastgaelam
...

可以看到,逐项展开,每一时刻的 GAE Advantage 和 PPO GAE 篇里介绍的公式是一模一样的,这里 GAE 就是一种数学公式表达,核心思想是 n step 的优势估计的加权平均,通过数学技巧恰好是上面的形式。

参考

  1. The 37 Implementation Details of Proximal Policy Optimization · The ICLR Blog Track (iclr-blog-track.github.io)
  2. HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION

这篇关于DeepSpeed-Chat RLHF 阶段代码解读(0) —— 原始 PPO 代码解读的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/769650

相关文章

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n

MCU7.keil中build产生的hex文件解读

1.hex文件大致解读 闲来无事,查看了MCU6.用keil新建项目的hex文件 用FlexHex打开 给我的第一印象是:经过软件的解释之后,发现这些数据排列地十分整齐 :02000F0080FE71:03000000020003F8:0C000300787FE4F6D8FD75810702000F3D:00000001FF 把解释后的数据当作十六进制来观察 1.每一行数据

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

计算机毕业设计 大学志愿填报系统 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点赞 👍 收藏 ⭐评论 📝 🍅 文末获取源码联系 👇🏻 精彩专栏推荐订阅 👇🏻 不然下次找不到哟~Java毕业设计项目~热门选题推荐《1000套》 目录 1.技术选型 2.开发工具 3.功能

代码随想录冲冲冲 Day39 动态规划Part7

198. 打家劫舍 dp数组的意义是在第i位的时候偷的最大钱数是多少 如果nums的size为0 总价值当然就是0 如果nums的size为1 总价值是nums[0] 遍历顺序就是从小到大遍历 之后是递推公式 对于dp[i]的最大价值来说有两种可能 1.偷第i个 那么最大价值就是dp[i-2]+nums[i] 2.不偷第i个 那么价值就是dp[i-1] 之后取这两个的最大值就是d

pip-tools:打造可重复、可控的 Python 开发环境,解决依赖关系,让代码更稳定

在 Python 开发中,管理依赖关系是一项繁琐且容易出错的任务。手动更新依赖版本、处理冲突、确保一致性等等,都可能让开发者感到头疼。而 pip-tools 为开发者提供了一套稳定可靠的解决方案。 什么是 pip-tools? pip-tools 是一组命令行工具,旨在简化 Python 依赖关系的管理,确保项目环境的稳定性和可重复性。它主要包含两个核心工具:pip-compile 和 pip

D4代码AC集

贪心问题解决的步骤: (局部贪心能导致全局贪心)    1.确定贪心策略    2.验证贪心策略是否正确 排队接水 #include<bits/stdc++.h>using namespace std;int main(){int w,n,a[32000];cin>>w>>n;for(int i=1;i<=n;i++){cin>>a[i];}sort(a+1,a+n+1);int i=1

Spring 源码解读:自定义实现Bean定义的注册与解析

引言 在Spring框架中,Bean的注册与解析是整个依赖注入流程的核心步骤。通过Bean定义,Spring容器知道如何创建、配置和管理每个Bean实例。本篇文章将通过实现一个简化版的Bean定义注册与解析机制,帮助你理解Spring框架背后的设计逻辑。我们还将对比Spring中的BeanDefinition和BeanDefinitionRegistry,以全面掌握Bean注册和解析的核心原理。

GPT系列之:GPT-1,GPT-2,GPT-3详细解读

一、GPT1 论文:Improving Language Understanding by Generative Pre-Training 链接:https://cdn.openai.com/research-covers/languageunsupervised/language_understanding_paper.pdf 启发点:生成loss和微调loss同时作用,让下游任务来适应预训