在自定义估计器过程中,搞清Estimator 与model_fn 及其他参数之间的关系十分中重要!总结一下,就是
estimator 拿着获取到的参数往model_fn里面灌,model_fn 是作为用数据的关键用户。
与scikit-learn和spark中的各种估计器相比,tensorflow的估计器抽象程度更高,因为他将各种由超参数知道构建的
模型作为参数传入,estimator的结构和定义不会因为模型的变化带来特别大的变化;而spark,scikit-learn中,估计器
往往因算法不同而有不同构造,TensorFlow的参数化程度更高,有更高自由度,因而参数管理就与前两者有所不同!
总之,Estimator要使用传入的数据就必须了解传入的数据,java有种类型控制,Python中鸭子判断检查,或者有元数据帮忙了解传入的数据,
或者大家有默契约定,或者有明显的协议!Esimator和mode_fn之间没有强制约束,靠大家默契约定,约定内容就在下面的英文描述中。
Depending on the value of mode
, different arguments are required. Namely
* For `mode == ModeKeys.TRAIN`: required fields are `loss` and `train_op`.
* For `mode == ModeKeys.EVAL`: required field is `loss`.
* For `mode == ModeKeys.PREDICT`: required fields are `predictions`.
class Estimator(object):
"""Estimator class to train and evaluate TensorFlow models.
The Estimator
object wraps a model which is specified by a model_fn
,
which, given inputs and a number of other parameters, returns the ops
necessary to perform training, evaluation, or predictions.
All outputs (checkpoints, event files, etc.) are written to model_dir
, or a
subdirectory thereof. If model_dir
is not set, a temporary directory is
used.
The config
argument can be passed tf.estimator.RunConfig
object containing
information about the execution environment. It is passed on to themodel_fn
, if the model_fn
has a parameter named "config" (and input
functions in the same manner). If the config
parameter is not passed, it is
instantiated by the Estimator
. Not passing config means that defaults useful
for local execution are used. Estimator
makes config available to the model
(for instance, to allow specialization based on the number of workers
available), and also uses some of its fields to control internals, especially
regarding checkpointing.
The params
argument contains hyperparameters. It is passed to themodel_fn
, if the model_fn
has a parameter named "params", and to the input
functions in the same manner. Estimator
only passes params along, it does
not inspect it. The structure of params
is therefore entirely up to the
developer.
None of Estimator
's methods can be overridden in subclasses (its
constructor enforces this). Subclasses should use model_fn
to configure
the base class, and may add methods implementing specialized functionality.
@compatibility(eager)
Calling methods of Estimator
will work while eager execution is enabled.
However, the model_fn
and input_fn
is not executed eagerly, Estimator
will switch to graph model before calling all user-provided functions (incl.
hooks), so their code has to be compatible with graph mode execution. Note
that input_fn
code using tf.data
generally works in both graph and eager
modes.
@end_compatibility
"""
def init(self, model_fn, model_dir=None, config=None, params=None,
warm_start_from=None):
"""Constructs an Estimator
instance.
See [estimators](https://tensorflow.org/guide/estimators) for more
information.
To warm-start an `Estimator`:
```python
estimator = tf.estimator.DNNClassifier(
feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
hidden_units=[1024, 512, 256],
warm_start_from="/path/to/checkpoint/dir")
```
For more details on warm-start configuration, see
`tf.estimator.WarmStartSettings`.
Args:
model_fn: Model function. Follows the signature:
* Args:
* `features`: This is the first item returned from the `input_fn`
passed to `train`, `evaluate`, and `predict`. This should be a
single `tf.Tensor` or `dict` of same.
* `labels`: This is the second item returned from the `input_fn`
passed to `train`, `evaluate`, and `predict`. This should be a
single `tf.Tensor` or `dict` of same (for multi-head models).
If mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` will
be passed. If the `model_fn`'s signature does not accept
`mode`, the `model_fn` must still be able to handle
`labels=None`.
* `mode`: Optional. Specifies if this training, evaluation or
prediction. See `tf.estimator.ModeKeys`.
* `params`: Optional `dict` of hyperparameters. Will receive what
is passed to Estimator in `params` parameter. This allows
to configure Estimators from hyper parameter tuning.
* `config`: Optional `estimator.RunConfig` object. Will receive what
is passed to Estimator as its `config` parameter, or a default
value. Allows setting up things in your `model_fn` based on
configuration such as `num_ps_replicas`, or `model_dir`.
* Returns:
`tf.estimator.EstimatorSpec`
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into an estimator to
continue training a previously saved model. If `PathLike` object, the
path will be resolved. If `None`, the model_dir in `config` will be used
if set. If both are set, they must be same. If both are `None`, a
temporary directory will be used.
config: `estimator.RunConfig` configuration object.
params: `dict` of hyper parameters that will be passed into `model_fn`.
Keys are names of parameters, values are basic python types.
warm_start_from: Optional string filepath to a checkpoint or SavedModel to
warm-start from, or a `tf.estimator.WarmStartSettings`
object to fully configure warm-starting. If the string
filepath is provided instead of a
`tf.estimator.WarmStartSettings`, then all variables are
warm-started, and it is assumed that vocabularies
and `tf.Tensor` names are unchanged.
Raises:
ValueError: parameters of `model_fn` don't match `params`.
ValueError: if this is called via a subclass and if that class overrides
a member of `Estimator`.
"""
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:tensorflow estimator 与 model_fn 是这样沟通的 - Python技术站