如何区分model.predict() 和 model.predict_proba()?

2023-12-01 03:36
文章标签 model 区分 predict proba

本文主要是介绍如何区分model.predict() 和 model.predict_proba()?,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

最近在做关于机器学习和深度学习实验的时候,直接将机器学习模型套了深度学习中,但出现了报错,下面我们来解释一下。

机器学习中:

def model_performance(model,X_train,y_train,X_test,y_test):model.fit(X_train, y_train)predicted = model.predict(X_test) # [0 0 0 ... 0 0 0]print("*****",predicted)predicted_prob = model.predict_proba(X_test) # 输出类别概率 5维度print(predicted_prob)
***** [0 0 0 ... 0 0 0]
[[0.0685155  0.22937195 0.39369586 0.20875027 0.09966642][0.06768499 0.18204635 0.46509778 0.19814811 0.08702277][0.06794315 0.19585645 0.42916814 0.20967341 0.09735886]...[0.05866877 0.13941506 0.37418567 0.24062653 0.18710397][0.05755057 0.16208819 0.4660596  0.21127041 0.10303124][0.04571597 0.12904605 0.40106484 0.24679342 0.17737972]]

上面是他们输出的结果 我们可以看出 model.predict 输出的是一个标签值,而第二个输出的是概率。

在机器学习中

  • 通常,机器学习中的模型输出是类别标签。例如,在二分类问题中,model.predict 可能返回类别标签 0 或 1。在多分类问题中,可能返回多个类别标签中的一个。
  • model.predict_proba 一般用于获取概率信息,返回每个类别的概率值。这在需要概率信息的场景中很有用,比如绘制 ROC 曲线、计算 AUC 等。

深度学习中:(model 中我们定义了一个MLP(多层感知机模型))

def model_performance_DL(model,X_train,y_train,X_test,y_test):# model 已经在上面模型中fit了# y_pred = model.predict_proba(X_test)  #  [[0.11328121 0.15560272 0.41047114 0.21880718 0.10183779]...] 生成概率矩阵y_pred = model.predict(X_test)  #  [[0.11328121 0.15560272 0.41047114 0.21880718 0.10183779]...] 生成概率矩阵print(y_pred)predicted = []   # [2, 2, 2, 2, 2, 2,...]predicted_prob = y_predfor i in range(len(y_pred)):predicted.append(np.argmax(y_pred[i]))print(predicted)
[[0.0914833  0.16449034 0.36516902 0.21704283 0.1618145 ][0.06984996 0.15207583 0.40046582 0.22811058 0.14949782][0.06896283 0.15217893 0.4018344  0.22758484 0.14943895]...[0.08319778 0.16089137 0.37650803 0.22180358 0.15759929][0.06792274 0.15025833 0.40528333 0.22829781 0.14823778][0.07095329 0.15256266 0.3989566  0.22734448 0.15018293]][[0.0914833  0.16449034 0.36516902 0.21704283 0.1618145 ][0.06984996 0.15207583 0.40046582 0.22811058 0.14949782][0.06896283 0.15217893 0.4018344  0.22758484 0.14943895]...[0.08319778 0.16089137 0.37650803 0.22180358 0.15759929][0.06792274 0.15025833 0.40528333 0.22829781 0.14823778][0.07095329 0.15256266 0.3989566  0.22734448 0.15018293]][2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,

我们可以看出来 两者输出的结果一样 那说明他们在深度学习中往往直接输出多类别的概率

在深度学习中

  • 在深度学习中,model.predict 也可能输出类别标签,但更常见的是输出概率分布。即使是在分类任务中,深度学习模型一般会输出每个类别的概率分布
  • model.predict_proba 在深度学习中可能被更直观地表示为 model.predict,因为深度学习模型常常直接输出类别概率,而不是硬性的类别标签。

总体而言,深度学习模型在输出时更倾向于提供类别的概率分布,这使得 model.predict 在深度学习中更类似于 model.predict_proba。在实际应用中,具体的输出形式取决于你的任务和模型的架构。

如果能够帮助你的话 可以给个赞赞奖励一下 谢谢!

这篇关于如何区分model.predict() 和 model.predict_proba()?的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/439635

相关文章

MVC(Model-View-Controller)和MVVM(Model-View-ViewModel)

1、MVC MVC(Model-View-Controller) 是一种常用的架构模式,用于分离应用程序的逻辑、数据和展示。它通过三个核心组件(模型、视图和控制器)将应用程序的业务逻辑与用户界面隔离,促进代码的可维护性、可扩展性和模块化。在 MVC 模式中,各组件可以与多种设计模式结合使用,以增强灵活性和可维护性。以下是 MVC 各组件与常见设计模式的关系和作用: 1. Model(模型)

MySQL表名区分大小写设置

打开 mysql配置文件mysqld.cnf 打开文件,找到[mysqld]在下面增加一行 lower_case_table_names=0 (0:大小写敏感;1:大小写不敏感) 重启mysql服务 docker restart mysqlserver

Circuit Design 贴片晶振的区分

贴片晶振脚位的区分(非常详细,尤其是如何区分四脚的有源无源晶振): http://ruitairt.com/Article/tiepian_1.html 如何区分有源和无源晶振: http://ruitairt.com/Article/yzjddbfqsq_1.html

diffusion model 合集

diffusion model 整理 DDPM: 前向一步到位,从数据集里的图片加噪声,根据随机到的 t t t 决定混合的比例,反向要慢慢迭代,DDPM是用了1000步迭代。模型的输入是带噪声图和 t,t 先生成embedding后,用通道和的方式加到每一层中间去: 训练过程是对每个样本分配一个随机的t,采样一个高斯噪声 ϵ \epsilon ϵ,然后根据 t 对图片和噪声进行混合,将加噪

Redis 命令不区分大小写,键值区分大小写Redis

今天才知道   Redis 命令不区分大小写   但键值区分大小写的

计算两个字符串的最大公共字符串的长度,字符不区分大小写

/*** */package testString;import java.util.Scanner;/***@author: Administrator*@date: 2016-12-28 下午01:08:30*/public class Main {public static void main(String[] args){Scanner sc=new Scanner(Syste

区分变压器损耗

磁芯损耗 铁芯损耗分为两类:涡流损耗和磁滞损耗。 磁滞损耗 当没有次级电流流动时,流过变压器初级绕组的电流会产生磁通量,从而在次级绕组中感应出电压。该初级电流称为励磁电流,由于初级绕组的 CEMF 较大,因此相当小。由于变压器是通过磁通量传输能量的设备,因此集中磁通量可提高变压器的效率。 初级绕组的磁通量缠绕在称为磁芯的铁或钢材料上,以集中磁通量。磁芯材料为磁通量提供了比露天更好的路径。磁

Segment Anything Model(SAM)中的Adapter是什么?

在META团队发布的Segment Anything Model (SAM) 中,Adapter 是一种用于提升模型在特定任务或领域上的性能的机制。具体来说,SAM 是一个通用的分割模型,能够处理多种不同类型的图像分割任务,而 Adapter 的引入是为了更好地让模型适应不同的任务需求。 Adapter 的主要功能是: 模块化设计:Adapter 是一种小规模的、可插拔的网络模块,可以在不改

Vue学习:v-model绑定文本框、单选按钮、下拉菜单、复选框等

v-model指令可以在组件上使用以实现双向绑定,之前学习过v-model绑定文本框和下拉菜单,今天把表单的几个控件单选按钮radio、复选框checkbox、多行文本框textarea都试着绑定了一下。 一、单行文本框和多行文本框 <p>1.单行文本框</p>用户名:<input type="text" v-model="inputMessage"><p>您的用户名是:{{inputMe

Java memory model(JMM)的理解

总结:JMM 是一种规范,目的是解决由于多线程通过共享内存进行通信时,存在的本地内存数据不一致、编译器会对代码指令重排序、处理器会对代码乱序执行等带来的问题。目的是保证并发编程场景中的原子性、可见性、有序性。 总结的很精辟! 感谢Hollis总结