Pytorch.Dataloader 详细深度解读和微修改源代码心得

2024-08-30 07:38

本文主要是介绍Pytorch.Dataloader 详细深度解读和微修改源代码心得,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

关于pytorch 的dataloader库,使用pytorch 基本都会用到的一个库

今天遇到了一个问题,我在训练的时候,采用batch_size =2 去训练,最终的loss抖动太大了,看得出来应该是某些样本在打标的时候打的不好导致的,需要找出这些样本重新修正。但是一开始是采用的dataloader默认库。然后输入进去的图像dataset 传出来之后是经过shuffle的,没有办法定位到哪张图片。PS(如果有知道的大佬知道简单的直接调用的方法,可以评论学习下,万分感谢)

如果全部自己写的话,工作量就比较大了,测试流程较长。所以想是不是可以简单的修改源码来达到目的。

研究了一下源码,还是简单的。

现在将整个过程分享下:

首先,来看看dataloader这个库到底是干了什么事情

1.先定位到源码目录

在torch.utils.data下有2个文件是我们目前需要的,一个dataloader.py,这个文件封装了我们调用的dataloader类,返回的是一个迭代对象,也就是我们网络的输入。另一个是

sampler,实现了拆分输入数据为多个batch,在dataloader内会有关键的调用。了解了两个文件大致目的,我们详细解读下。来看看我们今天的主题:

class DataLoader(object):   __initialized = Falsedef __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None, multiprocessing_context=None):torch._C._log_api_usage_once("python.data_loader")if num_workers < 0:raise ValueError('num_workers option should be non-negative; ''use num_workers=0 to disable multiprocessing.')if timeout < 0:raise ValueError('timeout option should be non-negative')self.dataset = datasetself.num_workers = num_workersself.pin_memory = pin_memoryself.timeout = timeoutself.worker_init_fn = worker_init_fnself.multiprocessing_context = multiprocessing_contextif isinstance(dataset, IterableDataset):self._dataset_kind = _DatasetKind.Iterableif shuffle is not False:raise ValueError("DataLoader with IterableDataset: expected unspecified ""shuffle option, but got shuffle={}".format(shuffle))elif sampler is not None:# See NOTE [ Custom Samplers and IterableDataset ]raise ValueError("DataLoader with IterableDataset: expected unspecified ""sampler option, but got sampler={}".format(sampler))elif batch_sampler is not None:# See NOTE [ Custom Samplers and IterableDataset ]raise ValueError("DataLoader with IterableDataset: expected unspecified ""batch_sampler option, but got batch_sampler={}".format(batch_sampler))else:self._dataset_kind = _DatasetKind.Mapif sampler is not None and shuffle:raise ValueError('sampler option is mutually exclusive with ''shuffle')if batch_sampler is not None:# auto_collation with custom batch_samplerif batch_size != 1 or shuffle or sampler is not None or drop_last:raise ValueError('batch_sampler option is mutually exclusive ''with batch_size, shuffle, sampler, and ''drop_last')batch_size = Nonedrop_last = Falseelif batch_size is None:# no auto_collationif shuffle or drop_last:raise ValueError('batch_size=None option disables auto-batching ''and is mutually exclusive with ''shuffle, and drop_last')if sampler is None:  # give default samplersif self._dataset_kind == _DatasetKind.Iterable:# See NOTE [ Custom Samplers and IterableDataset ]sampler = _InfiniteConstantSampler()else:  # map-styleif shuffle:sampler = RandomSampler(dataset)else:sampler = SequentialSampler(dataset)if batch_size is not None and batch_sampler is None:# auto_collation without custom batch_samplerbatch_sampler = BatchSampler(sampler, batch_size, drop_last)self.batch_size = batch_sizeself.drop_last = drop_lastself.sampler = samplerself.batch_sampler = batch_samplerif collate_fn is None:if self._auto_collation:collate_fn = _utils.collate.default_collateelse:collate_fn = _utils.collate.default_convertself.collate_fn = collate_fnself.__initialized = True

__init__()函数里面有个重点,在

if batch_size is not None and batch_sampler is None:# auto_collation without custom batch_samplerbatch_sampler = BatchSampler(sampler, batch_size, drop_last)

这个函数就得到了打乱顺序的batch的index,以list的形式存放,比如[12,242,456,13]

看下这个class,

class BatchSampler(Sampler):r"""Wraps another sampler to yield a mini-batch of indices.Args:sampler (Sampler): Base sampler.batch_size (int): Size of mini-batch.drop_last (bool): If ``True``, the sampler will drop the last batch ifits size would be less than ``batch_size``Example:>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))[[0, 1, 2], [3, 4, 5], [6, 7, 8]]"""def __init__(self, sampler, batch_size, drop_last):if not isinstance(sampler, Sampler):raise ValueError("sampler should be an instance of ""torch.utils.data.Sampler, but got sampler={}".format(sampler))if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \batch_size <= 0:raise ValueError("batch_size should be a positive integer value, ""but got batch_size={}".format(batch_size))if not isinstance(drop_last, bool):raise ValueError("drop_last should be a boolean value, but got ""drop_last={}".format(drop_last))self.sampler = samplerself.batch_size = batch_sizeself.drop_last = drop_lastdef __iter__(self):batch = []for idx in self.sampler:batch.append(idx)if len(batch) == self.batch_size:yield batchbatch = []if len(batch) > 0 and not self.drop_last:yield batchdef __len__(self):if self.drop_last:return len(self.sampler) // self.batch_sizeelse:return (len(self.sampler) + self.batch_size - 1) // self.batch_size

在__iter__()里面实现了随机采样idx放入list中

 

下面看下dataloader中最重要的函数,也就是我们需要修改的地方

    def __iter__(self):if self.num_workers == 0:return _SingleProcessDataLoaderIter(self)else:return _MultiProcessingDataLoaderIter(self)

dataloarder通过这个函数返回迭代值,当我们的num_worker大于0的时候,也就是采用多进程方式读取数据。

进入下面这个_MultiProcessingDataLoaderIter(self) 函数

class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):def __init__(self, loader):super(_MultiProcessingDataLoaderIter, self).__init__(loader)assert self._num_workers > 0if loader.multiprocessing_context is None:multiprocessing_context = multiprocessingelse:multiprocessing_context = loader.multiprocessing_contextself._worker_init_fn = loader.worker_init_fnself._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))self._worker_result_queue = multiprocessing_context.Queue()self._worker_pids_set = Falseself._shutdown = Falseself._send_idx = 0  # idx of the next task to be sent to workersself._rcvd_idx = 0  # idx of the next task to be returned in __next__self._task_info = {}self._tasks_outstanding = 0  # always equal to count(v for v in task_info.values() if len(v) == 1)self._workers_done_event = multiprocessing_context.Event()self._index_queues = []self._workers = []self._workers_status = []for i in range(self._num_workers):index_queue = multiprocessing_context.Queue()# index_queue.cancel_join_thread()w = multiprocessing_context.Process(target=_utils.worker._worker_loop,args=(self._dataset_kind, self._dataset, index_queue,self._worker_result_queue, self._workers_done_event,self._auto_collation, self._collate_fn, self._drop_last,self._base_seed + i, self._worker_init_fn, i, self._num_workers))w.daemon = Truew.start()self._index_queues.append(index_queue)self._workers.append(w)self._workers_status.append(True)if self._pin_memory:self._pin_memory_thread_done_event = threading.Event()self._data_queue = queue.Queue()pin_memory_thread = threading.Thread(target=_utils.pin_memory._pin_memory_loop,args=(self._worker_result_queue, self._data_queue,torch.cuda.current_device(),self._pin_memory_thread_done_event))pin_memory_thread.daemon = Truepin_memory_thread.start()# Similar to workers (see comment above), we only register# pin_memory_thread once it is started.self._pin_memory_thread = pin_memory_threadelse:self._data_queue = self._worker_result_queue_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))_utils.signal_handling._set_SIGCHLD_handler()self._worker_pids_set = True# prime the prefetch loopfor _ in range(2 * self._num_workers):self._try_put_index()def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):try:data = self._data_queue.get(timeout=timeout)return (True, data)except Exception as e:failed_workers = []for worker_id, w in enumerate(self._workers):if self._workers_status[worker_id] and not w.is_alive():failed_workers.append(w)self._shutdown_worker(worker_id)if len(failed_workers) > 0:pids_str = ', '.join(str(w.pid) for w in failed_workers)raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))if isinstance(e, queue.Empty):return (False, None)raisedef _get_data(self):self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)if self._timeout > 0:success, data = self._try_get_data(self._timeout)if success:return dataelse:raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))elif self._pin_memory:while self._pin_memory_thread.is_alive():success, data = self._try_get_data()if success:return dataelse:# while condition is false, i.e., pin_memory_thread died.raise RuntimeError('Pin memory thread exited unexpectedly')# In this case, `self._data_queue` is a `queue.Queue`,. But we don't# need to call `.task_done()` because we don't use `.join()`.else:while True:success, data = self._try_get_data()if success:return datadef __next__(self):while True:while self._rcvd_idx < self._send_idx:info = self._task_info[self._rcvd_idx]worker_id = info[0]if len(info) == 2 or self._workers_status[worker_id]:  # has data or is still activebreakdel self._task_info[self._rcvd_idx]self._rcvd_idx += 1else:# no valid `self._rcvd_idx` is found (i.e., didn't break)self._shutdown_workers()raise StopIterationif len(self._task_info[self._rcvd_idx]) == 2:data = self._task_info.pop(self._rcvd_idx)[1]return self._process_data(data)assert not self._shutdown and self._tasks_outstanding > 0idx, data = self._get_data()self._tasks_outstanding -= 1if self._dataset_kind == _DatasetKind.Iterable:# Check for _IterableDatasetStopIterationif isinstance(data, _utils.worker._IterableDatasetStopIteration):self._shutdown_worker(data.worker_id)self._try_put_index()continueif idx != self._rcvd_idx:# store out-of-order samplesself._task_info[idx] += (data,)else:del self._task_info[idx]return self._process_data(data)next = __next__  # Python 2 compatibilitydef _try_put_index(self):assert self._tasks_outstanding < 2 * self._num_workerstry:index = self._next_index()except StopIteration:returnfor _ in range(self._num_workers):  # find the next active worker, if anyworker_queue_idx = next(self._worker_queue_idx_cycle)if self._workers_status[worker_queue_idx]:breakelse:# not found (i.e., didn't break)returnself._index_queues[worker_queue_idx].put((self._send_idx, index))self._task_info[self._send_idx] = (worker_queue_idx,)self._tasks_outstanding += 1self._send_idx += 1def _process_data(self, data):self._rcvd_idx += 1self._try_put_index()if isinstance(data, ExceptionWrapper):data.reraise()return datadef _shutdown_worker(self, worker_id):assert self._workers_status[worker_id]q = self._index_queues[worker_id]q.put(None)self._workers_status[worker_id] = Falsedef _shutdown_workers(self):python_exit_status = _utils.python_exit_statusif python_exit_status is True or python_exit_status is None:# See (2) of the note. If Python is shutting down, do no-op.returnif not self._shutdown:self._shutdown = Truetry:if hasattr(self, '_pin_memory_thread'):# Use hasattr in case error happens before we set the attribute.self._pin_memory_thread_done_event.set()# Send something to pin_memory_thread in case it is waiting# so that it can wake up and check `pin_memory_thread_done_event`self._worker_result_queue.put((None, None))self._pin_memory_thread.join()self._worker_result_queue.close()# Exit workers now.self._workers_done_event.set()for worker_id in range(len(self._workers)):if self._workers_status[worker_id]:self._shutdown_worker(worker_id)for w in self._workers:w.join()for q in self._index_queues:q.cancel_join_thread()q.close()finally:if self._worker_pids_set:_utils.signal_handling._remove_worker_pids(id(self))self._worker_pids_set = Falsedef __del__(self):self._shutdown_workers()

找到我们的__next__()函数,也就是这个类的输出,我们看下输出的是

return self._process_data(data)

再进入self._process_data(data)

def _process_data(self, data):self._rcvd_idx += 1self._try_put_index()if isinstance(data, ExceptionWrapper):data.reraise()return data

发现了我们的data还是data,没有变,所以我们需要找到这个data值的来源,在输出这个地方加入我们需要自己传出去的值

当然,不要在源码上修改,我们copy一份文件dataloader_util 放在自己项目合适的位置,然后修改代码传出自己需要的数据。

我们刚才已经找到了我们的打乱的batch集合,也找到了我们数据出口。

接下去我们讲batch一起传出来即可

1.修改import

from torch.utils.data.sampler import Sampler, SequentialSampler, RandomSampler, BatchSampler
from torch.utils.data.dataset import IterableDataset

2.在__init__里面加上代码用来保存我们的batch:

# 新增数据下标self._index_list = []

3.在下面函数的最后一行赋值

    def _try_put_index(self):assert self._tasks_outstanding < 2 * self._num_workerstry:index = self._next_index()except StopIteration:returnfor _ in range(self._num_workers):  # find the next active worker, if anyworker_queue_idx = next(self._worker_queue_idx_cycle)if self._workers_status[worker_queue_idx]:breakelse:# not found (i.e., didn't break)returnself._index_queues[worker_queue_idx].put((self._send_idx, index))self._task_info[self._send_idx] = (worker_queue_idx,)self._tasks_outstanding += 1self._send_idx += 1self._index_list = index

4.在下面函数中将返回值加上

    def _process_data(self, data):self._rcvd_idx += 1self._try_put_index()if isinstance(data, ExceptionWrapper):data.reraise()return data, self._index_list

5.在训练的时候修改下获取数据的代码

    for i, (train_images, index_) in enumerate(train_loader):images, labels, training_mask = train_imagesshuffle_image_path = [train_loader.dataset.data_list[x][0] for x in index_]

记得调用这个train_loader 的路径改成自己copy出来的,现在这个shuffle_image_path 就是我要的数据了。可以愉快的打印保存了。

是不是很简单呢?其他的获取自己相应要修改的数据,也可以这么获取

这篇关于Pytorch.Dataloader 详细深度解读和微修改源代码心得的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1120196

相关文章

Java调用DeepSeek API的最佳实践及详细代码示例

《Java调用DeepSeekAPI的最佳实践及详细代码示例》:本文主要介绍如何使用Java调用DeepSeekAPI,包括获取API密钥、添加HTTP客户端依赖、创建HTTP请求、处理响应、... 目录1. 获取API密钥2. 添加HTTP客户端依赖3. 创建HTTP请求4. 处理响应5. 错误处理6.

Spring AI集成DeepSeek的详细步骤

《SpringAI集成DeepSeek的详细步骤》DeepSeek作为一款卓越的国产AI模型,越来越多的公司考虑在自己的应用中集成,对于Java应用来说,我们可以借助SpringAI集成DeepSe... 目录DeepSeek 介绍Spring AI 是什么?1、环境准备2、构建项目2.1、pom依赖2.2

Goland debug失效详细解决步骤(合集)

《Golanddebug失效详细解决步骤(合集)》今天用Goland开发时,打断点,以debug方式运行,发现程序并没有断住,程序跳过了断点,直接运行结束,网上搜寻了大量文章,最后得以解决,特此在这... 目录Bug:Goland debug失效详细解决步骤【合集】情况一:Go或Goland架构不对情况二:

Python itertools中accumulate函数用法及使用运用详细讲解

《Pythonitertools中accumulate函数用法及使用运用详细讲解》:本文主要介绍Python的itertools库中的accumulate函数,该函数可以计算累积和或通过指定函数... 目录1.1前言:1.2定义:1.3衍生用法:1.3Leetcode的实际运用:总结 1.1前言:本文将详

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

最长公共子序列问题的深度分析与Java实现方式

《最长公共子序列问题的深度分析与Java实现方式》本文详细介绍了最长公共子序列(LCS)问题,包括其概念、暴力解法、动态规划解法,并提供了Java代码实现,暴力解法虽然简单,但在大数据处理中效率较低,... 目录最长公共子序列问题概述问题理解与示例分析暴力解法思路与示例代码动态规划解法DP 表的构建与意义动

Spring Boot整合log4j2日志配置的详细教程

《SpringBoot整合log4j2日志配置的详细教程》:本文主要介绍SpringBoot项目中整合Log4j2日志框架的步骤和配置,包括常用日志框架的比较、配置参数介绍、Log4j2配置详解... 目录前言一、常用日志框架二、配置参数介绍1. 日志级别2. 输出形式3. 日志格式3.1 PatternL

Springboot 中使用Sentinel的详细步骤

《Springboot中使用Sentinel的详细步骤》文章介绍了如何在SpringBoot中使用Sentinel进行限流和熔断降级,首先添加依赖,配置Sentinel控制台地址,定义受保护的资源,... 目录步骤 1: 添加 Sentinel 依赖步骤 2: 配置 Sentinel步骤 3: 定义受保护的

修改若依框架Token的过期时间问题

《修改若依框架Token的过期时间问题》本文介绍了如何修改若依框架中Token的过期时间,通过修改`application.yml`文件中的配置来实现,默认单位为分钟,希望此经验对大家有所帮助,也欢迎... 目录修改若依框架Token的过期时间修改Token的过期时间关闭Token的过期时js间总结修改若依