(路透社数据集)新闻分类:多分类问题实战

2023-12-19 09:10

本文主要是介绍(路透社数据集)新闻分类:多分类问题实战,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

  • 前言
  • 一、电影评论分类实战
    • 1-1、数据集介绍&数据集导入&分割数据集
    • 1-2、字典的键值对颠倒&数字评论解码
    • 1-3、将整数序列转化为张量(训练数据和标签)
    • 1-4、搭建神经网络&选择损失函数和优化器&划分出验证集
    • 1-5、开始训练&绘制训练损失和验证损失&绘制训练准确率和验证准确率
    • 1-6、在测试集上验证准确率
  • 二、调参总结
  • 三、碎碎念(绘制3D爱心代码)
  • 总结


前言

对于路透社数据集的评论分类实战

一、电影评论分类实战

1-1、数据集介绍&数据集导入&分割数据集

from keras.datasets import reuters# 加载路透社数据集,包含许多短新闻及其对应的主题,它包含46个不同的主题。
# 加载数据:训练数据、训练标签;测试数据、测试标签。
# 将数据限定为前10000个最常出现的单词。
(train_data, train_labels), (test_data, test_labels) = reuters.load_data(num_words=10000)# 查看训练数据
train_data[0:2]

输出:可以看到单词序列已经被转化为了整数序列,否则的话我们还需要手动搭建词典并且将其转化为整数序列。
在这里插入图片描述

1-2、字典的键值对颠倒&数字评论解码

# 将单词映射为整数索引的字典。
word_index = reuters.get_word_index()# 键值颠倒,将整数索引映射为单词。
# 颠倒之后,前边是整数索引,后边是对应的单词。
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])# 将评论解码,注意,索引减去了3,是因为012是特殊含义的字符。
decoded_review = ' '.join(# 根据整数索引,查找对应的单词,然后使用空格来进行连接,如果没有找到相关的索引,那就用问号代替[reverse_word_index.get(i - 3, '?') for i in train_data[0]])# 看一下颠倒后的词典
print(reverse_word_index)
# 查看一下解码后的评论
print(decoded_review)

输出reverse_word_index
在这里插入图片描述
输出decoded_review:

在这里插入图片描述

1-3、将整数序列转化为张量(训练数据和标签)

import numpy as np
def vectorize_sequences(sequences,dimension=10000):"""将整数序列转化为二进制矩阵的函数"""results = np.zeros((len(sequences), dimension))for i, sequences in enumerate(sequences):# 相应列上的元素置为1,其他位置上的元素都为0。results[i, sequences] = 1return results# 这里只是预处理的一种方式,即单词序列编码为二进制向量,当然也可以采用其他方式,
# 比如说直接填充列表,然后使其具有相同的长度,然后将其转化为张量,并且网络第一层使用能够处理这种整数张量的层,即Embedding层。
# 训练数据向量化,即将其转化为二进制矩阵
x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)
# 将每个标签表示为全零向量,只有标签索引对应的元素为1
from keras.utils.np_utils import to_categorical
# keras内置这种转化方法,原理的话,与上边将整数序列转化为二进制矩阵的函数没有差别,唯一的不同是传入的维度是46,而不是10000
one_hot_train_labels = to_categorical(train_labels)
one_hot_test_labels = to_categorical(test_labels)# 查看一下训练集
print(one_hot_test_labels[0])
# 查看x_train
print(x_train)

输出one_hot_test_labels[0]
在这里插入图片描述
输出x_train
在这里插入图片描述

1-4、搭建神经网络&选择损失函数和优化器&划分出验证集

units = 64
from keras import models
from keras import layers
model = models.Sequential()
model.add(layers.Dense(units, activation='relu', input_shape=(10000,)))
model.add(layers.Dense(units, activation='relu'))
# 因为这里是46个类别,所以最后一层激活函数使用softmax,即对于每个输入样本,网络都会输出一个46维的向量,这个向量的每个元素代表不同的输出类别
model.add(layers.Dense(46, activation='softmax'))# one-hot编码标签对应categorical_crossentropy(分类交叉熵损失函数)
# 标签直接转化为张量对应sparse_categorical_crossentropy(稀疏交叉熵损失)
model.compile(optimizer='rmsprop',# 这类问题的损失一般都会使用分类交叉熵损失函数。loss = 'categorical_crossentropy',metrics = ['accuracy'])
x_val = x_train[:1000]
partial_x_train = x_train[1000:]y_val = one_hot_train_labels[:1000]
partial_y_train = one_hot_train_labels[1000:]

1-5、开始训练&绘制训练损失和验证损失&绘制训练准确率和验证准确率

epochs = 10history = model.fit(partial_x_train,partial_y_train,epochs=epochs,batch_size=512,validation_data=(x_val, y_val))

训练过程
在这里插入图片描述

绘制训练损失和验证损失

import plotly.express as px
import plotly.graph_objects as gohistory_dic = history.history
loss_val = history_dic['loss']
val_loss_values = history_dic['val_loss']
# epochs = range(1, len(loss_val)+1)
# np.linspace:作为序列生成器, numpy.linspace()函数用于在线性空间中以均匀步长生成数字序列
# 左闭右闭,所以是从整数120.
# 参数:起始、结束、生成的点
epochs = np.linspace(1, epochs, epochs)
fig = go.Figure()# Add traces
fig.add_trace(go.Scatter(x=epochs, y=loss_val,mode='markers',name='Training loss'))
fig.add_trace(go.Scatter(x=epochs, y=val_loss_values,mode='lines+markers',name='Validation loss'))
fig.show()

输出
在这里插入图片描述

绘制训练准确率和验证准确率

acc = history_dic['accuracy']
val_acc = history_dic['val_accuracy']
fig = go.Figure()# Add traces
fig.add_trace(go.Scatter(x=epochs, y=acc,mode='markers',name='Training acc'))
fig.add_trace(go.Scatter(x=epochs, y=val_acc,mode='lines+markers',name='Validation acc'))
fig.show()

输出
在这里插入图片描述

1-6、在测试集上验证准确率

# 两层、64个隐藏单元
# 训练轮次:20 损失:1.22 准确率:0.78
# 训练轮次:10 损失:0.96 准确率:0.79
# 训练轮次:9 损失:1.00 准确率:0.77
# 训练轮次:6 损失:1.01 准确率:0.77# 两层、128个隐藏单元 
# 训练轮次:20 损失:1.31 准确率:0.77
# 训练轮次:4 损失:0.97 准确率:0.78# 注意:准确率会浮动,一般在0.2的范围内浮动。model.evaluate(x_test, one_hot_test_labels)

在这里插入图片描述

二、调参总结

调参总结
1、训练轮次:先选择较大的轮次,一般设置为20,观察数据在验证集上的表现,训练是为了拟合一般数据,所以当模型在验证集上准确率下降时,那就不要再继续训练了。
2、隐藏单元设置:二分类选择较小的单元数,如果是多分类的话,可以试着设置较大的单元数,比如说64、128等。
3、隐藏层数设置:同隐藏单元的设置规则,这里设置的层数较少,如果数据复杂,可以多加几层来观察数据的整体表现。
4、标签直接设置为one-hot编码时,则对应设置损失为categorical_crossentropy(分类交叉熵损失函数),若标签直接转化为张量,则对应设置损失为sparse_categorical_crossentropy(稀疏交叉熵损失)。


三、碎碎念(绘制3D爱心代码)

# 刚打开csdn看到一个绘制3D爱心的代码,于是我直接白嫖过来。
import numpy as np
import wxgl.glplot as glta = np.linspace(0, 2*np.pi, 500)
b = np.linspace(0.5*np.pi, -0.5*np.pi, 500)
lons, lats = np.meshgrid(a, b)
w = np.sqrt(np.abs(a - np.pi)) * 2
x = 2 * np.cos(lats) * np.sin(lons) * w
y = -2 * np.cos(lats) * np.cos(lons) * w
z = 2 * np.sin(lats)glt.mesh(x, y, z, color='crimson') # crimson - 绯红
glt.show()

输出
在这里插入图片描述

总结

七夕不快乐,呱呱呱。

这篇关于(路透社数据集)新闻分类:多分类问题实战的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/511650

相关文章

C语言小项目实战之通讯录功能

《C语言小项目实战之通讯录功能》:本文主要介绍如何设计和实现一个简单的通讯录管理系统,包括联系人信息的存储、增加、删除、查找、修改和排序等功能,文中通过代码介绍的非常详细,需要的朋友可以参考下... 目录功能介绍:添加联系人模块显示联系人模块删除联系人模块查找联系人模块修改联系人模块排序联系人模块源代码如下

Java中注解与元数据示例详解

《Java中注解与元数据示例详解》Java注解和元数据是编程中重要的概念,用于描述程序元素的属性和用途,:本文主要介绍Java中注解与元数据的相关资料,文中通过代码介绍的非常详细,需要的朋友可以参... 目录一、引言二、元数据的概念2.1 定义2.2 作用三、Java 注解的基础3.1 注解的定义3.2 内

将sqlserver数据迁移到mysql的详细步骤记录

《将sqlserver数据迁移到mysql的详细步骤记录》:本文主要介绍将SQLServer数据迁移到MySQL的步骤,包括导出数据、转换数据格式和导入数据,通过示例和工具说明,帮助大家顺利完成... 目录前言一、导出SQL Server 数据二、转换数据格式为mysql兼容格式三、导入数据到MySQL数据

C++中使用vector存储并遍历数据的基本步骤

《C++中使用vector存储并遍历数据的基本步骤》C++标准模板库(STL)提供了多种容器类型,包括顺序容器、关联容器、无序关联容器和容器适配器,每种容器都有其特定的用途和特性,:本文主要介绍C... 目录(1)容器及简要描述‌php顺序容器‌‌关联容器‌‌无序关联容器‌(基于哈希表):‌容器适配器‌:(

C#提取PDF表单数据的实现流程

《C#提取PDF表单数据的实现流程》PDF表单是一种常见的数据收集工具,广泛应用于调查问卷、业务合同等场景,凭借出色的跨平台兼容性和标准化特点,PDF表单在各行各业中得到了广泛应用,本文将探讨如何使用... 目录引言使用工具C# 提取多个PDF表单域的数据C# 提取特定PDF表单域的数据引言PDF表单是一

一文详解Python中数据清洗与处理的常用方法

《一文详解Python中数据清洗与处理的常用方法》在数据处理与分析过程中,缺失值、重复值、异常值等问题是常见的挑战,本文总结了多种数据清洗与处理方法,文中的示例代码简洁易懂,有需要的小伙伴可以参考下... 目录缺失值处理重复值处理异常值处理数据类型转换文本清洗数据分组统计数据分箱数据标准化在数据处理与分析过

大数据小内存排序问题如何巧妙解决

《大数据小内存排序问题如何巧妙解决》文章介绍了大数据小内存排序的三种方法:数据库排序、分治法和位图法,数据库排序简单但速度慢,对设备要求高;分治法高效但实现复杂;位图法可读性差,但存储空间受限... 目录三种方法:方法概要数据库排序(http://www.chinasem.cn对数据库设备要求较高)分治法(常

Vue项目中Element UI组件未注册的问题原因及解决方法

《Vue项目中ElementUI组件未注册的问题原因及解决方法》在Vue项目中使用ElementUI组件库时,开发者可能会遇到一些常见问题,例如组件未正确注册导致的警告或错误,本文将详细探讨这些问题... 目录引言一、问题背景1.1 错误信息分析1.2 问题原因二、解决方法2.1 全局引入 Element

Python将大量遥感数据的值缩放指定倍数的方法(推荐)

《Python将大量遥感数据的值缩放指定倍数的方法(推荐)》本文介绍基于Python中的gdal模块,批量读取大量多波段遥感影像文件,分别对各波段数据加以数值处理,并将所得处理后数据保存为新的遥感影像... 本文介绍基于python中的gdal模块,批量读取大量多波段遥感影像文件,分别对各波段数据加以数值处

使用MongoDB进行数据存储的操作流程

《使用MongoDB进行数据存储的操作流程》在现代应用开发中,数据存储是一个至关重要的部分,随着数据量的增大和复杂性的增加,传统的关系型数据库有时难以应对高并发和大数据量的处理需求,MongoDB作为... 目录什么是MongoDB?MongoDB的优势使用MongoDB进行数据存储1. 安装MongoDB