Keras是什么,以及相关的基础知识,这里就不做详细介绍,请参考Keras学习站点http://keras-cn.readthedocs.io/en/latest/
Tensorflow作为backend时的训练逻辑梳理,主要是结合项目,研究了下源代码!
我们的项目是智能问答机器人,基于双向RNN(准确的说是GRU)网络,这里网络结构,就不做介绍,只研究其中的训练逻辑,我们的训练是基于fit_generator,即基于生成器模型,节省内存,有助效率提升。
什么是生成器以及生成器的工作原理,这里不表,属于python的基础范畴。
1. Keras的训练,是基于batch进行的,每一个batch训练过程,进行一次loss和acc的调整
1.1 .主要核心代码
A. /home/anaconda2/lib/python2.7/site-packages/keras/legacy/interfaces.py
1)里面的装饰器函数generate_legacy_interface里面。这里涉及到fit_generator这个最为核心的入口函数的执行过程。
2)python里面装饰器工作原理,非常类似java代码里面的AOP切面编程逻辑,即在正常的业务逻辑执行前,将before或者after或者两者都执行一下。
3)训练函数原型及重要参数解释
def fit_generator(self, generator, #生成器,一个yield的函数,迭代返回数据
steps_per_epoch, #一次训练周期(具体epoch是什么含义,要理解清楚)里面进行多少次batch
epochs=1, #设置进行几次全数据集的训练,每一次全数据集训练过程被定义成一个epoch,其实这个是可以灵活应用的
verbose=1, #一个开关,打开时,打印清晰的训练数据,即加载ProgbarLogger这个回调函数
callbacks=None, #设置业务需要的回调函数,我们的模型中添加了ModelCheckpoint这个回调函数
validation_data=None, #验证用的数据源设置,evaluate_generator函数要用到这个数据源,我们的项目里面,这里也是一个生成器
validation_steps=None, #设置验证多少次数据后取平均值作为此epoch训练后的效果,val_loss,val_acc的值受这个参数直接影响
class_weight=None, #此参数以及后续参数,我们的项目采用的都是默认值,可以参考官方文档了解细节
max_queue_size=10,
workers=1,
use_multiprocessing=False,
initial_epoch=0)
B. /home/anaconda2/lib/python2.7/site-packages/keras/callbacks.py
1)这里重点有ModelCheckpoint这个回调函数,涉及到业务参数,其他回调都是keras框架默认行为。
2)callback这个类,其实是一个容器,具体表现为一个List,可以在git_generator运行时,基于该函数的入参,构建一个Callback的实例,即一个list里面装入业务需要的callback实例,这里默认会有BaseLogger以及History这个callback,然后会判断verbose为true时,会添加ProgbarLogger这个callback,除此之外,就是fit_generator函数入参callbacks传入的参数。一般都会传递ModelCheckpoint这个。
3)在git_generator这个基于生成器模式训练的过程中,每一个epoch结束(on_epoch_end)时,都要调用这个callback函数(ModelCheckpoint)进行模型数据写文件的操作
2. Keras训练时用到的几个重要回调函数(主要工作在on_batch_end里面)
回调函数是基于抽象类Callback实现的。下面是Callback的成员函数,便于理解。
def __init__(self): self.validation_data = None def set_params(self, params): self.params = params def set_model(self, model): self.model = model def on_epoch_begin(self, epoch, logs=None): pass def on_epoch_end(self, epoch, logs=None): pass def on_batch_begin(self, batch, logs=None): pass def on_batch_end(self, batch, logs=None): pass def on_train_begin(self, logs=None): pass def on_train_end(self, logs=None): pass
A. keras.callbacks.BaseLogger
统计该batch里面训练的loss以及acc的值,计入totals,乘以batch_size后。
def on_batch_end(self, batch, logs=None): logs = logs or {} batch_size = logs.get('size', 0) self.seen += batch_size for k, v in logs.items(): if k in self.totals: self.totals[k] += v * batch_size else: self.totals[k] = v * batch_size
在BaseLogger这个类的on_epoch_end函数里,执行对这个epoch训练数据的loss以及acc求平均值。
def on_epoch_end(self, epoch, logs=None): if logs is not None: for k in self.params['metrics']: if k in self.totals: # Make value available to next callbacks. logs[k] = self.totals[k] / self.seen
B. keras.callbacks.ModelCheckpoint
在on_epoch_end时会保存模型数据进入文件
def on_epoch_end(self, epoch, logs=None): logs = logs or {} self.epochs_since_last_save += 1 if self.epochs_since_last_save >= self.period: self.epochs_since_last_save = 0 filepath = self.filepath.format(epoch=epoch, **logs) if self.save_best_only: current = logs.get(self.monitor) if current is None: warnings.warn('Can save best model only with %s available, ' 'skipping.' % (self.monitor), RuntimeWarning) else: if self.monitor_op(current, self.best): if self.verbose > 0: print('Epoch %05d: %s improved from %0.5f to %0.5f,' ' saving model to %s' % (epoch, self.monitor, self.best, current, filepath)) self.best = current if self.save_weights_only: self.model.save_weights(filepath, overwrite=True) else: self.model.save(filepath, overwrite=True) else: if self.verbose > 0: print('Epoch %05d: %s did not improve' % (epoch, self.monitor)) else: if self.verbose > 0: print('Epoch %05d: saving model to %s' % (epoch, filepath)) if self.save_weights_only: self.model.save_weights(filepath, overwrite=True) else: self.model.save(filepath, overwrite=True)
C.keras.callbacks.History
主要记录每一次epoch训练的结果,结果包含loss以及acc的值
D. keras.callbacks.ProgbarLogger
这个函数里面实现训练中间状态数据信息的输出,主要涉及进度相关信息。
3. 具体训练逻辑过程
A. 训练函数分析
a. model.fit_generator 训练入口函数(参考上面的函数原型定义), 我们项目中用tk_data_generator函数作为训练数据提供者(生成器)
1) callbacks.on_train_begin()
2) while epoch < epochs:
3) callbacks.on_epoch_begin(epoch)
4) while steps_done < steps_per_epoch:
5) generator_output = next(output_generator) #生成器next函数取输入数据进行训练,每次取一个batch大小的量
6) callbacks.on_batch_begin(batch_index, batch_logs)
7) outs = self.train_on_batch(x, y,sample_weight=sample_weight,class_weight=class_weight)
8) callbacks.on_batch_end(batch_index, batch_logs)
end of while steps_done < steps_per_epoch
self.evaluate_generator(...) #当一个epoch的最后一次batch执行完毕,执行一次训练效果的评估
9) callbacks.on_epoch_end(epoch, epoch_logs) #在这个执行过程中实现模型数据的保存操作
end of while epoch < epochs
10) callbacks.on_train_end()
b. 特别介绍下train_on_batch
train_on_batch (keras中的trainning.py)
|_self._standardize_user_data
|_self._make_train_function
|_self.train_function (tensorflow的函数)
|_updated = session.run(self.outputs + [self.updates_op], feed_dict=feed_dict,**self.session_kwargs)
B. 训练和验证的对比
a. 在每一个epoch的最后一个迭代(最后一次batch)时,要进行此轮epoch的校验(evaluate)
日志如下:
141/141 [==============================] - 12228s - loss: 0.5715 - acc: 0.6960 - val_loss: 0.5082 - val_acc: 0.7450 第一个141表示batch_index已经达到141,即steps_per_epoch参数规定的最后一步 第二个141表示steps_per_epoch,即一个epoch里面进行多少次batch处理 12228s 表示此batch处理结束所花费的时间 loss:此epoch里面的平均损失值 acc:此epoch里面的平均准确率 val_loss:此epoch训练完后进行的evaluate得到的损失值 val_acc:此epoch训练完后进行的evaluate得到的正确率
b. 验证逻辑,和训练逻辑差不多,只是将validation_steps指定次数的test的值进行取平均值,得到validation_steps次test的均值作为本epoch训练的最终效果
self.evaluate_generator(validation_data,validation_steps,max_queue_size=max_queue_size,workers=workers,use_multiprocessing=use_multiprocessing)
1) while steps_done < steps:
2) generator_output = next(output_generator)
3) outs = self.test_on_batch(x, y, sample_weight=sample_weight)
4)对上述while得到的每次outs进行 averages.append(np.average([out[i] for out in all_outs],weights=batch_sizes))
其中重点test_on_batch
test_on_batch(self, x, y, sample_weight=None)
|_self._standardize_user_data(x, y,sample_weight=sample_weight,check_batch_axis=True)
|_self._make_test_function()
|_self.test_function(ins)
|_updated = session.run(self.outputs + [self.updates_op],feed_dict=feed_dict,**self.session_kwargs)
c. train和test的重要区别,应该体现在下面的两个函数上
def _make_train_function(self): if not hasattr(self, 'train_function'): raise RuntimeError('You must compile your model before using it.') if self.train_function is None: inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights if self.uses_learning_phase and not isinstance(K.learning_phase(), int): inputs += [K.learning_phase()] with K.name_scope('training'): with K.name_scope(self.optimizer.__class__.__name__): training_updates = self.optimizer.get_updates( params=self._collected_trainable_weights, loss=self.total_loss) updates = self.updates + training_updates # Gets loss and metrics. Updates weights at each call. self.train_function = K.function(inputs, [self.total_loss] + self.metrics_tensors, updates=updates, name='train_function', **self._function_kwargs)
def _make_test_function(self): if not hasattr(self, 'test_function'): raise RuntimeError('You must compile your model before using it.') if self.test_function is None: inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights if self.uses_learning_phase and not isinstance(K.learning_phase(), int): inputs += [K.learning_phase()] # Return loss and metrics, no gradient updates. # Does update the network states. self.test_function = K.function(inputs, [self.total_loss] + self.metrics_tensors, updates=self.state_updates, name='test_function', **self._function_kwargs)
经过前面的代码逻辑梳理,可以看到不管是train的过程还是test的过程,最终底层都是调用Tensorflow的session.run方法进行loss和acc的获取,细心的观察,会发现两个session.run函数的入参其实有点不同。
结合上面train和test的私有函数中标注红色的注释,以及用K.function生成函数的入参中,可以看出train和test的差异。
总结:
0. 训练过程中,每次权重的更新都是在一个batch上进行一次,是基于batch量的数据为单位进行一次权重的更新
1. 基于生成器模型训练数据,可以提升效率,降低对物理服务器性能,尤其是内存的要求
2. 训练过程中,Callback函数执行了大量的工作,包括loss、acc值的记录,以及训练中间结果的日志反馈,最重要的是模型数据的输出,也是通过callback的方式实现(ModelCheckpoint)
3. 训练(train)和验证(evaluate/validate)的逻辑近乎一样,训练要更新权重,但是验证过程,仅仅更新网络状态,不涉及权重(loss以及acc参数)信息的更新
4. 代码梳理过程中,得出结论,Keras对python编程基本功底要求还是有点高的,采用了推导式编程习惯,生成器,装饰器,回调等编程思想,另外,对矩阵运算,例如numpy.dot以及numpy.multiply的数学逻辑都有一定要求,否则比较难看懂。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Keras/Tensorflow训练逻辑研究 - Python技术站