pytorch强化学习(2)——重写DQN

2024-03-07 01:36
文章标签 学习 重写 pytorch 强化 dqn

本文主要是介绍pytorch强化学习(2)——重写DQN,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

思路

在q-learning当中,Q函数的输入是状态state和action,输出是q-value。

而DQN就是使用神经网络来拟合Q函数,所以从直观上来说,我觉得神经网络的输入应该是状态state和action,输出应该是q-value。

但是,网上绝大多数DQN的代码实现都把state作为网络输入,把所有action的q-value的组合作为网络输出。我觉得这是不直观的、令人费解的,于是我按照自己的想法写了一份DQN代码。

在下面的代码中,神经网络的输入是state和action的连接,若干个浮点数表示state,一个整数表示action。神经网络的输出只有一个元素,代表q-value的值。

代码

env.py

import gym
from DQN_brain import DQN
import matplotlib.pyplot as plt
import numpylr = 1e-3  # 学习率
gamma = 0.9  # 折扣因子
epsilon = 0.9  # 贪心系数
n_hidden = 50  # 隐含层神经元个数env = gym.make("CartPole-v1")
n_states = env.observation_space.shape[0]  # 4
n_actions = env.action_space.n  # 2 动作的个数dqn = DQN(n_states, n_hidden, n_actions, lr, gamma, epsilon)if __name__ == '__main__':reward_list = []for i in range(100):# 获取初始环境state = env.reset()[0]  # len=4total_reward = 0done = Falsewhile True:# 获取最优动作action = dqn.optimal_action(state)# 有一定概率不采取最优动作,而是随机选择一个动作执行,这一点很重要if numpy.random.random() > epsilon:action = numpy.random.randint(n_actions)# 更新环境next_state, reward, done, _, _ = env.step(action)dqn.learning(state, next_state, action, reward, done)# 更新一些变量state = next_statetotal_reward += rewardif done:breakprint("第%d回合,total_reward=%f" % (i, total_reward))reward_list.append(total_reward)# 绘图episodes_list = list(range(len(reward_list)))plt.plot(episodes_list, reward_list)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('DQN Returns')plt.show()

DQN_brain.py

import torch
from torch import nn, Tensorclass Net(nn.Module):# 构造有2个隐含层的网络def __init__(self, input_dim: int, n_hidden: int, output_dim: int):super().__init__()self.network = nn.Sequential(torch.nn.Linear(input_dim, n_hidden, dtype=torch.float),torch.nn.ReLU(),torch.nn.Linear(n_hidden, n_hidden, dtype=torch.float),torch.nn.ReLU(),torch.nn.Linear(n_hidden, n_hidden, dtype=torch.float),torch.nn.ReLU(),torch.nn.Linear(n_hidden, output_dim, dtype=torch.float),)# 前传,直接调用Net对象,其实就是调用forward函数def forward(self, x):  # [b,n_states]return self.network(x)class DQN:def __init__(self, n_states: int, n_hidden: int, n_actions: int, lr: float, gamma: float, epsilon: float):# 属性分配self.n_states = n_states  # 状态的特征数self.n_hidden = n_hidden  # 隐含层个数self.n_actions = n_actions  # 动作数self.lr = lr  # 训练时的学习率self.gamma = gamma  # 折扣因子,对下一状态的回报的缩放self.epsilon = epsilon  # 贪婪策略,有1-epsilon的概率探索# 实例化训练网络,网络的输入是state+action,# 网络的输出是只有一个元素的一维向量,代表该动作在该状态下的q-valueself.q_net = Net(self.n_states + 1, self.n_hidden, 1)# 优化器,更新训练网络的参数self.q_optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)self.criterion = torch.nn.MSELoss()  # 损失函数# 把状态和动作转化为tensor并连接起来def _concat_input(self, state: list[float], action: int):state_tensor = torch.tensor(state, dtype=torch.float)action_tensor = torch.tensor([action], dtype=torch.float)return torch.concat([state_tensor, action_tensor])# 获取q-value值最大的actiondef optimal_action(self, state: list[float]):q_values = torch.tensor([], dtype=torch.float)# 获取所有action的q-valuefor action in range(self.n_actions):q_values = torch.concat([q_values, self.get_q_value(state, action)])# 返回值最大的那个下标,item()函数只能对只有单个元素的tensor使用return torch.argmax(q_values).item()# 更新网络def learning(self,state: list[float],next_state: list[float],action: int,reward: float,done: bool) -> None:# 下一状态的最优动作next_optimal_action = self.optimal_action(next_state)# 当前状态q_valueq_value = self.get_q_value(state, action)# 下一状态q_valuenext_q_value = self.get_q_value(next_state, next_optimal_action)# q_target计算q_target = reward + self.gamma * next_q_value * (1. - float(done))# 计算loss,然后反向传播,然后梯度下降loss: Tensor = self.criterion(q_value, q_target)self.q_optimizer.zero_grad()loss.backward()self.q_optimizer.step()# 根据状态和动作获取q_valuedef get_q_value(self, state: list[float], action: int) -> Tensor:return self.q_net(self._concat_input(state, action))# tensor([5.5241], grad_fn=<ViewBackward0>)

这篇关于pytorch强化学习(2)——重写DQN的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/782042

相关文章

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

pytorch+torchvision+python版本对应及环境安装

《pytorch+torchvision+python版本对应及环境安装》本文主要介绍了pytorch+torchvision+python版本对应及环境安装,安装过程中需要注意Numpy版本的降级,... 目录一、版本对应二、安装命令(pip)1. 版本2. 安装全过程3. 命令相关解释参考文章一、版本对

Java进阶学习之如何开启远程调式

《Java进阶学习之如何开启远程调式》Java开发中的远程调试是一项至关重要的技能,特别是在处理生产环境的问题或者协作开发时,:本文主要介绍Java进阶学习之如何开启远程调式的相关资料,需要的朋友... 目录概述Java远程调试的开启与底层原理开启Java远程调试底层原理JVM参数总结&nbsMbKKXJx

从零教你安装pytorch并在pycharm中使用

《从零教你安装pytorch并在pycharm中使用》本文详细介绍了如何使用Anaconda包管理工具创建虚拟环境,并安装CUDA加速平台和PyTorch库,同时在PyCharm中配置和使用PyTor... 目录背景介绍安装Anaconda安装CUDA安装pytorch报错解决——fbgemm.dll连接p

pycharm远程连接服务器运行pytorch的过程详解

《pycharm远程连接服务器运行pytorch的过程详解》:本文主要介绍在Linux环境下使用Anaconda管理不同版本的Python环境,并通过PyCharm远程连接服务器来运行PyTorc... 目录linux部署pytorch背景介绍Anaconda安装Linux安装pytorch虚拟环境安装cu

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用