基于LSTM模型的股票价格趋势预测,预测未来一天的开盘价格(附代码详解与注释)

本文主要是介绍基于LSTM模型的股票价格趋势预测,预测未来一天的开盘价格(附代码详解与注释),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一:简介

该股票价格选取了谷歌股票2012年1月3日至2016年12月20日,每天股票开盘的价格,其中2016年11月30日之前的股票价格作为LSTM模型的训练数据集。12月1日至20日的开盘价格作为股票价格的预测集。

数据展示:

测试集数据如该图所示;

二:模型介绍

LSTM模型是基于时间序列的模型,其内的神经元细胞具有记忆功能,即在该问题上,就是之前的开盘价格会影响后期的开盘价格,意思是12月20日早上的开盘价格受12月19日开盘价格的影响,于是,LSTM模型内的记忆细胞就会选择性的记住12月19日的价格。这是对LSTM的直白理解。如有问题请留言指正。严谨的LSTM结果模型如下图所示:

三:代码实现

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
# Part 1- Data Preprocessing
#importing training set
training_set=pd.read_csv('Google_Stock_Price_Train.csv')
#extract open value from the trainng data
training_set=training_set.iloc[:,1:2].values
#Feature Scaling
from sklearn.preprocessing import MinMaxScaler
sc=MinMaxScaler()
training_set=sc.fit_transform(training_set)
#Getting the input and output
X_train= training_set[:1236]
print(X_train)
Y_train=training_set[1:1257]
print(Y_train)
#Reshaping
X_train=np.reshape(X_train,(1236,1,1))
#Part-2 Building RNN
#importing keras library and packages
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
#Initalizing RNN
regressor=Sequential()
regressor.add(LSTM(units=50,activation='sigmoid', input_shape=(1,1)))
#Adding output layer (default argument)
regressor.add(Dense(units=1))
#Compile LSTM
regressor.compile(optimizer='adam',loss='mean_squared_error')
#Fitting the RNN on training set
regressor.fit(X_train,Y_train,batch_size=50,epochs=200)
#Part 3-Making Prediction and Visualizing Results
#Getting real Stock price for 2017
test_set=pd.read_csv('Google_Stock_Price_Test.csv')
real_stock_price=test_set.iloc[:,1:2].values
print(real_stock_price)
real_stock_price1=test_set.iloc[:,1:2].values
print(real_stock_price1)
#Getting predicted Stock price for 2017
inputs=real_stock_price
inputs=sc.transform(inputs)
inputs=np.reshape(inputs,(20,1,1))  #scaling the values
predicted_stock_price = regressor.predict(inputs)
predicted_stock_price = sc.inverse_transform(predicted_stock_price) #scaling to input values
#Visualize the results
x=[]
y1=[]
for i  in range(20):x.append(i)
for j in range(20):y1.append(j)
plt.plot(x,real_stock_price1,'ro',color='red',label='Real Stock Price')
plt.plot(y1,predicted_stock_price,'ro',color='green',label='Predicted Stock Price')
plt.title('Stock Price Prediction')
plt.xlabel('Time')
plt.ylabel('Stock Price')
plt.legend()
plt.show()
#Part 4- Evaluating the RNN
# since it is linear regression problem we will evaluate RMSE
import math
from sklearn.metrics import mean_squared_error
rmse=math.sqrt(mean_squared_error(real_stock_price, predicted_stock_price))
#expressing RMSE in percentage
rmse=rmse/800        # 800 becasue it is average value

训练结果:

............

 50/1236 [>.............................] - ETA: 0s - loss: 1.5879e-04
 950/1236 [======================>.......] - ETA: 0s - loss: 2.3606e-04
1236/1236 [==============================] - 0s 59us/step - loss: 2.5158e-04
Epoch 200/200

  50/1236 [>.............................] - ETA: 0s - loss: 6.5003e-04
1050/1236 [========================>.....] - ETA: 0s - loss: 2.7387e-04
1236/1236 [==============================] - 0s 53us/step - loss: 2.5892e-04

结果展示:

四:总结

该模型的主要目的就是训练该数据集根据前一天的开盘输入,能够预测出后一天的开盘价格,因此,输入训练集时输入0-1236行开盘价格为输入,而标签及为该1-2237行开盘价格,即标签往后推迟一天。最后训练的模型就行在测试集上进行测试。我将输出结果展示为散点图而不是折线图,其目的就是为了好给大家展示该预测结果与其真实值还是相当来说比较准的,但是该预测只是预测开盘价的走势。股票还有很多影响因素,其方法类似。

这篇关于基于LSTM模型的股票价格趋势预测,预测未来一天的开盘价格(附代码详解与注释)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/439654

相关文章

Conda与Python venv虚拟环境的区别与使用方法详解

《Conda与Pythonvenv虚拟环境的区别与使用方法详解》随着Python社区的成长,虚拟环境的概念和技术也在不断发展,:本文主要介绍Conda与Pythonvenv虚拟环境的区别与使用... 目录前言一、Conda 与 python venv 的核心区别1. Conda 的特点2. Python v

Spring Boot中WebSocket常用使用方法详解

《SpringBoot中WebSocket常用使用方法详解》本文从WebSocket的基础概念出发,详细介绍了SpringBoot集成WebSocket的步骤,并重点讲解了常用的使用方法,包括简单消... 目录一、WebSocket基础概念1.1 什么是WebSocket1.2 WebSocket与HTTP

java中反射Reflection的4个作用详解

《java中反射Reflection的4个作用详解》反射Reflection是Java等编程语言中的一个重要特性,它允许程序在运行时进行自我检查和对内部成员(如字段、方法、类等)的操作,本文将详细介绍... 目录作用1、在运行时判断任意一个对象所属的类作用2、在运行时构造任意一个类的对象作用3、在运行时判断

MySQL 中的 CAST 函数详解及常见用法

《MySQL中的CAST函数详解及常见用法》CAST函数是MySQL中用于数据类型转换的重要函数,它允许你将一个值从一种数据类型转换为另一种数据类型,本文给大家介绍MySQL中的CAST... 目录mysql 中的 CAST 函数详解一、基本语法二、支持的数据类型三、常见用法示例1. 字符串转数字2. 数字

SpringBoot中SM2公钥加密、私钥解密的实现示例详解

《SpringBoot中SM2公钥加密、私钥解密的实现示例详解》本文介绍了如何在SpringBoot项目中实现SM2公钥加密和私钥解密的功能,通过使用Hutool库和BouncyCastle依赖,简化... 目录一、前言1、加密信息(示例)2、加密结果(示例)二、实现代码1、yml文件配置2、创建SM2工具

MyBatis-Plus 中 nested() 与 and() 方法详解(最佳实践场景)

《MyBatis-Plus中nested()与and()方法详解(最佳实践场景)》在MyBatis-Plus的条件构造器中,nested()和and()都是用于构建复杂查询条件的关键方法,但... 目录MyBATis-Plus 中nested()与and()方法详解一、核心区别对比二、方法详解1.and()

Spring IoC 容器的使用详解(最新整理)

《SpringIoC容器的使用详解(最新整理)》文章介绍了Spring框架中的应用分层思想与IoC容器原理,通过分层解耦业务逻辑、数据访问等模块,IoC容器利用@Component注解管理Bean... 目录1. 应用分层2. IoC 的介绍3. IoC 容器的使用3.1. bean 的存储3.2. 方法注

MySQL 删除数据详解(最新整理)

《MySQL删除数据详解(最新整理)》:本文主要介绍MySQL删除数据的相关知识,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录一、前言二、mysql 中的三种删除方式1.DELETE语句✅ 基本语法: 示例:2.TRUNCATE语句✅ 基本语

Python内置函数之classmethod函数使用详解

《Python内置函数之classmethod函数使用详解》:本文主要介绍Python内置函数之classmethod函数使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 类方法定义与基本语法2. 类方法 vs 实例方法 vs 静态方法3. 核心特性与用法(1编程客

Python函数作用域示例详解

《Python函数作用域示例详解》本文介绍了Python中的LEGB作用域规则,详细解析了变量查找的四个层级,通过具体代码示例,展示了各层级的变量访问规则和特性,对python函数作用域相关知识感兴趣... 目录一、LEGB 规则二、作用域实例2.1 局部作用域(Local)2.2 闭包作用域(Enclos