本文主要是介绍在pytorch中load超大训练数据,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
在pytorch中load超大训练数据
by joeyqzhou
相关代码地址: https://github.com/joeyqzhou/blog/tree/master/pytorch%E4%B8%ADload%E8%B6%85%E5%A4%A7%E8%AE%AD%E7%BB%83%E6%95%B0%E6%8D%AE
最简单方式:
1 单线程获取数据到内存中
2 train的过程
for epoch in range(num_epochs):for i in range(inst_size): #截取 batch_x, batch_y#batch_x, batch_y 转换为tensor#model.forward()#loss.backward()#optimizer.step()
这种方式代码简单。缺点load数据过慢,数据全部存储在内存当中。
当训练数据过大的时候load很慢,内存会溢出
多进程load数据
如下是一个多进程load数据的例子
from multiprocessing import Pooldef process_line(line):return "FOO: %s" % lineif __name__ == "__main__":pool = Pool(4)file = "train.txt" #你的输入数据ret = []with open(file) as source_file:# chunk the work into batches of 4 lines at a timeresults = pool.map(process_line, source_file, 4)
这篇关于在pytorch中load超大训练数据的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!