Skip to contents

This is the abstract base class for learner objects like LearnerClassif and LearnerRegr.

Learners are build around the three following key parts:

  • Methods $train() and $predict() which call internal methods or private methods $.train()/$.predict()).

  • A paradox::ParamSet which stores meta-information about available hyperparameters, and also stores hyperparameter settings.

  • Meta-information about the requirements and capabilities of the learner.

  • The fitted model stored in field $model, available after calling $train().

Predefined learners are stored in the dictionary mlr_learners, e.g. classif.rpart or regr.rpart.

More classification and regression learners are implemented in the add-on package mlr3learners. Learners for survival analysis (or more general, for probabilistic regression) can be found in mlr3proba. Unsupervised cluster algorithms are implemented in mlr3cluster. The dictionary mlr_learners gets automatically populated with the new learners as soon as the respective packages are loaded.

More (experimental) learners can be found in the GitHub repository: https://github.com/mlr-org/mlr3extralearners. A guide on how to extend mlr3 with custom learners can be found in the mlr3book.

To combine the learner with preprocessing operations like factor encoding, mlr3pipelines is recommended. Hyperparameters stored in the param_set can be tuned with mlr3tuning.

Optional Extractors

Specific learner implementations are free to implement additional getters to ease the access of certain parts of the model in the inherited subclasses.

For the following operations, extractors are standardized:

  • importance(...): Returns the feature importance score as numeric vector. The higher the score, the more important the variable. The returned vector is named with feature names and sorted in decreasing order. Note that the model might omit features it has not used at all. The learner must be tagged with property "importance". To filter variables using the importance scores, see package mlr3filters.

  • selected_features(...): Returns a subset of selected features as character(). The learner must be tagged with property "selected_features".

  • oob_error(...): Returns the out-of-bag error of the model as numeric(1). The learner must be tagged with property "oob_error".

  • internal_valid_scores: Returns the internal validation score(s) of the model as a named list(). Only available for Learners with the "validation" property. If the learner is not trained yet, this returns NULL.

  • internal_tuned_values: Returns the internally tuned hyperparameters of the model as a named list(). Only available for Learners with the "internal_tuning" property. If the learner is not trained yet, this returns NULL.

Setting Hyperparameters

All information about hyperparameters is stored in the slot param_set which is a paradox::ParamSet. The printer gives an overview about the ids of available hyperparameters, their storage type, lower and upper bounds, possible levels (for factors), default values and assigned values. To set hyperparameters, call the set_values() method on the param_set:

lrn = lrn("classif.rpart")
lrn$param_set$set_values(minsplit = 3, cp = 0.01)

Note that this operation replaces all previously set hyperparameter values. If you only intend to change one specific hyperparameter value and leave the others as-is, you can use the helper function mlr3misc::insert_named():

lrn$param_set$values = mlr3misc::insert_named(lrn$param_set$values, list(cp = 0.001))

If the learner has additional hyperparameters which are not encoded in the ParamSet, you can easily extend the learner. Here, we add a factor hyperparameter with id "foo" and possible levels "a" and "b":

lrn$param_set$add(paradox::ParamFct$new("foo", levels = c("a", "b")))

Implementing Validation

Some Learners, such as XGBoost, other boosting algorithms, or deep learning models (mlr3torch), utilize validation data during the training to prevent overfitting or to log the validation performance. It is possible to configure learners to be able to receive such an independent validation set during training. To do so, one must:

  • annotate the learner with the "validation" property

  • implement the active binding $internal_valid_scores (see section Optional Extractors), as well as the private method $.extract_internal_valid_scores() which returns the (final) internal validation scores from the model of the Learner and returns them as a named list() of numeric(1). If the model is not trained yet, this method should return NULL.

  • Add the validate parameter, which can be either NULL, a ratio in $(0, 1)$, "test", or "predefined":

    • NULL: no validation

    • ratio: only proportion 1 - ratio of the task is used for training and ratio is used for validation.

    • "test" means that the "test" task is used. Warning: This can lead to biased performance estimation. This option is only available if the learner is being trained via resample(), benchmark() or functions that internally use them, e.g. tune() of mlr3tuning or batchmark() of mlr3batchmark. This is especially useful for hyperparameter tuning, where one might e.g. want to use the same validation data for early stopping and model evaluation.

    • "predefined" means that the task's (manually set) $internal_valid_task is used. See the Task documentation for more information.

For an example how to do this, see LearnerClassifDebug. Note that in .train(), the $internal_valid_task will only be present if the $validate field of the Learner is set to a non-NULL value.

Implementing Internal Tuning

Some learners such as XGBoost or cv.glmnet can internally tune hyperparameters. XGBoost, for example, can tune the number of boosting rounds based on the validation performance. CV Glmnet, on the other hand, can tune the regularization parameter based on an internal cross-validation. Internal tuning can therefore rely on the internal validation data, but does not necessarily do so.

In order to be able to combine this internal hyperparamer tuning with the standard hyperparameter optimization implemented via mlr3tuning, one most:

  • annotate the learner with the "internal_tuning" property

  • implement the active binding $internal_tuned_values (see section Optional Extractors) as well as the private method $.extract_internal_tuned_values() which extracts the internally tuned values from the Learner's model and returns them as a named list(). If the model is not trained yet, this method should return NULL.

  • Have at least one parameter tagged with "internal_tuning", which requires to also provide a in_tune_fn and disable_tune_fn, and should also include a default aggregation function.

For an example how to do this, see LearnerClassifDebug.

Implementing Marshaling

Some Learners have models that cannot be serialized as they e.g. contain external pointers. In order to still be able to save them, use them with parallelization or callr encapsulation it is necessary to implement how they should be (un)-marshaled. See marshaling for how to do this.

See also

Other Learner: LearnerClassif, LearnerRegr, mlr_learners, mlr_learners_classif.debug, mlr_learners_classif.featureless, mlr_learners_classif.rpart, mlr_learners_regr.debug, mlr_learners_regr.featureless, mlr_learners_regr.rpart

Public fields

id

(character(1))
Identifier of the object. Used in tables, plot and text output.

label

(character(1))
Label for this object. Can be used in tables, plot and text output instead of the ID.

state

(NULL | named list())
Current (internal) state of the learner. Contains all information gathered during train() and predict(). It is not recommended to access elements from state directly. This is an internal data structure which may change in the future.

task_type

(character(1))
Task type, e.g. "classif" or "regr".

For a complete list of possible task types (depending on the loaded packages), see mlr_reflections$task_types$type.

feature_types

(character())
Stores the feature types the learner can handle, e.g. "logical", "numeric", or "factor". A complete list of candidate feature types, grouped by task type, is stored in mlr_reflections$task_feature_types.

properties

(character())
Stores a set of properties/capabilities the learner has. A complete list of candidate properties, grouped by task type, is stored in mlr_reflections$learner_properties.

packages

(character(1))
Set of required packages. These packages are loaded, but not attached.

predict_sets

(character())
During resample()/benchmark(), a Learner can predict on multiple sets. Per default, a learner only predicts observations in the test set (predict_sets == "test"). To change this behavior, set predict_sets to a non-empty subset of {"train", "test", "internal_valid"}. The "train" predict set contains the train ids from the resampling. This means that if a learner does validation and sets $validate to a ratio (creating the validation data from the training data), the train predictions will include the predictions for the validation data. Each set yields a separate Prediction object. Those can be combined via getters in ResampleResult/BenchmarkResult, or Measures can be configured to operate on specific subsets of the calculated prediction sets.

parallel_predict

(logical(1))
If set to TRUE, use future to calculate predictions in parallel (default: FALSE). The row ids of the task will be split into future::nbrOfWorkers() chunks, and predictions are evaluated according to the active future::plan(). This currently only works for methods Learner$predict() and Learner$predict_newdata(), and has no effect during resample() or benchmark() where you have other means to parallelize.

Note that the recorded time required for prediction reports the time required to predict is not properly defined and depends on the parallelization backend.

timeout

(named numeric(2))
Timeout for the learner's train and predict steps, in seconds. This works differently for different encapsulation methods, see mlr3misc::encapsulate(). Default is c(train = Inf, predict = Inf). Also see the section on error handling the mlr3book: https://mlr3book.mlr-org.com/chapters/chapter10/advanced_technical_aspects_of_mlr3.html#sec-error-handling

man

(character(1))
String in the format [pkg]::[topic] pointing to a manual page for this object. Defaults to NA, but can be set by child classes.

Active bindings

data_formats

(character())
Supported data format. Always "data.table".. This is deprecated and will be removed in the future.

model

(any)
The fitted model. Only available after $train() has been called.

timings

(named numeric(2))
Elapsed time in seconds for the steps "train" and "predict".

When predictions for multiple predict sets were made during resample() or benchmark(), the predict time shows the cumulative duration of all predictions. If learner$predict() is called manually, the last predict time gets overwritten.

Measured via mlr3misc::encapsulate().

log

(data.table::data.table())
Returns the output (including warning and errors) as table with columns

  • "stage" ("train" or "predict"),

  • "class" ("output", "warning", or "error"), and

  • "msg" (character()).

warnings

(character())
Logged warnings as vector.

errors

(character())
Logged errors as vector.

hash

(character(1))
Hash (unique identifier) for this object. The hash is calculated based on the learner id, the parameter settings, the predict type, the fallback hash, the parallel predict setting, the validate setting, and the predict sets.

phash

(character(1))
Hash (unique identifier) for this partial object, excluding some components which are varied systematically during tuning (parameter values).

predict_type

(character(1))
Stores the currently active predict type, e.g. "response". Must be an element of $predict_types. A few learners already use the predict type during training. So there is no guarantee that changing the predict type after training will have any effect or does not lead to errors.

param_set

(paradox::ParamSet)
Set of hyperparameters.

fallback

(Learner)
Returns the fallback learner set with $encapsulate().

encapsulation

(character(2))
Returns the encapsulation settings set with $encapsulate().

hotstart_stack

(HotstartStack)
. Stores HotstartStack.

predict_types

(character())
Stores the possible predict types the learner is capable of. A complete list of candidate predict types, grouped by task type, is stored in mlr_reflections$learner_predict_types. This field is read-only.

Methods


Method new()

Creates a new instance of this R6 class.

Note that this object is typically constructed via a derived classes, e.g. LearnerClassif or LearnerRegr.

Usage

Learner$new(
  id,
  task_type,
  param_set = ps(),
  predict_types = character(),
  feature_types = character(),
  properties = character(),
  data_formats,
  packages = character(),
  label = NA_character_,
  man = NA_character_
)

Arguments

id

(character(1))
Identifier for the new instance.

task_type

(character(1))
Type of task, e.g. "regr" or "classif". Must be an element of mlr_reflections$task_types$type.

param_set

(paradox::ParamSet)
Set of hyperparameters.

predict_types

(character())
Supported predict types. Must be a subset of mlr_reflections$learner_predict_types.

feature_types

(character())
Feature types the learner operates on. Must be a subset of mlr_reflections$task_feature_types.

properties

(character())
Set of properties of the Learner. Must be a subset of mlr_reflections$learner_properties. The following properties are currently standardized and understood by learners in mlr3:

  • "missings": The learner can handle missing values in the data.

  • "weights": The learner supports observation weights.

  • "importance": The learner supports extraction of importance scores, i.e. comes with an $importance() extractor function (see section on optional extractors in Learner).

  • "selected_features": The learner supports extraction of the set of selected features, i.e. comes with a $selected_features() extractor function (see section on optional extractors in Learner).

  • "oob_error": The learner supports extraction of estimated out of bag error, i.e. comes with a oob_error() extractor function (see section on optional extractors in Learner).

  • "validation": The learner can use a validation task during training.

  • "internal_tuning": The learner is able to internally optimize hyperparameters (those are also tagged with "internal_tuning").

  • "marshal": To save learners with this property, you need to call $marshal() first. If a learner is in a marshaled state, you call first need to call $unmarshal() to use its model, e.g. for prediction.

data_formats

(character())
Deprecated: ignored, and will be removed in the future.

packages

(character())
Set of required packages. A warning is signaled by the constructor if at least one of the packages is not installed, but loaded (not attached) later on-demand via requireNamespace().

label

(character(1))
Label for the new instance.

man

(character(1))
String in the format [pkg]::[topic] pointing to a manual page for this object. The referenced help package can be opened via method $help().


Method format()

Helper for print outputs.

Usage

Learner$format(...)

Arguments

...

(ignored).


Method print()

Printer.

Usage

Learner$print(...)

Arguments

...

(ignored).


Method help()

Opens the corresponding help page referenced by field $man.

Usage

Learner$help()


Method train()

Train the learner on a set of observations of the provided task. Mutates the learner by reference, i.e. stores the model alongside other information in field $state.

Usage

Learner$train(task, row_ids = NULL)

Arguments

task

(Task).

row_ids

(integer())
Vector of training indices as subset of task$row_ids. For a simple split into training and test set, see partition().

Returns

Returns the object itself, but modified by reference. You need to explicitly $clone() the object beforehand if you want to keeps the object in its previous state.


Method predict()

Uses the information stored during $train() in $state to create a new Prediction for a set of observations of the provided task.

Usage

Learner$predict(task, row_ids = NULL)

Arguments

task

(Task).

row_ids

(integer())
Vector of test indices as subset of task$row_ids. For a simple split into training and test set, see partition().

Returns

Prediction.


Method predict_newdata()

Uses the model fitted during $train() to create a new Prediction based on the new data in newdata. Object task is the task used during $train() and required for conversion of newdata. If the learner's $train() method has been called, there is a (size reduced) version of the training task stored in the learner. If the learner has been fitted via resample() or benchmark(), you need to pass the corresponding task stored in the ResampleResult or BenchmarkResult, respectively. Further, auto_convert is used for type-conversions to ensure compatability of features between $train() and $predict().

Usage

Learner$predict_newdata(newdata, task = NULL)

Arguments

newdata

(any object supported by as_data_backend())
New data to predict on. All data formats convertible by as_data_backend() are supported, e.g. data.frame() or DataBackend. If a DataBackend is provided as newdata, the row ids are preserved, otherwise they are set to to the sequence 1:nrow(newdata).

task

(Task).

Returns

Prediction.


Method reset()

Reset the learner, i.e. un-train by resetting the state.

Usage

Learner$reset()

Returns

Returns the object itself, but modified by reference. You need to explicitly $clone() the object beforehand if you want to keeps the object in its previous state.


Method base_learner()

Extracts the base learner from nested learner objects like GraphLearner in mlr3pipelines or AutoTuner in mlr3tuning. Returns the Learner itself for regular learners.

Usage

Learner$base_learner(recursive = Inf)

Arguments

recursive

(integer(1))
Depth of recursion for multiple nested objects.

Returns

Learner.


Method encapsulate()

Sets the encapsulation method and fallback learner for the train and predict steps. There are currently four different methods implemented:

  • "none": Just runs the learner in the current session and measures the elapsed time. Does not keep a log, output is printed directly to the console. Works well together with traceback().

  • "try": Similar to "none", but catches error. Output is printed to the console and not logged.

  • "evaluate": Uses the package evaluate to call the learner, measure time and do the logging.

  • "callr": Uses the package callr to call the learner, measure time and do the logging. This encapsulation spawns a separate R session in which the learner is called. While this comes with a considerable overhead, it also guards your session from being teared down by segfaults.

The fallback learner is fitted to create valid predictions in case that either the model fitting or the prediction of the original learner fails. If the training step or the predict step of the original learner fails, the fallback is used completely to predict predictions sets. If the original learner only partially fails during predict step (usually in the form of missing to predict some observations or producing some NA`` predictions), these missing predictions are imputed by the fallback. Note that the fallback is always trained, as we do not know in advance whether prediction will fail. If the training step fails, the $modelfield of the original learner isNULL`.

Also see the section on error handling the mlr3book: https://mlr3book.mlr-org.com/chapters/chapter10/advanced_technical_aspects_of_mlr3.html#sec-error-handling

Usage

Learner$encapsulate(method, fallback = NULL)

Arguments

method

character(1)
One of "none", "try", "evaluate" or "callr". See the description for details.

fallback

Learner
The fallback learner for failed predictions.

Returns

self (invisibly).


Method configure()

Sets parameter values and fields of the learner. All arguments whose names match the name of a parameter of the paradox::ParamSet are set as parameters. All remaining arguments are assumed to be regular fields.

Usage

Learner$configure(..., .values = list())

Arguments

...

(named any)
Named arguments to set parameter values and fields.

.values

(named any)
Named list of parameter values and fields.


Method clone()

The objects of this class are cloneable with this method.

Usage

Learner$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.