使用Keras处理深度学习中的多分类问题——路透社新闻分类

2023-12-19 09:10

本文主要是介绍使用Keras处理深度学习中的多分类问题——路透社新闻分类,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

简介

本文将着手构建一个网络,将路透社新闻划分为46个互斥的主题,与二分类问题不同,这是一个多分类问题。

关于二分类问题的处理方式,请参考:使用Keras处理深度学习中的二分类问题——Imdb影评分类。

对于某个新闻,它只能划分到46个类别中的一个,所以这个问题又是单标签、多分类问题。如果每条新闻可以划分到不同的主题,那就是多标签、多分类问题了。

路透社数据集由路透社在1986年发布,包含许多短新闻及其对应的主题,它是一个简单、广泛使用的文本分类数据集。

它包含46个主题,某些主题的样本会比较多,有些比较少,但训练集中每个主题至少有10个样本。

路透社数据集内置于Keras,可以直接加载。

加载数据

代码:

from keras.datasets import reuters(train_data, train_labels), (test_data, test_labels) = reuters.load_data(num_words = 10000)

查看数据:

在这里插入图片描述

可以把数据解码为新闻文本:

# 将索引解码为新闻文本
path = r"E:\practice\tf2\new_multi_calssify\reuters_word_index.json"
word_index = reuters.get_word_index(path) # 省略 path 默认会下载文件到 C:\Users\Administrator\.keras\datasets
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])
decoded_newswire = ' '.join([reverse_word_index.get(i - 3, '?') for i in train_data[0]])

如下:

在这里插入图片描述

准备数据

对数据进行加工,以便输入到网络:

# 准备数据
import numpy as np# 数据向量化
def vectorize_sequences(sequences, dimension=10000):results = np.zeros((len(sequences), dimension))for i, sequence in enumerate(sequences):results[i, sequence] = 1.return resultsx_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)# 标签向量化,使用one-hot编码,使用Keras内置的函数
from keras.utils.np_utils import to_categoricalone_hot_train_labels = to_categorical(train_labels)
one_hot_test_labels = to_categorical(test_labels)
构建网络

本次仍构建二层密集连接层组成的网络,由于要训练的类别较多,使用64个隐藏单元。使用的函数与二分类中一致:

from keras import models
from keras import layersmodel = models.Sequential()
model.add(layers.Dense(64, activation = 'relu', input_shape = (10000,)))
model.add(layers.Dense(64, activation = 'relu'))
model.add(layers.Dense(46, activation = 'softmax'))
编译

使用与二分类中一样的优化器、损失函数,还是一行代码:

model.compile(optimizer = 'rmsprop', loss = 'categorical_crossentropy', metrics = ['accuracy'])
训练模型

首先从训练数据中留出了用于验证的的数据,训练结果保存在history中,用于绘图查看效果:

# 留出验证数据
x_val = x_train[:1000]
partial_x_train = x_train[1000:]y_val = one_hot_train_labels[:1000]
partial_y_train = one_hot_train_labels[1000:]# 训练,这里训练了9轮次,因为我已经知道了从9轮次后数据会出现过拟合,你可以先使用较大参数,分析结果后确定一个合适的轮次再训练一遍
history = model.fit(partial_x_train, partial_y_train, epochs = 9, batch_size = 512, validation_data = (x_val, y_val))
绘图表示结果

分别绘制损失曲线和精度曲线:

import matplotlib.pyplot as plt# 损失曲线
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(1, len(loss) + 1)plt.plot(epochs, loss, 'bo', label = 'Training loss')
plt.plot(epochs, val_loss, 'b', label = 'Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()plt.show()# 精度曲线
plt.clf()acc = history.history['accuracy']
val_acc = history.history['val_accuracy']plt.plot(epochs, acc, 'bo', label = 'Training accuracy')
plt.plot(epochs, val_acc, 'b', label = 'Validation accuracy')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel(Accuracy)
plt.legend()plt.show()

如图:

在这里插入图片描述
在这里插入图片描述

评估并作出预测

直接上代码:

# 评估模型
results = model.evaluate(x_test, one_hot_test_labels)
# 查看结果
results# 在新数据上生成预测
predictions = model.predict(x_test)# 查看结果
predictions[0].shape
np.sum(predictions[0])
np.argmax(predictions[0])

效果如下:

在这里插入图片描述

小结

本例展示了从数据搜集到使用模型预测新数据的整个流程,包括:

  • 数据搜集
  • 数据预处理
  • 构建网络
  • 编译网络
  • 构建验证数据
  • 训练网络
  • 结果绘制
  • 评估模型
  • 预测新数据

这个流程是通用流程,但并非唯一流程,可能会增加一些流程,也可能会减小一些流程,甚至在前述工作基础上,再返回头来调整一些参数,这些都是正常的。

模型能够在前所未有的数据上得到很好的预测结果,这才是终极目标。

参考资料

《Python深度学习》

这篇关于使用Keras处理深度学习中的多分类问题——路透社新闻分类的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/511646

相关文章

Pandas使用SQLite3实战

《Pandas使用SQLite3实战》本文主要介绍了Pandas使用SQLite3实战,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学... 目录1 环境准备2 从 SQLite3VlfrWQzgt 读取数据到 DataFrame基础用法:读

JSON Web Token在登陆中的使用过程

《JSONWebToken在登陆中的使用过程》:本文主要介绍JSONWebToken在登陆中的使用过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录JWT 介绍微服务架构中的 JWT 使用结合微服务网关的 JWT 验证1. 用户登录,生成 JWT2. 自定义过滤

Java中StopWatch的使用示例详解

《Java中StopWatch的使用示例详解》stopWatch是org.springframework.util包下的一个工具类,使用它可直观的输出代码执行耗时,以及执行时间百分比,这篇文章主要介绍... 目录stopWatch 是org.springframework.util 包下的一个工具类,使用它

Java使用Curator进行ZooKeeper操作的详细教程

《Java使用Curator进行ZooKeeper操作的详细教程》ApacheCurator是一个基于ZooKeeper的Java客户端库,它极大地简化了使用ZooKeeper的开发工作,在分布式系统... 目录1、简述2、核心功能2.1 CuratorFramework2.2 Recipes3、示例实践3

Springboot处理跨域的实现方式(附Demo)

《Springboot处理跨域的实现方式(附Demo)》:本文主要介绍Springboot处理跨域的实现方式(附Demo),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不... 目录Springboot处理跨域的方式1. 基本知识2. @CrossOrigin3. 全局跨域设置4.

springboot security使用jwt认证方式

《springbootsecurity使用jwt认证方式》:本文主要介绍springbootsecurity使用jwt认证方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录前言代码示例依赖定义mapper定义用户信息的实体beansecurity相关的类提供登录接口测试提供一

go中空接口的具体使用

《go中空接口的具体使用》空接口是一种特殊的接口类型,它不包含任何方法,本文主要介绍了go中空接口的具体使用,具有一定的参考价值,感兴趣的可以了解一下... 目录接口-空接口1. 什么是空接口?2. 如何使用空接口?第一,第二,第三,3. 空接口几个要注意的坑坑1:坑2:坑3:接口-空接口1. 什么是空接

springboot security快速使用示例详解

《springbootsecurity快速使用示例详解》:本文主要介绍springbootsecurity快速使用示例,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝... 目录创www.chinasem.cn建spring boot项目生成脚手架配置依赖接口示例代码项目结构启用s

Python如何使用__slots__实现节省内存和性能优化

《Python如何使用__slots__实现节省内存和性能优化》你有想过,一个小小的__slots__能让你的Python类内存消耗直接减半吗,没错,今天咱们要聊的就是这个让人眼前一亮的技巧,感兴趣的... 目录背景:内存吃得满满的类__slots__:你的内存管理小助手举个大概的例子:看看效果如何?1.

java中使用POI生成Excel并导出过程

《java中使用POI生成Excel并导出过程》:本文主要介绍java中使用POI生成Excel并导出过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录需求说明及实现方式需求完成通用代码版本1版本2结果展示type参数为atype参数为b总结注:本文章中代码均为