强化学习原理python篇06——DQN

2024-01-28 06:04
文章标签 python 学习 原理 强化 06 dqn

本文主要是介绍强化学习原理python篇06——DQN,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

强化学习原理python篇05——DQN

  • DQN 算法
    • 定义DQN网络
    • 初始化环境
    • 开始训练
    • 可视化结果

本章全篇参考赵世钰老师的教材 Mathmatical-Foundation-of-Reinforcement-Learning Deep Q-learning 章节,请各位结合阅读,本合集只专注于数学概念的代码实现。

DQN 算法

1)使用随机权重 ( w ← 1.0 ) (w←1.0) w1.0初始化目标网络 Q ( s , a , w ) Q(s, a, w) Q(s,a,w)和网络 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w) Q Q Q Q ^ \hat Q Q^相同,清空回放缓冲区。

2)以概率ε选择一个随机动作a,否则 a = a r g m a x Q ( s , a , w ) a=argmaxQ(s,a,w) a=argmaxQ(s,a,w)

3)在模拟器中执行动作a,观察奖励r和下一个状态s’。

4)将转移过程(s, a, r, s’)存储在回放缓冲区中。

5)从回放缓冲区中采样一个随机的小批量转移过程。

6)对于回放缓冲区中的每个转移过程,如果片段在此步结束,则计算目标 y = r y=r y=r,否则计算 y = r + γ m a x Q ^ ( s , a , w ) y=r+\gamma max \hat Q(s, a, w) y=r+γmaxQ^(s,a,w)

7)计算损失: L = ( Q ( s , a , w ) – y ) 2 L=(Q(s, a, w)–y)^2 L=(Q(s,a,w)y)2

8)固定网络 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w)不变,通过最小化模型参数的损失,使用SGD算法更新 Q ( s , a ) Q(s, a) Q(s,a)

9)每N步,将权重从目标网络 Q Q Q复制到 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w)

10)从步骤2开始重复,直到收敛为止。

定义DQN网络

import collections
import copy
import random
from collections import defaultdict
import math
import gym
import gym.spaces
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gym.envs.toy_text import frozen_lake
from torch.utils.tensorboard import SummaryWriterclass Net(nn.Module):def __init__(self, obs_size, hidden_size, q_table_size):super(Net, self).__init__()self.net = nn.Sequential(# 输入为状态,样本为(1*n)nn.Linear(obs_size, hidden_size),nn.ReLU(),# nn.Linear(hidden_size, hidden_size),# nn.ReLU(),nn.Linear(hidden_size, q_table_size),)def forward(self, state):return self.net(state)class DQN:def __init__(self, env, tgt_net, net):self.env = envself.tgt_net = tgt_netself.net = netdef generate_train_data(self, batch_size, epsilon):state, _ = env.reset()train_data = []while len(train_data)<batch_size*2:q_table_tgt = self.tgt_net(torch.Tensor(state)).detach()if np.random.uniform(0, 1, 1) > epsilon:action = self.env.action_space.sample()else:action = int(torch.argmax(q_table_tgt))new_state, reward,terminated, truncted, info = env.step(action)train_data.append([state, action, reward, new_state, terminated])state = new_stateif terminated:state, _ = env.reset()continuerandom.shuffle(train_data)                return train_data[:batch_size]def calculate_y_hat_and_y(self, batch):# 6)对于回放缓冲区中的每个转移过程,如果片段在此步结束,则计算目标$y=r$,否则计算$y=r+\gamma max \hat Q(s, a, w)$ 。y = []state_space = []action_space = []for state, action, reward, new_state, terminated in batch:# y值if terminated:y.append(reward)else:# 下一步的 qtable 的最大值q_table_net = self.net(torch.Tensor(np.array([new_state]))).detach()y.append(reward + gamma * float(torch.max(q_table_net)))# y hat的值state_space.append(state)action_space.append(action)idx = [list(range(len(action_space))), action_space]y_hat = self.tgt_net(torch.Tensor(np.array(state_space)))[idx]return y_hat, torch.tensor(y)def update_net_parameters(self, update=True):self.net.load_state_dict(self.tgt_net.state_dict())

初始化环境

   # 初始化环境
env = gym.make("CartPole-v1")
# env = DiscreteOneHotWrapper(env)hidden_num = 64
# 定义网络
net = Net(env.observation_space.shape[0],hidden_num, env.action_space.n)
tgt_net = Net(env.observation_space.shape[0],hidden_num, env.action_space.n)
dqn = DQN(env=env, net=net, tgt_net=tgt_net)# 初始化参数
# dqn.init_net_and_target_net_weight()# 定义优化器
opt = optim.Adam(tgt_net.parameters(), lr=0.001)# 定义损失函数
loss = nn.MSELoss()# 记录训练过程
# writer = SummaryWriter(log_dir="logs/DQN", comment="DQN")

开始训练

gamma = 0.8
for i in range(10000):batch = dqn.generate_train_data(256, 0.8)y_hat, y = dqn.calculate_y_hat_and_y(batch)opt.zero_grad()l = loss(y_hat, y)l.backward()opt.step()print("MSE: {}".format(l.item()))if i % 5 == 0:dqn.update_net_parameters(update=True)

输出:

MSE: 0.027348674833774567
MSE: 0.1803671419620514
MSE: 0.06523636728525162
MSE: 0.08363766968250275
MSE: 0.062360599637031555
MSE: 0.004909628536552191
MSE: 0.05730309337377548
MSE: 0.03543371334671974
MSE: 0.08458714932203293

可视化结果

env = gym.make("CartPole-v1", render_mode = "human")
env = gym.wrappers.RecordVideo(env, video_folder="video")state, info = env.reset()
total_rewards = 0while True:q_table_state = dqn.tgt_net(torch.Tensor(state)).detach()# if np.random.uniform(0, 1, 1) > 0.9:#     action = env.action_space.sample()# else:action = int(torch.argmax(q_table_state))state, reward, terminated, truncted, info = env.step(action)if terminated:break

这篇关于强化学习原理python篇06——DQN的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/652678

相关文章

Python使用Pandas对比两列数据取最大值的五种方法

《Python使用Pandas对比两列数据取最大值的五种方法》本文主要介绍使用Pandas对比两列数据取最大值的五种方法,包括使用max方法、apply方法结合lambda函数、函数、clip方法、w... 目录引言一、使用max方法二、使用apply方法结合lambda函数三、使用np.maximum函数

Python调用Orator ORM进行数据库操作

《Python调用OratorORM进行数据库操作》OratorORM是一个功能丰富且灵活的PythonORM库,旨在简化数据库操作,它支持多种数据库并提供了简洁且直观的API,下面我们就... 目录Orator ORM 主要特点安装使用示例总结Orator ORM 是一个功能丰富且灵活的 python O

Python使用国内镜像加速pip安装的方法讲解

《Python使用国内镜像加速pip安装的方法讲解》在Python开发中,pip是一个非常重要的工具,用于安装和管理Python的第三方库,然而,在国内使用pip安装依赖时,往往会因为网络问题而导致速... 目录一、pip 工具简介1. 什么是 pip?2. 什么是 -i 参数?二、国内镜像源的选择三、如何

python使用fastapi实现多语言国际化的操作指南

《python使用fastapi实现多语言国际化的操作指南》本文介绍了使用Python和FastAPI实现多语言国际化的操作指南,包括多语言架构技术栈、翻译管理、前端本地化、语言切换机制以及常见陷阱和... 目录多语言国际化实现指南项目多语言架构技术栈目录结构翻译工作流1. 翻译数据存储2. 翻译生成脚本

如何通过Python实现一个消息队列

《如何通过Python实现一个消息队列》这篇文章主要为大家详细介绍了如何通过Python实现一个简单的消息队列,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录如何通过 python 实现消息队列如何把 http 请求放在队列中执行1. 使用 queue.Queue 和 reque

Python如何实现PDF隐私信息检测

《Python如何实现PDF隐私信息检测》随着越来越多的个人信息以电子形式存储和传输,确保这些信息的安全至关重要,本文将介绍如何使用Python检测PDF文件中的隐私信息,需要的可以参考下... 目录项目背景技术栈代码解析功能说明运行结php果在当今,数据隐私保护变得尤为重要。随着越来越多的个人信息以电子形

使用Python快速实现链接转word文档

《使用Python快速实现链接转word文档》这篇文章主要为大家详细介绍了如何使用Python快速实现链接转word文档功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 演示代码展示from newspaper import Articlefrom docx import

Python Jupyter Notebook导包报错问题及解决

《PythonJupyterNotebook导包报错问题及解决》在conda环境中安装包后,JupyterNotebook导入时出现ImportError,可能是由于包版本不对应或版本太高,解决方... 目录问题解决方法重新安装Jupyter NoteBook 更改Kernel总结问题在conda上安装了

Python如何计算两个不同类型列表的相似度

《Python如何计算两个不同类型列表的相似度》在编程中,经常需要比较两个列表的相似度,尤其是当这两个列表包含不同类型的元素时,下面小编就来讲讲如何使用Python计算两个不同类型列表的相似度吧... 目录摘要引言数字类型相似度欧几里得距离曼哈顿距离字符串类型相似度Levenshtein距离Jaccard相

Python安装时常见报错以及解决方案

《Python安装时常见报错以及解决方案》:本文主要介绍在安装Python、配置环境变量、使用pip以及运行Python脚本时常见的错误及其解决方案,文中介绍的非常详细,需要的朋友可以参考下... 目录一、安装 python 时常见报错及解决方案(一)安装包下载失败(二)权限不足二、配置环境变量时常见报错及