tensorflow将图片保存为tfrecord和tfrecord的读取

2023-12-22 16:18

本文主要是介绍tensorflow将图片保存为tfrecord和tfrecord的读取,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

tensorflow官方提供了3种方法来读取数据:

  • 预加载数据(preloaded data):在TensorFlow图中定义常量或变量来保存所有的数据,适用于数据量不太大的情况。
  • 填充数据(feeding):通过Python产生数据,然后再把数据填充到后端。
  • 从文件读取数据(reading from file):从文件中直接读取,然后通过队列管理器从文件中读取数据。

本文主要介绍第三种方法,通过tfrecord文件来保存和读取数据,对于前两种读取数据的方式也会进行一个简单的介绍。

项目下载github地址:https://github.com/steelOneself/tensorflow_learn/tree/master/tf_records_writer_read

一、预加载数据

    a = tf.constant([1,2,3])b = tf.constant([4,5,6])c = tf.add(a,b)with tf.Session() as sess:print(sess.run(c))#[5 7 9]

这种方式加载数据比较简单,它是直接将数据嵌入在数据流图中,当训练数据较大时,比较消耗内存。

二、填充数据

通过先定义placeholder然后再通过feed_dict来喂养数据,这种方式在TensorFlow中使用的也是比较多的,但是也存在数据量大时比较消耗内存的缺点,下面介绍一种更高效的数据读取方式,通过tfrecord文件来读取数据。

    x = tf.placeholder(tf.int16)y = tf.placeholder(tf.int16)z = tf.add(x,y)with tf.Session() as sess:print(sess.run(z,feed_dict={x:[1,2,3],y:[4,5,6]}))#[5 7 9]

三、从文件读取数据

通过slim来实现将图片保存为tfrecord文件和tfrecord文件的读取,slim是基于TensorFlow的一个更高级别的封装模型,通过slim来编程可以实现更高效率和更简洁的代码。

在本次实验中使用的数据集是kaggle的dog vs cat,数据集下载地址:https://www.kaggle.com/c/dogs-vs-cats/data

1、tfrecord文件的保存

a、参数设置

  • dataset_dir_path:训练集图片存放的上级目录(train下还有一个train目录用来存放图片),在dog vs cat数据集中,dog和cat类的区别是依靠图片的名称,如果你的数据集通过文件夹的名称来划分图片类标的,可能需要对代码进行部分修改。
  • label_name_to_num:字符串类标与数字类标的对应关系,在将图片保存为tfrecord文件的时候,需要将字符串转为整数类标0和1,方便后的训练。
  • label_num_to_name:数字类标与字符串类标的对应关系。
  • val_size:验证集在训练集中所占的比例,训练集一共有25000张图片,用20000张来训练,5000张来进行验证。
  • batch_size:在读取tfrecord文件的时候,每次读取图片的数量。
#数据所在的目录路径
dataset_dir_path = "D:/dataset/kaggle/cat_or_dog/train"
#类标名称和数字的对应关系
label_name_to_num = {"cat":0,"dog":1}
label_num_to_name = {value:key for key,value in label_name_to_num.items()}
#设置验证集占整个数据集的比例
val_size = 0.2
batch_size = 1

b、获取训练集所有的图片路径

获取训练目录下所有的dog和cat的图片路径,将它们分开保存,便于后面训练集和验证集数据的划分,保证每类图片在所占的比例相同。

#获取文件所在路径dataset_dir = os.path.join(dataset_dir,split_name)#遍历目录下的所有图片for filename in os.listdir(dataset_dir):#获取文件的路径file_path = os.path.join(dataset_dir,filename)if file_path.endswith("jpg") and os.path.exists(file_path):#获取类别的名称label_name = filename.split(".")[0]if label_name == "cat":cat_img_paths.append(file_path)elif label_name == "dog":dog_img_paths.append(file_path)return cat_img_paths,dog_img_paths

c、设置需要保存的图片信息

对于训练集的图片主要保存图片的字节数据、图片的格式、图片的标签、图片的高和宽,测试集保存为tfrecord文件的时候需要保存图片的名称,因为在提交数据的时候需要用到图片的名称信息。在保存图片信息的时候,需要先将这些信息转换为byte数据才能写入到tfrecord文件中。

def int64_feature(values):if not isinstance(values, (tuple, list)):values = [values]return tf.train.Feature(int64_list=tf.train.Int64List(value=values))def bytes_feature(values):return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))#将图片信息转换为tfrecords可以保存的序列化信息
def image_to_tfexample(split_name,image_data, image_format, height, width, img_info):''':param split_name: train或val或test:param image_data: 图片的二进制数据:param image_format: 图片的格式:param height: 图片的高:param width: 图片的宽:param img_info: 图片的标签或图片的名称,当split_name为test时,img_info为图片的名称否则为图片标签:return:'''if split_name == "test":return tf.train.Example(features=tf.train.Features(feature={'image/encoded': bytes_feature(image_data),'image/format': bytes_feature(image_format),'image/img_name': bytes_feature(img_info),'image/height': int64_feature(height),'image/width': int64_feature(width),}))else:return tf.train.Example(features=tf.train.Features(feature={'image/encoded': bytes_feature(image_data),'image/format': bytes_feature(image_format),'image/label': int64_feature(img_info),'image/height': int64_feature(height),'image/width': int64_feature(width),}))

d、保存tfrecord文件

主要是通过TFRecordWriter来保存tfrecord文件,在将图片信息保存为tfrecord文件的时候,需要先将图片信息序列化为字符串才能进行写入。ImageReader类可以将图片字节数据解码为指定格式的图片,获取图片的宽和高信息。_get_dataset_filename函数是通过数据集的名称和split_name的名称来组合获取tfrecord文件的名称,tfrecord名称如下:

def _convert_tfrecord_dataset(split_name, filenames, label_name_to_id, 
dataset_dir, tfrecord_filename, _NUM_SHARDS):''':param split_name:train或val或test:param filenames:图片的路径列表:param label_name_to_id:标签名与数字标签的对应关系:param dataset_dir:数据存放的目录:param tfrecord_filename:文件保存的前缀名:param _NUM_SHARDS:将整个数据集分为几个文件:return:'''assert split_name in ['train', 'val','test']#计算平均每一个tfrecords文件保存多少张图片num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))with tf.Graph().as_default():image_reader = ImageReader()with tf.Session('') as sess:for shard_id in range(_NUM_SHARDS):#获取tfrecord文件的名称output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id,tfrecord_filename = tfrecord_filename, _NUM_SHARDS = _NUM_SHARDS)#写tfrecords文件with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:start_ndx = shard_id * num_per_shardend_ndx = min((shard_id+1) * num_per_shard, len(filenames))for i in range(start_ndx, end_ndx):#更新控制台中已经完成的图片数量sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i+1, len(filenames), shard_id))sys.stdout.flush()#读取图片,将图片数据读取为bytesimage_data = tf.gfile.FastGFile(filenames[i], 'rb').read()#获取图片的高和宽height, width = image_reader.read_image_dims(sess, image_data)#获取路径中的图片名称img_name = os.path.basename(filenames[i])if split_name == "test":#需要将图片名称转换为二进制example = image_to_tfexample(split_name,image_data, b'jpg', height, width, img_name.encode())tfrecord_writer.write(example.SerializeToString())else:#获取图片的类别class_name = img_name.split(".")[0]label_id = label_name_to_id[class_name]example = image_to_tfexample(split_name,image_data, b'jpg', height, width, label_id)tfrecord_writer.write(example.SerializeToString())sys.stdout.write('\n')sys.stdout.flush()

e、将数据集分为验证集和训练集保存为tfrecord文件

先获取数据集中所有图片的路径和图片的标签信息,将不同类别的图片分为训练集和验证集,并保证训练集和验证集中不同类别的图片数量保持相同,在保存为tfrecord文件之前,打乱所有图片的路径。将训练集分为了2个tfrecord文件,验证集保存为1个tfrecord文件。

#生成tfrecord文件
def generate_tfreocrd():#获取目录下所有的猫和狗图片的路径cat_img_paths,dog_img_paths = _get_dateset_imgPaths(dataset_dir_path,"train")#打乱路径列表的顺序np.random.shuffle(cat_img_paths)np.random.shuffle(dog_img_paths)#计算不同类别验证集所占的图片数量cat_val_num = int(len(cat_img_paths) * val_size)dog_val_num = int(len(dog_img_paths) * val_size)#将所有的图片路径分为训练集和验证集train_img_paths = cat_img_paths[cat_val_num:]val_img_paths = cat_img_paths[:cat_val_num]train_img_paths.extend(dog_img_paths[dog_val_num:])val_img_paths.extend(dog_img_paths[:dog_val_num])#打乱训练集和验证集的顺序np.random.shuffle(train_img_paths)np.random.shuffle(val_img_paths)#将训练集保存为tfrecord文件_convert_tfrecord_dataset("train",train_img_paths,label_name_to_num,dataset_dir_path,"catVSdog",2)#将验证集保存为tfrecord文件_convert_tfrecord_dataset("val",val_img_paths,label_name_to_num,dataset_dir_path,"catVSdog",1)

通过控制台你能够看到tfrecord文件的保存进度

2、从tfrecord文件中读取数据

a、读取tfrecord文件,将数据转换为dataset

通过TFRecordReader来读取tfrecord文件,在读取tfrecord文件时需要通过tf.FixedLenFeature来反序列化存储的图片信息,这里我们只读取图片数据和图片的标签,再通过slim模块将图片数据和标签信息存储为一个dataset。

    #创建一个tfrecord读文件对象reader = tf.TFRecordReaderkeys_to_feature = {"image/encoded":tf.FixedLenFeature((),tf.string,default_value=""),"image/format":tf.FixedLenFeature((),tf.string,default_value="jpg"),"image/label":tf.FixedLenFeature([],tf.int64,default_value=tf.zeros([],tf.int64))}items_to_handles = {"image":slim.tfexample_decoder.Image(),"label":slim.tfexample_decoder.Tensor("image/label")}items_to_descriptions = {"image":"a 3-channel RGB image","img_name":"a image label"}#创建一个tfrecoder解析对象decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_feature,items_to_handles)#读取所有的tfrecord文件,创建数据集dataset = slim.dataset.Dataset(data_sources = tfrecord_paths,decoder = decoder,reader = reader,num_readers = 4,num_samples = num_imgs,num_classes = num_classes,labels_to_name = labels_to_name,items_to_descriptions = items_to_descriptions)

b、获取batch数据

preprocessing_image对图片进行预处理,对图片进行数据增强,输出后的图片尺寸由height和width参数决定,固定图片的尺寸方便CNN的模型训练。

def load_batch(split_name,dataset,batch_size,height,width):data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset,common_queue_capacity = 24 + 3 * batch_size,common_queue_min = 24)raw_image,img_label = data_provider.get(["image","label"])#Perform the correct preprocessing for this image depending if it is training or evaluatingimage = preprocess_image(raw_image, height, width,True)#As for the raw images, we just do a simple reshape to batch it upraw_image = tf.expand_dims(raw_image, 0)raw_image = tf.image.resize_nearest_neighbor(raw_image, [height, width])raw_image = tf.squeeze(raw_image)#获取一个batch数据images,raw_image,labels = tf.train.batch([image,raw_image,img_label],batch_size=batch_size,num_threads=4,capacity=4*batch_size,allow_smaller_final_batch=True)return images,raw_image,labels

c、读取tfrecord文件

#读取tfrecord文件
def read_tfrecord():#从tfreocrd文件中读取数据train_dataset = get_dataset_by_tfrecords("train",dataset_dir_path,"catVSdog",2,label_num_to_name)images,raw_images,labels = load_batch("train",train_dataset,batch_size,227,227)with tf.Session() as sess:threads = tf.train.start_queue_runners(sess)for i in range(6):train_img,train_label = sess.run([raw_images,labels])plt.subplot(2,3,i+1)plt.imshow(np.array(train_img[0]))plt.title("image label:%s"%str(label_num_to_name[train_label[0]]))plt.show()

读取训练集的tfrecord文件,只从tfrecord文件中获取了图片数据和图片的标签,images表示的是预处理后的图片,raw_images表示的是没有经过预处理的图片。

这篇关于tensorflow将图片保存为tfrecord和tfrecord的读取的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/524608

相关文章

C#实现将Excel表格转换为图片(JPG/ PNG)

《C#实现将Excel表格转换为图片(JPG/PNG)》Excel表格可能会因为不同设备或字体缺失等问题,导致格式错乱或数据显示异常,转换为图片后,能确保数据的排版等保持一致,下面我们看看如何使用C... 目录通过C# 转换Excel工作表到图片通过C# 转换指定单元格区域到图片知识扩展C# 将 Excel

JS+HTML实现在线图片水印添加工具

《JS+HTML实现在线图片水印添加工具》在社交媒体和内容创作日益频繁的今天,如何保护原创内容、展示品牌身份成了一个不得不面对的问题,本文将实现一个完全基于HTML+CSS构建的现代化图片水印在线工具... 目录概述功能亮点使用方法技术解析延伸思考运行效果项目源码下载总结概述在社交媒体和内容创作日益频繁的

使用Node.js制作图片上传服务的详细教程

《使用Node.js制作图片上传服务的详细教程》在现代Web应用开发中,图片上传是一项常见且重要的功能,借助Node.js强大的生态系统,我们可以轻松搭建高效的图片上传服务,本文将深入探讨如何使用No... 目录准备工作搭建 Express 服务器配置 multer 进行图片上传处理图片上传请求完整代码示例

基于Python实现高效PPT转图片工具

《基于Python实现高效PPT转图片工具》在日常工作中,PPT是我们常用的演示工具,但有时候我们需要将PPT的内容提取为图片格式以便于展示或保存,所以本文将用Python实现PPT转PNG工具,希望... 目录1. 概述2. 功能使用2.1 安装依赖2.2 使用步骤2.3 代码实现2.4 GUI界面3.效

Python实现AVIF图片与其他图片格式间的批量转换

《Python实现AVIF图片与其他图片格式间的批量转换》这篇文章主要为大家详细介绍了如何使用Pillow库实现AVIF与其他格式的相互转换,即将AVIF转换为常见的格式,比如JPG或PNG,需要的小... 目录环境配置1.将单个 AVIF 图片转换为 JPG 和 PNG2.批量转换目录下所有 AVIF 图

详解如何通过Python批量转换图片为PDF

《详解如何通过Python批量转换图片为PDF》:本文主要介绍如何基于Python+Tkinter开发的图片批量转PDF工具,可以支持批量添加图片,拖拽等操作,感兴趣的小伙伴可以参考一下... 目录1. 概述2. 功能亮点2.1 主要功能2.2 界面设计3. 使用指南3.1 运行环境3.2 使用步骤4. 核

Java图片压缩三种高效压缩方案详细解析

《Java图片压缩三种高效压缩方案详细解析》图片压缩通常涉及减少图片的尺寸缩放、调整图片的质量(针对JPEG、PNG等)、使用特定的算法来减少图片的数据量等,:本文主要介绍Java图片压缩三种高效... 目录一、基于OpenCV的智能尺寸压缩技术亮点:适用场景:二、JPEG质量参数压缩关键技术:压缩效果对比

使用Python开发一个简单的本地图片服务器

《使用Python开发一个简单的本地图片服务器》本文介绍了如何结合wxPython构建的图形用户界面GUI和Python内建的Web服务器功能,在本地网络中搭建一个私人的,即开即用的网页相册,文中的示... 目录项目目标核心技术栈代码深度解析完整代码工作流程主要功能与优势潜在改进与思考运行结果总结你是否曾经

Python FastAPI+Celery+RabbitMQ实现分布式图片水印处理系统

《PythonFastAPI+Celery+RabbitMQ实现分布式图片水印处理系统》这篇文章主要为大家详细介绍了PythonFastAPI如何结合Celery以及RabbitMQ实现简单的分布式... 实现思路FastAPI 服务器Celery 任务队列RabbitMQ 作为消息代理定时任务处理完整

使用C#代码在PDF文档中添加、删除和替换图片

《使用C#代码在PDF文档中添加、删除和替换图片》在当今数字化文档处理场景中,动态操作PDF文档中的图像已成为企业级应用开发的核心需求之一,本文将介绍如何在.NET平台使用C#代码在PDF文档中添加、... 目录引言用C#添加图片到PDF文档用C#删除PDF文档中的图片用C#替换PDF文档中的图片引言在当