本文主要是介绍tensorflow读取数据-tfrecord格式(II),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
tensorflow读取数据-tfrecord格式(II)
上一篇博文介绍了tensorflow中的tfrecords方法,接下来以保存和读取图片数据为例,详细展示python实现代码
1、single picture
# -*- coding: utf-8 -*-
"""
Spyder Editor"""############single picture
import os
import tensorflow as tf
import cv2
from matplotlib import pyplot as plt
import numpy as npdef write_tfrecords(input,output):''' 借助于 TFRecordWriter 才能将信息写进 TFRecord 文件'''writer = tf.python_io.TFRecordWriter(output)# 读取图片并进行解码image = tf.read_file(input)image = tf.image.decode_jpeg(image)with tf.Session() as sess:image = sess.run(image)shape = image.shape# 将图片转换成 string。image_data = image.tostring()print(type(image))print(len(image_data))name = bytes("example", encoding='utf8')print(type(name))# 创建 Example 对象,并且将 Feature 一一对应填充进去。example = tf.train.Example(features=tf.train.Features(feature={'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))}))# 将 example 序列化成 string 类型,然后写入。writer.write(example.SerializeToString())writer.close()write_tfrecords('/Users/mac/MyProjects/SR/datasets/lr_img/lr_062.png','example.tfrecord')def _parse_record(example_proto):features = {'name': tf.FixedLenFeature((), tf.string),'shape': tf.FixedLenFeature([3], tf.int64),'data': tf.FixedLenFeature((), tf.string)}parsed_features = tf.parse_single_example(example_proto, features=features)return parsed_featuresdef read_tfrecords(input_file):# 用 dataset 读取 tfrecord 文件dataset = tf.data.TFRecordDataset(input_file)dataset = dataset.map(_parse_record)iterator = dataset.make_one_shot_iterator()with tf.Session() as sess:features = sess.run(iterator.get_next())name = features['name']name = name.decode()img_data = features['data']shape = features['shape']print('=======')print(type(shape))print(len(img_data))# 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组img_data = np.fromstring(img_data,dtype=np.uint8)image_data = np.reshape(img_data,shape)plt.figure()#显示图片plt.imshow(image_data)plt.show()#将数据重新编码成 jpg 图片并保存img = tf.image.encode_jpeg(image_data)tf.gfile.GFile('example_encode.jpg','wb').write(img.eval())read_tfrecords('example.tfrecord')
2、multi pictures
# -*- coding: utf-8 -*-
"""
Spyder EditorThis is a tfrecords script file.
"""##############multi pictures
import os
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as npdef write_tfrecords(input_path,output):''' 借助于 TFRecordWriter 才能将信息写进 TFRecord 文件'''writer = tf.python_io.TFRecordWriter(output)path = input_pathfile_names = [f for f in os.listdir(path) if f.endswith('.png')] #获取待存文件路径# 读取图片并进行解码for file_name in file_names:file_name = path + file_nameimage = tf.read_file(file_name)image = tf.image.decode_jpeg(image)with tf.Session() as sess:image = sess.run(image)shape = image.shape# 将图片转换成 string。image_data = image.tostring()#print(type(image))#print(len(image_data))name = bytes("train", encoding='utf8')#print(type(name))# 创建 Example 对象,并且将 Feature 一一对应填充进去。example = tf.train.Example(features=tf.train.Features(feature={'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))}))# 将 example 序列化成 string 类型,然后写入。writer.write(example.SerializeToString())writer.close()def _parse_record(example_proto):features = {'name': tf.FixedLenFeature((), tf.string),'shape': tf.FixedLenFeature([3], tf.int64),'data': tf.FixedLenFeature((), tf.string)}parsed_features = tf.parse_single_example(example_proto, features=features)return parsed_featuresdef read_tfrecords(num,input_file):# 用 dataset 读取 tfrecord 文件dataset = tf.data.TFRecordDataset(input_file)dataset = dataset.map(_parse_record)iterator = dataset.make_one_shot_iterator()with tf.Session() as sess:for i in range(num):features = sess.run(iterator.get_next())name = features['name']name = name.decode()img_data = features['data']shape = features['shape']print('=======')#print(type(shape))#print(len(img_data))# 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组img_data = np.fromstring(img_data,dtype=np.uint8)image_data = np.reshape(img_data,shape)#显示图片#plt.figure()#plt.imshow(image_data)#plt.show()#将数据重新编码成 jpg 图片并保存img = tf.image.encode_jpeg(image_data)tf.gfile.GFile('train_encode'+str(i)+'.jpg','wb').write(img.eval())if __name__ == '__main__':input_path = '/Users/MyProjects/SR/datasets/lr_img/'output = 'train.tfrecords'write_tfrecords(input_path, output)print ('Write tfrecords: %s done' %output)file_names = [f for f in os.listdir(input_path) if f.endswith('.png')]num = len(file_names)read_tfrecords(num,'train.tfrecords')
这篇关于tensorflow读取数据-tfrecord格式(II)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!