Generate Predictions with an Estimator

Generate predicted labels / values for input data provided by input_fn().

# S3 method for tf_estimator
predict(object, input_fn, checkpoint_path = NULL,
  predict_keys = c("predictions", "classes", "class_ids", "logistic",
  "logits", "probabilities"), hooks = NULL, as_iterable = FALSE,
  simplify = TRUE, ...)

Arguments

object

A TensorFlow estimator.

input_fn

An input function, typically generated by the input_fn() helper function.

checkpoint_path

The path to a specific model checkpoint to be used for prediction. If NULL (the default), the latest checkpoint in model_dir is used.

predict_keys

The types of predictions that should be produced, as an R list. When this argument is not specified (the default), all possible predicted values will be returned.

hooks

A list of R functions, to be used as callbacks inside the training loop. By default, hook_history_saver(every_n_step = 10) and hook_progress_bar() will be attached if not provided to save the metrics history and create the progress bar.

as_iterable

Boolean; should a raw Python generator be returned? When FALSE (the default), the predicted values will be consumed from the generator and returned as an R object.

simplify

Whether to simplify prediction results into a tibble, as opposed to a list. Defaults to TRUE.

...

Optional arguments passed on to the estimator's predict() method.

Yields

Evaluated values of predictions tensors.

Raises

ValueError: Could not find a trained model in model_dir. ValueError: if batch length of predictions are not same. ValueError: If there is a conflict between predict_keys and predictions. For example if predict_keys is not NULL but EstimatorSpec.predictions is not a dict.

See also

Other custom estimator methods: estimator_spec, estimator, evaluate.tf_estimator, export_savedmodel.tf_estimator, train.tf_estimator