pytorch -- DataLoader

2024-02-25 21:12
文章标签 pytorch dataloader

本文主要是介绍pytorch -- DataLoader,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  1. 定义
    提供了给定数据集的迭代器
torch.utils.data.DataLoader(dataset, 
batch_size=1, 每次拿多少数据
shuffle=None, 是否打乱
sampler=None, 
batch_sampler=None, 
num_workers=0, 多进程(加载数据时采用)默认是0,使用主进程加载数据
collate_fn=None, 
pin_memory=False, 
drop_last=False, 总数据/batch_size之后的余数是舍去(True)还是不舍去(False)
timeout=0, 
worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='')

官方描述
2. 使用

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvisiondataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()
])
# 测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset",transform=dataset_transform,train=False,download=True)
# 测试集中的第一张图片及target
img,target = test_set[0]
print(img.shape,target)
# 在终端中查看tensorboard --logdir=dataloader
writer = SummaryWriter("dataloader")
test_loader = DataLoader(dataset=test_set,batch_size=64,shuffle=True,num_workers=0,drop_last=True)
for epoch in range(2):step = 0for data in test_loader:imgs,targets = data# print(imgs.shape)# print(targets)writer.add_images("Epoch:{}".format(epoch),imgs,step)step += 1
writer.close()

这篇关于pytorch -- DataLoader的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/746734

相关文章

基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别

转发来源:https://swift.ctolib.com/ooooverflow-chinese-ocr.html chinese-ocr 基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别 环境部署 sh setup.sh 使用环境: python 3.6 + tensorflow 1.10 +pytorch 0.4.1 注:CPU环境

PyTorch模型_trace实战:深入理解与应用

pytorch使用trace模型 1、使用trace生成torchscript模型2、使用trace的模型预测 1、使用trace生成torchscript模型 def save_trace(model, input, save_path):traced_script_model = torch.jit.trace(model, input)<

pytorch国内镜像源安装及测试

一、安装命令:  pip install torch torchvision torchaudio -i https://pypi.tuna.tsinghua.edu.cn/simple  二、测试: import torchx = torch.rand(5, 3)print(x)

PyTorch nn.MSELoss() 均方误差损失函数详解和要点提醒

文章目录 nn.MSELoss() 均方误差损失函数参数数学公式元素版本 要点附录 参考链接 nn.MSELoss() 均方误差损失函数 torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean') Creates a criterion that measures the mean squared err

动手学深度学习(Pytorch版)代码实践 -计算机视觉-37微调

37微调 import osimport torchimport torchvisionfrom torch import nnimport liliPytorch as lpimport matplotlib.pyplot as pltfrom d2l import torch as d2l# 获取数据集d2l.DATA_HUB['hotdog'] = (d2l.DATA_U

WSL+Anconda(pytorch深度学习)环境配置

动机 最近在读point cloud相关论文,准备拉github上相应的code跑一下,但是之前没有深度学习的经验,在配置环境方面踩了超级多的坑,依次来记录一下。 一开始我直接将code拉到了windows本地来运行,遇到了数不清的问题(如:torch版本问题、numpy版本、bash命令无法运行等问题),经过请教,决定将project放到linux系统上进行运行。所以安装WSL(Window

动手学深度学习(Pytorch版)代码实践 -计算机视觉-36图像增广

6 图片增广 import matplotlib.pyplot as pltimport numpy as npimport torch import torchvisionfrom d2l import torch as d2lfrom torch import nn from PIL import Imageimport liliPytorch as lpfrom tor

pytorch 使用GPU加速常见的问题

pytorch如何使用gpu加速 print(torch.cuda.is_available())# 设置gpu设备device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# net使用GPUnet.to(device)# 数据copy到gpuinputData = inputData.to(devi

深度学习:关于损失函数的一些前置知识(PyTorch Loss)

在之前进行实验的时候发现:调用 Pytorch 中的 Loss 函数之前如果对其没有一定的了解,可能会影响实验效果和调试效率。以 CrossEntropyLoss 为例,最初设计实验的时候没有注意到该函数默认返回的是均值,以为是总和,于是最后计算完 Loss 之后,手动做了个均值,导致实际 Loss 被错误缩放,实验效果不佳,在后来 Debug 排除代码模型架构问题的时候才发觉这一点,着实花费了

《PyTorch计算机视觉实战》:一、二章

目录 第一章:人工神经网络基础 比较人工智能和传统机器学习 人工神经网络(Artificial Neural Network,ANN) 是一种受人类大脑运作方式启发而构建的监督学习算法。神经网络与人类大脑中神经元连接和激活的方式比较类似,神经网络接收输入并通过一个函数传递,导致随后的某些神经元被激活,从而产生输出。 有几种标准的 ANN 架构。通用近似定理认为,总是可以找到一