DI-engine强化学习入门(七)如何自定义神经网络模型

2024-05-12 14:20

本文主要是介绍DI-engine强化学习入门(七)如何自定义神经网络模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在强化学习中,需要根据决策问题和策略选择合适的神经网络。DI-engine中,神经网络模型可以通过两种方式指定:

  1. 使用配置文件中的cfg.policy.model自动生成默认模型。这种方式下,可以在配置文件中指定神经网络的类型(MLP、CNN等)以及超参数(隐层大小、激活函数等),DI-engine会根据这些配置自动构建神经网络模型。这种方式简单易用,适用于常见的标准网络结构。
  2. 自定义模型实例并传入Policy。这种方式下,需要用户自己定义Tensorflow/Pytorch模型类,实现前向传播等接口,然后将实例传入Policy中。这种方式灵活度高,用户可以自由设计任意结构的神经网络。但是需要用户比较熟悉网络定义和 Tensorflow/Pytorch接口。

(注:在强化学习中,策略(Policy)是指智能体(Agent)决策的规则。策略是从状态(State)到动作(Action)的映射,它定义了在给定的状态下,智能体应该采取什么动作。策略可以是确定性的(Deterministic)也可以是随机性的(Stochastic)。)
以上两种方式都会在Policy中封装为neural_net属性,策略学习会通过这个网络完成状态的 embedding 以及动作的选择。这套机制和接口为用户提供了必要的灵活性,可以根据具体问题和需求配置各种神经网络模型。
这是红色的文字

Policy 默认使用的模型是什么
DI-engine 中已经实现的 policy,默认使用 default_model 方法中表明的神经网络模型,例如在 SACPolicy 中:

@POLICY_REGISTRY.register('sac')class SACPolicy(Policy):...    def default_model(self) -> Tuple[str, List[str]]:        if self._cfg.multi_agent:            return 'maqac_continuous', ['ding.model.template.maqac']        else:            return 'qac', ['ding.model.template.qac']...

在这段代码中,我们看到的是一个名为 DI-engine 的强化学习框架中的一个策略(Policy)类的一部分。具体来说,它定义了一个使用Soft Actor-Critic, SAC 算法的策略类。这个段落描述了如何在这个框架内设置和使用策略相关的神经网络模型。

让我们逐步解释这段代码:

  1. @POLICY_REGISTRY.register('sac') 是一个装饰器,它将 SACPolicy 类注册到一个名为 POLICY_REGISTRY 的注册器中,并且用 'sac' 作为这个策略的标识符。这样的注册机制允许框架能够根据名字轻松地查找和实例化策略。
  2. class SACPolicy(Policy): 表明 SACPolicy 是从更一般的 Policy 类派生的,它是一个具体的策略实现,使用了 SAC 算法。
  3. default_model 是 SACPolicy 类的一个方法,它定义了该策略默认使用的模型。这个方法返回两个元素的元组:
  • 'maqac_continuous' 或 'qac':这是在模型注册器中注册的模型名字。根据配置是否是多智能体(multi_agent),它返回不同的模型名。
  • ['ding.model.template.maqac'] 或 ['ding.model.template.qac']:这是包含模型类的文件路径的列表。这个路径告诉 DI-engine 在哪里可以找到定义模型的代码。

4.当使用配置文件时,DI-engine 的入口文件将使用 cfg.policy.model 中的参数(比如 obs_shape, action_shape)来实例化提供的模型类。这个过程是自动化的,意味着用户定义好配置,DI-engine 将负责根据这些配置创建并初始化模型。

5.模型类会根据传入的参数构造适当的神经网络。例如,如果传入的 obs_shape 参数表明观测是一个图像,则模型可能会使用卷积层来处理输入;如果观测是一个向量,则可能使用全连接层。

简而言之,这段代码展示了 DI-engine 如何灵活地处理不同类型的策略和模型,以及如何通过配置文件来方便地自定义和实例化这些策略和模型。这种设计允许研究者和开发者能够轻松试验不同的算法和模型架构,而无需直接修改代码。

如何自定义神经网络模型

在 DI-engine 强化学习框架中,每个策略(如 SACPolicy)通常有一个关联的默认模型(通过 default_model 方法指定),这个默认模型是为特定类型的任务设计的。例如,原始的 qac 模型可能是为处理具有一维观测空间的环境设计的,即观测是一个向量。

但是!如果任务是在一个模拟器(如 dmc2gym,一个DeepMind Control Suite到OpenAI Gym接口的适配器)上运行,并且任务是 cartpole-swingup,而且你希望使用观测为像素的输入(即观测是一个图像),那么默认的 qac 模型不足以处理这样的高维度和多通道的输入。在这种情况下,观测空间的形状是 (3, height, width),其中 3 表示图像的颜色通道数(RGB),height 和 width 分别表示图像的高度和宽度。

在 dmc2gym 文档中,from_pixel 参数设定为 True 意味着环境将提供像素级的观测,而 channels_first 设定为 True 表明图像的通道维度是第一维(这是PyTorch等深度学习库通常采用的格式)。

面对这样的情况,如果你想要使用 SAC 算法处理像素级的观测,你需要自定义一个能够处理这种高维观测的模型。所以我们创建一个新的模型类,该类在内部使用卷积神经网络(CNN)来处理输入的图像数据,并适当地修改网络架构以适应任务的特定要求。

自定义模型完成后,可以将这个模型应用到 SACPolicy 中,替换原本的 qac 模型。涉及到以下几个步骤:

  1. 实现一个新的模型类,它继承自某个基础模型类,并覆盖必要的方法以支持像素级输入。
  2. 在策略配置中指定你的自定义模型,以便 DI-engine 使用你提供的模型而不是默认模型。
  3. 确保你的自定义模型注册到 DI-engine 的模型注册器中,这样框架可以识别和使用它。

自定义 model 基本步骤
1. 明确环境 (env) 和策略 (policy)
首先,需要确定你的强化学习任务的具体环境和任务。例如,我们选择 dmc2gym 环境中的 cartpole-swingup 任务,并且决定观测将以像素数据的形式提供,我们的观测空间是一个图像,其形状为 (3, height, width)。下面我们使用 SAC 算法来进行学习。

在这里,from_pixel = True 表明环境将提供基于像素的观测,channels_first = True 表明图像数据的通道维度在前,这通常是深度学习库(如 PyTorch)的标准格式。

2. 查阅策略中的 default_model 是否适用
接下来,需要检查选择的策略是否具有适用于任务的默认模型。这可以通过查看策略的文档或直接查阅源代码来完成。以 DI-engine 中的 SAC 策略为例,可以查看 SACPolicy 类中的 default_model 方法来了解默认模型:

如果进一步看一下 ding.model.template.qac 中的 QAC 模型,咱们可能会发现它仅支持一维的观测空间,而不支持像 (3, height, width) 这样的图像形状。这意味着对于我们的 cartpole-swingup 任务,需要创建一个自定义模型来处理像素级的观测。

3. 自定义模型 (custom_model) 实现
自定义模型的实现需要遵循一些基本的原则,以确保与 DI-engine 框架的兼容性。

a. 实现功能
自定义模型需要实现默认模型中的所有公共方法。包括:

  • __init__: 构造函数,对模型的各个部分进行初始化。
  • forward: 定义模型如何从输入到输出的前向传递。
  • compute_actor: 计算策略网络的输出,即给定观测值时的动作。
  • compute_critic: 计算价值网络的输出,即动作的价值。

b. 保持返回值类型一致
自定义模型的方法需要保证与原始默认模型的返回值类型一致,以便于替换使用。

c. 利用已实现的 encoder 和 head
在 ding/model/common 下有多种 encoder 和 head 的实现,可用于构建不同部分的模型:

  • Encoder: 负责对输入数据进行编码,使其适合后续的处理。例如,ConvEncoder 用于处理图像观测输入,FCEncoder 用于处理一维观测输入。

点击DI-engine强化学习入门(七)如何自定义神经网络模型 - 古月居可查看全文

这篇关于DI-engine强化学习入门(七)如何自定义神经网络模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/982863

相关文章

51单片机学习记录———定时器

文章目录 前言一、定时器介绍二、STC89C52定时器资源三、定时器框图四、定时器模式五、定时器相关寄存器六、定时器练习 前言 一个学习嵌入式的小白~ 有问题评论区或私信指出~ 提示:以下是本篇文章正文内容,下面案例可供参考 一、定时器介绍 定时器介绍:51单片机的定时器属于单片机的内部资源,其电路的连接和运转均在单片机内部完成。 定时器作用: 1.用于计数系统,可

问题:第一次世界大战的起止时间是 #其他#学习方法#微信

问题:第一次世界大战的起止时间是 A.1913 ~1918 年 B.1913 ~1918 年 C.1914 ~1918 年 D.1914 ~1919 年 参考答案如图所示

[word] word设置上标快捷键 #学习方法#其他#媒体

word设置上标快捷键 办公中,少不了使用word,这个是大家必备的软件,今天给大家分享word设置上标快捷键,希望在办公中能帮到您! 1、添加上标 在录入一些公式,或者是化学产品时,需要添加上标内容,按下快捷键Ctrl+shift++就能将需要的内容设置为上标符号。 word设置上标快捷键的方法就是以上内容了,需要的小伙伴都可以试一试呢!

AssetBundle学习笔记

AssetBundle是unity自定义的资源格式,通过调用引擎的资源打包接口对资源进行打包成.assetbundle格式的资源包。本文介绍了AssetBundle的生成,使用,加载,卸载以及Unity资源更新的一个基本步骤。 目录 1.定义: 2.AssetBundle的生成: 1)设置AssetBundle包的属性——通过编辑器界面 补充:分组策略 2)调用引擎接口API

Javascript高级程序设计(第四版)--学习记录之变量、内存

原始值与引用值 原始值:简单的数据即基础数据类型,按值访问。 引用值:由多个值构成的对象即复杂数据类型,按引用访问。 动态属性 对于引用值而言,可以随时添加、修改和删除其属性和方法。 let person = new Object();person.name = 'Jason';person.age = 42;console.log(person.name,person.age);//'J

一份LLM资源清单围观技术大佬的日常;手把手教你在美国搭建「百万卡」AI数据中心;为啥大模型做不好简单的数学计算? | ShowMeAI日报

👀日报&周刊合集 | 🎡ShowMeAI官网 | 🧡 点赞关注评论拜托啦! 1. 为啥大模型做不好简单的数学计算?从大模型高考数学成绩不及格说起 司南评测体系 OpenCompass 选取 7 个大模型 (6 个开源模型+ GPT-4o),组织参与了 2024 年高考「新课标I卷」的语文、数学、英语考试,然后由经验丰富的判卷老师评判得分。 结果如上图所

大学湖北中医药大学法医学试题及答案,分享几个实用搜题和学习工具 #微信#学习方法#职场发展

今天分享拥有拍照搜题、文字搜题、语音搜题、多重搜题等搜题模式,可以快速查找问题解析,加深对题目答案的理解。 1.快练题 这是一个网站 找题的网站海量题库,在线搜题,快速刷题~为您提供百万优质题库,直接搜索题库名称,支持多种刷题模式:顺序练习、语音听题、本地搜题、顺序阅读、模拟考试、组卷考试、赶快下载吧! 2.彩虹搜题 这是个老公众号了 支持手写输入,截图搜题,详细步骤,解题必备

C++必修:模版的入门到实践

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ 🎈🎈养成好习惯,先赞后看哦~🎈🎈 所属专栏:C++学习 贝蒂的主页:Betty’s blog 1. 泛型编程 首先让我们来思考一个问题,如何实现一个交换函数? void swap(int& x, int& y){int tmp = x;x = y;y = tmp;} 相信大家很快就能写出上面这段代码,但是如果要求这个交换函数支持字符型

零基础STM32单片机编程入门(一)初识STM32单片机

文章目录 一.概要二.单片机型号命名规则三.STM32F103系统架构四.STM32F103C8T6单片机启动流程五.STM32F103C8T6单片机主要外设资源六.编程过程中芯片数据手册的作用1.单片机外设资源情况2.STM32单片机内部框图3.STM32单片机管脚图4.STM32单片机每个管脚可配功能5.单片机功耗数据6.FALSH编程时间,擦写次数7.I/O高低电平电压表格8.外设接口

《offer来了》第二章学习笔记

1.集合 Java四种集合:List、Queue、Set和Map 1.1.List:可重复 有序的Collection ArrayList: 基于数组实现,增删慢,查询快,线程不安全 Vector: 基于数组实现,增删慢,查询快,线程安全 LinkedList: 基于双向链实现,增删快,查询慢,线程不安全 1.2.Queue:队列 ArrayBlockingQueue: