DAY9-深度学习100例-循环神经网络(RNN)实现股票预测

2024-03-20 03:50

本文主要是介绍DAY9-深度学习100例-循环神经网络(RNN)实现股票预测,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!


活动地址:CSDN21天学习挑战赛

  • 本文为🔗365天深度学习训练营 中的学习记录博客
  • 参考文章地址: 🔗深度学习100例-循环神经网络(RNN)实现股票预测 | 第9天

文章目录

  • 前言
  • 一、RNN
  • 二、准备工作
    • 1.设置GPU
    • 2.加载数据
  • 三、数据预处理
    • 1.归一化
    • 2.设置训练集、测试集
  • 四、构建模型
  • 五、激活模型
  • 六、训练模型
  • 七、结果可视化
    • 1.绘制loss图
    • 2.预测
    • 3.评估


前言

今天会通过RNN实现股票开盘价格的预测。

一、RNN

传统神经网络的结构:输入层—隐藏层—输出层,结构简单。
在这里插入图片描述
RNN与传统神经网络的区别在于:每次会将前一次的输出结果带入隐含层中,一起参与训练。
在这里插入图片描述
举例:用户说了一句"What time is it?",神经网络会先将这句话分为五个基本单元(四个单词+一个问号)。
在这里插入图片描述
然后按照顺序输入到RNN中,先输入"What"得到输出O1,随后按照顺序将"time"输入,得到输出O2,可以看到"What"的输出对于O2来说也有影响,以此类推,前面所有输入的输出结果都对后续的输出产生了影响。当神经网络判断意图时,只需要最后一层的输出O5即可。

在这里插入图片描述

二、准备工作

1.设置GPU

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")

2.加载数据

import os,math
from tensorflow.keras.layers import Dropout, Dense, SimpleRNN
from sklearn.preprocessing   import MinMaxScaler
from sklearn                 import metrics
import numpy             as np
import pandas            as pd
import tensorflow        as tf
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
data = pd.read_csv('SH600519.csv')  # 读取股票文件data

在这里插入图片描述

"""
前(2426-300=2126)天的开盘价作为训练集,表格从0开始计数,2:3 是提取[2:3)列,前闭后开,故提取出C列开盘价
后300天的开盘价作为测试集
"""
training_set = data.iloc[0:2426 - 300, 2:3].values  
test_set = data.iloc[2426 - 300:, 2:3].values  

三、数据预处理

1.归一化

sc           = MinMaxScaler(feature_range=(0, 1))
training_set = sc.fit_transform(training_set)
test_set     = sc.transform(test_set) 

2.设置训练集、测试集

x_train = []
y_train = []x_test = []
y_test = []"""
使用前60天的开盘价作为输入特征x_train第61天的开盘价作为输入标签y_trainfor循环共构建2426-300-60=2066组训练数据。共构建300-60=260组测试数据
"""
for i in range(60, len(training_set)):x_train.append(training_set[i - 60:i, 0])y_train.append(training_set[i, 0])for i in range(60, len(test_set)):x_test.append(test_set[i - 60:i, 0])y_test.append(test_set[i, 0])# 对训练集进行打乱
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
"""
将训练数据调整为数组(array)调整后的形状:
x_train:(2066, 60, 1)
y_train:(2066,)
x_test :(240, 60, 1)
y_test :(240,)
"""
x_train, y_train = np.array(x_train), np.array(y_train) # x_train形状为:(2066, 60, 1)
x_test,  y_test  = np.array(x_test),  np.array(y_test)"""
输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]
"""
x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))
x_test  = np.reshape(x_test,  (x_test.shape[0], 60, 1))

四、构建模型

model = tf.keras.Sequential([SimpleRNN(100, return_sequences=True), #布尔值。是返回输出序列中的最后一个输出,还是全部序列。Dropout(0.1),                         #防止过拟合SimpleRNN(100),Dropout(0.1),Dense(1)
])

五、激活模型

# 该应用只观测loss数值,不观测准确率,所以删去metrics选项,一会在每个epoch迭代显示时只显示loss值
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss='mean_squared_error')  # 损失函数用均方误差

六、训练模型

history = model.fit(x_train, y_train, batch_size=64, epochs=20, validation_data=(x_test, y_test), validation_freq=1)                  #测试的epoch间隔数model.summary()

在这里插入图片描述
在这里插入图片描述

七、结果可视化

1.绘制loss图

plt.plot(history.history['loss']    , label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

在这里插入图片描述

2.预测

predicted_stock_price = model.predict(x_test)                       # 测试集输入模型进行预测
predicted_stock_price = sc.inverse_transform(predicted_stock_price) # 对预测数据还原---从(01)反归一化到原始范围
real_stock_price = sc.inverse_transform(test_set[60:])              # 对真实数据还原---从(01)反归一化到原始范围# 画出真实数据和预测数据的对比曲线
plt.plot(real_stock_price, color='red', label='Stock Price')
plt.plot(predicted_stock_price, color='blue', label='Predicted Stock Price')
plt.title('Stock Price Prediction')
plt.xlabel('Time')
plt.ylabel('Stock Price')
plt.legend()
plt.show()

在这里插入图片描述

3.评估

"""
MSE  :均方误差    ----->  预测值减真实值求平方后求均值
RMSE :均方根误差  ----->  对均方误差开方
MAE  :平均绝对误差----->  预测值减真实值求绝对值后求均值
R2   :决定系数,可以简单理解为反映模型拟合优度的重要的统计量详细介绍可以参考文章:https://blog.csdn.net/qq_38251616/article/details/107997435
"""
MSE   = metrics.mean_squared_error(predicted_stock_price, real_stock_price)
RMSE  = metrics.mean_squared_error(predicted_stock_price, real_stock_price)**0.5
MAE   = metrics.mean_absolute_error(predicted_stock_price, real_stock_price)
R2    = metrics.r2_score(predicted_stock_price, real_stock_price)print('均方误差: %.5f' % MSE)
print('均方根误差: %.5f' % RMSE)
print('平均绝对误差: %.5f' % MAE)
print('R2: %.5f' % R2)
均方误差: 2832.56937
均方根误差: 53.22189
平均绝对误差: 46.33860
R2: 0.59097

这篇关于DAY9-深度学习100例-循环神经网络(RNN)实现股票预测的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/828195

相关文章

Go语言开发实现查询IP信息的MCP服务器

《Go语言开发实现查询IP信息的MCP服务器》随着MCP的快速普及和广泛应用,MCP服务器也层出不穷,本文将详细介绍如何在Go语言中使用go-mcp库来开发一个查询IP信息的MCP... 目录前言mcp-ip-geo 服务器目录结构说明查询 IP 信息功能实现工具实现工具管理查询单个 IP 信息工具的实现服

SpringBoot基于配置实现短信服务策略的动态切换

《SpringBoot基于配置实现短信服务策略的动态切换》这篇文章主要为大家详细介绍了SpringBoot在接入多个短信服务商(如阿里云、腾讯云、华为云)后,如何根据配置或环境切换使用不同的服务商,需... 目录目标功能示例配置(application.yml)配置类绑定短信发送策略接口示例:阿里云 & 腾

python实现svg图片转换为png和gif

《python实现svg图片转换为png和gif》这篇文章主要为大家详细介绍了python如何实现将svg图片格式转换为png和gif,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录python实现svg图片转换为png和gifpython实现图片格式之间的相互转换延展:基于Py

Python利用ElementTree实现快速解析XML文件

《Python利用ElementTree实现快速解析XML文件》ElementTree是Python标准库的一部分,而且是Python标准库中用于解析和操作XML数据的模块,下面小编就来和大家详细讲讲... 目录一、XML文件解析到底有多重要二、ElementTree快速入门1. 加载XML的两种方式2.

Java的栈与队列实现代码解析

《Java的栈与队列实现代码解析》栈是常见的线性数据结构,栈的特点是以先进后出的形式,后进先出,先进后出,分为栈底和栈顶,栈应用于内存的分配,表达式求值,存储临时的数据和方法的调用等,本文给大家介绍J... 目录栈的概念(Stack)栈的实现代码队列(Queue)模拟实现队列(双链表实现)循环队列(循环数组

C++如何通过Qt反射机制实现数据类序列化

《C++如何通过Qt反射机制实现数据类序列化》在C++工程中经常需要使用数据类,并对数据类进行存储、打印、调试等操作,所以本文就来聊聊C++如何通过Qt反射机制实现数据类序列化吧... 目录设计预期设计思路代码实现使用方法在 C++ 工程中经常需要使用数据类,并对数据类进行存储、打印、调试等操作。由于数据类

Python实现图片分割的多种方法总结

《Python实现图片分割的多种方法总结》图片分割是图像处理中的一个重要任务,它的目标是将图像划分为多个区域或者对象,本文为大家整理了一些常用的分割方法,大家可以根据需求自行选择... 目录1. 基于传统图像处理的分割方法(1) 使用固定阈值分割图片(2) 自适应阈值分割(3) 使用图像边缘检测分割(4)

Android实现在线预览office文档的示例详解

《Android实现在线预览office文档的示例详解》在移动端展示在线Office文档(如Word、Excel、PPT)是一项常见需求,这篇文章为大家重点介绍了两种方案的实现方法,希望对大家有一定的... 目录一、项目概述二、相关技术知识三、实现思路3.1 方案一:WebView + Office Onl

C# foreach 循环中获取索引的实现方式

《C#foreach循环中获取索引的实现方式》:本文主要介绍C#foreach循环中获取索引的实现方式,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录一、手动维护索引变量二、LINQ Select + 元组解构三、扩展方法封装索引四、使用 for 循环替代

Spring Security+JWT如何实现前后端分离权限控制

《SpringSecurity+JWT如何实现前后端分离权限控制》本篇将手把手教你用SpringSecurity+JWT搭建一套完整的登录认证与权限控制体系,具有很好的参考价值,希望对大家... 目录Spring Security+JWT实现前后端分离权限控制实战一、为什么要用 JWT?二、JWT 基本结构