TensorFLow 不同大小图片的TFrecords存取实例

TensorFlow 不同大小图片的TFRecords存取实例

1. 环境配置

使用 TensorFlow 存取 TFRecords 首先需要安装 TensorFlow 。如果您还没有安装 TensorFlow,请参考官方文档进行安装。

2. 创建TFRecords文件

创建 TFRecord 文件需要使用 TensorFlow 提供的 tf.io.TFRecordWriter() 函数,该函数接收的参数是 TFRecord 文件的路径。在本例中,我们将文件存放到 "../mydata/" 目录下,文件名为 "mydata.tfrecord"。

import os
import tensorflow as tf

# 指定图片目录
image_dir = "../images/"

# 定义类别
labels = {
    "cat": 0,
    "dog": 1
}

# 获取图片列表
image_paths = []
for label in labels:
    path = os.path.join(image_dir, label)
    for filename in os.listdir(path):
        image_paths.append((os.path.join(path, filename), labels[label]))

# 打乱图片顺序
import random
random.shuffle(image_paths)

# 划分训练集和测试集
num_train = int(len(image_paths) * 0.8)
train_paths = image_paths[:num_train]
test_paths = image_paths[num_train:]

# 定义 TFRecord 文件路径
train_tfrecord_path = "../mydata/train.tfrecord"
test_tfrecord_path = "../mydata/test.tfrecord"

# 创建 tfrecord 文件
def create_tfrecord(tfrecord_path, image_paths):
    with tf.io.TFRecordWriter(tfrecord_path) as writer:
        for image_path, label in image_paths:
            # 读取图片
            with tf.io.gfile.GFile(image_path, "rb") as f:
                image_data = f.read()

            # 解码图片
            image = tf.image.decode_jpeg(image_data)

            # 转换为 Tensor 并改变 shape
            image = tf.image.convert_image_dtype(image, tf.float32)
            image = tf.image.resize(image, [128, 128])
            image = tf.reshape(image, [-1])

            # 定义 Example
            example = tf.train.Example(features=tf.train.Features(feature={
                "image": tf.train.Feature(float_list=tf.train.FloatList(value=image.numpy())),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
            }))

            # 写入 TFRecord 文件
            writer.write(example.SerializeToString())

# 分别创建训练集和测试集的 TFRecord 文件
create_tfrecord(train_tfrecord_path, train_paths)
create_tfrecord(test_tfrecord_path, test_paths)

在上面的代码中,我们首先指定了图片目录(image_dir)和类别(labels),然后获取了图片列表(image_paths),并打乱图片顺序。接着,我们将数据划分为训练集(train_paths)和测试集(test_paths)。

接下来,我们定义了 TFRecord 文件的路径(train_tfrecord_path和test_tfrecord_path)并分别创建了训练集和测试集的TFRecord文件。

在 create_tfrecord() 函数内,我们首先读取图片并解码,然后将其转换为 Tensor,并改变其 shape。接着,我们使用 tf.train.Example 定义了 Example,并将其写入了 TFRecord 文件。

3. 读取TFRecords文件

使用 TensorFlow 读取 TFRecords 文件可以使用 tf.data.TFRecordDataset() 函数,该函数接收的参数是 TFRecord 文件的路径。在本例中,我们将读取 "../mydata/train.tfrecord" 文件。

import tensorflow as tf

# 定义 TFRecord 文件路径
tfrecord_path = "../mydata/train.tfrecord"

# 定义 Feature 字典,如下所示
feature_description = {
    "image": tf.io.FixedLenFeature([], tf.float32),
    "label": tf.io.FixedLenFeature([], tf.int64)
}

# 定义解析函数,用于解析 Example
def _parse_function(example_proto):
    example = tf.io.parse_single_example(example_proto, feature_description)
    image = tf.reshape(example["image"], [128, 128, 3])
    label = example["label"]
    return image, label

# 读取 TFRecord 文件
dataset = tf.data.TFRecordDataset(tfrecord_path)

# 解析 Example
dataset = dataset.map(_parse_function)

# 设置批次大小并打乱数据
batch_size = 32
dataset = dataset.batch(batch_size).shuffle(buffer_size=batch_size*10)

# 输出数据形状
for images, labels in dataset.take(1):
    print(images.shape, labels.shape)

在上面的代码中,我们定义了以 "../mydata/train.tfrecord" 为路径的 TFRecord 文件,并定义了 Feature 字典(feature_description),用于解析 Example。

然后,我们定义了解析函数(_parse_function),用于解析 Example,并将图片还原为原来的 shape。接着,我们使用 tf.data.TFRecordDataset() 函数读取了 TFRecord 文件,并使用 map() 函数解析 Example。

最后,我们设置了批次大小(batch_size)并打乱了数据。通过遍历 dataset 并输出第一个 batch 的形状,我们可以得到数据集的形状。

4. 结果

通过上面的步骤,我们可以成功地将不同大小的图片保存到 TFRecords 并读取出来。接下来,我们将通过两条示例说明如何使用上述代码:

示例1:增加图片大小并保存到 TFRecords 文件

我们可以将上述代码中的图片大小从 128x128 增加到 224x224,并保存为新的 TFRecords 文件。代码如下所示:

import os
import tensorflow as tf

# 指定图片目录
image_dir = "../images/"

# 定义类别
labels = {
    "cat": 0,
    "dog": 1
}

# 获取图片列表
image_paths = []
for label in labels:
    path = os.path.join(image_dir, label)
    for filename in os.listdir(path):
        image_paths.append((os.path.join(path, filename), labels[label]))

# 打乱图片顺序
import random
random.shuffle(image_paths)

# 划分训练集和测试集
num_train = int(len(image_paths) * 0.8)
train_paths = image_paths[:num_train]
test_paths = image_paths[num_train:]

# 定义 TFRecord 文件路径
train_tfrecord_path = "../mydata/train.tfrecord"
test_tfrecord_path = "../mydata/test.tfrecord"

# 创建 tfrecord 文件
def create_tfrecord(tfrecord_path, image_paths):
    with tf.io.TFRecordWriter(tfrecord_path) as writer:
        for image_path, label in image_paths:
            # 读取图片
            with tf.io.gfile.GFile(image_path, "rb") as f:
                image_data = f.read()

            # 解码图片
            image = tf.image.decode_jpeg(image_data)

            # 转换为 Tensor 并改变 shape
            image = tf.image.convert_image_dtype(image, tf.float32)
            image = tf.image.resize(image, [224, 224])
            image = tf.reshape(image, [-1])

            # 定义 Example
            example = tf.train.Example(features=tf.train.Features(feature={
                "image": tf.train.Feature(float_list=tf.train.FloatList(value=image.numpy())),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
            }))

            # 写入 TFRecord 文件
            writer.write(example.SerializeToString())

# 分别创建训练集和测试集的 TFRecord 文件
create_tfrecord(train_tfrecord_path, train_paths)
create_tfrecord(test_tfrecord_path, test_paths)

在上述代码中,我们使用了 tf.image.resize() 函数将图片大小从 128x128 增加到了 224x224。

示例2:使用 TFRecords 数据集训练模型

我们可以将上述代码中的读取 TFRecords 数据集部分用于训练模型。代码如下所示:

import tensorflow as tf

# 定义 TFRecord 文件路径
tfrecord_path = "../mydata/train.tfrecord"

# 定义 Feature 字典,如下所示
feature_description = {
    "image": tf.io.FixedLenFeature([], tf.float32),
    "label": tf.io.FixedLenFeature([], tf.int64)
}

# 定义解析函数,用于解析 Example
def _parse_function(example_proto):
    example = tf.io.parse_single_example(example_proto, feature_description)
    image = tf.reshape(example["image"], [224, 224, 3])
    label = example["label"]
    return image, label

# 读取 TFRecord 文件
dataset = tf.data.TFRecordDataset(tfrecord_path)

# 解析 Example
dataset = dataset.map(_parse_function)

# 设置批次大小并打乱数据
batch_size = 32
dataset = dataset.batch(batch_size).shuffle(buffer_size=batch_size*10)

# 模型定义
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, kernel_size=(3,3), activation='relu', input_shape=(224,224,3)),
    tf.keras.layers.MaxPooling2D(pool_size=(2,2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(2, activation='softmax')
])

# 模型编译
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 模型训练
model.fit(dataset, epochs=5)

在上述代码中,我们将 TFRecord 文件中的数据作为模型的输入,并使用 tf.keras 搭建了一个简单的卷积神经网络模型。经过训练后,我们可以得到模型的准确率。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:TensorFLow 不同大小图片的TFrecords存取实例 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • 关于Java整合RabbitMQ实现生产消费的7种通讯方式

    关于Java整合RabbitMQ实现生产消费的7种通讯方式,以下是详细的攻略。 1. 概述 RabbitMQ是一个流行的开源消息中间件,被广泛用于构建可靠、可扩展和高性能的分布式系统,而Java作为一种流行的编程语言,也提供了丰富的集成库来实现与RabbitMQ的通讯。Java整合RabbitMQ实现生产消费主要有以下7种通讯方式: 原生AMQP协议 Spr…

    人工智能概览 2023年5月25日
    00
  • python中时间转换datetime和pd.to_datetime详析

    Python中时间转换:datetime和pd.to_datetime详析 在Python中,时间的处理是一个常见需求。为了方便处理时间类型变量,Python提供了datetime库来进行时间转换。此外,pandas库也提供了pd.to_datetime函数来进行时间变量的转换。本文将详细介绍datetime和pd.to_datetime的使用方法和区别。 …

    人工智能概论 2023年5月25日
    00
  • Django项目中使用JWT的实现代码

    下面是关于Django项目中使用JWT的实现代码的完整攻略,包括最基本的JWT的使用和带有自定义用户模型的JWT使用: 基本JWT的使用 步骤1:安装相关库 在Django项目中使用JWT,需要安装两个Python库:pyjwt和django-rest-framework-jwt,可以使用以下命令进行安装: pip install pyjwt pip ins…

    人工智能概论 2023年5月25日
    00
  • 谈谈Redis分布式锁的正确实现方法

    谈谈Redis分布式锁的正确实现方法 在分布式系统中,为了避免因为多个线程同时对同一个资源进行写操作而出现的数据竞争问题,我们需要对关键代码段进行加锁,以保证在同一时间内只有一个线程对资源进行写操作。Redis作为一种高性能、高可用、可扩展的非关系型数据库,其分布式锁的实现也备受关注。 Redis分布式锁的基本原理 Redis分布式锁的基本原理是:当多个客户…

    人工智能概览 2023年5月25日
    00
  • Ubuntu中搭建Nginx、PHP环境最简单的方法

    搭建Nginx和PHP环境需要以下步骤: 1. 安装Nginx 在Ubuntu系统中,可以通过以下命令安装Nginx: sudo apt update sudo apt install nginx 安装完成后,可以使用以下命令检查Nginx是否安装成功: nginx -v 这会输出Nginx的版本号,表示安装成功。 2. 安装PHP 在Ubuntu系统中,可…

    人工智能概论 2023年5月25日
    00
  • ASP.NET(C#)读取Excel的文件内容

    下面我将为你详细讲解“ASP.NET(C#)读取Excel的文件内容”的完整攻略。 一、准备工作 在读取Excel文件之前,我们需要进行一些准备工作。 引入命名空间 在使用C#读取Excel文件之前,需要引入System.Data.OleDb命名空间,该命名空间包含了访问Excel文件的相关类。 csharpusing System.Data.OleDb; …

    人工智能概览 2023年5月25日
    00
  • Django路由层如何获取正确的url

    Django框架的路由层负责将HTTP请求映射到相应的视图函数。在Web开发中,获取正确的URL是非常重要的,可以通过以下步骤实现。 1. 定义URL路由模式 在Django应用程序中,首先需要定义URL路由模式。这可以通过在应用程序的urls.py文件中定义来实现。路由模式通常由路径模式、视图函数和URL名称组成。例如,以下代码定义了一个使用正则表达式匹配…

    人工智能概览 2023年5月25日
    00
  • Spring Cloud Eureka服务治理的实现

    Spring Cloud Eureka服务治理的实现 Spring Cloud Eureka是SpringCloud的子项目之一,用于实现服务治理。服务治理是SpringCloud微服务核心思想之一,其主要目的是协调各个微服务之间的通信,以便于负载均衡、故障恢复、服务升级等。在此文档中,我们将详细讲解“Spring Cloud Eureka服务治理的实现”的…

    人工智能概览 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部