本文主要是介绍Tensorflow使用TFRecord构建自己的数据集并读取,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
Tensorflow使用TFRecord构建自己的数据集并读取
参考文章:
http://blog.csdn.net/freedom098/article/details/56011858
还有 优酷上kevin大神的视频
目标:1、将自己的数据集以TFRecord格式存储。
2、从TFRecord中读取数据,并使用画图工具,以图片形式展现。
以一个图片为例:
一、将图片存储TFRecod
# 生成整数型的属性
def _int64_feature(value):if not isinstance(value,list):value=[value]return tf.train.Feature(int64_list=tf.train.Int64List(value=value))#生成字符串型的属性
def _byte_feature(value):return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
#将图片存储到tfrecord中
def convert_to_tfrecord(images, labels, save_dir, name):#从图片路径读取图片编码成tfrecord'''''convert all images and labels to one tfrecord file. Args: images: list of image directories, string type labels: list of labels, int type save_dir: the directory to save tfrecord file, e.g.: '/home/folder1/' name: the name of tfrecord file, string type, e.g.: 'train' Return: no return Note: converting needs some time, be patient... ''' filename = (save_dir + name + '.tfrecords') n_samples = len(labels) #判断 image的样本数量和label是否相同if np.shape(images)[0] != n_samples: raise ValueError('Images size %d does not match label size %d.' %(images.shape[0], n_samples)) writer = tf.python_io.TFRecordWriter(filename) print('\nTransform start......') for i in range(len(images)): try: image_raw_data = tf.gfile.FastGFile(images[i],'r').read()img_data = tf.image.decode_png(image_raw_data)label = int(labels[i]) example = tf.train.Example(features=tf.train.Features(feature={ 'label':int64_feature(label), 'image_raw': bytes_feature(image_raw)})) writer.write(example.SerializeToString()) except IOError as e: print('Could not read:', images[i]) print('error: %s' %e) print('Skip it!\n') writer.close()
二、读取数据,并绘图
# read the data from tfrecoder
def read_and_decode(tfrecords_file): '''''read and decode tfrecord file, generate (image, label) batches Args: tfrecords_file: the directory of tfrecord file batch_size: number of images in each batch Returns: image: 4D tensor - [batch_size, width, height, channel] label: 1D tensor - [batch_size] ''' # make an input queue from the tfrecord file filename_queue = tf.train.string_input_producer([tfrecords_file]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue)
#解析读入的样例img_features = tf.parse_single_example( serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'image_raw': tf.FixedLenFeature([], tf.string), })
#将字符串解析成相应的数组image = tf.decode_raw(img_features['image_raw'], tf.uint8)
#转化成图片的格式image = tf.reshape(image, [465, 315,3])sess = tf.Session()coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess,coord=coord)image , label = sess.run([image,label])print imageplt.imshow(image)plt.show()sess.close()
read_and_decode('/home/tensor/Desktop/tia.tfrecords')
三、结果
这篇关于Tensorflow使用TFRecord构建自己的数据集并读取的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!