This is the abstract base class for learner objects like LearnerClassif and LearnerRegr.

Learners are build around the three following key parts:

  • Methods $train() and $predict() which call internal methods (either public method $train_internal()/$predict_internal() (deprecated) or private methods $.train()/$.predict()).

  • A paradox::ParamSet which stores meta-information about available hyperparameters, and also stores hyperparameter settings.

  • Meta-information about the requirements and capabilities of the learner.

  • The fitted model stored in field $model, available after calling $train().

Predefined learners are stored in the dictionary mlr_learners, e.g. classif.rpart or regr.rpart.

More classification and regression learners are implemented in the add-on package mlr3learners. Learners for survival analysis (or more general, for probabilistic regression) can be found in mlr3proba. Unsupervised cluster algorithms are implemented in mlr3cluster. The dictionary mlr_learners gets automatically populated with the new learners as soon as the respective packages are loaded.

More (experimental) learners can be found in the GitHub repository: https://github.com/mlr-org/mlr3extralearners. A guide on how to extend mlr3 with custom learners can be found in the mlr3book.

Optional Extractors

Specific learner implementations are free to implement additional getters to ease the access of certain parts of the model in the inherited subclasses.

For the following operations, extractors are standardized:

  • importance(...): Returns the feature importance score as numeric vector. The higher the score, the more important the variable. The returned vector is named with feature names and sorted in decreasing order. Note that the model might omit features it has not used at all. The learner must be tagged with property "importance". To filter variables using the importance scores, see package mlr3filters.

  • selected_features(...): Returns a subset of selected features as character(). The learner must be tagged with property "selected_features".

  • oob_error(...): Returns the out-of-bag error of the model as numeric(1). The learner must be tagged with property "oob_error".

Setting Hyperparameters

All information about hyperparameters is stored in the slot param_set which is a paradox::ParamSet. The printer gives an overview about the ids of available hyperparameters, their storage type, lower and upper bounds, possible levels (for factors), default values and assigned values. To set hyperparameters, assign a named list to the subslot values:

lrn = lrn("classif.rpart")
lrn$param_set$values = list(minsplit = 3, cp = 0.01)

Note that this operation replaces all previously set hyperparameter values. If you only intend to change one specific hyperparameter value and leave the others as-is, you can use the helper function mlr3misc::insert_named():

lrn$param_set$values = mlr3misc::insert_named(lrn$param_set$values, list(cp = 0.001))

If the learner has additional hyperparameters which are not encoded in the ParamSet, you can easily extend the learner. Here, we add a factor hyperparameter with id "foo" and possible levels "a" and "b":

lrn$param_set$add(paradox::ParamFct$new("foo", levels = c("a", "b")))

See also

Public fields

id

(character(1))
Identifier of the object. Used in tables, plot and text output.

state

(NULL | named list())
Current (internal) state of the learner. Contains all information gathered during train() and predict(). It is not recommended to access elements from state directly. This is an internal data structure which may change in the future.

task_type

(character(1))
Task type, e.g. "classif" or "regr".For a complete list of possible task types (depending on the loaded packages), see mlr_reflections$task_types$type.

predict_types

(character())
Stores the possible predict types the learner is capable of. A complete list of candidate predict types, grouped by task type, is stored in mlr_reflections$learner_predict_types.

feature_types

(character())
Stores the feature types the learner can handle, e.g. "logical", "numeric", or "factor". A complete list of candidate feature types, grouped by task type, is stored in mlr_reflections$task_feature_types.

properties

(character())
Stores a set of properties/capabilities the learner has. A complete list of candidate properties, grouped by task type, is stored in mlr_reflections$learner_properties.

data_formats

(character())
Supported data format, e.g. "data.table" or "Matrix".

packages

(character(1))
Set of required packages. These packages are loaded, but not attached.

predict_sets

(character())
During resample()/benchmark(), a Learner can predict on multiple sets. Per default, a learner only predicts observations in the test set (predict_sets == "test"). To change this behaviour, set predict_sets to a non-empty subset of {"train", "test"}. Each set yields a separate Prediction object. Those be combined via getters in ResampleResult/BenchmarkResult, or Measures can be altered to operate on specific subsets of the calculated prediction sets.

fallback

(Learner)
Learner which is fitted to impute predictions in case that either the model fitting or the prediction of the top learner is not successful. Requires you to enable encapsulation, otherwise errors are not caught and the execution is terminated before the fallback learner kicks in.

man

(character(1))
String in the format [pkg]::[topic] pointing to a manual page for this object. Defaults to NA, but can be set by child classes.

Active bindings

model

(any)
The fitted model. Only available after $train() has been called.

timings

(named numeric(2))
Elapsed time in seconds for the steps "train" and "predict". Measured via mlr3misc::encapsulate().

log

(data.table::data.table())
Returns the output (including warning and errors) as table with columns

  • "stage" ("train" or "predict"),

  • "class" ("output", "warning", or "error"), and

  • "msg" (character()).

warnings

(character())
Logged warnings as vector.

errors

(character())
Logged errors as vector.

hash

(character(1))
Hash (unique identifier) for this object.

phash

(character(1))
Hash (unique identifier) for this partial object, excluding some components which are varied systematically during tuning (parameter values) or feature selection (feature names).

predict_type

(character(1))
Stores the currently active predict type, e.g. "response". Must be an element of $predict_types.

param_set

(paradox::ParamSet)
Set of hyperparameters.

encapsulate

(named character())
Controls how to execute the code in internal train and predict methods. Must be a named character vector with names "train" and "predict". Possible values are "none", "evaluate" (requires package evaluate) and "callr" (requires package callr). See mlr3misc::encapsulate() for more details.

Methods

Public methods


Method new()

Creates a new instance of this R6 class.

Note that this object is typically constructed via a derived classes, e.g. LearnerClassif or LearnerRegr.

Usage

Learner$new(
  id,
  task_type,
  param_set = ParamSet$new(),
  predict_types = character(),
  feature_types = character(),
  properties = character(),
  data_formats = "data.table",
  packages = character(),
  man = NA_character_
)

Arguments

id

(character(1))
Identifier for the new instance.

task_type

(character(1))
Type of task, e.g. "regr" or "classif". Must be an element of mlr_reflections$task_types$type.

param_set

(paradox::ParamSet)
Set of hyperparameters.

predict_types

(character())
Supported predict types. Must be a subset of mlr_reflections$learner_predict_types.

feature_types

(character())
Feature types the learner operates on. Must be a subset of mlr_reflections$task_feature_types.

properties

(character())
Set of properties of the Learner. Must be a subset of mlr_reflections$learner_properties. The following properties are currently standardized and understood by learners in mlr3:

  • "missings": The learner can handle missing values in the data.

  • "weights": The learner supports observation weights.

  • "importance": The learner supports extraction of importance scores, i.e. comes with an $importance() extractor function (see section on optional extractors in Learner).

  • "selected_features": The learner supports extraction of the set of selected features, i.e. comes with a $selected_features() extractor function (see section on optional extractors in Learner).

  • "oob_error": The learner supports extraction of estimated out of bag error, i.e. comes with a oob_error() extractor function (see section on optional extractors in Learner).

data_formats

(character())
Set of supported data formats which can be processed during $train() and $predict(), e.g. "data.table".

packages

(character())
Set of required packages. A warning is signaled by the constructor if at least one of the packages is not installed, but loaded (not attached) later on-demand via requireNamespace().

man

(character(1))
String in the format [pkg]::[topic] pointing to a manual page for this object. The referenced help package can be opened via method $help().


Method format()

Helper for print outputs.

Usage

Learner$format()


Method print()

Printer.

Usage

Learner$print()

Arguments

...

(ignored).


Method help()

Opens the corresponding help page referenced by field $man.

Usage

Learner$help()


Method train()

Train the learner on a set of observations of the provided task. Mutates the learner by reference, i.e. stores the model alongside other information in field $state.

Usage

Learner$train(task, row_ids = NULL)

Arguments

task

(Task).

row_ids

(integer())
Vector of training indices.

Returns

Returns the object itself, but modified by reference. You need to explicitly $clone() the object beforehand if you want to keeps the object in its previous state.


Method predict()

Uses the information stored during $train() in $state to create a new Prediction for a set of observations of the provided task.

Usage

Learner$predict(task, row_ids = NULL)

Arguments

task

(Task).

row_ids

(integer())
Vector of test indices.

Returns

Prediction.


Method predict_newdata()

Uses the model fitted during $train() to create a new Prediction based on the new data in newdata. Object task is the task used during $train() and required for conversion of newdata. If the learner's $train() method has been called, there is a (size reduced) version of the training task stored in the learner. If the learner has been fitted via resample() or benchmark(), you need to pass the corresponding task stored in the ResampleResult or BenchmarkResult, respectively.

Usage

Learner$predict_newdata(newdata, task = NULL)

Arguments

newdata

(data.frame())
New data to predict on. Row ids are automatically created via auto-incrementing.

task

(Task).

Returns

Prediction.


Method reset()

Reset the learner, i.e. un-train by resetting the state.

Usage

Learner$reset()

Returns

Returns the object itself, but modified by reference. You need to explicitly $clone() the object beforehand if you want to keeps the object in its previous state.


Method clone()

The objects of this class are cloneable with this method.

Usage

Learner$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.