Tensorflow分类器项目自定义数据读入的实现

yizhihongxing

1.准备工作

在进行Tensorflow分类器项目的自定义数据读入之前,需要做好以下准备工作:

1)安装Tensorflow库

2)准备自定义数据集

这里以mnist手写数字数据集为例,数据集存储方式是将训练数据和测试数据分别存储在不同的文件中,其中每个样本由784个像素值以及对应的数字标签构成,每行代表一张图片。

2.自定义数据读入

Tensorflow已经为我们提供了许多API便于读入数据,但是如果我们想读入自己定义的数据,则需要使用tf.data方法进行读入。具体步骤如下:

1)定义输入流水线

首先定义输入流水线,我们可以使用tf.data.TFRecordDataset来实现这一步骤,代码如下:

def parser(record):
    features = tf.parse_single_example(
        record,
        features={
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64)
        })
    image = tf.decode_raw(features['image'], tf.uint8)
    image = tf.reshape(image, [28, 28])
    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(features['label'], tf.int32)

    return image, label

def input_fn(is_training, filenames, batch_size, num_epochs=1):
    dataset = tf.data.TFRecordDataset(filenames)

    if is_training:
        dataset = dataset.shuffle(buffer_size=50000)

    dataset = dataset.map(parser, num_parallel_calls=4)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_one_shot_iterator()

    images, labels = iterator.get_next()

    return images, labels

2)调用输入流水线

调用定义好的输入流水线,代码如下:

train_filenames = ['train.tfrecords']
eval_filenames = ['test.tfrecords']

train_input_fn = lambda: input_fn(True, train_filenames, batch_size, num_epochs)
eval_input_fn = lambda: input_fn(False, eval_filenames, batch_size, num_epochs)

其中,train_filenames和eval_filenames是我们存储train数据和test数据的文件名。

3.示例说明

下面给出两个使用自定义数据读入的示例:

(1)实现AlexNet模型

在AlexNet模型中,需要用到imagenet数据集,由于imagenet数据集非常大,所以需要使用集群模式训练。在每个worker上进行训练的时候,我们需要自定义数据读入方式,具体实现可以参考上述的步骤。代码如下:

# define dataset
def _dataset_parser(serialized_example):
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image/encoded': tf.FixedLenFeature([], tf.string, ''),
            'image/class/label': tf.FixedLenFeature([], tf.int64, -1),
            'image/height': tf.FixedLenFeature([], tf.int64, 224),
            'image/width': tf.FixedLenFeature([], tf.int64, 224),
            'image/channels': tf.FixedLenFeature([], tf.int64, 3),
        })

    image = tf.decode_raw(features['image/encoded'], tf.uint8)
    image = tf.cast(image, tf.float32)

    label = tf.cast(features['image/class/label'], tf.int32)

    return image, label

class InputPipeline:
    def __init__(self, filenames, batch_size, is_training):
        self.filenames = filenames
        self.batch_size = batch_size
        self.is_training = is_training

        dataset = tf.data.TFRecordDataset(filenames)
        dataset = dataset.map(_dataset_parser, num_parallel_calls=16)

        if self.is_training:
            dataset = dataset.shuffle(buffer_size=10000)

        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(buffer_size=1)

        self.iterator = dataset.make_initializable_iterator()
        self.images, self.labels = self.iterator.get_next()

    def initialize(self, sess):
        sess.run(self.iterator.initializer)

    def get_batch(self):
        return self.images, self.labels

input_pipeline = InputPipeline(filenames_train, batch_size=BATCH_SIZE, is_training=True)

(2)实现Faster RCNN模型

在Faster RCNN模型中,需要使用coco数据集。由于coco数据集也非常大,所以需要使用集群模式训练。在进行训练时,需要自定义数据读入方式,并进行数据增强。具体实现可以参考以下代码:

class InputReader(object):
    def __init__(self, sess, tfrecords=[], is_training=True, num_threads=4):
        '''
        tfrecords: 需要读取的TFRecord文件列表
        is_training: 是否进行训练
        num_threads: 多线程读取tfrecords
        '''
        self.sess = sess
        self.is_training = is_training
        self.num_threads = min(num_threads, len(tfrecords))
        self.reader = tf.TFRecordReader()
        self.queue = tf.train.string_input_producer(tfrecords, shuffle=self.is_training)
        self.batch_size = FLAGS.batch_size if self.is_training else FLAGS.test_batch_size

    def parse_single_example(self, example, is_training):
        '''
        example: 单个样例的字节串形式
        is_training: 是否进行训练
        '''
        if is_training:
            image_size = FLAGS.train_image_size
            min_object_covered = 0.5
            crop_ratio_range = (0.75, 1.25)
            max_attempts = 200
            random_mirror = True
            random_flip = True
        else:
            image_size = FLAGS.test_image_size
            min_object_covered = 0.0
            crop_ratio_range = (1.0, 1.0)
            max_attempts = 1
            random_mirror = False
            random_flip = False

        features = tf.parse_single_example(example, features={
            'image/height': tf.FixedLenFeature([], tf.int64),
            'image/width': tf.FixedLenFeature([], tf.int64),
            'image/encoded': tf.FixedLenFeature([], tf.string),
            'image/format': tf.FixedLenFeature([], tf.string),
            'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
            'image/object/class/label': tf.VarLenFeature(dtype=tf.int64),
            'image/object/class/text': tf.VarLenFeature(dtype=tf.string),
            'image/object/difficult': tf.VarLenFeature(dtype=tf.int64),
            'image/object/truncated': tf.VarLenFeature(dtype=tf.int64),
            'image/filename': tf.FixedLenFeature([], dtype=tf.string)
        })

        image = tf.image.decode_jpeg(features['image/encoded'], channels=3)
        filename = tf.cast(features['image/filename'], tf.string)
        short_size = tf.minimum(tf.shape(image)[:2])
        image = aspect_preserving_resize(image, short_size, image_size)
        scale_factor = tf.cast(image_size / tf.cast(short_size, tf.float32), tf.float32)
        bboxes = tf.stack([
            tf.sparse_tensor_to_dense(features['image/object/bbox/ymin']) * scale_factor,
            tf.sparse_tensor_to_dense(features['image/object/bbox/xmin']) * scale_factor,
            tf.sparse_tensor_to_dense(features['image/object/bbox/ymax']) * scale_factor,
            tf.sparse_tensor_to_dense(features['image/object/bbox/xmax']) * scale_factor
        ], axis=-1)
        labels = tf.sparse_tensor_to_dense(features['image/object/class/label']) - 1
        difficult = tf.sparse_tensor_to_dense(features['image/object/difficult'])
        truncated = tf.sparse_tensor_to_dense(features['image/object/truncated'])

        if is_training:
            image, bboxes = distort_image_with_autoaugment(image, bboxes)

        image = tf.cast(image, tf.float32)

        if is_training:
            image = tf.image.random_crop(image, [image_size, image_size, 3])

        if random_mirror:
            image, bboxes = random_horizontal_flip(image, bboxes)

        if random_flip:
            image, bboxes = random_vertical_flip(image, bboxes)

        image, bboxes = pad_image_and_labels(image, bboxes, None, None, image_size, is_training)

        num_valid_boxes = tf.reduce_sum(tf.cast(tf.logical_not(tf.logical_or(tf.equal(difficult, 1), tf.equal(truncated, 1))), tf.float32))
        return image, bboxes, labels, num_valid_boxes, filename

    def read(self):
        '''
        读取TFRecord文件列表中所有的数据
        '''
        # Read serialized example data from Dataset. `read_up_to` performs dynamic batching so that if there
        # aren't enough queued items to fill a batch, it won't wait until it has `batch_size` items in the
        # queue. Note: This will break incomplete batches in one of three ways, depending on the value of
        # `ALLOWED_PADDING_DECREMENTS`:
        # - If negative, pads with zeros to fill out the current batch.
        # - If zero, discards the remainder so that all batches are full.
        # - If positive, adds padding to the remainder so that it fills the current (incomplete) batch.
        batch_images, batch_bboxes, batch_labels, batch_num_valid_boxes, batch_filenames = \
            tf.train.batch(
                [self.parse_single_example(example, self.is_training) for _ in range(self.num_threads)],
                self.batch_size,
                num_threads=self.num_threads,
                capacity=1024,
                allow_smaller_final_batch=True,
            )

        batch_images = tf.identity(batch_images, 'image')
        batch_labels = tf.identity(batch_labels, 'label')
        batch_bboxes = tf.identity(batch_bboxes, 'bbox')
        batch_num_valid_boxes = tf.identity(batch_num_valid_boxes, 'num_valid_boxes')
        batch_filenames = tf.identity(batch_filenames, 'filename')

        return batch_images, batch_labels, batch_bboxes, batch_num_valid_boxes, batch_filenames

以上是两个使用自定义数据读入的示例,可以根据需求进行修改和参考。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Tensorflow分类器项目自定义数据读入的实现 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • Android屏幕旋转 处理Activity与AsyncTask的最佳解决方案

    这是一个涉及到Android屏幕旋转以及在旋转中处理Activity和AsyncTask的问题。以下是处理这个问题的最佳解决方案。 问题说明 在Android中,当屏幕旋转时,Activity将会被销毁并重新创建。此外,AsyncTask的生命周期会在Activity的生命周期内更改。如果不正确处理屏幕旋转和AsyncTask的生命周期,可能会导致应用程序的…

    人工智能概览 2023年5月25日
    00
  • Python3之外部文件调用Django程序操作model等文件实现方式

    下面为你讲解Python3之外部文件调用Django程序操作model等文件实现方式的攻略: 1. 配置环境及导入模块 首先,确保你已经配置好Django环境,并安装好了相关的Python库,如django、os等。 接下来,在外部文件中导入Django应用的model和相关需要的库: import os import django # 设置 Django配…

    人工智能概览 2023年5月25日
    00
  • 修改Nginx与Apache上传文件大小限制

    针对修改Nginx和Apache上传文件大小限制的问题,我将为您分享以下完整攻略。 修改Nginx上传文件大小限制 Nginx的上传文件大小限制包括两个参数,分别为client_max_body_size和client_body_buffer_size。 1. 修改client_max_body_size 第一步,修改Nginx配置文件中的client_ma…

    人工智能概览 2023年5月25日
    00
  • 对python中的six.moves模块的下载函数urlretrieve详解

    对python中的six.moves模块的下载函数urlretrieve详解 介绍 six.moves是由six模块提供的一个适用于Python 2和3的兼容性工具,致力于让开发者在Python 2/3之间轻松移植。常用的六个子模块:- builtins- configparser- http_client- urllib- queue- xrange si…

    人工智能概览 2023年5月25日
    00
  • Go 代码规范错误处理示例经验总结

    下面是关于“Go 代码规范错误处理示例经验总结”的完整攻略。 什么是错误处理 错误处理是指在软件开发过程中处理程序运行过程中可能出现的错误的一种方式。在Go语言中,错误处理通常使用返回值来表示,而不是抛出异常(类似于Java或Python的做法)。因此,Go程序员需要养成规范正确的错误处理习惯来保证程序的健壮性和可维护性。 错误处理的代码规范 把错误信息放在…

    人工智能概览 2023年5月25日
    00
  • PHP编译configure时常见错误的总结

    PHP编译configure时常见错误的总结 在编译PHP时,configure是非常重要的一个步骤,不能正确进行configure,之后的make和make install都有可能失败,因此,总结一些常见的configure错误并解决这些错误是非常必要的。 1. configure: error: Cannot find OpenSSL’s 这个错误是因为…

    人工智能概览 2023年5月25日
    00
  • 利用Spring Boot如何开发REST服务详解

    利用Spring Boot开发REST服务的详细攻略如下: 1. 搭建Spring Boot项目环境 首先,我们需要创建一个Spring Boot项目。具体步骤如下: 在IDE中创建一个新的Maven项目,并打开“pom.xml”文件。 在“pom.xml”文件中添加Spring Boot的依赖项,如下所示: <dependency> <g…

    人工智能概论 2023年5月25日
    00
  • Bootstrap实现登录校验表单(带验证码)

    实现Bootstrap登录校验表单(带验证码)需要遵循以下步骤: 1. 引入Bootstrap和jQuery库 在标签内引入Bootstrap和jQuery库: <head> <link rel="stylesheet" href="https://cdn.staticfile.org/twitter-boot…

    人工智能概览 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部