强化学习实践(二):Dynamic Programming(Value \ Policy Iteration)

2024-09-03 03:52

本文主要是介绍强化学习实践(二):Dynamic Programming(Value \ Policy Iteration),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

强化学习实践(二):Dynamic Programming(Value \ Policy Iteration)

  • 伪代码
    • Value Iteration
    • Policy Iteration
    • Truncated Policy Iteration
  • 代码
  • 项目地址

伪代码

具体的理解可以看理论学习篇,以及代码中的注释,以及赵老师原著

Value Iteration

在这里插入图片描述

Policy Iteration

在这里插入图片描述

Truncated Policy Iteration

在这里插入图片描述

代码

import numpy as npfrom environment.env import Env
from environment.vis import Visclass DynamicProgramming:"""动态规划的两个方法, 实际都为Truncated Policy Iteration, 具体代码尽量复刻伪代码的逻辑"""def __init__(self, gamma: float = 0.9, env: Env = None, vis: Vis = None, render: bool = False):self.gamma = gammaself.env = envself.vis = visself.render = renderself.policy = np.zeros(shape=(self.env.state_space_size, self.env.action_space_size), dtype=int)self.qtable = np.zeros(shape=self.env.state_space_size, dtype=float)def value_iteration(self, threshold: float = 0.01) -> None:"""计算每个状态动作对的状态动作价值,然后每个状态选择最大的值对应的动作作为自己的策略,并将值作为自己的状态价值根据Contraction Mapping Theorem, qsa的计算公式满足该理论要求,通过迭代不断优化全局状态价值,并找到对应的最优策略:param threshold: 迭代结束的阈值,前后两次迭代后的全局状态价值的欧氏距离相差小于该阈值时代表优化空间已经不大,结束优化:return: None"""differ = np.infwhile differ > threshold:kth_qtable = self.qtable.copy()for state in self.env.state_space:qsa = np.zeros(shape=self.env.action_space_size, dtype=float)for action in self.env.action_space:qsa[action] = self.calculate_qvalue(state, action)self.policy[state] = np.zeros(shape=self.env.action_space_size)self.policy[state, np.argmax(qsa)] = 1self.qtable[state] = np.max(qsa)differ = np.linalg.norm(kth_qtable - self.qtable, ord=1)if self.render:self.vis.show_policy(self.policy)self.vis.show_value(self.qtable)self.vis.show()def policy_iteration(self, policy_threshold: float = 0.01, value_threshold: float = 0.01, steps: int = 10) -> None:"""step 1:从初始策略开始,求解该策略对应的全局状态价值(在这个过程中本来要无穷次迭代得到真正的状态价值,但实际会设置阈值,截断策略迭代算法)step 2:拿到第K次迭代对应的策略求解出的全局状态价值之后,利用该价值作为初始值,再进行全局状态价值优化以及策略优化这个过程其实相较于值迭代比较难理解Q1:In the policy evaluation step, how to get the state value vπk by solving the Bellman equation?A1:x=f(x)这种满足Contraction Mapping Theorem的迭代求解方式(也可以解析解matrix vector form,但是涉及矩阵逆运算会很慢O(n^3))Q2*:In the policy improvement step, why is the new policy πk+1 better than πk?A2:直观上不是很好理解就得利用数学工具了,赵老师原著Chapter4.P73页对比了前后两次迭代证明了Vπk - Vπk+1 < 0Q3*:Why can this algorithm finally converge to an optimal policy?A3:Chapter4.P75页不仅证明了能达到最优,而且引入这种PE过程会收敛得更快,证明了Vπk>Vk,同一个迭代timing,策略迭代状态价值更接近最优:param policy_threshold: 策略阈值:param value_threshold: 全局状态价值阈值:param steps: 截断的最大迭代次数,只用阈值也行,但这样更方便说明:return: None"""policy_differ = np.infself.init_policy()while policy_differ > policy_threshold:kth_policy = self.policy.copy()# step 1: policy evaluationvalue_differ = np.infwhile value_differ > value_threshold and steps > 0:steps -= 1kth_qtable = self.qtable.copy()for state in self.env.state_space:state_value = 0for action in self.env.action_space:state_value += self.policy[state, action] * self.calculate_qvalue(state, action)self.qtable[state] = state_valuevalue_differ = np.linalg.norm(kth_qtable - self.qtable, ord=1)# step 2: policy improvement 相当于上面的PE给下面提供了一个初始状态(对应策略),之前值迭代的时候是全0为初始值value_differ = np.infwhile value_differ > value_threshold:kth_qtable = self.qtable.copy()for state in self.env.state_space:qsa = np.zeros(shape=self.env.action_space_size, dtype=float)for action in self.env.action_space:qsa[action] = self.calculate_qvalue(state, action)self.policy[state] = np.zeros(shape=self.env.action_space_size)self.policy[state, np.argmax(qsa)] = 1self.qtable[state] = np.max(qsa)value_differ = np.linalg.norm(kth_qtable - self.qtable, ord=1)policy_differ = np.linalg.norm(kth_policy - self.policy, ord=1)if self.render:self.vis.show_policy(self.policy)self.vis.show_value(self.qtable)self.vis.show()def init_policy(self) -> None:"""之前值迭代可以不用初始化,因为只对policy进行了更新,现在策略迭代得初始化,因为首先就要利用policy进行PE:return: None"""random_action = np.random.randint(self.env.action_space_size, size=self.env.state_space_size)for state, action in enumerate(random_action):self.policy[state, action] = 1def calculate_qvalue(self, state: int, action: int) -> float:"""计算状态动作价值函数的元素展开式, 这里就能理解为什么环境模型为什么是这样的数据结构:param state: 当前状态:param action: 当前动作:return: 当前的状态动作价值"""qvalue = 0# immediately reward: sigma(r * p(r | s, a))for reward_type in range(self.env.reward_space_size):qvalue += self.env.reward_space[reward_type] * self.env.rewards_model[state, action, reward_type]# next state expected reward : sigma(vk(s') * p(s' | s, a))for next_state in range(self.env.state_space_size):qvalue += self.gamma * self.env.states_model[state, action, next_state] * self.qtable[next_state]return qvalueif __name__ == "__main__":start_state = [0, 0]target_state = [2, 3]forbid = [[2, 2], [2, 1], [1, 1], [3, 3], [1, 3], [1, 4]]model = DynamicProgramming(vis=Vis(target_state=target_state, forbid=forbid),env=Env(target_state=target_state, forbid=forbid),render=True)model.value_iteration()# model.policy_iteration()

项目地址

RL_Algorithms(正在逐步更新多智能体的算法,STAR HOPE(^ - ^)

这篇关于强化学习实践(二):Dynamic Programming(Value \ Policy Iteration)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1131949

相关文章

Java调用DeepSeek API的最佳实践及详细代码示例

《Java调用DeepSeekAPI的最佳实践及详细代码示例》:本文主要介绍如何使用Java调用DeepSeekAPI,包括获取API密钥、添加HTTP客户端依赖、创建HTTP请求、处理响应、... 目录1. 获取API密钥2. 添加HTTP客户端依赖3. 创建HTTP请求4. 处理响应5. 错误处理6.

golang内存对齐的项目实践

《golang内存对齐的项目实践》本文主要介绍了golang内存对齐的项目实践,内存对齐不仅有助于提高内存访问效率,还确保了与硬件接口的兼容性,是Go语言编程中不可忽视的重要优化手段,下面就来介绍一下... 目录一、结构体中的字段顺序与内存对齐二、内存对齐的原理与规则三、调整结构体字段顺序优化内存对齐四、内

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

C++实现封装的顺序表的操作与实践

《C++实现封装的顺序表的操作与实践》在程序设计中,顺序表是一种常见的线性数据结构,通常用于存储具有固定顺序的元素,与链表不同,顺序表中的元素是连续存储的,因此访问速度较快,但插入和删除操作的效率可能... 目录一、顺序表的基本概念二、顺序表类的设计1. 顺序表类的成员变量2. 构造函数和析构函数三、顺序表

python实现简易SSL的项目实践

《python实现简易SSL的项目实践》本文主要介绍了python实现简易SSL的项目实践,包括CA.py、server.py和client.py三个模块,文中通过示例代码介绍的非常详细,对大家的学习... 目录运行环境运行前准备程序实现与流程说明运行截图代码CA.pyclient.pyserver.py参

使用C++实现单链表的操作与实践

《使用C++实现单链表的操作与实践》在程序设计中,链表是一种常见的数据结构,特别是在动态数据管理、频繁插入和删除元素的场景中,链表相比于数组,具有更高的灵活性和高效性,尤其是在需要频繁修改数据结构的应... 目录一、单链表的基本概念二、单链表类的设计1. 节点的定义2. 链表的类定义三、单链表的操作实现四、

Spring Boot统一异常拦截实践指南(最新推荐)

《SpringBoot统一异常拦截实践指南(最新推荐)》本文介绍了SpringBoot中统一异常处理的重要性及实现方案,包括使用`@ControllerAdvice`和`@ExceptionHand... 目录Spring Boot统一异常拦截实践指南一、为什么需要统一异常处理二、核心实现方案1. 基础组件

SpringBoot项目中Maven剔除无用Jar引用的最佳实践

《SpringBoot项目中Maven剔除无用Jar引用的最佳实践》在SpringBoot项目开发中,Maven是最常用的构建工具之一,通过Maven,我们可以轻松地管理项目所需的依赖,而,... 目录1、引言2、Maven 依赖管理的基础概念2.1 什么是 Maven 依赖2.2 Maven 的依赖传递机

Oracle查询优化之高效实现仅查询前10条记录的方法与实践

《Oracle查询优化之高效实现仅查询前10条记录的方法与实践》:本文主要介绍Oracle查询优化之高效实现仅查询前10条记录的相关资料,包括使用ROWNUM、ROW_NUMBER()函数、FET... 目录1. 使用 ROWNUM 查询2. 使用 ROW_NUMBER() 函数3. 使用 FETCH FI

在C#中获取端口号与系统信息的高效实践

《在C#中获取端口号与系统信息的高效实践》在现代软件开发中,尤其是系统管理、运维、监控和性能优化等场景中,了解计算机硬件和网络的状态至关重要,C#作为一种广泛应用的编程语言,提供了丰富的API来帮助开... 目录引言1. 获取端口号信息1.1 获取活动的 TCP 和 UDP 连接说明:应用场景:2. 获取硬