TensorFLow 不同大小图片的TFrecords存取实例

TensorFlow 不同大小图片的TFRecords存取实例

1. 环境配置

使用 TensorFlow 存取 TFRecords 首先需要安装 TensorFlow 。如果您还没有安装 TensorFlow,请参考官方文档进行安装。

2. 创建TFRecords文件

创建 TFRecord 文件需要使用 TensorFlow 提供的 tf.io.TFRecordWriter() 函数,该函数接收的参数是 TFRecord 文件的路径。在本例中,我们将文件存放到 "../mydata/" 目录下,文件名为 "mydata.tfrecord"。

import os
import tensorflow as tf

# 指定图片目录
image_dir = "../images/"

# 定义类别
labels = {
    "cat": 0,
    "dog": 1
}

# 获取图片列表
image_paths = []
for label in labels:
    path = os.path.join(image_dir, label)
    for filename in os.listdir(path):
        image_paths.append((os.path.join(path, filename), labels[label]))

# 打乱图片顺序
import random
random.shuffle(image_paths)

# 划分训练集和测试集
num_train = int(len(image_paths) * 0.8)
train_paths = image_paths[:num_train]
test_paths = image_paths[num_train:]

# 定义 TFRecord 文件路径
train_tfrecord_path = "../mydata/train.tfrecord"
test_tfrecord_path = "../mydata/test.tfrecord"

# 创建 tfrecord 文件
def create_tfrecord(tfrecord_path, image_paths):
    with tf.io.TFRecordWriter(tfrecord_path) as writer:
        for image_path, label in image_paths:
            # 读取图片
            with tf.io.gfile.GFile(image_path, "rb") as f:
                image_data = f.read()

            # 解码图片
            image = tf.image.decode_jpeg(image_data)

            # 转换为 Tensor 并改变 shape
            image = tf.image.convert_image_dtype(image, tf.float32)
            image = tf.image.resize(image, [128, 128])
            image = tf.reshape(image, [-1])

            # 定义 Example
            example = tf.train.Example(features=tf.train.Features(feature={
                "image": tf.train.Feature(float_list=tf.train.FloatList(value=image.numpy())),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
            }))

            # 写入 TFRecord 文件
            writer.write(example.SerializeToString())

# 分别创建训练集和测试集的 TFRecord 文件
create_tfrecord(train_tfrecord_path, train_paths)
create_tfrecord(test_tfrecord_path, test_paths)

在上面的代码中,我们首先指定了图片目录(image_dir)和类别(labels),然后获取了图片列表(image_paths),并打乱图片顺序。接着,我们将数据划分为训练集(train_paths)和测试集(test_paths)。

接下来,我们定义了 TFRecord 文件的路径(train_tfrecord_path和test_tfrecord_path)并分别创建了训练集和测试集的TFRecord文件。

在 create_tfrecord() 函数内,我们首先读取图片并解码,然后将其转换为 Tensor,并改变其 shape。接着,我们使用 tf.train.Example 定义了 Example,并将其写入了 TFRecord 文件。

3. 读取TFRecords文件

使用 TensorFlow 读取 TFRecords 文件可以使用 tf.data.TFRecordDataset() 函数,该函数接收的参数是 TFRecord 文件的路径。在本例中,我们将读取 "../mydata/train.tfrecord" 文件。

import tensorflow as tf

# 定义 TFRecord 文件路径
tfrecord_path = "../mydata/train.tfrecord"

# 定义 Feature 字典,如下所示
feature_description = {
    "image": tf.io.FixedLenFeature([], tf.float32),
    "label": tf.io.FixedLenFeature([], tf.int64)
}

# 定义解析函数,用于解析 Example
def _parse_function(example_proto):
    example = tf.io.parse_single_example(example_proto, feature_description)
    image = tf.reshape(example["image"], [128, 128, 3])
    label = example["label"]
    return image, label

# 读取 TFRecord 文件
dataset = tf.data.TFRecordDataset(tfrecord_path)

# 解析 Example
dataset = dataset.map(_parse_function)

# 设置批次大小并打乱数据
batch_size = 32
dataset = dataset.batch(batch_size).shuffle(buffer_size=batch_size*10)

# 输出数据形状
for images, labels in dataset.take(1):
    print(images.shape, labels.shape)

在上面的代码中,我们定义了以 "../mydata/train.tfrecord" 为路径的 TFRecord 文件,并定义了 Feature 字典(feature_description),用于解析 Example。

然后,我们定义了解析函数(_parse_function),用于解析 Example,并将图片还原为原来的 shape。接着,我们使用 tf.data.TFRecordDataset() 函数读取了 TFRecord 文件,并使用 map() 函数解析 Example。

最后,我们设置了批次大小(batch_size)并打乱了数据。通过遍历 dataset 并输出第一个 batch 的形状,我们可以得到数据集的形状。

4. 结果

通过上面的步骤,我们可以成功地将不同大小的图片保存到 TFRecords 并读取出来。接下来,我们将通过两条示例说明如何使用上述代码:

示例1:增加图片大小并保存到 TFRecords 文件

我们可以将上述代码中的图片大小从 128x128 增加到 224x224,并保存为新的 TFRecords 文件。代码如下所示:

import os
import tensorflow as tf

# 指定图片目录
image_dir = "../images/"

# 定义类别
labels = {
    "cat": 0,
    "dog": 1
}

# 获取图片列表
image_paths = []
for label in labels:
    path = os.path.join(image_dir, label)
    for filename in os.listdir(path):
        image_paths.append((os.path.join(path, filename), labels[label]))

# 打乱图片顺序
import random
random.shuffle(image_paths)

# 划分训练集和测试集
num_train = int(len(image_paths) * 0.8)
train_paths = image_paths[:num_train]
test_paths = image_paths[num_train:]

# 定义 TFRecord 文件路径
train_tfrecord_path = "../mydata/train.tfrecord"
test_tfrecord_path = "../mydata/test.tfrecord"

# 创建 tfrecord 文件
def create_tfrecord(tfrecord_path, image_paths):
    with tf.io.TFRecordWriter(tfrecord_path) as writer:
        for image_path, label in image_paths:
            # 读取图片
            with tf.io.gfile.GFile(image_path, "rb") as f:
                image_data = f.read()

            # 解码图片
            image = tf.image.decode_jpeg(image_data)

            # 转换为 Tensor 并改变 shape
            image = tf.image.convert_image_dtype(image, tf.float32)
            image = tf.image.resize(image, [224, 224])
            image = tf.reshape(image, [-1])

            # 定义 Example
            example = tf.train.Example(features=tf.train.Features(feature={
                "image": tf.train.Feature(float_list=tf.train.FloatList(value=image.numpy())),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
            }))

            # 写入 TFRecord 文件
            writer.write(example.SerializeToString())

# 分别创建训练集和测试集的 TFRecord 文件
create_tfrecord(train_tfrecord_path, train_paths)
create_tfrecord(test_tfrecord_path, test_paths)

在上述代码中,我们使用了 tf.image.resize() 函数将图片大小从 128x128 增加到了 224x224。

示例2:使用 TFRecords 数据集训练模型

我们可以将上述代码中的读取 TFRecords 数据集部分用于训练模型。代码如下所示:

import tensorflow as tf

# 定义 TFRecord 文件路径
tfrecord_path = "../mydata/train.tfrecord"

# 定义 Feature 字典,如下所示
feature_description = {
    "image": tf.io.FixedLenFeature([], tf.float32),
    "label": tf.io.FixedLenFeature([], tf.int64)
}

# 定义解析函数,用于解析 Example
def _parse_function(example_proto):
    example = tf.io.parse_single_example(example_proto, feature_description)
    image = tf.reshape(example["image"], [224, 224, 3])
    label = example["label"]
    return image, label

# 读取 TFRecord 文件
dataset = tf.data.TFRecordDataset(tfrecord_path)

# 解析 Example
dataset = dataset.map(_parse_function)

# 设置批次大小并打乱数据
batch_size = 32
dataset = dataset.batch(batch_size).shuffle(buffer_size=batch_size*10)

# 模型定义
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, kernel_size=(3,3), activation='relu', input_shape=(224,224,3)),
    tf.keras.layers.MaxPooling2D(pool_size=(2,2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(2, activation='softmax')
])

# 模型编译
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 模型训练
model.fit(dataset, epochs=5)

在上述代码中,我们将 TFRecord 文件中的数据作为模型的输入,并使用 tf.keras 搭建了一个简单的卷积神经网络模型。经过训练后,我们可以得到模型的准确率。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:TensorFLow 不同大小图片的TFrecords存取实例 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • dpn网络的pytorch实现方式

    下面是关于“dpn网络的pytorch实现方式”的完整攻略: DPN网络简介 DPN(Dual Path Network)网络是一种深度卷积神经网络。与传统的卷积神经网络不同,DPN网络引入了双向路径机制,以提高网络的性能和稳定性。其核心思想是将特征图分成两个路径,分别进行特征提取和特征融合。 DPN网络的pytorch实现方式 下面是DPN网络的pytor…

    人工智能概论 2023年5月25日
    00
  • Spring Boot应用Docker化的步骤详解

    下面我来详细讲解如何将Spring Boot应用Docker化的步骤。 一、安装Docker首先需要在本地安装Docker,可以去Docker官网(https://www.docker.com/)下载对应系统的版本进行安装。 二、编写Dockerfile文件我们需要编写一个Dockerfile文件,用来定义如何构建Docker镜像。以下是一个示例的Docke…

    人工智能概览 2023年5月25日
    00
  • google jQuery 引用文件,jQuery 引用地址集合(jquery 1.2.6至jquery1.5.2)

    下面就来详细讲解一下“Google jQuery 引用文件,jQuery 引用地址集合(jQuery1.2.6至jQuery1.5.2)”的完整攻略。 1. Google jQuery 引用文件 Google 提供了 CDN(内容分发网络)来加速开发者网页内容的传输。通过使用 Google 提供的在线库,可以让用户在访问网站时更快地下载页面所需的文件和其他内…

    人工智能概论 2023年5月25日
    00
  • python3+PyQt5实现使用剪贴板做复制与粘帖示例

    下面我来为你详细讲解“python3+PyQt5实现使用剪贴板做复制与粘帖示例”的完整攻略。 1.准备工作 在开始编写代码之前,我们需要先安装必要的依赖包: Python3 PyQt5 对于python依赖库可以使用pip安装 pip3 install PyQt5 2. 剪贴板基础用法 在PyQt中,使用剪贴板操作非常简单。可以通过QApplication.…

    人工智能概览 2023年5月25日
    00
  • C++之openFrameworks框架介绍

    C++之openFrameworks框架介绍 什么是openFrameworks openFrameworks是一个开源的C++跨平台创意编程框架,旨在使创意编程变得更加容易、更容易使用并且开放。它通过封装大量的C++库和硬件驱动程序,提供了一种快速开发原型、制作交互式的多媒体应用程序、绘画、制作自动化等领域的框架。它支持多种操作系统,如Linux、MacO…

    人工智能概览 2023年5月25日
    00
  • Windows环境下配置Qt 5.8+opencv 3.1.0开发环境的方法

    下面是详细的“Windows环境下配置Qt 5.8+opencv 3.1.0开发环境的方法”的攻略: 环境要求 Windows操作系统 Qt5.8+(建议使用官方安装包,如qt-opensource-windows-x86-5.8.0.exe) opencv3.1.0+ (建议使用官方安装包,如opencv-3.1.0.exe) 步骤 1. 安装Qt5 安装…

    人工智能概览 2023年5月25日
    00
  • centos7如何设置密码规则?centos7设置密码规则的方法

    下面是详细讲解“centos7如何设置密码规则?centos7设置密码规则的方法”的完整攻略。 设置密码规则 CentOS 7使用强密码来保护用户的帐户。在CentOS 7中,通过修改PAM(Pluggable Authentication Modules,可插入身份验证模块)配置文件,可以设置密码规则来确保用户密码的强度。下面是设置密码规则的步骤: 步骤1…

    人工智能概览 2023年5月25日
    00
  • java+MongoDB实现存图片、下载图片的方法示例

    接下来我将详细讲解“java+MongoDB实现存图片、下载图片的方法示例”的完整攻略。 1. 简介 MongoDB是一个NoSQL数据库,它简化了复杂查询和数据模型。它很好地支持面向文档的数据存储,使得存储和检索图片等二进制数据变得更容易。Java是一种广泛使用的编程语言,支持面向对象编程。它也非常适合用于与MongoDB一起工作,以实现存储和检索二进制数…

    人工智能概论 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部