以 mnist数据集为例子,需要用自己的数据集训练,可以参考我之前的博客:
tensorflow2 keras 调用官方提供的模型训练分类与测试 [https://www.cnblogs.com/yanghailin/p/12601043.html]
还有其他的可以自己翻

训练脚本,同时打印网络结构,保存了网络图和loss,acc图,保存训练的模型

以下为训练脚本,保存了网络图和loss,acc图

import numpy as np
import keras
from keras.utils import plot_model
from keras.datasets import mnist

from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
import matplotlib.pyplot as plt

EPOCH = 50
path_save_net_pic = './net_pic.jpg'
path_save_loss_pic = './loss-acc.jpg'
path_save_model = "mnist_model.h5"

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = {'batch':[], 'epoch':[]}
        self.accuracy = {'batch':[], 'epoch':[]}
        self.val_loss = {'batch':[], 'epoch':[]}
        self.val_acc = {'batch':[], 'epoch':[]}

    def on_batch_end(self, batch, logs={}):
        self.losses['batch'].append(logs.get('loss'))
        self.accuracy['batch'].append(logs.get('acc'))
        self.val_loss['batch'].append(logs.get('val_loss'))
        self.val_acc['batch'].append(logs.get('val_acc'))

    def on_epoch_end(self, batch, logs={}):
        self.losses['epoch'].append(logs.get('loss'))
        self.accuracy['epoch'].append(logs.get('acc'))
        self.val_loss['epoch'].append(logs.get('val_loss'))
        self.val_acc['epoch'].append(logs.get('val_acc'))

    def loss_plot(self, loss_type):
        iters = range(len(self.losses[loss_type]))
        plt.figure()
        # acc
        plt.plot(iters, self.accuracy[loss_type], 'r', label='train acc')
        # loss
        plt.plot(iters, self.losses[loss_type], 'g', label='train loss')
        if loss_type == 'epoch':
            # val_acc
            plt.plot(iters, self.val_acc[loss_type], 'b', label='val acc')
            # val_loss
            plt.plot(iters, self.val_loss[loss_type], 'k', label='val loss')
        plt.grid(True)
        plt.xlabel(loss_type)
        plt.ylabel('acc-loss')
        plt.legend(loc="upper right")
        plt.savefig(path_save_loss_pic)
        plt.show()



(x_train, y_train), (x_test, y_test) = mnist.load_data()

print(x_train[0].shape)
print(y_train)

########################### x 处理 ##################################
# 将训练集合中的数字变成标准的四维张量形式(样本数量、长、宽、深(灰度图 1))
# 并将像素值变成浮点格式
width = 28
height = 28
depth = 1 
x_train = x_train.reshape(x_train.shape[0], width, height, depth).astype('float32')
x_test = x_test.reshape(x_test.shape[0], width, height, depth).astype('float32')

# 归一化处理,将像素值控制在 0 - 1
x_train /= 255.0
x_test /= 255.0
classes = 10


from keras.utils import to_categorical
y_train_ohe = to_categorical(y_train)
y_test_ohe = to_categorical(y_test)

###################### 搭建卷积神经网络 ###############################
model = Sequential()
# 添加卷积层,构造 64 个过滤器,过滤器范围 3x3x1, 过滤器步长为 1, 图像四周补一圈 0, 并用 relu 非线性变换
model.add(Conv2D(filters=32, kernel_size=(3,3), strides=(1,1), padding='same', input_shape=(width, height, 1), activation='relu'))
# 添加 Max_Pooling , 2 x 2 取最大值
model.add(MaxPooling2D(pool_size=(2, 2)))
# 设立 Dropout , 将概率设为 0.5
model.add(Dropout(0.35))

#重复构造, 搭建神经网络
model.add(Conv2D(64, kernel_size=(3, 3), strides=(1,1), padding='same', activation='relu'))
#model.add(MaxPooling2D(pool_size=(2, 2)))
#model.add(Dropout(0.5))
#model.add(Conv2D(64, kernel_size=(3,3), strides=(1, 1), padding='same', activation='relu'))
#model.add((MaxPooling2D(pool_size=(2, 2))))
model.add(Dropout(0.25))

model.add(Conv2D(16, kernel_size=(5, 5), strides=(1,1), padding='same', activation='relu'))

# 将当前节点展平, 构造全连神经网络
model.add(Flatten())

# 构造全连接神经网络
model.add(Dense(64, activation='relu'))
model.add(Dense(16, activation='relu'))
model.add(Dropout(0.4))
model.add(Dense(classes, activation='softmax'))

################################ 编译模型 ##########################
# 一般,分类问题的损失函数才有交叉熵 (Cross Entropy)
model.compile(loss='categorical_crossentropy', optimizer='adagrad', metrics=['accuracy'])
model.summary()
plot_model(model, to_file=path_save_net_pic, show_shapes=True)

history = LossHistory()
######################### 训练模型 ################################
model.fit(x_train, y_train_ohe, validation_data=(x_test, y_test_ohe), epochs=EPOCH, batch_size=64, callbacks=[history])

######################## 评价模型 ################################
scores = model.evaluate(x_test, y_test_ohe, verbose=0)

######################## 保持模型与权重 ################################
# 保持整个模型(包括结构、权重)
model.save(path_save_model)

print('Test score:', scores[0])
print('Test accuracy:', scores[1])

#绘制acc-loss曲线
history.loss_plot('epoch')

参数:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 28, 28, 32)        320       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 32)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 14, 14, 64)        18496     
_________________________________________________________________
dropout_2 (Dropout)          (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 14, 14, 16)        25616     
_________________________________________________________________
flatten_1 (Flatten)          (None, 3136)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                200768    
_________________________________________________________________
dense_2 (Dense)              (None, 16)                1040      
_________________________________________________________________
dropout_3 (Dropout)          (None, 16)                0         
_________________________________________________________________
dense_3 (Dense)              (None, 10)                170       
=================================================================
Total params: 246,410
Trainable params: 246,410
Non-trainable params: 0

训练日志

Epoch 41/50
60000/60000 [==============================] - 5s 89us/step - loss: 0.1466 - acc: 0.9366 - val_loss: 0.0417 - val_acc: 0.9906
Epoch 42/50
60000/60000 [==============================] - 5s 88us/step - loss: 0.1418 - acc: 0.9387 - val_loss: 0.0416 - val_acc: 0.9905
Epoch 43/50
60000/60000 [==============================] - 5s 89us/step - loss: 0.1421 - acc: 0.9379 - val_loss: 0.0405 - val_acc: 0.9913
Epoch 44/50
60000/60000 [==============================] - 5s 90us/step - loss: 0.1430 - acc: 0.9372 - val_loss: 0.0401 - val_acc: 0.9916
Epoch 45/50
60000/60000 [==============================] - 5s 90us/step - loss: 0.1441 - acc: 0.9367 - val_loss: 0.0408 - val_acc: 0.9911
Epoch 46/50
60000/60000 [==============================] - 5s 88us/step - loss: 0.1397 - acc: 0.9391 - val_loss: 0.0413 - val_acc: 0.9913
Epoch 47/50
60000/60000 [==============================] - 5s 88us/step - loss: 0.1396 - acc: 0.9392 - val_loss: 0.0413 - val_acc: 0.9909
Epoch 48/50
60000/60000 [==============================] - 5s 88us/step - loss: 0.1385 - acc: 0.9402 - val_loss: 0.0405 - val_acc: 0.9916
Epoch 49/50
60000/60000 [==============================] - 5s 89us/step - loss: 0.1382 - acc: 0.9409 - val_loss: 0.0423 - val_acc: 0.9912
Epoch 50/50
60000/60000 [==============================] - 5s 88us/step - loss: 0.1394 - acc: 0.9410 - val_loss: 0.0412 - val_acc: 0.9917
Test score: 0.04119344436894603
Test accuracy: 0.9917

网络图
keras  训练保存网络图,查看loss,acc,单张图片推理,保存中间feature map图,查看参数
loss-acc图:
keras  训练保存网络图,查看loss,acc,单张图片推理,保存中间feature map图,查看参数

加载模型(这里只加载模型文件包括了网络),单张图片预测

keras  训练保存网络图,查看loss,acc,单张图片推理,保存中间feature map图,查看参数

import numpy as np
from keras.models import load_model
import matplotlib.pyplot as plt
from PIL import Image
from keras.models import Model
import cv2

path_img = '/data_2/everyday/0515/xian/code/6/2.jpeg'

model = load_model('/data_2/everyday/0515/xian/code/6/mnist_model.h5')

def pre_pic(picName):
    # 先打开传入的原始图片
    img = Image.open(picName)
    # 使用消除锯齿的方法resize图片
    reIm = img.resize((28,28),Image.ANTIALIAS)
    # 变成灰度图,转换成矩阵
    im_arr = np.array(reIm.convert("L"))
    return im_arr


img = cv2.imread(path_img)
_,img_1 = cv2.threshold(img, 120, 255, cv2.THRESH_BINARY_INV)

im1 = img_1.copy()
im1 = cv2.resize(im1,(28,28))
im1 = cv2.cvtColor(im1,cv2.COLOR_BGR2GRAY)
#im1 = pre_pic(path_img)
print('输入数字:')

plt.imshow(im1,cmap=plt.get_cmap('gray'))
plt.show()

im1 = im1.reshape((1,28,28,1))
im1 = im1.astype('float32')/255

predict = model.predict_classes(im1)


print ('answer=:')
print (predict)
cv2.imshow("pic",img)
cv2.imshow("pic_bi",img_1)
cv2.waitKey(0)

显示中间某层的feature map

首先看网络

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 28, 28, 32)        320       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 32)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 14, 14, 64)        18496     
_________________________________________________________________
dropout_2 (Dropout)          (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 14, 14, 16)        25616     
_________________________________________________________________
flatten_1 (Flatten)          (None, 3136)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                200768    
_________________________________________________________________
dense_2 (Dense)              (None, 16)                1040      
_________________________________________________________________
dropout_3 (Dropout)          (None, 16)                0         
_________________________________________________________________
dense_3 (Dense)              (None, 10)                170       
=================================================================
Total params: 246,410
Trainable params: 246,410
Non-trainable params: 0

比如看conv2d_1 (Conv2D) (None, 28, 28, 32) 这个的feature map

28*28是长宽,32是通道数
我输入的是这张图:
keras  训练保存网络图,查看loss,acc,单张图片推理,保存中间feature map图,查看参数
经过conv2d_1层之后,可以得到32个feature map图如下:
keras  训练保存网络图,查看loss,acc,单张图片推理,保存中间feature map图,查看参数
代码如下:

import numpy as np
from keras.models import load_model
import matplotlib.pyplot as plt
from PIL import Image
from keras.models import Model

model = load_model('/data_2/everyday/0515/xian/code/6/mnist_model.h5')#######加载模型文件############################

def pre_pic(picName):
    # 先打开传入的原始图片
    img = Image.open(picName)
    # 使用消除锯齿的方法resize图片
    reIm = img.resize((28,28),Image.ANTIALIAS)
    # 变成灰度图,转换成矩阵
    im_arr = np.array(reIm.convert("L"))
    return im_arr

im1 = pre_pic('/data_2/everyday/0515/xian/code/6/9.jpeg')
print('输入数字:')

plt.imshow(im1,cmap=plt.get_cmap('gray'))
plt.show()

im1 = im1.reshape((1,28,28,1))
im1 = im1.astype('float32')/255

m1 = Model(inputs=model.input, outputs=model.get_layer('conv2d_1').output)##写层名 可以看打印出来的网络###################################

predict = m1.predict(im1)
print("predict-shape",predict.shape)
#print ('识别为:')
#print (predict)

plt.figure()
for i in range(0,32):##通道数###################################
    x_output = predict[0, :, :, i]#x_output = predict[0,:,:,0]
    max = np.max(x_output)
    print("max value is :",max)
    x_output =x_output.astype("float32") / max * 255
    print("x_output.shape=",x_output.shape)

    from PIL import Image as PILImage
    x_output =PILImage.fromarray(np.asarray(x_output))
    x_output1 = x_output.resize((28,28))##长宽####################################
    #plt.imshow(x_output1,cmap=plt.get_cmap('gray'))
    #plt.show()
    #x_output1.imshow(x_output1)

    plt.subplot(4, 8, i+1)###32个图,4行8列显示#############################################

    plt.imshow(x_output1,cmap=plt.get_cmap('gray'))
plt.savefig("./conv2d_1.jpg")###保存#############################################
plt.show()

同样的,我需要看pool层之后的feature map

max_pooling2d_1 (MaxPooling2 (None, 14, 14, 32)
keras  训练保存网络图,查看loss,acc,单张图片推理,保存中间feature map图,查看参数

import numpy as np
from keras.models import load_model
import matplotlib.pyplot as plt
from PIL import Image
from keras.models import Model

model = load_model('/data_2/everyday/0515/xian/code/6/mnist_model.h5')#######加载模型文件############################

def pre_pic(picName):
    # 先打开传入的原始图片
    img = Image.open(picName)
    # 使用消除锯齿的方法resize图片
    reIm = img.resize((28,28),Image.ANTIALIAS)
    # 变成灰度图,转换成矩阵
    im_arr = np.array(reIm.convert("L"))
    return im_arr

im1 = pre_pic('/data_2/everyday/0515/xian/code/6/9.jpeg')
print('输入数字:')

plt.imshow(im1,cmap=plt.get_cmap('gray'))
plt.show()

im1 = im1.reshape((1,28,28,1))
im1 = im1.astype('float32')/255


m1 = Model(inputs=model.input, outputs=model.get_layer('max_pooling2d_1').output)##写层名 可以看打印出来的网络###################################

predict = m1.predict(im1)
print("predict-shape",predict.shape)
#print ('识别为:')
#print (predict)


plt.figure()
for i in range(0,32):##通道数###################################
    x_output = predict[0, :, :, i]#x_output = predict[0,:,:,0]
    max = np.max(x_output)
    print("max value is :",max)
    x_output =x_output.astype("float32") / max * 255
    print("x_output.shape=",x_output.shape)

    from PIL import Image as PILImage
    x_output =PILImage.fromarray(np.asarray(x_output))
    x_output1 = x_output.resize((14,14))##长宽####################################
    #plt.imshow(x_output1,cmap=plt.get_cmap('gray'))
    #plt.show()
    #x_output1.imshow(x_output1)

    plt.subplot(4, 8, i+1)###32个图,4行8列显示#############################################

    plt.imshow(x_output1,cmap=plt.get_cmap('gray'))
plt.savefig("./max_pooling2d_1.jpg")###保存#############################################
plt.show()

同样的,我需要看conv2d_2层之后的feature map

conv2d_2 (Conv2D) (None, 14, 14, 64)
keras  训练保存网络图,查看loss,acc,单张图片推理,保存中间feature map图,查看参数
代码一样,注意一些参数的变化:

import numpy as np
from keras.models import load_model
import matplotlib.pyplot as plt
from PIL import Image
from keras.models import Model

model = load_model('/data_2/everyday/0515/xian/code/6/mnist_model.h5')#######加载模型文件############################

def pre_pic(picName):
    # 先打开传入的原始图片
    img = Image.open(picName)
    # 使用消除锯齿的方法resize图片
    reIm = img.resize((28,28),Image.ANTIALIAS)
    # 变成灰度图,转换成矩阵
    im_arr = np.array(reIm.convert("L"))
    return im_arr

im1 = pre_pic('/data_2/everyday/0515/xian/code/6/9.jpeg')
print('输入数字:')

plt.imshow(im1,cmap=plt.get_cmap('gray'))
plt.show()

im1 = im1.reshape((1,28,28,1))
im1 = im1.astype('float32')/255


m1 = Model(inputs=model.input, outputs=model.get_layer('conv2d_2').output)##写层名 可以看打印出来的网络###################################

predict = m1.predict(im1)
print("predict-shape",predict.shape)
#print ('识别为:')
#print (predict)


plt.figure()
for i in range(0,64):##通道数###################################
    x_output = predict[0, :, :, i]#x_output = predict[0,:,:,0]
    max = np.max(x_output)
    print("max value is :",max)
    x_output =x_output.astype("float32") / max * 255
    print("x_output.shape=",x_output.shape)

    from PIL import Image as PILImage
    x_output =PILImage.fromarray(np.asarray(x_output))
    x_output1 = x_output.resize((14,14))##长宽####################################
    #plt.imshow(x_output1,cmap=plt.get_cmap('gray'))
    #plt.show()
    #x_output1.imshow(x_output1)

    plt.subplot(8, 8, i+1)###32个图,4行8列显示#############################################

    plt.imshow(x_output1,cmap=plt.get_cmap('gray'))
plt.savefig("./conv2d_2.jpg")###保存#############################################
plt.show()

下面给出调用resnet18的例子

import numpy as np
import keras
from keras.utils import plot_model
from keras.datasets import mnist

from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
import matplotlib.pyplot as plt
import resnet

#%matplotlib inline


EPOCH = 30
path_save_net_pic = './net_pic.jpg'
path_save_loss_pic = './loss-acc.jpg'
path_save_model = "3-0-mnist_model_resnet.h5"

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = {'batch':[], 'epoch':[]}
        self.accuracy = {'batch':[], 'epoch':[]}
        self.val_loss = {'batch':[], 'epoch':[]}
        self.val_acc = {'batch':[], 'epoch':[]}

    def on_batch_end(self, batch, logs={}):
        self.losses['batch'].append(logs.get('loss'))
        self.accuracy['batch'].append(logs.get('acc'))
        self.val_loss['batch'].append(logs.get('val_loss'))
        self.val_acc['batch'].append(logs.get('val_acc'))

    def on_epoch_end(self, batch, logs={}):
        self.losses['epoch'].append(logs.get('loss'))
        self.accuracy['epoch'].append(logs.get('acc'))
        self.val_loss['epoch'].append(logs.get('val_loss'))
        self.val_acc['epoch'].append(logs.get('val_acc'))

    def loss_plot(self, loss_type):
        iters = range(len(self.losses[loss_type]))
        plt.figure()
        # acc
        plt.plot(iters, self.accuracy[loss_type], 'r', label='train acc')
        # loss
        plt.plot(iters, self.losses[loss_type], 'g', label='train loss')
        if loss_type == 'epoch':
            # val_acc
            plt.plot(iters, self.val_acc[loss_type], 'b', label='val acc')
            # val_loss
            plt.plot(iters, self.val_loss[loss_type], 'k', label='val loss')
        plt.grid(True)
        plt.xlabel(loss_type)
        plt.ylabel('acc-loss')
        plt.legend(loc="upper right")
        plt.savefig(path_save_loss_pic)
        plt.show()



(x_train, y_train), (x_test, y_test) = mnist.load_data()

print(x_train[0].shape)
print(y_train)

########################### x 处理 ##################################
# 将训练集合中的数字变成标准的四维张量形式(样本数量、长、宽、深(灰度图 1))
# 并将像素值变成浮点格式
width = 28
height = 28
depth = 1 
x_train = x_train.reshape(x_train.shape[0], width, height, depth).astype('float32')
x_test = x_test.reshape(x_test.shape[0], width, height, depth).astype('float32')

# 归一化处理,将像素值控制在 0 - 1
x_train /= 255.0
x_test /= 255.0
classes = 10


from keras.utils import to_categorical
y_train_ohe = to_categorical(y_train)
y_test_ohe = to_categorical(y_test)


model = resnet.ResnetBuilder.build_resnet_18((depth, height, width), classes)


# 一般,分类问题的损失函数才有交叉熵 (Cross Entropy)
model.compile(loss='categorical_crossentropy', optimizer='adagrad', metrics=['accuracy'])
model.summary()
plot_model(model, to_file=path_save_net_pic, show_shapes=True)

history = LossHistory()
######################### 训练模型 ################################
model.fit(x_train, y_train_ohe, validation_data=(x_test, y_test_ohe), epochs=EPOCH, batch_size=64, callbacks=[history])

######################## 评价模型 ################################
scores = model.evaluate(x_test, y_test_ohe, verbose=0)

######################## 保持模型与权重 ################################
# 保持整个模型(包括结构、权重)
model.save(path_save_model)

print('Test score:', scores[0])
print('Test accuracy:', scores[1])

#绘制acc-loss曲线
history.loss_plot('epoch')

resnet18

from __future__ import division

import six
from keras.models import Model
from keras.layers import (
    Input,
    Activation,
    Dense,
    Flatten
)
from keras.layers.convolutional import (
    Conv2D,
    MaxPooling2D,
    AveragePooling2D
)
from keras.layers.merge import add
from keras.layers.normalization import BatchNormalization
from keras.regularizers import l2
from keras import backend as K


def _bn_relu(input):
    """Helper to build a BN -> relu block
    """
    norm = BatchNormalization(axis=CHANNEL_AXIS)(input)
    return Activation("relu")(norm)


def _conv_bn_relu(**conv_params):
    """Helper to build a conv -> BN -> relu block
    """
    filters = conv_params["filters"]
    kernel_size = conv_params["kernel_size"]
    strides = conv_params.setdefault("strides", (1, 1))
    kernel_initializer = conv_params.setdefault("kernel_initializer", "he_normal")
    padding = conv_params.setdefault("padding", "same")
    kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4))

    def f(input):
        conv = Conv2D(filters=filters, kernel_size=kernel_size,
                      strides=strides, padding=padding,
                      kernel_initializer=kernel_initializer,
                      kernel_regularizer=kernel_regularizer)(input)
        return _bn_relu(conv)

    return f


def _bn_relu_conv(**conv_params):
    """Helper to build a BN -> relu -> conv block.
    This is an improved scheme proposed in http://arxiv.org/pdf/1603.05027v2.pdf
    """
    filters = conv_params["filters"]
    kernel_size = conv_params["kernel_size"]
    strides = conv_params.setdefault("strides", (1, 1))
    kernel_initializer = conv_params.setdefault("kernel_initializer", "he_normal")
    padding = conv_params.setdefault("padding", "same")
    kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4))

    def f(input):
        activation = _bn_relu(input)
        return Conv2D(filters=filters, kernel_size=kernel_size,
                      strides=strides, padding=padding,
                      kernel_initializer=kernel_initializer,
                      kernel_regularizer=kernel_regularizer)(activation)

    return f


def _shortcut(input, residual):
    """Adds a shortcut between input and residual block and merges them with "sum"
    """
    # Expand channels of shortcut to match residual.
    # Stride appropriately to match residual (width, height)
    # Should be int if network architecture is correctly configured.
    input_shape = K.int_shape(input)
    residual_shape = K.int_shape(residual)
    stride_width = int(round(input_shape[ROW_AXIS] / residual_shape[ROW_AXIS]))
    stride_height = int(round(input_shape[COL_AXIS] / residual_shape[COL_AXIS]))
    equal_channels = input_shape[CHANNEL_AXIS] == residual_shape[CHANNEL_AXIS]

    shortcut = input
    # 1 X 1 conv if shape is different. Else identity.
    if stride_width > 1 or stride_height > 1 or not equal_channels:
        shortcut = Conv2D(filters=residual_shape[CHANNEL_AXIS],
                          kernel_size=(1, 1),
                          strides=(stride_width, stride_height),
                          padding="valid",
                          kernel_initializer="he_normal",
                          kernel_regularizer=l2(0.0001))(input)

    return add([shortcut, residual])


def _residual_block(block_function, filters, repetitions, is_first_layer=False):
    """Builds a residual block with repeating bottleneck blocks.
    """
    def f(input):
        for i in range(repetitions):
            init_strides = (1, 1)
            if i == 0 and not is_first_layer:
                init_strides = (2, 2)
            input = block_function(filters=filters, init_strides=init_strides,
                                   is_first_block_of_first_layer=(is_first_layer and i == 0))(input)
        return input

    return f


def basic_block(filters, init_strides=(1, 1), is_first_block_of_first_layer=False):
    """Basic 3 X 3 convolution blocks for use on resnets with layers <= 34.
    Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf
    """
    def f(input):

        if is_first_block_of_first_layer:
            # don't repeat bn->relu since we just did bn->relu->maxpool
            conv1 = Conv2D(filters=filters, kernel_size=(3, 3),
                           strides=init_strides,
                           padding="same",
                           kernel_initializer="he_normal",
                           kernel_regularizer=l2(1e-4))(input)
        else:
            conv1 = _bn_relu_conv(filters=filters, kernel_size=(3, 3),
                                  strides=init_strides)(input)

        residual = _bn_relu_conv(filters=filters, kernel_size=(3, 3))(conv1)
        return _shortcut(input, residual)

    return f


def bottleneck(filters, init_strides=(1, 1), is_first_block_of_first_layer=False):
    """Bottleneck architecture for > 34 layer resnet.
    Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf
    Returns:
        A final conv layer of filters * 4
    """
    def f(input):

        if is_first_block_of_first_layer:
            # don't repeat bn->relu since we just did bn->relu->maxpool
            conv_1_1 = Conv2D(filters=filters, kernel_size=(1, 1),
                              strides=init_strides,
                              padding="same",
                              kernel_initializer="he_normal",
                              kernel_regularizer=l2(1e-4))(input)
        else:
            conv_1_1 = _bn_relu_conv(filters=filters, kernel_size=(1, 1),
                                     strides=init_strides)(input)

        conv_3_3 = _bn_relu_conv(filters=filters, kernel_size=(3, 3))(conv_1_1)
        residual = _bn_relu_conv(filters=filters * 4, kernel_size=(1, 1))(conv_3_3)
        return _shortcut(input, residual)

    return f


def _handle_dim_ordering():
    global ROW_AXIS
    global COL_AXIS
    global CHANNEL_AXIS
    if K.image_dim_ordering() == 'tf':
        ROW_AXIS = 1
        COL_AXIS = 2
        CHANNEL_AXIS = 3
    else:
        CHANNEL_AXIS = 1
        ROW_AXIS = 2
        COL_AXIS = 3


def _get_block(identifier):
    if isinstance(identifier, six.string_types):
        res = globals().get(identifier)
        if not res:
            raise ValueError('Invalid {}'.format(identifier))
        return res
    return identifier


class ResnetBuilder(object):
    @staticmethod
    def build(input_shape, num_outputs, block_fn, repetitions):
        """Builds a custom ResNet like architecture.
        Args:
            input_shape: The input shape in the form (nb_channels, nb_rows, nb_cols)
            num_outputs: The number of outputs at final softmax layer
            block_fn: The block function to use. This is either `basic_block` or `bottleneck`.
                The original paper used basic_block for layers < 50
            repetitions: Number of repetitions of various block units.
                At each block unit, the number of filters are doubled and the input size is halved
        Returns:
            The keras `Model`.
        """
        _handle_dim_ordering()
        if len(input_shape) != 3:
            raise Exception("Input shape should be a tuple (nb_channels, nb_rows, nb_cols)")

        # Permute dimension order if necessary
        if K.image_dim_ordering() == 'tf':
            input_shape = (input_shape[1], input_shape[2], input_shape[0])

        # Load function from str if needed.
        block_fn = _get_block(block_fn)

        input = Input(shape=input_shape)
        conv1 = _conv_bn_relu(filters=64, kernel_size=(7, 7), strides=(2, 2))(input)
        pool1 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding="same")(conv1)

        block = pool1
        filters = 64
        for i, r in enumerate(repetitions):
            block = _residual_block(block_fn, filters=filters, repetitions=r, is_first_layer=(i == 0))(block)
            filters *= 2

        # Last activation
        block = _bn_relu(block)

        # Classifier block
        block_shape = K.int_shape(block)
        pool2 = AveragePooling2D(pool_size=(block_shape[ROW_AXIS], block_shape[COL_AXIS]),
                                 strides=(1, 1))(block)
        flatten1 = Flatten()(pool2)
        dense = Dense(units=num_outputs, kernel_initializer="he_normal",
                      activation="softmax")(flatten1)

        model = Model(inputs=input, outputs=dense)
        return model

    @staticmethod
    def build_resnet_18(input_shape, num_outputs):
        return ResnetBuilder.build(input_shape, num_outputs, basic_block, [2, 2, 2, 2])

    @staticmethod
    def build_resnet_34(input_shape, num_outputs):
        return ResnetBuilder.build(input_shape, num_outputs, basic_block, [3, 4, 6, 3])

    @staticmethod
    def build_resnet_50(input_shape, num_outputs):
        return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 4, 6, 3])

    @staticmethod
    def build_resnet_101(input_shape, num_outputs):
        return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 4, 23, 3])

    @staticmethod
    def build_resnet_152(input_shape, num_outputs):
        return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 8, 36, 3])