Actor-Critic 跑 CartPole-v1

2023-12-16 11:30
文章标签 v1 actor critic cartpole

本文主要是介绍Actor-Critic 跑 CartPole-v1,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

gym-0.26.1
CartPole-v1
Actor-Critic

这里采用 时序差分残差
ψ t = r t + γ V π θ ( s t + 1 ) − V π θ ( s t ) \psi_t = r_t + \gamma V_{\pi _ \theta} (s_{t+1}) - V_{\pi _ \theta}({s_t}) ψt=rt+γVπθ(st+1)Vπθ(st)
详细请参考 动手学强化学习
简单来说就是 reforce 是采用蒙特卡洛搜索方法来估计Q(s,a) ,然后这里先是把状态价值函数V作为基线, 然后利用Q = r + gamma * V 得到上式。

代码如下

import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
from d2l import torch as d2l
import rl_utils
from tqdm import tqdmclass PolicyNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super().__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, action_dim)def forward(self, X):X = F.relu(self.fc1(X))return F.softmax(self.fc2(X), dim=1)class ValueNet(nn.Module):def __init__(self, state_dim, hidden_dim):super().__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, 1)def forward(self, X):X = F.relu(self.fc1(X))return self.fc2(X)class ActorCritic:def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, device):# 策略网络self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)# 价值网络self.critic = ValueNet(state_dim, hidden_dim).to(device)# 策略网络优化器self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr = actor_lr)#价值网络优化器self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr = critic_lr)self.gamma = gammaself.device = devicedef take_action(self, state):state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)probs = self.actor(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict):states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).reshape(-1,1).to(self.device)rewards = torch.tensor(transition_dict['rewards']).reshape(-1,1).to(device)next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'], dtype=torch.float).reshape(-1,1).to(self.device)# 时分差分目标td_target = rewards + self.gamma * self.critic(next_states) * (1- dones)# 时分差序目标td_delta = td_target - self.critic(states)log_probs = torch.log(self.actor(states).gather(1, actions))actor_loss = torch.mean(-log_probs * td_delta.detach())# 均方误差critic_loss= torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()# 计算策略网络的梯度actor_loss.backward()# 计算价值网络的梯度critic_loss.backward()# 更新策略网络梯度self.actor_optimizer.step()# 跟新价值网络梯度self.critic_optimizer.step()def train(env, agent, num_episodes):return_list = []for i in range(10):with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}state = env.reset()[0]done ,truncated = False, Falsewhile not done and not truncated:action = agent.take_action(state)next_state, reward, done, truncated, info = env.step(action)transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)agent.update(transition_dict)if (i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)return return_list
actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = d2l.try_gpu()env_name = 'CartPole-v1'
env = gym.make(env_name)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = ActorCritic(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, device)return_list = train(env, agent, num_episodes)
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Return')
plt.title(f'Actor-Critic on {env_name}')
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Return')
plt.title(f'Actor-Critic on {env_name}')
plt.show()

jupyter运行结果如下


reforce学习更加稳定,而且总体return也要高一些。

这篇关于Actor-Critic 跑 CartPole-v1的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/500294

相关文章

scala并发编程原生线程Actor、Case Class下的消息传递和偏函数实战

参考代码: import scala.actors._case class Person(name:String,age:Int)class HelloActor extends Actor{def act(){while(true){receive{case Person(name,age)=>{ //偏函数println("Name: "+ name + ":" +"Age:"

TokuDB7.5.7-2.4.1使用TokuDB时的内存注意事项v1

tokudb_cache_size指定TokuDB自己的cache大小,该值默认会为50%的RAM(?TokuDB will allocate 50% of the installed RAM for its own cache)。在如下集中情况,需要手动配置tokudb_cache_size的值: 1 TokuDB和其他占内存型的任务一起跑在同一台机器上 一个保收的配置是其他所有任务都在运行

TokuDB7.5.7-2.1使用TokuDB的系统和硬件需求v1

1 操作系统需求 TokuDB到目前位置只支持64位的Linux系统(所以现在不支持在window上编译运行) 2 硬件需求 内存:至少1G;如果想较好性能,建议2G以上。 外存:建议为数据目录(tokudb_data_dir)和日志目录(tokudb_log_dir)配置足够大的存储空间。

MySQL变量-binlog_format:决定binlog的存储格式v1

1 global和session都可 2 三个值: STATEMENT:sql语句的格式 ROW:具体数据行记录的格式 MIXED:混合格式

UE的Gameplay框架(二) —— Actor和Component

这篇博客聊一下UE的Gameplay框架很重要的一部分 Actor 和 Component 文章目录 ActorComponentSceneComponent注册组件 Actor生命周期参考资料 Actor 如UE文档所述,所有可以放入关卡的对象都是 Actor,比如摄像机、静态网格体、玩家起始位置。Actor 支持三维变换,例如平移、旋转和缩放。在 C++

爆改YOLOv8|利用全新的聚焦式线性注意力模块Focused Linear Attention 改进yolov8(v1)

1,本文介绍 全新的聚焦线性注意力模块(Focused Linear Attention)是一种旨在提高计算效率和准确性的注意力机制。传统的自注意力机制在处理长序列数据时通常计算复杂度较高,限制了其在大规模数据上的应用。聚焦线性注意力模块则通过优化注意力计算的方式,显著降低了计算复杂度。 核心特点: 线性时间复杂度:与传统的自注意力机制不同,聚焦线性注意力模块采用了线性时间复杂度的计算方法

强化学习第十章:Actor-Critic 方法

强化学习第十章:Actor-Critic 方法 什么叫Actor-Critic最简单的AC,QAC(Q Actor-Critic)优势函数的AC,A2C(Advantage Actor-Critic)异策略AC,Off-Policy AC确定性策略梯度,DPG总结参考资料 什么叫Actor-Critic 一句话,策略由动作来执行,执行者叫Actor,评价执行好坏的叫Critic(

docker pull报错: Error response from daemon: Get https://../v1/_ping: http: server gave HTTP response

问题描述,安装好docker私有库之后,不管是从私有库pull还是push,都会报错: Error response from daemon: Get https://xxx.xxx.xxx.xxx:5000/v1/_ping: http: server gave HTTP response to HTTPS client 原因是由于客户端采用https,docker registry未采用h

【并发】Java并发的四种风味:Thread、Executor、ForkJoin和Actor

本文由  ImportNew -  shenggordon 翻译自  Oleg Shelajev。欢迎加入 翻译小组。转载请见文末要求。 这篇文章讨论了Java应用中并行处理的多种方法。从自己管理Java线程,到各种更好几的解决方法,Executor服务、ForkJoin 框架以及计算中的Actor模型。 Java并发编程的4种风格:Threads,Executors,ForkJoin

CV-笔记-重读YOLO目标检测系列 v1

目录 如何检测定义label训练的时候损失函数缺点引用 将对象检测定义为一个回归问题,回归到空间分离的边界框和相关的类概率。与最先进的检测系统相比,YOLO会产生更多的定位错误,但不太可能预测背景上的误报less likely to predict false positives on background(假阳少)都看做一个回归问题,所以不需要复杂的pipeline。titan