本文主要是介绍自己造轮子:深度学习dataloader自己实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
自己造轮子:深度学习dataloader自己实现
**摘要:**因为计算机性能的限制,所有的深度学习框架都是采用批量随机梯度下降,所以每次计算都要读取batch_size的数据。这里以自己实现的方式介绍深度学习框架实现批量读取数据的原理,不涉及具体细节和一些逻辑,只注重大体流程和原理。
总体流程:
- 采用yield写一个生成器函数实现批量图片/标注信息的读取
- 采用multiprocessing/threading加速文件读取
- 时间对比
深度学习大体流程
for i in range(epoch):data, lable = dataloader.next(batch_size=16) # 读取batch_size的数据output = model(data) # 前向传播loss = crition(output, label) # 求损失函数loss.backward() # 反向传播
在dataloader的时候,一般会采用多个进程(num_workers
)加快文件I/O的速度,避免网络反向传播过了,还没有数据。
1. 用yield写一个生成器函数
# coding:utf-8
# 自己造轮子,实现深度学习批量数据的读取
import os
import glob
import numpy as np
import cv2 def get_images(path):files = []for ext in ['jpg', 'png', 'jpeg', 'JPG']:files.extend(glob.glob(os.path.join(path, '*.{}'.format(ext))))return filesdef dataset(batch_size=2, path='/media/chenjun/data/1_deeplearning/7_ammeter_data/test'):"""写一个读取图片的生成器batch_size:批量大小path:图片路径"""# 1. 读取所有图片名字image_list = get_images(path)index = np.arange(0, len(image_list))while True:np.random.shuffle(index)images = []image_names = []for i in index:try:im_name = image_list[i]im = cv2.imread(im_name) # 读取图片# 读取相应图片的标注信息# text_polys = fun1()images.append(im[:,:, ::-1].astype(np.float32)) # cv2读取图片的顺序为BGR,转换成RGB格式image_names.append(im_name)if len(images) == batch_size:yield images, image_names # 采用函数生成器,生成一个可迭代对象images = []image_names = []except Exception as e:import tracebacktraceback.print_exc()continue # 所有图片已经读完一遍,跳出for循环,再打乱图片的顺序进行第二次读取
2. 使用muitlprocessing加速文件读取速度
<!-- 采用正常模式进行图片读取,读取100个batch -->
import time
mydataset = dataset()
start = time.time()
for _ in range(100):im, im_name = next(mydataset)
# print(im_name)
print('use time:{}'.format(time.time() - start))
>>> use time:0.16786599159240723<!-- 采用muitlprocessing模式进行图片读取,读取100个batch -->
import multiprocessing
def data_generator(data, q):for _ in range(100): # 循环多少次generator_output = next(data)q.put(generator_output)q = multiprocessing.Queue()
start2 = time.time()
thread = multiprocessing.Process(target=data_generator, args=(dataset(), q))
thread.start() # 多进程开始读取图片
print('mulprocess time is:{}'.format(time.time() - start2))
>>> mulprocess time is:0.002292633056640625
可以看到读取100个batch,时间提高了80倍。
同时,一般的深度学习框架都会使用几个多进程处理上面的功能。
eg:
for _ in range(workers):if self._use_multiprocessing:# Reset random seed else all children processes# share the same seednp.random.seed(self.random_seed)thread = multiprocessing.Process(target=data_generator_task)
网上的资料显示threading的效率没有muitlprocessing高,这里就不测试了。
reference
[1] 莫烦python
[2] argman/EAST
这篇关于自己造轮子:深度学习dataloader自己实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!