DataLoader基础用法

2024-06-09 19:36
文章标签 基础 用法 dataloader

本文主要是介绍DataLoader基础用法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DataLoader 是 PyTorch 中一个非常有用的工具,用于将数据集进行批处理,并提供一个迭代器来简化模型训练和评估过程。以下是 DataLoader 的常见用法和功能介绍:

基本用法

  1. 创建数据集
    首先,需要一个数据集。数据集可以是 PyTorch 提供的内置数据集,也可以是自定义的数据集。数据集需要继承 torch.utils.data.Dataset 并实现 __len____getitem__ 方法。

    import torch
    import torch.utils.data as Dataclass MyDataSet(Data.Dataset):def __init__(self, enc_inputs, dec_inputs, dec_outputs):self.enc_inputs = enc_inputsself.dec_inputs = dec_inputsself.dec_outputs = dec_outputsdef __len__(self):return len(self.enc_inputs)def __getitem__(self, idx):return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]
    
  2. 创建 DataLoader
    DataLoader 用于将数据集封装成批次,并提供一个迭代器来进行数据的加载。常见的参数包括数据集、批量大小、是否打乱数据、使用的进程数等。

    enc_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
    dec_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
    dec_outputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])dataset = MyDataSet(enc_inputs, dec_inputs, dec_outputs)
    loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True)
    
  3. 迭代数据
    使用 DataLoader 的迭代器来访问批次数据。

    for batch in loader:enc_batch, dec_batch, output_batch = batchprint(enc_batch)print(dec_batch)print(output_batch)
    

常见参数

  1. dataset

    • 数据集对象,必须继承 torch.utils.data.Dataset 类。
  2. batch_size

    • 每个批次的大小,默认为 1。
  3. shuffle

    • 是否在每个 epoch 开始时打乱数据,默认为 False
  4. num_workers

    • 使用多少个子进程来加载数据。0 表示数据将在主进程中加载。对于大型数据集,增加 num_workers 可以加快数据加载速度。
  5. drop_last

    • 如果设置为 True,则丢弃不能整除 batch_size 的最后一个不完整的批次。
  6. pin_memory

    • 如果设置为 True,DataLoader 将在返回前将张量复制到 CUDA 固定内存中。这对 GPU 训练有所帮助。

进阶用法

  1. 自定义 collate_fn

    • collate_fn 用于指定如何将多个样本合并成一个批次。默认情况下,DataLoader 将使用 default_collate,它会将相同类型的数据合并在一起。例如,所有张量数据将合并成一个张量。
    def my_collate_fn(batch):enc_inputs, dec_inputs, dec_outputs = zip(*batch)return torch.stack(enc_inputs, 0), torch.stack(dec_inputs, 0), torch.stack(dec_outputs, 0)loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=my_collate_fn)
    
  2. 使用 Sampler

    • Sampler 用于指定如何抽样数据。PyTorch 提供了一些内置的采样器,如 RandomSamplerSequentialSampler
    from torch.utils.data.sampler import RandomSamplersampler = RandomSampler(dataset)
    loader = Data.DataLoader(dataset=dataset, batch_size=2, sampler=sampler)
    

完整示例

import torch
import torch.utils.data as Dataclass MyDataSet(Data.Dataset):def __init__(self, enc_inputs, dec_inputs, dec_outputs):self.enc_inputs = enc_inputsself.dec_inputs = dec_inputsself.dec_outputs = dec_outputsdef __len__(self):return len(self.enc_inputs)def __getitem__(self, idx):return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]enc_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
dec_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
dec_outputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])dataset = MyDataSet(enc_inputs, dec_inputs, dec_outputs)
loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True)for batch in loader:enc_batch, dec_batch, output_batch = batchprint("Encoder batch:", enc_batch)print("Decoder batch:", dec_batch)print("Output batch:", output_batch)

通过使用 DataLoader,我们可以轻松地处理和批量化我们的数据,这对于大型数据集和深度学习模型的训练是非常重要的。

这篇关于DataLoader基础用法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1046105

相关文章

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]

【Linux 从基础到进阶】Ansible自动化运维工具使用

Ansible自动化运维工具使用 Ansible 是一款开源的自动化运维工具,采用无代理架构(agentless),基于 SSH 连接进行管理,具有简单易用、灵活强大、可扩展性高等特点。它广泛用于服务器管理、应用部署、配置管理等任务。本文将介绍 Ansible 的安装、基本使用方法及一些实际运维场景中的应用,旨在帮助运维人员快速上手并熟练运用 Ansible。 1. Ansible的核心概念

AI基础 L9 Local Search II 局部搜索

Local Beam search 对于当前的所有k个状态,生成它们的所有可能后继状态。 检查生成的后继状态中是否有任何状态是解决方案。 如果所有后继状态都不是解决方案,则从所有后继状态中选择k个最佳状态。 当达到预设的迭代次数或满足某个终止条件时,算法停止。 — Choose k successors randomly, biased towards good ones — Close

bytes.split的用法和注意事项

当然,我很乐意详细介绍 bytes.Split 的用法和注意事项。这个函数是 Go 标准库中 bytes 包的一个重要组成部分,用于分割字节切片。 基本用法 bytes.Split 的函数签名如下: func Split(s, sep []byte) [][]byte s 是要分割的字节切片sep 是用作分隔符的字节切片返回值是一个二维字节切片,包含分割后的结果 基本使用示例: pa

音视频入门基础:WAV专题(10)——FFmpeg源码中计算WAV音频文件每个packet的pts、dts的实现

一、引言 从文章《音视频入门基础:WAV专题(6)——通过FFprobe显示WAV音频文件每个数据包的信息》中我们可以知道,通过FFprobe命令可以打印WAV音频文件每个packet(也称为数据包或多媒体包)的信息,这些信息包含该packet的pts、dts: 打印出来的“pts”实际是AVPacket结构体中的成员变量pts,是以AVStream->time_base为单位的显

C 语言基础之数组

文章目录 什么是数组数组变量的声明多维数组 什么是数组 数组,顾名思义,就是一组数。 假如班上有 30 个同学,让你编程统计每个人的分数,求最高分、最低分、平均分等。如果不知道数组,你只能这样写代码: int ZhangSan_score = 95;int LiSi_score = 90;......int LiuDong_score = 100;int Zhou

c++基础版

c++基础版 Windows环境搭建第一个C++程序c++程序运行原理注释常亮字面常亮符号常亮 变量数据类型整型实型常量类型确定char类型字符串布尔类型 控制台输入随机数产生枚举定义数组数组便利 指针基础野指针空指针指针运算动态内存分配 结构体结构体默认值结构体数组结构体指针结构体指针数组函数无返回值函数和void类型地址传递函数传递数组 引用函数引用传参返回指针的正确写法函数返回数组

【QT】基础入门学习

文章目录 浅析Qt应用程序的主函数使用qDebug()函数常用快捷键Qt 编码风格信号槽连接模型实现方案 信号和槽的工作机制Qt对象树机制 浅析Qt应用程序的主函数 #include "mywindow.h"#include <QApplication>// 程序的入口int main(int argc, char *argv[]){// argc是命令行参数个数,argv是

【MRI基础】TR 和 TE 时间概念

重复时间 (TR) 磁共振成像 (MRI) 中的 TR(重复时间,repetition time)是施加于同一切片的连续脉冲序列之间的时间间隔。具体而言,TR 是施加一个 RF(射频)脉冲与施加下一个 RF 脉冲之间的持续时间。TR 以毫秒 (ms) 为单位,主要控制后续脉冲之前的纵向弛豫程度(T1 弛豫),使其成为显著影响 MRI 中的图像对比度和信号特性的重要参数。 回声时间 (TE)

UVM:callback机制的意义和用法

1. 作用         Callback机制在UVM验证平台,最大用处就是为了提高验证平台的可重用性。在不创建复杂的OOP层次结构前提下,针对组件中的某些行为,在其之前后之后,内置一些函数,增加或者修改UVM组件的操作,增加新的功能,从而实现一个环境多个用例。此外还可以通过Callback机制构建异常的测试用例。 2. 使用步骤         (1)在UVM组件中内嵌callback函