GRU(Gated Recurrent Unit)是LSTM的一个变体,也能克服RNN无法很好处理远距离依赖的问题。
GRU的结构跟LSTM类似,不过增加了让三个门层也接收细胞状态的输入,是常用的LSTM变体之一。
LSTM核心模块:
这一核心模块在GRU中变为:
CTC网络结构定义:
def get_model(height,nclass):
input = Input(shape=(height,None,1),name='the_input')
m = Conv2D(64,kernel_size=(3,3),activation='relu',padding='same',name='conv1')(input)
m = MaxPooling2D(pool_size=(2,2),strides=(2,2),name='pool1')(m)
m = Conv2D(128,kernel_size=(3,3),activation='relu',padding='same',name='conv2')(m)
m = MaxPooling2D(pool_size=(2,2),strides=(2,2),name='pool2')(m)
m = Conv2D(256,kernel_size=(3,3),activation='relu',padding='same',name='conv3')(m)
m = Conv2D(256,kernel_size=(3,3),activation='relu',padding='same',name='conv4')(m)
m = ZeroPadding2D(padding=(0,1))(m)
m = MaxPooling2D(pool_size=(2,2),strides=(2,1),padding='valid',name='pool3')(m)
m = Conv2D(512,kernel_size=(3,3),activation='relu',padding='same',name='conv5')(m)
m = BatchNormalization(axis=1)(m)
m = Conv2D(512,kernel_size=(3,3),activation='relu',padding='same',name='conv6')(m)
m = BatchNormalization(axis=1)(m)
m = ZeroPadding2D(padding=(0,1))(m)
m = MaxPooling2D(pool_size=(2,2),strides=(2,1),padding='valid',name='pool4')(m)
m = Conv2D(512,kernel_size=(2,2),activation='relu',padding='valid',name='conv7')(m)
m = Permute((2,1,3),name='permute')(m)
m = TimeDistributed(Flatten(),name='timedistrib')(m)
m = Bidirectional(GRU(rnnunit,return_sequences=True),name='blstm1')(m)
m = Dense(rnnunit,name='blstm1_out',activation='linear')(m)
m = Bidirectional(GRU(rnnunit,return_sequences=True),name='blstm2')(m)
y_pred = Dense(nclass,name='blstm2_out',activation='softmax')(m)
basemodel = Model(inputs=input,outputs=y_pred)
labels = Input(name='the_labels', shape=[None,], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])
model = Model(inputs=[input, labels, input_length, label_length], outputs=[loss_out])
sgd = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5)
#model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adadelta')
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=sgd)
model.summary()
return model,basemodel
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
the_input (InputLayer) (None, 32, None, 1) 0
____________________________________________________________________________________________________
conv1 (Conv2D) (None, 32, None, 64) 640 the_input[0][0]
____________________________________________________________________________________________________
pool1 (MaxPooling2D) (None, 16, None, 64) 0 conv1[0][0]
____________________________________________________________________________________________________
conv2 (Conv2D) (None, 16, None, 128) 73856 pool1[0][0]
____________________________________________________________________________________________________
pool2 (MaxPooling2D) (None, 8, None, 128) 0 conv2[0][0]
____________________________________________________________________________________________________
conv3 (Conv2D) (None, 8, None, 256) 295168 pool2[0][0]
____________________________________________________________________________________________________
conv4 (Conv2D) (None, 8, None, 256) 590080 conv3[0][0]
____________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D) (None, 8, None, 256) 0 conv4[0][0]
____________________________________________________________________________________________________
pool3 (MaxPooling2D) (None, 4, None, 256) 0 zero_padding2d_1[0][0]
____________________________________________________________________________________________________
conv5 (Conv2D) (None, 4, None, 512) 1180160 pool3[0][0]
____________________________________________________________________________________________________
batch_normalization_1 (BatchNorm (None, 4, None, 512) 16 conv5[0][0]
____________________________________________________________________________________________________
conv6 (Conv2D) (None, 4, None, 512) 2359808 batch_normalization_1[0][0]
____________________________________________________________________________________________________
batch_normalization_2 (BatchNorm (None, 4, None, 512) 16 conv6[0][0]
____________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D) (None, 4, None, 512) 0 batch_normalization_2[0][0]
____________________________________________________________________________________________________
pool4 (MaxPooling2D) (None, 2, None, 512) 0 zero_padding2d_2[0][0]
____________________________________________________________________________________________________
conv7 (Conv2D) (None, 1, None, 512) 1049088 pool4[0][0]
____________________________________________________________________________________________________
permute (Permute) (None, None, 1, 512) 0 conv7[0][0]
____________________________________________________________________________________________________
timedistrib (TimeDistributed) (None, None, 512) 0 permute[0][0]
____________________________________________________________________________________________________
blstm1 (Bidirectional) (None, None, 512) 1181184 timedistrib[0][0]
____________________________________________________________________________________________________
blstm1_out (Dense) (None, None, 256) 131328 blstm1[0][0]
____________________________________________________________________________________________________
blstm2 (Bidirectional) (None, None, 512) 787968 blstm1_out[0][0]
____________________________________________________________________________________________________
blstm2_out (Dense) (None, None, 5531) 2837403 blstm2[0][0]
____________________________________________________________________________________________________
the_labels (InputLayer) (None, None) 0
____________________________________________________________________________________________________
input_length (InputLayer) (None, 1) 0
____________________________________________________________________________________________________
label_length (InputLayer) (None, 1) 0
____________________________________________________________________________________________________
ctc (Lambda) (None, 1) 0 blstm2_out[0][0]
the_labels[0][0]
input_length[0][0]
label_length[0][0]
====================================================================================================
Total params: 10,486,715
Trainable params: 10,486,699
模型: 模型包含5500个中文字符,包括常用汉字、大小写英文字符、标点符号、特殊符号(@、¥、&)等,可以在现有模型基础上继续训练。
训练: 样本保存在data文件夹下,使用LMDB格式; train.py是训练文件,可以选择保存模型权重或模型结构+模型权重,训练结果保存在models文件夹下。
测试: test.py是中文OCR测试文件
识别效果:
济南华富锻造有限公司
夺得铜牌后,福民爱流下了激动的泪水。“石川
Itturnedoutthat328girswerenamedAbcdeintheUnitedstates
工程(含训练模型)地址: http://download.csdn.net/download/dcrmg/10248818
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Keras GRU 文字识别 - Python技术站