- 什么是tfrecord文件?
tfrecord是tensorflow中定义的一种二进制数据存储格式,它可以将一个或多个样本数据转化成二进制序列,并将多个二进制序列拼接成一个二进制文件。这种方式将大量的数据存储在单个文件中,具有良好的读写性能,有利于数据加载和处理。
- 如何生成tfrecord文件?
生成tfrecord文件需要以下四个步骤:
(1)将数据存储到一个或多个特定格式的数据文件中,如csv、txt、图片等文件。
(2)使用tensorflow提供的dataset API或tf.python_io.TFRecordWriter将数据文件中的数据转化为Example格式的protobuf消息。
(3)将Example消息写入到TFRecord文件中。
(4)对于需要测量的度量指标,可以采用tf.summary方式将它们汇总到TensorBoard中。
下面是一个根据图片数据生成tfrecord文件的示例代码:
import tensorflow as tf
import numpy as np
import os
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def convert_to_example(image_data, label):
example = tf.train.Example(features=tf.train.Features(feature={
'image': _bytes_feature(image_data),
'label': _int64_feature(label),
}))
return example
def _add_to_tfrecord(data_dir, name, tfrecord_writer):
img_dir = os.path.join(data_dir, name)
if name == 'cats':
label = 0
else:
label = 1
for img_name in os.listdir(img_dir):
img_path = os.path.join(img_dir, img_name)
img_data = tf.gfile.FastGFile(img_path, 'rb').read()
example = convert_to_example(img_data, label)
tfrecord_writer.write(example.SerializeToString())
def run(data_dir, output_dir, shuffling=True):
if not tf.gfile.Exists(output_dir):
tf.gfile.MakeDirs(output_dir)
tfrecord_filename = os.path.join(output_dir, 'cats_vs_dogs.tfrecords')
with tf.python_io.TFRecordWriter(tfrecord_filename) as tfrecord_writer:
name_list = ['cats', 'dogs']
if shuffling:
np.random.shuffle(name_list)
for name in name_list:
_add_to_tfrecord(data_dir, name, tfrecord_writer)
print('Successfully encoded dataset.')
- 如何读取tfrecord文件?
读取tfrecord文件需要以下三个步骤:
(1)创建一个TFRecordReader实例。
(2)使用该实例读取TFRecord文件中的Example数据。
(3)使用tf.parse_single_example对Example消息进行解析。
下面是一个从已有的tfrecord文件读取并解码的示例代码:
import tensorflow as tf
def decode(serialized_example):
features = tf.parse_single_example(
serialized_example,
features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features['image'], tf.uint8)
image.set_shape([None])
image = tf.reshape(image, [28, 28, 1])
label = tf.cast(features['label'], tf.int32)
return image, label
def input_fn(data_dir, mode, batch_size):
tfrecord_filename = os.path.join(data_dir, mode + '.tfrecords')
dataset = tf.data.TFRecordDataset(tfrecord_filename)
if mode == 'train':
dataset = dataset.map(decode).repeat().batch(batch_size).shuffle(buffer_size=10000)
else:
dataset = dataset.map(decode).batch(batch_size)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
以上便是tfrecord文件的生成与读取的完整攻略,希望对你有所帮助。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:tensorflow学习笔记之tfrecord文件的生成与读取 - Python技术站