关于“keras的ImageDataGenerator和flow()的用法说明”的完整攻略,以下是具体的讲解过程:
1. keras的ImageDataGenerator介绍
keras的ImageDataGenerator是为了在训练深度学习模型时,方便进行数据增强的工具。它可以帮助我们通过对数据集进行一定的变形、裁剪、旋转、翻转等操作,增加数据的数量及多样性,从而提高模型的泛化能力。
2. 数据增强的常用方式
数据增强的方法有很多,常用的包括:
- 数据旋转
- 图像翻转
- 对比度、亮度调整
- 随机切割等。
3. flow()方法的使用说明
ImageDataGenerator的主要功能是对数据进行预处理,使用其进行数据增强的方式主要有两种:
- fit()方法:fit()方法可以对ImageDatagenrator的预处理功能进行初始化,在对模型进行训练之前,使用它来确定数据的预处理方式。
- flow()方法:不同于fit()方法,flow()方法可以直接对图像生成器进行操作,并返回多个批次的增强数据。
接下来我们详细说明一下使用方法。
4. ImageDataGenerator的初始化设置
当我们准备使用ImageDataGenerator来生成数据集时,我们需要对其进行初始化。这里假设我们的数据集是一组猫狗的图像,我们需要针对这些图像进行数据增强。以下是一个具体的示例代码:
from keras.preprocessing.image import ImageDataGenerator
# 初始化ImageDataGenerator
DataGen = ImageDataGenerator(
rescale=1. / 255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
上述代码中的各参数的含义如下:
- rescale:归一化参数,因为一般图像的像素数值范围是[0,255],我们可以将其缩放到[0,1]之间
- rotation_range:旋转范围,用于指定图像随机旋转角度的最大值(以度为单位)
- width_shift_range和height_shift_range:宽度移动和高度移动范围,用于指定图像在水平和垂直方向上可移动的最大比例
- shear_range:剪切范围,用于指定逆时针方向中的剪切变换角度,以度为单位
- zoom_range:缩放范围,用于指定可以缩放图像的最大程度
- horizontal_flip:水平翻转,用于确定是否随机应用水平翻转的操作(True,或False)
- fill_mode:填充方式,当进行变形操作时,需要在变形的空白区域填充像素,fill_mode指定填充的方式(常用方式有:nearest、constant、reflect等)
参数设置完成后,我们可以使用fit()方法来对ImageDataGenerator进行初始化,示例如下:
from keras.preprocessing.image import ImageDataGenerator
# 初始化ImageDataGenerator
DataGen = ImageDataGenerator(
rescale=1. / 255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
# 载入数据集
trainGen = DataGen.flow_from_directory(
'catsvgdogs/train', # 数据集路径
target_size=(150, 150), # 将所有图像裁剪成150x150的规格
batch_size=32, # 每个批次的大小
class_mode='binary' # 给定分类的依据,本例中只有猫和狗两种分类,故选择binary
)
5. flow()方法的使用
ImageDataGenerator提供了两个方法:fit()和flow()。当我们使用fit()方法对ImageDataGenerator进行初始化后,在对模型进行训练之前,我们需要使用其内部参数来对图像进行预处理。而flow()方法可以直接对图像生成器进行操作,并返回多个批次的增强数据。以下是一个示例代码:
from keras.preprocessing.image import ImageDataGenerator
# 初始化ImageDataGenerator
DataGen = ImageDataGenerator(
rescale=1. / 255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
# 载入数据集
trainGen = DataGen.flow_from_directory(
'catsvgdogs/train', # 数据集路径
target_size=(150, 150), # 将所有图像裁剪成150x150的规格
batch_size=32, # 每个批次的大小
class_mode='binary' # 给定分类的依据,本例中只有猫和狗两种分类,故选择binary
)
# 使用flow()方法得到增强后的数据
imageList0, labelList0 = trainGen.next()
imageList1, labelList1 = trainGen.next()
# 输出大小
print('imageList0: ', imageList0.shape)
print('labelList0: ', labelList0.shape)
print('imageList1: ', imageList1.shape)
print('labelList1: ', labelList1.shape)
以上代码中,我们先进行ImageDataGenerator的初始化设置,然后对训练集进行fit操作后便可以直接使用flow()方法来得到其增强后的数据集。需要注意的是,ImageDataGenerator支持使用多个增强操作,当对图像进行增强操作时,每次都会随机选择其中一种方式。
6. 总结
至此,我们已经详细讲解了ImageDataGenerator和flow()方法的使用说明。以上方式适用于基于keras的图像分类任务,在实际应用中可以根据需求进行调整,具体操作难度较小,过程简单易懂。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:keras的ImageDataGenerator和flow()的用法说明 - Python技术站