O2O : Finetuning Offline World Models in the Real World

2024-06-04 02:12

本文主要是介绍O2O : Finetuning Offline World Models in the Real World,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

CoRL 2023 Oral
paper
code

Intro

算法基于TD-MPC,利用离线数据训练世界模型,然后在线融合基于集成Q的不确定性估计实现Planning。得到的在线数据将联合离线数据共同训练目标策略。

Method

在这里插入图片描述

TD-MPC

TD-MPC由五部分构成:

  1. 状态特征提取 z = h θ ( s ) z = h_\theta(s) z=hθ(s)
  2. 隐动力学模型 z ′ ‘ = d θ ( z , a ) z'`=d_\theta(z,a) z=dθ(z,a)
  3. 奖励模型 r ^ = R θ ( z , a ) \hat{r}=R_\theta(z,a) r^=Rθ(z,a)
  4. planning policy a ^ = π θ ( z ) \hat{a}=\pi_\theta(z) a^=πθ(z)
  5. 终止状态下的 q ^ = Q θ ( z , a ) \hat{q}=Q_\theta(z,a) q^=Qθ(z,a)

通过联合训练进行优化,损失函数为:
L ( θ ) = E ( s , a , r , s ′ ) 0 : h ∼ B ⌊ ∑ t = 0 h ( ∥ z t ′ − s g ( h ϕ ( s t ′ ) ) ∥ 2 2 ⏟ Latent dynamics + ∥ r ^ t − r t ∥ 2 2 ⏟ Reward + ∥ q ^ t − q t ∥ 2 2 ⏟ Value − Q θ ( z t , a ^ t ) ⏟ Action ) ⌋ ( 1 ) \mathcal{L}(\theta)=\mathbb{E}_{(\mathbf{s},\mathbf{a},r,\mathbf{s}^{\prime})_{0:h}\sim\mathcal{B}}\left\lfloor\sum_{t=0}^{h}\left(\underbrace{\|\mathbf{z}_{t}^{\prime}-\mathrm{sg}(h_{\phi}(\mathbf{s}_{t}^{\prime}))\|_{2}^{2}}_{\text{Latent dynamics}}+\underbrace{\|\hat{r}_{t}-r_{t}\|_{2}^{2}}_{\text{Reward}}+\underbrace{\|\hat{q}_{t}-q_{t}\|_{2}^{2}}_{\text{Value}}-\underbrace{Q_{\theta}(\mathbf{z}_{t},\hat{\mathbf{a}}_{t})}_{\text{Action}}\right)\right\rfloor(1) L(θ)=E(s,a,r,s)0:hB t=0h Latent dynamics ztsg(hϕ(st))22+Reward r^trt22+Value q^tqt22Action Qθ(zt,a^t) (1)
在Offline 设定下,分布偏移将导致Q估计以及隐模型以及价值函数的错误估计。启发于IQL,通过只对in-sample的动作尽心TD-backups来估计,缓解过估计问题。因此对模型价值函数利用离线数据进行训练时,此时Q函数采用IQL中的期望回归方法优化
L V ( θ ) = ∣ τ − 1 { Q ϕ ( z t , a t ) − V θ ( z t ) < 0 } ∣ ( Q ϕ ( z t , a t ) − V θ ( z t ) ) 2 , \mathcal{L}_{V}(\theta)=|\tau-1_{\{Q_{\phi}(\mathbf{z}_{t},\mathbf{a}_{t})-V_{\theta}(\mathbf{z}_{t})<0\}}|(Q_{\phi}(\mathbf{z}_{t},\mathbf{a}_{t})-V_{\theta}(\mathbf{z}_{t}))^{2}, LV(θ)=τ1{Qϕ(zt,at)Vθ(zt)<0}(Qϕ(zt,at)Vθ(zt))2,
同时对planning policy采用AWR的更新,即 exp ⁡ ( β ( Q ϕ ( z t , a t ) − V θ ( z t ^ ) ) ) log ⁡ π θ ( a t ∣ z t ) \exp(\beta(Q_\phi(\mathbf{z}_t,\mathbf{a}_t)-V_\theta(\hat{\mathbf{z}_t})))\log\pi_\theta(\mathbf{a}_t|\mathbf{z}_t) exp(β(Qϕ(zt,at)Vθ(zt^)))logπθ(atzt)

Uncertainty Estimation as Test-Time Behavior Regularizatio

离线训练的模型依旧存在OOD数据过估计,需要在线微调。文章提出基于不确定性估计的planning实现在线交互过程中的动作选择。planning一定程度缓解基于约束的离线算法导致的在现阶段探索能力不足。进而导致算法样本效率低的问题。

首先构建集成Q函数模型,计算基于标准差的不确信度,作为惩罚项对奖励进行调整,实现保守的在线planning。
R ^ = γ h ( Q θ ( z h , a h ) − λ u h ) + ∑ t = 0 h − 1 γ t ( R θ ( z t , a t ) − λ u t ) , u t = s t d ( { Q θ ( i ) ( z t , a t ) } i = 1 N ) \hat{\mathcal{R}}=\gamma^{h}\left(Q_{\theta}(\mathbf{z}_{h},\mathbf{a}_{h})-\lambda u_{h}\right)+\sum_{t=0}^{h-1}\gamma^{t}\left(R_{\theta}(\mathbf{z}_{t},\mathbf{a}_{t})-\lambda u_{t}\right),\quad u_{t}=\mathrm{std}\left(\{Q_{\theta}^{(i)}(\mathbf{z}_{t},\mathbf{a}_{t})\}_{i=1}^{N}\right) R^=γh(Qθ(zh,ah)λuh)+t=0h1γt(Rθ(zt,at)λut),ut=std({Qθ(i)(zt,at)}i=1N)

除此外,还维护两个buffer分别存储离线数据于在线数据,通过balance sampling数据训练模型、策略以及价值函数。

结果

在这里插入图片描述
在这里插入图片描述

这篇关于O2O : Finetuning Offline World Models in the Real World的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1028776

相关文章

Retrieval-Augmented Generation for Large Language Models A Survey

Retrieval-Augmented Generation for Large Language Models: A Survey 文献综述 文章目录 Retrieval-Augmented Generation for Large Language Models: A Survey 文献综述 Abstract背景介绍 RAG概述原始RAG先进RAG预检索过程后检索过程 模块化RAGMo

AI 大模型企业应用实战(10)-LLMs和Chat Models

1 模型 来看两种不同类型的模型--LLM 和聊天模型。然后,它将介绍如何使用提示模板来格式化这些模型的输入,以及如何使用输出解析器来处理输出。 LangChain 中的语言模型有两种类型: 1.1 Chat Models 聊天模型通常由 LLM 支持,但专门针对会话进行了调整。提供者 API 使用与纯文本补全模型不同的接口。它们的输入不是单个字符串,而是聊天信息列表,输出则是一条人工智能

AJAX:如何编写一个关于AJAX的Hello World?(ajax发送异步请求(四步操作))

用到的一个Servlet类: package cn.edu.web.servlet;import java.io.IOException;import javax.servlet.ServletException;import javax.servlet.annotation.WebServlet;import javax.servlet.http.HttpServlet;impor

oracle学习之第一个存储过程:打印Hello World

数据库对象:表、视图、索引、序列、同义词、存储过程、存储函数 存储过程:指的是存储在数据库中供所有用户程序调用的子程序叫存储过程、存储函数 存储过程和存储函数的相同点:完成特定功能的程序 存储过程和存储函数的区别:是否用return语句返回值(存储函数可以,但是存储过程不行) --第一个存储过程:打印Hello World/*调用存储过程2种方式:1、exec sayhellow

在Mac OS上使用Visual Studio Code创建C++ Qt的Hello World应用

引言 Qt是一个跨平台的应用程序和用户界面框架,而Visual Studio Code是一个功能强大的编辑器,两者结合可以极大地提升开发效率。本文将指导你在Mac OS上使用Visual Studio Code创建一个简单的Qt 'Hello World'窗口应用。 环境准备 确保你的MacBook OS运行最新的操作系统。安装Homebrew,Mac OS的包管理器。通过Homebrew安装

SpringBoot (一) :入门篇 Hello World

什么是SpringBoot Spring Boot是由Pivotal团队提供的全新框架,其设计目的是用来简化新Spring应用的初始搭建以及开发过程。该框架使用了特定的方式来进行配置,从而使开发人员不再需要定义样板化的配置。通过这种方式,Spring Boot致力于在蓬勃发展的快速应用开发领域(rapid application development)成为领导者。 SpringBoot有什么

从同—视角理解扩散模型(Understanding Diffusion Models A Unified Perspective)

从同—视角理解扩散模型 Understanding Diffusion Models A Unified Perspective【全公式推导】【免费视频讲解】 B站视频讲解 视频的论文笔记 从同一视角理解扩散模型【视频讲解笔记】 配合视频讲解的同步笔记。 整个系列完整的论文笔记内容如下,仅为了不用—一回复,共计14个视频讲解笔记,故设定了一个比较低的价格(粉丝仅6毛),大家可以自取。

Autoencoder(AE)、Variational Autoencoder(VAE)和Diffusion Models(DM)了解

Autoencoder (AE) 工作原理: Autoencoder就像一个数据压缩机器。它由两部分组成: 编码器:将输入数据压缩成一个小小的代码。解码器:将这个小代码还原成尽可能接近原始输入的数据。 优点和应用: 简单易懂:用于学习数据的特征和去除噪声。应用场景:例如可以用来缩小图像的大小但保留关键特征,或者去除文本数据中的错误。 挑战: 数据损坏:如果输入数据太乱,编码器可能无法有

android (No cached version available for offline mode)----bug解析处理

错误日志 Execution failed for task ':base:generateDebugRFile'.> Could not resolve all files for configuration ':base:debugCompileClasspath'.> Could not download core-1.3.0.aar (androidx.core:core:1.3.0)

论文阅读--Cross-view Transformers for real-time Map-view Semantic Segmentation

一种新的2D维度的bev特征提取方案,其通过引入相机先验信息(相机内参和外参)构建了一个多视图交叉注意力机制,能够将多视图特征映射为BEV特征。 cross view attention:BEV位置编码+由根据相机标定结果(内参和外参)演算得到的相机位置编码+多视图特征做attention得到 整体上文章的网络前端使用CNN作为特征抽取网络,中端使用CNN多级特征作为输入在多视图下优化BEV特