TFRecord格式存储数据与队列读取实例

下面详细讲解“TFRecord格式存储数据与队列读取实例”的完整攻略。本文将包含两个具体的示例说明,以帮助读者更好地理解和掌握相关知识。

什么是TFRecord格式?

TFRecord是一种TensorFlow的数据格式,它是一种二进制格式,可以更加高效地存储数据,方便数据的快速读取和处理。

使用TFRecord的好处包括:

  • 无需通过大量的代码去读取和处理数据;
  • 快速的并行化数据处理的方法;
  • 可以将多个文件合并成一个文件,方便读取。

TFRecord格式通常用于存储大量的数据。

TFRecord格式的存储

可以使用Python的Protocol Buffer库来存储数据。Protocol Buffer是Google开发的用于序列化结构化数据的一种格式。它可以将数据进行编码,然后以二进制格式进行存储。

存储数据的步骤如下:

  1. 创建一个tf.train.Example对象,这个对象包含了需要存储的信息。

```python
import tensorflow as tf

# 创建一个字典用于存储特征
feature_dict = {'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}

# 创建一个Example对象
example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
```

  1. tf.train.Example对象序列化为一个字符串:

python
serialized_example = example.SerializeToString()

  1. 将字符串写入TFRecord文件:

python
writer.write(serialized_example)

使用队列读取TFRecord格式的数据

可以先创建一个输入队列,然后通过读取队列中的元素来获取数据,具体的步骤如下:

  1. 创建一个输入队列:

python
filename_queue = tf.train.string_input_producer([filename])

  1. 定义一个TFRecordReader对象读取数据:

python
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)

  1. 将序列化后的字符串解析为一个tf.train.Example对象:

python
features = tf.parse_single_example(serialized_example, features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)})
image = tf.decode_raw(features['image'], tf.uint8)
label = tf.cast(features['label'], tf.int32)

  1. 对读取的数据进行预处理:

python
image = tf.reshape(image, [height, width, num_channels])
image = tf.cast(image, tf.float32)
image /= 255.0

  1. 创建一个batch:

python
images_batch, labels_batch = tf.train.shuffle_batch([image, label], batch_size=batch_size,
capacity=capacity, min_after_dequeue=min_after_dequeue)

通过以上步骤,一个输入数据的队列就创建好了,现在就可以通过sess.run的方式去读取数据了。

示例1:存储MNIST图像数据和标签数据为TFRecord文件

下面是一个例子,展示如何将MNIST图像数据和标签数据保存为TFRecord文件。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 数据集参数
num_examples = input_data.train.num_examples
num_classes = 10

# 文件路径和名称
filename = 'mnist.tfrecords'

# 将数据写入TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)
for i in range(num_examples):
    image_raw = input_data.train.images[i].tostring()
    label_raw = input_data.train.labels[i].astype(int).tostring()

    feature_dict = {'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
                    'label_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_raw]))}

    example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
    writer.write(example.SerializeToString())
writer.close()

示例2:读取TFRecord格式的数据并进行训练

下面是一个例子,展示如何从TFRecord文件中读取数据,并进行训练。

import tensorflow as tf

# 数据集参数
batch_size = 128
capacity = 10000
min_after_dequeue = 3000

# 文件路径和名称
filename = 'mnist.tfrecords'

# 创建一个输入队列
filename_queue = tf.train.string_input_producer([filename])

# 定义一个TFRecordReader对象读取数据
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)

# 将序列化后的字符串解析为一个Example对象
features = tf.parse_single_example(serialized_example, features={
    'image_raw': tf.FixedLenFeature([], tf.string),
    'label_raw': tf.FixedLenFeature([], tf.string)})

# 将图像数据解析为一个Tensor
image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape([784])

# 将标签数据解析为一个Tensor
label = tf.decode_raw(features['label_raw'], tf.int32)
label.set_shape([])

# 进行数据预处理
image = tf.cast(image, tf.float32)
image /= 255.0

# 创建一个batch
images_batch, labels_batch = tf.train.shuffle_batch([image, label], batch_size=batch_size,
                                                    capacity=capacity, min_after_dequeue=min_after_dequeue)

# 构建模型
x = tf.placeholder(tf.float32, [batch_size, 784])
y = tf.placeholder(tf.int32, [batch_size])
logits = tf.layers.dense(x, num_classes, activation=None)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y))
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss)

# 训练模型
with tf.Session() as sess:
    # 初始化所有变量
    sess.run(tf.global_variables_initializer())

    # 启动输入队列
    tf.train.start_queue_runners()

    # 迭代训练
    for i in range(1000):
        images, labels = sess.run([images_batch, labels_batch])
        _, l = sess.run([train_op, loss], feed_dict={x: images, y: labels})

        if i % 10 == 0:
            print('Step {:5d}: loss = {:.3f}'.format(i, l))

通过上面两个示例,相信大家已经初步了解了如何使用TFRecord格式存储数据并进行队列读取。当然,实际应用中,还需要根据不同的数据集,进行一些细节的调整和处理。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:TFRecord格式存储数据与队列读取实例 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • tensorflow1.0 构建卷积神经网络

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import os os.environ[“CUDA_DEVICE_ORDER”] = “0,1” mnist = input_data.read_data_sets(“MNIST_data”,…

    卷积神经网络 2023年4月6日
    00
  • MINST手写数字识别(二)—— 卷积神经网络(CNN)

          今天我们的主角是keras,其简洁性和易用性简直出乎David 9我的预期。大家都知道keras是在TensorFlow上又包装了一层,向简洁易用的深度学习又迈出了坚实的一步。       所以,今天就来带大家写keras中的Hello World , 做一个手写数字识别的cnn。回顾cnn架构: 我们要处理的是这样的灰度像素图:   我们先来看…

    2023年4月7日
    00
  • 1-11 为什么使用卷积?

    为什么使用卷积?(Why convolutions?) 和只用全连接层相比,卷积层的两个主要优势在于参数共享和稀疏连接: 假设有一张 32×32×3 维度的图片,假设用了 6 个大小为 5×5 的过滤器,输出维度为 28×28×6。32×32×3=3072, 28×28×6=4704。我们构建一个神经网络,其中一层含有 3072 个单元,下一层含有 4074…

    2023年4月8日
    00
  • <转>卷积神经网络是如何学习到平移不变的特征

    After some thought, I do not believe that pooling operations are responsible for the translation invariant property in CNNs. I believe that invariance (at least to translation) is …

    2023年4月8日
    00
  • tensorflow中的卷积和池化层(一)

    在官方tutorial的帮助下,我们已经使用了最简单的CNN用于Mnist的问题,而其实在这个过程中,主要的问题在于如何设置CNN网络,这和Caffe等框架的原理是一样的,但是tf的设置似乎更加简洁、方便,这其实完全类似于Caffe的python接口,但是由于框架底层的实现不一样,tf无论是在单机还是分布式设备上的实现效率都受到一致认可。 CNN网络中的卷积…

    卷积神经网络 2023年4月6日
    00
  • 循环卷积与任意长度FFT

    在之前的DFT中有n^2的循环卷积 考虑式子为 的暴力卷积 拆分nk为 对于Xk,k^2/2是常值 于是 可以发现后半部分是关于n和(k-n)的卷积。 可以得到点值。 逆运算可以推出相对的式子即可。

    2023年4月8日
    00
  • EdgeFormer: 向视觉 Transformer 学习,构建一个比 MobileViT 更好更快的卷积网络

    ​  前言 本文主要探究了轻量模型的设计。通过使用 Vision Transformer 的优势来改进卷积网络,从而获得更好的性能。 欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。 ​ 论文:https://arxiv.org/abs/2203.03952 代码:https://github.com/hkz…

    卷积神经网络 2023年4月7日
    00
  • pytorch DistributedDataParallel 多卡训练结果变差的解决方案

    为了解决pytorch DistributedDataParallel多卡训练结果变差的问题,我们可以采用以下解决方案: 数据加载器设置shuffle参数 在使用多卡训练时,我们需要使用torch.utils.data.DistributedSampler并设置shuffle参数为True。这可以确保数据在多机多卡之间均匀地分配,从而避免了训练结果变差的原因…

    卷积神经网络 2023年5月15日
    00
合作推广
合作推广
分享本页
返回顶部