Pytorch.Dataloader 详细深度解读和微修改源代码心得

2024-08-30 07:38

本文主要是介绍Pytorch.Dataloader 详细深度解读和微修改源代码心得,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

关于pytorch 的dataloader库,使用pytorch 基本都会用到的一个库

今天遇到了一个问题,我在训练的时候,采用batch_size =2 去训练,最终的loss抖动太大了,看得出来应该是某些样本在打标的时候打的不好导致的,需要找出这些样本重新修正。但是一开始是采用的dataloader默认库。然后输入进去的图像dataset 传出来之后是经过shuffle的,没有办法定位到哪张图片。PS(如果有知道的大佬知道简单的直接调用的方法,可以评论学习下,万分感谢)

如果全部自己写的话,工作量就比较大了,测试流程较长。所以想是不是可以简单的修改源码来达到目的。

研究了一下源码,还是简单的。

现在将整个过程分享下:

首先,来看看dataloader这个库到底是干了什么事情

1.先定位到源码目录

在torch.utils.data下有2个文件是我们目前需要的,一个dataloader.py,这个文件封装了我们调用的dataloader类,返回的是一个迭代对象,也就是我们网络的输入。另一个是

sampler,实现了拆分输入数据为多个batch,在dataloader内会有关键的调用。了解了两个文件大致目的,我们详细解读下。来看看我们今天的主题:

class DataLoader(object):   __initialized = Falsedef __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None, multiprocessing_context=None):torch._C._log_api_usage_once("python.data_loader")if num_workers < 0:raise ValueError('num_workers option should be non-negative; ''use num_workers=0 to disable multiprocessing.')if timeout < 0:raise ValueError('timeout option should be non-negative')self.dataset = datasetself.num_workers = num_workersself.pin_memory = pin_memoryself.timeout = timeoutself.worker_init_fn = worker_init_fnself.multiprocessing_context = multiprocessing_contextif isinstance(dataset, IterableDataset):self._dataset_kind = _DatasetKind.Iterableif shuffle is not False:raise ValueError("DataLoader with IterableDataset: expected unspecified ""shuffle option, but got shuffle={}".format(shuffle))elif sampler is not None:# See NOTE [ Custom Samplers and IterableDataset ]raise ValueError("DataLoader with IterableDataset: expected unspecified ""sampler option, but got sampler={}".format(sampler))elif batch_sampler is not None:# See NOTE [ Custom Samplers and IterableDataset ]raise ValueError("DataLoader with IterableDataset: expected unspecified ""batch_sampler option, but got batch_sampler={}".format(batch_sampler))else:self._dataset_kind = _DatasetKind.Mapif sampler is not None and shuffle:raise ValueError('sampler option is mutually exclusive with ''shuffle')if batch_sampler is not None:# auto_collation with custom batch_samplerif batch_size != 1 or shuffle or sampler is not None or drop_last:raise ValueError('batch_sampler option is mutually exclusive ''with batch_size, shuffle, sampler, and ''drop_last')batch_size = Nonedrop_last = Falseelif batch_size is None:# no auto_collationif shuffle or drop_last:raise ValueError('batch_size=None option disables auto-batching ''and is mutually exclusive with ''shuffle, and drop_last')if sampler is None:  # give default samplersif self._dataset_kind == _DatasetKind.Iterable:# See NOTE [ Custom Samplers and IterableDataset ]sampler = _InfiniteConstantSampler()else:  # map-styleif shuffle:sampler = RandomSampler(dataset)else:sampler = SequentialSampler(dataset)if batch_size is not None and batch_sampler is None:# auto_collation without custom batch_samplerbatch_sampler = BatchSampler(sampler, batch_size, drop_last)self.batch_size = batch_sizeself.drop_last = drop_lastself.sampler = samplerself.batch_sampler = batch_samplerif collate_fn is None:if self._auto_collation:collate_fn = _utils.collate.default_collateelse:collate_fn = _utils.collate.default_convertself.collate_fn = collate_fnself.__initialized = True

__init__()函数里面有个重点,在

if batch_size is not None and batch_sampler is None:# auto_collation without custom batch_samplerbatch_sampler = BatchSampler(sampler, batch_size, drop_last)

这个函数就得到了打乱顺序的batch的index,以list的形式存放,比如[12,242,456,13]

看下这个class,

class BatchSampler(Sampler):r"""Wraps another sampler to yield a mini-batch of indices.Args:sampler (Sampler): Base sampler.batch_size (int): Size of mini-batch.drop_last (bool): If ``True``, the sampler will drop the last batch ifits size would be less than ``batch_size``Example:>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))[[0, 1, 2], [3, 4, 5], [6, 7, 8]]"""def __init__(self, sampler, batch_size, drop_last):if not isinstance(sampler, Sampler):raise ValueError("sampler should be an instance of ""torch.utils.data.Sampler, but got sampler={}".format(sampler))if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \batch_size <= 0:raise ValueError("batch_size should be a positive integer value, ""but got batch_size={}".format(batch_size))if not isinstance(drop_last, bool):raise ValueError("drop_last should be a boolean value, but got ""drop_last={}".format(drop_last))self.sampler = samplerself.batch_size = batch_sizeself.drop_last = drop_lastdef __iter__(self):batch = []for idx in self.sampler:batch.append(idx)if len(batch) == self.batch_size:yield batchbatch = []if len(batch) > 0 and not self.drop_last:yield batchdef __len__(self):if self.drop_last:return len(self.sampler) // self.batch_sizeelse:return (len(self.sampler) + self.batch_size - 1) // self.batch_size

在__iter__()里面实现了随机采样idx放入list中

 

下面看下dataloader中最重要的函数,也就是我们需要修改的地方

    def __iter__(self):if self.num_workers == 0:return _SingleProcessDataLoaderIter(self)else:return _MultiProcessingDataLoaderIter(self)

dataloarder通过这个函数返回迭代值,当我们的num_worker大于0的时候,也就是采用多进程方式读取数据。

进入下面这个_MultiProcessingDataLoaderIter(self) 函数

class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):def __init__(self, loader):super(_MultiProcessingDataLoaderIter, self).__init__(loader)assert self._num_workers > 0if loader.multiprocessing_context is None:multiprocessing_context = multiprocessingelse:multiprocessing_context = loader.multiprocessing_contextself._worker_init_fn = loader.worker_init_fnself._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))self._worker_result_queue = multiprocessing_context.Queue()self._worker_pids_set = Falseself._shutdown = Falseself._send_idx = 0  # idx of the next task to be sent to workersself._rcvd_idx = 0  # idx of the next task to be returned in __next__self._task_info = {}self._tasks_outstanding = 0  # always equal to count(v for v in task_info.values() if len(v) == 1)self._workers_done_event = multiprocessing_context.Event()self._index_queues = []self._workers = []self._workers_status = []for i in range(self._num_workers):index_queue = multiprocessing_context.Queue()# index_queue.cancel_join_thread()w = multiprocessing_context.Process(target=_utils.worker._worker_loop,args=(self._dataset_kind, self._dataset, index_queue,self._worker_result_queue, self._workers_done_event,self._auto_collation, self._collate_fn, self._drop_last,self._base_seed + i, self._worker_init_fn, i, self._num_workers))w.daemon = Truew.start()self._index_queues.append(index_queue)self._workers.append(w)self._workers_status.append(True)if self._pin_memory:self._pin_memory_thread_done_event = threading.Event()self._data_queue = queue.Queue()pin_memory_thread = threading.Thread(target=_utils.pin_memory._pin_memory_loop,args=(self._worker_result_queue, self._data_queue,torch.cuda.current_device(),self._pin_memory_thread_done_event))pin_memory_thread.daemon = Truepin_memory_thread.start()# Similar to workers (see comment above), we only register# pin_memory_thread once it is started.self._pin_memory_thread = pin_memory_threadelse:self._data_queue = self._worker_result_queue_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))_utils.signal_handling._set_SIGCHLD_handler()self._worker_pids_set = True# prime the prefetch loopfor _ in range(2 * self._num_workers):self._try_put_index()def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):try:data = self._data_queue.get(timeout=timeout)return (True, data)except Exception as e:failed_workers = []for worker_id, w in enumerate(self._workers):if self._workers_status[worker_id] and not w.is_alive():failed_workers.append(w)self._shutdown_worker(worker_id)if len(failed_workers) > 0:pids_str = ', '.join(str(w.pid) for w in failed_workers)raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))if isinstance(e, queue.Empty):return (False, None)raisedef _get_data(self):self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)if self._timeout > 0:success, data = self._try_get_data(self._timeout)if success:return dataelse:raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))elif self._pin_memory:while self._pin_memory_thread.is_alive():success, data = self._try_get_data()if success:return dataelse:# while condition is false, i.e., pin_memory_thread died.raise RuntimeError('Pin memory thread exited unexpectedly')# In this case, `self._data_queue` is a `queue.Queue`,. But we don't# need to call `.task_done()` because we don't use `.join()`.else:while True:success, data = self._try_get_data()if success:return datadef __next__(self):while True:while self._rcvd_idx < self._send_idx:info = self._task_info[self._rcvd_idx]worker_id = info[0]if len(info) == 2 or self._workers_status[worker_id]:  # has data or is still activebreakdel self._task_info[self._rcvd_idx]self._rcvd_idx += 1else:# no valid `self._rcvd_idx` is found (i.e., didn't break)self._shutdown_workers()raise StopIterationif len(self._task_info[self._rcvd_idx]) == 2:data = self._task_info.pop(self._rcvd_idx)[1]return self._process_data(data)assert not self._shutdown and self._tasks_outstanding > 0idx, data = self._get_data()self._tasks_outstanding -= 1if self._dataset_kind == _DatasetKind.Iterable:# Check for _IterableDatasetStopIterationif isinstance(data, _utils.worker._IterableDatasetStopIteration):self._shutdown_worker(data.worker_id)self._try_put_index()continueif idx != self._rcvd_idx:# store out-of-order samplesself._task_info[idx] += (data,)else:del self._task_info[idx]return self._process_data(data)next = __next__  # Python 2 compatibilitydef _try_put_index(self):assert self._tasks_outstanding < 2 * self._num_workerstry:index = self._next_index()except StopIteration:returnfor _ in range(self._num_workers):  # find the next active worker, if anyworker_queue_idx = next(self._worker_queue_idx_cycle)if self._workers_status[worker_queue_idx]:breakelse:# not found (i.e., didn't break)returnself._index_queues[worker_queue_idx].put((self._send_idx, index))self._task_info[self._send_idx] = (worker_queue_idx,)self._tasks_outstanding += 1self._send_idx += 1def _process_data(self, data):self._rcvd_idx += 1self._try_put_index()if isinstance(data, ExceptionWrapper):data.reraise()return datadef _shutdown_worker(self, worker_id):assert self._workers_status[worker_id]q = self._index_queues[worker_id]q.put(None)self._workers_status[worker_id] = Falsedef _shutdown_workers(self):python_exit_status = _utils.python_exit_statusif python_exit_status is True or python_exit_status is None:# See (2) of the note. If Python is shutting down, do no-op.returnif not self._shutdown:self._shutdown = Truetry:if hasattr(self, '_pin_memory_thread'):# Use hasattr in case error happens before we set the attribute.self._pin_memory_thread_done_event.set()# Send something to pin_memory_thread in case it is waiting# so that it can wake up and check `pin_memory_thread_done_event`self._worker_result_queue.put((None, None))self._pin_memory_thread.join()self._worker_result_queue.close()# Exit workers now.self._workers_done_event.set()for worker_id in range(len(self._workers)):if self._workers_status[worker_id]:self._shutdown_worker(worker_id)for w in self._workers:w.join()for q in self._index_queues:q.cancel_join_thread()q.close()finally:if self._worker_pids_set:_utils.signal_handling._remove_worker_pids(id(self))self._worker_pids_set = Falsedef __del__(self):self._shutdown_workers()

找到我们的__next__()函数,也就是这个类的输出,我们看下输出的是

return self._process_data(data)

再进入self._process_data(data)

def _process_data(self, data):self._rcvd_idx += 1self._try_put_index()if isinstance(data, ExceptionWrapper):data.reraise()return data

发现了我们的data还是data,没有变,所以我们需要找到这个data值的来源,在输出这个地方加入我们需要自己传出去的值

当然,不要在源码上修改,我们copy一份文件dataloader_util 放在自己项目合适的位置,然后修改代码传出自己需要的数据。

我们刚才已经找到了我们的打乱的batch集合,也找到了我们数据出口。

接下去我们讲batch一起传出来即可

1.修改import

from torch.utils.data.sampler import Sampler, SequentialSampler, RandomSampler, BatchSampler
from torch.utils.data.dataset import IterableDataset

2.在__init__里面加上代码用来保存我们的batch:

# 新增数据下标self._index_list = []

3.在下面函数的最后一行赋值

    def _try_put_index(self):assert self._tasks_outstanding < 2 * self._num_workerstry:index = self._next_index()except StopIteration:returnfor _ in range(self._num_workers):  # find the next active worker, if anyworker_queue_idx = next(self._worker_queue_idx_cycle)if self._workers_status[worker_queue_idx]:breakelse:# not found (i.e., didn't break)returnself._index_queues[worker_queue_idx].put((self._send_idx, index))self._task_info[self._send_idx] = (worker_queue_idx,)self._tasks_outstanding += 1self._send_idx += 1self._index_list = index

4.在下面函数中将返回值加上

    def _process_data(self, data):self._rcvd_idx += 1self._try_put_index()if isinstance(data, ExceptionWrapper):data.reraise()return data, self._index_list

5.在训练的时候修改下获取数据的代码

    for i, (train_images, index_) in enumerate(train_loader):images, labels, training_mask = train_imagesshuffle_image_path = [train_loader.dataset.data_list[x][0] for x in index_]

记得调用这个train_loader 的路径改成自己copy出来的,现在这个shuffle_image_path 就是我要的数据了。可以愉快的打印保存了。

是不是很简单呢?其他的获取自己相应要修改的数据,也可以这么获取

这篇关于Pytorch.Dataloader 详细深度解读和微修改源代码心得的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1120196

相关文章

Redis与缓存解读

《Redis与缓存解读》文章介绍了Redis作为缓存层的优势和缺点,并分析了六种缓存更新策略,包括超时剔除、先删缓存再更新数据库、旁路缓存、先更新数据库再删缓存、先更新数据库再更新缓存、读写穿透和异步... 目录缓存缓存优缺点缓存更新策略超时剔除先删缓存再更新数据库旁路缓存(先更新数据库,再删缓存)先更新数

最新版IDEA配置 Tomcat的详细过程

《最新版IDEA配置Tomcat的详细过程》本文介绍如何在IDEA中配置Tomcat服务器,并创建Web项目,首先检查Tomcat是否安装完成,然后在IDEA中创建Web项目并添加Web结构,接着,... 目录配置tomcat第一步,先给项目添加Web结构查看端口号配置tomcat    先检查自己的to

使用Nginx来共享文件的详细教程

《使用Nginx来共享文件的详细教程》有时我们想共享电脑上的某些文件,一个比较方便的做法是,开一个HTTP服务,指向文件所在的目录,这次我们用nginx来实现这个需求,本文将通过代码示例一步步教你使用... 在本教程中,我们将向您展示如何使用开源 Web 服务器 Nginx 设置文件共享服务器步骤 0 —

Node.js 中 http 模块的深度剖析与实战应用小结

《Node.js中http模块的深度剖析与实战应用小结》本文详细介绍了Node.js中的http模块,从创建HTTP服务器、处理请求与响应,到获取请求参数,每个环节都通过代码示例进行解析,旨在帮... 目录Node.js 中 http 模块的深度剖析与实战应用一、引言二、创建 HTTP 服务器:基石搭建(一

SpringBoot集成SOL链的详细过程

《SpringBoot集成SOL链的详细过程》Solanaj是一个用于与Solana区块链交互的Java库,它为Java开发者提供了一套功能丰富的API,使得在Java环境中可以轻松构建与Solana... 目录一、什么是solanaj?二、Pom依赖三、主要类3.1 RpcClient3.2 Public

手把手教你idea中创建一个javaweb(webapp)项目详细图文教程

《手把手教你idea中创建一个javaweb(webapp)项目详细图文教程》:本文主要介绍如何使用IntelliJIDEA创建一个Maven项目,并配置Tomcat服务器进行运行,过程包括创建... 1.启动idea2.创建项目模板点击项目-新建项目-选择maven,显示如下页面输入项目名称,选择

Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)

《Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)》:本文主要介绍Python基于火山引擎豆包大模型搭建QQ机器人详细的相关资料,包括开通模型、配置APIKEY鉴权和SD... 目录豆包大模型概述开通模型付费安装 SDK 环境配置 API KEY 鉴权Ark 模型接口Prompt

在 VSCode 中配置 C++ 开发环境的详细教程

《在VSCode中配置C++开发环境的详细教程》本文详细介绍了如何在VisualStudioCode(VSCode)中配置C++开发环境,包括安装必要的工具、配置编译器、设置调试环境等步骤,通... 目录如何在 VSCode 中配置 C++ 开发环境:详细教程1. 什么是 VSCode?2. 安装 VSCo

Spring Boot 中整合 MyBatis-Plus详细步骤(最新推荐)

《SpringBoot中整合MyBatis-Plus详细步骤(最新推荐)》本文详细介绍了如何在SpringBoot项目中整合MyBatis-Plus,包括整合步骤、基本CRUD操作、分页查询、批... 目录一、整合步骤1. 创建 Spring Boot 项目2. 配置项目依赖3. 配置数据源4. 创建实体类

python与QT联合的详细步骤记录

《python与QT联合的详细步骤记录》:本文主要介绍python与QT联合的详细步骤,文章还展示了如何在Python中调用QT的.ui文件来实现GUI界面,并介绍了多窗口的应用,文中通过代码介绍... 目录一、文章简介二、安装pyqt5三、GUI页面设计四、python的使用python文件创建pytho