DAY9-深度学习100例-循环神经网络(RNN)实现股票预测

2024-03-20 03:50

本文主要是介绍DAY9-深度学习100例-循环神经网络(RNN)实现股票预测,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!


活动地址:CSDN21天学习挑战赛

  • 本文为🔗365天深度学习训练营 中的学习记录博客
  • 参考文章地址: 🔗深度学习100例-循环神经网络(RNN)实现股票预测 | 第9天

文章目录

  • 前言
  • 一、RNN
  • 二、准备工作
    • 1.设置GPU
    • 2.加载数据
  • 三、数据预处理
    • 1.归一化
    • 2.设置训练集、测试集
  • 四、构建模型
  • 五、激活模型
  • 六、训练模型
  • 七、结果可视化
    • 1.绘制loss图
    • 2.预测
    • 3.评估


前言

今天会通过RNN实现股票开盘价格的预测。

一、RNN

传统神经网络的结构:输入层—隐藏层—输出层,结构简单。
在这里插入图片描述
RNN与传统神经网络的区别在于:每次会将前一次的输出结果带入隐含层中,一起参与训练。
在这里插入图片描述
举例:用户说了一句"What time is it?",神经网络会先将这句话分为五个基本单元(四个单词+一个问号)。
在这里插入图片描述
然后按照顺序输入到RNN中,先输入"What"得到输出O1,随后按照顺序将"time"输入,得到输出O2,可以看到"What"的输出对于O2来说也有影响,以此类推,前面所有输入的输出结果都对后续的输出产生了影响。当神经网络判断意图时,只需要最后一层的输出O5即可。

在这里插入图片描述

二、准备工作

1.设置GPU

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")

2.加载数据

import os,math
from tensorflow.keras.layers import Dropout, Dense, SimpleRNN
from sklearn.preprocessing   import MinMaxScaler
from sklearn                 import metrics
import numpy             as np
import pandas            as pd
import tensorflow        as tf
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
data = pd.read_csv('SH600519.csv')  # 读取股票文件data

在这里插入图片描述

"""
前(2426-300=2126)天的开盘价作为训练集,表格从0开始计数,2:3 是提取[2:3)列,前闭后开,故提取出C列开盘价
后300天的开盘价作为测试集
"""
training_set = data.iloc[0:2426 - 300, 2:3].values  
test_set = data.iloc[2426 - 300:, 2:3].values  

三、数据预处理

1.归一化

sc           = MinMaxScaler(feature_range=(0, 1))
training_set = sc.fit_transform(training_set)
test_set     = sc.transform(test_set) 

2.设置训练集、测试集

x_train = []
y_train = []x_test = []
y_test = []"""
使用前60天的开盘价作为输入特征x_train第61天的开盘价作为输入标签y_trainfor循环共构建2426-300-60=2066组训练数据。共构建300-60=260组测试数据
"""
for i in range(60, len(training_set)):x_train.append(training_set[i - 60:i, 0])y_train.append(training_set[i, 0])for i in range(60, len(test_set)):x_test.append(test_set[i - 60:i, 0])y_test.append(test_set[i, 0])# 对训练集进行打乱
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
"""
将训练数据调整为数组(array)调整后的形状:
x_train:(2066, 60, 1)
y_train:(2066,)
x_test :(240, 60, 1)
y_test :(240,)
"""
x_train, y_train = np.array(x_train), np.array(y_train) # x_train形状为:(2066, 60, 1)
x_test,  y_test  = np.array(x_test),  np.array(y_test)"""
输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]
"""
x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))
x_test  = np.reshape(x_test,  (x_test.shape[0], 60, 1))

四、构建模型

model = tf.keras.Sequential([SimpleRNN(100, return_sequences=True), #布尔值。是返回输出序列中的最后一个输出,还是全部序列。Dropout(0.1),                         #防止过拟合SimpleRNN(100),Dropout(0.1),Dense(1)
])

五、激活模型

# 该应用只观测loss数值,不观测准确率,所以删去metrics选项,一会在每个epoch迭代显示时只显示loss值
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss='mean_squared_error')  # 损失函数用均方误差

六、训练模型

history = model.fit(x_train, y_train, batch_size=64, epochs=20, validation_data=(x_test, y_test), validation_freq=1)                  #测试的epoch间隔数model.summary()

在这里插入图片描述
在这里插入图片描述

七、结果可视化

1.绘制loss图

plt.plot(history.history['loss']    , label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

在这里插入图片描述

2.预测

predicted_stock_price = model.predict(x_test)                       # 测试集输入模型进行预测
predicted_stock_price = sc.inverse_transform(predicted_stock_price) # 对预测数据还原---从(01)反归一化到原始范围
real_stock_price = sc.inverse_transform(test_set[60:])              # 对真实数据还原---从(01)反归一化到原始范围# 画出真实数据和预测数据的对比曲线
plt.plot(real_stock_price, color='red', label='Stock Price')
plt.plot(predicted_stock_price, color='blue', label='Predicted Stock Price')
plt.title('Stock Price Prediction')
plt.xlabel('Time')
plt.ylabel('Stock Price')
plt.legend()
plt.show()

在这里插入图片描述

3.评估

"""
MSE  :均方误差    ----->  预测值减真实值求平方后求均值
RMSE :均方根误差  ----->  对均方误差开方
MAE  :平均绝对误差----->  预测值减真实值求绝对值后求均值
R2   :决定系数,可以简单理解为反映模型拟合优度的重要的统计量详细介绍可以参考文章:https://blog.csdn.net/qq_38251616/article/details/107997435
"""
MSE   = metrics.mean_squared_error(predicted_stock_price, real_stock_price)
RMSE  = metrics.mean_squared_error(predicted_stock_price, real_stock_price)**0.5
MAE   = metrics.mean_absolute_error(predicted_stock_price, real_stock_price)
R2    = metrics.r2_score(predicted_stock_price, real_stock_price)print('均方误差: %.5f' % MSE)
print('均方根误差: %.5f' % RMSE)
print('平均绝对误差: %.5f' % MAE)
print('R2: %.5f' % R2)
均方误差: 2832.56937
均方根误差: 53.22189
平均绝对误差: 46.33860
R2: 0.59097

这篇关于DAY9-深度学习100例-循环神经网络(RNN)实现股票预测的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/828195

相关文章

使用Python实现在Word中添加或删除超链接

《使用Python实现在Word中添加或删除超链接》在Word文档中,超链接是一种将文本或图像连接到其他文档、网页或同一文档中不同部分的功能,本文将为大家介绍一下Python如何实现在Word中添加或... 在Word文档中,超链接是一种将文本或图像连接到其他文档、网页或同一文档中不同部分的功能。通过添加超

windos server2022里的DFS配置的实现

《windosserver2022里的DFS配置的实现》DFS是WindowsServer操作系统提供的一种功能,用于在多台服务器上集中管理共享文件夹和文件的分布式存储解决方案,本文就来介绍一下wi... 目录什么是DFS?优势:应用场景:DFS配置步骤什么是DFS?DFS指的是分布式文件系统(Distr

NFS实现多服务器文件的共享的方法步骤

《NFS实现多服务器文件的共享的方法步骤》NFS允许网络中的计算机之间共享资源,客户端可以透明地读写远端NFS服务器上的文件,本文就来介绍一下NFS实现多服务器文件的共享的方法步骤,感兴趣的可以了解一... 目录一、简介二、部署1、准备1、服务端和客户端:安装nfs-utils2、服务端:创建共享目录3、服

C#使用yield关键字实现提升迭代性能与效率

《C#使用yield关键字实现提升迭代性能与效率》yield关键字在C#中简化了数据迭代的方式,实现了按需生成数据,自动维护迭代状态,本文主要来聊聊如何使用yield关键字实现提升迭代性能与效率,感兴... 目录前言传统迭代和yield迭代方式对比yield延迟加载按需获取数据yield break显式示迭

Python实现高效地读写大型文件

《Python实现高效地读写大型文件》Python如何读写的是大型文件,有没有什么方法来提高效率呢,这篇文章就来和大家聊聊如何在Python中高效地读写大型文件,需要的可以了解下... 目录一、逐行读取大型文件二、分块读取大型文件三、使用 mmap 模块进行内存映射文件操作(适用于大文件)四、使用 pand

python实现pdf转word和excel的示例代码

《python实现pdf转word和excel的示例代码》本文主要介绍了python实现pdf转word和excel的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价... 目录一、引言二、python编程1,PDF转Word2,PDF转Excel三、前端页面效果展示总结一

Python xmltodict实现简化XML数据处理

《Pythonxmltodict实现简化XML数据处理》Python社区为提供了xmltodict库,它专为简化XML与Python数据结构的转换而设计,本文主要来为大家介绍一下如何使用xmltod... 目录一、引言二、XMLtodict介绍设计理念适用场景三、功能参数与属性1、parse函数2、unpa

C#实现获得某个枚举的所有名称

《C#实现获得某个枚举的所有名称》这篇文章主要为大家详细介绍了C#如何实现获得某个枚举的所有名称,文中的示例代码讲解详细,具有一定的借鉴价值,有需要的小伙伴可以参考一下... C#中获得某个枚举的所有名称using System;using System.Collections.Generic;usi

Go语言实现将中文转化为拼音功能

《Go语言实现将中文转化为拼音功能》这篇文章主要为大家详细介绍了Go语言中如何实现将中文转化为拼音功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 有这么一个需求:新用户入职 创建一系列账号比较麻烦,打算通过接口传入姓名进行初始化。想把姓名转化成拼音。因为有些账号即需要中文也需要英

C# 读写ini文件操作实现

《C#读写ini文件操作实现》本文主要介绍了C#读写ini文件操作实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录一、INI文件结构二、读取INI文件中的数据在C#应用程序中,常将INI文件作为配置文件,用于存储应用程序的