将自己的数据集加载到dataloader中

2024-04-23 18:04
文章标签 数据 加载 dataloader

本文主要是介绍将自己的数据集加载到dataloader中,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

from torch.utils.data import Dataset
class YourDataset(Dataset):  # 继承Dataset类# 构造函数必须存在def __init__(self, root_dir, ann_file, transform=None):self.ann_file = ann_fileself.root_dir = root_dirself.img_label = self.load_annotations()  # img_label是一个字典self.img = [os.path.join(self.root_dir, img) for img in list(self.img_label.keys())]self.label = [label for label in list(self.img_label.values())]self.transform = transform  # 数据需要做的预处理操作def __len__(self):return len(self.img)# 获取图像和标签交给模型,该函数必须存在# 不要修改参数,每次调用时会传入随机的idx# 一个batch的数据就是由__getitem__函数处理数据传入得到的def __getitem__(self, idx):image = Image.open(self.img[idx]).convert('RGB')  # img保存了图像的路径label = self.label[idx]if self.transform:image = self.transform(image)  # 对数据进行预处理操作label = torch.from_numpy(np.array(label))  # 转换label的数据类型,由list->numpy->tensorreturn image, labeldef load_annotations(self):data_infos = {}with open(self.ann_file) as f:samples = [x.strip().split(' ') for x in f.readlines()]for filename, gt_label in samples:data_infos[filename] = np.array(gt_label, dtype=np.int64)return data_infos

前言

使用开源算法进行小规模的训练,例如分类任务需要加载自己的数据集和类别,需要使用dataloader格式的数据集。现在通过以下的方式将自己的数据集制作为dataloader格式:

参考

参考文档:深度学习(17)--DataLoader自定义数据集制作_自定义dataloader-CSDN博客

首先感谢分享,但是文档中有较多不明白的地方,并且也没有引用,在此上做了相对的改进:

实现过程

1 从txt文件中读取图片文件名和对应label

import numpy as np
def load_annotation(ann_file): # 参数为文本文件的路径# 创建一个字典结构用于保存数据,key作为图像的名字,value作为图像的标签data_infos = {}with open(ann_file) as f:# strip()去除一些换行符等# split(' ')是以空格为分隔符# samples是一个list,格式为图像名字,图像标签# eg:[['image11.jpg,'0'],['image22.jpg,'1'],['image33.jpg,'3']]samples = [x.strip().split(' ') for x in f.readlines()]for filename, gt_label in samples:# filename是图像名字--'image11.jpg',gt_label--'0'是标签,加载到字典data_infos中去# value值设置为array(gt_label,dtype=int64)类型data_infos[filename] = np.array(gt_label, dtype=np.int64)# 得到的字典格式:{'image11.jpg':array(0,dtype=int64),'image22.jpg':array(1,dtype=int64)}return data_infos

文件格式如下:

c1.jpg 1
c2.jpg 1
c3.jpg 1
c4.jpg 1
c5.jpg 1
l1.jpg 2
l2.jpg 2
l3.jpg 2
l4.jpg 2
l5.jpg 2
l6.jpg 2
q1.jpg 3
q2.jpg 3
q3.jpg 3
q4.jpg 3
q5.jpg 3
q6.jpg 3

2 文件名和标签存入list

img_label = load_annotation('testload.txt')
image_name = list(img_label.keys())  # 取keys值
label = list(img_label.values())  # 取labels值
print(img_label.keys()) # 参看你的数据

3 整合为数据类

from torch.utils.data import Dataset
class YourDataset(Dataset):  # 继承Dataset类# 构造函数必须存在def __init__(self, root_dir, ann_file, transform=None):self.ann_file = ann_fileself.root_dir = root_dirself.img_label = self.load_annotations()  # img_label是一个字典self.img = [os.path.join(self.root_dir, img) for img in list(self.img_label.keys())]self.label = [label for label in list(self.img_label.values())]self.transform = transform  # 数据需要做的预处理操作def __len__(self):return len(self.img)# 获取图像和标签交给模型,该函数必须存在# 不要修改参数,每次调用时会传入随机的idx# 一个batch的数据就是由__getitem__函数处理数据传入得到的def __getitem__(self, idx):image = Image.open(self.img[idx]).convert('RGB')  # img保存了图像的路径label = self.label[idx]if self.transform:image = self.transform(image)  # 对数据进行预处理操作label = torch.from_numpy(np.array(label))  # 转换label的数据类型,由list->numpy->tensorreturn image, labeldef load_annotations(self):data_infos = {}with open(self.ann_file) as f:samples = [x.strip().split(' ') for x in f.readlines()]for filename, gt_label in samples:data_infos[filename] = np.array(gt_label, dtype=np.int64)return data_infos

4 使用transform函数进行数据处理

# 创建一个字典结构的数据类型来进行图像预处理操作:key - value
import torchvision.transforms as transforms
data_transforms = {# 对训练集的预处理'train': transforms.Compose([transforms.Resize([256, 256]),  # 卷积神经网络处理的数据大小必须相同,通过Resize来设置# 数据增强transforms.RandomRotation(45),  # 随机旋转,-45到45度之间随机选#transforms.CenterCrop(64),  # 从中心开始裁剪,将原本96x96大小的图片数据裁剪为64x64大小的图片数据,可以获取更多的参数transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转 选择一个概率概率,50%的概率进行水平翻转transforms.RandomVerticalFlip(p=0.5),  # 随 机垂直翻转,50%的概率进行竖直翻转#transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),  # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相#transforms.RandomGrayscale(p=0.025),  # 概率转换成灰度率,3通道就是R=G=B(三颜色通道转为单一颜色通道,很少进行此处理)# 将数据转为Tensor类型transforms.ToTensor(),# 标准化transforms.Normalize([0.5, 0.5, 0.5], [0.224, 0.224, 0.225])  # 设置均值,标准差,分别对应R、G、B三个颜色通道的三个均值和标准差值,(x-μ)/σ]),# 对验证集的预处理(不需要进行数据增强)'valid': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])# 均值和标准差数值的设置和训练集的相同(验证集的数据对我们来说是未知的,不能利用其中的数据再计算出相关的均值和标准差)]),
}

5 进行实例化

import os
from torch.utils.data import DataLoader
# 训练集
train_dataset = YourDataset(root_dir='.../...', ann_file='xxx.txt', transform=data_transforms['train'])
# 测试集
#valid_dataset = YourDataset(root_dir=valid_dir, ann_file='./.../valid.txt', transform=data_transforms['valid'])
# 实例化DataLoader(使用封装好的DataLoader包)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
#valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=True)

6 调用使用你的数据

dataloaders = {'train': train_loader}
#dataloaders = {'train': train_loader, "valid": valid_loader}
for inputs, labels in dataloaders['train']:print("处理训练集")
#for inputs, labels in dataloaders['valid']:
#    print("处理验证集")#或者使用枚举
#for i, (imgs, labels) in enumerate(dataloaders['train']):#测试完成后你可以将以上6步写入你的程序,或封装成数据读取包

7 检查和展示你的数据

# 检查训练集
from PIL import Image
import torch
import matplotlib.pyplot as plt
image1, label1 = next(iter(train_loader))  # iter表示train_loader进行迭代,next取一个batch的数据
sample = image1[0].squeeze()  # 通过squeeze()压缩一个维度,有时候维度为1x3x64x64,去除这个1
# 此时的sample是3x64x64的结构,而需要图像展示则需要转换结构为64X64X3,同时需要转换为numpy数据结构
sample = sample.permute((1, 2, 0)).numpy()
# 标准化还原 x = (x-μ) / σ -> x = x*σ + μ (预处理中进行了标准化,需要还原)
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label1[0].numpy()))# 检查训练集
#image2, label2 = next(iter(valid_loader))  # iter表示train_loader进行迭代,next取一个batch的数据
#sample = image2[0].squeeze()  # 通过squeeze()压缩一个维度,有时候维度为1x3x64x64,去除这个1
# 此时的sample是3x64x64的结构,而需要图像展示则需要转换结构为64X64X3,同时需要转换为numpy数据结构
#sample = sample.permute((1, 2, 0)).numpy()
# 标准化还原 x = (x-μ) / σ -> x = x*σ + μ (预处理中进行了标准化,需要还原)
#sample *= [0.229, 0.224, 0.225]
#sample += [0.485, 0.456, 0.406]
#plt.imshow(sample)
#plt.show()
#print('Label is: {}'.format(label2[0].numpy()))

这篇关于将自己的数据集加载到dataloader中的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/929501

相关文章

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

基于MySQL Binlog的Elasticsearch数据同步实践

一、为什么要做 随着马蜂窝的逐渐发展,我们的业务数据越来越多,单纯使用 MySQL 已经不能满足我们的数据查询需求,例如对于商品、订单等数据的多维度检索。 使用 Elasticsearch 存储业务数据可以很好的解决我们业务中的搜索需求。而数据进行异构存储后,随之而来的就是数据同步的问题。 二、现有方法及问题 对于数据同步,我们目前的解决方案是建立数据中间表。把需要检索的业务数据,统一放到一张M

关于数据埋点,你需要了解这些基本知识

产品汪每天都在和数据打交道,你知道数据来自哪里吗? 移动app端内的用户行为数据大多来自埋点,了解一些埋点知识,能和数据分析师、技术侃大山,参与到前期的数据采集,更重要是让最终的埋点数据能为我所用,否则可怜巴巴等上几个月是常有的事。   埋点类型 根据埋点方式,可以区分为: 手动埋点半自动埋点全自动埋点 秉承“任何事物都有两面性”的道理:自动程度高的,能解决通用统计,便于统一化管理,但个性化定

使用SecondaryNameNode恢复NameNode的数据

1)需求: NameNode进程挂了并且存储的数据也丢失了,如何恢复NameNode 此种方式恢复的数据可能存在小部分数据的丢失。 2)故障模拟 (1)kill -9 NameNode进程 [lytfly@hadoop102 current]$ kill -9 19886 (2)删除NameNode存储的数据(/opt/module/hadoop-3.1.4/data/tmp/dfs/na

异构存储(冷热数据分离)

异构存储主要解决不同的数据,存储在不同类型的硬盘中,达到最佳性能的问题。 异构存储Shell操作 (1)查看当前有哪些存储策略可以用 [lytfly@hadoop102 hadoop-3.1.4]$ hdfs storagepolicies -listPolicies (2)为指定路径(数据存储目录)设置指定的存储策略 hdfs storagepolicies -setStoragePo

Hadoop集群数据均衡之磁盘间数据均衡

生产环境,由于硬盘空间不足,往往需要增加一块硬盘。刚加载的硬盘没有数据时,可以执行磁盘数据均衡命令。(Hadoop3.x新特性) plan后面带的节点的名字必须是已经存在的,并且是需要均衡的节点。 如果节点不存在,会报如下错误: 如果节点只有一个硬盘的话,不会创建均衡计划: (1)生成均衡计划 hdfs diskbalancer -plan hadoop102 (2)执行均衡计划 hd

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

烟火目标检测数据集 7800张 烟火检测 带标注 voc yolo

一个包含7800张带标注图像的数据集,专门用于烟火目标检测,是一个非常有价值的资源,尤其对于那些致力于公共安全、事件管理和烟花表演监控等领域的人士而言。下面是对此数据集的一个详细介绍: 数据集名称:烟火目标检测数据集 数据集规模: 图片数量:7800张类别:主要包含烟火类目标,可能还包括其他相关类别,如烟火发射装置、背景等。格式:图像文件通常为JPEG或PNG格式;标注文件可能为X

pandas数据过滤

Pandas 数据过滤方法 Pandas 提供了多种方法来过滤数据,可以根据不同的条件进行筛选。以下是一些常见的 Pandas 数据过滤方法,结合实例进行讲解,希望能帮你快速理解。 1. 基于条件筛选行 可以使用布尔索引来根据条件过滤行。 import pandas as pd# 创建示例数据data = {'Name': ['Alice', 'Bob', 'Charlie', 'Dav

SWAP作物生长模型安装教程、数据制备、敏感性分析、气候变化影响、R模型敏感性分析与贝叶斯优化、Fortran源代码分析、气候数据降尺度与变化影响分析

查看原文>>>全流程SWAP农业模型数据制备、敏感性分析及气候变化影响实践技术应用 SWAP模型是由荷兰瓦赫宁根大学开发的先进农作物模型,它综合考虑了土壤-水分-大气以及植被间的相互作用;是一种描述作物生长过程的一种机理性作物生长模型。它不但运用Richard方程,使其能够精确的模拟土壤中水分的运动,而且耦合了WOFOST作物模型使作物的生长描述更为科学。 本文让更多的科研人员和农业工作者