Pytorch.Dataloader 详细深度解读和微修改源代码心得

2024-08-30 07:38

本文主要是介绍Pytorch.Dataloader 详细深度解读和微修改源代码心得,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

关于pytorch 的dataloader库,使用pytorch 基本都会用到的一个库

今天遇到了一个问题,我在训练的时候,采用batch_size =2 去训练,最终的loss抖动太大了,看得出来应该是某些样本在打标的时候打的不好导致的,需要找出这些样本重新修正。但是一开始是采用的dataloader默认库。然后输入进去的图像dataset 传出来之后是经过shuffle的,没有办法定位到哪张图片。PS(如果有知道的大佬知道简单的直接调用的方法,可以评论学习下,万分感谢)

如果全部自己写的话,工作量就比较大了,测试流程较长。所以想是不是可以简单的修改源码来达到目的。

研究了一下源码,还是简单的。

现在将整个过程分享下:

首先,来看看dataloader这个库到底是干了什么事情

1.先定位到源码目录

在torch.utils.data下有2个文件是我们目前需要的,一个dataloader.py,这个文件封装了我们调用的dataloader类,返回的是一个迭代对象,也就是我们网络的输入。另一个是

sampler,实现了拆分输入数据为多个batch,在dataloader内会有关键的调用。了解了两个文件大致目的,我们详细解读下。来看看我们今天的主题:

class DataLoader(object):   __initialized = Falsedef __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None, multiprocessing_context=None):torch._C._log_api_usage_once("python.data_loader")if num_workers < 0:raise ValueError('num_workers option should be non-negative; ''use num_workers=0 to disable multiprocessing.')if timeout < 0:raise ValueError('timeout option should be non-negative')self.dataset = datasetself.num_workers = num_workersself.pin_memory = pin_memoryself.timeout = timeoutself.worker_init_fn = worker_init_fnself.multiprocessing_context = multiprocessing_contextif isinstance(dataset, IterableDataset):self._dataset_kind = _DatasetKind.Iterableif shuffle is not False:raise ValueError("DataLoader with IterableDataset: expected unspecified ""shuffle option, but got shuffle={}".format(shuffle))elif sampler is not None:# See NOTE [ Custom Samplers and IterableDataset ]raise ValueError("DataLoader with IterableDataset: expected unspecified ""sampler option, but got sampler={}".format(sampler))elif batch_sampler is not None:# See NOTE [ Custom Samplers and IterableDataset ]raise ValueError("DataLoader with IterableDataset: expected unspecified ""batch_sampler option, but got batch_sampler={}".format(batch_sampler))else:self._dataset_kind = _DatasetKind.Mapif sampler is not None and shuffle:raise ValueError('sampler option is mutually exclusive with ''shuffle')if batch_sampler is not None:# auto_collation with custom batch_samplerif batch_size != 1 or shuffle or sampler is not None or drop_last:raise ValueError('batch_sampler option is mutually exclusive ''with batch_size, shuffle, sampler, and ''drop_last')batch_size = Nonedrop_last = Falseelif batch_size is None:# no auto_collationif shuffle or drop_last:raise ValueError('batch_size=None option disables auto-batching ''and is mutually exclusive with ''shuffle, and drop_last')if sampler is None:  # give default samplersif self._dataset_kind == _DatasetKind.Iterable:# See NOTE [ Custom Samplers and IterableDataset ]sampler = _InfiniteConstantSampler()else:  # map-styleif shuffle:sampler = RandomSampler(dataset)else:sampler = SequentialSampler(dataset)if batch_size is not None and batch_sampler is None:# auto_collation without custom batch_samplerbatch_sampler = BatchSampler(sampler, batch_size, drop_last)self.batch_size = batch_sizeself.drop_last = drop_lastself.sampler = samplerself.batch_sampler = batch_samplerif collate_fn is None:if self._auto_collation:collate_fn = _utils.collate.default_collateelse:collate_fn = _utils.collate.default_convertself.collate_fn = collate_fnself.__initialized = True

__init__()函数里面有个重点,在

if batch_size is not None and batch_sampler is None:# auto_collation without custom batch_samplerbatch_sampler = BatchSampler(sampler, batch_size, drop_last)

这个函数就得到了打乱顺序的batch的index,以list的形式存放,比如[12,242,456,13]

看下这个class,

class BatchSampler(Sampler):r"""Wraps another sampler to yield a mini-batch of indices.Args:sampler (Sampler): Base sampler.batch_size (int): Size of mini-batch.drop_last (bool): If ``True``, the sampler will drop the last batch ifits size would be less than ``batch_size``Example:>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))[[0, 1, 2], [3, 4, 5], [6, 7, 8]]"""def __init__(self, sampler, batch_size, drop_last):if not isinstance(sampler, Sampler):raise ValueError("sampler should be an instance of ""torch.utils.data.Sampler, but got sampler={}".format(sampler))if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \batch_size <= 0:raise ValueError("batch_size should be a positive integer value, ""but got batch_size={}".format(batch_size))if not isinstance(drop_last, bool):raise ValueError("drop_last should be a boolean value, but got ""drop_last={}".format(drop_last))self.sampler = samplerself.batch_size = batch_sizeself.drop_last = drop_lastdef __iter__(self):batch = []for idx in self.sampler:batch.append(idx)if len(batch) == self.batch_size:yield batchbatch = []if len(batch) > 0 and not self.drop_last:yield batchdef __len__(self):if self.drop_last:return len(self.sampler) // self.batch_sizeelse:return (len(self.sampler) + self.batch_size - 1) // self.batch_size

在__iter__()里面实现了随机采样idx放入list中

 

下面看下dataloader中最重要的函数,也就是我们需要修改的地方

    def __iter__(self):if self.num_workers == 0:return _SingleProcessDataLoaderIter(self)else:return _MultiProcessingDataLoaderIter(self)

dataloarder通过这个函数返回迭代值,当我们的num_worker大于0的时候,也就是采用多进程方式读取数据。

进入下面这个_MultiProcessingDataLoaderIter(self) 函数

class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):def __init__(self, loader):super(_MultiProcessingDataLoaderIter, self).__init__(loader)assert self._num_workers > 0if loader.multiprocessing_context is None:multiprocessing_context = multiprocessingelse:multiprocessing_context = loader.multiprocessing_contextself._worker_init_fn = loader.worker_init_fnself._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))self._worker_result_queue = multiprocessing_context.Queue()self._worker_pids_set = Falseself._shutdown = Falseself._send_idx = 0  # idx of the next task to be sent to workersself._rcvd_idx = 0  # idx of the next task to be returned in __next__self._task_info = {}self._tasks_outstanding = 0  # always equal to count(v for v in task_info.values() if len(v) == 1)self._workers_done_event = multiprocessing_context.Event()self._index_queues = []self._workers = []self._workers_status = []for i in range(self._num_workers):index_queue = multiprocessing_context.Queue()# index_queue.cancel_join_thread()w = multiprocessing_context.Process(target=_utils.worker._worker_loop,args=(self._dataset_kind, self._dataset, index_queue,self._worker_result_queue, self._workers_done_event,self._auto_collation, self._collate_fn, self._drop_last,self._base_seed + i, self._worker_init_fn, i, self._num_workers))w.daemon = Truew.start()self._index_queues.append(index_queue)self._workers.append(w)self._workers_status.append(True)if self._pin_memory:self._pin_memory_thread_done_event = threading.Event()self._data_queue = queue.Queue()pin_memory_thread = threading.Thread(target=_utils.pin_memory._pin_memory_loop,args=(self._worker_result_queue, self._data_queue,torch.cuda.current_device(),self._pin_memory_thread_done_event))pin_memory_thread.daemon = Truepin_memory_thread.start()# Similar to workers (see comment above), we only register# pin_memory_thread once it is started.self._pin_memory_thread = pin_memory_threadelse:self._data_queue = self._worker_result_queue_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))_utils.signal_handling._set_SIGCHLD_handler()self._worker_pids_set = True# prime the prefetch loopfor _ in range(2 * self._num_workers):self._try_put_index()def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):try:data = self._data_queue.get(timeout=timeout)return (True, data)except Exception as e:failed_workers = []for worker_id, w in enumerate(self._workers):if self._workers_status[worker_id] and not w.is_alive():failed_workers.append(w)self._shutdown_worker(worker_id)if len(failed_workers) > 0:pids_str = ', '.join(str(w.pid) for w in failed_workers)raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))if isinstance(e, queue.Empty):return (False, None)raisedef _get_data(self):self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)if self._timeout > 0:success, data = self._try_get_data(self._timeout)if success:return dataelse:raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))elif self._pin_memory:while self._pin_memory_thread.is_alive():success, data = self._try_get_data()if success:return dataelse:# while condition is false, i.e., pin_memory_thread died.raise RuntimeError('Pin memory thread exited unexpectedly')# In this case, `self._data_queue` is a `queue.Queue`,. But we don't# need to call `.task_done()` because we don't use `.join()`.else:while True:success, data = self._try_get_data()if success:return datadef __next__(self):while True:while self._rcvd_idx < self._send_idx:info = self._task_info[self._rcvd_idx]worker_id = info[0]if len(info) == 2 or self._workers_status[worker_id]:  # has data or is still activebreakdel self._task_info[self._rcvd_idx]self._rcvd_idx += 1else:# no valid `self._rcvd_idx` is found (i.e., didn't break)self._shutdown_workers()raise StopIterationif len(self._task_info[self._rcvd_idx]) == 2:data = self._task_info.pop(self._rcvd_idx)[1]return self._process_data(data)assert not self._shutdown and self._tasks_outstanding > 0idx, data = self._get_data()self._tasks_outstanding -= 1if self._dataset_kind == _DatasetKind.Iterable:# Check for _IterableDatasetStopIterationif isinstance(data, _utils.worker._IterableDatasetStopIteration):self._shutdown_worker(data.worker_id)self._try_put_index()continueif idx != self._rcvd_idx:# store out-of-order samplesself._task_info[idx] += (data,)else:del self._task_info[idx]return self._process_data(data)next = __next__  # Python 2 compatibilitydef _try_put_index(self):assert self._tasks_outstanding < 2 * self._num_workerstry:index = self._next_index()except StopIteration:returnfor _ in range(self._num_workers):  # find the next active worker, if anyworker_queue_idx = next(self._worker_queue_idx_cycle)if self._workers_status[worker_queue_idx]:breakelse:# not found (i.e., didn't break)returnself._index_queues[worker_queue_idx].put((self._send_idx, index))self._task_info[self._send_idx] = (worker_queue_idx,)self._tasks_outstanding += 1self._send_idx += 1def _process_data(self, data):self._rcvd_idx += 1self._try_put_index()if isinstance(data, ExceptionWrapper):data.reraise()return datadef _shutdown_worker(self, worker_id):assert self._workers_status[worker_id]q = self._index_queues[worker_id]q.put(None)self._workers_status[worker_id] = Falsedef _shutdown_workers(self):python_exit_status = _utils.python_exit_statusif python_exit_status is True or python_exit_status is None:# See (2) of the note. If Python is shutting down, do no-op.returnif not self._shutdown:self._shutdown = Truetry:if hasattr(self, '_pin_memory_thread'):# Use hasattr in case error happens before we set the attribute.self._pin_memory_thread_done_event.set()# Send something to pin_memory_thread in case it is waiting# so that it can wake up and check `pin_memory_thread_done_event`self._worker_result_queue.put((None, None))self._pin_memory_thread.join()self._worker_result_queue.close()# Exit workers now.self._workers_done_event.set()for worker_id in range(len(self._workers)):if self._workers_status[worker_id]:self._shutdown_worker(worker_id)for w in self._workers:w.join()for q in self._index_queues:q.cancel_join_thread()q.close()finally:if self._worker_pids_set:_utils.signal_handling._remove_worker_pids(id(self))self._worker_pids_set = Falsedef __del__(self):self._shutdown_workers()

找到我们的__next__()函数,也就是这个类的输出,我们看下输出的是

return self._process_data(data)

再进入self._process_data(data)

def _process_data(self, data):self._rcvd_idx += 1self._try_put_index()if isinstance(data, ExceptionWrapper):data.reraise()return data

发现了我们的data还是data,没有变,所以我们需要找到这个data值的来源,在输出这个地方加入我们需要自己传出去的值

当然,不要在源码上修改,我们copy一份文件dataloader_util 放在自己项目合适的位置,然后修改代码传出自己需要的数据。

我们刚才已经找到了我们的打乱的batch集合,也找到了我们数据出口。

接下去我们讲batch一起传出来即可

1.修改import

from torch.utils.data.sampler import Sampler, SequentialSampler, RandomSampler, BatchSampler
from torch.utils.data.dataset import IterableDataset

2.在__init__里面加上代码用来保存我们的batch:

# 新增数据下标self._index_list = []

3.在下面函数的最后一行赋值

    def _try_put_index(self):assert self._tasks_outstanding < 2 * self._num_workerstry:index = self._next_index()except StopIteration:returnfor _ in range(self._num_workers):  # find the next active worker, if anyworker_queue_idx = next(self._worker_queue_idx_cycle)if self._workers_status[worker_queue_idx]:breakelse:# not found (i.e., didn't break)returnself._index_queues[worker_queue_idx].put((self._send_idx, index))self._task_info[self._send_idx] = (worker_queue_idx,)self._tasks_outstanding += 1self._send_idx += 1self._index_list = index

4.在下面函数中将返回值加上

    def _process_data(self, data):self._rcvd_idx += 1self._try_put_index()if isinstance(data, ExceptionWrapper):data.reraise()return data, self._index_list

5.在训练的时候修改下获取数据的代码

    for i, (train_images, index_) in enumerate(train_loader):images, labels, training_mask = train_imagesshuffle_image_path = [train_loader.dataset.data_list[x][0] for x in index_]

记得调用这个train_loader 的路径改成自己copy出来的,现在这个shuffle_image_path 就是我要的数据了。可以愉快的打印保存了。

是不是很简单呢?其他的获取自己相应要修改的数据,也可以这么获取

这篇关于Pytorch.Dataloader 详细深度解读和微修改源代码心得的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1120196

相关文章

SpringCloud动态配置注解@RefreshScope与@Component的深度解析

《SpringCloud动态配置注解@RefreshScope与@Component的深度解析》在现代微服务架构中,动态配置管理是一个关键需求,本文将为大家介绍SpringCloud中相关的注解@Re... 目录引言1. @RefreshScope 的作用与原理1.1 什么是 @RefreshScope1.

Java使用Curator进行ZooKeeper操作的详细教程

《Java使用Curator进行ZooKeeper操作的详细教程》ApacheCurator是一个基于ZooKeeper的Java客户端库,它极大地简化了使用ZooKeeper的开发工作,在分布式系统... 目录1、简述2、核心功能2.1 CuratorFramework2.2 Recipes3、示例实践3

java之Objects.nonNull用法代码解读

《java之Objects.nonNull用法代码解读》:本文主要介绍java之Objects.nonNull用法代码,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录Java之Objects.nonwww.chinasem.cnNull用法代码Objects.nonN

Python实现无痛修改第三方库源码的方法详解

《Python实现无痛修改第三方库源码的方法详解》很多时候,我们下载的第三方库是不会有需求不满足的情况,但也有极少的情况,第三方库没有兼顾到需求,本文将介绍几个修改源码的操作,大家可以根据需求进行选择... 目录需求不符合模拟示例 1. 修改源文件2. 继承修改3. 猴子补丁4. 追踪局部变量需求不符合很

Python 中的异步与同步深度解析(实践记录)

《Python中的异步与同步深度解析(实践记录)》在Python编程世界里,异步和同步的概念是理解程序执行流程和性能优化的关键,这篇文章将带你深入了解它们的差异,以及阻塞和非阻塞的特性,同时通过实际... 目录python中的异步与同步:深度解析与实践异步与同步的定义异步同步阻塞与非阻塞的概念阻塞非阻塞同步

通过Docker Compose部署MySQL的详细教程

《通过DockerCompose部署MySQL的详细教程》DockerCompose作为Docker官方的容器编排工具,为MySQL数据库部署带来了显著优势,下面小编就来为大家详细介绍一... 目录一、docker Compose 部署 mysql 的优势二、环境准备与基础配置2.1 项目目录结构2.2 基

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

SpringCloud负载均衡spring-cloud-starter-loadbalancer解读

《SpringCloud负载均衡spring-cloud-starter-loadbalancer解读》:本文主要介绍SpringCloud负载均衡spring-cloud-starter-loa... 目录简述主要特点使用负载均衡算法1. 轮询负载均衡策略(Round Robin)2. 随机负载均衡策略(

Linux修改pip和conda缓存路径的几种方法

《Linux修改pip和conda缓存路径的几种方法》在Python生态中,pip和conda是两种常见的软件包管理工具,它们在安装、更新和卸载软件包时都会使用缓存来提高效率,适当地修改它们的缓存路径... 目录一、pip 和 conda 的缓存机制1. pip 的缓存机制默认缓存路径2. conda 的缓

Linux修改pip临时目录方法的详解

《Linux修改pip临时目录方法的详解》在Linux系统中,pip在安装Python包时会使用临时目录(TMPDIR),但默认的临时目录可能会受到存储空间不足或权限问题的影响,所以本文将详细介绍如何... 目录引言一、为什么要修改 pip 的临时目录?1. 解决存储空间不足的问题2. 解决权限问题3. 提