【猫狗分类】Pytorch VGG16 实现猫狗分类5-预测新图片

2024-06-16 23:52

本文主要是介绍【猫狗分类】Pytorch VGG16 实现猫狗分类5-预测新图片,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

背景
 

好了,现在开尝试预测新的图片,并且让vgg16模型判断是狗还是猫吧。

声明:整个数据和代码来自于b站,链接:使用pytorch框架手把手教你利用VGG16网络编写猫狗分类程序_哔哩哔哩_bilibili

预测

1、导包

from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from net import vgg16

2、设置新照片的路径

test_pth=r'.\img.png'#设置可以检测的图像
test=Image.open(test_pth)

3、处理图片:图片变成tensor

transform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])
image=transform(test)
  • transforms.Compose:这是一个类,可以将多个变换操作组合在一起。当你需要对数据执行一系列变换时,就会用到它。它接受一个变换函数列表作为参数。

4、设置设备

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")#CPU与GPU的选择

5、加载网络(vgg16net)

net =vgg16()#输入网络

6、加载模型(权重模型)

model=torch.load(r".\DogandCat5.pth",map_location=device)#已训练完成的结果权重输入
net.load_state_dict(model)#模型导入

网络是网络,模型是模型!模型是训练出来的权重模型!网络是认为设定的!

7、模式选择(是训练模式还是推理模式)

net.eval()#设置为推测模式
  • 在PyTorch中,net.eval()是一个非常重要的方法调用,它用于改变模型的状态,使其从训练模式切换到推理(推测)模式。理解这一点很重要,因为模型在两种模式下的行为有所不同:
  • 训练模式 (net.train()): 在这种模式下,模型中的所有层都会处于活跃状态,包括像Dropout和Batch Normalization这样的层,它们会在每次前向传播时根据训练数据进行更新,引入随机性和依赖于批次的统计信息。这对于学习模型参数是非常必要的。

  • 推理模式 (net.eval()): 调用net.eval()后,模型会进入推理模式。这时,Dropout层将不起作用(即总是通过),而Batch Normalization层会使用在训练过程中计算得到的移动平均和方差,而不是 mini-batch 中的统计信息。这意味着模型的输出对于相同的输入将变得确定性,这对于测试和预测非常重要,因为你希望对同一输入多次运行模型时得到相同的结果。

  • 总结来说,当你准备好使用训练好的模型对新数据进行预测,而不是继续修改模型参数时,就应该调用net.eval()来确保模型以正确、一致的方式进行推理

8、传图片到网络,调整输入维度为四维张量

image=torch.reshape(image,(1,3,224,224))#四维图形,RGB三个通道

在PyTorch中,使用torch.reshape或者更常用的torch.Tensor.view方法可以改变张量的形状。对于图像数据,特别是当您准备将图像输入到深度学习模型时,将其调整为适合模型输入维度的四维张量是很常见的操作。

9、开始预测

with torch.no_grad():out=net(image)
out=F.softmax(out,dim=1) #softmax转为概率学问题
out=out.data.cpu().numpy()
print(out)
a=int(out.argmax(1))#输出最大值位置
  • with torch.no_grad():: 这一行代码用来指示PyTorch在接下来的代码块中不记录任何梯度信息。这对于推理(预测)阶段是非常重要的,因为它可以减少内存使用并加速计算过程,因为不需要为反向传播做准备。

  • out=net(image): 在上下文管理器torch.no_grad()内,将处理过的图像image输入到神经网络模型net中进行前向传播,得到模型的原始输出out。这个输出通常是未经处理的概率分布,对于分类任务,它通常代表每个类别的得分

  • out=F.softmax(out, dim=1): 使用F.softmax函数对模型输出out进行处理,该函数会将每一行的数据转换为概率分布,确保所有元素之和为1。这里dim=1表示沿着类别维度(通常对应于神经网络输出的最后一维)进行softmax操作,使得每个样本的预测结果可以解释为各类别的概率。

例举:假设你有一个简单的分类任务,模型需要区分猫、狗、鸟三种动物,即共有3个类别。你使用一个神经网络模型进行预测,对于一个批次内单个样本的输出可能看起来像这样(在未经过softmax处理前):

out_before_softmax = torch.tensor([2.0, 1.0, 0.5], dtype=torch.float32)

这里的输出张量out_before_softmax表示模型对于这个样本属于三个类别的原始打分或logits。注意,这些数值没有直接的概率意义,它们可以是任意实数。

应用Softmax

为了将这些原始分数转化为概率分布,你将使用F.softmax函数,并且指定dim=1,因为在这个一维张量的情况下,类别维度自然就是最后一维。执行操作后:

import torch.nn.functional as F
out_after_softmax = F.softmax(out_before_softmax, dim=1)
print(out_after_softmax)

输出解释

执行上述代码后,你可能会看到类似以下的输出(具体数值可能因四舍五入略有不同):

tensor([0.5561, 0.2476, 0.1963])

现在,out_after_softmax中的每个元素代表样本属于对应类别的概率,且所有概率之和为1(或接近1,由于浮点运算的精度限制)。例如,这里模型认为该样本有大约55.61%的概率是猫,24.76%的概率是狗,以及19.63%的概率是鸟。

总结

通过指定dim=1,你告诉softmax函数沿张量的最后一维进行操作,这在多分类任务中至关重要,因为它确保了每个样本的预测能够被合理地解释为各类别的概率分布。

  • out=out.data.cpu().numpy(): 将张量out从GPU(如果有的话)复制到CPU上,并转换为numpy数组,以便于进一步的处理和显示。这样做是因为后续的操作可能涉及到非PyTorch的库,如matplotlib用于绘图。

  • a=int(out.argmax(1)): 找出概率最大的类别索引,即预测的类别。argmax(1)沿着第1维度(类别维度)找到最大值的索引。argmax函数是用来找出数组或张量中最大值所在的位置(索引)。

10、显示图像

plt.figure()
list=['Cat','Dog']
plt.suptitle("Classes:{}:{:.1%}".format(list[a],out[0,a]))
plt.imshow(test)
plt.show()

  • list=['Cat','Dog']: 定义了一个类别标签列表,这里简化为猫和狗两类。
  • plt.suptitle(...): 设置图表的主标题,显示预测的类别名称及最高概率的百分比。
  • plt.imshow(test): 显示原始测试图像。
  • plt.show(): 显示整个图表,包括图像和标题。

给一张柴犬的照片,预测下:

                        

这篇关于【猫狗分类】Pytorch VGG16 实现猫狗分类5-预测新图片的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1067906

相关文章

MySQL中查找重复值的实现

《MySQL中查找重复值的实现》查找重复值是一项常见需求,比如在数据清理、数据分析、数据质量检查等场景下,我们常常需要找出表中某列或多列的重复值,具有一定的参考价值,感兴趣的可以了解一下... 目录技术背景实现步骤方法一:使用GROUP BY和HAVING子句方法二:仅返回重复值方法三:返回完整记录方法四:

IDEA中新建/切换Git分支的实现步骤

《IDEA中新建/切换Git分支的实现步骤》本文主要介绍了IDEA中新建/切换Git分支的实现步骤,通过菜单创建新分支并选择是否切换,创建后在Git详情或右键Checkout中切换分支,感兴趣的可以了... 前提:项目已被Git托管1、点击上方栏Git->NewBrancjsh...2、输入新的分支的

Python实现对阿里云OSS对象存储的操作详解

《Python实现对阿里云OSS对象存储的操作详解》这篇文章主要为大家详细介绍了Python实现对阿里云OSS对象存储的操作相关知识,包括连接,上传,下载,列举等功能,感兴趣的小伙伴可以了解下... 目录一、直接使用代码二、详细使用1. 环境准备2. 初始化配置3. bucket配置创建4. 文件上传到os

关于集合与数组转换实现方法

《关于集合与数组转换实现方法》:本文主要介绍关于集合与数组转换实现方法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、Arrays.asList()1.1、方法作用1.2、内部实现1.3、修改元素的影响1.4、注意事项2、list.toArray()2.1、方

使用Python实现可恢复式多线程下载器

《使用Python实现可恢复式多线程下载器》在数字时代,大文件下载已成为日常操作,本文将手把手教你用Python打造专业级下载器,实现断点续传,多线程加速,速度限制等功能,感兴趣的小伙伴可以了解下... 目录一、智能续传:从崩溃边缘抢救进度二、多线程加速:榨干网络带宽三、速度控制:做网络的好邻居四、终端交互

java实现docker镜像上传到harbor仓库的方式

《java实现docker镜像上传到harbor仓库的方式》:本文主要介绍java实现docker镜像上传到harbor仓库的方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 前 言2. 编写工具类2.1 引入依赖包2.2 使用当前服务器的docker环境推送镜像2.2

C++20管道运算符的实现示例

《C++20管道运算符的实现示例》本文简要介绍C++20管道运算符的使用与实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录标准库的管道运算符使用自己实现类似的管道运算符我们不打算介绍太多,因为它实际属于c++20最为重要的

Java easyExcel实现导入多sheet的Excel

《JavaeasyExcel实现导入多sheet的Excel》这篇文章主要为大家详细介绍了如何使用JavaeasyExcel实现导入多sheet的Excel,文中的示例代码讲解详细,感兴趣的小伙伴可... 目录1.官网2.Excel样式3.代码1.官网easyExcel官网2.Excel样式3.代码

python实现对数据公钥加密与私钥解密

《python实现对数据公钥加密与私钥解密》这篇文章主要为大家详细介绍了如何使用python实现对数据公钥加密与私钥解密,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录公钥私钥的生成使用公钥加密使用私钥解密公钥私钥的生成这一部分,使用python生成公钥与私钥,然后保存在两个文

浏览器插件cursor实现自动注册、续杯的详细过程

《浏览器插件cursor实现自动注册、续杯的详细过程》Cursor简易注册助手脚本通过自动化邮箱填写和验证码获取流程,大大简化了Cursor的注册过程,它不仅提高了注册效率,还通过友好的用户界面和详细... 目录前言功能概述使用方法安装脚本使用流程邮箱输入页面验证码页面实战演示技术实现核心功能实现1. 随机