强化学习原理python篇06——DQN

2024-01-28 06:04
文章标签 python 学习 原理 强化 06 dqn

本文主要是介绍强化学习原理python篇06——DQN,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

强化学习原理python篇05——DQN

  • DQN 算法
    • 定义DQN网络
    • 初始化环境
    • 开始训练
    • 可视化结果

本章全篇参考赵世钰老师的教材 Mathmatical-Foundation-of-Reinforcement-Learning Deep Q-learning 章节,请各位结合阅读,本合集只专注于数学概念的代码实现。

DQN 算法

1)使用随机权重 ( w ← 1.0 ) (w←1.0) w1.0初始化目标网络 Q ( s , a , w ) Q(s, a, w) Q(s,a,w)和网络 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w) Q Q Q Q ^ \hat Q Q^相同,清空回放缓冲区。

2)以概率ε选择一个随机动作a,否则 a = a r g m a x Q ( s , a , w ) a=argmaxQ(s,a,w) a=argmaxQ(s,a,w)

3)在模拟器中执行动作a,观察奖励r和下一个状态s’。

4)将转移过程(s, a, r, s’)存储在回放缓冲区中。

5)从回放缓冲区中采样一个随机的小批量转移过程。

6)对于回放缓冲区中的每个转移过程,如果片段在此步结束,则计算目标 y = r y=r y=r,否则计算 y = r + γ m a x Q ^ ( s , a , w ) y=r+\gamma max \hat Q(s, a, w) y=r+γmaxQ^(s,a,w)

7)计算损失: L = ( Q ( s , a , w ) – y ) 2 L=(Q(s, a, w)–y)^2 L=(Q(s,a,w)y)2

8)固定网络 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w)不变,通过最小化模型参数的损失,使用SGD算法更新 Q ( s , a ) Q(s, a) Q(s,a)

9)每N步,将权重从目标网络 Q Q Q复制到 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w)

10)从步骤2开始重复,直到收敛为止。

定义DQN网络

import collections
import copy
import random
from collections import defaultdict
import math
import gym
import gym.spaces
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gym.envs.toy_text import frozen_lake
from torch.utils.tensorboard import SummaryWriterclass Net(nn.Module):def __init__(self, obs_size, hidden_size, q_table_size):super(Net, self).__init__()self.net = nn.Sequential(# 输入为状态,样本为(1*n)nn.Linear(obs_size, hidden_size),nn.ReLU(),# nn.Linear(hidden_size, hidden_size),# nn.ReLU(),nn.Linear(hidden_size, q_table_size),)def forward(self, state):return self.net(state)class DQN:def __init__(self, env, tgt_net, net):self.env = envself.tgt_net = tgt_netself.net = netdef generate_train_data(self, batch_size, epsilon):state, _ = env.reset()train_data = []while len(train_data)<batch_size*2:q_table_tgt = self.tgt_net(torch.Tensor(state)).detach()if np.random.uniform(0, 1, 1) > epsilon:action = self.env.action_space.sample()else:action = int(torch.argmax(q_table_tgt))new_state, reward,terminated, truncted, info = env.step(action)train_data.append([state, action, reward, new_state, terminated])state = new_stateif terminated:state, _ = env.reset()continuerandom.shuffle(train_data)                return train_data[:batch_size]def calculate_y_hat_and_y(self, batch):# 6)对于回放缓冲区中的每个转移过程,如果片段在此步结束,则计算目标$y=r$,否则计算$y=r+\gamma max \hat Q(s, a, w)$ 。y = []state_space = []action_space = []for state, action, reward, new_state, terminated in batch:# y值if terminated:y.append(reward)else:# 下一步的 qtable 的最大值q_table_net = self.net(torch.Tensor(np.array([new_state]))).detach()y.append(reward + gamma * float(torch.max(q_table_net)))# y hat的值state_space.append(state)action_space.append(action)idx = [list(range(len(action_space))), action_space]y_hat = self.tgt_net(torch.Tensor(np.array(state_space)))[idx]return y_hat, torch.tensor(y)def update_net_parameters(self, update=True):self.net.load_state_dict(self.tgt_net.state_dict())

初始化环境

   # 初始化环境
env = gym.make("CartPole-v1")
# env = DiscreteOneHotWrapper(env)hidden_num = 64
# 定义网络
net = Net(env.observation_space.shape[0],hidden_num, env.action_space.n)
tgt_net = Net(env.observation_space.shape[0],hidden_num, env.action_space.n)
dqn = DQN(env=env, net=net, tgt_net=tgt_net)# 初始化参数
# dqn.init_net_and_target_net_weight()# 定义优化器
opt = optim.Adam(tgt_net.parameters(), lr=0.001)# 定义损失函数
loss = nn.MSELoss()# 记录训练过程
# writer = SummaryWriter(log_dir="logs/DQN", comment="DQN")

开始训练

gamma = 0.8
for i in range(10000):batch = dqn.generate_train_data(256, 0.8)y_hat, y = dqn.calculate_y_hat_and_y(batch)opt.zero_grad()l = loss(y_hat, y)l.backward()opt.step()print("MSE: {}".format(l.item()))if i % 5 == 0:dqn.update_net_parameters(update=True)

输出:

MSE: 0.027348674833774567
MSE: 0.1803671419620514
MSE: 0.06523636728525162
MSE: 0.08363766968250275
MSE: 0.062360599637031555
MSE: 0.004909628536552191
MSE: 0.05730309337377548
MSE: 0.03543371334671974
MSE: 0.08458714932203293

可视化结果

env = gym.make("CartPole-v1", render_mode = "human")
env = gym.wrappers.RecordVideo(env, video_folder="video")state, info = env.reset()
total_rewards = 0while True:q_table_state = dqn.tgt_net(torch.Tensor(state)).detach()# if np.random.uniform(0, 1, 1) > 0.9:#     action = env.action_space.sample()# else:action = int(torch.argmax(q_table_state))state, reward, terminated, truncted, info = env.step(action)if terminated:break

这篇关于强化学习原理python篇06——DQN的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/652678

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

python: 多模块(.py)中全局变量的导入

文章目录 global关键字可变类型和不可变类型数据的内存地址单模块(单个py文件)的全局变量示例总结 多模块(多个py文件)的全局变量from x import x导入全局变量示例 import x导入全局变量示例 总结 global关键字 global 的作用范围是模块(.py)级别: 当你在一个模块(文件)中使用 global 声明变量时,这个变量只在该模块的全局命名空

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

深入探索协同过滤:从原理到推荐模块案例

文章目录 前言一、协同过滤1. 基于用户的协同过滤(UserCF)2. 基于物品的协同过滤(ItemCF)3. 相似度计算方法 二、相似度计算方法1. 欧氏距离2. 皮尔逊相关系数3. 杰卡德相似系数4. 余弦相似度 三、推荐模块案例1.基于文章的协同过滤推荐功能2.基于用户的协同过滤推荐功能 前言     在信息过载的时代,推荐系统成为连接用户与内容的桥梁。本文聚焦于

hdu4407(容斥原理)

题意:给一串数字1,2,......n,两个操作:1、修改第k个数字,2、查询区间[l,r]中与n互质的数之和。 解题思路:咱一看,像线段树,但是如果用线段树做,那么每个区间一定要记录所有的素因子,这样会超内存。然后我就做不来了。后来看了题解,原来是用容斥原理来做的。还记得这道题目吗?求区间[1,r]中与p互质的数的个数,如果不会的话就先去做那题吧。现在这题是求区间[l,r]中与n互质的数的和

【Python编程】Linux创建虚拟环境并配置与notebook相连接

1.创建 使用 venv 创建虚拟环境。例如,在当前目录下创建一个名为 myenv 的虚拟环境: python3 -m venv myenv 2.激活 激活虚拟环境使其成为当前终端会话的活动环境。运行: source myenv/bin/activate 3.与notebook连接 在虚拟环境中,使用 pip 安装 Jupyter 和 ipykernel: pip instal

06 C++Lambda表达式

lambda表达式的定义 没有显式模版形参的lambda表达式 [捕获] 前属性 (形参列表) 说明符 异常 后属性 尾随类型 约束 {函数体} 有显式模版形参的lambda表达式 [捕获] <模版形参> 模版约束 前属性 (形参列表) 说明符 异常 后属性 尾随类型 约束 {函数体} 含义 捕获:包含零个或者多个捕获符的逗号分隔列表 模板形参:用于泛型lambda提供个模板形参的名

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]