Pytorch:复写Dataset函数详解,以及Dataloader如何调用

2024-08-24 05:36

本文主要是介绍Pytorch:复写Dataset函数详解,以及Dataloader如何调用,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在 PyTorch 中,DatasetDataLoader 是数据加载和处理的重要组件。下面详细介绍 Dataset 类的作用及其 __len__()__getitem__() 方法,以及它们如何与 DataLoader 协作,包括数据打乱(shuffle)和批处理(batching)等功能。

Dataset

Dataset 是一个抽象基类,用于表示一个数据集。你需要继承这个基类并实现以下两个方法:

1. __len__()

  • 作用: 返回数据集中样本的总数量。

  • 返回值: 一个整数,表示数据集中样本的数量。

  • 用例: 当你需要知道数据集的大小时,例如在创建 DataLoader 对象时,DataLoader 需要知道数据集中有多少样本才能正确地进行批处理和打乱操作。

这一步确定你取数据的范围,如果你想一次取两个数据,需要在__len__()里面控制索引的长度

class MyDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)

2. __getitem__(index)

  • 作用: 根据给定的索引返回数据集中的一个样本。

  • 参数: index,一个整数,表示要获取的样本的索引。

  • 返回值: 返回一个样本的数据(和可能的标签),这可以是任何类型,例如一个图像和其标签、一个文本片段等等。

这里其实可操作性很大,比如你想每次dataloader得到的batch里面包含图片的路径,那么在这里return。

class MyDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]

DataLoader

DataLoader 是用于批量加载数据的工具,它接受一个 Dataset 对象并提供了以下功能:

1. 批处理(Batching)

  • 作用: 将数据集分成小批量,每次从数据集中取出一个批次的数据进行训练或评估。

  • 实现: DataLoader 根据 batch_size 参数将数据分批,每个批次包含 batch_size 个样本。

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

2. 数据打乱(Shuffling)

  • 作用: 在每个 epoch 开始时打乱数据的顺序,有助于模型更好地泛化。

  • 实现: 如果 shuffle=TrueDataLoader 会在每个 epoch 开始时创建一个打乱的索引列表,然后按这些索引顺序提取样本。这里的打乱索引范围就是从_len_函数获取的

过程:

  1. 创建一个索引列表 [0, 1, 2, ..., len(dataset) - 1]
  2. 如果 shuffle=True,打乱这个索引列表。
  3. 使用打乱后的索引列表来从数据集中提取样本。
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

3. 并行加载(Multi-threaded Loading)

  • 作用: 使用多个子进程并行加载数据,减少数据预处理和加载的时间。

  • 参数: num_workers 指定用于数据加载的子进程数量。

  • 实现: DataLoader 会启动 num_workers 个子进程来调用 Dataset__getitem__() 方法并加载数据。

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

总结

  1. Dataset:

    • __len__(): 返回数据集的总样本数。DataLoader 使用这个方法来确定数据的总量,以便进行正确的批处理和打乱。
    • __getitem__(index): 根据索引返回单个数据样本。DataLoader 会调用这个方法来获取每个批次的数据样本。
  2. DataLoader:

    • 负责将数据分批、打乱数据、并行加载等任务。
    • 使用 Dataset__len__() 来了解数据集的大小,并使用 __getitem__() 来获取每个样本。

通过这种方式,Dataset 提供了数据访问的接口,而 DataLoader 管理数据的加载、打乱和批处理等高级功能。

这篇关于Pytorch:复写Dataset函数详解,以及Dataloader如何调用的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1101593

相关文章

Spring Security基于数据库验证流程详解

Spring Security 校验流程图 相关解释说明(认真看哦) AbstractAuthenticationProcessingFilter 抽象类 /*** 调用 #requiresAuthentication(HttpServletRequest, HttpServletResponse) 决定是否需要进行验证操作。* 如果需要验证,则会调用 #attemptAuthentica

hdu1171(母函数或多重背包)

题意:把物品分成两份,使得价值最接近 可以用背包,或者是母函数来解,母函数(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v) 其中指数为价值,每一项的数目为(该物品数+1)个 代码如下: #include<iostream>#include<algorithm>

OpenHarmony鸿蒙开发( Beta5.0)无感配网详解

1、简介 无感配网是指在设备联网过程中无需输入热点相关账号信息,即可快速实现设备配网,是一种兼顾高效性、可靠性和安全性的配网方式。 2、配网原理 2.1 通信原理 手机和智能设备之间的信息传递,利用特有的NAN协议实现。利用手机和智能设备之间的WiFi 感知订阅、发布能力,实现了数字管家应用和设备之间的发现。在完成设备间的认证和响应后,即可发送相关配网数据。同时还支持与常规Sof

如何在页面调用utility bar并传递参数至lwc组件

1.在app的utility item中添加lwc组件: 2.调用utility bar api的方式有两种: 方法一,通过lwc调用: import {LightningElement,api ,wire } from 'lwc';import { publish, MessageContext } from 'lightning/messageService';import Ca

6.1.数据结构-c/c++堆详解下篇(堆排序,TopK问题)

上篇:6.1.数据结构-c/c++模拟实现堆上篇(向下,上调整算法,建堆,增删数据)-CSDN博客 本章重点 1.使用堆来完成堆排序 2.使用堆解决TopK问题 目录 一.堆排序 1.1 思路 1.2 代码 1.3 简单测试 二.TopK问题 2.1 思路(求最小): 2.2 C语言代码(手写堆) 2.3 C++代码(使用优先级队列 priority_queue)

K8S(Kubernetes)开源的容器编排平台安装步骤详解

K8S(Kubernetes)是一个开源的容器编排平台,用于自动化部署、扩展和管理容器化应用程序。以下是K8S容器编排平台的安装步骤、使用方式及特点的概述: 安装步骤: 安装Docker:K8S需要基于Docker来运行容器化应用程序。首先要在所有节点上安装Docker引擎。 安装Kubernetes Master:在集群中选择一台主机作为Master节点,安装K8S的控制平面组件,如AP

C++操作符重载实例(独立函数)

C++操作符重载实例,我们把坐标值CVector的加法进行重载,计算c3=c1+c2时,也就是计算x3=x1+x2,y3=y1+y2,今天我们以独立函数的方式重载操作符+(加号),以下是C++代码: c1802.cpp源代码: D:\YcjWork\CppTour>vim c1802.cpp #include <iostream>using namespace std;/*** 以独立函数

嵌入式Openharmony系统构建与启动详解

大家好,今天主要给大家分享一下,如何构建Openharmony子系统以及系统的启动过程分解。 第一:OpenHarmony系统构建      首先熟悉一下,构建系统是一种自动化处理工具的集合,通过将源代码文件进行一系列处理,最终生成和用户可以使用的目标文件。这里的目标文件包括静态链接库文件、动态链接库文件、可执行文件、脚本文件、配置文件等。      我们在编写hellowor

LabVIEW FIFO详解

在LabVIEW的FPGA开发中,FIFO(先入先出队列)是常用的数据传输机制。通过配置FIFO的属性,工程师可以在FPGA和主机之间,或不同FPGA VIs之间进行高效的数据传输。根据具体需求,FIFO有多种类型与实现方式,包括目标范围内FIFO(Target-Scoped)、DMA FIFO以及点对点流(Peer-to-Peer)。 FIFO类型 **目标范围FIFO(Target-Sc

019、JOptionPane类的常用静态方法详解

目录 JOptionPane类的常用静态方法详解 1. showInputDialog()方法 1.1基本用法 1.2带有默认值的输入框 1.3带有选项的输入对话框 1.4自定义图标的输入对话框 2. showConfirmDialog()方法 2.1基本用法 2.2自定义按钮和图标 2.3带有自定义组件的确认对话框 3. showMessageDialog()方法 3.1