【lightgbm/xgboost/nn代码整理三】keras做二分类,多分类以及回

2024-06-12 22:38

本文主要是介绍【lightgbm/xgboost/nn代码整理三】keras做二分类,多分类以及回,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

【lightgbm/xgboost/nn代码整理三】keras做二分类,多分类以及回归任务

浏览更多内容,可访问:http://www.growai.cn

1.简介

该部分是比较基础的深度网络部分,是基于keras实现的多层感知机网络(mlp),使用nn个人感觉最大的一个好处就是目标函数自定义很方便,下面将从数据处理、网络搭建和模型训练三个部分介绍。如果只是想要阅读代码,可直接移步到尾部链接。

2. 数据处理

神经网络对数据的要求比较多,不能处理缺失值,并且数据分布对其影响也很大,输入模型前需要对数据做预处理。具体需要做如下处理

  • onehot:参考上一节

  • 填充:常用的有均值填充,常数值填充,中位数填充等,根据数据场景做选择,这里直接填充的常数值-1

    for i in train_x.columns:if train_x[i].isnull().sum() != 0:train_x[i] = train_x[i].fillna(-1)test[i] = test[i].fillna(-1)
    
  • 归一化:如果各个特征值差距很大,会严重影响模型参数分布,需要对整体数据进行归一化处理

    scaler = StandardScaler()
    train_X = scaler.fit_transform(train_x)
    test_X = scaler.transform(test)
    

3.模型部分

def MLP(dropout_rate=0.25, activation='relu'):start_neurons = 512model = Sequential()model.add(Dense(start_neurons, input_dim=train_X.shape[1], activation=activation))model.add(BatchNormalization())model.add(Dropout(dropout_rate))model.add(Dense(start_neurons // 2, activation=activation))model.add(BatchNormalization())model.add(Dropout(dropout_rate))model.add(Dense(start_neurons // 4, activation=activation))model.add(BatchNormalization())model.add(Dropout(dropout_rate))model.add(Dense(start_neurons // 8, activation=activation))model.add(BatchNormalization())model.add(Dropout(dropout_rate / 2))model.add(Dense(classes, activation='sigmoid'))return model

这里定义的是四层感知网络,为了提高网络的性能,添加的dropout层和BN层。Dropout的具体工作原理是随机的使一些神经元失活,从而达到防止过拟合的作用。直观的理解的话,dropout有点像集成学习中的bagging的思路,每次训练的时候只训练一部分神经元,相当于训练了多个弱分类器,预测的时候则是全部分类器同时作用。而bagging的作用也是为了减少方差(防止过拟合)。BN,Batch Normalization,就是在深度神经网络训练过程中使得每一层神经网络的输入保持相近的分布,可以加速训练。

针对不同的网络,输出层的激活函数不同

  • 二分类:sigmoid
  • 多分类:softmax
  • 回归:linear

4. 模型训练

首先需要定义网络模型,然后定义loss优化和目标函数,keras训练函数和sklearn很相似,直接调用fit函数即可。

model = MLP(dropout_rate=0.5, activation='relu')
model.compile(optimizer='adam', loss='binary_crossentropy',  metrics=['accuracy'])
history = model.fit(x_train, y_train,validation_data=[x_valid, y_valid],epochs=epochs,batch_size=batch_size,callbacks=[call_ES, ],shuffle=True,verbose=1)
  • optimizer:loss优化函数,常用的有sgd, rmsprop, adam等

  • loss:常用的loss损失函数

    • 二分类:binary_crossentropy等
    • 多分类:categorical_crossentropy等
    • 回归:mse,mae等
  • metrics:评价函数:

    • 分类:accuracy等
    • 回归:mse, mae等
  • callbacks:这个是回调函数,该函数是在加载完一次数据后调用,可以用他来加载loss,打印tensorboard,提前停止等,这里给出了提前停止的代码

    call_ES = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=patience, verbose=1, mode='auto', baseline=None)
    

模型预测部分

##分类
predictions = model.predict_proba(test_X, batch_size=batch_size)##回归&分类
oof_preds[val_] = model.predict(x_valid, batch_size=batch_size)

分类任务可以通过第一个式子预测每个类别的概率。对于二分类任务可以自定义阈值,得到最终的分类结果

threshold = 0.5
result = []
for pred in predictions:result.append(1 if pred > threshold else 0)

对于多分类:

result = np.argmax(predictions, axis=1)

代码地址:data_mining_models

写在后面

欢迎您关注作者知乎:ML与DL成长之路

推荐关注公众号:AI成长社,ML与DL的成长圣地。

这篇关于【lightgbm/xgboost/nn代码整理三】keras做二分类,多分类以及回的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1055538

相关文章

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

认识、理解、分类——acm之搜索

普通搜索方法有两种:1、广度优先搜索;2、深度优先搜索; 更多搜索方法: 3、双向广度优先搜索; 4、启发式搜索(包括A*算法等); 搜索通常会用到的知识点:状态压缩(位压缩,利用hash思想压缩)。

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n

数论入门整理(updating)

一、gcd lcm 基础中的基础,一般用来处理计算第一步什么的,分数化简之类。 LL gcd(LL a, LL b) { return b ? gcd(b, a % b) : a; } <pre name="code" class="cpp">LL lcm(LL a, LL b){LL c = gcd(a, b);return a / c * b;} 例题:

计算机毕业设计 大学志愿填报系统 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点赞 👍 收藏 ⭐评论 📝 🍅 文末获取源码联系 👇🏻 精彩专栏推荐订阅 👇🏻 不然下次找不到哟~Java毕业设计项目~热门选题推荐《1000套》 目录 1.技术选型 2.开发工具 3.功能

代码随想录冲冲冲 Day39 动态规划Part7

198. 打家劫舍 dp数组的意义是在第i位的时候偷的最大钱数是多少 如果nums的size为0 总价值当然就是0 如果nums的size为1 总价值是nums[0] 遍历顺序就是从小到大遍历 之后是递推公式 对于dp[i]的最大价值来说有两种可能 1.偷第i个 那么最大价值就是dp[i-2]+nums[i] 2.不偷第i个 那么价值就是dp[i-1] 之后取这两个的最大值就是d

pip-tools:打造可重复、可控的 Python 开发环境,解决依赖关系,让代码更稳定

在 Python 开发中,管理依赖关系是一项繁琐且容易出错的任务。手动更新依赖版本、处理冲突、确保一致性等等,都可能让开发者感到头疼。而 pip-tools 为开发者提供了一套稳定可靠的解决方案。 什么是 pip-tools? pip-tools 是一组命令行工具,旨在简化 Python 依赖关系的管理,确保项目环境的稳定性和可重复性。它主要包含两个核心工具:pip-compile 和 pip

D4代码AC集

贪心问题解决的步骤: (局部贪心能导致全局贪心)    1.确定贪心策略    2.验证贪心策略是否正确 排队接水 #include<bits/stdc++.h>using namespace std;int main(){int w,n,a[32000];cin>>w>>n;for(int i=1;i<=n;i++){cin>>a[i];}sort(a+1,a+n+1);int i=1

html css jquery选项卡 代码练习小项目

在学习 html 和 css jquery 结合使用的时候 做好是能尝试做一些简单的小功能,来提高自己的 逻辑能力,熟悉代码的编写语法 下面分享一段代码 使用html css jquery选项卡 代码练习 <div class="box"><dl class="tab"><dd class="active">手机</dd><dd>家电</dd><dd>服装</dd><dd>数码</dd><dd