sklearn中SVM的可视化

2024-06-21 13:32
文章标签 可视化 svm sklearn

本文主要是介绍sklearn中SVM的可视化,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

    • 第一部分:如何绘制三维散点图和分类平面
    • 第二部分:sklearn中的SVM参数介绍
    • 第三部分:源代码and数据

最近遇到一个简单的二分类任务,本来可用一维的线性分类器来解决,但是为了获得更好的泛化性能,我选取了三个特征,变成了一个三维空间的二分类任务。目的就是使两类样本之间的间隔再大一些,为了满足这种需求,自然而然的想到使用SVM作为分类器,并且该任务是线性可分,自然的选用LinearSVM——核函数为线性函数。为了充分理解SVM,我还对SVM的分类平面、支持向量、bad case均进行可视化,通过本文可以了解:

1 .如何用matplotlib绘制三维散点图
2 .sklearn中SVM的核函数

先看图:

这里写图片描述

其中蓝色的是负样本,红色为正样本,带绿色圈圈的是支持向量,蓝色平面就是分类平面;

再看测试集,其中绿色圈圈,圈出来的是分类错误的样本:
这里写图片描述

第一部分:如何绘制三维散点图和分类平面

这里采用sklearn里面的SVC作为我的分类器,由于分类任务较为简单,几乎是线性可分,所以采用线性核函数,通过以下语句构建SVM并训练:

cls = svm.SVC(kernel='linear', C=1.5)
cls.fit(x_train, y_train)

其中C为惩罚因子,C越大,模型越不能容忍错误,则会使模型更容易过拟合,反之C越小,模型对错误样本容忍性很强,可能导致模型欠拟合。关于系数C的理论可以参考博客,点这里呀

训练好我们的cls之后就是如何绘制分类平面了,这就得知道我们分类平面的表达式是什么。SVM分类平面通式为Wφ(X)+b = 0 ,当采用线性核函数时,分类平面简化为:WX+b=0 (φ(X)=X),其中W,X为向量,b为标量,想进一步了解核函数作用的朋友可以参考博客,点这里呀

本文用例X是一个三维向量,因此W也应该是一个三维的向量,W和b 分别可从cls的coef_ , intercept_这两个属性中获取,具体如下:

w = cls.coef_  
b = cls.intercept_

则绘制分类平面步骤:

	ax = plt.subplot(111, projection='3d')x = np.arange(0,1,0.01)y = np.arange(0,1,0.11)x, y = np.meshgrid(x, y)z = (w[0,0]*x + w[0,1]*y + b) / (-w[0,2])surf = ax.plot_surface(x, y, z, rstride=1, cstride=1)

首先,创建一个3d的画布,其次要构建分类平面表达式 z = (w[0,0]x + w[0,1]y + b) / (-w[0,2])
其实是这样演变的:
Wφ(X)+b = 0 $ \Rightarrow $ WX+b=0 $ \Rightarrow $ w1
x1+w2
x2+w3*x3 + b = 0

有了分类平面,我们还想知道,支持向量是哪些,那么可以通过cls中的support_ 属性获取支持向量的idx,然后依据idx去训练集中找到我们的支持向量

第二部分:sklearn中的SVM参数介绍

SVM中最关键的就是核函数的选择,上一部分中仅仅采用了最简单的线性核函数(其实等于没用核函数,哈哈哈),SVM中常用的核函数有高斯核(rbf,径向基)、多项式核以及sigmoid核。在这里就简单介绍sklearn中SVM的这些核函数具体使用方法。

1.高斯核(rbf) 表达式:$ K(x,z)=exp(−γ||x−z||^{2})$
涉及参数 γ,默认值为 1/特征维度
创建一个高斯核的SVM分类器:
cls = svm.SVC(kenerl = ‘rbf’,gamma = 0.5 )

2.sigmoid核函数表达式: K ( x , z ) = t a n h ( γ x ∙ z + r ) K(x,z)=tanh(γx∙z+r) K(x,z)=tanhγxz+r)
涉及两个参数:γ,r
γ通过gamma设置,默认值为1/特征维度; r通过coef0设置,默认值为0;
创建一个sigmoid核函数的SVM:
cls = svm.SVC(kenerl = ‘sigmoid’,gamma = 0.3,coef0=0)

3.多项式核表达式: K ( x , z ) = ( γ x ∙ z + r ) d K(x,z)=(γx∙z+r)^{d} K(x,z)=γxz+r)d
涉及三个参数:γ,r,d
γ通过gamma设置,默认值为1/特征维度; r通过coef0设置,默认值为0;,d通过degree设置,默认值为3
创建一个二阶多项式核的SVM:
cls = svm.SVC(kenerl = ‘poly’,gamma = 0.3,coef0=0,dgree=2 )

在SVC中还有一个参数可以控制样本的权重,用以解决unbalance问题,class_weight,具体参考官方文档:
http://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html
推荐博客:
https://www.cnblogs.com/pinard/p/6117515.html

最后留下一个疑问,倘若我采用非线性核,如高斯核函数,我应该如何绘制分类平面?

我的思路是这样的,分类平面表达式:Wφ(X) + b =0, 当采用非线性核的时候,我们如何能知道这个映射函数φ(·)呢?

第三部分:源代码and数据

代码+数据文件可从:
1.CSDN下载:https://download.csdn.net/download/u011995719/10557270
2.百度云: https://pan.baidu.com/s/1s5Xu_h2nlTSum7jeoKniGQ 密码: gtc8

代码:

# coding: utf-8
import numpy as np
import csv
from sklearn import svm
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D"""
采用 sklearn 中的svm 
"""
train_path = './train_set.csv'
test_path = './test_set.csv'def load_data(data_path):X,Y = [],[]csv_reader = csv.reader(open(data_path,'r'))for row in csv_reader:a = row[0][1:-1].split()X.append(np.array(a))Y.append(np.array(row[1]))return X, Ydef find_badcase(X, Y):bad_list = []y = cls.predict(X)for i in range(len(X)):if y[i] != Y[i]:bad_list.append(i)return bad_listif __name__=="__main__":# load datax_train, y_train = load_data(train_path)x_test, y_test = load_data(test_path)# trainingcls = svm.SVC(kernel='linear', C=1.5)cls.fit(x_train, y_train)# accuracyprint('Test score: %.4f' % cls.score(x_test, y_test))print('Train score: %.4f' % cls.score(x_train, y_train))# print bad case idbad_idx = find_badcase(x_test,y_test)n_Support_vector = cls.n_support_  # 支持向量个数sv_idx = cls.support_  # 支持向量索引w = cls.coef_  # 方向向量Wb = cls.intercept_# plot# 绘制分类平面ax = plt.subplot(111, projection='3d')x = np.arange(0,1,0.01)y = np.arange(0,1,0.11)x, y = np.meshgrid(x, y)z = (w[0,0]*x + w[0,1]*y + b) / (-w[0,2])surf = ax.plot_surface(x, y, z, rstride=1, cstride=1)# 绘制三维散点图x_array = np.array(x_train, dtype=float)y_array = np.array(y_train, dtype=int)pos = x_array[np.where(y_array==1)]neg = x_array[np.where(y_array==-1)]ax.scatter(pos[:,0], pos[:,1], pos[:,2], c='r', label='pos')ax.scatter(neg[:,0], neg[:,1], neg[:,2], c='b', label='neg')# 绘制支持向量X = np.array(x_train,dtype=float)for i in range(len(sv_idx)):ax.scatter(X[sv_idx[i],0], X[sv_idx[i],1], X[sv_idx[i],2],s=50,c='',marker='o', edgecolors='g')# 绘制 bad case# x_test = np.array(x_test,dtype=float)# for i in range(len(bad_idx)):#     j = bad_idx[i]#     ax.scatter(x_test[j,0], x_test[j,1], x_test[j,2],s=60,#                c='',marker='o', edgecolors='g')ax.set_zlabel('Z')    # 坐标轴ax.set_ylabel('Y')ax.set_xlabel('X')ax.set_zlim([0, 1])plt.legend(loc='upper left')ax.view_init(35,300)plt.show()

再次强调:请问如何绘制非线性核的分类平面呢??

这篇关于sklearn中SVM的可视化的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1081356

相关文章

Python中的可视化设计与UI界面实现

《Python中的可视化设计与UI界面实现》本文介绍了如何使用Python创建用户界面(UI),包括使用Tkinter、PyQt、Kivy等库进行基本窗口、动态图表和动画效果的实现,通过示例代码,展示... 目录从像素到界面:python带你玩转UI设计示例:使用Tkinter创建一个简单的窗口绘图魔法:用

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

Python:豆瓣电影商业数据分析-爬取全数据【附带爬虫豆瓣,数据处理过程,数据分析,可视化,以及完整PPT报告】

**爬取豆瓣电影信息,分析近年电影行业的发展情况** 本文是完整的数据分析展现,代码有完整版,包含豆瓣电影爬取的具体方式【附带爬虫豆瓣,数据处理过程,数据分析,可视化,以及完整PPT报告】   最近MBA在学习《商业数据分析》,大实训作业给了数据要进行数据分析,所以先拿豆瓣电影练练手,网络上爬取豆瓣电影TOP250较多,但对于豆瓣电影全数据的爬取教程很少,所以我自己做一版。 目

基于SSM+Vue+MySQL的可视化高校公寓管理系统

系统展示 管理员界面 宿管界面 学生界面 系统背景   当前社会各行业领域竞争压力非常大,随着当前时代的信息化,科学化发展,让社会各行业领域都争相使用新的信息技术,对行业内的各种相关数据进行科学化,规范化管理。这样的大环境让那些止步不前,不接受信息改革带来的信息技术的企业随时面临被淘汰,被取代的风险。所以当今,各个行业领域,不管是传统的教育行业

SVM编程实现python

深入解析python版SVM源码系列--简化版SMO算法 SVM使用SMO算法来解决其中涉及到的二次规划问题。一个简单版本的SMO算法的实现如下: ''' 随机选择随机数,不等于J '''def selectJrand(i,m):j=i #we want to select any J not equal to iwhile (j==i):j = int(random

「大数据分析」图形可视化,如何选择大数据可视化图形?

​图形可视化技术,在大数据分析中,是一个非常重要的关键部分。我们前期通过数据获取,数据处理,数据分析,得出结果,这些过程都是比较抽象的。如果是非数据分析专业人员,很难清楚我们这些工作,到底做了些什么事情。即使是专业人员,在不清楚项目,不了解业务规则,不熟悉技术细节的情况下。要搞清楚我们的大数据分析,这一系列过程,也是比较困难的。 我们在数据处理和分析完成后,一般来说,都需要形成结论报告。怎样让大

11Python的Pandas:可视化

Pandas本身并没有直接的可视化功能,但它与其他Python库(如Matplotlib和Seaborn)无缝集成,允许你快速创建各种图表和可视化。这里是一些使用Pandas数据进行可视化的常见方法: 1. 使用Matplotlib Pandas中的plot()方法实际上是基于Matplotlib的,你可以使用它来绘制各种基本图表,例如折线图、柱状图、散点图等。 import pandas

【全网最全】2024年数学建模国赛A题30页完整建模文档+17页成品论文+保奖matla代码+可视化图表等(后续会更新)

您的点赞收藏是我继续更新的最大动力! 一定要点击如下的卡片,那是获取资料的入口! 【全网最全】2024年数学建模国赛A题30页完整建模文档+17页成品论文+保奖matla代码+可视化图表等(后续会更新)「首先来看看目前已有的资料,还会不断更新哦~一次购买,后续不会再被收费哦,保证是全网最全资源,随着后续内容更新,价格会上涨,越早购买,价格越低,让大家再也不需要到处买断片资料啦~💰💸👋」�

Python利用pyecharts实现数据可视化

小编会持续更新知识笔记,如果感兴趣可以三连支持。闲来无事,水文一篇,不过上手实践一下倒还是挺好玩的,这一块知识说不定以后真可以尝试拿来做数据库的报表显示。         有梦别怕苦,想赢别喊累。 目录 前言 JSON数据格式的转换 pyecharts简介和入门使用 前言       小编我今天闲来无事,打算学习一下py,结果你猜怎么着,竟然看到py可以将数据

【mysql zeppelin】zeppelin 大数据可视化分析工具安装教程精要

Apache Zeppelin是一款大数据分析和可视化工具,可以让数据分析师在一个基于Web页面的笔记本中,使用不同的语言,对不同数据源中的数据进行交互式分析,并对分析结果进行可视化的工具。下面我们主要讲解如何安装和配置的精要部分。 一、zeppelin 安装和配置登录用户 官方网站: https://zeppelin.apache.org/ 下载地址: https://zeppelin