SAC算法论文解读

2023-11-23 14:31
文章标签 算法 解读 论文 sac

本文主要是介绍SAC算法论文解读,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

SAC算法

原论文:Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor

SAC算法是off-policy算法,此前的off-policy算法存在采样复杂性高和难收敛的问题,使得超参数十分敏感,SAC算法通过在最大预期return的同时最大化熵值,也就是尽量获得最高累计收益的同时保持探索避免过早掉入局部最优解。SAC结合已有的off-policy模型actor-critic框架使得在不同的随机种子上都能达到SOTA效果。

0 介绍

深度强化学习的快速发展,给机器人控制领域带来了许多进展。此前的工作中,面向连续控制任务的算法有TRPO、PPO、DDPG等算法。

PPO是一种on-policy面向离散和连续控制的算法,在许多数据集上取得了较好的效果,但是存在严重的采样效率低下的问题,这对于真实环境中的控制问题采样花费来说是难以接受的;DDPG是一种off-policy的面向连续控制的问题,比PPO采样效率高但是DDPG训练了一种确定性策略(deterministic policy),在每个状态下只选择一个最优的动作,这样很容易掉入局部最优解的情况。

在连续控制问题中,SAC算法结合已有actor-critic框架,使用随机策略(stochastic policy)最大累计收益的同时也保持熵值最大化,提升了采样效率增强了智能体的探索能力,避免了过早陷入局部最优解的情况,同时也增强了模型在不同初始环境的泛化能力和鲁棒性。

1 预备知识

最大熵强化学习

传统的强化学习是最大化累计回报值:
J ( π ) = ∑ t E ( s t , a t ) ∼ ρ π [ r ( s t , a t ) ] J(\pi)=\sum_t\mathbb{E}_{(s_t,a_t)\sim\rho_\pi}[r(s_t,a_t)] J(π)=tE(st,at)ρπ[r(st,at)]
而最大熵的RL算法的目标函数为:
J ( π ) = ∑ t = 0 T E ( s t , a t ) ∼ ρ π [ r ( s t , a t ) + α H ( π ( ⋅ ∣ s t ) ) ] J(\pi)=\sum^{T}_{t=0}\mathbb{E}_{(s_t,a_t)\sim\rho_\pi[r(s_t,a_t)+\alpha\mathcal{H}(\pi(\cdot|s_t))]} J(π)=t=0TE(st,at)ρπ[r(st,at)+αH(π(st))]
其中 α \alpha α为熵的温度系数超参数,用于调整对熵的重视程度。 H ( π ( ⋅ ∣ s t ) ) \mathcal{H}(\pi(\cdot|s_t)) H(π(st))是熵值,可表示为: H ( π ( ⋅ ∣ s t ) ) = − E s t [ log ⁡ π ( ⋅ ∣ s t ) ] \mathcal{H}(\pi(\cdot|s_t))=-\mathbb{E}_{s_t}[\log \pi(\cdot|s_t)] H(π(st))=Est[logπ(st)]

在累计回报值中加入熵值的目的是使策略随机化(stochastic),在遇到一个state有多个同样优秀的动作时鼓励探索,可以随机从这些动作中选出一个形成trajectory,而不是总选择同一个确定性策略(deterministic)导致模型最终无法学到全局最优解。

2 Soft policy Iteration

在model-free强化学习policy iteration中,常将策略更新过程分为policy evaluation和policy improvement两个阶段。

2.1 Soft policy evaluation

标准的Q function:
Q π ( s , a ) = r ( s , a ) + γ E ( s ′ , a ′ ) ∼ ρ π [ Q ( s ′ , a ′ ) ] Q^\pi(s,a)=r(s,a)+\gamma\mathbb{E}_{(s^\prime,a^\prime)\sim\rho_\pi}[Q(s^\prime,a^\prime)] Qπ(s,a)=r(s,a)+γE(s,a)ρπ[Q(s,a)]
标准的V function:
V π ( s ) = E ( s t , a t ) ∼ ρ π [ Q ( s ′ , a ′ ) ] V^\pi(s)=\mathbb{E}_{(s_t,a_t)\sim\rho_\pi}[Q(s^\prime,a^\prime)] Vπ(s)=E(st,at)ρπ[Q(s,a)]
在标准的方程中引入熵得到Soft Value Function:

Soft Q function:
Q s o f t π ( s , a ) = r ( s , a ) + γ E ( s ′ , a ′ ) ∼ ρ π [ Q ( s ′ , a ′ ) − α log ⁡ ( π ( a ′ ∣ s ′ ) ) ] Q^\pi_{soft}(s,a)=r(s,a)+\gamma\mathbb{E}_{(s^\prime,a^\prime)\sim\rho_\pi}[Q(s^\prime,a^\prime)-\alpha\log(\pi(a^\prime|s^\prime))] Qsoftπ(s,a)=r(s,a)+γE(s,a)ρπ[Q(s,a)αlog(π(as))]
Soft V function:
V s o f t π ( s ′ ) = E ( s ′ , a ′ ) ∼ ρ π [ Q s o f t ( s ′ , a ′ ) − α log ⁡ ( π ( a ′ ∣ s ′ ) ) ] V^\pi_{soft}(s^\prime)=\mathbb{E}_{(s^\prime,a^\prime)\sim\rho_\pi}[Q_{soft}(s^\prime,a^\prime)-\alpha\log(\pi(a^\prime|s^\prime))] Vsoftπ(s)=E(s,a)ρπ[Qsoft(s,a)αlog(π(as))]
由此可得Soft Q和V的Bellman方程:

Q soft π ( s , a ) = r ( s , a ) + γ E ( s ′ , a ′ ) ∼ ρ π [ Q ( s ′ , a ′ ) − α log ⁡ ( π ( a ′ ∣ s ′ ) ) ] = r ( s , a ) + γ E s ′ ∼ ρ [ V soft π ( s ′ ) ] \begin{align*} Q^\pi_{\text{soft}}(s,a) &= r(s,a) + \gamma\mathbb{E}_{(s^\prime,a^\prime)\sim\rho_\pi}[Q(s^\prime,a^\prime)-\alpha\log(\pi(a^\prime|s^\prime))]\\ &= r(s,a) + \gamma\mathbb{E}_{s^\prime\sim\rho}[V^\pi_{\text{soft}}(s^\prime)] \end{align*} Qsoftπ(s,a)=r(s,a)+γE(s,a)ρπ[Q(s,a)αlog(π(as))]=r(s,a)+γEsρ[Vsoftπ(s)]

在固定policy下,使用soft Bellman equation更新Q value直到收敛。

2.2 Soft policy improvement

stochastic policy的重要性:面对多模的(multimodal)的Q function,传统的RL只能收敛到一个选择(左图),而更优的办法是右图,让policy也直接符合Q的分布。

请添加图片描述

为了适应更复杂的任务,MERL中的策略不再是以往的高斯分布形式,而是用基于能量的模型(energy-based model)来表示策略:
π ( a t ∣ s t ) ∝ e x p ( − E ( s t , a t ) ) \pi(a_t|s_t)\propto exp(-\mathcal{E}(s_t,a_t)) π(atst)exp(E(st,at))
为了让EBP和值函数联系起来,设置 E ( s t , a t ) = − 1 α Q s o f t ( s t , a t ) \mathcal{E}(s_t,a_t)=-\frac{1}{\alpha}Q_{soft}(s_t,a_t) E(st,at)=α1Qsoft(st,at),因此 π ( a t ∣ s t ) ∝ e x p ( − 1 α Q s o f t ( s t , a t ) ) \pi(a_t|s_t)\propto exp(-\frac{1}{\alpha}Q_{soft}(s_t,a_t)) π(atst)exp(α1Qsoft(st,at))

由soft v function变形可得:
KaTeX parse error: Expected 'EOF', got '&' at position 14: \pi(s_t,a_t)&̲=&exp(\frac{1}{…
定义softmax(注意此处softmax和神经网络不同,神经网络中的softmax实际上是求分布的最大值soft argmax)
s o f t m a x a f ( a ) : = log ⁡ ∫ e x p f ( a ) d a softmax_af(a):=\log\int expf(a)da softmaxaf(a):=logexpf(a)da
因此 V s o f t ( s t ) = α s o f t m a x a ( 1 α Q s o f t ( s t , a t ) ) V_{soft}(s_t)=\alpha softmax_a(\frac{1}{\alpha}Q_{soft}(s_t,a_t)) Vsoft(st)=αsoftmaxa(α1Qsoft(st,at))

根据Soft Q function可化为softmax形式:
Q s o f t ( s t , a t ) = E [ r t + γ s o f t m a x a Q ( s t + 1 , a t + 1 ) ] Q_{soft}(s_t,a_t)=\mathbb{E}[r_t+\gamma softmax_aQ(s_{t+1},a_{t+1})] Qsoft(st,at)=E[rt+γsoftmaxaQ(st+1,at+1)]
因此整个Policy Iteration流程可总结为:

**soft policy evaluation:**固定policy,使用Bellman方程更新Q值直到收敛
Q s o f t π ( s , a ) = r ( s , a ) + γ E ( s ′ , a ′ ) ∼ ρ π [ Q ( s ′ , a ′ ) − α log ⁡ ( π ( a ′ ∣ s ′ ) ) ] Q^\pi_{soft}(s,a)=r(s,a)+\gamma\mathbb{E}_{(s^\prime,a^\prime)\sim\rho_\pi}[Q(s^\prime,a^\prime)-\alpha\log(\pi(a^\prime|s^\prime))] Qsoftπ(s,a)=r(s,a)+γE(s,a)ρπ[Q(s,a)αlog(π(as))]
**soft policy improvement:**更新policy
π ′ = arg ⁡ min ⁡ π k ∈ ∏ D K L ( π k ( ⋅ ∣ s t ) ∣ ∣ e x p ( 1 α Q s o f t π ( s t , ⋅ ) ) Z s o f t π ( s t ) ) \pi^\prime=\arg\min_{\pi_k\in \prod}D_{KL}(\pi_k(\cdot|s_t)||\frac{exp(\frac{1}{\alpha}Q^\pi_{soft}(s_t,\cdot))}{Z_{soft}^\pi(s_t)}) π=argπkminDKL(πk(st)∣∣Zsoftπ(st)exp(α1Qsoftπ(st,)))

3 Soft Actor-Critic框架

请添加图片描述

SAC算法的构建首先是神经网络化,我们用神经网络来表示Q和Policy: Q θ ( s t , a t ) Q_\theta(s_t,a_t) Qθ(st,at) π ϕ ( a t ∣ s t ) \pi_\phi(a_t|s_t) πϕ(atst)。Q网络比较简单,几层的MLP最后输出一个单值表示Q就可以了,Policy网络需要输出一个分布,一般是输出一个Gaussian包含mean和covariance。下面就是构建神经网络的更新公式。

3.1 Critic

构造两个Q网络,参数通过每次更新Q值小的网络参数,Q网络的损失函数为:
J Q ( θ ) = E ( s t , a t , s t + 1 ) ∼ D [ 1 2 ( Q θ ( s t , a t ) − ( r ( s t , a t ) + γ V θ ˉ ( s t + 1 ) ) ) 2 ] J_Q(\theta)=\mathbb{E}_{(s_t,a_t,s_{t+1})\sim \mathcal{D}}[\frac{1}{2}(Q_\theta(s_t,a_t)-(r(s_t,a_t)+\gamma V_{\bar{\theta}}(s_{t+1})))^2] JQ(θ)=E(st,at,st+1)D[21(Qθ(st,at)(r(st,at)+γVθˉ(st+1)))2]
θ ˉ \bar{\theta} θˉ是target soft Q网络的参数,带入V的迭代表达式:
J Q ( θ ) = E ( s t , a t , s t + 1 ) ∼ D [ 1 2 ( Q θ ( s t , a t ) − ( r ( s t , a t ) + γ ( Q θ ˉ ( s t + 1 , a t + 1 ) − α log ⁡ ( π ( a t + 1 ∣ s t + 1 ) ) ) ) ) 2 ] J_Q(\theta)=\mathbb{E}_{(s_t,a_t,s_{t+1})\sim \mathcal{D}}[\frac{1}{2}(Q_\theta(s_t,a_t)-(r(s_t,a_t)+\gamma (Q_{\bar \theta}(s_{t+1},a_{t+1})-\alpha\log(\pi(a_{t+1}|s_{t+1})))))^2] JQ(θ)=E(st,at,st+1)D[21(Qθ(st,at)(r(st,at)+γ(Qθˉ(st+1,at+1)αlog(π(at+1st+1)))))2]

3.2 Actor

Policy网络的损失函数为:
KaTeX parse error: Expected 'EOF', got '&' at position 13: J_\pi(\phi)&̲=&D_{KL}(\pi_k(…
其中策略网络的输出是一个动作分布,即高斯分布的均值和方差,这里的action采用重参数技巧来获得,即:
a t = f ϕ ( ϵ t ; s t ) = f ϕ μ ( s t ) + ϵ t ⋅ f ϕ μ ( s t ) a_t=f_\phi(\epsilon_t;s_t)=f^\mu_\phi(s_t)+\epsilon_t\cdot f^\mu_\phi(s_t) at=fϕ(ϵt;st)=fϕμ(st)+ϵtfϕμ(st)

3.3 Update temperature

前面的SAC中,我们只是人为给定一个固定的temperature α \alpha α作为entropy的权重,但实际上由于reward的不断变化,采用固定的temperature并不合理,会让整个训练不稳定,因此,有必要能够自动调节这个temperature。当policy探索到新的区域时,最优的action还不清楚,应该调高temperature 去探索更多的空间。当某一个区域已经探索得差不多,最优的action基本确定了,那么这个temperature就可以减小。

通过构造一个带约束的优化问题,让熵权重在不同状态下权重可变,得到权重的loss:
J ( α ) = E a t ∼ π t [ − α log ⁡ π t ( a t ∣ π t ) − α H 0 ] J(\alpha)=\mathbb{E}_{a_t\sim\pi_t}[-\alpha \log \pi_t(a_t|\pi_t)-\alpha\mathcal{H}_0] J(α)=Eatπt[αlogπt(atπt)αH0]
soft actor-critic算法用伪代码可表示为:

请添加图片描述

4 实验

请添加图片描述

在连续控制的benchmark上表现效果比大多数SOTA算法(DDPG、PPO、SQL、TD3)好。

5 总结

基于最大熵的强化学习算法优势:

1)学到policy可以作为更复杂具体任务的初始化。因为通过最大熵,policy不仅仅学到一种解决任务的方法,而是所有all。因此这样的policy就更有利于去学习新的任务。比如我们一开始是学走,然后之后要学朝某一个特定方向走。

2)更强的exploration能力,这是显而易见的,能够更容易的在多模态reward (multimodal reward)下找到更好的模式。比如既要求机器人走的好,又要求机器人节约能源。

3)更robust鲁棒,更强的generalization。因为要从不同的方式来探索各种最优的可能性,也因此面对干扰的时候能够更容易做出调整。(干扰会是神经网络学习过程中看到的一种state,既然已经探索到了,学到了就可以更好的做出反应,继续获取高reward)。

虽然SAC算法采用了energy-based模型,但是实际上策略分布仍为高斯分布,存在一定的局限性。

这篇关于SAC算法论文解读的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/418587

相关文章

MySQL中时区参数time_zone解读

《MySQL中时区参数time_zone解读》MySQL时区参数time_zone用于控制系统函数和字段的DEFAULTCURRENT_TIMESTAMP属性,修改时区可能会影响timestamp类型... 目录前言1.时区参数影响2.如何设置3.字段类型选择总结前言mysql 时区参数 time_zon

Python中的随机森林算法与实战

《Python中的随机森林算法与实战》本文详细介绍了随机森林算法,包括其原理、实现步骤、分类和回归案例,并讨论了其优点和缺点,通过面向对象编程实现了一个简单的随机森林模型,并应用于鸢尾花分类和波士顿房... 目录1、随机森林算法概述2、随机森林的原理3、实现步骤4、分类案例:使用随机森林预测鸢尾花品种4.1

MySQL中的锁和MVCC机制解读

《MySQL中的锁和MVCC机制解读》MySQL事务、锁和MVCC机制是确保数据库操作原子性、一致性和隔离性的关键,事务必须遵循ACID原则,锁的类型包括表级锁、行级锁和意向锁,MVCC通过非锁定读和... 目录mysql的锁和MVCC机制事务的概念与ACID特性锁的类型及其工作机制锁的粒度与性能影响多版本

Redis过期键删除策略解读

《Redis过期键删除策略解读》Redis通过惰性删除策略和定期删除策略来管理过期键,惰性删除策略在键被访问时检查是否过期并删除,节省CPU开销但可能导致过期键滞留,定期删除策略定期扫描并删除过期键,... 目录1.Redis使用两种不同的策略来删除过期键,分别是惰性删除策略和定期删除策略1.1惰性删除策略

Redis与缓存解读

《Redis与缓存解读》文章介绍了Redis作为缓存层的优势和缺点,并分析了六种缓存更新策略,包括超时剔除、先删缓存再更新数据库、旁路缓存、先更新数据库再删缓存、先更新数据库再更新缓存、读写穿透和异步... 目录缓存缓存优缺点缓存更新策略超时剔除先删缓存再更新数据库旁路缓存(先更新数据库,再删缓存)先更新数

C#反射编程之GetConstructor()方法解读

《C#反射编程之GetConstructor()方法解读》C#中Type类的GetConstructor()方法用于获取指定类型的构造函数,该方法有多个重载版本,可以根据不同的参数获取不同特性的构造函... 目录C# GetConstructor()方法有4个重载以GetConstructor(Type[]

不懂推荐算法也能设计推荐系统

本文以商业化应用推荐为例,告诉我们不懂推荐算法的产品,也能从产品侧出发, 设计出一款不错的推荐系统。 相信很多新手产品,看到算法二字,多是懵圈的。 什么排序算法、最短路径等都是相对传统的算法(注:传统是指科班出身的产品都会接触过)。但对于推荐算法,多数产品对着网上搜到的资源,都会无从下手。特别当某些推荐算法 和 “AI”扯上关系后,更是加大了理解的难度。 但,不了解推荐算法,就无法做推荐系

康拓展开(hash算法中会用到)

康拓展开是一个全排列到一个自然数的双射(也就是某个全排列与某个自然数一一对应) 公式: X=a[n]*(n-1)!+a[n-1]*(n-2)!+...+a[i]*(i-1)!+...+a[1]*0! 其中,a[i]为整数,并且0<=a[i]<i,1<=i<=n。(a[i]在不同应用中的含义不同); 典型应用: 计算当前排列在所有由小到大全排列中的顺序,也就是说求当前排列是第

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

综合安防管理平台LntonAIServer视频监控汇聚抖动检测算法优势

LntonAIServer视频质量诊断功能中的抖动检测是一个专门针对视频稳定性进行分析的功能。抖动通常是指视频帧之间的不必要运动,这种运动可能是由于摄像机的移动、传输中的错误或编解码问题导致的。抖动检测对于确保视频内容的平滑性和观看体验至关重要。 优势 1. 提高图像质量 - 清晰度提升:减少抖动,提高图像的清晰度和细节表现力,使得监控画面更加真实可信。 - 细节增强:在低光条件下,抖