本文主要是介绍tf创建tfRecord文件,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
项目详细请猛戳我的github地址,直接可运行:https://github.com/SamXiaosheng/create-tfRecord
下面是main文件代码和create tfRecord文件:
import tensorflow as tf from tfRecord import * import cv2FLAGS = tf.app.flags.FLAGStf.app.flags.DEFINE_string('image_dir', './image/',"""Directory where to write event logs """)def main(_): create_tfrecords(FLAGS.image_dir)image_batch,label_batch =read_and_decode('test.tfRecord')with tf.Session() as sess: coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)while not coord.should_stop(): image,label = sess.run([image_batch,label_batch])print(label)cv2.imshow('image',image[0])cv2.waitKey(200)coord.request_stop()coord.join(threads)if __name__ == '__main__': tf.app.run()
import tensorflow as tf import numpy as np import os import cv2def read_and_decode(filename): filename_queue = tf.train.string_input_producer([filename])reader = tf.TFRecordReader()_, serialized_example = reader.read(filename_queue)features = tf.parse_single_example(serialized_example,features={'label': tf.FixedLenFeature([], tf.int64),'img_raw' : tf.FixedLenFeature([], tf.string),})img = tf.decode_raw(features['img_raw'], tf.uint8)#这里的格式非常重要 img = tf.reshape(img, [227, 227, 3])#img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 label = tf.cast(features['label'], tf.uint8)image_batch, label_batch = tf.train.shuffle_batch([img, label],batch_size=1,#这里参数设置目的是每次只读取一个样本 capacity=1,min_after_dequeue=0)#label_batch = tf.one_hot(label_batch, NUM_CLASSES) #label_batch = tf.cast(label_batch, dtype=tf.int64) #label_batch = tf.reshape(label_batch, [batch_size, NUM_CLASSES]) return image_batch, label_batch #读取某目录路径下的所有文件,返回图片的名称列表 def dirtomdfbatchmsra(dirpath):#读取目录下训练图像和对应的label image_ext = 'jpg' images = [fn for fn in os.listdir(dirpath) if fn.endswith(image_ext)]#返回dirpath路径下所有后缀jpg文件 images.sort()#排序的目的有利于样本和标签的对应 #print(images) gt_ext = 'png' gt_maps = [fn for fn in os.listdir(dirpath) if fn.endswith(gt_ext)]gt_maps.sort()#print(gt_maps) return gt_maps,images#返回gt图和训练image的所有文件名 def create_tfrecords(image_dir): writer = tf.python_io.TFRecordWriter("test.tfRecord")image_png,image_jpg = dirtomdfbatchmsra(image_dir)for index, name in enumerate(image_jpg): img = cv2.imread(image_dir+name).astype(np.uint8)img = cv2.resize(img,(227,227))#统一大小 img_raw = img.tobytes()#转换成字节形式 example = tf.train.Example(features=tf.train.Features(feature={"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))}))writer.write(example.SerializeToString())for index, name in enumerate(image_png): img = cv2.imread(image_dir+name).astype(np.uint8)img = cv2.resize(img,(227, 227))# img_raw = img.tobytes()example = tf.train.Example(features=tf.train.Features(feature={"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))}))writer.write(example.SerializeToString())writer.close()
这篇关于tf创建tfRecord文件的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!