【深度学习|Pytorch】torchvision.datasets.ImageFolder详解

2024-04-04 07:36

本文主要是介绍【深度学习|Pytorch】torchvision.datasets.ImageFolder详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

ImageFolder详解

  • 1、数据准备
  • 2、ImageFolder类的定义
    • transforms.ToTensor()解析
  • 3、ImageFolder返回对象

1、数据准备

创建一个文件夹,比如叫dataset,将cat和dog文件夹都放在dataset文件夹路径下:
在这里插入图片描述

2、ImageFolder类的定义

class ImageFolder(DatasetFolder):def __init__(self,root: str,transform: Optional[Callable] = None,target_transform: Optional[Callable] = None,loader: Callable[[str], Any] = default_loader,is_valid_file: Optional[Callable[[str], bool]] = None,):

可以看到,ImageFolder类有这几个参数:
root:图片存储的根目录,即存放不同类别图片文件夹的前一个路径。
transform:即对加载的这些图片进行的前处理的方式,这里可以传入一个实例化的torchvision.Compose()对象,里面包含了各种预处理的操作。
target_transform:对图片类别进行预处理,通常来说不会用到这一步,因此可以直接不传入参数,默认图像标签没有变换,如果需要进行标签的处理,同样可以传入一个实例化的torchvision.Compose()对象。
loader:表示图像数据加载的方式,通常采用默认的加载方式,ImageFolder加载图像的方式为调用PIL库,因此图像的通道顺序是RGB而非opencv的BGR
is_valid_file:获取图像文件路径的函数,并且可以检查是否有损坏的文件。
示例代码:

ROOT_TEST = 'dataset' #dataset/cat, dataset/dog
normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
val_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),normalize
])# 加载训练数据集
val_dataset = ImageFolder(ROOT_TEST, transform=val_transform)

transforms.ToTensor()解析

这里需要特别说一下ToTensor()这个函数的作用,刚接触深度学习的我那时以为只是单纯的将图像的ndarray和PIL格式转成Tensor格式,后来查看了一下源码之后发现,事情并没有这么简答!

   """Convert a PIL Image or ndarray to tensor and scale the values accordingly.This transform does not support torchscript.Converts a PIL Image or numpy.ndarray (H x W x C) in the range[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)or if the numpy.ndarray has dtype = np.uint8In the other cases, tensors are returned without scaling... note::Because the input image is scaled to [0.0, 1.0], this transformation should not be used whentransforming target image masks. See the `references`_ for implementing the transforms for image masks... _references: https://github.com/pytorch/vision/tree/main/references/segmentation"""

这是关于ToTensor()函数的注解,这里明确指出了ToTensor()可以将PIL和ndarray格式的图像数据转成Tensor并缩放它们的值,这里的缩放他们的值的意思在下面也指出了,即将[0, 255]的像素值域归一化[0, 1.0],并且图像转换成Tensor格式之后,维度的顺序也会发生一点变化,从一开始的HWC变成了CHW的排列方式。

3、ImageFolder返回对象

以第一部分为例,我们用一个val_dataset接收了ImageFolder的返回值,那么这个Val_dataset对象里面包含了什么呢:
val_dataset.classes:存放着根目录下的子文件夹的名称(类别名称)的列表。
val_dataset.class_to_idx:存放着类别名称和各自的索引,字典类型。
val_dataset.extensions:存放着ImageFolder可以读取的图像格式名称,元组类型。
val_dataset.targets:存放着根目录下每一张图的类别索引。
val_dataset.transform:我们提供的transform的方式。
val_dataset.imgs:存放着根目录下每一张图的路径和类别索引。元组列表类型。
以上是关于这个ImageFolder返回的对象的属性的解析。

此外,我们可以通过一个for循环来遍历整个val_dataset的所有图像数据,其中val_dataset[i]是一个元组类型的数据,val_dataset[i][0]代表了前处理后的图像数据,类型为tensor,以AlexNet为例,此时的tensor应该是3 * 224 * 224的维度。val_dataset[i][1]代表了图像的类别索引。
完整示例代码:

import torch
from AlexNet import AlexNet
from torch.autograd import Variable
from torchvision import transforms
from torchvision.transforms import ToPILImage
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader# ROOT_TRAIN = 'D:/pycharm/AlexNet/data/train'
ROOT_TEST = 'dataset'# 将图像的像素值归一化到[-1,1]之间
normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])val_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),normalize
])# 加载训练数据集
val_dataset = ImageFolder(ROOT_TEST, transform=val_transform)# 如果有NVIDA显卡,转到GPU训练,否则用CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'# 模型实例化,将模型转到device
model = AlexNet().to(device)# 加载train.py里训练好的模型
model.load_state_dict(torch.load(r'save_model/model_best.pth'))# 结果类型
classes = ["cat","dog"
]# 把Tensor转化为图片,方便可视化
show = ToPILImage()# 进入验证阶段
model.eval()
for i in range(10):x, y = val_dataset[i][0], val_dataset[i][1]# show():显示图片# show(x).show()# torch.unsqueeze(input, dim),input(Tensor):输入张量,dim (int):插入维度的索引,最终扩展张量维度为4维x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=False).to(device)with torch.no_grad():pred = model(x)# argmax(input):返回指定维度最大值的序号# 得到预测类别中最高的那一类,再把最高的这一类对应classes中的那一类predicted, actual = classes[torch.argmax(pred[0])], classes[y]# 输出预测值与真实值print(f'predicted:"{predicted}", actual:"{actual}"')

这篇关于【深度学习|Pytorch】torchvision.datasets.ImageFolder详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/875182

相关文章

Mysql 中的多表连接和连接类型详解

《Mysql中的多表连接和连接类型详解》这篇文章详细介绍了MySQL中的多表连接及其各种类型,包括内连接、左连接、右连接、全外连接、自连接和交叉连接,通过这些连接方式,可以将分散在不同表中的相关数据... 目录什么是多表连接?1. 内连接(INNER JOIN)2. 左连接(LEFT JOIN 或 LEFT

Java中switch-case结构的使用方法举例详解

《Java中switch-case结构的使用方法举例详解》:本文主要介绍Java中switch-case结构使用的相关资料,switch-case结构是Java中处理多个分支条件的一种有效方式,它... 目录前言一、switch-case结构的基本语法二、使用示例三、注意事项四、总结前言对于Java初学者

Linux内核之内核裁剪详解

《Linux内核之内核裁剪详解》Linux内核裁剪是通过移除不必要的功能和模块,调整配置参数来优化内核,以满足特定需求,裁剪的方法包括使用配置选项、模块化设计和优化配置参数,图形裁剪工具如makeme... 目录简介一、 裁剪的原因二、裁剪的方法三、图形裁剪工具四、操作说明五、make menuconfig

Node.js 中 http 模块的深度剖析与实战应用小结

《Node.js中http模块的深度剖析与实战应用小结》本文详细介绍了Node.js中的http模块,从创建HTTP服务器、处理请求与响应,到获取请求参数,每个环节都通过代码示例进行解析,旨在帮... 目录Node.js 中 http 模块的深度剖析与实战应用一、引言二、创建 HTTP 服务器:基石搭建(一

详解Java中的敏感信息处理

《详解Java中的敏感信息处理》平时开发中常常会遇到像用户的手机号、姓名、身份证等敏感信息需要处理,这篇文章主要为大家整理了一些常用的方法,希望对大家有所帮助... 目录前后端传输AES 对称加密RSA 非对称加密混合加密数据库加密MD5 + Salt/SHA + SaltAES 加密平时开发中遇到像用户的

Springboot使用RabbitMQ实现关闭超时订单(示例详解)

《Springboot使用RabbitMQ实现关闭超时订单(示例详解)》介绍了如何在SpringBoot项目中使用RabbitMQ实现订单的延时处理和超时关闭,通过配置RabbitMQ的交换机、队列和... 目录1.maven中引入rabbitmq的依赖:2.application.yml中进行rabbit

C语言线程池的常见实现方式详解

《C语言线程池的常见实现方式详解》本文介绍了如何使用C语言实现一个基本的线程池,线程池的实现包括工作线程、任务队列、任务调度、线程池的初始化、任务添加、销毁等步骤,感兴趣的朋友跟随小编一起看看吧... 目录1. 线程池的基本结构2. 线程池的实现步骤3. 线程池的核心数据结构4. 线程池的详细实现4.1 初

Python绘制土地利用和土地覆盖类型图示例详解

《Python绘制土地利用和土地覆盖类型图示例详解》本文介绍了如何使用Python绘制土地利用和土地覆盖类型图,并提供了详细的代码示例,通过安装所需的库,准备地理数据,使用geopandas和matp... 目录一、所需库的安装二、数据准备三、绘制土地利用和土地覆盖类型图四、代码解释五、其他可视化形式1.

SpringBoot使用Apache POI库读取Excel文件的操作详解

《SpringBoot使用ApachePOI库读取Excel文件的操作详解》在日常开发中,我们经常需要处理Excel文件中的数据,无论是从数据库导入数据、处理数据报表,还是批量生成数据,都可能会遇到... 目录项目背景依赖导入读取Excel模板的实现代码实现代码解析ExcelDemoInfoDTO 数据传输

如何用Java结合经纬度位置计算目标点的日出日落时间详解

《如何用Java结合经纬度位置计算目标点的日出日落时间详解》这篇文章主详细讲解了如何基于目标点的经纬度计算日出日落时间,提供了在线API和Java库两种计算方法,并通过实际案例展示了其应用,需要的朋友... 目录前言一、应用示例1、天安门升旗时间2、湖南省日出日落信息二、Java日出日落计算1、在线API2