【Keras学习笔记】10:IMDb电影评价数据集文本分类

2023-11-25 17:30

本文主要是介绍【Keras学习笔记】10:IMDb电影评价数据集文本分类,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

读取数据
import keras
from keras import layers
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
%matplotlib inline
Using TensorFlow backend.
data = keras.datasets.imdb
# 最多提取10000个单词,多的不要
(x_train, y_train), (x_test, y_test) = data.load_data(num_words=10000)
Downloading data from https://s3.amazonaws.com/text-datasets/imdb.npz
17465344/17464789 [==============================] - 761s 44us/step
x_train.shape, y_train.shape, x_test.shape, y_test.shape
((25000,), (25000,), (25000,), (25000,))

数据集已经为每个单词做好数字编码了,所以得到的每个样本都是一个整数形式的向量:

# 看一下第一个样本的前10个单词的数字编码
x_train[0][:10]
[1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65]
# 标签y是非0即1的,表示负面和正面评价
y_train
array([1, 0, 0, ..., 0, 1, 0], dtype=int64)

不妨恢复一条样本看一下原始形式是什么样子的。

# 这个得到的是一个字典,里面是{单词:数字序号,单词:数字序号,...}
word_index = data.get_word_index()
Downloading data from https://s3.amazonaws.com/text-datasets/imdb_word_index.json
1646592/1641221 [==============================] - 100s 61us/step

现在要根据数字序号去得到单词,所以把这个字典的k-v反转一下。这里用生成器来将其反转,再转换成字典。

index_word = dict((value, key) for key, value in word_index.items())

用生成器将第一个样本转换成单词序列,注意这个数据集的word=>index映射时是从0开始编码的,但前面得到的word_index里保留了0,1,2三个编码,也就是所有编码加了3,,这里将其减掉。另外,有些词在word_index里找不到,不妨在找不到时候就给个?标识。

" ".join(index_word.get(index-3,'?') for index in x_train[0])
"? this film was just brilliant casting location scenery story direction everyone's really suited the part they played and you could just imagine being there robert ? is an amazing actor and now the same being director ? father came from the same scottish island as myself so i loved the fact there was a real connection with this film the witty remarks throughout the film were great it was just brilliant so much that i bought the film as soon as it was released for ? and would recommend it to everyone to watch and the fly fishing was amazing really cried at the end it was so sad and you know what they say if you cry at a film it must have been good and this definitely was also ? to the two little boy's that played the ? of norman and paul they were just brilliant children are often left out of the ? list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all"

样本有的很短,有的很长,看一下前10个样本的长度:

[len(seq) for seq in x_train[:10]]
[218, 189, 141, 550, 147, 43, 123, 562, 233, 130]

但一定不会超过读取数据集时定义的最大长度10000:

max(max(seq) for seq in x_train)
9999
文本的向量化

因为有10000个单词,可以使用长度为10000的向量,然后将每个词对应一个索引,如果一个词在一条样本中出现了,就将相应位置设置成1(或者+1),这就是次袋模型。

如果设置成1(而不是+1),那么这个向量是有很多为1的分量,其余位置都是0,在学习视频里老师叫它k-hot编码(没查到有这种叫法,估计又是自己扯的),了解一下就好。

def k_hot(seqs, dim=10000):"""数字编码转k-hot编码"""res = np.zeros((len(seqs), dim))for i, seq in enumerate(seqs):res[i, seq] = 1return res
x_tr = k_hot(x_train) 
x_tr.shape
(25000, 10000)
x_ts = k_hot(x_test)
x_ts.shape
(25000, 10000)
建立模型和训练
model = keras.Sequential()
model.add(layers.Dense(32, input_dim=10000, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_4 (Dense)              (None, 32)                320032    
_________________________________________________________________
dense_5 (Dense)              (None, 32)                1056      
_________________________________________________________________
dense_6 (Dense)              (None, 1)                 33        
=================================================================
Total params: 321,121
Trainable params: 321,121
Non-trainable params: 0
_________________________________________________________________
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['acc']
)
history = model.fit(x_tr, y_train, epochs=15, batch_size=256, validation_data=(x_ts, y_test), verbose=0)
plt.plot(history.epoch, history.history.get('val_acc'), c='g', label='validation acc')
plt.plot(history.epoch, history.history.get('acc'), c='b', label='train acc')
plt.legend()
<matplotlib.legend.Legend at 0x1890b908>

在这里插入图片描述

plt.plot(history.epoch, history.history.get('val_loss'), c='g', label='validation loss')
plt.plot(history.epoch, history.history.get('loss'), c='b', label='train loss')
plt.legend()
<matplotlib.legend.Legend at 0x189b79e8>

在这里插入图片描述

可以看到发生了严重的过拟合,下面尝试引入Dropout和正则化项,同时减小网络的规模。

模型优化
from keras import regularizers
model = keras.Sequential()
model.add(layers.Dense(8, input_dim=10000, activation='relu', kernel_regularizer=regularizers.l2(0.005)))
model.add(layers.Dropout(rate=0.4)) # keeep_prob=0.6
model.add(layers.Dense(8, activation='relu', kernel_regularizer=regularizers.l2(0.005)))
model.add(layers.Dense(1, activation='sigmoid'))
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_13 (Dense)             (None, 8)                 80008     
_________________________________________________________________
dropout_3 (Dropout)          (None, 8)                 0         
_________________________________________________________________
dense_14 (Dense)             (None, 8)                 72        
_________________________________________________________________
dense_15 (Dense)             (None, 1)                 9         
=================================================================
Total params: 80,089
Trainable params: 80,089
Non-trainable params: 0
_________________________________________________________________
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['acc']
)
history = model.fit(x_tr, y_train, epochs=15, batch_size=256, validation_data=(x_ts, y_test), verbose=0)
plt.plot(history.epoch, history.history.get('val_acc'), c='g', label='validation acc')
plt.plot(history.epoch, history.history.get('acc'), c='b', label='train acc')
plt.legend()
<matplotlib.legend.Legend at 0x1b2a4f28>

在这里插入图片描述

plt.plot(history.epoch, history.history.get('val_loss'), c='g', label='validation loss')
plt.plot(history.epoch, history.history.get('loss'), c='b', label='train loss')
plt.legend()
<matplotlib.legend.Legend at 0x1b319208>

在这里插入图片描述

好了很多。

这篇关于【Keras学习笔记】10:IMDb电影评价数据集文本分类的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/424249

相关文章

Python获取中国节假日数据记录入JSON文件

《Python获取中国节假日数据记录入JSON文件》项目系统内置的日历应用为了提升用户体验,特别设置了在调休日期显示“休”的UI图标功能,那么问题是这些调休数据从哪里来呢?我尝试一种更为智能的方法:P... 目录节假日数据获取存入jsON文件节假日数据读取封装完整代码项目系统内置的日历应用为了提升用户体验,

Java利用JSONPath操作JSON数据的技术指南

《Java利用JSONPath操作JSON数据的技术指南》JSONPath是一种强大的工具,用于查询和操作JSON数据,类似于SQL的语法,它为处理复杂的JSON数据结构提供了简单且高效... 目录1、简述2、什么是 jsONPath?3、Java 示例3.1 基本查询3.2 过滤查询3.3 递归搜索3.4

MySQL大表数据的分区与分库分表的实现

《MySQL大表数据的分区与分库分表的实现》数据库的分区和分库分表是两种常用的技术方案,本文主要介绍了MySQL大表数据的分区与分库分表的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有... 目录1. mysql大表数据的分区1.1 什么是分区?1.2 分区的类型1.3 分区的优点1.4 分

Mysql删除几亿条数据表中的部分数据的方法实现

《Mysql删除几亿条数据表中的部分数据的方法实现》在MySQL中删除一个大表中的数据时,需要特别注意操作的性能和对系统的影响,本文主要介绍了Mysql删除几亿条数据表中的部分数据的方法实现,具有一定... 目录1、需求2、方案1. 使用 DELETE 语句分批删除2. 使用 INPLACE ALTER T

Python Dash框架在数据可视化仪表板中的应用与实践记录

《PythonDash框架在数据可视化仪表板中的应用与实践记录》Python的PlotlyDash库提供了一种简便且强大的方式来构建和展示互动式数据仪表板,本篇文章将深入探讨如何使用Dash设计一... 目录python Dash框架在数据可视化仪表板中的应用与实践1. 什么是Plotly Dash?1.1

Redis 中的热点键和数据倾斜示例详解

《Redis中的热点键和数据倾斜示例详解》热点键是指在Redis中被频繁访问的特定键,这些键由于其高访问频率,可能导致Redis服务器的性能问题,尤其是在高并发场景下,本文给大家介绍Redis中的热... 目录Redis 中的热点键和数据倾斜热点键(Hot Key)定义特点应对策略示例数据倾斜(Data S

Python实现将MySQL中所有表的数据都导出为CSV文件并压缩

《Python实现将MySQL中所有表的数据都导出为CSV文件并压缩》这篇文章主要为大家详细介绍了如何使用Python将MySQL数据库中所有表的数据都导出为CSV文件到一个目录,并压缩为zip文件到... python将mysql数据库中所有表的数据都导出为CSV文件到一个目录,并压缩为zip文件到另一个

使用Python实现文本转语音(TTS)并播放音频

《使用Python实现文本转语音(TTS)并播放音频》在开发涉及语音交互或需要语音提示的应用时,文本转语音(TTS)技术是一个非常实用的工具,下面我们来看看如何使用gTTS和playsound库将文本... 目录什么是 gTTS 和 playsound安装依赖库实现步骤 1. 导入库2. 定义文本和语言 3

Python实现常用文本内容提取

《Python实现常用文本内容提取》在日常工作和学习中,我们经常需要从PDF、Word文档中提取文本,本文将介绍如何使用Python编写一个文本内容提取工具,有需要的小伙伴可以参考下... 目录一、引言二、文本内容提取的原理三、文本内容提取的设计四、文本内容提取的实现五、完整代码示例一、引言在日常工作和学

SpringBoot整合jasypt实现重要数据加密

《SpringBoot整合jasypt实现重要数据加密》Jasypt是一个专注于简化Java加密操作的开源工具,:本文主要介绍详细介绍了如何使用jasypt实现重要数据加密,感兴趣的小伙伴可... 目录jasypt简介 jasypt的优点SpringBoot使用jasypt创建mapper接口配置文件加密