使用Keras处理深度学习中的多分类问题——路透社新闻分类

2023-12-19 09:10

本文主要是介绍使用Keras处理深度学习中的多分类问题——路透社新闻分类,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

简介

本文将着手构建一个网络,将路透社新闻划分为46个互斥的主题,与二分类问题不同,这是一个多分类问题。

关于二分类问题的处理方式,请参考:使用Keras处理深度学习中的二分类问题——Imdb影评分类。

对于某个新闻,它只能划分到46个类别中的一个,所以这个问题又是单标签、多分类问题。如果每条新闻可以划分到不同的主题,那就是多标签、多分类问题了。

路透社数据集由路透社在1986年发布,包含许多短新闻及其对应的主题,它是一个简单、广泛使用的文本分类数据集。

它包含46个主题,某些主题的样本会比较多,有些比较少,但训练集中每个主题至少有10个样本。

路透社数据集内置于Keras,可以直接加载。

加载数据

代码:

from keras.datasets import reuters(train_data, train_labels), (test_data, test_labels) = reuters.load_data(num_words = 10000)

查看数据:

在这里插入图片描述

可以把数据解码为新闻文本:

# 将索引解码为新闻文本
path = r"E:\practice\tf2\new_multi_calssify\reuters_word_index.json"
word_index = reuters.get_word_index(path) # 省略 path 默认会下载文件到 C:\Users\Administrator\.keras\datasets
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])
decoded_newswire = ' '.join([reverse_word_index.get(i - 3, '?') for i in train_data[0]])

如下:

在这里插入图片描述

准备数据

对数据进行加工,以便输入到网络:

# 准备数据
import numpy as np# 数据向量化
def vectorize_sequences(sequences, dimension=10000):results = np.zeros((len(sequences), dimension))for i, sequence in enumerate(sequences):results[i, sequence] = 1.return resultsx_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)# 标签向量化,使用one-hot编码,使用Keras内置的函数
from keras.utils.np_utils import to_categoricalone_hot_train_labels = to_categorical(train_labels)
one_hot_test_labels = to_categorical(test_labels)
构建网络

本次仍构建二层密集连接层组成的网络,由于要训练的类别较多,使用64个隐藏单元。使用的函数与二分类中一致:

from keras import models
from keras import layersmodel = models.Sequential()
model.add(layers.Dense(64, activation = 'relu', input_shape = (10000,)))
model.add(layers.Dense(64, activation = 'relu'))
model.add(layers.Dense(46, activation = 'softmax'))
编译

使用与二分类中一样的优化器、损失函数,还是一行代码:

model.compile(optimizer = 'rmsprop', loss = 'categorical_crossentropy', metrics = ['accuracy'])
训练模型

首先从训练数据中留出了用于验证的的数据,训练结果保存在history中,用于绘图查看效果:

# 留出验证数据
x_val = x_train[:1000]
partial_x_train = x_train[1000:]y_val = one_hot_train_labels[:1000]
partial_y_train = one_hot_train_labels[1000:]# 训练,这里训练了9轮次,因为我已经知道了从9轮次后数据会出现过拟合,你可以先使用较大参数,分析结果后确定一个合适的轮次再训练一遍
history = model.fit(partial_x_train, partial_y_train, epochs = 9, batch_size = 512, validation_data = (x_val, y_val))
绘图表示结果

分别绘制损失曲线和精度曲线:

import matplotlib.pyplot as plt# 损失曲线
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(1, len(loss) + 1)plt.plot(epochs, loss, 'bo', label = 'Training loss')
plt.plot(epochs, val_loss, 'b', label = 'Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()plt.show()# 精度曲线
plt.clf()acc = history.history['accuracy']
val_acc = history.history['val_accuracy']plt.plot(epochs, acc, 'bo', label = 'Training accuracy')
plt.plot(epochs, val_acc, 'b', label = 'Validation accuracy')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel(Accuracy)
plt.legend()plt.show()

如图:

在这里插入图片描述
在这里插入图片描述

评估并作出预测

直接上代码:

# 评估模型
results = model.evaluate(x_test, one_hot_test_labels)
# 查看结果
results# 在新数据上生成预测
predictions = model.predict(x_test)# 查看结果
predictions[0].shape
np.sum(predictions[0])
np.argmax(predictions[0])

效果如下:

在这里插入图片描述

小结

本例展示了从数据搜集到使用模型预测新数据的整个流程,包括:

  • 数据搜集
  • 数据预处理
  • 构建网络
  • 编译网络
  • 构建验证数据
  • 训练网络
  • 结果绘制
  • 评估模型
  • 预测新数据

这个流程是通用流程,但并非唯一流程,可能会增加一些流程,也可能会减小一些流程,甚至在前述工作基础上,再返回头来调整一些参数,这些都是正常的。

模型能够在前所未有的数据上得到很好的预测结果,这才是终极目标。

参考资料

《Python深度学习》

这篇关于使用Keras处理深度学习中的多分类问题——路透社新闻分类的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/511646

相关文章

Python Transformers库(NLP处理库)案例代码讲解

《PythonTransformers库(NLP处理库)案例代码讲解》本文介绍transformers库的全面讲解,包含基础知识、高级用法、案例代码及学习路径,内容经过组织,适合不同阶段的学习者,对... 目录一、基础知识1. Transformers 库简介2. 安装与环境配置3. 快速上手示例二、核心模

解决Maven项目idea找不到本地仓库jar包问题以及使用mvn install:install-file

《解决Maven项目idea找不到本地仓库jar包问题以及使用mvninstall:install-file》:本文主要介绍解决Maven项目idea找不到本地仓库jar包问题以及使用mvnin... 目录Maven项目idea找不到本地仓库jar包以及使用mvn install:install-file基

一文详解Java异常处理你都了解哪些知识

《一文详解Java异常处理你都了解哪些知识》:本文主要介绍Java异常处理的相关资料,包括异常的分类、捕获和处理异常的语法、常见的异常类型以及自定义异常的实现,文中通过代码介绍的非常详细,需要的朋... 目录前言一、什么是异常二、异常的分类2.1 受检异常2.2 非受检异常三、异常处理的语法3.1 try-

Python使用getopt处理命令行参数示例解析(最佳实践)

《Python使用getopt处理命令行参数示例解析(最佳实践)》getopt模块是Python标准库中一个简单但强大的命令行参数处理工具,它特别适合那些需要快速实现基本命令行参数解析的场景,或者需要... 目录为什么需要处理命令行参数?getopt模块基础实际应用示例与其他参数处理方式的比较常见问http

Java Response返回值的最佳处理方案

《JavaResponse返回值的最佳处理方案》在开发Web应用程序时,我们经常需要通过HTTP请求从服务器获取响应数据,这些数据可以是JSON、XML、甚至是文件,本篇文章将详细解析Java中处理... 目录摘要概述核心问题:关键技术点:源码解析示例 1:使用HttpURLConnection获取Resp

C 语言中enum枚举的定义和使用小结

《C语言中enum枚举的定义和使用小结》在C语言里,enum(枚举)是一种用户自定义的数据类型,它能够让你创建一组具名的整数常量,下面我会从定义、使用、特性等方面详细介绍enum,感兴趣的朋友一起看... 目录1、引言2、基本定义3、定义枚举变量4、自定义枚举常量的值5、枚举与switch语句结合使用6、枚

使用Python从PPT文档中提取图片和图片信息(如坐标、宽度和高度等)

《使用Python从PPT文档中提取图片和图片信息(如坐标、宽度和高度等)》PPT是一种高效的信息展示工具,广泛应用于教育、商务和设计等多个领域,PPT文档中常常包含丰富的图片内容,这些图片不仅提升了... 目录一、引言二、环境与工具三、python 提取PPT背景图片3.1 提取幻灯片背景图片3.2 提取

usb接口驱动异常问题常用解决方案

《usb接口驱动异常问题常用解决方案》当遇到USB接口驱动异常时,可以通过多种方法来解决,其中主要就包括重装USB控制器、禁用USB选择性暂停设置、更新或安装新的主板驱动等... usb接口驱动异常怎么办,USB接口驱动异常是常见问题,通常由驱动损坏、系统更新冲突、硬件故障或电源管理设置导致。以下是常用解决

Java中Switch Case多个条件处理方法举例

《Java中SwitchCase多个条件处理方法举例》Java中switch语句用于根据变量值执行不同代码块,适用于多个条件的处理,:本文主要介绍Java中SwitchCase多个条件处理的相... 目录前言基本语法处理多个条件示例1:合并相同代码的多个case示例2:通过字符串合并多个case进阶用法使用

Mysql如何解决死锁问题

《Mysql如何解决死锁问题》:本文主要介绍Mysql如何解决死锁问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录【一】mysql中锁分类和加锁情况【1】按锁的粒度分类全局锁表级锁行级锁【2】按锁的模式分类【二】加锁方式的影响因素【三】Mysql的死锁情况【1