本文主要是介绍DataLoader基础用法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
DataLoader
是 PyTorch 中一个非常有用的工具,用于将数据集进行批处理,并提供一个迭代器来简化模型训练和评估过程。以下是 DataLoader
的常见用法和功能介绍:
基本用法
-
创建数据集:
首先,需要一个数据集。数据集可以是 PyTorch 提供的内置数据集,也可以是自定义的数据集。数据集需要继承torch.utils.data.Dataset
并实现__len__
和__getitem__
方法。import torch import torch.utils.data as Dataclass MyDataSet(Data.Dataset):def __init__(self, enc_inputs, dec_inputs, dec_outputs):self.enc_inputs = enc_inputsself.dec_inputs = dec_inputsself.dec_outputs = dec_outputsdef __len__(self):return len(self.enc_inputs)def __getitem__(self, idx):return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]
-
创建 DataLoader:
DataLoader
用于将数据集封装成批次,并提供一个迭代器来进行数据的加载。常见的参数包括数据集、批量大小、是否打乱数据、使用的进程数等。enc_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) dec_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) dec_outputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])dataset = MyDataSet(enc_inputs, dec_inputs, dec_outputs) loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True)
-
迭代数据:
使用DataLoader
的迭代器来访问批次数据。for batch in loader:enc_batch, dec_batch, output_batch = batchprint(enc_batch)print(dec_batch)print(output_batch)
常见参数
-
dataset:
- 数据集对象,必须继承
torch.utils.data.Dataset
类。
- 数据集对象,必须继承
-
batch_size:
- 每个批次的大小,默认为 1。
-
shuffle:
- 是否在每个 epoch 开始时打乱数据,默认为
False
。
- 是否在每个 epoch 开始时打乱数据,默认为
-
num_workers:
- 使用多少个子进程来加载数据。
0
表示数据将在主进程中加载。对于大型数据集,增加num_workers
可以加快数据加载速度。
- 使用多少个子进程来加载数据。
-
drop_last:
- 如果设置为
True
,则丢弃不能整除batch_size
的最后一个不完整的批次。
- 如果设置为
-
pin_memory:
- 如果设置为
True
,DataLoader 将在返回前将张量复制到 CUDA 固定内存中。这对 GPU 训练有所帮助。
- 如果设置为
进阶用法
-
自定义 collate_fn:
collate_fn
用于指定如何将多个样本合并成一个批次。默认情况下,DataLoader
将使用default_collate
,它会将相同类型的数据合并在一起。例如,所有张量数据将合并成一个张量。
def my_collate_fn(batch):enc_inputs, dec_inputs, dec_outputs = zip(*batch)return torch.stack(enc_inputs, 0), torch.stack(dec_inputs, 0), torch.stack(dec_outputs, 0)loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=my_collate_fn)
-
使用 Sampler:
Sampler
用于指定如何抽样数据。PyTorch 提供了一些内置的采样器,如RandomSampler
和SequentialSampler
。
from torch.utils.data.sampler import RandomSamplersampler = RandomSampler(dataset) loader = Data.DataLoader(dataset=dataset, batch_size=2, sampler=sampler)
完整示例
import torch
import torch.utils.data as Dataclass MyDataSet(Data.Dataset):def __init__(self, enc_inputs, dec_inputs, dec_outputs):self.enc_inputs = enc_inputsself.dec_inputs = dec_inputsself.dec_outputs = dec_outputsdef __len__(self):return len(self.enc_inputs)def __getitem__(self, idx):return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]enc_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
dec_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
dec_outputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])dataset = MyDataSet(enc_inputs, dec_inputs, dec_outputs)
loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True)for batch in loader:enc_batch, dec_batch, output_batch = batchprint("Encoder batch:", enc_batch)print("Decoder batch:", dec_batch)print("Output batch:", output_batch)
通过使用 DataLoader
,我们可以轻松地处理和批量化我们的数据,这对于大型数据集和深度学习模型的训练是非常重要的。
这篇关于DataLoader基础用法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!