xgb-练习

2024-08-28 02:44
文章标签 xgb 练习

本文主要是介绍xgb-练习,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

以下代码未验证,仅用作练习

#!/usr/bin/env python3
# -*- coding: utf-8 -*-import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV, StratifiedKFold
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve, precision_recall_curve
from sklearn.feature_selection import SelectKBest, f_classif, RFE
from sklearn.exceptions import ConvergenceWarning
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline as ImbPipeline
import shap
import joblib
import warnings
import time
import logging
from scipy import stats
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots# 设置警告过滤和日志
warnings.filterwarnings("ignore", category=ConvergenceWarning)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')class CreditRiskModel:def __init__(self, data_path):self.data_path = data_pathself.df = Noneself.X = Noneself.y = Noneself.X_train = Noneself.X_test = Noneself.y_train = Noneself.y_test = Noneself.preprocessor = Noneself.models = {}self.best_model = Nonedef load_and_explore_data(self):logging.info("Loading and exploring data...")self.df = pd.read_csv(self.data_path)logging.info(f"Dataset shape: {self.df.shape}")logging.info("\nDataset info:")self.df.info()logging.info("\nDataset description:")logging.info(self.df.describe())logging.info("\nTarget variable distribution:")logging.info(self.df['target'].value_counts(normalize=True))# 数据可视化self.visualize_data()def visualize_data(self):logging.info("Generating data visualizations...")# 相关性热力图plt.figure(figsize=(12, 10))sns.heatmap(self.df.corr(), annot=False, cmap='coolwarm')plt.title('特征相关性热力图')plt.tight_layout()plt.savefig('correlation_heatmap.png')plt.close()# 目标变量分布plt.figure(figsize=(8, 6))sns.countplot(x='target', data=self.df)plt.title('目标变量分布')plt.savefig('target_distribution.png')plt.close()# 数值型特征的分布num_features = self.df.select_dtypes(include=['int64', 'float64']).columnsfig = make_subplots(rows=len(num_features)//3 + 1, cols=3, subplot_titles=num_features)for i, col in enumerate(num_features):row = i // 3 + 1col_num = i % 3 + 1fig.add_trace(go.Histogram(x=self.df[col], name=col), row=row, col=col_num)fig.update_layout(height=300*len(num_features)//3, width=1000, title_text="数值型特征分布")fig.write_html("numeric_features_distribution.html")def preprocess_data(self):logging.info("Preprocessing data...")self.X = self.df.drop('target', axis=1)self.y = self.df['target']numeric_features = self.X.select_dtypes(include=['int64', 'float64']).columnscategorical_features = self.X.select_dtypes(include=['object']).columnsnumeric_transformer = Pipeline(steps=[('imputer', SimpleImputer(strategy='median')),('scaler', StandardScaler())])categorical_transformer = Pipeline(steps=[('imputer', SimpleImputer(strategy='constant', fill_value='missing')),('onehot', OneHotEncoder(handle_unknown='ignore'))])self.preprocessor = ColumnTransformer(transformers=[('num', numeric_transformer, numeric_features),('cat', categorical_transformer, categorical_features)])# 使用SMOTE处理不平衡数据smote = SMOTE(random_state=42)self.preprocessor = ImbPipeline([('preprocessor', self.preprocessor),('smote', smote)])# 数据分割self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(self.X, self.y, test_size=0.2, random_state=42, stratify=self.y)# 应用预处理self.X_train = self.preprocessor.fit_transform(self.X_train, self.y_train)self.X_test = self.preprocessor.transform(self.X_test)def select_features(self, k=20):logging.info(f"Selecting top {k} features...")selector = SelectKBest(f_classif, k=k)self.X_train = selector.fit_transform(self.X_train, self.y_train)self.X_test = selector.transform(self.X_test)selected_feature_indices = selector.get_support(indices=True)self.selected_features = self.preprocessor.get_feature_names_out()[selected_feature_indices]logging.info(f"Selected features: {self.selected_features}")def train_models(self):logging.info("Training multiple models...")models = {'RandomForest': RandomForestClassifier(random_state=42),'GradientBoosting': GradientBoostingClassifier(random_state=42),'LogisticRegression': LogisticRegression(random_state=42),'SVM': SVC(probability=True, random_state=42)}for name, model in models.items():logging.info(f"Training {name}...")model.fit(self.X_train, self.y_train)self.models[name] = modeldef evaluate_models(self):logging.info("Evaluating models...")results = {}for name, model in self.models.items():logging.info(f"Evaluating {name}...")y_pred = model.predict(self.X_test)y_pred_proba = model.predict_proba(self.X_test)[:, 1]results[name] = {'accuracy': model.score(self.X_test, self.y_test),'roc_auc': roc_auc_score(self.y_test, y_pred_proba),'classification_report': classification_report(self.y_test, y_pred),'confusion_matrix': confusion_matrix(self.y_test, y_pred)}# ROC曲线fpr, tpr, _ = roc_curve(self.y_test, y_pred_proba)plt.figure()plt.plot(fpr, tpr, label=f'ROC curve (AUC = {results[name]["roc_auc"]:.2f})')plt.plot([0, 1], [0, 1], 'k--')plt.xlim([0.0, 1.0])plt.ylim([0.0, 1.05])plt.xlabel('False Positive Rate')plt.ylabel('True Positive Rate')plt.title(f'ROC Curve - {name}')plt.legend(loc="lower right")plt.savefig(f'roc_curve_{name}.png')plt.close()# 精确率-召回率曲线precision, recall, _ = precision_recall_curve(self.y_test, y_pred_proba)plt.figure()plt.plot(recall, precision, label='Precision-Recall curve')plt.xlabel('Recall')plt.ylabel('Precision')plt.title(f'Precision-Recall Curve - {name}')plt.legend(loc="lower left")plt.savefig(f'precision_recall_curve_{name}.png')plt.close()self.results = resultsself.best_model = max(results, key=lambda x: results[x]['roc_auc'])logging.info(f"Best model: {self.best_model}")def hyperparameter_tuning(self):logging.info("Performing hyperparameter tuning for the best model...")if self.best_model == 'RandomForest':param_grid = {'n_estimators': [100, 200, 300],'max_depth': [None, 10, 20, 30],'min_samples_split': [2, 5, 10],'min_samples_leaf': [1, 2, 4]}model = RandomForestClassifier(random_state=42)elif self.best_model == 'GradientBoosting':param_grid = {'n_estimators': [100, 200, 300],'learning_rate': [0.01, 0.1, 0.2],'max_depth': [3, 4, 5],'min_samples_split': [2, 5, 10],'min_samples_leaf': [1, 2, 4]}model = GradientBoostingClassifier(random_state=42)elif self.best_model == 'LogisticRegression':param_grid = {'C': [0.001, 0.01, 0.1, 1, 10, 100],'penalty': ['l1', 'l2'],'solver': ['liblinear', 'saga']}model = LogisticRegression(random_state=42)else:  # SVMparam_grid = {'C': [0.1, 1, 10],'kernel': ['rbf', 'poly'],'gamma': ['scale', 'auto', 0.1, 1]}model = SVC(probability=True, random_state=42)grid_search = GridSearchCV(model, param_grid, cv=5, scoring='roc_auc', n_jobs=-1)grid_search.fit(self.X_train, self.y_train)logging.info(f"Best parameters: {grid_search.best_params_}")logging.info(f"Best cross-validation score: {grid_search.best_score_:.4f}")self.models[self.best_model] = grid_search.best_estimator_def feature_importance(self):logging.info("Calculating feature importance...")if hasattr(self.models[self.best_model], 'feature_importances_'):importances = self.models[self.best_model].feature_importances_indices = np.argsort(importances)[::-1]plt.figure(figsize=(12, 8))plt.title("Feature Importances")plt.bar(range(len(importances)), importances[indices])plt.xticks(range(len(importances)), [self.selected_features[i] for i in indices], rotation=90)plt.tight_layout()plt.savefig('feature_importances.png')plt.close()def model_interpretation(self):logging.info("Interpreting model with SHAP...")explainer = shap.TreeExplainer(self.models[self.best_model])shap_values = explainer.shap_values(self.X_test)plt.figure(figsize=(10, 8))shap.summary_plot(shap_values[1], self.X_test, feature_names=self.selected_features, plot_type="bar")plt.title("Feature Importance (SHAP values)")plt.tight_layout()plt.savefig('shap_feature_importance.png')plt.close()plt.figure(figsize=(12, 8))shap.summary_plot(shap_values[1], self.X_test, feature_names=self.selected_features)plt.title("Feature Impact (SHAP values)")plt.tight_layout()plt.savefig('shap_feature_impact.png')plt.close()def save_model(self):logging.info("Saving the best model...")joblib.dump(self.models[self.best_model], f'best_model_{self.best_model}.joblib')joblib.dump(self.preprocessor, 'preprocessor.joblib')def generate_report(self):logging.info("Generating final report...")report = f"""Credit Risk Model Report========================Data Summary:-------------Total samples: {len(self.df)}Features: {len(self.X.columns)}Target distribution:{self.df['target'].value_counts(normalize=True)}Model Performance:------------------Best Model: {self.best_model}ROC AUC Score: {self.results[self.best_model]['roc_auc']:.4f}Classification Report:{self.results[self.best_model]['classification_report']}Confusion Matrix:{self.results[self.best_model]['confusion_matrix']}Top Features:-------------{', '.join(self.selected_features[:10])}Model Interpretation:---------------------Please refer to the SHAP plots for detailed feature importance and impact analysis.Notes:------- The model has been trained on balanced data using SMOTE.- Hyperparameter tuning was performed using GridSearchCV.- The model and preprocessor have been saved for future use.Next Steps:-----------1. Monitor model performance in production.2. Regularly retrain the model with new data.3. Consider adding more relevant features if available.4. Explore more advanced techniques like stacking or neural networks."""with open('credit_risk_model_report.txt', 'w') as f:f.write(report)logging.info("Report generated and saved as 'credit_risk_model_report.txt'")def detect_anomalies(self):logging.info("Detecting anomalies in the dataset...")# 使用IsolationForest进行异常检测from sklearn.ensemble import IsolationForestiso_forest = IsolationForest(contamination=0.1, random_state=42)anomalies = iso_forest.fit_predict(self.X)# 将异常结果添加到原始数据集self.df['anomaly'] = anomalies# 可视化异常plt.figure(figsize=(12, 8))plt.scatter(self.df.index, self.df.iloc[:, 0], c=self.df['anomaly'], cmap='viridis')plt.title('Anomaly Detection Results')plt.xlabel('Index')plt.ylabel('Feature 1')plt.colorbar(label='Anomaly (-1) vs Normal (1)')plt.savefig('anomaly_detection.png')plt.close()logging.info(f"Detected {sum(anomalies == -1)} anomalies in the dataset.")def perform_cross_validation(self):logging.info("Performing cross-validation...")cv_scores = cross_val_score(self.models[self.best_model], self.X_train, self.y_train, cv=5, scoring='roc_auc')logging.info(f"Cross-validation ROC AUC scores: {cv_scores}")logging.info(f"Mean ROC AUC: {np.mean(cv_scores):.4f} (+/- {np.std(cv_scores) * 2:.4f})")def analyze_misclassifications(self):logging.info("Analyzing misclassifications...")y_pred = self.models[self.best_model].predict(self.X_test)misclassified = self.X_test[y_pred != self.y_test]# 分析误分类样本的特征分布for feature in self.selected_features:plt.figure(figsize=(10, 6))sns.boxplot(x=self.y_test, y=self.X_test[feature])plt.title(f'Distribution of {feature} for Correct and Incorrect Predictions')plt.savefig(f'misclassification_analysis_{feature}.png')plt.close()logging.info("Misclassification analysis plots saved.")def run(self):self.load_and_explore_data()self.preprocess_data()self.select_features()self.train_models()self.evaluate_models()self.hyperparameter_tuning()self.feature_importance()self.model_interpretation()self.save_model()self.detect_anomalies()self.perform_cross_validation()self.analyze_misclassifications()self.generate_report()def main():start_time = time.time()logging.info("Starting Credit Risk Modeling process...")model = CreditRiskModel('credit_risk_data.csv')model.run()end_time = time.time()logging.info(f"Credit Risk Modeling process completed in {end_time - start_time:.2f} seconds.")if __name__ == "__main__":main()

这篇关于xgb-练习的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1113506

相关文章

RabbitMQ练习(AMQP 0-9-1 Overview)

1、What is AMQP 0-9-1 AMQP 0-9-1(高级消息队列协议)是一种网络协议,它允许遵从该协议的客户端(Publisher或者Consumer)应用程序与遵从该协议的消息中间件代理(Broker,如RabbitMQ)进行通信。 AMQP 0-9-1模型的核心概念包括消息发布者(producers/publisher)、消息(messages)、交换机(exchanges)、

【Rust练习】12.枚举

练习题来自:https://practice-zh.course.rs/compound-types/enum.html 1 // 修复错误enum Number {Zero,One,Two,}enum Number1 {Zero = 0,One,Two,}// C语言风格的枚举定义enum Number2 {Zero = 0.0,One = 1.0,Two = 2.0,}fn m

MySql 事务练习

事务(transaction) -- 事务 transaction-- 事务是一组操作的集合,是一个不可分割的工作单位,事务会将所有的操作作为一个整体一起向系统提交或撤销请求-- 事务的操作要么同时成功,要么同时失败-- MySql的事务默认是自动提交的,当执行一个DML语句,MySql会立即自动隐式提交事务-- 常见案例:银行转账-- 逻辑:A给B转账1000:1.查询

html css jquery选项卡 代码练习小项目

在学习 html 和 css jquery 结合使用的时候 做好是能尝试做一些简单的小功能,来提高自己的 逻辑能力,熟悉代码的编写语法 下面分享一段代码 使用html css jquery选项卡 代码练习 <div class="box"><dl class="tab"><dd class="active">手机</dd><dd>家电</dd><dd>服装</dd><dd>数码</dd><dd

014.Python爬虫系列_解析练习

我 的 个 人 主 页:👉👉 失心疯的个人主页 👈👈 入 门 教 程 推 荐 :👉👉 Python零基础入门教程合集 👈👈 虚 拟 环 境 搭 建 :👉👉 Python项目虚拟环境(超详细讲解) 👈👈 PyQt5 系 列 教 程:👉👉 Python GUI(PyQt5)文章合集 👈👈 Oracle数据库教程:👉👉 Oracle数据库文章合集 👈👈 优

如何快速练习键盘盲打

盲打是指在不看键盘的情况下进行打字,这样可以显著提高打字速度和效率。以下是一些练习盲打的方法: 熟悉键盘布局:首先,你需要熟悉键盘上的字母和符号的位置。可以通过键盘图或者键盘贴纸来帮助记忆。 使用在线打字练习工具:有许多在线的打字练习网站,如Typing.com、10FastFingers等,它们提供了不同难度的练习和测试。 练习基本键位:先从学习手指放在键盘上的“家位”开始,通常是左手的

anaconda3下的python编程练习-csv翻译器

相关理解和命令 一、环境配置1、conda命令2、pip命令3、python命令 二、开发思路三、开发步骤 一、环境配置 1、conda命令 镜像源配置 conda config --show channels //查看镜像源conda config --remove-key channels //删除添加源,恢复默认源#添加镜像源conda config --ad

推荐练习键盘盲打的网站

对于初学者来说,以下是一些推荐的在线打字练习网站: 打字侠:这是一个专业的在线打字练习平台,提供科学合理的课程设置和个性化学习计划,适合各个水平的用户。它还提供实时反馈和数据分析,帮助你提升打字速度和准确度。 dazidazi.com:这个网站提供了基础的打字练习,适合初学者从零开始学习打字。 Type.fun打字星球:提供了丰富的盲打课程和科学的打字课程设计,还有诗词歌赋、经典名著等多样

综合DHCP、ACL、NAT、Telnet和PPPoE进行网络设计练习

描述:企业内网和运营商网络如上图所示。 公网IP段:12.1.1.0/24。 内网IP段:192.168.1.0/24。 公网口PPPOE 拨号采用CHAP认证,用户名:admin 密码:Admin@123 财务PC 配置静态IP:192.168.1.8 R1使用模拟器中的AR201型号,作为交换路由一体机,下图的WAN口为E0/0/8口,可以在该接口下配置IP地址。 可以通过

JAVA学习-练习试用Java实现“删除有序数组中的重复项”

问题: 给你一个有序数组 nums ,请你 原地 删除重复出现的元素,使每个元素 只出现一次 ,返回删除后数组的新长度。 不要使用额外的数组空间,你必须在 原地 修改输入数组 并在使用 O(1) 额外空间的条件下完成。 说明: 为什么返回数值是整数,但输出的答案是数组呢? 请注意,输入数组是以「引用」方式传递的,这意味着在函数里修改输入数组对于调用者是可见的。 你可以想象内部操作如下