DataLoader基础用法

2024-06-09 19:36
文章标签 基础 用法 dataloader

本文主要是介绍DataLoader基础用法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DataLoader 是 PyTorch 中一个非常有用的工具,用于将数据集进行批处理,并提供一个迭代器来简化模型训练和评估过程。以下是 DataLoader 的常见用法和功能介绍:

基本用法

  1. 创建数据集
    首先,需要一个数据集。数据集可以是 PyTorch 提供的内置数据集,也可以是自定义的数据集。数据集需要继承 torch.utils.data.Dataset 并实现 __len____getitem__ 方法。

    import torch
    import torch.utils.data as Dataclass MyDataSet(Data.Dataset):def __init__(self, enc_inputs, dec_inputs, dec_outputs):self.enc_inputs = enc_inputsself.dec_inputs = dec_inputsself.dec_outputs = dec_outputsdef __len__(self):return len(self.enc_inputs)def __getitem__(self, idx):return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]
    
  2. 创建 DataLoader
    DataLoader 用于将数据集封装成批次,并提供一个迭代器来进行数据的加载。常见的参数包括数据集、批量大小、是否打乱数据、使用的进程数等。

    enc_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
    dec_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
    dec_outputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])dataset = MyDataSet(enc_inputs, dec_inputs, dec_outputs)
    loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True)
    
  3. 迭代数据
    使用 DataLoader 的迭代器来访问批次数据。

    for batch in loader:enc_batch, dec_batch, output_batch = batchprint(enc_batch)print(dec_batch)print(output_batch)
    

常见参数

  1. dataset

    • 数据集对象,必须继承 torch.utils.data.Dataset 类。
  2. batch_size

    • 每个批次的大小,默认为 1。
  3. shuffle

    • 是否在每个 epoch 开始时打乱数据,默认为 False
  4. num_workers

    • 使用多少个子进程来加载数据。0 表示数据将在主进程中加载。对于大型数据集,增加 num_workers 可以加快数据加载速度。
  5. drop_last

    • 如果设置为 True,则丢弃不能整除 batch_size 的最后一个不完整的批次。
  6. pin_memory

    • 如果设置为 True,DataLoader 将在返回前将张量复制到 CUDA 固定内存中。这对 GPU 训练有所帮助。

进阶用法

  1. 自定义 collate_fn

    • collate_fn 用于指定如何将多个样本合并成一个批次。默认情况下,DataLoader 将使用 default_collate,它会将相同类型的数据合并在一起。例如,所有张量数据将合并成一个张量。
    def my_collate_fn(batch):enc_inputs, dec_inputs, dec_outputs = zip(*batch)return torch.stack(enc_inputs, 0), torch.stack(dec_inputs, 0), torch.stack(dec_outputs, 0)loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=my_collate_fn)
    
  2. 使用 Sampler

    • Sampler 用于指定如何抽样数据。PyTorch 提供了一些内置的采样器,如 RandomSamplerSequentialSampler
    from torch.utils.data.sampler import RandomSamplersampler = RandomSampler(dataset)
    loader = Data.DataLoader(dataset=dataset, batch_size=2, sampler=sampler)
    

完整示例

import torch
import torch.utils.data as Dataclass MyDataSet(Data.Dataset):def __init__(self, enc_inputs, dec_inputs, dec_outputs):self.enc_inputs = enc_inputsself.dec_inputs = dec_inputsself.dec_outputs = dec_outputsdef __len__(self):return len(self.enc_inputs)def __getitem__(self, idx):return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]enc_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
dec_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
dec_outputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])dataset = MyDataSet(enc_inputs, dec_inputs, dec_outputs)
loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True)for batch in loader:enc_batch, dec_batch, output_batch = batchprint("Encoder batch:", enc_batch)print("Decoder batch:", dec_batch)print("Output batch:", output_batch)

通过使用 DataLoader,我们可以轻松地处理和批量化我们的数据,这对于大型数据集和深度学习模型的训练是非常重要的。

这篇关于DataLoader基础用法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1046105

相关文章

JavaScript中的reduce方法执行过程、使用场景及进阶用法

《JavaScript中的reduce方法执行过程、使用场景及进阶用法》:本文主要介绍JavaScript中的reduce方法执行过程、使用场景及进阶用法的相关资料,reduce是JavaScri... 目录1. 什么是reduce2. reduce语法2.1 语法2.2 参数说明3. reduce执行过程

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应

Python itertools中accumulate函数用法及使用运用详细讲解

《Pythonitertools中accumulate函数用法及使用运用详细讲解》:本文主要介绍Python的itertools库中的accumulate函数,该函数可以计算累积和或通过指定函数... 目录1.1前言:1.2定义:1.3衍生用法:1.3Leetcode的实际运用:总结 1.1前言:本文将详

MyBatis-Flex BaseMapper的接口基本用法小结

《MyBatis-FlexBaseMapper的接口基本用法小结》本文主要介绍了MyBatis-FlexBaseMapper的接口基本用法小结,文中通过示例代码介绍的非常详细,对大家的学习或者工作具... 目录MyBATis-Flex简单介绍特性基础方法INSERT① insert② insertSelec

深入解析Spring TransactionTemplate 高级用法(示例代码)

《深入解析SpringTransactionTemplate高级用法(示例代码)》TransactionTemplate是Spring框架中一个强大的工具,它允许开发者以编程方式控制事务,通过... 目录1. TransactionTemplate 的核心概念2. 核心接口和类3. TransactionT

数据库使用之union、union all、各种join的用法区别解析

《数据库使用之union、unionall、各种join的用法区别解析》:本文主要介绍SQL中的Union和UnionAll的区别,包括去重与否以及使用时的注意事项,还详细解释了Join关键字,... 目录一、Union 和Union All1、区别:2、注意点:3、具体举例二、Join关键字的区别&php

oracle中exists和not exists用法举例详解

《oracle中exists和notexists用法举例详解》:本文主要介绍oracle中exists和notexists用法的相关资料,EXISTS用于检测子查询是否返回任何行,而NOTE... 目录基本概念:举例语法pub_name总结 exists (sql 返回结果集为真)not exists (s

MySQL中my.ini文件的基础配置和优化配置方式

《MySQL中my.ini文件的基础配置和优化配置方式》文章讨论了数据库异步同步的优化思路,包括三个主要方面:幂等性、时序和延迟,作者还分享了MySQL配置文件的优化经验,并鼓励读者提供支持... 目录mysql my.ini文件的配置和优化配置优化思路MySQL配置文件优化总结MySQL my.ini文件

Springboot中Jackson用法详解

《Springboot中Jackson用法详解》Springboot自带默认json解析Jackson,可以在不引入其他json解析包情况下,解析json字段,下面我们就来聊聊Springboot中J... 目录前言Jackson用法将对象解析为json字符串将json解析为对象将json文件转换为json

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]