【tensorflow框架神经网络实现鸢尾花分类】

2024-03-28 21:12

本文主要是介绍【tensorflow框架神经网络实现鸢尾花分类】,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 1、数据获取
  • 2、数据集构建
  • 3、模型的训练验证
  • 可视化训练过程

1、数据获取

  • 从sklearn中获取鸢尾花数据,并合并处理
from sklearn.datasets import load_iris
import pandas as pdx_data = load_iris().data
y_data = load_iris().targetx_data = pd.DataFrame(x_data, columns=['花萼长度','花萼宽度','花瓣长度','花瓣宽度'])
pd.set_option('display.unicode.east_asian_width', True)x_data['类别'] = y_data
x_data

在这里插入图片描述

2、数据集构建

  • 数据集构建包括:
    • 数据读取
    • 数据打乱
    • 数据划分
    • 小批量迭代器生成
import tensorflow as tf
import numpy as np
from sklearn.datasets import load_iris# 1、从sklearn包中datasets读取数据集
x_data = load_iris().data
y_data = load_iris().target# 2、数据打乱
np.random.seed(1)   # 使用相同的seed,使输入特征/标签一一对应
np.random.shuffle(x_data)
np.random.seed(1)
np.random.shuffle(y_data)
tf.random.set_seed(1)# 3、训练集、测试集划分
x_train, x_test = x_data[:-30], x_data[-30:]
y_train, y_test = y_data[:-30], y_data[-30:]# 4、小批量数据
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
train_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

3、模型的训练验证

# 定义超参数,预设变量
lr = 0.1
loss_all = 0
Epoch = 500
train_loss_list = []
test_acc = []# 定义神经网络的可训练参数
w = tf.Variable(tf.random.truncated_normal([4,3], stddev=0.1, seed=1))
b = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))# 循环迭代,训练参数
for epoch in range(Epoch):for step, (x_, y_) in enumerate(train_db):with tf.GradientTape() as tape:x_ = tf.cast(x_, tf.float32)y_pre = tf.matmul(x_, w) + by_pre = tf.nn.softmax(y_pre)y_lab = tf.one_hot(y_, depth=3)loss = tf.reduce_mean(tf.square(y_lab - y_pre))loss_all += loss.numpy()grads = tape.gradient(loss, [w,b])w.assign_sub(lr * grads[0])b.assign_sub(lr * grads[1])print(f'Epoch: {epoch}, loss: {loss_all/4}')train_loss_list.append(loss_all/4)loss_all = 0# 测试部分total_correct, total_number = 0, 0for x_,y_ in test_db:x_ = tf.cast(x_, tf.float32)y_pre = tf.matmul(x_, w) + by_pre = tf.nn.softmax(y_pre)y_p = tf.argmax(y_pre, axis=1)y_p = tf.cast(y_p, dtype=y_.dtype)correct = tf.cast(tf.equal(y_p, y_), dtype=tf.int32)correct = tf.reduce_sum(correct)total_correct += int(correct) total_number += x_.shape[0]acc = total_correct / total_numbertest_acc.append(acc)print("Test_acc:", acc)print("-"*30)

在这里插入图片描述

可视化训练过程

# 绘制测试Acc曲线和训练loss曲线
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(train_loss_list,'b-')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')ax1 = ax.twinx()
ax1.plot(test_acc,'r-')
ax1.set_ylabel('Acc')ax1.spines['left'].set_color('blue')
ax1.spines['right'].set_color('red')

在这里插入图片描述

这篇关于【tensorflow框架神经网络实现鸢尾花分类】的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/856707

相关文章

Linux下删除乱码文件和目录的实现方式

《Linux下删除乱码文件和目录的实现方式》:本文主要介绍Linux下删除乱码文件和目录的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux下删除乱码文件和目录方法1方法2总结Linux下删除乱码文件和目录方法1使用ls -i命令找到文件或目录

SpringBoot+EasyExcel实现自定义复杂样式导入导出

《SpringBoot+EasyExcel实现自定义复杂样式导入导出》这篇文章主要为大家详细介绍了SpringBoot如何结果EasyExcel实现自定义复杂样式导入导出功能,文中的示例代码讲解详细,... 目录安装处理自定义导出复杂场景1、列不固定,动态列2、动态下拉3、自定义锁定行/列,添加密码4、合并

mybatis执行insert返回id实现详解

《mybatis执行insert返回id实现详解》MyBatis插入操作默认返回受影响行数,需通过useGeneratedKeys+keyProperty或selectKey获取主键ID,确保主键为自... 目录 两种方式获取自增 ID:1. ​​useGeneratedKeys+keyProperty(推

Spring Boot集成Druid实现数据源管理与监控的详细步骤

《SpringBoot集成Druid实现数据源管理与监控的详细步骤》本文介绍如何在SpringBoot项目中集成Druid数据库连接池,包括环境搭建、Maven依赖配置、SpringBoot配置文件... 目录1. 引言1.1 环境准备1.2 Druid介绍2. 配置Druid连接池3. 查看Druid监控

Linux在线解压jar包的实现方式

《Linux在线解压jar包的实现方式》:本文主要介绍Linux在线解压jar包的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux在线解压jar包解压 jar包的步骤总结Linux在线解压jar包在 Centos 中解压 jar 包可以使用 u

c++ 类成员变量默认初始值的实现

《c++类成员变量默认初始值的实现》本文主要介绍了c++类成员变量默认初始值,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录C++类成员变量初始化c++类的变量的初始化在C++中,如果使用类成员变量时未给定其初始值,那么它将被

Qt使用QSqlDatabase连接MySQL实现增删改查功能

《Qt使用QSqlDatabase连接MySQL实现增删改查功能》这篇文章主要为大家详细介绍了Qt如何使用QSqlDatabase连接MySQL实现增删改查功能,文中的示例代码讲解详细,感兴趣的小伙伴... 目录一、创建数据表二、连接mysql数据库三、封装成一个完整的轻量级 ORM 风格类3.1 表结构

基于Python实现一个图片拆分工具

《基于Python实现一个图片拆分工具》这篇文章主要为大家详细介绍了如何基于Python实现一个图片拆分工具,可以根据需要的行数和列数进行拆分,感兴趣的小伙伴可以跟随小编一起学习一下... 简单介绍先自己选择输入的图片,默认是输出到项目文件夹中,可以自己选择其他的文件夹,选择需要拆分的行数和列数,可以通过

Python中将嵌套列表扁平化的多种实现方法

《Python中将嵌套列表扁平化的多种实现方法》在Python编程中,我们常常会遇到需要将嵌套列表(即列表中包含列表)转换为一个一维的扁平列表的需求,本文将给大家介绍了多种实现这一目标的方法,需要的朋... 目录python中将嵌套列表扁平化的方法技术背景实现步骤1. 使用嵌套列表推导式2. 使用itert

Python使用pip工具实现包自动更新的多种方法

《Python使用pip工具实现包自动更新的多种方法》本文深入探讨了使用Python的pip工具实现包自动更新的各种方法和技术,我们将从基础概念开始,逐步介绍手动更新方法、自动化脚本编写、结合CI/C... 目录1. 背景介绍1.1 目的和范围1.2 预期读者1.3 文档结构概述1.4 术语表1.4.1 核