RNN学习:利用LSTM,GRU层解决航空公司评论数据预测问题

2024-03-29 13:48

本文主要是介绍RNN学习:利用LSTM,GRU层解决航空公司评论数据预测问题,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

RNN学习:利用LSTM,GRU解决航空公司评论数据预测问题

文章目录

  • RNN学习:利用LSTM,GRU解决航空公司评论数据预测问题
    • 1.RNN的介绍
      • 1.1 LSTM的简单介绍
      • 1.2 GRU的简单介绍
    • 2.数据集的介绍
    • 3.读取数据并作预处理
    • 4.模型的搭建
    • 结语

1.RNN的介绍

​ RNN,即循环神经网络,即一般的神经网络同层节点与节点之间并无连接,比如CNN隐藏单元之间并没有连接,那么这相对于一些序列问题上的处理就会效果很差。如翻译单词,全文的意思必须是根据所有单词来进行判断。或判断说话人情绪,评论好坏,最终的输出要和前面所有的输入发生关系,所以这里学者们提出循环神经网络,让上一个节点会对下一个节点传递状态向量,每个节点之间输出两个值,一个是我们要的输出,还有一个就是状态向量,该向量输入下一个下一个节点,最终输出为二维数据(None,units)units为RNN的隐藏单元数。

在这里插入图片描述

1.1 LSTM的简单介绍

​ 刚才我们说明了RNN会不断的向下一个节点传递状态,但是经过长时间的多次传递,最终传递的状态可能会引起梯度爆炸或梯度消失等问题,为了解决这个问题,学者们又提出了LSTM层来解决这个问题,LSTM层的内部存在一些门,他会通过训练门的参数控制了上一状态我们需要遗忘多少,并且在这一层状态的更新。

在这里插入图片描述

可以看到在这一个单元中上一层的输出ht-1和状态Ct-1都传递了进来从而经过我们的门来控制该单元遗忘并更新状态。

1.2 GRU的简单介绍

GRU是LSTM结构的一种变体,他可以做到与LSTM性能相当的情况下,计算量会比LSTM减少,他的网络结构如下

在这里插入图片描述

可以看到他作为LSTM的变体,与LSTM的相似之处,他也会有前一次的状态(但是不会有前一层的输出传入)向他传入并且通过训练控制前一次状态在本单元的遗忘与更新。

2.数据集的介绍

本次使用的是Twitter 美国航空公司情绪:2015年2月美国航空公司的Twitter数据,分类为正面,负面和中性推文(https://www.kaggle.com/crowdflower/twitter-airline-sentiment)

在这里插入图片描述

整个数据集使用CSV格式存储,这种文件格式是一种经常用来数据科学存储数据的纯文本文件,可以用EXCEL直接打开。

可以看到该数据集上有该评价的好坏有neural,positive,negative三种,关于评价的具体 文本是text下的,我们在此次任务中只会用到评价文本(作为数据),情绪好坏作为我们的标签,也就是真值。

3.读取数据并作预处理

​ 首先先清楚我们的目标在预处理过程,是想要提取一个序列(这个序列是由我们的评论转换的),和一个标签(标签也要数字化),那么我们接下来就开始从CSV格式文件中提取文本和标签并分别将他们转化成序列和数字。

import tensorflow as tf
keras=tf.keras
layers=keras.layers
import numpy as np
import pandas as pd
import re
data=pd.read_csv('../input/twitter-airline-sentiment/Tweets.csv')
data.head()#文件内部数据太多使用这个默认查看前五行

在这里插入图片描述

然后我们此次只需要提取每个人评论的text,和评论观点的倾向,所以我们提取以下两列

data=data[['airline_sentiment','text']]
data.head()

在这里插入图片描述

我们成功提取每个评论的情绪,和文本,接下来我们先将情绪用数字表示,可以先查看有多少种情绪

data.airline_sentiment.value_counts()#使用该方法可以查看每个值的个数
negative    9178
neutral     3099
positive    2363
Name: airline_sentiment, dtype: int64

可以看到这里有三个倾向的情绪,消极,中立,积极,那么也就是说这是一个多分类单标签问题,那么我们直接对每个情绪进行编码然后转化即可

sentiment_to_index={'positive':0,'neutral':1,'negative':2}
def to_index(sentiment):#写函数来转化return sentiment_to_index.get(sentiment)
data['sentiment']=data.airline_sentiment.apply(to_index)
del data['airline_sentiment'] #删除原有的一列
data.head()

在这里插入图片描述

可以看到我们的标签被成功的转化成对应的数字标签。

并且我们还要注意一点,消极的评论远远多余积极的评论,我们在训练分类问题上最好是将每个类别上的数据的数量都保持一致,防止模型对于某些分类的特征过分学习。也就是说我们在这里使用消极和中立的数量都必须被降为和积极一样,那么这里我们就直接使用切片,对于series数据切片使用iloc函数

data_positive=data[data.airline_sentiment=='positive']
data_negative=data[data.airline_sentiment=='negative']
data_neutral=data[data.airline_sentiment=='neutral']
data_negative=data_negative.iloc[:len(data_positive)]  
data_neutral=data_neutral.iloc[:len(data_positive)]
len(data_negative),len(data_neutral),len(data_positive)(2363, 2363, 2363) #可以看到我们将三个数据全部转化为相同个数

那么接下来我们合并我们的这些数据并且使用sample方法随机打乱(sample的用法是从原有数据随机抽出一部分数据,但是如果我们把抽出数据的规模等于所有数据,就相当于打乱)

data=pd.concat([data_negative,data_positive,data_neutral])
data=data.sample(len(data))  #smaple的意思是从dataframe中随机抽取指定数量的数据
data.head()

那么接下来我们就将每个文本转化为一个序列,怎么转化呢,其实很简单,那就是将每个句子里的单词映射成一个数字,那么整个句子就成为了一个数字序列,那么如何来完成了,接下来我们开始贴代码

token=re.compile('[A-Za-z]+|[!?,.()]')
#我们设置匹配的时候不要特殊字符,只要标点符号和字母,并且大小写不会影响单词原意,我这里也直接将所有大写转化成小写
def constractor_text(text):res_text=token.findall(text)res_text=[word.lower() for word in res_text]return res_text
#上面是使用re库提供的一个正则匹配方法在除去特殊符号其他均匹配情况下效果显著
new_data=data.text.apply(constractor_text)
data['text']=new_data
data.head()

在这里插入图片描述

那么接下里我们将单词全部映射成一个个数字其实想法很简单,先做一个集合将所有单词添加进集合吗,由于集合本身的特性,会自动删除重复的,然后我们将该集合中的单词转化成字典,就可以将单词转化成序列了,这里也简单的贴代码

word_list=list(word_set) #因为集合并没有下标这个概念,所以为了后面的方便我们转化成列表
word_dict=dict((k,v+1) for v,k in enumerate(word_list))
word_dict#同时为了防止填充单词之后填充0影响结果我们将所有数据,转化
{'win': 1,'DEFINITELY': 2,'gfc': 3,'OI': 4,'pearl': 5,'briughy': 6,'necessity': 7,'flyingwithUS': 8,'agreement': 9,...

这里需要非常注意的一点就是,每个评论的数据都是有一定长度的,但最后为了规范化我们一定是要将所有评论长度都处理到相同长度,那么我们填充的数字一般用0来填充,所以我们在字典里不能对0进行赋值,防止影响结果,所以我这里将所有单词对应的编号加一。可以看到我们单词编号从一开始。

好的有了单词的转换表,那么我们接下来编写函数将句子转换成序列

def word_to_vector(text):vector=[word_dict.get(word,0) for word in text]return vector
data['text']=data.text.apply(word_to_vector)
data_text=data['text']
data_text.head()8263    [3228, 11239, 9075, 694, 1133, 4364, 10324, 10...
4953    [1721, 11079, 870, 10, 11285, 9390, 10642, 724...
5489    [1721, 443, 6165, 4999, 4859, 4806, 7367, 7013...
2452    [3436, 10200, 6758, 10, 310, 1660, 8275, 10324...
8219    [3228, 11460, 10774, 10324, 1291, 6804, 516, 7...
Name: text, dtype: object

这里我们可以看到每个句子就都被转换为对应的序列,那么我们接下里将所有向量处理成完全一样的长度,

maxlen=max(len(x) for x in data_text)
max_word=len(word_set)+1
data_text=keras.preprocessing.sequence.pad_sequences(data_text.values,maxlen=maxlen)
data_text.shape
(7089, 40)

可以看到每个序列都被填充到了长度为40,那么我们接下来提取标签然后制作dataset,划分测试集与训练集

label=data.sentiment.values
test_count=int(7089*0.2)
train_count=7089-test_count
test_data=train_data.take(test_count)
train_data=train_data.skip(test_count)
train_data=train_data.shuffle(train_count).repeat().batch(64)
test_data=test_data.batch(64)

划分完毕后我们总算是完成了我们数据的预处理,接下来开始我们模型的搭建。

4.模型的搭建

我们输入的是一个长度为40的序列,但这样并不适合我们模型对他的处理,对此已经有提出词嵌入方法,WORD2VEC的方法,即将每个单词转化成固定维度的向量,向量之间差的大小,表示每个单词之间关系的大小(我理解为单词之间的相似性),这里我们可以用RGB表示颜色的方式来理解,每个颜色的值都可以用一个三维向量来表示,对于单词就是我们设置一个几十个维度的词向量,假设所有词都可以用这个高维向量来表示,那么具体怎么转换,有多种方法,我们这里使用keras提供的embelding层来将所有单词转换成我们设定维度的向量

model=keras.Sequential()
#Embedding层可以吧文本映射为一个密集向量
model.add(layers.Embedding(max_word,50,input_length=maxlen))
#然后我们多次未见的主角GRU登场,用它来处理这种序列数据效果是十分好的
model.add(layers.GRU(64))#LSTM的参数是一个隐藏单元数
model.add(layers.Dense(3,activation='softmax'))
#最后输出这是一个三分类的问题,所以我们激活函数用softmax
model.compile(optimizer=keras.optimizers.Adam(0.0001),loss='sparse_categorical_crossentropy',metrics=['acc'])
#设置模型的优化器这里没什么好说的

5.训练结果分析与网络调整

model.fit(train_data,steps_per_epoch=train_count//64,epochs=10,validation_data=test_data,validation_steps=test_count//64)

这里我们开始训练查看结果却发现

在这里插入图片描述

网络已经达到严重过拟合,测试集准确率极高,但验证集却非常低,两者相差达到20%,那么为了抑制过拟合我这里采取两种方法一是增加网络深度,添加Dropout层抑制过拟合

model=keras.Sequential()
#Embedding层可以吧文本映射为一个密集向量
model.add(layers.Embedding(max_word,50,input_length=maxlen))
model.add(layers.GRU(64))#LSTM的参数是一个隐藏单元数
model.add(layers.Dropout(0.2))
model.add(layers.Dense(32,activation='relu'))
model.add(layers.Dropout(0.2))
model.add(layers.Dense(16,activation='relu'))
model.add(layers.Dense(3,activation='softmax'))

,二是我将数据增加一倍,(等于是复制了一遍数据再,打乱),最终数据翻倍达到14000多条那么我们再次开始训练,查看结果

Epoch 15/15
177/177 [==============================] - 6s 34ms/step - loss: 0.0542 - acc: 0.9852 - val_loss: 0.1551 - val_acc: 0.9592

可以看到在训练最后,过拟合被抑制了,模型无论在训练集,测试集都达到了极高的正确率

结语

本篇博客简单介绍了RNN网络,并且非常具体的展示了如何从CSV文件读取数据,预处理并制作成模型可以接收的数据,在最后利用GRU搭建模型,并且对于训练结果产生过拟合如何去抑制方面做了处理,如果有任何建议或者问题欢迎评论区指出,谢谢!

这篇关于RNN学习:利用LSTM,GRU层解决航空公司评论数据预测问题的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/858763

相关文章

如何使用 Python 读取 Excel 数据

《如何使用Python读取Excel数据》:本文主要介绍使用Python读取Excel数据的详细教程,通过pandas和openpyxl,你可以轻松读取Excel文件,并进行各种数据处理操... 目录使用 python 读取 Excel 数据的详细教程1. 安装必要的依赖2. 读取 Excel 文件3. 读

关于MongoDB图片URL存储异常问题以及解决

《关于MongoDB图片URL存储异常问题以及解决》:本文主要介绍关于MongoDB图片URL存储异常问题以及解决方案,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录MongoDB图片URL存储异常问题项目场景问题描述原因分析解决方案预防措施js总结MongoDB图

SpringBoot项目中报错The field screenShot exceeds its maximum permitted size of 1048576 bytes.的问题及解决

《SpringBoot项目中报错ThefieldscreenShotexceedsitsmaximumpermittedsizeof1048576bytes.的问题及解决》这篇文章... 目录项目场景问题描述原因分析解决方案总结项目场景javascript提示:项目相关背景:项目场景:基于Spring

解决Maven项目idea找不到本地仓库jar包问题以及使用mvn install:install-file

《解决Maven项目idea找不到本地仓库jar包问题以及使用mvninstall:install-file》:本文主要介绍解决Maven项目idea找不到本地仓库jar包问题以及使用mvnin... 目录Maven项目idea找不到本地仓库jar包以及使用mvn install:install-file基

Spring 请求之传递 JSON 数据的操作方法

《Spring请求之传递JSON数据的操作方法》JSON就是一种数据格式,有自己的格式和语法,使用文本表示一个对象或数组的信息,因此JSON本质是字符串,主要负责在不同的语言中数据传递和交换,这... 目录jsON 概念JSON 语法JSON 的语法JSON 的两种结构JSON 字符串和 Java 对象互转

最详细安装 PostgreSQL方法及常见问题解决

《最详细安装PostgreSQL方法及常见问题解决》:本文主要介绍最详细安装PostgreSQL方法及常见问题解决,介绍了在Windows系统上安装PostgreSQL及Linux系统上安装Po... 目录一、在 Windows 系统上安装 PostgreSQL1. 下载 PostgreSQL 安装包2.

C++如何通过Qt反射机制实现数据类序列化

《C++如何通过Qt反射机制实现数据类序列化》在C++工程中经常需要使用数据类,并对数据类进行存储、打印、调试等操作,所以本文就来聊聊C++如何通过Qt反射机制实现数据类序列化吧... 目录设计预期设计思路代码实现使用方法在 C++ 工程中经常需要使用数据类,并对数据类进行存储、打印、调试等操作。由于数据类

usb接口驱动异常问题常用解决方案

《usb接口驱动异常问题常用解决方案》当遇到USB接口驱动异常时,可以通过多种方法来解决,其中主要就包括重装USB控制器、禁用USB选择性暂停设置、更新或安装新的主板驱动等... usb接口驱动异常怎么办,USB接口驱动异常是常见问题,通常由驱动损坏、系统更新冲突、硬件故障或电源管理设置导致。以下是常用解决

Mysql如何解决死锁问题

《Mysql如何解决死锁问题》:本文主要介绍Mysql如何解决死锁问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录【一】mysql中锁分类和加锁情况【1】按锁的粒度分类全局锁表级锁行级锁【2】按锁的模式分类【二】加锁方式的影响因素【三】Mysql的死锁情况【1

SpringBoot内嵌Tomcat临时目录问题及解决

《SpringBoot内嵌Tomcat临时目录问题及解决》:本文主要介绍SpringBoot内嵌Tomcat临时目录问题及解决,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,... 目录SprinjavascriptgBoot内嵌Tomcat临时目录问题1.背景2.方案3.代码中配置t