TensorFlow 不同大小图片的TFRecords存取实例
1. 环境配置
使用 TensorFlow 存取 TFRecords 首先需要安装 TensorFlow 。如果您还没有安装 TensorFlow,请参考官方文档进行安装。
2. 创建TFRecords文件
创建 TFRecord 文件需要使用 TensorFlow 提供的 tf.io.TFRecordWriter() 函数,该函数接收的参数是 TFRecord 文件的路径。在本例中,我们将文件存放到 "../mydata/" 目录下,文件名为 "mydata.tfrecord"。
import os
import tensorflow as tf
# 指定图片目录
image_dir = "../images/"
# 定义类别
labels = {
"cat": 0,
"dog": 1
}
# 获取图片列表
image_paths = []
for label in labels:
path = os.path.join(image_dir, label)
for filename in os.listdir(path):
image_paths.append((os.path.join(path, filename), labels[label]))
# 打乱图片顺序
import random
random.shuffle(image_paths)
# 划分训练集和测试集
num_train = int(len(image_paths) * 0.8)
train_paths = image_paths[:num_train]
test_paths = image_paths[num_train:]
# 定义 TFRecord 文件路径
train_tfrecord_path = "../mydata/train.tfrecord"
test_tfrecord_path = "../mydata/test.tfrecord"
# 创建 tfrecord 文件
def create_tfrecord(tfrecord_path, image_paths):
with tf.io.TFRecordWriter(tfrecord_path) as writer:
for image_path, label in image_paths:
# 读取图片
with tf.io.gfile.GFile(image_path, "rb") as f:
image_data = f.read()
# 解码图片
image = tf.image.decode_jpeg(image_data)
# 转换为 Tensor 并改变 shape
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [128, 128])
image = tf.reshape(image, [-1])
# 定义 Example
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(float_list=tf.train.FloatList(value=image.numpy())),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
# 写入 TFRecord 文件
writer.write(example.SerializeToString())
# 分别创建训练集和测试集的 TFRecord 文件
create_tfrecord(train_tfrecord_path, train_paths)
create_tfrecord(test_tfrecord_path, test_paths)
在上面的代码中,我们首先指定了图片目录(image_dir)和类别(labels),然后获取了图片列表(image_paths),并打乱图片顺序。接着,我们将数据划分为训练集(train_paths)和测试集(test_paths)。
接下来,我们定义了 TFRecord 文件的路径(train_tfrecord_path和test_tfrecord_path)并分别创建了训练集和测试集的TFRecord文件。
在 create_tfrecord() 函数内,我们首先读取图片并解码,然后将其转换为 Tensor,并改变其 shape。接着,我们使用 tf.train.Example 定义了 Example,并将其写入了 TFRecord 文件。
3. 读取TFRecords文件
使用 TensorFlow 读取 TFRecords 文件可以使用 tf.data.TFRecordDataset() 函数,该函数接收的参数是 TFRecord 文件的路径。在本例中,我们将读取 "../mydata/train.tfrecord" 文件。
import tensorflow as tf
# 定义 TFRecord 文件路径
tfrecord_path = "../mydata/train.tfrecord"
# 定义 Feature 字典,如下所示
feature_description = {
"image": tf.io.FixedLenFeature([], tf.float32),
"label": tf.io.FixedLenFeature([], tf.int64)
}
# 定义解析函数,用于解析 Example
def _parse_function(example_proto):
example = tf.io.parse_single_example(example_proto, feature_description)
image = tf.reshape(example["image"], [128, 128, 3])
label = example["label"]
return image, label
# 读取 TFRecord 文件
dataset = tf.data.TFRecordDataset(tfrecord_path)
# 解析 Example
dataset = dataset.map(_parse_function)
# 设置批次大小并打乱数据
batch_size = 32
dataset = dataset.batch(batch_size).shuffle(buffer_size=batch_size*10)
# 输出数据形状
for images, labels in dataset.take(1):
print(images.shape, labels.shape)
在上面的代码中,我们定义了以 "../mydata/train.tfrecord" 为路径的 TFRecord 文件,并定义了 Feature 字典(feature_description),用于解析 Example。
然后,我们定义了解析函数(_parse_function),用于解析 Example,并将图片还原为原来的 shape。接着,我们使用 tf.data.TFRecordDataset() 函数读取了 TFRecord 文件,并使用 map() 函数解析 Example。
最后,我们设置了批次大小(batch_size)并打乱了数据。通过遍历 dataset 并输出第一个 batch 的形状,我们可以得到数据集的形状。
4. 结果
通过上面的步骤,我们可以成功地将不同大小的图片保存到 TFRecords 并读取出来。接下来,我们将通过两条示例说明如何使用上述代码:
示例1:增加图片大小并保存到 TFRecords 文件
我们可以将上述代码中的图片大小从 128x128 增加到 224x224,并保存为新的 TFRecords 文件。代码如下所示:
import os
import tensorflow as tf
# 指定图片目录
image_dir = "../images/"
# 定义类别
labels = {
"cat": 0,
"dog": 1
}
# 获取图片列表
image_paths = []
for label in labels:
path = os.path.join(image_dir, label)
for filename in os.listdir(path):
image_paths.append((os.path.join(path, filename), labels[label]))
# 打乱图片顺序
import random
random.shuffle(image_paths)
# 划分训练集和测试集
num_train = int(len(image_paths) * 0.8)
train_paths = image_paths[:num_train]
test_paths = image_paths[num_train:]
# 定义 TFRecord 文件路径
train_tfrecord_path = "../mydata/train.tfrecord"
test_tfrecord_path = "../mydata/test.tfrecord"
# 创建 tfrecord 文件
def create_tfrecord(tfrecord_path, image_paths):
with tf.io.TFRecordWriter(tfrecord_path) as writer:
for image_path, label in image_paths:
# 读取图片
with tf.io.gfile.GFile(image_path, "rb") as f:
image_data = f.read()
# 解码图片
image = tf.image.decode_jpeg(image_data)
# 转换为 Tensor 并改变 shape
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [224, 224])
image = tf.reshape(image, [-1])
# 定义 Example
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(float_list=tf.train.FloatList(value=image.numpy())),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
# 写入 TFRecord 文件
writer.write(example.SerializeToString())
# 分别创建训练集和测试集的 TFRecord 文件
create_tfrecord(train_tfrecord_path, train_paths)
create_tfrecord(test_tfrecord_path, test_paths)
在上述代码中,我们使用了 tf.image.resize() 函数将图片大小从 128x128 增加到了 224x224。
示例2:使用 TFRecords 数据集训练模型
我们可以将上述代码中的读取 TFRecords 数据集部分用于训练模型。代码如下所示:
import tensorflow as tf
# 定义 TFRecord 文件路径
tfrecord_path = "../mydata/train.tfrecord"
# 定义 Feature 字典,如下所示
feature_description = {
"image": tf.io.FixedLenFeature([], tf.float32),
"label": tf.io.FixedLenFeature([], tf.int64)
}
# 定义解析函数,用于解析 Example
def _parse_function(example_proto):
example = tf.io.parse_single_example(example_proto, feature_description)
image = tf.reshape(example["image"], [224, 224, 3])
label = example["label"]
return image, label
# 读取 TFRecord 文件
dataset = tf.data.TFRecordDataset(tfrecord_path)
# 解析 Example
dataset = dataset.map(_parse_function)
# 设置批次大小并打乱数据
batch_size = 32
dataset = dataset.batch(batch_size).shuffle(buffer_size=batch_size*10)
# 模型定义
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, kernel_size=(3,3), activation='relu', input_shape=(224,224,3)),
tf.keras.layers.MaxPooling2D(pool_size=(2,2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(2, activation='softmax')
])
# 模型编译
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 模型训练
model.fit(dataset, epochs=5)
在上述代码中,我们将 TFRecord 文件中的数据作为模型的输入,并使用 tf.keras 搭建了一个简单的卷积神经网络模型。经过训练后,我们可以得到模型的准确率。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:TensorFLow 不同大小图片的TFrecords存取实例 - Python技术站