PyTorch Demo-2 : 分类模型评估

2024-09-05 01:38

本文主要是介绍PyTorch Demo-2 : 分类模型评估,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. 预训练模型加载和预测

1.1 加载预训练参数

根据训练函数中保存的训练参数,使用 torch.load() 进行读取,再加载 model.load_state_dict()

def load_pretrained_model(model, path):"""Load the pretrained model:param model: the defined model:param path: path of the ".pth" file"""state = torch.load(path)model.load_state_dict(state['net'])model = CIFAR10_Net()
load_pretrained_model(model, 'ckpt.pth')

1.2 预测

使用预测时一定要加 model.eval() !!!(大坑)
1.2.1 单张图片预测

直接用读图的API,再transform到Tensor,读图的API主要包含 PIL,opencv,skimage等,如果要用 transforms.Resize() 等数据增强的操作需要用PIL打开,或者在前面加上 transforms.ToPILImage()

from PIL import Image
from torchvision import transformsimg = Image.open('test.jpg')
trans = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img = trans(img)
img = img.unsqueeze(0)model.eval()
out = model(img)
_, pred = torch.max(out, 1)
print(pred)
1.2.2 多张图

读取多张图像,转换为Tensor再concat,或者先concat再转换。剩下的操作同上。

2. 模型评估

2.1 基本概念

  • 真阳性(True Positive,TP):预测为正,实际也为正

  • 假阳性(Flase Positive,FP):预测为正,实际为负

  • 假阴性(False Negative,FN):预测与负、实际为正

  • 真阴性(True Negative,TN):预测为负、实际也为负

  • 假阳率(False Positive Rate,FPR):原本是错的预测为对的比例(越小越好,0为理想状态)
    F P R = F P F P + T N FPR = \frac{FP}{FP+TN} FPR=FP+TNFP

  • 真阳率(True Positive Rate,TPR):原本是对的预测为对的比例(越大越好,1为理想状态)

T P R = T P T P + F N TPR = \frac{TP}{TP+FN} TPR=TP+FNTP

  • 精确率(查准率,Precision):预测为对的当中,原本为对的比例(越大越好,1为理想状态)

P r e c i s i o n = T P T P + F P Precision = \frac{TP}{TP+FP} Precision=TP+FPTP

  • 召回率(查全率,Recall):原本为对的当中,预测为对的比例(越大越好,1为理想状态)

R e c a l l = T P T P + F N Recall = \frac{TP}{TP+FN} Recall=TP+FNTP

  • 准确率(Accuracy):预测对的(包括原本是对预测为对,原本是错的预测为错两种情形)占整个的比例(越大越好,1为理想状态)

A c c u r a c y = T P + T N T P + T N + F P + F N Accuracy = \frac{TP+TN}{TP+TN+FP+FN} Accuracy=TP+TN+FP+FNTP+TN

2.2 F-score

F分数是对准确率和召回率做一个权衡,公式为:
F β = ( 1 + β 2 ) ⋅ P r e c i s i o n ⋅ R e c a l l ( β 2 ⋅ P r e c i s i o n ) + R e c a l l F_{\beta} = (1+\beta^2)·\frac{Precision·Recall}{(\beta^2·Precision)+Recall} Fβ=(1+β2)(β2Precision)+RecallPrecisionRecall
β \beta β 用于调和Precision和Recall的重要性, 当 β \beta β 为1时同等重要,称为F1-score。

  • 微平均Micro-F1:计算出所有类别总的Precision和Recall,然后计算F1
  • 宏平均Macro-F1:计算出每一个类的Precison和Recall后计算F1,最后将F1平均
from sklearn.metrics import f1_scorema_f1 = f1_score(labels, y_pred, average='macro')
mi_f1 = f1_score(labels, y_pred, average='micro')
print(ma_f1, mi_f1)
"""
0.786153332589687,0.7858
"""

2.2 混淆矩阵

混淆矩阵是机器学习中总结分类模型预测结果的情形分析表,以矩阵形式将数据集中的记录按照真实的类别与分类模型预测的类别判断两个标准进行汇总。其中矩阵的行表示真实值,矩阵的列表示预测值。

以二分类为例:混淆矩阵分别用”0“和”1“代表负样本和正样本。FP代表实际类标签为”0“,但预测类标签为”1“的样本数量。其余,类似推理。

cm

调用 sklearn 计算混淆矩阵:

from sklearn.metrics import confusion_matrixcm = confusion_matrix(y_true, y_pred, labels=None, sample_weight=None)

使用 matplotlib 绘制混淆矩阵:

import matplotlib.pyplot as pltdef plot_confusion_matrix(cm, labels_name, title):plt.figure(figsize=(8, 8))cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]    # 归一化plt.imshow(cm, interpolation='nearest')    # 在特定的窗口上显示图像plt.title(title)    # 图像标题plt.colorbar()num_local = np.array(range(len(labels_name)))    plt.xticks(num_local, labels_name, rotation=90)    # 将标签印在x轴坐标上plt.yticks(num_local, labels_name)    # 将标签印在y轴坐标上plt.ylabel('True label')    plt.xlabel('Predicted label')plt.savefig('%s.jpg' % title, bbox_inches='tight')classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
plot_confusion_matrix(cm, labels_name=classes, title='CIFAR10_cm')

CM

2.3 ROC曲线 & AUC

要生成一个ROC曲线,只需要真阳性率(TPR)和假阳性率(FPR)。TPR决定了一个分类器或者一个诊断测试在所有阳性样本中能正确区分的阳性案例的性能.而FPR是决定了在所有阴性的样本中有多少假阳性的判断。ROC曲线中分别将FPR和TPR定义为x和y轴,这样就描述了真阳性(获利)和假阳性(成本)之间的博弈。而TPR就可以定义为灵敏度,而FPR就定义为1-特异度,因此ROC曲线有时候也叫做灵敏度和1-特异度图像。每一个预测结果在ROC曲线中以一个点代表。
有了ROC曲线后,可以引出AUC的含义:ROC曲线下的面积(越大越好,1为理想状态)。

(1)ROC曲线图中的四个点
第一个点,(0,1),即FPR=0, TPR=1,这意味着FN(false negative)=0,并且FP(false positive)=0。这是一个完美的分类器,它将所有的样本都正确分类。第二个点,(1,0),即FPR=1,TPR=0,类似地分析可以发现这是一个最糟糕的分类器,因为它成功避开了所有的正确答案。第三个点,(0,0),即FPR=TPR=0,即FP(false positive)=TP(true positive)=0,可以发现该分类器预测所有的样本都为负样本(negative)。类似的,第四个点(1,1),分类器实际上预测所有的样本都为正样本。经过以上的分析,,ROC曲线越接近左上角,该分类器的性能越好。

(2)ROC曲线图中的一条特殊线

考虑ROC曲线图中的虚线y=x上的点。这条对角线上的点其实表示的是一个采用随机猜测策略的分类器的结果,例如(0.5,0.5),表示该分类器随机对于一半的样本猜测其为正样本,另外一半的样本为负样本。

对于多分类问题,ROC曲线的获取主要有两种方法:
假设测试样本个数为m,类别个数为n。在训练完成后,计算出每个测试样本的在各类别下的概率或置信度,得到一个[m, n]形状的矩阵P,每一行表示一个测试样本在各类别下概率值(按类别标签排序)。相应地,将每个测试样本的标签转换为类似二进制的形式,每个位置用来标记是否属于对应的类别(也按标签排序,这样才和前面对应),由此也可以获得一个[m, n]的标签矩阵L。
①方法一:

每种类别下,都可以得到m个测试样本为该类别的概率(矩阵P中的列)。所以,根据概率矩阵P和标签矩阵L中对应的每一列,可以计算出各个阈值下的假正例率(FPR)和真正例率(TPR),从而绘制出一条ROC曲线。这样总共可以绘制出n条ROC曲线。最后对n条ROC曲线取平均,即可得到最终的ROC曲线。
②方法二:
首先,对于一个测试样本:1)标签只由0和1组成,1的位置表明了它的类别(可对应二分类问题中的‘’正’’),0就表示其他类别(‘’负‘’);2)要是分类器对该测试样本分类正确,则该样本标签中1对应的位置在概率矩阵P中的值是大于0对应的位置的概率值的。基于这两点,将标签矩阵L和概率矩阵P分别按行展开,转置后形成两列,这就得到了一个二分类的结果。所以,此方法经过计算后可以直接得到最终的ROC曲线。
上面的两个方法得到的ROC曲线是不同的,当然曲线下的面积AUC也是不一样的。 在python中,方法1和方法2分别对应 sklearn.metrics.roc_auc_score 函数中参数average值为"macro"和"micro"的情况。

调用 sklearn 计算ROC和AUC,需要的参数包含模型输出的得分(最后一层FC的输出经过softmax激活后的结果),独热编码的标签(one-hot处理):

from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarizescores = torch.softmax(out, dim=1).cpu().numpy() # out = model(data)
binary_label = label_binarize(labels, classes=list(range(num_classes))) # num_classes=10fpr = {}
tpr = {}
roc_auc = {}for i in range(num_classes):fpr[i], tpr[i], _ = roc_curve(binary_label[:, i], scores[:, i])roc_auc[i] = auc(fpr[i], tpr[i])# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(binary_label.ravel(), scores.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])# Compute macro-average ROC curve and ROC area
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(num_classes)]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(num_classes):mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
# Finally average it and compute AUC
mean_tpr /= num_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

使用 matplotlib 绘制ROC曲线:

plt.figure(figsize=(8, 8))
plt.plot(fpr["micro"], tpr["micro"],label='micro-average ROC curve (area = {0:0.2f})'.format(roc_auc["micro"]),color='deeppink', linestyle=':', linewidth=4)plt.plot(fpr["macro"], tpr["macro"],label='macro-average ROC curve (area = {0:0.2f})'.format(roc_auc["macro"]),color='navy', linestyle=':', linewidth=4)for i in range(10):plt.plot(fpr[i], tpr[i], lw=2,label='ROC curve of class {0} (area = {1:0.2f})'.format(i, roc_auc[i]))plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.grid()
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Multi-class ROC')
plt.legend(loc="lower right")
plt.savefig('Multi-class ROC.jpg', bbox_inches='tight')
plt.show()

ROC

2.4 CMC曲线(RANK曲线)

CMC曲线是算一种topk的击中概率,主要用来评估闭集中rank的正确率。

假如在人脸识别中,底库中有100个人,现在来了1个待识别的人脸(假如label为m1),与底库中的人脸比对后将底库中的人脸按照得分从高到低进行排序,我们发现:

如果识别结果是m1、m2、m3、m4、m5……,则此时rank-1的正确率为100%;rank-2的正确率也为100%;rank-5的正确率也为100%;
如果识别结果是m2、m1、m3、m4、m5……,则此时rank-1的正确率为0%;rank-2的正确率为100%;rank-5的正确率也为100%;
如果识别结果是m2、m3、m4、m5、m1……,则此时rank-1的正确率为0%;rank-2的正确率为0%;rank-5的正确率为100%;
同理,当待识别的人脸集合有很多时,则采取取平均值的做法。例如待识别人脸有3个(假如label为m1,m2,m3),同样对每一个人脸都有一个从高到低的得分:

比如人脸1结果为m1、m2、m3、m4、m5……,人脸2结果为m2、m1、m3、m4、m5……,人脸3结果m3、m1、m2、m4、m5……,则此时rank-1的正确率为(1+1+1)/3=100%;rank-2的正确率也为(1+1+1)/3=100%;rank-5的正确率也为(1+1+1)/3=100%;
比如人脸1结果为m4、m2、m3、m5、m6……,人脸2结果为m1、m2、m3、m4、m5……,人脸3结果m3、m1、m2、m4、m5……,则此时rank-1的正确率为(0+0+1)/3=33.33%;rank-2的正确率为(0+1+1)/3=66.66%;rank-5的正确率也为(0+1+1)/3=66.66%;

PyTorch中提供了 Tensor.topk() 函数,来获取前k大的值,这里k不超过类别个数。定义函数返回topk的正确个数,准确率则直接用正确个数除以总个数。

def accuracy(output, target, topk=(1,)):"""Computes the accuracy over the k top predictions for the specified values of k:param output: tensor, output of model:param target: tensor, label of input data:param topk: tuple, the k top predictions"""with torch.no_grad():maxk = max(topk)batch_size = target.size(0)_, pred = output.topk(maxk, 1, True, True)pred = pred.t()correct = pred.eq(target.view(1, -1).expand_as(pred))res = []for k in topk:correct_k = correct[:k].view(-1).float().sum(0, keepdim=True).item()# 每个 rank 包含的正确个数res.append(correct_k)return res"""
correct_k = accuracy(out, label, topk=(1, 5))
acc1, acc5 = correct_k / len(dataset)
"""

绘制CMC曲线,横坐标为k,纵坐标为topk准确率。

def Draw_CMC_Curve(acc, topk=10):""":param acc: list, the topk accuracy:param topk: int"""plt.figure(figsize=(8, 8))plt.plot(list(range(topk)), acc)plt.title('CMC_Curve')plt.grid()plt.xlim([0.0, topk-1])plt.ylim([0.0, 1.05])plt.xlabel('Rank')plt.ylabel('Accuracy')plt.savefig('CMC_Curve.jpg', bbox_inches='tight')plt.close()

CMC

Reference:

[1] ImageNet training in PyTorch

[2] PRID:行人重识别常用评测指标(rank-n、Precision & Recall、F-score、mAP 、CMC、ROC)

[3] ROC原理介绍及利用python实现二分类和多分类的ROC曲线

这篇关于PyTorch Demo-2 : 分类模型评估的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1137606

相关文章

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Spring AI Alibaba接入大模型时的依赖问题小结

《SpringAIAlibaba接入大模型时的依赖问题小结》文章介绍了如何在pom.xml文件中配置SpringAIAlibaba依赖,并提供了一个示例pom.xml文件,同时,建议将Maven仓... 目录(一)pom.XML文件:(二)application.yml配置文件(一)pom.xml文件:首

如何在本地部署 DeepSeek Janus Pro 文生图大模型

《如何在本地部署DeepSeekJanusPro文生图大模型》DeepSeekJanusPro模型在本地成功部署,支持图片理解和文生图功能,通过Gradio界面进行交互,展示了其强大的多模态处... 目录什么是 Janus Pro1. 安装 conda2. 创建 python 虚拟环境3. 克隆 janus

本地私有化部署DeepSeek模型的详细教程

《本地私有化部署DeepSeek模型的详细教程》DeepSeek模型是一种强大的语言模型,本地私有化部署可以让用户在自己的环境中安全、高效地使用该模型,避免数据传输到外部带来的安全风险,同时也能根据自... 目录一、引言二、环境准备(一)硬件要求(二)软件要求(三)创建虚拟环境三、安装依赖库四、获取 Dee

C#使用DeepSeek API实现自然语言处理,文本分类和情感分析

《C#使用DeepSeekAPI实现自然语言处理,文本分类和情感分析》在C#中使用DeepSeekAPI可以实现多种功能,例如自然语言处理、文本分类、情感分析等,本文主要为大家介绍了具体实现步骤,... 目录准备工作文本生成文本分类问答系统代码生成翻译功能文本摘要文本校对图像描述生成总结在C#中使用Deep

DeepSeek模型本地部署的详细教程

《DeepSeek模型本地部署的详细教程》DeepSeek作为一款开源且性能强大的大语言模型,提供了灵活的本地部署方案,让用户能够在本地环境中高效运行模型,同时保护数据隐私,在本地成功部署DeepSe... 目录一、环境准备(一)硬件需求(二)软件依赖二、安装Ollama三、下载并部署DeepSeek模型选

Golang的CSP模型简介(最新推荐)

《Golang的CSP模型简介(最新推荐)》Golang采用了CSP(CommunicatingSequentialProcesses,通信顺序进程)并发模型,通过goroutine和channe... 目录前言一、介绍1. 什么是 CSP 模型2. Goroutine3. Channel4. Channe

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)

《Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)》:本文主要介绍Python基于火山引擎豆包大模型搭建QQ机器人详细的相关资料,包括开通模型、配置APIKEY鉴权和SD... 目录豆包大模型概述开通模型付费安装 SDK 环境配置 API KEY 鉴权Ark 模型接口Prompt