案例 :基于癌症生存数据建立神经网络(附链接)

2024-04-28 11:58

本文主要是介绍案例 :基于癌症生存数据建立神经网络(附链接),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

作者:Jason Brownlee  翻译:wwl   校对:车前子


本文约4000字,建议阅读3分钟本文介绍了haberman乳腺癌生存二分类数据集,进行神经网络模型拟合。包含数据准备、MLP模型学习机制、模型稳健性评估。

 

根据新数据集开发神经网络预测模型是一个挑战。

一种方法是先对数据集进行探查,然后思考什么模型适用于这个数据集,先尝试一些简单的模型,最后再开发并调优一个稳健的模型。

这个流程适用于为分类、回归预测模型问题开发高效的神经网络。

本教程中,你将学习如何开发一个多层感知机神经网络模型,用于癌症生存二分类数据集。

完成本教程后,你将了解到:

  • 如何加载和汇总癌症生存数据集,根据结果来进行数据准备和模型配置。

  • 如何探索MLP模型拟合数据的学习机制。

  • 如何得到稳健的模型,调优并做预测。

开始吧!

Bernd Thaller拍摄

概览

本教程分为4部份:

  • Haberman 乳腺癌生存数据集

  • 神经网络学习机制

  • 模型鲁棒性评估

  • 最终的模型及预测

Haberman 乳腺癌生存数据集

首先,定义数据集并作数据探查。

我们使用的是“haberman”标准二分类数据集。

数据集描述的是乳腺癌患者的数据,结局事件是患者生存,具体是指病人是否生存了五年活以上,或患者是否存活。

这是学习不平衡数据分类问题的标准的数据集。数据集的背景描述表明,研究是在1958年到1970年期间,在芝加哥大学的Billings医院开展的。

数据集有306个样本,3个输入变量:

  • 病人在手术期间的年龄;

  • 手术的两位数年份;

  • 检测到的腋窝淋巴结阳性数,这是衡量癌症是否已扩散的一种手段。

 

我们只有以上数据,无法选择组成数据集合的病例,以及病例的特征。

尽管这个数据集描述的是乳腺癌患者的生存情况,但考虑到数据集的样本量少,以及这些数据是基于发生在几十年前的乳腺癌病例,因此基于这个数据集的模型并不具备泛化能力。

备注:声明,我们不是要治愈乳腺癌,而是在探索一种标准的分类数据集。

以下是数据集的前5行的抽样。

从以下链接,可以对这个数据集有更多了解:

Haberman Survival Dataset (haberman.csv)(https://github.com/jbrownlee/Datasets/blob/master/haberman.csv)

Haberman Survival Dataset Details (haberman.names)(https://github.com/jbrownlee/Datasets/blob/master/haberman.names)

可以直接从URL中加载数据集,保存为pandas DataFrame,如下:

执行这个例子,可以直接从这个URL加载数据,获得数据集的维度。

本例中,我们可以确定,数据集有4个变量(3个输入1个输出变量),有306行数据。

对于一个神经网络来说,这个数据量不算大,因此一个小的、并适当加入正则项的网络,可能更合适。

另外,相对于直接拆分为训练集和测试集,k折交叉验证有助于生成一个更值得信赖的模型结果,因为单一的模型只需要几秒钟就可以拟合得到。

接下来,可以看一看数据的总结信息,并可视化数据。

执行这个例子,首先加载了数据,接着打印了对每个变量的统计信息。

我们可以看到每个变量的均值和不同,或许在建模之前,需要先进行标准化。

接下来,对每个变量绘制直方图。

我们发现,第一个变量符合高斯分布,另外两个输入变量可能是指数分布。

在每个变量上使用幂变换可以减少概率分布的偏差,从而提高模型的性能。

我们可以看到两个类之间的示例分布有一些偏差,这意味着分类是不平衡的。这是不平衡数据。

有必要了解数据集不平衡的程度。

可以用Counter对象统计每个分类下的样本量,用这个统计结果总结分布的特征。

完整的例子如下:

执行这个例子,会对数据集中类别的分布做一个总结。

类别1包含225个样本,约为数据集的74%,是最多的分类。类别2是未存活的样本,只有81个,占26%。

这个类别的分布是偏态的,但不是非常严重的不平衡。

当我们评估分类准确性的时候,考虑以上信息是有帮助的,因为任何准确度在73.5%以下的模型在这个数据集上都是没有价值的。

现在我们已经熟悉了这个数据集,接下来,一起开发神经网络模型吧。

神经网络学习机制

我们将用TensorFlow根据这个数据集拟合多层感知机模型。

我们无法知道,在这个数据集上表现最好的超参数是多少,所以我们需要经过实验寻找适合的超参数。

考虑到这是个小数据集,用小批尺寸进行批量训练可能是个好主意,例如16或32行。开始时使用Adam版本的随机梯度下降,因为它将自动调整学习速率,并在大多数数据集上运行良好。

在我们认真评估模型之前,先回顾下学习机制并调整模型架构和学习配置,直到我们有了稳定的学习机制,然后看看如何最大限度地利用模型。

可以通过简单地将数据划分为测试集和训练集,并查看学习曲线来实现以上目标。这个可以帮助我们了解模型过拟合还是欠拟合,接下来,我们可以根据结果调整配置。

首先需要确保,输入变量都是浮点值,目标变量是0/1的整型值。

接着,我们把数据集划分为输入变量和输出变量,划分成比例为67/33的训练集和测试集。

还需要保证,训练集和测试集上不同类别数据的分布和整个数据集是一致的。

本例中,我们可以定义一个小的MLP模型,包含一个10节点的隐藏层,一个输出层(这个是任意选择的)。隐藏层的激活函数用ReLu函数,和he_normal 权重初始化函数 ,通常这些设定在实践中表现优秀。

ReLu函数

https://machinelearningmastery.com/rectified-linear-activation-function-for-deep-learning-neural-networks/

权重初始化函数

https://machinelearningmastery.com/weight-initialization-for-deep-learning-neural-networks/

模型的输出是sigmoid激活后的二分类结果,我们将最小化二分类交叉熵损失函数。

二分类交叉熵损失函数

https://machinelearningmastery.com/how-to-choose-loss-functions-when-training-deep-learning-neural-networks/

我们将拟合这个模型,由于是小样本数据,使用200个训练epoch(任意选择的),每个批量是16个样本。

我们认为在原始数据上拟合模型可能是个好主意,但这是个重要的起点。

训练结束,我们将在测试集上评估模型表现,报告分类准确度。

最后,我们将绘制训练过程中的反映交叉熵损失的学习曲线。

把以上操作整合,得到了在癌症生存数据集上的第一个MLP模型的完整代码示例。

运行该示例首先在训练数据集上拟合模型,然后在测试数据集上报告分类准确度。

跟随我的新书 Data Preparation for Machine Learning(https://machinelearningmastery.com/data-preparation-for-machine-learning/),开启你的项目,其中包括所有示例的分步教程和Python源代码文件。

本例中,我们可以看到模型准确度超过73.5%,比上文提到的全预测为一类的准确度高。

在训练集和测试集上的损失值的曲线图如下。我们可以看到模型拟合的很好,没有出现欠拟合和过拟合。

我们已经对这个数据集上简单的MLP模型有了一些概念,我们可以寻求更稳健的模型评估。

模型稳健性评估

K折交叉验证的过程可以对模型效果提供更可靠的评估,虽然执行会慢一点。

这是因为k模型必须进行拟合和评估。当数据集很小时,这不是问题,例如癌症生存数据集。

我们可以用StratifiedKFold这个类,手动循环每个折子,拟合模型,得到模型评估结果,然后整个流程结束后,得到模型评估的平均值。

https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedKFold.html

我们可以应用这个框架得到一个可信赖的MLP模型的结果,对于不同的数据准备、模型架构、学习配置,这个框架都适用。

关键的是,在使用k-折交叉验证前,我们先对模型在这个数据集上的学习机制有了了解。如果我们直接对模型调优 ,可能我们会一下子就得到好的结果,但如果没有的话,我们可能不知道为什么,比如说为什么模型会过拟合或者欠拟合。

如果我们又对模型进行了大的修改,有必要返回去确认模型是在适当收敛的。

上文中评估MLP模型的完整代码示例如下。

运行示例,报告了评价过程的每次迭代模型性能,并报告了运行结束时分类准确度的均值和标准偏差。

跟随我的新书 Data Preparation for Machine Learning(https://machinelearningmastery.com/data-preparation-for-machine-learning/),开启你的项目,其中包括所有示例的分步教程和Python源代码文件。

这个例子中,MLP模型的平均准确度是75.2%,和我们上一部分的模型结果接近。

这证实了我们的期望,即对于这个数据集,基本模型配置可能比简单的模型工作得更好。

但这是个好的结果吗?

事实上,这是个具有挑战的分类问题,74.5%的准确度结果已经不错了。

接下来,让我们看看我们如何拟合最终的模型并用它来预测

最终的模型和预测

当我们选择了模型参数,我们可以在所有数据上训练一个最终的模型,并用模型对新数据进行预测。

在本例中,我们将使用带dropout的模型,和小批量训练。

数据准备和模型拟合按上文实现,尽管是在整个数据集上,而不是在数据集的训练子集上。

我们可以利用这个模型对新的数据进行预测。

首先,定义一行新数据。

备注:我是提取的数据集的第一行数据,预期输出结果是‘1’。

可以做出预测。

然后对预测结果进行转置,得到正确形式下可解释的结果(是一个整数)。

本例中,我们简单的报告下预测结果。

把以上步骤整合起来,对haberman数据集上进行拟合最终模型,并对新数据进行预测的完整代码示例如下所示。

执行示例代码在整个数据集上拟合模型,并对新数据进行预测。

 

跟随我的新书 Data Preparation for Machine Learning(https://machinelearningmastery.com/data-preparation-for-machine-learning/),开启你的项目,其中包括所有示例的分步教程和Python源代码文件。

本例中,我们可以看到预测结果是1

扩展阅读

如果你想在这个方向继续探索,本节提供了更多学习资源

教程

  • How to Develop a Probabilistic Model of Breast Cancer Patient Survival

https://machinelearningmastery.com/how-to-develop-a-probabilistic-model-of-breast-cancer-patient-survival/

  • How to Develop a Neural Net for Predicting Disturbances in the Ionosphere

https://machinelearningmastery.com/predicting-disturbances-in-the-ionosphere/

  • Best Results for Standard Machine Learning Datasets

https://machinelearningmastery.com/results-for-standard-classification-and-regression-machine-learning-datasets/

  • TensorFlow 2 Tutorial: Get Started in Deep Learning With tf.keras

https://machinelearningmastery.com/tensorflow-tutorial-deep-learning-with-tf-keras/

  • A Gentle Introduction to k-fold Cross-Validation

 https://machinelearningmastery.com/k-fold-cross-validation/

总结

在本教程中,您了解了如何应用癌症生存二分类数据集开发多层感知器神经网络模型。

具体来说,你学到了:

  • 如何加载和汇总癌症生存数据集,并使用结果来建议要使用的数据准备和模型配置。

  • 如何在数据集上探索简单MLP模型的学习动态。

  • 如何开发模型性能的稳健估计,调整模型性能并对新数据进行预测。

原标题:

Develop a Neural Network for Cancer Survival Dataset

原文链接:

https://machinelearningmastery.com/neural-network-for-cancer-survival-dataset

END

版权声明:本号内容部分来自互联网,转载请注明原文链接和作者,如有侵权或出处有误请和我们联系。


合作请加QQ:365242293  

数据分析(ID : ecshujufenxi )互联网科技与数据圈自己的微信,也是WeMedia自媒体联盟成员之一,WeMedia联盟覆盖5000万人群。

这篇关于案例 :基于癌症生存数据建立神经网络(附链接)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/943206

相关文章

C++对象布局及多态实现探索之内存布局(整理的很多链接)

本文通过观察对象的内存布局,跟踪函数调用的汇编代码。分析了C++对象内存的布局情况,虚函数的执行方式,以及虚继承,等等 文章链接:http://dev.yesky.com/254/2191254.shtml      论C/C++函数间动态内存的传递 (2005-07-30)   当你涉及到C/C++的核心编程的时候,你会无止境地与内存管理打交道。 文章链接:http://dev.yesky

C++工程编译链接错误汇总VisualStudio

目录 一些小的知识点 make工具 可以使用windows下的事件查看器崩溃的地方 dumpbin工具查看dll是32位还是64位的 _MSC_VER .cc 和.cpp 【VC++目录中的包含目录】 vs 【C/C++常规中的附加包含目录】——头文件所在目录如何怎么添加,添加了以后搜索头文件就会到这些个路径下搜索了 include<> 和 include"" WinMain 和

C/C++的编译和链接过程

目录 从源文件生成可执行文件(书中第2章) 1.Preprocessing预处理——预处理器cpp 2.Compilation编译——编译器cll ps:vs中优化选项设置 3.Assembly汇编——汇编器as ps:vs中汇编输出文件设置 4.Linking链接——链接器ld 符号 模块,库 链接过程——链接器 链接过程 1.简单链接的例子 2.链接过程 3.地址和

【服务器运维】MySQL数据存储至数据盘

查看磁盘及分区 [root@MySQL tmp]# fdisk -lDisk /dev/sda: 21.5 GB, 21474836480 bytes255 heads, 63 sectors/track, 2610 cylindersUnits = cylinders of 16065 * 512 = 8225280 bytesSector size (logical/physical)

人工智能机器学习算法总结神经网络算法(前向及反向传播)

1.定义,意义和优缺点 定义: 神经网络算法是一种模仿人类大脑神经元之间连接方式的机器学习算法。通过多层神经元的组合和激活函数的非线性转换,神经网络能够学习数据的特征和模式,实现对复杂数据的建模和预测。(我们可以借助人类的神经元模型来更好的帮助我们理解该算法的本质,不过这里需要说明的是,虽然名字是神经网络,并且结构等等也是借鉴了神经网络,但其原型以及算法本质上还和生物层面的神经网络运行原理存在

SQL Server中,查询数据库中有多少个表,以及数据库其余类型数据统计查询

sqlserver查询数据库中有多少个表 sql server 数表:select count(1) from sysobjects where xtype='U'数视图:select count(1) from sysobjects where xtype='V'数存储过程select count(1) from sysobjects where xtype='P' SE

python实现最简单循环神经网络(RNNs)

Recurrent Neural Networks(RNNs) 的模型: 上图中红色部分是输入向量。文本、单词、数据都是输入,在网络里都以向量的形式进行表示。 绿色部分是隐藏向量。是加工处理过程。 蓝色部分是输出向量。 python代码表示如下: rnn = RNN()y = rnn.step(x) # x为输入向量,y为输出向量 RNNs神经网络由神经元组成, python

数据时代的数字企业

1.写在前面 讨论数据治理在数字企业中的影响和必要性,并介绍数据治理的核心内容和实践方法。作者强调了数据质量、数据安全、数据隐私和数据合规等方面是数据治理的核心内容,并介绍了具体的实践措施和案例分析。企业需要重视这些方面以实现数字化转型和业务增长。 数字化转型行业小伙伴可以加入我的星球,初衷成为各位数字化转型参考库,星球内容每周更新 个人工作经验资料全部放在这里,包含数据治理、数据要

如何在Java中处理JSON数据?

如何在Java中处理JSON数据? 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿!今天我们将探讨在Java中如何处理JSON数据。JSON(JavaScript Object Notation)作为一种轻量级的数据交换格式,在现代应用程序中被广泛使用。Java通过多种库和API提供了处理JSON的能力,我们将深入了解其用法和最佳

两个基因相关性CPTAC蛋白组数据

目录 蛋白数据下载 ①蛋白数据下载 1,TCGA-选择泛癌数据  2,TCGA-TCPA 3,CPTAC(非TCGA) ②蛋白相关性分析 1,数据整理 2,蛋白相关性分析 PCAS在线分析 蛋白数据下载 CPTAC蛋白组学数据库介绍及数据下载分析 – 王进的个人网站 (jingege.wang) ①蛋白数据下载 可以下载泛癌蛋白数据:UCSC Xena (xena