DI-engine强化学习入门(十又二分之一)如何使用RNN——数据处理、隐藏状态、Burn-in

本文主要是介绍DI-engine强化学习入门(十又二分之一)如何使用RNN——数据处理、隐藏状态、Burn-in,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、数据处理
用于训练 RNN 的 mini-batch 数据不同于通常的数据。 这些数据通常应按时间序列排列。 对于 DI-engine, 这个处理是在 collector 阶段完成的。 用户需要在配置文件中指定 learn_unroll_len 以确保序列数据的长度与算法匹配。 对于大多数情况, learn_unroll_len 应该等于 RNN 的历史长度(a.k.a 时间序列长度),但在某些情况下并非如此。比如,在 r2d2 中, 我们使用burn-in操作, 序列长度等于 learn_unroll_len + burnin_step 。 这里将在下一节中具体解释。

什么是数据处理?
数据处理指的是为循环神经网络(RNN)训练准备时间序列数据的过程。这个过程包括将收集到的数据组织成适当格式的小批量(mini-batches),这些批量数据将用于网络的训练。这一步骤通常发生在DI-engine的collector阶段,也就是数据收集和预处理发生的地方。用户需要在配置文件中指定 learn_unroll_len 以确保序列数据的长度与算法匹配。 对于大多数情况, learn_unroll_len 应该等于 RNN 的历史长度(a.k.a 时间序列长度),但在某些情况下并非如此。比如,在 r2d2 中, 我们使用burn-in操作, 序列长度等于 learn_unroll_len + burnin_step 。例如,如果你设置 learn_unroll_len = 10 和 burnin_step = 5,那么 RNN 实际接收的输入序列长度将是 15:前 5 步为 burn-in(用于预热隐藏状态),接下来的 10 步作为学习的一部分。这样设置可以帮助 RNN 在计算梯度和进行权重更新时,有一个更加准确的隐藏状态作为起点。
部分名词解释

  • mini-batches:在机器学习中,特别是在训练神经网络时,数据一般被分成小的批次进行处理,这些批次被称为 “mini-batch”。一个 mini-batch 包含了一组样本,这组样本用于执行单次迭代的前向传播和反向传播,以更新网络的权重。使用 mini-batches 而不是单个样本或整个数据集(后者称为 “batch” 或 “full-batch”)可以平衡计算效率和内存限制,有助于提高学习的稳定性和收敛速度。
  • collector阶段:在 DI-engine中,collector 阶段是指环境与智能体交互并收集经验数据的过程。在这个阶段,智能体根据其当前的策略执行操作,环境则返回新的状态、奖励和其他可能的信息,如是否达到终止状态。收集到的数据(经常被称为经验或转换)随后被用于训练智能体的模型,例如对策略或价值函数进行更新。

为什么要进行数据处理:

  1. 保持时间依赖性:RNN的核心优势是处理具有时间序列依赖性的数据,比如语言、视频帧、股票价格等。正确的数据处理确保了这些时间依赖性在训练数据中得以保留,使得模型能够学习到数据中的序列特征。
  2. 提高学习效率:通过将数据划分为与模型期望的序列长度匹配的批次,可以提高模型学习的效率。这样做可以确保网络在每次更新时都接收到足够的上下文信息。
  3. 适配算法要求:不同的RNN算法可能需要不同形式的输入数据。例如,标准的RNN只需要过去的信息,而一些变体如LSTM或GRU可能会处理更长的序列。特定的算法,如R2D2,还可能需要额外的步骤(如burn-in),以便更好地初始化网络状态。
  4. 处理不规则长度:在现实世界的数据集中,序列长度往往是不规则的。数据处理确保了每个mini-batch都有统一的序列长度,这通常通过截断过长的序列或填充过短的序列来实现。
  5. 优化内存和计算资源:通过将数据组织成具有固定时间步长的批次,可以更有效地利用GPU等计算资源,因为这些资源在处理固定大小的数据时通常更高效。
  6. 稳定学习过程:特别是在强化学习中,使用如n-step返回或经验回放的技术,可以帮助模型从环境反馈中学习,并减少方差,从而稳定学习过程。

如何进行数据处理

def _get_train_sample(self, data: list) -> Union[None, List[Any]]:    data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)    return get_train_sample(data, self._sequence_len)

 代码段 def _get_train_sample(self, data: list) 是一个方法,它的作用是从收集到的数据中提取用于训练 RNN 的样本。这个方法会在两个步骤中处理数据:

  • N步返回计算(get_nstep_return_data): 这个函数接受原始的经验数据,然后计算所谓的 N 步返回值。N 步返回是一个在强化学习中用于临时差分(Temporal Difference, TD)学习的概念,它考虑了从当前状态开始的未来 N 步的累积奖励。计算这个值需要使用折现因子 gamma。这个步骤的目的是为了让智能体学习如何根据当前的行动预测未来的奖励,这是强化学习中价值函数估计的重要部分。
  • 训练样本获取(get_train_sample): 在得到 N 步返回值之后,这个函数进一步处理数据以生成训练样本。具体地,它会根据 self._sequence_len(即时间序列长度或者 RNN 的历史长度)来选择数据序列。这意味着每个训练样本将是一个具有 self._sequence_len 长度的数据序列,这对于训练 RNN 来说是必要的,因为 RNN 需要一定长度的历史来维护其内部状态(或记忆)。

有关这两个数据处理功能的工作流程见下图:

二、初始化隐藏状态 (Hidden State)
RNN用于处理具有时间依赖性的信息。RNN的隐藏状态(Hidden State)是其记忆的一部分,它能够捕捉到前一时间步长的信息。这些信息对于预测下一个动作或状态非常关键。在此上下文中,初始化RNN的隐藏状态是一个重要的步骤,它确保了RNN在开始新的数据批次处理时具有正确的起始状态。
策略的 _learn_model 需要初始化 RNN。这些隐藏状态来自 _collect_model 保存的 prev_state。 用户需要通过 _process_transition 函数将这些状态添加到 _learn_model 输入数据字典中。 

def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:    transition = {        'obs': obs,        'action': model_output['action'],        'prev_state': model_output['prev_state'], # add ``prev_state`` key here        'reward': timestep.reward,        'done': timestep.done,    }    return transition

点击DI-engine强化学习入门(十又二分之一)如何使用RNN——数据处理、隐藏状态、Burn-in - 古月居 可查看全文

这篇关于DI-engine强化学习入门(十又二分之一)如何使用RNN——数据处理、隐藏状态、Burn-in的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/969655

相关文章

golang1.23版本之前 Timer Reset方法无法正确使用

《golang1.23版本之前TimerReset方法无法正确使用》在Go1.23之前,使用`time.Reset`函数时需要先调用`Stop`并明确从timer的channel中抽取出东西,以避... 目录golang1.23 之前 Reset ​到底有什么问题golang1.23 之前到底应该如何正确的

详解Vue如何使用xlsx库导出Excel文件

《详解Vue如何使用xlsx库导出Excel文件》第三方库xlsx提供了强大的功能来处理Excel文件,它可以简化导出Excel文件这个过程,本文将为大家详细介绍一下它的具体使用,需要的小伙伴可以了解... 目录1. 安装依赖2. 创建vue组件3. 解释代码在Vue.js项目中导出Excel文件,使用第三

Linux alias的三种使用场景方式

《Linuxalias的三种使用场景方式》文章介绍了Linux中`alias`命令的三种使用场景:临时别名、用户级别别名和系统级别别名,临时别名仅在当前终端有效,用户级别别名在当前用户下所有终端有效... 目录linux alias三种使用场景一次性适用于当前用户全局生效,所有用户都可调用删除总结Linux

java图像识别工具类(ImageRecognitionUtils)使用实例详解

《java图像识别工具类(ImageRecognitionUtils)使用实例详解》:本文主要介绍如何在Java中使用OpenCV进行图像识别,包括图像加载、预处理、分类、人脸检测和特征提取等步骤... 目录前言1. 图像识别的背景与作用2. 设计目标3. 项目依赖4. 设计与实现 ImageRecogni

python管理工具之conda安装部署及使用详解

《python管理工具之conda安装部署及使用详解》这篇文章详细介绍了如何安装和使用conda来管理Python环境,它涵盖了从安装部署、镜像源配置到具体的conda使用方法,包括创建、激活、安装包... 目录pytpshheraerUhon管理工具:conda部署+使用一、安装部署1、 下载2、 安装3

Mysql虚拟列的使用场景

《Mysql虚拟列的使用场景》MySQL虚拟列是一种在查询时动态生成的特殊列,它不占用存储空间,可以提高查询效率和数据处理便利性,本文给大家介绍Mysql虚拟列的相关知识,感兴趣的朋友一起看看吧... 目录1. 介绍mysql虚拟列1.1 定义和作用1.2 虚拟列与普通列的区别2. MySQL虚拟列的类型2

使用MongoDB进行数据存储的操作流程

《使用MongoDB进行数据存储的操作流程》在现代应用开发中,数据存储是一个至关重要的部分,随着数据量的增大和复杂性的增加,传统的关系型数据库有时难以应对高并发和大数据量的处理需求,MongoDB作为... 目录什么是MongoDB?MongoDB的优势使用MongoDB进行数据存储1. 安装MongoDB

关于@MapperScan和@ComponentScan的使用问题

《关于@MapperScan和@ComponentScan的使用问题》文章介绍了在使用`@MapperScan`和`@ComponentScan`时可能会遇到的包扫描冲突问题,并提供了解决方法,同时,... 目录@MapperScan和@ComponentScan的使用问题报错如下原因解决办法课外拓展总结@

mysql数据库分区的使用

《mysql数据库分区的使用》MySQL分区技术通过将大表分割成多个较小片段,提高查询性能、管理效率和数据存储效率,本文就来介绍一下mysql数据库分区的使用,感兴趣的可以了解一下... 目录【一】分区的基本概念【1】物理存储与逻辑分割【2】查询性能提升【3】数据管理与维护【4】扩展性与并行处理【二】分区的

使用Python实现在Word中添加或删除超链接

《使用Python实现在Word中添加或删除超链接》在Word文档中,超链接是一种将文本或图像连接到其他文档、网页或同一文档中不同部分的功能,本文将为大家介绍一下Python如何实现在Word中添加或... 在Word文档中,超链接是一种将文本或图像连接到其他文档、网页或同一文档中不同部分的功能。通过添加超