ETH开源PPO算法学习

2024-02-29 04:04
文章标签 算法 学习 开源 eth ppo

本文主要是介绍ETH开源PPO算法学习,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言

项目地址:https://github.com/leggedrobotics/rsl_rl

项目简介:快速简单的强化学习算法实现,设计为完全在 GPU 上运行。这段代码是 NVIDIA Isaac GYM 提供的 rl-pytorch 的进化版。

下载源码,查看目录,整个项目模块化得非常好,每个部分各司其职。下面我们自底向上地进行讲解加粗的部分。

rsl_rl/
│ __init__.py

├─algorithms/
│ │ __init__.py
│ │ ppo.py # PPO算法的实现
│ │
├─env/
│ │ __init__.py
│ │ vec_env.py # 实现并行处理多个环境的向量化环境
│ │
├─modules/
│ │ __init__.py
│ │ actor_critic.py # 定义 Actor-Critic 网络结构
│ │ actor_critic_recurrent.py # 定义包含循环层的 Actor-Critic 网络
│ │ normalizer.py # 数据正规化工具,有助于训练过程的稳定性
│ │
├─runners/
│ │ __init__.py
│ │ on_policy_runner.py # 实现用于执行 on-policy 算法训练循环的运行器
│ │
├─storage/
│ │ __init__.py
│ │ rollout_storage.py # 存储和管理策略 rollout 数据的工具
│ │
└─utils/
│ __init__.py
│ neptune_utils.py # 用于与 Neptune.ai 集成的工具
│ utils.py # 通用实用工具函数
│ wandb_utils.py # 用于与 Weights & Biases 集成的工具

rollout 数据储存和管理(rollout_storage.py)

定义了一个名为 RolloutStorage 的类,用于存储和管理在强化学习训练过程中从环境中收集到的数据(称为rollouts)。

  • 定义Transition

用于存储单个时间步的所有相关数据,包括观察值、动作、奖励、完成标志(dones)、值函数估计、动作的对数概率、动作的均值和标准差,以及可能的隐藏状态(对于使用循环网络的情况)。

  • 特权观察值(Privileged Observations)

除了self.observations外还有self.privileged_observations的使用,在强化学习中是指那些在训练期间可用但在实际部署或测试时不可用的额外信息。这些信息通常提供了环境的内部状态或其他有助于学习的提示,但在现实世界应用中可能难以获得或完全不可用。在训练期间使用特权观察值的一种常见方法是通过教师-学生架构(我们常常也称作特权学习),其中一个拥有全部信息的教师模型(可以访问特权观察值)来指导一个学生模型(只能访问普通观察值)。学生模型的目标是模仿教师模型的决策,尽管它没有直接访问特权信息。

  • 奖励和优势的计算
    def compute_returns(self, last_values, gamma, lam):advantage = 0for step in reversed(range(self.num_transitions_per_env)):if step == self.num_transitions_per_env - 1:next_values = last_valueselse:next_values = self.values[step + 1]next_is_not_terminal = 1.0 - self.dones[step].float()delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]advantage = delta + next_is_not_terminal * gamma * lam * advantageself.returns[step] = advantage + self.values[step]# Compute and normalize the advantagesself.advantages = self.returns - self.valuesself.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)

这段代码实现的是在强化学习中计算回报(returns)和优势(advantages)的逻辑,具体是使用了一种称为广义优势估算(Generalized Advantage Estimation, GAE)的方法。GAE是一种权衡偏差和方差以及平滑回报信号的技术,由以下几个数学公式定义:

  1. TD残差(Temporal Difference Residual):
    δ t = R t + γ V ( S t + 1 ) ( 1 − d o n e t ) − V ( S t ) \delta_t = R_t + \gamma V(S_{t+1}) (1 - done_t) - V(S_t) δt=Rt+γV(St+1)(1donet)V(St)
    其中, δ t \delta_t δt是时刻 t t t的TD残差, R t R_t Rt是奖励, γ \gamma γ是折扣因子, V ( S t ) V(S_t) V(St)是状态 S t S_t St的价值函数估计, d o n e t done_t donet是表示当前状态是否为终止状态的指示函数(如果当前状态为终止状态,则 d o n e t = 1 done_t = 1 donet=1;否则, d o n e t = 0 done_t = 0 donet=0)。如果 d o n e t = 1 done_t = 1 donet=1,那么 γ V ( S t + 1 ) \gamma V(S_{t+1}) γV(St+1)项将为 0,因为终止状态之后没有未来回报。

  2. GAE优势估计:
    A t G A E ( γ , λ ) = ∑ l = 0 ∞ ( γ λ ) l δ t + l A_t^{GAE(\gamma, \lambda)} = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l} AtGAE(γ,λ)=l=0(γλ)lδt+l
    在代码中,这个无限求和是通过迭代地计算来近似的,具体的迭代公式为:
    A t = δ t + ( γ λ ) A t + 1 ( 1 − d o n e t ) A_t = \delta_t + (\gamma \lambda) A_{t+1} (1 - done_t) At=δt+(γλ)At+1(1donet)
    其中, A t A_t At是时刻 t t t的优势估计, λ \lambda λ是用来平衡TD估计和蒙特卡罗估计之间权重的参数。

  3. 回报的计算:
    G t = A t + V ( S t ) G_t = A_t + V(S_t) Gt=At+V(St)
    其中, G t G_t Gt是时刻 t t t的回报估计。

代码中使用的变量名与数学符号的对应关系:

变量名数学符号含义
rewards[step] R t R_t Rt时刻 t t t的奖励
gamma γ \gamma γ折扣因子,用于计算未来奖励的现值
values[step] V ( S t ) V(S_t) V(St)状态 S t S_t St在当前策略下的价值函数估计
dones[step] d o n e t done_t donet指示当前状态 S t S_t St是否为终止状态的标志(1 表示终止,0 表示非终止)
delta δ t \delta_t δt时刻 t t t的 TD 残差
advantage A t A_t At时刻 t t t的优势估计,根据 GAE 方法计算
lam λ \lambda λ用于 GAE 计算中平衡 TD 估计和蒙特卡罗估计之间权重的参数
returns[step] G t G_t Gt时刻 t t t的回报估计
advantages A t n o r m A_t^{norm} Atnorm标准化后的优势估计
mu_A, sigma_A μ A \mu_A μA, σ A \sigma_A σA优势估计的平均值和标准差
epsilon ϵ \epsilon ϵ避免除零错误而加的小常数,通常取值为 1e-8

代码中的循环从最后一个转换开始向前迭代,使用以上的数学公式来计算每一步的优势和回报。最后,它还对优势进行了标准化处理,即从每个优势中减去所有优势的平均值,并除以标准差,以减少训练期间的方差并加速收敛。标准化公式如下:
A t n o r m = A t − μ A σ A + ϵ A_t^{norm} = \frac{A_t - \mu_A}{\sigma_A + \epsilon} Atnorm=σA+ϵAtμA
其中, μ A \mu_A μA是优势的平均值, σ A \sigma_A σA是优势的标准差, ϵ \epsilon ϵ​ 是为了防止除以零而加的一个小常数(在代码中为 1e-8)。

  • 轨迹的平均长度

类中并没有显式存储轨迹的长度,轨迹长度隐含在self.dones之中。代码中使用的方法是:将每个环境中最后一步置为‘1’,然后flatten(展开)、拼接所有环境中的dones得到flat_dones,差分数组中为‘1’位置的索引得到智能体在每个环境中的步数,即轨迹长度。这个统计量有助于了解训练过程中智能体的表现。

  • mini-batch迭代器

mini_batch_generator 函数通过在多个训练周期(num_epochs)内,从经验回放缓冲区中随机选择小批量数据(包括观察值 observations、动作 actions、奖励 rewards 等)来生成小批量数据集。该函数利用 torch.randperm 生成随机索引 indices 来随机化数据抽样,进而支持基于批处理的学习方法,如梯度下降。通过每次只处理必要的数据量,该生成器在优化模型参数的同时,也优化了内存使用,确保了训练过程的高效性和灵活性。

(未完待续)

在这里插入图片描述

这篇关于ETH开源PPO算法学习的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/757525

相关文章

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

Python中的随机森林算法与实战

《Python中的随机森林算法与实战》本文详细介绍了随机森林算法,包括其原理、实现步骤、分类和回归案例,并讨论了其优点和缺点,通过面向对象编程实现了一个简单的随机森林模型,并应用于鸢尾花分类和波士顿房... 目录1、随机森林算法概述2、随机森林的原理3、实现步骤4、分类案例:使用随机森林预测鸢尾花品种4.1

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

不懂推荐算法也能设计推荐系统

本文以商业化应用推荐为例,告诉我们不懂推荐算法的产品,也能从产品侧出发, 设计出一款不错的推荐系统。 相信很多新手产品,看到算法二字,多是懵圈的。 什么排序算法、最短路径等都是相对传统的算法(注:传统是指科班出身的产品都会接触过)。但对于推荐算法,多数产品对着网上搜到的资源,都会无从下手。特别当某些推荐算法 和 “AI”扯上关系后,更是加大了理解的难度。 但,不了解推荐算法,就无法做推荐系

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

康拓展开(hash算法中会用到)

康拓展开是一个全排列到一个自然数的双射(也就是某个全排列与某个自然数一一对应) 公式: X=a[n]*(n-1)!+a[n-1]*(n-2)!+...+a[i]*(i-1)!+...+a[1]*0! 其中,a[i]为整数,并且0<=a[i]<i,1<=i<=n。(a[i]在不同应用中的含义不同); 典型应用: 计算当前排列在所有由小到大全排列中的顺序,也就是说求当前排列是第

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

综合安防管理平台LntonAIServer视频监控汇聚抖动检测算法优势

LntonAIServer视频质量诊断功能中的抖动检测是一个专门针对视频稳定性进行分析的功能。抖动通常是指视频帧之间的不必要运动,这种运动可能是由于摄像机的移动、传输中的错误或编解码问题导致的。抖动检测对于确保视频内容的平滑性和观看体验至关重要。 优势 1. 提高图像质量 - 清晰度提升:减少抖动,提高图像的清晰度和细节表现力,使得监控画面更加真实可信。 - 细节增强:在低光条件下,抖