Pytorch.Dataloader 详细深度解读和微修改源代码心得

2024-08-30 07:38

本文主要是介绍Pytorch.Dataloader 详细深度解读和微修改源代码心得,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

关于pytorch 的dataloader库,使用pytorch 基本都会用到的一个库

今天遇到了一个问题,我在训练的时候,采用batch_size =2 去训练,最终的loss抖动太大了,看得出来应该是某些样本在打标的时候打的不好导致的,需要找出这些样本重新修正。但是一开始是采用的dataloader默认库。然后输入进去的图像dataset 传出来之后是经过shuffle的,没有办法定位到哪张图片。PS(如果有知道的大佬知道简单的直接调用的方法,可以评论学习下,万分感谢)

如果全部自己写的话,工作量就比较大了,测试流程较长。所以想是不是可以简单的修改源码来达到目的。

研究了一下源码,还是简单的。

现在将整个过程分享下:

首先,来看看dataloader这个库到底是干了什么事情

1.先定位到源码目录

在torch.utils.data下有2个文件是我们目前需要的,一个dataloader.py,这个文件封装了我们调用的dataloader类,返回的是一个迭代对象,也就是我们网络的输入。另一个是

sampler,实现了拆分输入数据为多个batch,在dataloader内会有关键的调用。了解了两个文件大致目的,我们详细解读下。来看看我们今天的主题:

class DataLoader(object):   __initialized = Falsedef __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None, multiprocessing_context=None):torch._C._log_api_usage_once("python.data_loader")if num_workers < 0:raise ValueError('num_workers option should be non-negative; ''use num_workers=0 to disable multiprocessing.')if timeout < 0:raise ValueError('timeout option should be non-negative')self.dataset = datasetself.num_workers = num_workersself.pin_memory = pin_memoryself.timeout = timeoutself.worker_init_fn = worker_init_fnself.multiprocessing_context = multiprocessing_contextif isinstance(dataset, IterableDataset):self._dataset_kind = _DatasetKind.Iterableif shuffle is not False:raise ValueError("DataLoader with IterableDataset: expected unspecified ""shuffle option, but got shuffle={}".format(shuffle))elif sampler is not None:# See NOTE [ Custom Samplers and IterableDataset ]raise ValueError("DataLoader with IterableDataset: expected unspecified ""sampler option, but got sampler={}".format(sampler))elif batch_sampler is not None:# See NOTE [ Custom Samplers and IterableDataset ]raise ValueError("DataLoader with IterableDataset: expected unspecified ""batch_sampler option, but got batch_sampler={}".format(batch_sampler))else:self._dataset_kind = _DatasetKind.Mapif sampler is not None and shuffle:raise ValueError('sampler option is mutually exclusive with ''shuffle')if batch_sampler is not None:# auto_collation with custom batch_samplerif batch_size != 1 or shuffle or sampler is not None or drop_last:raise ValueError('batch_sampler option is mutually exclusive ''with batch_size, shuffle, sampler, and ''drop_last')batch_size = Nonedrop_last = Falseelif batch_size is None:# no auto_collationif shuffle or drop_last:raise ValueError('batch_size=None option disables auto-batching ''and is mutually exclusive with ''shuffle, and drop_last')if sampler is None:  # give default samplersif self._dataset_kind == _DatasetKind.Iterable:# See NOTE [ Custom Samplers and IterableDataset ]sampler = _InfiniteConstantSampler()else:  # map-styleif shuffle:sampler = RandomSampler(dataset)else:sampler = SequentialSampler(dataset)if batch_size is not None and batch_sampler is None:# auto_collation without custom batch_samplerbatch_sampler = BatchSampler(sampler, batch_size, drop_last)self.batch_size = batch_sizeself.drop_last = drop_lastself.sampler = samplerself.batch_sampler = batch_samplerif collate_fn is None:if self._auto_collation:collate_fn = _utils.collate.default_collateelse:collate_fn = _utils.collate.default_convertself.collate_fn = collate_fnself.__initialized = True

__init__()函数里面有个重点,在

if batch_size is not None and batch_sampler is None:# auto_collation without custom batch_samplerbatch_sampler = BatchSampler(sampler, batch_size, drop_last)

这个函数就得到了打乱顺序的batch的index,以list的形式存放,比如[12,242,456,13]

看下这个class,

class BatchSampler(Sampler):r"""Wraps another sampler to yield a mini-batch of indices.Args:sampler (Sampler): Base sampler.batch_size (int): Size of mini-batch.drop_last (bool): If ``True``, the sampler will drop the last batch ifits size would be less than ``batch_size``Example:>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))[[0, 1, 2], [3, 4, 5], [6, 7, 8]]"""def __init__(self, sampler, batch_size, drop_last):if not isinstance(sampler, Sampler):raise ValueError("sampler should be an instance of ""torch.utils.data.Sampler, but got sampler={}".format(sampler))if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \batch_size <= 0:raise ValueError("batch_size should be a positive integer value, ""but got batch_size={}".format(batch_size))if not isinstance(drop_last, bool):raise ValueError("drop_last should be a boolean value, but got ""drop_last={}".format(drop_last))self.sampler = samplerself.batch_size = batch_sizeself.drop_last = drop_lastdef __iter__(self):batch = []for idx in self.sampler:batch.append(idx)if len(batch) == self.batch_size:yield batchbatch = []if len(batch) > 0 and not self.drop_last:yield batchdef __len__(self):if self.drop_last:return len(self.sampler) // self.batch_sizeelse:return (len(self.sampler) + self.batch_size - 1) // self.batch_size

在__iter__()里面实现了随机采样idx放入list中

 

下面看下dataloader中最重要的函数,也就是我们需要修改的地方

    def __iter__(self):if self.num_workers == 0:return _SingleProcessDataLoaderIter(self)else:return _MultiProcessingDataLoaderIter(self)

dataloarder通过这个函数返回迭代值,当我们的num_worker大于0的时候,也就是采用多进程方式读取数据。

进入下面这个_MultiProcessingDataLoaderIter(self) 函数

class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):def __init__(self, loader):super(_MultiProcessingDataLoaderIter, self).__init__(loader)assert self._num_workers > 0if loader.multiprocessing_context is None:multiprocessing_context = multiprocessingelse:multiprocessing_context = loader.multiprocessing_contextself._worker_init_fn = loader.worker_init_fnself._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))self._worker_result_queue = multiprocessing_context.Queue()self._worker_pids_set = Falseself._shutdown = Falseself._send_idx = 0  # idx of the next task to be sent to workersself._rcvd_idx = 0  # idx of the next task to be returned in __next__self._task_info = {}self._tasks_outstanding = 0  # always equal to count(v for v in task_info.values() if len(v) == 1)self._workers_done_event = multiprocessing_context.Event()self._index_queues = []self._workers = []self._workers_status = []for i in range(self._num_workers):index_queue = multiprocessing_context.Queue()# index_queue.cancel_join_thread()w = multiprocessing_context.Process(target=_utils.worker._worker_loop,args=(self._dataset_kind, self._dataset, index_queue,self._worker_result_queue, self._workers_done_event,self._auto_collation, self._collate_fn, self._drop_last,self._base_seed + i, self._worker_init_fn, i, self._num_workers))w.daemon = Truew.start()self._index_queues.append(index_queue)self._workers.append(w)self._workers_status.append(True)if self._pin_memory:self._pin_memory_thread_done_event = threading.Event()self._data_queue = queue.Queue()pin_memory_thread = threading.Thread(target=_utils.pin_memory._pin_memory_loop,args=(self._worker_result_queue, self._data_queue,torch.cuda.current_device(),self._pin_memory_thread_done_event))pin_memory_thread.daemon = Truepin_memory_thread.start()# Similar to workers (see comment above), we only register# pin_memory_thread once it is started.self._pin_memory_thread = pin_memory_threadelse:self._data_queue = self._worker_result_queue_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))_utils.signal_handling._set_SIGCHLD_handler()self._worker_pids_set = True# prime the prefetch loopfor _ in range(2 * self._num_workers):self._try_put_index()def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):try:data = self._data_queue.get(timeout=timeout)return (True, data)except Exception as e:failed_workers = []for worker_id, w in enumerate(self._workers):if self._workers_status[worker_id] and not w.is_alive():failed_workers.append(w)self._shutdown_worker(worker_id)if len(failed_workers) > 0:pids_str = ', '.join(str(w.pid) for w in failed_workers)raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))if isinstance(e, queue.Empty):return (False, None)raisedef _get_data(self):self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)if self._timeout > 0:success, data = self._try_get_data(self._timeout)if success:return dataelse:raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))elif self._pin_memory:while self._pin_memory_thread.is_alive():success, data = self._try_get_data()if success:return dataelse:# while condition is false, i.e., pin_memory_thread died.raise RuntimeError('Pin memory thread exited unexpectedly')# In this case, `self._data_queue` is a `queue.Queue`,. But we don't# need to call `.task_done()` because we don't use `.join()`.else:while True:success, data = self._try_get_data()if success:return datadef __next__(self):while True:while self._rcvd_idx < self._send_idx:info = self._task_info[self._rcvd_idx]worker_id = info[0]if len(info) == 2 or self._workers_status[worker_id]:  # has data or is still activebreakdel self._task_info[self._rcvd_idx]self._rcvd_idx += 1else:# no valid `self._rcvd_idx` is found (i.e., didn't break)self._shutdown_workers()raise StopIterationif len(self._task_info[self._rcvd_idx]) == 2:data = self._task_info.pop(self._rcvd_idx)[1]return self._process_data(data)assert not self._shutdown and self._tasks_outstanding > 0idx, data = self._get_data()self._tasks_outstanding -= 1if self._dataset_kind == _DatasetKind.Iterable:# Check for _IterableDatasetStopIterationif isinstance(data, _utils.worker._IterableDatasetStopIteration):self._shutdown_worker(data.worker_id)self._try_put_index()continueif idx != self._rcvd_idx:# store out-of-order samplesself._task_info[idx] += (data,)else:del self._task_info[idx]return self._process_data(data)next = __next__  # Python 2 compatibilitydef _try_put_index(self):assert self._tasks_outstanding < 2 * self._num_workerstry:index = self._next_index()except StopIteration:returnfor _ in range(self._num_workers):  # find the next active worker, if anyworker_queue_idx = next(self._worker_queue_idx_cycle)if self._workers_status[worker_queue_idx]:breakelse:# not found (i.e., didn't break)returnself._index_queues[worker_queue_idx].put((self._send_idx, index))self._task_info[self._send_idx] = (worker_queue_idx,)self._tasks_outstanding += 1self._send_idx += 1def _process_data(self, data):self._rcvd_idx += 1self._try_put_index()if isinstance(data, ExceptionWrapper):data.reraise()return datadef _shutdown_worker(self, worker_id):assert self._workers_status[worker_id]q = self._index_queues[worker_id]q.put(None)self._workers_status[worker_id] = Falsedef _shutdown_workers(self):python_exit_status = _utils.python_exit_statusif python_exit_status is True or python_exit_status is None:# See (2) of the note. If Python is shutting down, do no-op.returnif not self._shutdown:self._shutdown = Truetry:if hasattr(self, '_pin_memory_thread'):# Use hasattr in case error happens before we set the attribute.self._pin_memory_thread_done_event.set()# Send something to pin_memory_thread in case it is waiting# so that it can wake up and check `pin_memory_thread_done_event`self._worker_result_queue.put((None, None))self._pin_memory_thread.join()self._worker_result_queue.close()# Exit workers now.self._workers_done_event.set()for worker_id in range(len(self._workers)):if self._workers_status[worker_id]:self._shutdown_worker(worker_id)for w in self._workers:w.join()for q in self._index_queues:q.cancel_join_thread()q.close()finally:if self._worker_pids_set:_utils.signal_handling._remove_worker_pids(id(self))self._worker_pids_set = Falsedef __del__(self):self._shutdown_workers()

找到我们的__next__()函数,也就是这个类的输出,我们看下输出的是

return self._process_data(data)

再进入self._process_data(data)

def _process_data(self, data):self._rcvd_idx += 1self._try_put_index()if isinstance(data, ExceptionWrapper):data.reraise()return data

发现了我们的data还是data,没有变,所以我们需要找到这个data值的来源,在输出这个地方加入我们需要自己传出去的值

当然,不要在源码上修改,我们copy一份文件dataloader_util 放在自己项目合适的位置,然后修改代码传出自己需要的数据。

我们刚才已经找到了我们的打乱的batch集合,也找到了我们数据出口。

接下去我们讲batch一起传出来即可

1.修改import

from torch.utils.data.sampler import Sampler, SequentialSampler, RandomSampler, BatchSampler
from torch.utils.data.dataset import IterableDataset

2.在__init__里面加上代码用来保存我们的batch:

# 新增数据下标self._index_list = []

3.在下面函数的最后一行赋值

    def _try_put_index(self):assert self._tasks_outstanding < 2 * self._num_workerstry:index = self._next_index()except StopIteration:returnfor _ in range(self._num_workers):  # find the next active worker, if anyworker_queue_idx = next(self._worker_queue_idx_cycle)if self._workers_status[worker_queue_idx]:breakelse:# not found (i.e., didn't break)returnself._index_queues[worker_queue_idx].put((self._send_idx, index))self._task_info[self._send_idx] = (worker_queue_idx,)self._tasks_outstanding += 1self._send_idx += 1self._index_list = index

4.在下面函数中将返回值加上

    def _process_data(self, data):self._rcvd_idx += 1self._try_put_index()if isinstance(data, ExceptionWrapper):data.reraise()return data, self._index_list

5.在训练的时候修改下获取数据的代码

    for i, (train_images, index_) in enumerate(train_loader):images, labels, training_mask = train_imagesshuffle_image_path = [train_loader.dataset.data_list[x][0] for x in index_]

记得调用这个train_loader 的路径改成自己copy出来的,现在这个shuffle_image_path 就是我要的数据了。可以愉快的打印保存了。

是不是很简单呢?其他的获取自己相应要修改的数据,也可以这么获取

这篇关于Pytorch.Dataloader 详细深度解读和微修改源代码心得的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1120196

相关文章

Spring Boot集成Druid实现数据源管理与监控的详细步骤

《SpringBoot集成Druid实现数据源管理与监控的详细步骤》本文介绍如何在SpringBoot项目中集成Druid数据库连接池,包括环境搭建、Maven依赖配置、SpringBoot配置文件... 目录1. 引言1.1 环境准备1.2 Druid介绍2. 配置Druid连接池3. 查看Druid监控

创建Java keystore文件的完整指南及详细步骤

《创建Javakeystore文件的完整指南及详细步骤》本文详解Java中keystore的创建与配置,涵盖私钥管理、自签名与CA证书生成、SSL/TLS应用,强调安全存储及验证机制,确保通信加密和... 目录1. 秘密键(私钥)的理解与管理私钥的定义与重要性私钥的管理策略私钥的生成与存储2. 证书的创建与

使用Docker构建Python Flask程序的详细教程

《使用Docker构建PythonFlask程序的详细教程》在当今的软件开发领域,容器化技术正变得越来越流行,而Docker无疑是其中的佼佼者,本文我们就来聊聊如何使用Docker构建一个简单的Py... 目录引言一、准备工作二、创建 Flask 应用程序三、创建 dockerfile四、构建 Docker

Python设置Cookie永不超时的详细指南

《Python设置Cookie永不超时的详细指南》Cookie是一种存储在用户浏览器中的小型数据片段,用于记录用户的登录状态、偏好设置等信息,下面小编就来和大家详细讲讲Python如何设置Cookie... 目录一、Cookie的作用与重要性二、Cookie过期的原因三、实现Cookie永不超时的方法(一)

解读GC日志中的各项指标用法

《解读GC日志中的各项指标用法》:本文主要介绍GC日志中的各项指标用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、基础 GC 日志格式(以 G1 为例)1. Minor GC 日志2. Full GC 日志二、关键指标解析1. GC 类型与触发原因2. 堆

Java设计模式---迭代器模式(Iterator)解读

《Java设计模式---迭代器模式(Iterator)解读》:本文主要介绍Java设计模式---迭代器模式(Iterator),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,... 目录1、迭代器(Iterator)1.1、结构1.2、常用方法1.3、本质1、解耦集合与遍历逻辑2、统一

深度解析Java DTO(最新推荐)

《深度解析JavaDTO(最新推荐)》DTO(DataTransferObject)是一种用于在不同层(如Controller层、Service层)之间传输数据的对象设计模式,其核心目的是封装数据,... 目录一、什么是DTO?DTO的核心特点:二、为什么需要DTO?(对比Entity)三、实际应用场景解析

深度解析Java项目中包和包之间的联系

《深度解析Java项目中包和包之间的联系》文章浏览阅读850次,点赞13次,收藏8次。本文详细介绍了Java分层架构中的几个关键包:DTO、Controller、Service和Mapper。_jav... 目录前言一、各大包1.DTO1.1、DTO的核心用途1.2. DTO与实体类(Entity)的区别1

SpringBoot整合liteflow的详细过程

《SpringBoot整合liteflow的详细过程》:本文主要介绍SpringBoot整合liteflow的详细过程,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋...  liteflow 是什么? 能做什么?总之一句话:能帮你规范写代码逻辑 ,编排并解耦业务逻辑,代码

MySQL之InnoDB存储页的独立表空间解读

《MySQL之InnoDB存储页的独立表空间解读》:本文主要介绍MySQL之InnoDB存储页的独立表空间,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、独立表空间【1】表空间大小【2】区【3】组【4】段【5】区的类型【6】XDES Entry区结构【