Train an Estimator

Train an estimator on a set of input data provides by the input_fn().

# S3 method for tf_estimator
train(object, input_fn, steps = NULL, hooks = NULL,
  max_steps = NULL, saving_listeners = NULL, ...)

Arguments

object

A TensorFlow estimator.

input_fn

An input function, typically generated by the input_fn() helper function.

steps

The number of steps for which the model should be trained on this particular train() invocation. If NULL (the default), this function will either train forever, or until the supplied input_fn() has provided all available data.

hooks

A list of R functions, to be used as callbacks inside the training loop. By default, hook_history_saver(every_n_step = 10) and hook_progress_bar() will be attached if not provided to save the metrics history and create the progress bar.

max_steps

The total number of steps for which the model should be trained. If set, steps must be NULL. If the estimator has already been trained a total of max_steps times, then no training will be performed.

saving_listeners

(Available since TensorFlow v1.4) A list of CheckpointSaverListener objects used for callbacks that run immediately before or after checkpoint savings.

...

Optional arguments, passed on to the estimator's train() method.

Value

A data.frame of the training loss history.

See also

Other custom estimator methods: estimator_spec, estimator, evaluate.tf_estimator, export_savedmodel.tf_estimator, predict.tf_estimator