本文主要是介绍【学习】pytorch框架的数据管理—— 理解Dataloader,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
参考:https://spite-triangle.github.io/artificial_intelligence/#/./README
1.标准数据集
使用:以 CIFAR10 数据集为例,其他数据集类似。
# root:数据存放路径
# train:区分训练集,还是测试集
# transform:对数据集中的图进行预处理
# target_transfrom:对期望输出进行预处理
# download:从网上直接下载数据集
torchvision.datasets.CIFAR10(root: str, train: bool=True, transform=None, target_transform=None, download=False)
2. 自定义数据集
常用的文件路径操作:
rootPath = '..\\asset'
path = '..\\asset\\cat.jpeg'
# 测试路径
os.path.exists(rootPath)
# 文件类型判断
os.path.isfile(path)
os.path.islink(path)
os.path.isdir(path)
# 获取绝对路径
os.path.abspath(rootPath)
# 罗列出文件夹下的所有文件名
os.listdir(rootPath)
# 路径拼接
os.path.join(rootPath,'cat.jpeg')
数据集:
class ImgaeAssets(torch.utils.data.Dataset):""" 自定义数据集类 """def __init__(self,path):self.root = pathself.files = os.listdir(path)passdef __getitem__(self,id):""" 用于数据集中的样本获取 """filePath = os.path.join(self.root,self.files[id])img = Image.open(filePath)return imgdef __len__(self):""" 数据的数量 """return len(self.files)# 创建数据集assets = ImgaeAssets('../asset')# 获取数据img = assets[0]img.show()
##重点 Dataloader
- 作用: 控制数据集 dataSets 的获取
用 dataloader 将 dataset 中的数据取出打包成 batch 的过程中,会通过 sampler 从 dataset 中取出 batch_size 个样本,然后通过 collect function 将取出的样本整理并打包成最终的 batch。
sampler 获取从 dataset 中获取样本,首先通过 len 获取总样本数,然后根据总样本数生成索引序列(数组的索引号),最后根据索引号通过 getitem 加载真正的样本数据(dataset 只预先加载了数据的文件路径,真正的文件并没直接加载)。
通过 sampler 获取到的数据样本,其实是一个「tuple(tensor) 类型数组」,并非真正的一个 tensor。将 tensor 数组最终整合成一个 tensor 就需要通过 dataset 的 collect function 实现。
# dataset:设置数据集
# batch_size:一个 batch 包含多少样本
# shuffle:下一次 epoch 是否需要将数据打乱,再划分 batch
# drop_last:当最后一个 batch 不具有 batch_size 个样本时,是否需要舍弃
# num_workers:线程数
# collate_fn:自定义 collate_fn
# sampler:自定义采集
torch.utils.data.DataLoader(dataset,batch_size,shuffle=False,drop_last=False,num_workers=0,worker_init_fn,collate_fn,sampler)
这篇关于【学习】pytorch框架的数据管理—— 理解Dataloader的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!