RNN学习:利用LSTM,GRU层解决航空公司评论数据预测问题

2024-03-29 13:48

本文主要是介绍RNN学习:利用LSTM,GRU层解决航空公司评论数据预测问题,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

RNN学习:利用LSTM,GRU解决航空公司评论数据预测问题

文章目录

  • RNN学习:利用LSTM,GRU解决航空公司评论数据预测问题
    • 1.RNN的介绍
      • 1.1 LSTM的简单介绍
      • 1.2 GRU的简单介绍
    • 2.数据集的介绍
    • 3.读取数据并作预处理
    • 4.模型的搭建
    • 结语

1.RNN的介绍

​ RNN,即循环神经网络,即一般的神经网络同层节点与节点之间并无连接,比如CNN隐藏单元之间并没有连接,那么这相对于一些序列问题上的处理就会效果很差。如翻译单词,全文的意思必须是根据所有单词来进行判断。或判断说话人情绪,评论好坏,最终的输出要和前面所有的输入发生关系,所以这里学者们提出循环神经网络,让上一个节点会对下一个节点传递状态向量,每个节点之间输出两个值,一个是我们要的输出,还有一个就是状态向量,该向量输入下一个下一个节点,最终输出为二维数据(None,units)units为RNN的隐藏单元数。

在这里插入图片描述

1.1 LSTM的简单介绍

​ 刚才我们说明了RNN会不断的向下一个节点传递状态,但是经过长时间的多次传递,最终传递的状态可能会引起梯度爆炸或梯度消失等问题,为了解决这个问题,学者们又提出了LSTM层来解决这个问题,LSTM层的内部存在一些门,他会通过训练门的参数控制了上一状态我们需要遗忘多少,并且在这一层状态的更新。

在这里插入图片描述

可以看到在这一个单元中上一层的输出ht-1和状态Ct-1都传递了进来从而经过我们的门来控制该单元遗忘并更新状态。

1.2 GRU的简单介绍

GRU是LSTM结构的一种变体,他可以做到与LSTM性能相当的情况下,计算量会比LSTM减少,他的网络结构如下

在这里插入图片描述

可以看到他作为LSTM的变体,与LSTM的相似之处,他也会有前一次的状态(但是不会有前一层的输出传入)向他传入并且通过训练控制前一次状态在本单元的遗忘与更新。

2.数据集的介绍

本次使用的是Twitter 美国航空公司情绪:2015年2月美国航空公司的Twitter数据,分类为正面,负面和中性推文(https://www.kaggle.com/crowdflower/twitter-airline-sentiment)

在这里插入图片描述

整个数据集使用CSV格式存储,这种文件格式是一种经常用来数据科学存储数据的纯文本文件,可以用EXCEL直接打开。

可以看到该数据集上有该评价的好坏有neural,positive,negative三种,关于评价的具体 文本是text下的,我们在此次任务中只会用到评价文本(作为数据),情绪好坏作为我们的标签,也就是真值。

3.读取数据并作预处理

​ 首先先清楚我们的目标在预处理过程,是想要提取一个序列(这个序列是由我们的评论转换的),和一个标签(标签也要数字化),那么我们接下来就开始从CSV格式文件中提取文本和标签并分别将他们转化成序列和数字。

import tensorflow as tf
keras=tf.keras
layers=keras.layers
import numpy as np
import pandas as pd
import re
data=pd.read_csv('../input/twitter-airline-sentiment/Tweets.csv')
data.head()#文件内部数据太多使用这个默认查看前五行

在这里插入图片描述

然后我们此次只需要提取每个人评论的text,和评论观点的倾向,所以我们提取以下两列

data=data[['airline_sentiment','text']]
data.head()

在这里插入图片描述

我们成功提取每个评论的情绪,和文本,接下来我们先将情绪用数字表示,可以先查看有多少种情绪

data.airline_sentiment.value_counts()#使用该方法可以查看每个值的个数
negative    9178
neutral     3099
positive    2363
Name: airline_sentiment, dtype: int64

可以看到这里有三个倾向的情绪,消极,中立,积极,那么也就是说这是一个多分类单标签问题,那么我们直接对每个情绪进行编码然后转化即可

sentiment_to_index={'positive':0,'neutral':1,'negative':2}
def to_index(sentiment):#写函数来转化return sentiment_to_index.get(sentiment)
data['sentiment']=data.airline_sentiment.apply(to_index)
del data['airline_sentiment'] #删除原有的一列
data.head()

在这里插入图片描述

可以看到我们的标签被成功的转化成对应的数字标签。

并且我们还要注意一点,消极的评论远远多余积极的评论,我们在训练分类问题上最好是将每个类别上的数据的数量都保持一致,防止模型对于某些分类的特征过分学习。也就是说我们在这里使用消极和中立的数量都必须被降为和积极一样,那么这里我们就直接使用切片,对于series数据切片使用iloc函数

data_positive=data[data.airline_sentiment=='positive']
data_negative=data[data.airline_sentiment=='negative']
data_neutral=data[data.airline_sentiment=='neutral']
data_negative=data_negative.iloc[:len(data_positive)]  
data_neutral=data_neutral.iloc[:len(data_positive)]
len(data_negative),len(data_neutral),len(data_positive)(2363, 2363, 2363) #可以看到我们将三个数据全部转化为相同个数

那么接下来我们合并我们的这些数据并且使用sample方法随机打乱(sample的用法是从原有数据随机抽出一部分数据,但是如果我们把抽出数据的规模等于所有数据,就相当于打乱)

data=pd.concat([data_negative,data_positive,data_neutral])
data=data.sample(len(data))  #smaple的意思是从dataframe中随机抽取指定数量的数据
data.head()

那么接下来我们就将每个文本转化为一个序列,怎么转化呢,其实很简单,那就是将每个句子里的单词映射成一个数字,那么整个句子就成为了一个数字序列,那么如何来完成了,接下来我们开始贴代码

token=re.compile('[A-Za-z]+|[!?,.()]')
#我们设置匹配的时候不要特殊字符,只要标点符号和字母,并且大小写不会影响单词原意,我这里也直接将所有大写转化成小写
def constractor_text(text):res_text=token.findall(text)res_text=[word.lower() for word in res_text]return res_text
#上面是使用re库提供的一个正则匹配方法在除去特殊符号其他均匹配情况下效果显著
new_data=data.text.apply(constractor_text)
data['text']=new_data
data.head()

在这里插入图片描述

那么接下里我们将单词全部映射成一个个数字其实想法很简单,先做一个集合将所有单词添加进集合吗,由于集合本身的特性,会自动删除重复的,然后我们将该集合中的单词转化成字典,就可以将单词转化成序列了,这里也简单的贴代码

word_list=list(word_set) #因为集合并没有下标这个概念,所以为了后面的方便我们转化成列表
word_dict=dict((k,v+1) for v,k in enumerate(word_list))
word_dict#同时为了防止填充单词之后填充0影响结果我们将所有数据,转化
{'win': 1,'DEFINITELY': 2,'gfc': 3,'OI': 4,'pearl': 5,'briughy': 6,'necessity': 7,'flyingwithUS': 8,'agreement': 9,...

这里需要非常注意的一点就是,每个评论的数据都是有一定长度的,但最后为了规范化我们一定是要将所有评论长度都处理到相同长度,那么我们填充的数字一般用0来填充,所以我们在字典里不能对0进行赋值,防止影响结果,所以我这里将所有单词对应的编号加一。可以看到我们单词编号从一开始。

好的有了单词的转换表,那么我们接下来编写函数将句子转换成序列

def word_to_vector(text):vector=[word_dict.get(word,0) for word in text]return vector
data['text']=data.text.apply(word_to_vector)
data_text=data['text']
data_text.head()8263    [3228, 11239, 9075, 694, 1133, 4364, 10324, 10...
4953    [1721, 11079, 870, 10, 11285, 9390, 10642, 724...
5489    [1721, 443, 6165, 4999, 4859, 4806, 7367, 7013...
2452    [3436, 10200, 6758, 10, 310, 1660, 8275, 10324...
8219    [3228, 11460, 10774, 10324, 1291, 6804, 516, 7...
Name: text, dtype: object

这里我们可以看到每个句子就都被转换为对应的序列,那么我们接下里将所有向量处理成完全一样的长度,

maxlen=max(len(x) for x in data_text)
max_word=len(word_set)+1
data_text=keras.preprocessing.sequence.pad_sequences(data_text.values,maxlen=maxlen)
data_text.shape
(7089, 40)

可以看到每个序列都被填充到了长度为40,那么我们接下来提取标签然后制作dataset,划分测试集与训练集

label=data.sentiment.values
test_count=int(7089*0.2)
train_count=7089-test_count
test_data=train_data.take(test_count)
train_data=train_data.skip(test_count)
train_data=train_data.shuffle(train_count).repeat().batch(64)
test_data=test_data.batch(64)

划分完毕后我们总算是完成了我们数据的预处理,接下来开始我们模型的搭建。

4.模型的搭建

我们输入的是一个长度为40的序列,但这样并不适合我们模型对他的处理,对此已经有提出词嵌入方法,WORD2VEC的方法,即将每个单词转化成固定维度的向量,向量之间差的大小,表示每个单词之间关系的大小(我理解为单词之间的相似性),这里我们可以用RGB表示颜色的方式来理解,每个颜色的值都可以用一个三维向量来表示,对于单词就是我们设置一个几十个维度的词向量,假设所有词都可以用这个高维向量来表示,那么具体怎么转换,有多种方法,我们这里使用keras提供的embelding层来将所有单词转换成我们设定维度的向量

model=keras.Sequential()
#Embedding层可以吧文本映射为一个密集向量
model.add(layers.Embedding(max_word,50,input_length=maxlen))
#然后我们多次未见的主角GRU登场,用它来处理这种序列数据效果是十分好的
model.add(layers.GRU(64))#LSTM的参数是一个隐藏单元数
model.add(layers.Dense(3,activation='softmax'))
#最后输出这是一个三分类的问题,所以我们激活函数用softmax
model.compile(optimizer=keras.optimizers.Adam(0.0001),loss='sparse_categorical_crossentropy',metrics=['acc'])
#设置模型的优化器这里没什么好说的

5.训练结果分析与网络调整

model.fit(train_data,steps_per_epoch=train_count//64,epochs=10,validation_data=test_data,validation_steps=test_count//64)

这里我们开始训练查看结果却发现

在这里插入图片描述

网络已经达到严重过拟合,测试集准确率极高,但验证集却非常低,两者相差达到20%,那么为了抑制过拟合我这里采取两种方法一是增加网络深度,添加Dropout层抑制过拟合

model=keras.Sequential()
#Embedding层可以吧文本映射为一个密集向量
model.add(layers.Embedding(max_word,50,input_length=maxlen))
model.add(layers.GRU(64))#LSTM的参数是一个隐藏单元数
model.add(layers.Dropout(0.2))
model.add(layers.Dense(32,activation='relu'))
model.add(layers.Dropout(0.2))
model.add(layers.Dense(16,activation='relu'))
model.add(layers.Dense(3,activation='softmax'))

,二是我将数据增加一倍,(等于是复制了一遍数据再,打乱),最终数据翻倍达到14000多条那么我们再次开始训练,查看结果

Epoch 15/15
177/177 [==============================] - 6s 34ms/step - loss: 0.0542 - acc: 0.9852 - val_loss: 0.1551 - val_acc: 0.9592

可以看到在训练最后,过拟合被抑制了,模型无论在训练集,测试集都达到了极高的正确率

结语

本篇博客简单介绍了RNN网络,并且非常具体的展示了如何从CSV文件读取数据,预处理并制作成模型可以接收的数据,在最后利用GRU搭建模型,并且对于训练结果产生过拟合如何去抑制方面做了处理,如果有任何建议或者问题欢迎评论区指出,谢谢!

这篇关于RNN学习:利用LSTM,GRU层解决航空公司评论数据预测问题的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/858763

相关文章

Vue3绑定props默认值问题

《Vue3绑定props默认值问题》使用Vue3的defineProps配合TypeScript的interface定义props类型,并通过withDefaults设置默认值,使组件能安全访问传入的... 目录前言步骤步骤1:使用 defineProps 定义 Props步骤2:设置默认值总结前言使用T

MyBatis-plus处理存储json数据过程

《MyBatis-plus处理存储json数据过程》文章介绍MyBatis-Plus3.4.21处理对象与集合的差异:对象可用内置Handler配合autoResultMap,集合需自定义处理器继承F... 目录1、如果是对象2、如果需要转换的是List集合总结对象和集合分两种情况处理,目前我用的MP的版本

GSON框架下将百度天气JSON数据转JavaBean

《GSON框架下将百度天气JSON数据转JavaBean》这篇文章主要为大家详细介绍了如何在GSON框架下实现将百度天气JSON数据转JavaBean,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下... 目录前言一、百度天气jsON1、请求参数2、返回参数3、属性映射二、GSON属性映射实战1、类对象映

504 Gateway Timeout网关超时的根源及完美解决方法

《504GatewayTimeout网关超时的根源及完美解决方法》在日常开发和运维过程中,504GatewayTimeout错误是常见的网络问题之一,尤其是在使用反向代理(如Nginx)或... 目录引言为什么会出现 504 错误?1. 探索 504 Gateway Timeout 错误的根源 1.1 后端

Web服务器-Nginx-高并发问题

《Web服务器-Nginx-高并发问题》Nginx通过事件驱动、I/O多路复用和异步非阻塞技术高效处理高并发,结合动静分离和限流策略,提升性能与稳定性... 目录前言一、架构1. 原生多进程架构2. 事件驱动模型3. IO多路复用4. 异步非阻塞 I/O5. Nginx高并发配置实战二、动静分离1. 职责2

解决升级JDK报错:module java.base does not“opens java.lang.reflect“to unnamed module问题

《解决升级JDK报错:modulejava.basedoesnot“opensjava.lang.reflect“tounnamedmodule问题》SpringBoot启动错误源于Jav... 目录问题描述原因分析解决方案总结问题描述启动sprintboot时报以下错误原因分析编程异js常是由Ja

C# LiteDB处理时间序列数据的高性能解决方案

《C#LiteDB处理时间序列数据的高性能解决方案》LiteDB作为.NET生态下的轻量级嵌入式NoSQL数据库,一直是时间序列处理的优选方案,本文将为大家大家简单介绍一下LiteDB处理时间序列数... 目录为什么选择LiteDB处理时间序列数据第一章:LiteDB时间序列数据模型设计1.1 核心设计原则

Java+AI驱动实现PDF文件数据提取与解析

《Java+AI驱动实现PDF文件数据提取与解析》本文将和大家分享一套基于AI的体检报告智能评估方案,详细介绍从PDF上传、内容提取到AI分析、数据存储的全流程自动化实现方法,感兴趣的可以了解下... 目录一、核心流程:从上传到评估的完整链路二、第一步:解析 PDF,提取体检报告内容1. 引入依赖2. 封装

深度剖析SpringBoot日志性能提升的原因与解决

《深度剖析SpringBoot日志性能提升的原因与解决》日志记录本该是辅助工具,却为何成了性能瓶颈,SpringBoot如何用代码彻底破解日志导致的高延迟问题,感兴趣的小伙伴可以跟随小编一起学习一下... 目录前言第一章:日志性能陷阱的底层原理1.1 日志级别的“双刃剑”效应1.2 同步日志的“吞吐量杀手”

MySQL 表空却 ibd 文件过大的问题及解决方法

《MySQL表空却ibd文件过大的问题及解决方法》本文给大家介绍MySQL表空却ibd文件过大的问题及解决方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考... 目录一、问题背景:表空却 “吃满” 磁盘的怪事二、问题复现:一步步编程还原异常场景1. 准备测试源表与数据