本文主要是介绍pytorch -- DataLoader,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
- 定义
提供了给定数据集的迭代器
torch.utils.data.DataLoader(dataset,
batch_size=1, 每次拿多少数据
shuffle=None, 是否打乱
sampler=None,
batch_sampler=None,
num_workers=0, 多进程(加载数据时采用)默认是0,使用主进程加载数据
collate_fn=None,
pin_memory=False,
drop_last=False, 总数据/batch_size之后的余数是舍去(True)还是不舍去(False)
timeout=0,
worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='')
官方描述
2. 使用
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvisiondataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()
])
# 测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset",transform=dataset_transform,train=False,download=True)
# 测试集中的第一张图片及target
img,target = test_set[0]
print(img.shape,target)
# 在终端中查看tensorboard --logdir=dataloader
writer = SummaryWriter("dataloader")
test_loader = DataLoader(dataset=test_set,batch_size=64,shuffle=True,num_workers=0,drop_last=True)
for epoch in range(2):step = 0for data in test_loader:imgs,targets = data# print(imgs.shape)# print(targets)writer.add_images("Epoch:{}".format(epoch),imgs,step)step += 1
writer.close()
这篇关于pytorch -- DataLoader的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!