深度学习Week5-心脏病预测(RNN)

2024-03-21 01:10

本文主要是介绍深度学习Week5-心脏病预测(RNN),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

🍨 本文为[🔗365天深度学习训练营]中的学习记录博客
🍦 参考文章:[🔗深度学习100例-循环神经网络(RNN)心脏病预测]
🍖 原作者:[K同学啊|接辅导、项目定制]

  • 难度:新手入门⭐

🍺要求:

  1. 本地读取并加载数据。(✔)
  2. 了解循环神经网络(RNN)的构建过程(✔)
  3. 测试集accuracy到达87%(✔)

🍻拔高:

  1. 测试集accuracy到达89%(X)

环境:

  • 语言环境:Python3.8
  • 编译器:pycharm社区版
  • 深度学习框架:TensorFlow2.4.1
  • 数据地址:🔗百度网盘

总体流程

一、前期准备

1.设置GPU

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True)  # 设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0], "GPU")print(gpus)

这里遇到了bug,一是h5py版本出错了,卸了重装就行,二是缺少CUDNN64_8.DLL文件,这里推荐一个下载dll文件的网址,不收费可中文,非常好用。

cudnn64_8.dll 搜索结果 | DLL‑files.com

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

2. 导入数据

数据是csv表格,里面有各种参数,我们通过对数据的处理训练,实现预测功能

数据介绍:

age:1) 年龄

sex:2) 性别

cp:3) 胸痛类型 (4 values)

trestbps:4) 静息血压

chol:5) 血清胆甾醇 (mg/dl

fbs:6) 空腹血糖 > 120 mg/dl

restecg:7) 静息心电图结果 (值 0,1 ,2)

thalach:8) 达到的最大心率

exang:9) 运动诱发的心绞痛

oldpeak:10)  相对于静止状态,运动引起的ST段压低

slope:11) 运动峰值 ST 段的斜率

ca:12) 荧光透视着色的主要血管数量 (0-3)

thal:13) 0 = 正常;1 = 固定缺陷;2 = 可逆转的缺陷

target:14) 0 = 心脏病发作的几率较小 1 = 心脏病发作的几率更大

把数据文件放在和代码.py同一个目录下。

import pandas as pd
import numpy as npdf = pd.read_csv("heart.csv")
print(df)

 3.(次要)检查数据

由于数据可能有误输入为0的,可以检查一下有没有空值,大多情况下可以跳过这一步

print(df.isnull().sum())

age         0
sex         0
cp          0
trestbps    0
chol        0
fbs         0
restecg     0
thalach     0
exang       0
oldpeak     0
slope       0
ca          0
thal        0
target      0
dtype: int64

每组数据空值个数都是0,数据正常。

二、数据预处理

1. 划分训练集与测试集

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_splitX = df.iloc[:,:-1] #iloc函数:对数据进行位置索引,从而在数据表中提取出相应的数据。
y = df.iloc[:,-1]X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.1, random_state = 1)print(X_train.shape, y_train.shape)

(272, 13) (272,)

2. 标准化

# 将每一列特征标准化为标准正太分布,注意,标准化是针对每一列而言的
sc      = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test  = sc.transform(X_test)X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_test  = X_test.reshape(X_test.shape[0], X_test.shape[1], 1)

三、构建RNN模型

函数原型

tf.keras.layers.SimpleRNN(units,activation='tanh',use_bias=True,kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',bias_initializer='zeros',kernel_regularizer=None,recurrent_regularizer=None,bias_regularizer=None,activity_regularizer=None,kernel_constraint=None,recurrent_constraint=None,
bias_constraint=None,dropout=0.0,recurrent_dropout=0.0,return_sequences=False,return_state=False,
go_backwards=False,stateful=False,unroll=False,**kwargs)

关键参数说明:

units: 正整数,输出空间的维度。

activation: 要使用的激活函数。 默认:双曲正切(tanh)。 如果传入 None,则不使用激活函数 (即 线性激活:a(x) = x)。

use_bias: 布尔值,该层是否使用偏置向量。

kernel_initializer: kernel 权值矩阵的初始化器, 用于输入的线性转换 (详见 initializers)。

recurrent_initializer: recurrent_kernel 权值矩阵 的初始化器,用于循环层状态的线性转换 (详见 initializers)。

bias_initializer:偏置向量的初始化器 (详见initializers).

dropout: 在 0 和 1 之间的浮点数。 单元的丢弃比例,用于输入的线性转换。

import tensorflow
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM,SimpleRNNmodel = Sequential()
model.add(SimpleRNN(200, input_shape= (13,1), activation='relu'))
model.add(Dense(100, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()

 四、编译模型

设置学习率啥的 和CNN一样(吧)

opt = tf.keras.optimizers.Adam(learning_rate=1e-4)model.compile(loss='binary_crossentropy',optimizer=opt,metrics="accuracy")

 五、训练模型

因为都是数据,训练比图片数据快不少,100轮起步

epochs = 100history = model.fit(X_train, y_train, epochs=epochs, batch_size=128, validation_data=(X_test, y_test),verbose=1)


六、模型评估

import matplotlib.pyplot as pltacc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

 R1周的这个数据量有点少,这个抖动是正常的,可以通过扩充数据集来解决。

最后输出一下准确率(其实看第一百轮准确率即可):

scores = model.evaluate(X_test, y_test, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))

accuracy: 90.32%

这准确率太不稳定了,我训练了两次,一次84%,这次90%...

*拔高

结合week4猴痘识别的经验,先把训练的history中的batch_size降低到32

效果相当的不好,改回去准备从学习率入手

然而学习率无论是调大还是调小,准确率都稳定在87%左右,甚至调的幅度太大准确率会有明显下降...

可能是数据集太小的缘故,以后发现优化方法再和大家分享

这篇关于深度学习Week5-心脏病预测(RNN)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/831332

相关文章

Python与DeepSeek的深度融合实战

《Python与DeepSeek的深度融合实战》Python作为最受欢迎的编程语言之一,以其简洁易读的语法、丰富的库和广泛的应用场景,成为了无数开发者的首选,而DeepSeek,作为人工智能领域的新星... 目录一、python与DeepSeek的结合优势二、模型训练1. 数据准备2. 模型架构与参数设置3

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

最长公共子序列问题的深度分析与Java实现方式

《最长公共子序列问题的深度分析与Java实现方式》本文详细介绍了最长公共子序列(LCS)问题,包括其概念、暴力解法、动态规划解法,并提供了Java代码实现,暴力解法虽然简单,但在大数据处理中效率较低,... 目录最长公共子序列问题概述问题理解与示例分析暴力解法思路与示例代码动态规划解法DP 表的构建与意义动

Go中sync.Once源码的深度讲解

《Go中sync.Once源码的深度讲解》sync.Once是Go语言标准库中的一个同步原语,用于确保某个操作只执行一次,本文将从源码出发为大家详细介绍一下sync.Once的具体使用,x希望对大家有... 目录概念简单示例源码解读总结概念sync.Once是Go语言标准库中的一个同步原语,用于确保某个操

五大特性引领创新! 深度操作系统 deepin 25 Preview预览版发布

《五大特性引领创新!深度操作系统deepin25Preview预览版发布》今日,深度操作系统正式推出deepin25Preview版本,该版本集成了五大核心特性:磐石系统、全新DDE、Tr... 深度操作系统今日发布了 deepin 25 Preview,新版本囊括五大特性:磐石系统、全新 DDE、Tree

Node.js 中 http 模块的深度剖析与实战应用小结

《Node.js中http模块的深度剖析与实战应用小结》本文详细介绍了Node.js中的http模块,从创建HTTP服务器、处理请求与响应,到获取请求参数,每个环节都通过代码示例进行解析,旨在帮... 目录Node.js 中 http 模块的深度剖析与实战应用一、引言二、创建 HTTP 服务器:基石搭建(一

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;