使用BP神经网络对鸢尾花数据集分类

2024-03-21 11:50

本文主要是介绍使用BP神经网络对鸢尾花数据集分类,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

        最近认识的一位大佬搭建的人工智能学习网站,内容通俗易懂,风趣幽默,感兴趣的可以去看看:床长人工智能教程

 废话不多说,请看正文!

使用BP神经网络对鸢尾花数据集分类

from sklearn.datasets import load_iris
from pandas import DataFrame
import pandas as pdx_data = load_iris().data  # 返回iris数据集所有输入特征
y_data = load_iris().target  # 返回iris数据集所有标签
print("x_data from datasets:", x_data)
print("y_data from datasets", y_data)x_data = DataFrame(x_data, columns=['花萼长', '花萼宽', '花瓣长', '花瓣宽'])
pd.set_option('display.unicode.east_asian_width', True)  # 设置列名对齐
print(x_data)x_data['类别'] = y_data  # 新加一列,列标签‘类别’,数据为y_data
print("x_data add a column: \n", x_data)from sklearn.datasets import load_iris
from pandas import DataFrame
import pandas as pd
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as pltimport os
import PySide2dirname = os.path.dirname(PySide2.__file__)
plugin_path = os.path.join(dirname, 'plugins', 'platforms')
os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = plugin_path# 定义超参数和画图用的两个存数据的空列表
lr = 0.1
train_loss_results = []  # 将每轮的loss记录在此列表中,为后续画loss曲线提供数据
test_acc = []  # 将每轮的acc记录在此列表中,为后续画acc曲线提供数据
epoch = 300
loss_all = 0  # 每轮分为4个step(因为一共有120个训练数据,每个batch有32个样本,所以epoch迭代一次120个数据需要4个batch),loss_all记录四个step生成的4个loss的和# ____________________________数据准备______________________________
# 1.数据集的读入
x_data = load_iris().data  # 返回iris数据集所有输入特征
y_data = load_iris().target  # 返回iris数据集所有标签
# print("x_data from datasets:", x_data)
# print("y_data from datasets", y_data)# 2.数据集乱序
np.random.seed(116)  # 使用相同的种子seed,使得乱序后的数据特征和标签仍然可以对齐
np.random.shuffle(x_data)  # 打乱数据集
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)# 3.数据集分出永不相见的训练集和测试集
x_train = x_data[:-30]  # 前120个数据作为训练集
y_train = y_data[:-30]  # 前120个标签作为训练集标签
x_test = x_data[-30:]  # 后30个数据集作为测试集
y_test = y_data[-30:]# 转换x的数据类型,否则后面矩阵相乘时会因为数据类型不一致报错
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)# 配成【输入特征, 标签】对,每次喂入一小撮(batch)(把数据集分为批次,每批次32组数据)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)# ____________________________定义神经网络______________________________
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1, seed=1))  # 4表示输入的4的特征,3表示3分类
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))  # 3表示3分类# ____________________________训练部分:嵌套循环迭代_______________________
for epoch in range(epoch):  # 数据集级别迭代for step, (x_train, y_train) in enumerate(train_db):  # batch级别迭代with tf.GradientTape() as tape:  # 在with结构中计算前向传播y以及计算总损失lossy = tf.matmul(x_train, w1) + b1  # 神经网络乘加运算y = tf.nn.softmax(y)  # 使输出y符合概率分布(此操作后与独热码同量级,可以相减求loss)y_ = tf.one_hot(y_train, depth=3)  # 将标签值转换为独热码格式,方便计算loss和accloss = tf.reduce_mean(tf.square(y_ - y))  # 采用均值方差损失函数MSEloss_all += loss.numpy()  # 将每个step计算出loss累加,为后续求loss平均值提供数据# 计算loss对各个参数的梯度grads = tape.gradient(loss, [w1, b1])  # 损失函数loss分别对参数w1和b1计算偏导数# 实现梯度更新 w1 = w1 - lr * w1_grad    b = b - lr * b_gradw1.assign_sub(lr * grads[0])  # 参数w1自更新b1.assign_sub(lr * grads[1])  # 参数b1自更新# 求出每个epoch的平均损失print("Epoch {}, loss:{}".format(epoch, loss_all / 4))train_loss_results.append(loss_all / 4)  # 将4个step的loss求平均记录在此变量中loss_all = 0  # loss_all归零为记录下一个epoch的loss做准备# ____________________________测试部分:识别准确率______________________________total_correct, total_number = 0, 0for x_test, y_test in test_db:y = tf.matmul(x_test, w1) + b1  # y为预测结果y = tf.nn.softmax(y)  # y符合概率分布pred = tf.argmax(y, axis=1)  # 返回y中最大值的索引,即预测的分类pred = tf.cast(pred, dtype=y_test.dtype)  # 调整数据类型与标签一致,即为把pred预测值转换为y_test数据类型correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)  # 如果真实值与预测值相同,就正确correct = tf.reduce_sum(correct)  # 将每个batch的correct加起来total_correct += int(correct)  # 将所有batch中的correct数加起来total_number += x_test.shape[0]# 总的准确率等于total_correct / total_numberacc = total_correct / total_numbertest_acc.append(acc)print("test_acc", acc)print("__________________________")# ____________________________acc / loss 可视化___________________________
# 绘制loss曲线
plt.title("Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.plot(train_loss_results, label="$Loss$")  # 逐点画出test_acc值并连线
plt.legend()
plt.show()#  绘制Accuracy曲线
plt.title("Acc Curve")
plt.xlabel("Epoch")import graphviz
plt.ylabel("Acc")
plt.plot(test_acc, label="$Accuracy$")  # 逐点画出test_acc值并连线
plt.legend()
plt.show()

结果:

这篇关于使用BP神经网络对鸢尾花数据集分类的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/832563

相关文章

postgresql使用UUID函数的方法

《postgresql使用UUID函数的方法》本文给大家介绍postgresql使用UUID函数的方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录PostgreSQL有两种生成uuid的方法。可以先通过sql查看是否已安装扩展函数,和可以安装的扩展函数

如何使用Lombok进行spring 注入

《如何使用Lombok进行spring注入》本文介绍如何用Lombok简化Spring注入,推荐优先使用setter注入,通过注解自动生成getter/setter及构造器,减少冗余代码,提升开发效... Lombok为了开发环境简化代码,好处不用多说。spring 注入方式为2种,构造器注入和setter

MySQL中比较运算符的具体使用

《MySQL中比较运算符的具体使用》本文介绍了SQL中常用的符号类型和非符号类型运算符,符号类型运算符包括等于(=)、安全等于(=)、不等于(/!=)、大小比较(,=,,=)等,感兴趣的可以了解一下... 目录符号类型运算符1. 等于运算符=2. 安全等于运算符<=>3. 不等于运算符<>或!=4. 小于运

使用zip4j实现Java中的ZIP文件加密压缩的操作方法

《使用zip4j实现Java中的ZIP文件加密压缩的操作方法》本文介绍如何通过Maven集成zip4j1.3.2库创建带密码保护的ZIP文件,涵盖依赖配置、代码示例及加密原理,确保数据安全性,感兴趣的... 目录1. zip4j库介绍和版本1.1 zip4j库概述1.2 zip4j的版本演变1.3 zip4

Python 字典 (Dictionary)使用详解

《Python字典(Dictionary)使用详解》字典是python中最重要,最常用的数据结构之一,它提供了高效的键值对存储和查找能力,:本文主要介绍Python字典(Dictionary)... 目录字典1.基本特性2.创建字典3.访问元素4.修改字典5.删除元素6.字典遍历7.字典的高级特性默认字典

使用Python构建一个高效的日志处理系统

《使用Python构建一个高效的日志处理系统》这篇文章主要为大家详细讲解了如何使用Python开发一个专业的日志分析工具,能够自动化处理、分析和可视化各类日志文件,大幅提升运维效率,需要的可以了解下... 目录环境准备工具功能概述完整代码实现代码深度解析1. 类设计与初始化2. 日志解析核心逻辑3. 文件处

一文详解如何使用Java获取PDF页面信息

《一文详解如何使用Java获取PDF页面信息》了解PDF页面属性是我们在处理文档、内容提取、打印设置或页面重组等任务时不可或缺的一环,下面我们就来看看如何使用Java语言获取这些信息吧... 目录引言一、安装和引入PDF处理库引入依赖二、获取 PDF 页数三、获取页面尺寸(宽高)四、获取页面旋转角度五、判断

MyBatis-Plus通用中等、大量数据分批查询和处理方法

《MyBatis-Plus通用中等、大量数据分批查询和处理方法》文章介绍MyBatis-Plus分页查询处理,通过函数式接口与Lambda表达式实现通用逻辑,方法抽象但功能强大,建议扩展分批处理及流式... 目录函数式接口获取分页数据接口数据处理接口通用逻辑工具类使用方法简单查询自定义查询方法总结函数式接口

C++中assign函数的使用

《C++中assign函数的使用》在C++标准模板库中,std::list等容器都提供了assign成员函数,它比操作符更灵活,支持多种初始化方式,下面就来介绍一下assign的用法,具有一定的参考价... 目录​1.assign的基本功能​​语法​2. 具体用法示例​​​(1) 填充n个相同值​​(2)

Spring StateMachine实现状态机使用示例详解

《SpringStateMachine实现状态机使用示例详解》本文介绍SpringStateMachine实现状态机的步骤,包括依赖导入、枚举定义、状态转移规则配置、上下文管理及服务调用示例,重点解... 目录什么是状态机使用示例什么是状态机状态机是计算机科学中的​​核心建模工具​​,用于描述对象在其生命