强化学习原理python篇08——actor-critic

2024-02-02 13:20

本文主要是介绍强化学习原理python篇08——actor-critic,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

强化学习原理python篇08——actor-critic

  • 前置知识
    • TD Error
    • REINFORCE
    • QAC
    • Advantage actor-critic (A2C)
  • torch实现步骤
    • 第一步
    • 第二步
    • 第三步
    • 训练
    • 结果
  • Ref

本章全篇参考赵世钰老师的教材 Mathmatical-Foundation-of-Reinforcement-Learning Actor-Critic Methods 章节,请各位结合阅读,本合集只专注于数学概念的代码实现。

前置知识

TD Error

如果用 v ^ ( s , w ) \hat v(s,w) v^(s,w)代表状态值函数,则TD Error表示为
r t + 1 + γ v ^ ( s t + 1 , w ) − v ^ ( s t , w ) r_{t+1}+\gamma \hat v(s_{t+1},w) -\hat v(s_{t},w) rt+1+γv^(st+1,w)v^(st,w)

令损失函数
J w = E [ v ( s t ) − v ^ ( s t , w ) ] 2 J_w = E[ v(s_{t}) -\hat v(s_{t},w)]^2 Jw=E[v(st)v^(st,w)]2

则利用梯度下降法最小化 J θ J_\theta Jθ
w k + 1 = w k − α ∇ w J ( w k ) = w k − α [ − 2 E ( [ r t + 1 + γ v ^ ( s t + 1 , w ) − v ^ ( s t , w ) ] ) ] ∇ w v ^ ( s t , w ) ) \begin{align*} w_{k+1} =& w_k -\alpha\nabla_w J(w_k)\\ =& w_k -\alpha[-2E([r_{t+1}+\gamma \hat v(s_{t+1},w) -\hat v(s_{t},w)])]\nabla_w \hat v(s_{t},w)) \end{align*} wk+1==wkαwJ(wk)wkα[2E([rt+1+γv^(st+1,w)v^(st,w)])]wv^(st,w))

用随机梯度来估算,则最小化 J θ J_\theta Jθ
w k + 1 = w k − α ∇ w J ( w k ) = w k + α [ r t + 1 + γ v ^ ( s t + 1 , w ) − v ^ ( s t , w ) ] ∇ w v ^ ( s t , w ) ) = w k + α [ v ( s t ) − v ^ ( s t , w ) ] ∇ w v ^ ( s t , w ) ) \begin{align*} w_{k+1} =& w_k -\alpha\nabla_w J(w_k)\\ =& w_k +\alpha[r_{t+1}+\gamma \hat v(s_{t+1},w) -\hat v(s_{t},w)]\nabla_w \hat v(s_{t},w))\\ =& w_k +\alpha[ v(s_{t}) -\hat v(s_{t},w)]\nabla_w \hat v(s_{t},w))\\ \end{align*} wk+1===wkαwJ(wk)wk+α[rt+1+γv^(st+1,w)v^(st,w)]wv^(st,w))wk+α[v(st)v^(st,w)]wv^(st,w))

对于q—value来说,
w k + 1 = w k − α ∇ w J ( w k ) = w k + α [ r t + 1 + γ q ^ ( s t + 1 , a t + 1 , w ) − q ^ ( s t , a t , w ) ] ∇ w q ^ ( s t , a t , w ) ) \begin{align*} w_{k+1} =& w_k -\alpha\nabla_w J(w_k)\\ =& w_k +\alpha[r_{t+1}+\gamma \hat q(s_{t+1}, a_{t+1},w) -\hat q(s_{t}, a_{t},w)]\nabla_w \hat q(s_{t},a_{t},w))\\ \end{align*} wk+1==wkαwJ(wk)wk+α[rt+1+γq^(st+1,at+1,w)q^(st,at,w)]wq^(st,at,w))

REINFORCE

参考上一节

θ t + 1 = θ t + ∇ θ J ( θ t ) = θ t + ∇ θ E S − d , a − π ( S , Θ ) [ q ( s , a ) ∇ θ l n π ( a ∣ s , θ ) ] \begin {align*} θ_{t+1} =& θ_{t} + \nabla _{\theta}J(θ_t)\\=& θ_{t} + \nabla _{\theta}E_{S-d,a-\pi(S,\Theta)}[q(s,a) \nabla _{\theta}ln\pi(a|s,\theta)] \end {align*} θt+1==θt+θJ(θt)θt+θESd,aπ(S,Θ)[q(s,a)θl(as,θ)]
一般来说, ∇ θ l n π ( a ∣ s , θ ) \nabla _{\theta}ln\pi(a|s,\theta) θl(as,θ)是未知的,可以用随机梯度法来估计,则
θ t + 1 = θ t + ∇ θ J ( θ t ) = θ t + ∇ θ [ q ( s , a ) ∇ θ l n π ( a ∣ s , θ ) ] \begin {align*} θ_{t+1} =& θ_{t} + \nabla _{\theta}J(θ_t)\\=& θ_{t} + \nabla _{\theta}[q(s,a) \nabla _{\theta}ln\pi(a|s,\theta)] \end {align*} θt+1==θt+θJ(θt)θt+θ[q(s,a)θl(as,θ)]

QAC

The simplest actor-critic algorithm

  • actor:更新策略

    我们采用reinforce的方法来更新策略函数 π \pi π θ t + 1 = θ t + ∇ θ [ q ( s , a ) ∇ θ l n π ( a ∣ s , θ ) ] \begin {align*} θ_{t+1} =& θ_{t} + \nabla _{\theta}[q(s,a) \nabla _{\theta}ln\pi(a|s,\theta)] \end {align*} θt+1=θt+θ[q(s,a)θl(as,θ)]

  • critic:更新值

    我们采用优化td-error的方法来更新行动值 q q q
    w k + 1 = w k + α [ r t + 1 + γ q ^ ( s t + 1 , a t + 1 , w ) − q ^ ( s t , a t , w ) ] ∇ w q ^ ( s t , a t , w ) ) \begin{align*} w_{k+1} =& w_k +\alpha[r_{t+1}+\gamma \hat q(s_{t+1}, a_{t+1},w) -\hat q(s_{t}, a_{t},w)]\nabla_w \hat q(s_{t},a_{t},w)) \end{align*} wk+1=wk+α[rt+1+γq^(st+1,at+1,w)q^(st,at,w)]wq^(st,at,w))

Advantage actor-critic (A2C)

减小方差的下一步是使基线与状态相关(这是一个好主意,因为不同的状态可能具有非常不同的基线)。确实,要决定某个特定动作在某种状态下的适用性,我们会使用该动作的折扣总奖励。但是,总奖励本身可以表示为状态的价值加上动作的优势值:Q(s,a)=V(s)+A(s,a)(参见DuelingDQN)。

知道每个状态的价值(至少有一个近似值)后,我们就可以用它来计算策略梯度并更新策略网络,以增加具有良好优势值的动作的执行概率,并减少具有劣势优势值的动作的执行概率。策略网络(返回动作的概率分布)被称为行动者(actor),因为它会告诉我们该做什么。另一个网络称为评论家(critic),因为它能使我们了解自己的动作有多好。这种改进有一个众所周知的名称,即advantage actorcritic方法,通常被简称为A2C。
E S − d , a − π ( S , Θ ) [ q ( s , a ) ∇ θ l n π ( a ∣ s , θ ) ] = E S − d , a − π ( S , Θ ) [ ∇ θ l n π ( a ∣ s , θ ) [ q ( s , a ) − v ( s ) ] ] E_{S-d,a-\pi(S,\Theta)}[q(s,a) \nabla _{\theta}ln\pi(a|s,\theta)]=E_{S-d,a-\pi(S,\Theta)}[\nabla _{\theta}ln\pi(a|s,\theta)[q(s,a) -v(s)]] ESd,aπ(S,Θ)[q(s,a)θl(as,θ)]=ESd,aπ(S,Θ)[θl(as,θ)[q(s,a)v(s)]]

  • Advantage(TD error)

    δ t = r t + 1 + γ v ( s t + 1 ; w t ) − v ( s t ; w t ) \delta_t =r_{t+1}+\gamma v(s_{t+1};w_t)- v(s_t;w_t) δt=rt+1+γv(st+1;wt)v(st;wt)

  • actor:更新策略

    我们采用reinforce的方法来更新策略函数 π \pi π

    θ t + 1 = θ t + a δ t ∇ θ [ ∇ θ l n π ( a ∣ s , θ ) ] \begin {align*} θ_{t+1} =& θ_{t} + a\delta_t\nabla _{\theta}[\nabla _{\theta}ln\pi(a|s,\theta)] \end {align*} θt+1=θt+aδtθ[θl(as,θ)]

  • critic:更新值

    1、我们采用优化td-error的方法来更新状态值 v v v w k + 1 = w k − α ∇ w [ v ( s t , w ) − v ^ ( s t , w ) ] 2 \begin{align*} w_{k+1} =& w_k -\alpha\nabla_w[ v(s_{t},w) -\hat v(s_{t},w)]^2 \end{align*} wk+1=wkαw[v(st,w)v^(st,w)]2

    2、在这里,使用实际发生的discount reward来估算 v ( s t , w ) v(s_{t},w) v(st,w)

    3、 w k + 1 = w k − α ∇ w [ R − v ^ ( s t , w ) ] 2 \begin{align*} w_{k+1} =& w_k -\alpha\nabla_w[R -\hat v(s_{t},w)]^2 \end{align*} wk+1=wkαw[Rv^(st,w)]2

torch实现步骤

第一步

  1. 初始化A2CNet,使其返回策略函数pi(s, a),和价值V(s)
import collections
import copy
import math
import random
import time
from collections import defaultdictimport gym
import gym.spaces
import numpy as np
import torch
import torch.nn as nn
import torch.nn.utils as nn_utils
import torch.optim as optim
from gym.envs.toy_text import frozen_lake
from torch.utils.tensorboard import SummaryWriterclass A2CNet(nn.Module):def __init__(self, obs_size, hidden_size, q_table_size):super().__init__()# 策略函数pi(s, a)self.policy_net = nn.Sequential(nn.Linear(obs_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, q_table_size),nn.Softmax(dim=1),)# 价值V(s)self.v_net = nn.Sequential(nn.Linear(obs_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, 1),)def forward(self, state):if len(torch.Tensor(state).size()) == 1:state = state.reshape(1, -1)return self.policy_net(state), self.v_net(state)

第二步

  1. 使用当前策略πθ在环境中交互N步,并保存状态(st)、动作(at)和奖励(rt)
  2. 如果片段到达结尾,则R=0,否则为Vθ(st),这里采用环境产生的R来近似。
def discount_reward(R, gamma):# r 为历史得分n = len(R)dr = 0for i in range(n):dr += gamma**i * R[i]return drdef generate_episode(env, n_steps, net, gamma, predict=False):episode_history = dict()r_list = []for _ in range(n_steps):episode = []predict_reward = []state, info = env.reset()while True:p, v = net(torch.Tensor(state))p = p.detach().numpy().reshape(-1)action = np.random.choice(list(range(env.action_space.n)), p=p)next_state, reward, terminated, truncted, info = env.step(action)# 如果截断,则展开 v(state) = r + gamma*v(next_state)if truncted and not terminated:reward = reward + gamma * float(net(torch.Tensor(next_state))[1].detach())episode.append([state, action, next_state, reward, terminated])predict_reward.append(reward)state = next_stateif terminated or truncted:episode_history[_] = episoder_list.append(len(episode))episode = []predict_reward = []breakif predict:return np.mean(r_list)return episode_historydef calculate_t_discount_reward(reward_list, gamma):discount_reward = []total_reward = 0for i in reward_list[::-1]:total_reward = total_reward * gamma + idiscount_reward.append(total_reward)return discount_reward[::-1]

第三步

  1. 累积策略梯度 θ t + 1 = θ t + a δ t ∇ θ [ ∇ θ l n π ( a ∣ s , θ ) ] \begin {align*} θ_{t+1} =& θ_{t} + a\delta_t\nabla _{\theta}[\nabla _{\theta}ln\pi(a|s,\theta)] \end {align*} θt+1=θt+aδtθ[θl(as,θ)]

  2. 累积价值梯度
    w k + 1 = w k − α ∇ w [ R − v ^ ( s t , w ) ] 2 \begin{align*} w_{k+1} =& w_k -\alpha\nabla_w[R -\hat v(s_{t},w)]^2 \end{align*} wk+1=wkαw[Rv^(st,w)]2

# actor策略损失函数
def loss(net, batch, gamma, entropy_beta=False):l = 0for episode in batch.values():reward_list = [reward for state, action, next_state, reward, terminated in episode]state = [state for state, action, next_state, reward, terminated in episode]action = [action for state, action, next_state, reward, terminated in episode]# actor策略损失函数## max entropyqt = calculate_t_discount_reward(reward_list, gamma)pi = net(torch.Tensor(state))[0]entropy_loss = -torch.sum((pi * torch.log(pi)), axis=1).mean() * entropy_betapi = pi.gather(dim=1, index=torch.LongTensor(action).reshape(-1, 1))l_policy = -torch.Tensor(qt) @ torch.log(pi)if entropy_beta:l_policy -= entropy_loss# critic损失函数critic_loss = nn.MSELoss()(net(torch.Tensor(state))[1].reshape(-1), torch.Tensor(qt))l += l_policy + critic_lossreturn l / len(batch.values())

训练

## 初始化环境
env = gym.make("CartPole-v1", max_episode_steps=200)
# env = gym.make("CartPole-v1", render_mode = "human")state, info = env.reset()obs_n = env.observation_space.shape[0]
hidden_num = 64
act_n = env.action_space.n
a2c = A2CNet(obs_n, hidden_num, act_n)# 定义优化器
opt = optim.Adam(a2c.parameters(), lr=0.01)# 记录
writer = SummaryWriter(log_dir="logs/PolicyGradient/A2C", comment="test1")epochs = 200
batch_size = 20
gamma = 0.9
entropy_beta = 0.01
# 避免梯度太大
CLIP_GRAD = 0.1for epoch in range(epochs):batch = generate_episode(env, batch_size, a2c, gamma)l = loss(a2c, batch, gamma, entropy_beta)# 反向传播opt.zero_grad()l.backward()# 梯度裁剪nn_utils.clip_grad_norm_(a2c.parameters(), CLIP_GRAD)opt.step()max_steps = generate_episode(env, 10, a2c, gamma, predict=True)writer.add_scalars("Loss",{"loss": l.item(), "max_steps": max_steps},epoch,)print("epoch:{},  Loss: {}, max_steps: {}".format(epoch, l.detach(), max_steps))

结果

在这里插入图片描述
可以看到,对比上一节的几种方法,收敛速度和收敛方向都稳定了不少。

Ref

[1] Mathematical Foundations of Reinforcement Learning,Shiyu Zhao
[2] 深度学习强化学习实践(第二版),Maxim Lapan

这篇关于强化学习原理python篇08——actor-critic的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/670824

相关文章

Golang HashMap实现原理解析

《GolangHashMap实现原理解析》HashMap是一种基于哈希表实现的键值对存储结构,它通过哈希函数将键映射到数组的索引位置,支持高效的插入、查找和删除操作,:本文主要介绍GolangH... 目录HashMap是一种基于哈希表实现的键值对存储结构,它通过哈希函数将键映射到数组的索引位置,支持

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

如何使用 Python 读取 Excel 数据

《如何使用Python读取Excel数据》:本文主要介绍使用Python读取Excel数据的详细教程,通过pandas和openpyxl,你可以轻松读取Excel文件,并进行各种数据处理操... 目录使用 python 读取 Excel 数据的详细教程1. 安装必要的依赖2. 读取 Excel 文件3. 读

Python的time模块一些常用功能(各种与时间相关的函数)

《Python的time模块一些常用功能(各种与时间相关的函数)》Python的time模块提供了各种与时间相关的函数,包括获取当前时间、处理时间间隔、执行时间测量等,:本文主要介绍Python的... 目录1. 获取当前时间2. 时间格式化3. 延时执行4. 时间戳运算5. 计算代码执行时间6. 转换为指

利用Python调试串口的示例代码

《利用Python调试串口的示例代码》在嵌入式开发、物联网设备调试过程中,串口通信是最基础的调试手段本文将带你用Python+ttkbootstrap打造一款高颜值、多功能的串口调试助手,需要的可以了... 目录概述:为什么需要专业的串口调试工具项目架构设计1.1 技术栈选型1.2 关键类说明1.3 线程模

Python ZIP文件操作技巧详解

《PythonZIP文件操作技巧详解》在数据处理和系统开发中,ZIP文件操作是开发者必须掌握的核心技能,Python标准库提供的zipfile模块以简洁的API和跨平台特性,成为处理ZIP文件的首选... 目录一、ZIP文件操作基础三板斧1.1 创建压缩包1.2 解压操作1.3 文件遍历与信息获取二、进阶技

Python Transformers库(NLP处理库)案例代码讲解

《PythonTransformers库(NLP处理库)案例代码讲解》本文介绍transformers库的全面讲解,包含基础知识、高级用法、案例代码及学习路径,内容经过组织,适合不同阶段的学习者,对... 目录一、基础知识1. Transformers 库简介2. 安装与环境配置3. 快速上手示例二、核心模

Python正则表达式语法及re模块中的常用函数详解

《Python正则表达式语法及re模块中的常用函数详解》这篇文章主要给大家介绍了关于Python正则表达式语法及re模块中常用函数的相关资料,正则表达式是一种强大的字符串处理工具,可以用于匹配、切分、... 目录概念、作用和步骤语法re模块中的常用函数总结 概念、作用和步骤概念: 本身也是一个字符串,其中

Python使用getopt处理命令行参数示例解析(最佳实践)

《Python使用getopt处理命令行参数示例解析(最佳实践)》getopt模块是Python标准库中一个简单但强大的命令行参数处理工具,它特别适合那些需要快速实现基本命令行参数解析的场景,或者需要... 目录为什么需要处理命令行参数?getopt模块基础实际应用示例与其他参数处理方式的比较常见问http

python实现svg图片转换为png和gif

《python实现svg图片转换为png和gif》这篇文章主要为大家详细介绍了python如何实现将svg图片格式转换为png和gif,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录python实现svg图片转换为png和gifpython实现图片格式之间的相互转换延展:基于Py