[cleanrl] ppo_continuous_action源码解析

2023-12-12 05:44

本文主要是介绍[cleanrl] ppo_continuous_action源码解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1 import库(略)

import os
import random
import time
from dataclasses import dataclassimport gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tyro
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter

2 Args类(略)

定义了所有有关模型的参数,参数含义见英文注释。

@dataclass
class Args:exp_name: str = os.path.basename(__file__)[: -len(".py")]"""the name of this experiment"""seed: int = 1"""seed of the experiment"""torch_deterministic: bool = True"""if toggled, `torch.backends.cudnn.deterministic=False`"""cuda: bool = True"""if toggled, cuda will be enabled by default"""track: bool = False"""if toggled, this experiment will be tracked with Weights and Biases"""wandb_project_name: str = "cleanRL""""the wandb's project name"""wandb_entity: str = None"""the entity (team) of wandb's project"""capture_video: bool = False"""whether to capture videos of the agent performances (check out `videos` folder)"""save_model: bool = False"""whether to save model into the `runs/{run_name}` folder"""upload_model: bool = False"""whether to upload the saved model to huggingface"""hf_entity: str = """""the user or org name of the model repository from the Hugging Face Hub"""# Algorithm specific argumentsenv_id: str = "HalfCheetah-v4""""the id of the environment"""total_timesteps: int = 1000000"""total timesteps of the experiments"""learning_rate: float = 3e-4"""the learning rate of the optimizer"""num_envs: int = 1"""the number of parallel game environments"""num_steps: int = 2048"""the number of steps to run in each environment per policy rollout"""anneal_lr: bool = True"""Toggle learning rate annealing for policy and value networks"""gamma: float = 0.99"""the discount factor gamma"""gae_lambda: float = 0.95"""the lambda for the general advantage estimation"""num_minibatches: int = 32"""the number of mini-batches"""update_epochs: int = 10"""the K epochs to update the policy"""norm_adv: bool = True"""Toggles advantages normalization"""clip_coef: float = 0.2"""the surrogate clipping coefficient"""clip_vloss: bool = True"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""ent_coef: float = 0.0"""coefficient of the entropy"""vf_coef: float = 0.5"""coefficient of the value function"""max_grad_norm: float = 0.5"""the maximum norm for the gradient clipping"""target_kl: float = None"""the target KL divergence threshold"""# to be filled in runtimebatch_size: int = 0"""the batch size (computed in runtime)"""minibatch_size: int = 0"""the mini-batch size (computed in runtime)"""num_iterations: int = 0"""the number of iterations (computed in runtime)"""

3 定义Agent

使用gym.wrappers对原始gym环境进行修改:

  • FlattenObservation:将obs矩阵展平为1维向量
  • RecordEpisodeStatistics:记录episode的统计数据
  • ClipAction:剪裁action以满足action_space的要求
  • NormalizeObservation:对obs矩阵进行归一化
  • TransformObservation:对obs矩阵进行变换
  • NormalizeReward:对reward进行归一化
  • TransformReward:对reward进行变换
def make_env(env_id, idx, capture_video, run_name, gamma):def thunk():if capture_video and idx == 0:env = gym.make(env_id, render_mode="rgb_array")env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")else:env = gym.make(env_id)env = gym.wrappers.FlattenObservation(env)  # deal with dm_control's Dict observation spaceenv = gym.wrappers.RecordEpisodeStatistics(env)env = gym.wrappers.ClipAction(env)env = gym.wrappers.NormalizeObservation(env)env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))env = gym.wrappers.NormalizeReward(env, gamma=gamma)env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))return envreturn thunk

初始化神经网络中的每层的参数。

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):torch.nn.init.orthogonal_(layer.weight, std)torch.nn.init.constant_(layer.bias, bias_const)return layer

PPO(连续动作)的Agent类,Actor-Critic结构,其中Actor网络和Critic网络均基于MLP构建,激活函数使用Tanh

Critic网络的输入尺寸为(batch_size, obs_dim, 64),输出尺寸为(batch_size, 1),作用是形成obs到value的映射。向外暴露get_value函数以计算状态价值。

Actor网络包含两部分:

  • self.action_mean将obs映射到动作均值,输入尺寸为(batch_size, obs_dim, 64),输出尺寸为(batch_size, action_dim)
  • self.actor_logstd是一个(1, action_dim)大小的Parameter,用于形成动作方差的对数(后面需要对其使用torch.exp保证其为正数)

在cleanrl的实现里,Actor网络使用对角高斯分布来生成连续动作的分布,即根据Normal(action_mean, actor_std)对动作进行抽样。

get_action_and_value函数中计算了:

  • 动作分布probs
  • 动作采样probs.sample()
  • 对数似然probs.log_prob(action).sum(1)
  • probs.entropy().sum(1)
  • 状态价值self.critic(x)

在对数似然和熵的计算中,sum(1)用于计算多个相互独立动作的联合概率。

class Agent(nn.Module):def __init__(self, envs):super().__init__()self.critic = nn.Sequential(layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),nn.Tanh(),layer_init(nn.Linear(64, 64)),nn.Tanh(),layer_init(nn.Linear(64, 1), std=1.0),)self.actor_mean = nn.Sequential(layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),nn.Tanh(),layer_init(nn.Linear(64, 64)),nn.Tanh(),layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),)self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))def get_value(self, x):return self.critic(x)def get_action_and_value(self, x, action=None):action_mean = self.actor_mean(x)action_logstd = self.actor_logstd.expand_as(action_mean)action_std = torch.exp(action_logstd)probs = Normal(action_mean, action_std)if action is None:action = probs.sample()return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)

4 训练Agent

设置一些参数,稍微解释一下几个参数的含义:

  • batch_sizenum_envsnum_steps的乘积,表示跑一次迭代能收集到多少样本
  • minibatch_size:每次训练都从大的batch中抽取小的minibatch进行训练
  • num_iterations:整个训练过程跑几轮迭代
args = tyro.cli(Args)
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.num_iterations = args.total_timesteps // args.batch_size
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:import wandbwandb.init(project=args.wandb_project_name,entity=args.wandb_entity,sync_tensorboard=True,config=vars(args),name=run_name,monitor_gym=True,save_code=True,)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text("hyperparameters","|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministicdevice = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

实例化envs、agent以及optim。

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"agent = Agent(envs).to(device)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

定义需要收集的信息

  • obs:观测到的环境状态
  • actions:动作采样值
  • logprobs:动作采样的对数似然
  • rewards:即时奖励
  • dones:episode是否结束
  • values:状态价值
# ALGO Logic: Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)

next_obs存储每步的观测结果,next_done存储每步是否导致episode结束。这两个变量用于计算由最后一个动作导致的下一个状态的价值。

# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
next_obs, _ = envs.reset(seed=args.seed)
next_obs = torch.Tensor(next_obs).to(device)
next_done = torch.zeros(args.num_envs).to(device)

step的for循环里,Actor网络和Critic网络基于当前策略(旧策略)收集样本。因为旧策略不作为参数参与到梯度下降过程,因此需要torch.no_grad()包围相关数值的计算过程。

for iteration in range(1, args.num_iterations + 1):# Annealing the rate if instructed to do so.if args.anneal_lr:frac = 1.0 - (iteration - 1.0) / args.num_iterationslrnow = frac * args.learning_rateoptimizer.param_groups[0]["lr"] = lrnowfor step in range(0, args.num_steps):global_step += args.num_envsobs[step] = next_obsdones[step] = next_done# ALGO LOGIC: action logicwith torch.no_grad():action, logprob, _, value = agent.get_action_and_value(next_obs)values[step] = value.flatten()actions[step] = actionlogprobs[step] = logprob# TRY NOT TO MODIFY: execute the game and log data.next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())next_done = np.logical_or(terminations, truncations)rewards[step] = torch.tensor(reward).to(device).view(-1)next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)if "final_info" in infos:for info in infos["final_info"]:if info and "episode" in info:print(f"global_step={global_step}, episodic_return={info['episode']['r']}")writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)

这部分基于value、reward计算GAE(广义优势估计)。从最后一个reward开始,通过迭代计算:

  • δ t = r t + γ ∗ V ( s t + 1 ) − V ( s t ) \delta_t = r_t+\gamma * V(s_{t+1})-V(s_t) δt=rt+γV(st+1)V(st)
  • a t = δ t + γ ∗ λ ∗ a t + 1 a_t = \delta_t + \gamma * \lambda * a_{t+1} at=δt+γλat+1
###############################################
for iteration in range(1, args.num_iterations + 1):【在iteration的for循环中拼接上一段代码】
################################################ bootstrap value if not donewith torch.no_grad():next_value = agent.get_value(next_obs).reshape(1, -1)advantages = torch.zeros_like(rewards).to(device)lastgaelam = 0for t in reversed(range(args.num_steps)):if t == args.num_steps - 1:nextnonterminal = 1.0 - next_donenextvalues = next_valueelse:nextnonterminal = 1.0 - dones[t + 1]nextvalues = values[t + 1]delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelamreturns = advantages + values

原先的矩阵都是(num_envs, num_steps, XX_dim)的形状,现在转换成(batch_size, XX_dim)的形状,后面要基于batch划分minibatch进行训练。

###############################################
for iteration in range(1, args.num_iterations + 1):【在iteration的for循环中拼接上一段代码】
################################################ flatten the batchb_obs = obs.reshape((-1,) + envs.single_observation_space.shape)b_logprobs = logprobs.reshape(-1)b_actions = actions.reshape((-1,) + envs.single_action_space.shape)b_advantages = advantages.reshape(-1)b_returns = returns.reshape(-1)b_values = values.reshape(-1)# Optimizing the policy and value networkb_inds = np.arange(args.batch_size)clipfracs = []

minibatch的划分是基于b_inds进行的,所以先使用shuffle进行打乱,然后在start的for循环里每次抽取minibatch,计算新的newlogprobentropynewvalue。根据新的和旧的logprob计算ratio,用于后面的PPO截断。

###############################################
for iteration in range(1, args.num_iterations + 1):......
###############################################for epoch in range(args.update_epochs):np.random.shuffle(b_inds)for start in range(0, args.batch_size, args.minibatch_size):end = start + args.minibatch_sizemb_inds = b_inds[start:end]_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])logratio = newlogprob - b_logprobs[mb_inds]ratio = logratio.exp()

首先采用kl-approx使用蒙特卡洛近似KL散度approx_kl,然后获取minibatch的advantage,按需归一化。最后进行PPO截断,计算policy loss。

###############################################
for iteration in range(1, args.num_iterations + 1):......for epoch in range(args.update_epochs):......for start in range(0, args.batch_size, args.minibatch_size):【在start的for循环中拼接上一段代码】
###############################################with torch.no_grad():# calculate approx_kl http://joschu.net/blog/kl-approx.htmlold_approx_kl = (-logratio).mean()approx_kl = ((ratio - 1) - logratio).mean()clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]mb_advantages = b_advantages[mb_inds]if args.norm_adv:mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)# Policy losspg_loss1 = -mb_advantages * ratiopg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)pg_loss = torch.max(pg_loss1, pg_loss2).mean()

根据旧的b_returns和新的newvalue计算value loss。当然这里也提供了value loss clip。

###############################################
for iteration in range(1, args.num_iterations + 1):......for epoch in range(args.update_epochs):......for start in range(0, args.batch_size, args.minibatch_size):【在start的for循环中拼接上一段代码】
################################################ Value lossnewvalue = newvalue.view(-1)if args.clip_vloss:v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2v_clipped = b_values[mb_inds] + torch.clamp(newvalue - b_values[mb_inds],-args.clip_coef,args.clip_coef,)v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)v_loss = 0.5 * v_loss_max.mean()else:v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

根据policy loss、value loss和entropy加权求和得到总的loss,然后反向传播优化参数。在之前计算了新旧策略之间的KL散度,这里可以利用KL散度实现early stopping,即KL散度大于阈值则停止当前batch的训练。(当然也可以停止掉当前minibatch的训练)

###############################################
for iteration in range(1, args.num_iterations + 1):......for epoch in range(args.update_epochs):......for start in range(0, args.batch_size, args.minibatch_size):【在start的for循环中拼接上一段代码】
###############################################entropy_loss = entropy.mean()loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coefoptimizer.zero_grad()loss.backward()nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)optimizer.step()if args.target_kl is not None and approx_kl > args.target_kl:break

tensorboard记录数据,没什么好说的。

###############################################
for iteration in range(1, args.num_iterations + 1):【在iteration的for循环中拼接上一段代码】
###############################################y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()var_y = np.var(y_true)explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y# TRY NOT TO MODIFY: record rewards for plotting purposeswriter.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)writer.add_scalar("losses/value_loss", v_loss.item(), global_step)writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)writer.add_scalar("losses/explained_variance", explained_var, global_step)print("SPS:", int(global_step / (time.time() - start_time)))writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

模型保存的一些操作,也没什么好说的。

if args.save_model:model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"torch.save(agent.state_dict(), model_path)print(f"model saved to {model_path}")from cleanrl_utils.evals.ppo_eval import evaluateepisodic_returns = evaluate(model_path,make_env,args.env_id,eval_episodes=10,run_name=f"{run_name}-eval",Model=Agent,device=device,gamma=args.gamma,)for idx, episodic_return in enumerate(episodic_returns):writer.add_scalar("eval/episodic_return", episodic_return, idx)if args.upload_model:from cleanrl_utils.huggingface import push_to_hubrepo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_namepush_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval")envs.close()
writer.close()

这篇关于[cleanrl] ppo_continuous_action源码解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/483445

相关文章

PostgreSQL的扩展dict_int应用案例解析

《PostgreSQL的扩展dict_int应用案例解析》dict_int扩展为PostgreSQL提供了专业的整数文本处理能力,特别适合需要精确处理数字内容的搜索场景,本文给大家介绍PostgreS... 目录PostgreSQL的扩展dict_int一、扩展概述二、核心功能三、安装与启用四、字典配置方法

深度解析Java DTO(最新推荐)

《深度解析JavaDTO(最新推荐)》DTO(DataTransferObject)是一种用于在不同层(如Controller层、Service层)之间传输数据的对象设计模式,其核心目的是封装数据,... 目录一、什么是DTO?DTO的核心特点:二、为什么需要DTO?(对比Entity)三、实际应用场景解析

深度解析Java项目中包和包之间的联系

《深度解析Java项目中包和包之间的联系》文章浏览阅读850次,点赞13次,收藏8次。本文详细介绍了Java分层架构中的几个关键包:DTO、Controller、Service和Mapper。_jav... 目录前言一、各大包1.DTO1.1、DTO的核心用途1.2. DTO与实体类(Entity)的区别1

Java中的雪花算法Snowflake解析与实践技巧

《Java中的雪花算法Snowflake解析与实践技巧》本文解析了雪花算法的原理、Java实现及生产实践,涵盖ID结构、位运算技巧、时钟回拨处理、WorkerId分配等关键点,并探讨了百度UidGen... 目录一、雪花算法核心原理1.1 算法起源1.2 ID结构详解1.3 核心特性二、Java实现解析2.

使用Python绘制3D堆叠条形图全解析

《使用Python绘制3D堆叠条形图全解析》在数据可视化的工具箱里,3D图表总能带来眼前一亮的效果,本文就来和大家聊聊如何使用Python实现绘制3D堆叠条形图,感兴趣的小伙伴可以了解下... 目录为什么选择 3D 堆叠条形图代码实现:从数据到 3D 世界的搭建核心代码逐行解析细节优化应用场景:3D 堆叠图

深度解析Python装饰器常见用法与进阶技巧

《深度解析Python装饰器常见用法与进阶技巧》Python装饰器(Decorator)是提升代码可读性与复用性的强大工具,本文将深入解析Python装饰器的原理,常见用法,进阶技巧与最佳实践,希望可... 目录装饰器的基本原理函数装饰器的常见用法带参数的装饰器类装饰器与方法装饰器装饰器的嵌套与组合进阶技巧

解析C++11 static_assert及与Boost库的关联从入门到精通

《解析C++11static_assert及与Boost库的关联从入门到精通》static_assert是C++中强大的编译时验证工具,它能够在编译阶段拦截不符合预期的类型或值,增强代码的健壮性,通... 目录一、背景知识:传统断言方法的局限性1.1 assert宏1.2 #error指令1.3 第三方解决

全面解析MySQL索引长度限制问题与解决方案

《全面解析MySQL索引长度限制问题与解决方案》MySQL对索引长度设限是为了保持高效的数据检索性能,这个限制不是MySQL的缺陷,而是数据库设计中的权衡结果,下面我们就来看看如何解决这一问题吧... 目录引言:为什么会有索引键长度问题?一、问题根源深度解析mysql索引长度限制原理实际场景示例二、五大解决

深度解析Spring Boot拦截器Interceptor与过滤器Filter的区别与实战指南

《深度解析SpringBoot拦截器Interceptor与过滤器Filter的区别与实战指南》本文深度解析SpringBoot中拦截器与过滤器的区别,涵盖执行顺序、依赖关系、异常处理等核心差异,并... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现

深度解析Spring AOP @Aspect 原理、实战与最佳实践教程

《深度解析SpringAOP@Aspect原理、实战与最佳实践教程》文章系统讲解了SpringAOP核心概念、实现方式及原理,涵盖横切关注点分离、代理机制(JDK/CGLIB)、切入点类型、性能... 目录1. @ASPect 核心概念1.1 AOP 编程范式1.2 @Aspect 关键特性2. 完整代码实