KNN cifar-10 L1 L2距离 交叉验证

2024-05-15 09:58
文章标签 验证 距离 knn l2 l1 交叉 cifar

本文主要是介绍KNN cifar-10 L1 L2距离 交叉验证,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

K-NN k-Nearest Neighbor分类器

之前的近邻算法(NN)是仅仅选择一个最近的图像标签,K-NN是选出K个差值最小的图像标签,然后看那个标签的数量多就选用那个标签作为预测值,这样就提高了泛化能力。

交叉验证。

有时候,训练集数量较小(因此验证集的数量更小)。如果是交叉验证集,将训练集平均分成5份,其中4份用来训练,1份用来验证。然后我们循环着取其中4份来训练,其中1份来验证,最后取所有5次验证结果的平均值作为算法验证结果。
这里写图片描述

import numpy as np
import pickle
import matplotlib.pyplot as plt'''
输入训练集及测试集
'''
file_path = "E:/cifar-10-python/cifar-10-batches-py/"def unpickle(file):import picklewith open(file, 'rb') as fo:dict = pickle.load(fo, encoding='latin1')return dict'''
加载数据集
'''
def load_CIFAR10(file):# dictTrain1 = unpickle(file + "data_batch_1")# dataTrain1 = dictTrain1['data']# labelTrain1 = dictTrain1['labels']## dictTrain2 = unpickle(file + "data_batch_2")# dataTrain2 = dictTrain2['data']# labelTrain2 = dictTrain2['labels']## dictTrain3 = unpickle(file + "data_batch_3")# dataTrain3 = dictTrain3['data']# labelTrain3 = dictTrain3['labels']## dictTrain4 = unpickle(file + "data_batch_4")# dataTrain4 = dictTrain4['data']# labelTrain4 = dictTrain4['labels']## dictTrain5 = unpickle(file + "data_batch_5")# dataTrain5 = dictTrain5['data']# labelTrain5 = dictTrain5['labels']# dataTrain = np.vstack([dataTrain1, dataTrain2, dataTrain3, dataTrain4, dataTrain5])# labelTrain = np.concatenate([labelTrain1, labelTrain2, labelTrain3, labelTrain4, labelTrain5])dictTrain = unpickle(file + "data_batch_1")dataTrain = dictTrain['data']labelTrain = dictTrain['labels']for i in range(2,6):dictTrain = unpickle(file+"data_batch_"+str(i))dataTrain = np.vstack([dataTrain, dictTrain['data']])labelTrain = np.hstack([labelTrain, dictTrain['labels']])dictTest = unpickle(file + "test_batch")dataTest = dictTest['data']labelTest = dictTest['labels']labelTest = np.array(labelTest)return dataTrain, labelTrain, dataTest, labelTestclass KNearestNeighbor(object):def __init__(self):self.X_train = Noneself.y_train = Nonedef train(self, X_train, y_train):"""KNN无需训练"""self.X_train = X_trainself.y_train = y_traindef compute_distances_L1(self, X_test):"""计算测试集和每个训练集的曼哈顿距离:param X_test: 测试集 numpy.ndarray:return: 测试集与训练集的欧氏距离数组 numpy.ndarray"""dists = np.zeros((X_test.shape[0], self.X_train.shape[0]))for i in range(X_test.shape[0]):dists[i] = np.sum( np.abs(self.X_train- X_test[i]), axis=1)return distsdef compute_distances_L2(self, X_test):"""计算测试集和每个训练集的欧氏距离向量化实现需转化公式后实现(单个循环不需要):param X_test: 测试集 numpy.ndarray:return: 测试集与训练集的欧氏距离数组 numpy.ndarray"""dists = np.zeros((X_test.shape[0], self.X_train.shape[0]))value_2xy = np.multiply(X_test.dot(self.X_train.T), -2)value_x2 = np.sum(np.square(X_test), axis=1, keepdims=True) #保持其维度不变value_y2 = np.sum(np.square(self.X_train), axis=1)dists = value_2xy + value_x2 + value_y2return distsdef predict_label(self, dists, k):"""选择前K个距离最近的标签,从这些标签中选择个数最多的作为预测分类:param dists: 欧氏距离:param k: 前K个分类:return: 预测分类(向量)"""y_pred = np.zeros(dists.shape[0])for i in range(dists.shape[0]):# 取前K个标签closest_y = self.y_train[np.argsort(dists[i, :])[:k]]# 取K个标签中个数最多的标签y_pred[i] = np.argmax(np.bincount(closest_y))return y_preddef predict(self, X_test, k, L):"""选择前K个距离最近的标签,从这些标签中选择个数最多的作为预测分类:param k: 前K个分类:param L: 1 : L1(曼哈顿距离) 2:L2(欧氏距离):return: 预测向量"""if(L==1):dists = self.compute_distances_L1(X_test)else:dists = self.compute_distances_L2(X_test)y_pred = self.predict_label(dists, k)return y_preddef Cross_validation(X_train, y_train):"""交叉验证,确定超参K,同时可视化K值:param X_train: 训练集:param y_train: 训练标签"""num_folds = 5k_choices = [1, 3, 5, 8, 10, 12, 15, 20, 50, 100]k_accuracy = {}# 将数据集分为5份, X_train_folds ([],[],[],[],[]) 列表里面有个5个narrayX_train_folds = np.array_split(X_train, num_folds)y_train_folds = np.array_split(y_train, num_folds)print("length of x_train_folds", len(X_train_folds))print("X_train shape", type(X_train[0]))print("X_train len", X_train_folds[0].shape)# 计算每种K值for k in k_choices:k_accuracy[k] = []# 每个K值分别计算每份数据集作为测试集时的正确率for index in range(num_folds):# 构建数据集X_te = X_train_folds[index]y_te = y_train_folds[index]X_tr = np.reshape( np.array(X_train_folds[:index] + X_train_folds[index + 1:]),(int(X_train.shape[0] * (num_folds - 1) / num_folds), -1) )y_tr = np.reshape(y_train_folds[:index] + y_train_folds[index + 1:],int(X_train.shape[0] * (num_folds - 1) / num_folds))# 预测结果classify = KNearestNeighbor()classify.train(X_tr, y_tr)y_te_pred = classify.predict(X_te, k, 2)accuracy = np.mean(y_te_pred == y_te)k_accuracy[k].append(accuracy)for k, accuracylist in k_accuracy.items():for accuracy in accuracylist:print("k = %d, accuracy = %.3f" % (k, accuracy))# 可视化K值效果for k in k_choices:accuracies = k_accuracy[k]plt.scatter([k] * len(accuracies), accuracies)accuracies_mean = np.array([np.mean(v) for k, v in sorted(k_accuracy.items())])accuracies_std = np.array([np.std(v) for k, v in sorted(k_accuracy.items())])# 根据均值和方差构建误差棒图plt.errorbar(k_choices, accuracies_mean, yerr=accuracies_std)plt.title('Cross-validation on k')plt.xlabel('k')plt.ylabel('Cross-validation accuracy')plt.show()dataTrain, labelTrain, dataTest, labelTest = load_CIFAR10(file_path)# print(dataTrain.shape)
# print(type(labelTrain))
# print(dataTest.shape)
# print(len(labelTest))Cross_validation(dataTrain[:1000,:], labelTrain[:1000])
'''
# find hyperparameters that work best on the validation set
validation_accuracies = []AccuaracyL1L2 = np.zeros([2, 8])
k_value = [1,2,4,8,16,32,64,128]for k in range(len(k_value)):knn = KNearestNeighbor()knn.train(dataTrain[:5000,:], labelTrain[:5000])label_predict = knn.predict(dataTest[:50,:], k_value[k], 1)AccuaracyL1L2[0][k] = np.mean( label_predict == labelTest[:50] )label_predict = knn.predict(dataTest[:50, :], k_value[k], 2)AccuaracyL1L2[1][k] = np.mean( label_predict == labelTest[:50] )accuracy = np.mean(AccuaracyL1L2, axis = 1)print(AccuaracyL1L2)if(accuracy[0] > accuracy[1]):print("L1 准确率大于 L2")print("最好的K取值为%d,最大准确率为 %f" % (k_value[ np.argmax(AccuaracyL1L2[0]) ], np.max(AccuaracyL1L2[0])))
else:print("L2 准确率大于 L1")print("最好的K取值为", np.max(AccuaracyL1L2[1]))print("最好的K取值为%d,最大准确率为 %f" % (k_value[np.argmax(AccuaracyL1L2[1])], np.max(AccuaracyL1L2[1])))
'''

可以看到每个K对应5个训练集,这里使用的L1距离(欧氏距离),下图是对应的可视化图像,相对与L2距离,L1在交叉验证,总体验证时的效果要比L2好。
这里写图片描述
这里写图片描述
下图是L2距离的效果验证
这里写图片描述
这里写图片描述
参考

https://zhuanlan.zhihu.com/p/20900216

这篇关于KNN cifar-10 L1 L2距离 交叉验证的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/991509

相关文章

Spring Security基于数据库验证流程详解

Spring Security 校验流程图 相关解释说明(认真看哦) AbstractAuthenticationProcessingFilter 抽象类 /*** 调用 #requiresAuthentication(HttpServletRequest, HttpServletResponse) 决定是否需要进行验证操作。* 如果需要验证,则会调用 #attemptAuthentica

C++ | Leetcode C++题解之第393题UTF-8编码验证

题目: 题解: class Solution {public:static const int MASK1 = 1 << 7;static const int MASK2 = (1 << 7) + (1 << 6);bool isValid(int num) {return (num & MASK2) == MASK1;}int getBytes(int num) {if ((num &

C语言 | Leetcode C语言题解之第393题UTF-8编码验证

题目: 题解: static const int MASK1 = 1 << 7;static const int MASK2 = (1 << 7) + (1 << 6);bool isValid(int num) {return (num & MASK2) == MASK1;}int getBytes(int num) {if ((num & MASK1) == 0) {return

easyui同时验证账户格式和ajax是否存在

accountName: {validator: function (value, param) {if (!/^[a-zA-Z][a-zA-Z0-9_]{3,15}$/i.test(value)) {$.fn.validatebox.defaults.rules.accountName.message = '账户名称不合法(字母开头,允许4-16字节,允许字母数字下划线)';return fal

easyui 验证下拉菜单select

validatebox.js中添加以下方法: selectRequired: {validator: function (value) {if (value == "" || value.indexOf('请选择') >= 0 || value.indexOf('全部') >= 0) {return false;}else {return true;}},message: '该下拉框为必选项'}

web群集--nginx配置文件location匹配符的优先级顺序详解及验证

文章目录 前言优先级顺序优先级顺序(详解)1. 精确匹配(Exact Match)2. 正则表达式匹配(Regex Match)3. 前缀匹配(Prefix Match) 匹配规则的综合应用验证优先级 前言 location的作用 在 NGINX 中,location 指令用于定义如何处理特定的请求 URI。由于网站往往需要不同的处理方式来适应各种请求,NGINX 提供了多种匹

React 笔记 父子组件传值 | 父组件调用子组件数据 | defaultProps | propsType合法性验证

1.通过props实现父组件像子组件传值 、方法、甚至整个父组件 传递整个父组件则   [变量名]={this} import Header from "./Header"render(){return(<Header msg={"我是props传递的数据"}/>)} import React,{Component} from "react";class Header extends

线性代数|机器学习-P35距离矩阵和普鲁克问题

文章目录 1. 距离矩阵2. 正交普鲁克问题3. 实例说明 1. 距离矩阵 假设有三个点 x 1 , x 2 , x 3 x_1,x_2,x_3 x1​,x2​,x3​,三个点距离如下: ∣ ∣ x 1 − x 2 ∣ ∣ 2 = 1 , ∣ ∣ x 2 − x 3 ∣ ∣ 2 = 1 , ∣ ∣ x 1 − x 3 ∣ ∣ 2 = 6 \begin{equation} ||x

Java验证辛钦大数定理

本实验通过程序模拟采集大量的样本数据来验证辛钦大数定理。   实验环境: 本实验采用Java语言编程,开发环境为Eclipse,图像生成使用JFreeChart类。   一,验证辛钦大数定理 由辛钦大数定理描述为: 辛钦大数定理(弱大数定理)  设随机变量序列 X1, X2, … 相互独立,服从同一分布,具有数学期望E(Xi) = μ, i = 1, 2, …, 则对于任意正数ε ,

模拟退火求n个点到某点距离和最短

/*找出一个点使得这个店到n个点的最长距离最短,即求最小覆盖圆的半径用一个点往各个方向扩展,如果结果更优,则继续以当前步长扩展,否则缩小步长*/#include<stdio.h>#include<math.h>#include<string.h>const double pi = acos(-1.0);struct point {double x,y;}p[1010];int