Construct a Custom Estimator
Construct a custom estimator, to be used to train and evaluate TensorFlow models.
estimator(model_fn, model_dir = NULL, config = NULL, params = NULL,
class = NULL)
Arguments
model_fn | The model function. See Model Function for details on the structure of a model function. |
model_dir | Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model. If |
config | Configuration object. |
params | List of hyper parameters that will be passed into |
class | An optional set of R classes to add to the generated object. |
Details
The Estimator
object wraps a model which is specified by a model_fn
,
which, given inputs and a number of other parameters, returns the operations
necessary to perform training, evaluation, and prediction.
All outputs (checkpoints, event files, etc.) are written to model_dir
, or a
subdirectory thereof. If model_dir
is not set, a temporary directory is
used.
The config
argument can be used to passed run configuration object
containing information about the execution environment. It is passed on to
the model_fn
, if the model_fn
has a parameter named "config" (and input
functions in the same manner). If the config
parameter is not passed, it is
instantiated by estimator()
. Not passing config means that defaults useful
for local execution are used. estimator()
makes config available to the
model (for instance, to allow specialization based on the number of workers
available), and also uses some of its fields to control internals, especially
regarding checkpointing.
The params
argument contains hyperparameters. It is passed to the
model_fn
, if the model_fn
has a parameter named "params", and to the
input functions in the same manner. estimator()
only passes params
along, it
does not inspect it. The structure of params
is therefore entirely up to
the developer.
None of estimator's methods can be overridden in subclasses (its
constructor enforces this). Subclasses should use model_fn
to configure the
base class, and may add methods implementing specialized functionality.
Model Functions
The model_fn
should be an R function of the form:
function(features, labels, mode, params) { # 1. Configure the model via TensorFlow operations. # 2. Define the loss function for training and evaluation. # 3. Define the training optimizer. # 4. Define how predictions should be produced. # 5. Return the result as an `estimator_spec()` object. estimator_spec(mode, predictions, loss, train_op, eval_metric_ops) }
The model function's inputs are defined as follows:
features |
The feature tensor(s). |
labels |
The label tensor(s). |
mode |
The current training mode ("train", "eval", "infer").
These can be accessed through the mode_keys() object. |
params |
An optional list of hyperparameters, as received
through the estimator() constructor. |
See estimator_spec()
for more details as to how the estimator specification
should be constructed, and https://www.tensorflow.org/extend/estimators#constructing_the_model_fn for
more information as to how the model function should be constructed.
See also
Other custom estimator methods: estimator_spec
,
evaluate.tf_estimator
,
export_savedmodel.tf_estimator
,
predict.tf_estimator
,
train.tf_estimator