本文主要是介绍OCSVM: A toy example for beginner (2),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
目录
1. 前言
2. 在训练数据中掺入一些单位圆以外的数据
3. 在正常集合和异常集合之间设置隔离带
4. 总结
1. 前言
在上一篇(OCSVM: A toy example for beginner)中,训练数据全部是处于单位圆内的。经过训练,分类器模型也的确搞明白了把单位圆内作为正常区域,在此之外的区域当作是异常数据。在对测试数据的判别中有98%的准确率,还是觉得挺神奇的。因为并没有给标签啊,分类器是怎么学习到要区分单位圆内和单位圆外的呢?仅仅是因为训练数据全部位于单位圆内这样一个规则的区域?
2. 在训练数据中掺入一些单位圆以外的数据
以下实验中,将单位圆以外的数据也加入到训练数据X中,但是密度只有单位圆内的1/10. 让我们看看训练和预测结果会发生什么变化?
n_train = 10000
n_test = 1000
X1 = np.zeros((n_train,2))
Y1 = np.zeros((n_test,2))
k_outside = 0
for i in range(n_train):x1 = random.uniform(-1.,1.)x2 = random.uniform(-1.,1.)if isInUnitCircle(x1,x2):X1[i,:] = [x1,x2]else:k_outside += 1if k_outside%10 == 0:X1[i,:] = [x1,x2]for i in range(n_test):y1 = random.uniform(-1.,1.)y2 = random.uniform(-1.,1.)Y1[i,:] = [y1,y2]clf = OneClassSVM(gamma=0.01, nu=0.01).fit(X1)
#clf = OneClassSVM().fit(X)
X1_pred = clf.predict(X1)
Y1_pred = clf.predict(Y1)
X1_accuracy = np.sum(X1_pred[X1_pred > 0])/len(X1_pred)
print('X1_accuracy = {0:6.3f}'.format(X1_accuracy))Y1_pred_gt = 2 * ((Y1[:,0]**2 + Y1[:,1]**2) < 1) - 1
Y1_diff = Y1_pred_gt * Y1_pred
Y1_accuracy = np.sum(Y1_diff[Y1_diff > 0])/len(Y1_diff)
print('Y1_accuracy = {0:6.3f}'.format(Y1_accuracy))
训练和预测结果:
X1_accuracy = 0.990Y1_accuracy = 0.885
与之前相比测试数据的预测准确度有所下降。接下来以图示的方式看看发生了什么?
unitcirc = np.zeros((100,2))
for i in range(100):unitcirc[i,0] = np.cos(i*2*np.pi/100)unitcirc[i,1] = np.sin(i*2*np.pi/100)X1_pos = X1[X1_pred==1]
X1_neg = X1[X1_pred==-1]
Y1_pos = Y1[Y1_pred==1]
Y1_neg = Y1[Y1_pred==-1]fig, ax = plt.subplots(1,2,figsize = (12,6))
ax[0].scatter(X1_pos[:,0],X1_pos[:,1])
ax[0].scatter(X1_neg[:,0],X1_neg[:,1])
ax[0].plot(unitcirc[:,0],unitcirc[:,1],color='y')
#ax[1].scatter(Y1[:,0],Y1[:,1])
ax[1].scatter(Y1_pos[:,0],Y1_pos[:,1])
ax[1].scatter(Y1_neg[:,0],Y1_neg[:,1])
ax[1].plot(unitcirc[:,0],unitcirc[:,1],color='y')
ax[0].set_title('Train set')
ax[1].set_title('Test set')
fig.suptitle('Left: train set; Right: test set')
在图上叠加了画单位圆以方便分辨单位圆内外的点集。
训练好的分类器依旧基本上学到了把单位圆边界作为正常、异常的分界线。这个可以说还是不错的,毕竟现在单位圆内外其实只有数据密度的差别,分类器捕捉到了密度的差距并据此总结出了分辨规则。但是在单位圆附近(尤其是靠外侧)有一些模糊,所以不管是训练集还是测试集,在单位圆边界靠外侧有很多点被判断为‘正常’点。这个结果应该说是很正常,毕竟密度的估计(假定分类器‘聪明’地捕捉到了以密度作为判断标准)需要针对一定区域,从内到外密度的下降是渐变的,而不是在单位圆边界处断崖式地跌落为单位圆内的1/10.
那如果在训练集中,不仅让单位圆外的数据密度保持为单位圆内的1/10,而且还设置一条“护城河”,让‘正常’集合与‘异常’集合之间有一个明显的无人区过渡带,是不是有助于分类器学习到更精确的判断规则呢?
3. 在正常集合和异常集合之间设置隔离带
在以下例子中,我们将‘正常’集合限定于单单位圆内,而将‘异常’集合限定于距离原点1.5以外的区域。
n_train = 10000
n_test = 1000
X1 = np.zeros((n_train,2))
Y1 = np.zeros((n_test,2))
k_outside = 0
for i in range(n_train):x1 = random.uniform(-1.,1.)x2 = random.uniform(-1.,1.)if isInUnitCircle(x1,x2):X1[i,:] = [x1,x2]elif x1**2 + x2**2 > 1.5:k_outside += 1if k_outside%10 == 0:X1[i,:] = [x1,x2]for i in range(n_test):y1 = random.uniform(-1.,1.)y2 = random.uniform(-1.,1.)Y1[i,:] = [y1,y2]
clf = OneClassSVM(gamma=0.01, nu=0.01).fit(X1)
#clf = OneClassSVM().fit(X)
X1_pred = clf.predict(X1)
Y1_pred = clf.predict(Y1)
X1_accuracy = np.sum(X1_pred[X1_pred > 0])/len(X1_pred)
print('X1_accuracy = {0:6.3f}'.format(X1_accuracy))Y1_pred_gt = 2 * ((Y1[:,0]**2 + Y1[:,1]**2) < 1) - 1
Y1_diff = Y1_pred_gt * Y1_pred
Y1_accuracy = np.sum(Y1_diff[Y1_diff > 0])/len(Y1_diff)
print('Y1_accuracy = {0:6.3f}'.format(Y1_accuracy))
运行结果: X1_accuracy = 0.990Y1_accuracy = 0.995
unitcirc = np.zeros((100,2))
for i in range(100):unitcirc[i,0] = np.cos(i*2*np.pi/100)unitcirc[i,1] = np.sin(i*2*np.pi/100)X1_pos = X1[X1_pred==1]
X1_neg = X1[X1_pred==-1]
Y1_pos = Y1[Y1_pred==1]
Y1_neg = Y1[Y1_pred==-1]fig, ax = plt.subplots(1,2,figsize = (12,6))
ax[0].scatter(X1_pos[:,0],X1_pos[:,1])
ax[0].scatter(X1_neg[:,0],X1_neg[:,1])
ax[0].plot(unitcirc[:,0],unitcirc[:,1],color='y')
#ax[1].scatter(Y1[:,0],Y1[:,1])
ax[1].scatter(Y1_pos[:,0],Y1_pos[:,1])
ax[1].scatter(Y1_neg[:,0],Y1_neg[:,1])
ax[1].plot(unitcirc[:,0],unitcirc[:,1],color='y')
ax[0].set_title('Train set')
ax[1].set_title('Test set')
fig.suptitle('Left: train set; Right: test set')
果然如预期一样,分类器学习到了更加准确的分类规则,甚至针对测试集合的预测准确度超过了针对训练集本身的预测准确度,达到了99.5%!
4. 总结
通过几个简单的实验增进了对于单类SVM的“行为特性”的理解。
接下来要找一些更为多样化、复杂化的例子来“玩弄玩弄”,看看各种参数应该如何选择或者优化。再往后就要整点硬核一点的东西,比如说假模假式地推(抄?^-^)一下公式啊啥的。。。,然后能不能写一个属于自己的OCSVM呢?还有经典的SVM呢?
欲知后事如何,敬请期待!
这篇关于OCSVM: A toy example for beginner (2)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!