本文主要是介绍Pytorch.Dataloader 详细深度解读和微修改源代码心得,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
关于pytorch 的dataloader库,使用pytorch 基本都会用到的一个库
今天遇到了一个问题,我在训练的时候,采用batch_size =2 去训练,最终的loss抖动太大了,看得出来应该是某些样本在打标的时候打的不好导致的,需要找出这些样本重新修正。但是一开始是采用的dataloader默认库。然后输入进去的图像dataset 传出来之后是经过shuffle的,没有办法定位到哪张图片。PS(如果有知道的大佬知道简单的直接调用的方法,可以评论学习下,万分感谢)
如果全部自己写的话,工作量就比较大了,测试流程较长。所以想是不是可以简单的修改源码来达到目的。
研究了一下源码,还是简单的。
现在将整个过程分享下:
首先,来看看dataloader这个库到底是干了什么事情
1.先定位到源码目录
在torch.utils.data下有2个文件是我们目前需要的,一个dataloader.py,这个文件封装了我们调用的dataloader类,返回的是一个迭代对象,也就是我们网络的输入。另一个是
sampler,实现了拆分输入数据为多个batch,在dataloader内会有关键的调用。了解了两个文件大致目的,我们详细解读下。来看看我们今天的主题:
class DataLoader(object): __initialized = Falsedef __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None, multiprocessing_context=None):torch._C._log_api_usage_once("python.data_loader")if num_workers < 0:raise ValueError('num_workers option should be non-negative; ''use num_workers=0 to disable multiprocessing.')if timeout < 0:raise ValueError('timeout option should be non-negative')self.dataset = datasetself.num_workers = num_workersself.pin_memory = pin_memoryself.timeout = timeoutself.worker_init_fn = worker_init_fnself.multiprocessing_context = multiprocessing_contextif isinstance(dataset, IterableDataset):self._dataset_kind = _DatasetKind.Iterableif shuffle is not False:raise ValueError("DataLoader with IterableDataset: expected unspecified ""shuffle option, but got shuffle={}".format(shuffle))elif sampler is not None:# See NOTE [ Custom Samplers and IterableDataset ]raise ValueError("DataLoader with IterableDataset: expected unspecified ""sampler option, but got sampler={}".format(sampler))elif batch_sampler is not None:# See NOTE [ Custom Samplers and IterableDataset ]raise ValueError("DataLoader with IterableDataset: expected unspecified ""batch_sampler option, but got batch_sampler={}".format(batch_sampler))else:self._dataset_kind = _DatasetKind.Mapif sampler is not None and shuffle:raise ValueError('sampler option is mutually exclusive with ''shuffle')if batch_sampler is not None:# auto_collation with custom batch_samplerif batch_size != 1 or shuffle or sampler is not None or drop_last:raise ValueError('batch_sampler option is mutually exclusive ''with batch_size, shuffle, sampler, and ''drop_last')batch_size = Nonedrop_last = Falseelif batch_size is None:# no auto_collationif shuffle or drop_last:raise ValueError('batch_size=None option disables auto-batching ''and is mutually exclusive with ''shuffle, and drop_last')if sampler is None: # give default samplersif self._dataset_kind == _DatasetKind.Iterable:# See NOTE [ Custom Samplers and IterableDataset ]sampler = _InfiniteConstantSampler()else: # map-styleif shuffle:sampler = RandomSampler(dataset)else:sampler = SequentialSampler(dataset)if batch_size is not None and batch_sampler is None:# auto_collation without custom batch_samplerbatch_sampler = BatchSampler(sampler, batch_size, drop_last)self.batch_size = batch_sizeself.drop_last = drop_lastself.sampler = samplerself.batch_sampler = batch_samplerif collate_fn is None:if self._auto_collation:collate_fn = _utils.collate.default_collateelse:collate_fn = _utils.collate.default_convertself.collate_fn = collate_fnself.__initialized = True
__init__()函数里面有个重点,在
if batch_size is not None and batch_sampler is None:# auto_collation without custom batch_samplerbatch_sampler = BatchSampler(sampler, batch_size, drop_last)
这个函数就得到了打乱顺序的batch的index,以list的形式存放,比如[12,242,456,13]
看下这个class,
class BatchSampler(Sampler):r"""Wraps another sampler to yield a mini-batch of indices.Args:sampler (Sampler): Base sampler.batch_size (int): Size of mini-batch.drop_last (bool): If ``True``, the sampler will drop the last batch ifits size would be less than ``batch_size``Example:>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))[[0, 1, 2], [3, 4, 5], [6, 7, 8]]"""def __init__(self, sampler, batch_size, drop_last):if not isinstance(sampler, Sampler):raise ValueError("sampler should be an instance of ""torch.utils.data.Sampler, but got sampler={}".format(sampler))if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \batch_size <= 0:raise ValueError("batch_size should be a positive integer value, ""but got batch_size={}".format(batch_size))if not isinstance(drop_last, bool):raise ValueError("drop_last should be a boolean value, but got ""drop_last={}".format(drop_last))self.sampler = samplerself.batch_size = batch_sizeself.drop_last = drop_lastdef __iter__(self):batch = []for idx in self.sampler:batch.append(idx)if len(batch) == self.batch_size:yield batchbatch = []if len(batch) > 0 and not self.drop_last:yield batchdef __len__(self):if self.drop_last:return len(self.sampler) // self.batch_sizeelse:return (len(self.sampler) + self.batch_size - 1) // self.batch_size
在__iter__()里面实现了随机采样idx放入list中
下面看下dataloader中最重要的函数,也就是我们需要修改的地方
def __iter__(self):if self.num_workers == 0:return _SingleProcessDataLoaderIter(self)else:return _MultiProcessingDataLoaderIter(self)
dataloarder通过这个函数返回迭代值,当我们的num_worker大于0的时候,也就是采用多进程方式读取数据。
进入下面这个_MultiProcessingDataLoaderIter(self) 函数
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):def __init__(self, loader):super(_MultiProcessingDataLoaderIter, self).__init__(loader)assert self._num_workers > 0if loader.multiprocessing_context is None:multiprocessing_context = multiprocessingelse:multiprocessing_context = loader.multiprocessing_contextself._worker_init_fn = loader.worker_init_fnself._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))self._worker_result_queue = multiprocessing_context.Queue()self._worker_pids_set = Falseself._shutdown = Falseself._send_idx = 0 # idx of the next task to be sent to workersself._rcvd_idx = 0 # idx of the next task to be returned in __next__self._task_info = {}self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)self._workers_done_event = multiprocessing_context.Event()self._index_queues = []self._workers = []self._workers_status = []for i in range(self._num_workers):index_queue = multiprocessing_context.Queue()# index_queue.cancel_join_thread()w = multiprocessing_context.Process(target=_utils.worker._worker_loop,args=(self._dataset_kind, self._dataset, index_queue,self._worker_result_queue, self._workers_done_event,self._auto_collation, self._collate_fn, self._drop_last,self._base_seed + i, self._worker_init_fn, i, self._num_workers))w.daemon = Truew.start()self._index_queues.append(index_queue)self._workers.append(w)self._workers_status.append(True)if self._pin_memory:self._pin_memory_thread_done_event = threading.Event()self._data_queue = queue.Queue()pin_memory_thread = threading.Thread(target=_utils.pin_memory._pin_memory_loop,args=(self._worker_result_queue, self._data_queue,torch.cuda.current_device(),self._pin_memory_thread_done_event))pin_memory_thread.daemon = Truepin_memory_thread.start()# Similar to workers (see comment above), we only register# pin_memory_thread once it is started.self._pin_memory_thread = pin_memory_threadelse:self._data_queue = self._worker_result_queue_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))_utils.signal_handling._set_SIGCHLD_handler()self._worker_pids_set = True# prime the prefetch loopfor _ in range(2 * self._num_workers):self._try_put_index()def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):try:data = self._data_queue.get(timeout=timeout)return (True, data)except Exception as e:failed_workers = []for worker_id, w in enumerate(self._workers):if self._workers_status[worker_id] and not w.is_alive():failed_workers.append(w)self._shutdown_worker(worker_id)if len(failed_workers) > 0:pids_str = ', '.join(str(w.pid) for w in failed_workers)raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))if isinstance(e, queue.Empty):return (False, None)raisedef _get_data(self):self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)if self._timeout > 0:success, data = self._try_get_data(self._timeout)if success:return dataelse:raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))elif self._pin_memory:while self._pin_memory_thread.is_alive():success, data = self._try_get_data()if success:return dataelse:# while condition is false, i.e., pin_memory_thread died.raise RuntimeError('Pin memory thread exited unexpectedly')# In this case, `self._data_queue` is a `queue.Queue`,. But we don't# need to call `.task_done()` because we don't use `.join()`.else:while True:success, data = self._try_get_data()if success:return datadef __next__(self):while True:while self._rcvd_idx < self._send_idx:info = self._task_info[self._rcvd_idx]worker_id = info[0]if len(info) == 2 or self._workers_status[worker_id]: # has data or is still activebreakdel self._task_info[self._rcvd_idx]self._rcvd_idx += 1else:# no valid `self._rcvd_idx` is found (i.e., didn't break)self._shutdown_workers()raise StopIterationif len(self._task_info[self._rcvd_idx]) == 2:data = self._task_info.pop(self._rcvd_idx)[1]return self._process_data(data)assert not self._shutdown and self._tasks_outstanding > 0idx, data = self._get_data()self._tasks_outstanding -= 1if self._dataset_kind == _DatasetKind.Iterable:# Check for _IterableDatasetStopIterationif isinstance(data, _utils.worker._IterableDatasetStopIteration):self._shutdown_worker(data.worker_id)self._try_put_index()continueif idx != self._rcvd_idx:# store out-of-order samplesself._task_info[idx] += (data,)else:del self._task_info[idx]return self._process_data(data)next = __next__ # Python 2 compatibilitydef _try_put_index(self):assert self._tasks_outstanding < 2 * self._num_workerstry:index = self._next_index()except StopIteration:returnfor _ in range(self._num_workers): # find the next active worker, if anyworker_queue_idx = next(self._worker_queue_idx_cycle)if self._workers_status[worker_queue_idx]:breakelse:# not found (i.e., didn't break)returnself._index_queues[worker_queue_idx].put((self._send_idx, index))self._task_info[self._send_idx] = (worker_queue_idx,)self._tasks_outstanding += 1self._send_idx += 1def _process_data(self, data):self._rcvd_idx += 1self._try_put_index()if isinstance(data, ExceptionWrapper):data.reraise()return datadef _shutdown_worker(self, worker_id):assert self._workers_status[worker_id]q = self._index_queues[worker_id]q.put(None)self._workers_status[worker_id] = Falsedef _shutdown_workers(self):python_exit_status = _utils.python_exit_statusif python_exit_status is True or python_exit_status is None:# See (2) of the note. If Python is shutting down, do no-op.returnif not self._shutdown:self._shutdown = Truetry:if hasattr(self, '_pin_memory_thread'):# Use hasattr in case error happens before we set the attribute.self._pin_memory_thread_done_event.set()# Send something to pin_memory_thread in case it is waiting# so that it can wake up and check `pin_memory_thread_done_event`self._worker_result_queue.put((None, None))self._pin_memory_thread.join()self._worker_result_queue.close()# Exit workers now.self._workers_done_event.set()for worker_id in range(len(self._workers)):if self._workers_status[worker_id]:self._shutdown_worker(worker_id)for w in self._workers:w.join()for q in self._index_queues:q.cancel_join_thread()q.close()finally:if self._worker_pids_set:_utils.signal_handling._remove_worker_pids(id(self))self._worker_pids_set = Falsedef __del__(self):self._shutdown_workers()
找到我们的__next__()函数,也就是这个类的输出,我们看下输出的是
return self._process_data(data)
再进入self._process_data(data)
def _process_data(self, data):self._rcvd_idx += 1self._try_put_index()if isinstance(data, ExceptionWrapper):data.reraise()return data
发现了我们的data还是data,没有变,所以我们需要找到这个data值的来源,在输出这个地方加入我们需要自己传出去的值
当然,不要在源码上修改,我们copy一份文件dataloader_util 放在自己项目合适的位置,然后修改代码传出自己需要的数据。
我们刚才已经找到了我们的打乱的batch集合,也找到了我们数据出口。
接下去我们讲batch一起传出来即可
1.修改import
from torch.utils.data.sampler import Sampler, SequentialSampler, RandomSampler, BatchSampler
from torch.utils.data.dataset import IterableDataset
2.在__init__里面加上代码用来保存我们的batch:
# 新增数据下标self._index_list = []
3.在下面函数的最后一行赋值
def _try_put_index(self):assert self._tasks_outstanding < 2 * self._num_workerstry:index = self._next_index()except StopIteration:returnfor _ in range(self._num_workers): # find the next active worker, if anyworker_queue_idx = next(self._worker_queue_idx_cycle)if self._workers_status[worker_queue_idx]:breakelse:# not found (i.e., didn't break)returnself._index_queues[worker_queue_idx].put((self._send_idx, index))self._task_info[self._send_idx] = (worker_queue_idx,)self._tasks_outstanding += 1self._send_idx += 1self._index_list = index
4.在下面函数中将返回值加上
def _process_data(self, data):self._rcvd_idx += 1self._try_put_index()if isinstance(data, ExceptionWrapper):data.reraise()return data, self._index_list
5.在训练的时候修改下获取数据的代码
for i, (train_images, index_) in enumerate(train_loader):images, labels, training_mask = train_imagesshuffle_image_path = [train_loader.dataset.data_list[x][0] for x in index_]
记得调用这个train_loader 的路径改成自己copy出来的,现在这个shuffle_image_path 就是我要的数据了。可以愉快的打印保存了。
是不是很简单呢?其他的获取自己相应要修改的数据,也可以这么获取
这篇关于Pytorch.Dataloader 详细深度解读和微修改源代码心得的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!