最近在做一个鉴黄的项目,数据量比较大,有几百个G,一次性加入内存再去训练模青型是不现实的。

查阅资料发现keras中可以用两种方法解决,一是将数据转为tfrecord,但转换后数据大小会方法不好;另外一种就是利用generator,先一次加入所有数据的路径,然后每个batch的读入

# 读取图片函数
def get_im_cv2(paths, img_rows, img_cols, color_type=1, normalize=True):
    '''
    参数:
        paths:要读取的图片路径列表
        img_rows:图片行
        img_cols:图片列
        color_type:图片颜色通道
    返回: 
        imgs: 图片数组
    '''
    # Load as grayscale
    imgs = []
    for path in paths:
        if color_type == 1:
            img = cv2.imread(path, 0)
        elif color_type == 3:
            img = cv2.imread(path)
        # Reduce size
        resized = cv2.resize(img, (img_cols, img_rows))
        if normalize:
            resized = resized.astype('float32')
            resized /= 127.5
            resized -= 1. 
        
        imgs.append(resized)
        
    return np.array(imgs).reshape(len(paths), img_rows, img_cols, color_type)

 

 

def get_train_batch(X_train, y_train, batch_size, img_w, img_h, color_type, is_argumentation):
    '''
    参数:
        X_train:所有图片路径列表
        y_train: 所有图片对应的标签列表
        batch_size:批次
        img_w:图片宽
        img_h:图片高
        color_type:图片类型
        is_argumentation:是否需要数据增强
    返回: 
        一个generator,x: 获取的批次图片 y: 获取的图片对应的标签
    '''
    while 1:
        for i in range(0, len(X_train), batch_size):
            x = get_im_cv2(X_train[i:i+batch_size], img_w, img_h, color_type)
            y = y_train[i:i+batch_size]
            if is_argumentation:
                # 数据增强
                x, y = img_augmentation(x, y)
            # 最重要的就是这个yield,它代表返回,返回以后循环还是会继续,然后再返回。就比如有一个机器一直在作累加运算,但是会把每次累加中间结果告诉你一样,直到把所有数加完
            yield(np.array(x}, np.array(y))

 

 

 

result = model.fit_generator(generator=get_train_batch(X_train, y_train, train_batch_size, img_w, img_h, color_type, True), 
          steps_per_epoch=1351, 
          epochs=50, verbose=1,
          validation_data=get_train_batch(X_valid, y_valid, valid_batch_size,img_w, img_h, color_type, False),
          validation_steps=52,
          callbacks=[ckpt, early_stop],
          max_queue_size=capacity,
          workers=1)

 

参考:https://www.jianshu.com/p/5bdae9dcfc9c

          https://keras.io/zh/models/model/