TensorFLow 不同大小图片的TFrecords存取实例

TensorFlow 不同大小图片的TFRecords存取实例

1. 环境配置

使用 TensorFlow 存取 TFRecords 首先需要安装 TensorFlow 。如果您还没有安装 TensorFlow,请参考官方文档进行安装。

2. 创建TFRecords文件

创建 TFRecord 文件需要使用 TensorFlow 提供的 tf.io.TFRecordWriter() 函数,该函数接收的参数是 TFRecord 文件的路径。在本例中,我们将文件存放到 "../mydata/" 目录下,文件名为 "mydata.tfrecord"。

import os
import tensorflow as tf

# 指定图片目录
image_dir = "../images/"

# 定义类别
labels = {
    "cat": 0,
    "dog": 1
}

# 获取图片列表
image_paths = []
for label in labels:
    path = os.path.join(image_dir, label)
    for filename in os.listdir(path):
        image_paths.append((os.path.join(path, filename), labels[label]))

# 打乱图片顺序
import random
random.shuffle(image_paths)

# 划分训练集和测试集
num_train = int(len(image_paths) * 0.8)
train_paths = image_paths[:num_train]
test_paths = image_paths[num_train:]

# 定义 TFRecord 文件路径
train_tfrecord_path = "../mydata/train.tfrecord"
test_tfrecord_path = "../mydata/test.tfrecord"

# 创建 tfrecord 文件
def create_tfrecord(tfrecord_path, image_paths):
    with tf.io.TFRecordWriter(tfrecord_path) as writer:
        for image_path, label in image_paths:
            # 读取图片
            with tf.io.gfile.GFile(image_path, "rb") as f:
                image_data = f.read()

            # 解码图片
            image = tf.image.decode_jpeg(image_data)

            # 转换为 Tensor 并改变 shape
            image = tf.image.convert_image_dtype(image, tf.float32)
            image = tf.image.resize(image, [128, 128])
            image = tf.reshape(image, [-1])

            # 定义 Example
            example = tf.train.Example(features=tf.train.Features(feature={
                "image": tf.train.Feature(float_list=tf.train.FloatList(value=image.numpy())),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
            }))

            # 写入 TFRecord 文件
            writer.write(example.SerializeToString())

# 分别创建训练集和测试集的 TFRecord 文件
create_tfrecord(train_tfrecord_path, train_paths)
create_tfrecord(test_tfrecord_path, test_paths)

在上面的代码中,我们首先指定了图片目录(image_dir)和类别(labels),然后获取了图片列表(image_paths),并打乱图片顺序。接着,我们将数据划分为训练集(train_paths)和测试集(test_paths)。

接下来,我们定义了 TFRecord 文件的路径(train_tfrecord_path和test_tfrecord_path)并分别创建了训练集和测试集的TFRecord文件。

在 create_tfrecord() 函数内,我们首先读取图片并解码,然后将其转换为 Tensor,并改变其 shape。接着,我们使用 tf.train.Example 定义了 Example,并将其写入了 TFRecord 文件。

3. 读取TFRecords文件

使用 TensorFlow 读取 TFRecords 文件可以使用 tf.data.TFRecordDataset() 函数,该函数接收的参数是 TFRecord 文件的路径。在本例中,我们将读取 "../mydata/train.tfrecord" 文件。

import tensorflow as tf

# 定义 TFRecord 文件路径
tfrecord_path = "../mydata/train.tfrecord"

# 定义 Feature 字典,如下所示
feature_description = {
    "image": tf.io.FixedLenFeature([], tf.float32),
    "label": tf.io.FixedLenFeature([], tf.int64)
}

# 定义解析函数,用于解析 Example
def _parse_function(example_proto):
    example = tf.io.parse_single_example(example_proto, feature_description)
    image = tf.reshape(example["image"], [128, 128, 3])
    label = example["label"]
    return image, label

# 读取 TFRecord 文件
dataset = tf.data.TFRecordDataset(tfrecord_path)

# 解析 Example
dataset = dataset.map(_parse_function)

# 设置批次大小并打乱数据
batch_size = 32
dataset = dataset.batch(batch_size).shuffle(buffer_size=batch_size*10)

# 输出数据形状
for images, labels in dataset.take(1):
    print(images.shape, labels.shape)

在上面的代码中,我们定义了以 "../mydata/train.tfrecord" 为路径的 TFRecord 文件,并定义了 Feature 字典(feature_description),用于解析 Example。

然后,我们定义了解析函数(_parse_function),用于解析 Example,并将图片还原为原来的 shape。接着,我们使用 tf.data.TFRecordDataset() 函数读取了 TFRecord 文件,并使用 map() 函数解析 Example。

最后,我们设置了批次大小(batch_size)并打乱了数据。通过遍历 dataset 并输出第一个 batch 的形状,我们可以得到数据集的形状。

4. 结果

通过上面的步骤,我们可以成功地将不同大小的图片保存到 TFRecords 并读取出来。接下来,我们将通过两条示例说明如何使用上述代码:

示例1:增加图片大小并保存到 TFRecords 文件

我们可以将上述代码中的图片大小从 128x128 增加到 224x224,并保存为新的 TFRecords 文件。代码如下所示:

import os
import tensorflow as tf

# 指定图片目录
image_dir = "../images/"

# 定义类别
labels = {
    "cat": 0,
    "dog": 1
}

# 获取图片列表
image_paths = []
for label in labels:
    path = os.path.join(image_dir, label)
    for filename in os.listdir(path):
        image_paths.append((os.path.join(path, filename), labels[label]))

# 打乱图片顺序
import random
random.shuffle(image_paths)

# 划分训练集和测试集
num_train = int(len(image_paths) * 0.8)
train_paths = image_paths[:num_train]
test_paths = image_paths[num_train:]

# 定义 TFRecord 文件路径
train_tfrecord_path = "../mydata/train.tfrecord"
test_tfrecord_path = "../mydata/test.tfrecord"

# 创建 tfrecord 文件
def create_tfrecord(tfrecord_path, image_paths):
    with tf.io.TFRecordWriter(tfrecord_path) as writer:
        for image_path, label in image_paths:
            # 读取图片
            with tf.io.gfile.GFile(image_path, "rb") as f:
                image_data = f.read()

            # 解码图片
            image = tf.image.decode_jpeg(image_data)

            # 转换为 Tensor 并改变 shape
            image = tf.image.convert_image_dtype(image, tf.float32)
            image = tf.image.resize(image, [224, 224])
            image = tf.reshape(image, [-1])

            # 定义 Example
            example = tf.train.Example(features=tf.train.Features(feature={
                "image": tf.train.Feature(float_list=tf.train.FloatList(value=image.numpy())),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
            }))

            # 写入 TFRecord 文件
            writer.write(example.SerializeToString())

# 分别创建训练集和测试集的 TFRecord 文件
create_tfrecord(train_tfrecord_path, train_paths)
create_tfrecord(test_tfrecord_path, test_paths)

在上述代码中,我们使用了 tf.image.resize() 函数将图片大小从 128x128 增加到了 224x224。

示例2:使用 TFRecords 数据集训练模型

我们可以将上述代码中的读取 TFRecords 数据集部分用于训练模型。代码如下所示:

import tensorflow as tf

# 定义 TFRecord 文件路径
tfrecord_path = "../mydata/train.tfrecord"

# 定义 Feature 字典,如下所示
feature_description = {
    "image": tf.io.FixedLenFeature([], tf.float32),
    "label": tf.io.FixedLenFeature([], tf.int64)
}

# 定义解析函数,用于解析 Example
def _parse_function(example_proto):
    example = tf.io.parse_single_example(example_proto, feature_description)
    image = tf.reshape(example["image"], [224, 224, 3])
    label = example["label"]
    return image, label

# 读取 TFRecord 文件
dataset = tf.data.TFRecordDataset(tfrecord_path)

# 解析 Example
dataset = dataset.map(_parse_function)

# 设置批次大小并打乱数据
batch_size = 32
dataset = dataset.batch(batch_size).shuffle(buffer_size=batch_size*10)

# 模型定义
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, kernel_size=(3,3), activation='relu', input_shape=(224,224,3)),
    tf.keras.layers.MaxPooling2D(pool_size=(2,2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(2, activation='softmax')
])

# 模型编译
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 模型训练
model.fit(dataset, epochs=5)

在上述代码中,我们将 TFRecord 文件中的数据作为模型的输入,并使用 tf.keras 搭建了一个简单的卷积神经网络模型。经过训练后,我们可以得到模型的准确率。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:TensorFLow 不同大小图片的TFrecords存取实例 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • Django自定义用户表+自定义admin后台中的字段实例

    下面详细讲解一下Django自定义用户表+自定义admin后台中的字段实例的完整攻略。 首先,在Django中自定义用户表时,需要继承Django默认的AbstractBaseUser和PermissionsMixin类,具体做法如下: from django.contrib.auth.models import AbstractBaseUser, Perm…

    人工智能概览 2023年5月25日
    00
  • nginx 平滑重启的实现方法

    下面来讲解“nginx 平滑重启的实现方法”的完整攻略。 什么是nginx平滑重启? nginx是一款优秀的Web服务器,为了稳定性,在nginx运行过程中,如果需要重新加载配置文件或升级程序,都需要通过重启来完成,但是重启会导致服务短暂中断,可能会造成一定的损失。相比之下,nginx的平滑重启就可以在重新加载配置文件或升级程序的时候不中断服务,这对于线上环…

    人工智能概览 2023年5月25日
    00
  • Android Studio配置(Android Studio4.1为例)

    下面我来为你讲解详细的Android Studio配置攻略。 环境准备 在开始配置之前需要确认一下你的环境是否满足要求,需要具备以下条件: 系统:Windows、MacOS或Linux,建议使用64位操作系统 JDK:建议使用JDK8或者OpenJDK8 内存:至少需要8GB RAM,推荐16GB RAM 安装Android Studio 下载安装包 首先需…

    人工智能概览 2023年5月25日
    00
  • SpringCloud hystrix断路器与局部降级全面介绍

    SpringCloud Hystrix断路器与局部降级全面介绍 什么是Hystrix断路器 Hystrix是Netflix发布的一款容错框架,用于处理分布式系统的延迟和容错问题。Hystrix在整合了SpringCloud项目之后,是同步、异步请求的断路器。 断路器是对延迟和故障的容错,当请求后端服务出现链路故障、返回超时等,断路器会直接断开请求链路,避免系…

    人工智能概览 2023年5月25日
    00
  • AngularJS轻松实现双击排序的功能

    下面是“AngularJS轻松实现双击排序的功能”的完整攻略: 1. 概述 在AngularJS中实现双击排序的功能可以通过使用ng-repeat、ng-click和双击事件结合起来实现。其中ng-repeat用于循环生成视图,ng-click用于处理排序事件,双击事件用于响应用户的行为。 2. 示例说明 下面是两个示例,分别演示了如何使用AngularJS…

    人工智能概论 2023年5月24日
    00
  • 基于javascript处理nginx请求过程详解

    基于JavaScript处理Nginx请求过程详解 本篇攻略旨在介绍使用JavaScript与Nginx一同处理web请求的过程。首先需要了解Nginx的基本架构,它是由主进程(Master Process)和多个工作进程(Worker Process)组成的,其中主进程用于监听端口和管理工作进程,而工作进程用于处理来自客户端的请求。我们将基于这个架构使用J…

    人工智能概览 2023年5月25日
    00
  • Python脚本调试工具安装过程

    下面是Python脚本调试工具安装过程的完整攻略。 安装过程 步骤1:安装Python 首先需要安装Python,可以在Python官网下载安装包进行安装,或使用系统自带的Python环境。 步骤2:安装调试工具 常用的Python脚本调试工具有pdb、ipdb、pudb等。具体安装方法如下: 使用pip安装pdb 如果已经安装了Python,可以使用pip…

    人工智能概览 2023年5月25日
    00
  • 使用Lua编写Nginx服务器的认证模块的方法

    下面是详细讲解如何使用Lua编写Nginx服务器的认证模块。 1. 什么是Nginx Nginx是一个高性能的HTTP和反向代理服务器,也是一个IMAP/POP3/SMTP代理服务器。常用于静态文件的服务和监视HTTP流量的代理服务器,同时具有负载均衡、容错、安全性高等特点。 2. 认证模块简介 Nginx服务器提供了一种叫做“模块”的技术,可以通过编写自定…

    人工智能概览 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部