强化学习之图解SAC算法

2023-11-23 14:31
文章标签 算法 学习 图解 强化 sac

本文主要是介绍强化学习之图解SAC算法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

强化学习之图解SAC算法

  • 1. 网络结构
  • 2. 产生experience的过程
  • 3. Q Critic网络的更新流程
  • 4. V Critic网络的更新流程
  • 5. Actor网络的更新流程

柔性动作-评价(Soft Actor-Critic,SAC)算法的网络结构有5个。SAC算法解决的问题是离散动作空间和连续动作空间的强化学习问题,是off-policy的强化学习算法(关于on-policy和off-policy的讨论可见:强化学习之图解PPO算法和TD3算法)。

SAC的论文有两篇,一篇是《Soft Actor-Critic Algorithms and Applications》,2019年1月发表,其中SAC算法流程如下所示,它包括1个actor网络,4个Q Critic网络:

在这里插入图片描述

一篇是《Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor》,2018年8月发表,其中SAC算法流程如下所示,它包括1个actor网络,2个V Critic网络(1个V Critic网络,1个Target V Critic网络),2个Q Critic网络:

在这里插入图片描述

本文介绍的算法思路是1个actor网络,2个V Critic网络(1个V Critic网络,1个Target V Critic网络),2个Q Critic网络。而另一种SAC算法思路可以参考openAI的spinning up教程:openAI spinning up

1. 网络结构

关于SAC算法的网络结构图解,笔者认为此链接的讲解也非常地好:Soft Actor-Critic,本文和此链接的说法一致。

在这里插入图片描述

一个actor网络,四个critic网络,分别是状态价值估计 v v v和Target v v v网络;动作-状态价值估计 Q 0 Q_0 Q0 Q 1 Q_1 Q1网络。

actor网络的输入为状态,输出为动作概率 π ( a t ∣ s t ) \pi(a_t|s_t) π(atst)(对于离散动作空间而言)或者动作概率分布参数(对于连续动作空间而言)

critic网络的输入为状态,输出为状态的价值。其中 V Critic 网络的输出为 v ( s ) v(s) v(s),代表状态价值的估计Q Critic 网络的输出为 q ( s , a ) q(s,a) q(s,a),代表动作-状态对价值(以下简称为动作价值的估计

因为在SAC算法中为了鼓励探索,增加了熵的概念,所以它actor和critic网络的训练目标和常规不含熵的算法(如TD3,PPO)的训练目标不一样。

在SAC算法中,如果actor网络输出的动作越能够使一个综合指标(既包含动作价值 q q q,又包含熵 h h h)变大,那么就越好。

如果Q critic网络输出的动作价值 q q q越准确(根据贝尔曼方程可知, q q q是否准确依赖于 v v v是否准确),那么就越好。

如果V critic网络输出的状态价值 v v v越准确,那么就越好。但需要注意的是,因为SAC中加了熵的概念,所以状态价值 v v v并不是我们通常理解的 v ( s ) v(s) v(s),它其中还加了熵这一项。

接下来只说SAC的算法流程,而不对其中的公式做过多的解释,具体SAC算法的推导过程可以参考《最前沿:深度解读Soft Actor-Critic 算法》。

2. 产生experience的过程

已知一个状态 s t s_t st,通过 actor网络 得到所有动作的概率 π ( a ∣ s t ) \pi(a|s_t) π(ast)(图中以三个动作: a 1 , a 2 , a 3 a_1,a_2,a_3 a1,a2,a3为例),然后依概率采样得到动作 a t = a 2 a_t=a_2 at=a2,然后将 a 2 a_2 a2输入到环境中,得到 s t + 1 s_{t+1} st+1 r t + 1 r_{t+1} rt+1,这样就得到一个experience: ( s t , a 2 , s t + 1 , r t + 1 ) (s_t, a_2, s_{t+1}, r_{t+1}) (st,a2,st+1,rt+1),然后将experience放入经验池中。

以上是离散动作的情况,如果是连续动作,就输出概率分布的参数(比如高斯分布的均值和方差),然后按照概率分布去采样得到动作 a t a_t at.

经验池 存在的意义是为了消除experience的相关性,因为强化学习中前后动作通常是强相关的,而将它们打散,放入经验池中,然后在训练神经网络时,随机地从经验池中选出一批experience,这样能够使神经网络训练地更好。

在这里插入图片描述

3. Q Critic网络的更新流程

在这里插入图片描述

拿从经验池buffer中采出的数据 ( s t , a t , s t + 1 , r t + 1 ) (s_t, a_t, s_{t+1}, r_{t+1}) (st,at,st+1,rt+1)进行Critic网络的更新,以 ( s t , a 2 , s t + 1 , r t + 1 ) (s_t, a_2, s_{t+1}, r_{t+1}) (st,a2,st+1,rt+1)为例。

基于最优贝尔曼方程,用 U t ( q ) = r t + γ v ( s t + 1 ) U_{t}^{\left( q \right)}=r_t+\gamma v(s_{t+1}) Ut(q)=rt+γv(st+1)作为状态 s t s_t st真实价值估计,而用实际采用的动作 a 2 a_2 a2 q i ( s t , a 2 ) q_i(s_t,a_2) qi(st,a2) ( 其 中 , i = 0 , 1 ) (其中,i=0,1) (i=0,1)作为状态 s t s_t st预测价值估计,最后用MSEloss作为Loss函数,对神经网络 Q 0 Q_0 Q0, Q 1 Q_1 Q1进行训练。

注意取MSELoss就意味着对从经验池buffer中取一个batch的数据进行了求平均的操作,即:

L o s s = 1 ∣ B ∣ ∑ ( s t , a t , r t + 1 , s t + 1 ) ∈ B [ q i ( s t , a t ; w ( i ) ) − U t ( q ) ] 2 Loss=\frac{1}{|\mathcal{B}|}\sum_{\left( s_t,a_t,r_{t+1},s_{t+1} \right) \in \mathcal{B}}{\left[ q_i\left( s_t,a_t;w^{\left( i \right)} \right) -U_{t}^{\left( q \right)} \right] ^2} Loss=B1(st,at,rt+1,st+1)B[qi(st,at;w(i))Ut(q)]2

pytorch代码如下:

        # train Q criticnext_v_tensor = self.v_target_net(next_state_tensor)q_target_tensor = reward_tensor.unsqueeze(1) + self.gamma *                 (1. - done_tensor.unsqueeze(1)) * next_v_tensorall_q0_pred_tensor = self.q0_net(state_tensor)q0_pred_tensor = torch.gather(all_q0_pred_tensor, 1, action_tensor.unsqueeze(1))q0_loss_tensor = self.q0_loss(q0_pred_tensor, q_target_tensor.detach())self.q0_optimizer.zero_grad()q0_loss_tensor.backward()self.q0_optimizer.step()all_q1_pred_tensor = self.q1_net(state_tensor)q1_pred_tensor = torch.gather(all_q1_pred_tensor, 1, action_tensor.unsqueeze(1))q1_loss_tensor = self.q1_loss(q1_pred_tensor, q_target_tensor.detach())self.q1_optimizer.zero_grad()q1_loss_tensor.backward()self.q1_optimizer.step()

4. V Critic网络的更新流程

在这里插入图片描述

拿从经验池buffer中采出的数据 ( s t , a t , s t + 1 , r t + 1 ) (s_t, a_t, s_{t+1}, r_{t+1}) (st,at,st+1,rt+1)进行V Critic网络的更新,接着 ( s t , a 2 , s t + 1 , r t + 1 ) (s_t, a_2, s_{t+1}, r_{t+1}) (st,a2,st+1,rt+1)的例子。

用含熵的式子进行状态价值估计,即下式作为V critic网络输出的真实值:

U t ( v ) = E a t ′ ∼ π ( ⋅ ∣ s t ; θ ) [ min ⁡ i = 0 , 1 q i ( s t , a t ′ ; w ( i ) ) − α ln ⁡ π ( a t ′ ∣ s t ; θ ) ] = ∑ a t ′ ∈ A ( s t ) π ( a t ′ ∣ s t ; θ ) [ min ⁡ i = 0 , 1 q i ( s t , a t ′ ; w ( i ) ) − α ln ⁡ π ( a t ′ ∣ s t ; θ ) ] U_{t}^{\left( v \right)}=E_{a_{t}^{'}\sim \pi \left( \cdot |s_t;\theta \right)}\left[ \underset{i=0,1}{\min}q_i\left( s_t,a_{t}^{'};w^{\left( i \right)} \right) -\alpha \ln \pi \left( a_{t}^{'}|s_t;\theta \right) \right] \\ =\sum_{a_{t}^{'}\in \mathbb{A}\left( s_t \right)}{\pi \left( a_{t}^{'}|s_t;\theta \right) \left[ \underset{i=0,1}{\min}q_i\left( s_t,a_{t}^{'};w^{\left( i \right)} \right) -\alpha \ln \pi \left( a_{t}^{'}|s_t;\theta \right) \right]} Ut(v)=Eatπ(st;θ)[i=0,1minqi(st,at;w(i))αlnπ(atst;θ)]=atA(st)π(atst;θ)[i=0,1minqi(st,at;w(i))αlnπ(atst;θ)]

可以看到 π ( a t ′ ∣ s t ; θ ) \pi \left( a_{t}^{'}|s_t;\theta \right) π(atst;θ) min ⁡ i = 0 , 1 q i ( s t , a t ′ ; w ( i ) ) \underset{i=0,1}{\min}q_i\left( s_t,a_{t}^{'};w^{\left( i \right)} \right) i=0,1minqi(st,at;w(i)) ln ⁡ π ( a t ′ ∣ s t ; θ ) \ln \pi \left( a_{t}^{'}|s_t;\theta \right) lnπ(atst;θ)这三项和图中的Loss三个输入箭头完全一致。

用V critic网络的输出作为预测值,最后用MSEloss作为Loss函数,对神经网络 V V V进行训练。

注意取MSELoss就意味着对从经验池buffer中取一个batch的数据进行了求平均的操作,即:

L o s s = 1 ∣ B ∣ ∑ ( s t , a t , r t + 1 , s t + 1 ) ∈ B [ v ( s t ; w ( v ) ) − U t ( v ) ] 2 Loss=\frac{1}{|\mathcal{B}|}\sum_{\left( s_t,a_t,r_{t+1},s_{t+1} \right) \in \mathcal{B}}{\left[ v\left( s_t;w^{\left( v \right)} \right) -U_{t}^{\left( v \right)} \right] ^2} Loss=B1(st,at,rt+1,st+1)B[v(st;w(v))Ut(v)]2

pytorch代码如下:

        # train V criticq0_tensor = self.q0_net(state_tensor)q1_tensor = self.q1_net(state_tensor)q01_tensor = torch.min(q0_tensor, q1_tensor)prob_tensor = self.actor_net(state_tensor)ln_prob_tensor = torch.log(prob_tensor.clamp(1e-6, 1.))entropic_q01_tensor = prob_tensor * (q01_tensor -self.alpha * ln_prob_tensor)# OR entropic_q01_tensor = prob_tensor * (q01_tensor - \#         self.alpha * torch.xlogy(prob_tensor, prob_tensor)v_target_tensor = torch.sum(entropic_q01_tensor, dim=-1, keepdim=True)v_pred_tensor = self.v_evaluate_net(state_tensor)v_loss_tensor = self.v_loss(v_pred_tensor, v_target_tensor.detach())self.v_optimizer.zero_grad()v_loss_tensor.backward()self.v_optimizer.step()self.update_net(self.v_target_net, self.v_evaluate_net)

5. Actor网络的更新流程

在这里插入图片描述

对actor网络训练的loss稍微有些复杂,其表达式为:

L o s s = − 1 ∣ B ∣ ∑ ( s t , a t , r t + 1 , s t + 1 ) ∈ B E a t ′ ∼ π ( ⋅ ∣ s t ; θ ) [ q 0 ( s t , a t ′ ) − α ln ⁡ π ( a t ′ ∣ s t ; θ ) ] Loss=-\frac{1}{|\mathcal{B}|}\sum_{\left( s_t,a_t,r_{t+1},s_{t+1} \right) \in \mathcal{B}}{E_{a_{t}^{'}\sim \pi \left( \cdot |s_t;\theta \right)}}\left[ q_0\left( s_t,a_{t}^{'} \right) -\alpha \ln \pi \left( a_{t}^{'}|s_t;\theta \right) \right] Loss=B1(st,at,rt+1,st+1)BEatπ(st;θ)[q0(st,at)αlnπ(atst;θ)]

E a t ′ ∼ π ( ⋅ ∣ s t ; θ ) [ . . . . ] E_{a_{t}^{'}\sim \pi \left( \cdot |s_t;\theta \right)}[....] Eatπ(st;θ)[....]代表需要对中括号里面的项取期望,注意: a t ′ a_{t}^{'} at并不是在buffer中取出的数据 ( s t , a t , r t + 1 , s t + 1 ) \left( s_t,a_t,r_{t+1},s_{t+1} \right) (st,at,rt+1,st+1)中的 a t a_t at,而是重新用actor网络 π \pi π预测的所有可能的动作,因此对于离散动作空间,常有以下的等价计算方法:

E a t ′ ∼ π ( ⋅ ∣ s t ; θ ) [ q 0 ( s t , a t ′ ; w ( 0 ) ) − α ln ⁡ π ( a t ′ ∣ s t ; θ ) ] = ∑ a t ′ ∈ A ( s t ) π ( a t ′ ∣ s t ; θ ) [ q 0 ( s t , a t ′ ; w ( 0 ) ) − α ln ⁡ π ( a t ′ ∣ s t ; θ ) ] E_{a_{t}^{'}\sim \pi \left( \cdot |s_t;\theta \right)}\left[ q_0\left( s_t,a_{t}^{'};w^{\left( 0 \right)} \right) -\alpha \ln \pi \left( a_{t}^{'}|s_t;\theta \right) \right] \\ =\sum_{a_{t}^{'}\in \mathbb{A}\left( s_t \right)}{\pi \left( a_{t}^{'}|s_t;\theta \right) \left[ q_0\left( s_t,a_{t}^{'};w^{\left( 0 \right)} \right) -\alpha \ln \pi \left( a_{t}^{'}|s_t;\theta \right) \right]} Eatπ(st;θ)[q0(st,at;w(0))αlnπ(atst;θ)]=atA(st)π(atst;θ)[q0(st,at;w(0))αlnπ(atst;θ)]

可以看到 π ( a t ′ ∣ s t ; θ ) \pi \left( a_{t}^{'}|s_t;\theta \right) π(atst;θ) q 0 ( s t , a t ′ ; w ( 0 ) ) q_0\left( s_t,a_{t}^{'};w^{\left( 0 \right)} \right) q0(st,at;w(0)) ln ⁡ π ( a t ′ ∣ s t ; θ ) \ln \pi \left( a_{t}^{'}|s_t;\theta \right) lnπ(atst;θ)这三项和图中的Loss三个输入箭头完全一致。需要注意的是 q 0 ( s t , a t ′ ; w ( 0 ) ) q_0\left( s_t,a_{t}^{'};w^{\left( 0 \right)} \right) q0(st,at;w(0))可以用 q 1 ( s t , a t ′ ; w ( 1 ) ) q_1\left( s_t,a_{t}^{'};w^{\left( 1 \right)} \right) q1(st,at;w(1))替换,这两个Q critic网络在功能上是等价的。

B \mathcal{B} B代表经验池buffer,即求Loss的时候还需要对经验池中取出的样本取平均。这样能够体现取出的样本平均意义下的好坏。

其中: α \alpha α是熵的奖励系数,它决定熵 ln ⁡ π ( a t + 1 ∣ s t ; θ ) \ln \pi \left( a_{t+1}|s_t;\theta \right) lnπ(at+1st;θ)的重要性,越大越重要。

pytorch代码如下:

# train actorprob_q_tensor = prob_tensor * (self.alpha * ln_prob_tensor - q0_tensor)actor_loss_tensor = prob_q_tensor.sum(axis=-1).mean()self.actor_optimizer.zero_grad()actor_loss_tensor.backward()self.actor_optimizer.step()

如果觉得文章有帮助,可以关注我、并且给文章点赞并收藏,欢迎大家关注我的知乎同名账号:ReEchooo。主页链接

这篇关于强化学习之图解SAC算法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/418582

相关文章

龙蜥操作系统Anolis OS-23.x安装配置图解教程(保姆级)

《龙蜥操作系统AnolisOS-23.x安装配置图解教程(保姆级)》:本文主要介绍了安装和配置AnolisOS23.2系统,包括分区、软件选择、设置root密码、网络配置、主机名设置和禁用SELinux的步骤,详细内容请阅读本文,希望能对你有所帮助... ‌AnolisOS‌是由阿里云推出的开源操作系统,旨

Python中的随机森林算法与实战

《Python中的随机森林算法与实战》本文详细介绍了随机森林算法,包括其原理、实现步骤、分类和回归案例,并讨论了其优点和缺点,通过面向对象编程实现了一个简单的随机森林模型,并应用于鸢尾花分类和波士顿房... 目录1、随机森林算法概述2、随机森林的原理3、实现步骤4、分类案例:使用随机森林预测鸢尾花品种4.1

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

不懂推荐算法也能设计推荐系统

本文以商业化应用推荐为例,告诉我们不懂推荐算法的产品,也能从产品侧出发, 设计出一款不错的推荐系统。 相信很多新手产品,看到算法二字,多是懵圈的。 什么排序算法、最短路径等都是相对传统的算法(注:传统是指科班出身的产品都会接触过)。但对于推荐算法,多数产品对着网上搜到的资源,都会无从下手。特别当某些推荐算法 和 “AI”扯上关系后,更是加大了理解的难度。 但,不了解推荐算法,就无法做推荐系

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

康拓展开(hash算法中会用到)

康拓展开是一个全排列到一个自然数的双射(也就是某个全排列与某个自然数一一对应) 公式: X=a[n]*(n-1)!+a[n-1]*(n-2)!+...+a[i]*(i-1)!+...+a[1]*0! 其中,a[i]为整数,并且0<=a[i]<i,1<=i<=n。(a[i]在不同应用中的含义不同); 典型应用: 计算当前排列在所有由小到大全排列中的顺序,也就是说求当前排列是第

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

综合安防管理平台LntonAIServer视频监控汇聚抖动检测算法优势

LntonAIServer视频质量诊断功能中的抖动检测是一个专门针对视频稳定性进行分析的功能。抖动通常是指视频帧之间的不必要运动,这种运动可能是由于摄像机的移动、传输中的错误或编解码问题导致的。抖动检测对于确保视频内容的平滑性和观看体验至关重要。 优势 1. 提高图像质量 - 清晰度提升:减少抖动,提高图像的清晰度和细节表现力,使得监控画面更加真实可信。 - 细节增强:在低光条件下,抖