Container object for a machine learning experiment. After initialization with a Task and a Learner, the experiment is conducted by calling the methods $train(), $predict() and $score().

Format

R6::R6Class object.

Construction

Experiment$new(task = NULL, learner = NULL, ctrl = list())
  • task :: (Task | character(1))
    May be NULL during initialization, but is mandatory to train the Experiment. Instead of a Task object, it is also possible to provide a key to retrieve a task from the mlr_tasks dictionary.

  • learner :: Learner | character(1))
    May be NULL during initialization, but is mandatory to train the Experiment. Instead of a Learner object, it is also possible to provide a key to retrieve a learner from the mlr_learners dictionary.

  • ctrl :: named list()
    Control object, see mlr_control().

Fields

  • ctrl :: list()
    Control settings passed during initialization.

  • data :: named list()
    See section "Internal Data Storage".

  • has_errors :: logical(1)
    Flag which is TRUE if any error has been recorded during $train() or $predict().

  • hash :: character(1)
    Hash (unique identifier) for this object.

  • model :: any
    Access the trained model of the Learner. Only available after the learner has been trained.

  • performance :: named numeric()
    Access the scored performance scores as returned by the Measure stored in the Task.

  • prediction :: Prediction
    Access the individual predictions of the model stored in the Learner.

  • seeds :: integer(3)
    Named integer of random number generator seeds passed to set.seed() prior to calling external code in train(), predict() or score(). Names must match "train", "predict" and "score". Set to NA to disable seeding (default).

  • state :: ordered(1)
    Returns the state of the experiment as ordered factor with levels "defined", "trained", "predicted", and "scored".

  • task :: Task
    Access to the stored Task.

  • test_set :: (integer() | character())
    The row ids of the Task for the test set used in $predict() Timings are NA if the respective step has not been performed yet.

  • timings :: named numeric(3)
    Stores the elapsed time for the steps train(), predict() and score() in seconds with up to millisecond accuracy (c.f. proc.time()).

  • train_set :: (integer() | character())
    The row ids of the Task for the training set used in $train().

  • validation_set :: (integer() || character())
    The row ids of the validation set of the Task. Validation sets are not yet completely integrated into the package.

Methods

  • train(row_ids = NULL, ctrl = list())
    (integer() | character(), list()) -> self
    Fits the induced Learner on the row_ids of the Task and stores the model inside the Learner object. If no row_ids are provided, trains the model on all rows of the Task with row role "use". The fitted model can be accessed via $model.

  • predict(row_ids = NULL, newdata = NULL, ctrl = list())
    (integer() | character(), data.frame(), list()) -> self
    Uses the previously fitted model to predict new observations. New observations are either addressed as row_ids referencing rows in the stored task, or as data.frame() via newdata. The later fuses the new observations with the stored Task, and thereby mutates the Experiment. To avoid any side effects, it is advised to clone the experiment first. The resulting predictions are stored internally as an Prediction object and can be accessed via $prediction.

  • score(measures = NULL, ctrl = list())
    (list of [Measure], list()) -> self
    Quantifies stored predictions using the list of Measure provided here, defaulting to the default measures that come with the Task. The performance values are stored internally and can be accessed via $performance.

  • log(steps = c("train", "predict"))
    character(1) -> Log
    Returns a Log for specified steps.

Internal Data Storage

All data is stored in the slot data as named list(). Directly accessing the elements is not recommended, but sometimes required, especially if you aim to extend mlr3. The data object contains the following items:

  • task :: Task
    A clone of the Task which was provided during construction. Also accessible via e$task.

  • learner :: Learner
    A clone of the Learner which was provided during construction. Also accessible via e$learner.

  • resampling :: Resampling
    Is NULL prior to calling $train(). If the experiment is constructed manually (i.e., not via resample() or benchmark()), a ResamplingCustom object is stored. The combination of resampling and iteration (next item) is used to extract the training and test set indices. These are directly accessible via e$train_set and e$test_set.

  • iteration :: integer(1)
    Refers to the iteration number of the stored Resampling object. If the experiment is constructed manually, this is always 1, as there is only one training-test split.

  • train_log :: data.table::data.table()
    Log for the training step. May be NULL if no encapsulation has been enabled via mlr_control().

  • train_time :: numeric(1)
    Elapsed time during train in seconds with up to millisecond accuracy (c.f. proc.time()).

  • predict_log :: data.table::data.table()
    Log for the predict step. May be NULL if no encapsulation has been enabled via mlr_control().

  • predict_time :: numeric(1)
    Elapsed time during predict in seconds with up to millisecond accuracy (c.f. proc.time()).

  • prediction :: Prediction
    Prediction object as returned by the Learner's predict() call.

  • measures :: list() of Measure
    Measures which where used for performance assessment.

  • performance :: named numeric()
    Aggregated scores returned by the measures, named with measure ids.

  • score_time :: numeric(1)
    Elapsed time during score in seconds with up to millisecond accuracy (c.f. proc.time())..

Examples

e = Experiment$new(task = "iris", learner = "classif.rpart") print(e)
#> <Experiment> [defined]: #> + Task: iris #> + Learner: classif.rpart #> - Model: [missing] #> - Predictions: [missing] #> - Performance: [missing] #> #> Public: clone(), ctrl, data, has_errors, hash, learner, log(), model, #> performance, predict(), prediction, score(), seeds, state, task, #> test_set, timings, train_set, train(), validation_set
e$state
#> [1] defined #> Levels: undefined < defined < trained < predicted < scored
e$train(row_ids = 1:120)
#> INFO [mlr3] Training learner 'classif.rpart' on task 'iris' ...
#> <Experiment> [trained]: #> + Task: iris #> + Learner: classif.rpart #> + Model: [rpart] #> - Predictions: [missing] #> - Performance: [missing] #> #> Public: clone(), ctrl, data, has_errors, hash, learner, log(), model, #> performance, predict(), prediction, score(), seeds, state, task, #> test_set, timings, train_set, train(), validation_set
#> <Experiment> [trained]: #> + Task: iris #> + Learner: classif.rpart #> + Model: [rpart] #> - Predictions: [missing] #> - Performance: [missing] #> #> Public: clone(), ctrl, data, has_errors, hash, learner, log(), model, #> performance, predict(), prediction, score(), seeds, state, task, #> test_set, timings, train_set, train(), validation_set
e$state
#> [1] trained #> Levels: undefined < defined < trained < predicted < scored
e$model
#> n= 120 #> #> node), split, n, loss, yval, (yprob) #> * denotes terminal node #> #> 1) root 120 70 setosa (0.41666667 0.41666667 0.16666667) #> 2) Petal.Length< 2.45 50 0 setosa (1.00000000 0.00000000 0.00000000) * #> 3) Petal.Length>=2.45 70 20 versicolor (0.00000000 0.71428571 0.28571429) #> 6) Petal.Length< 4.95 49 1 versicolor (0.00000000 0.97959184 0.02040816) * #> 7) Petal.Length>=4.95 21 2 virginica (0.00000000 0.09523810 0.90476190) *
e$predict(row_ids = 121:150)
#> INFO [mlr3] Predicting with model of learner 'classif.rpart' on task 'iris' ...
#> <Experiment> [predicted]: #> + Task: iris #> + Learner: classif.rpart #> + Model: [rpart] #> + Predictions: [PredictionClassif] #> - Performance: [missing] #> #> Public: clone(), ctrl, data, has_errors, hash, learner, log(), model, #> performance, predict(), prediction, score(), seeds, state, task, #> test_set, timings, train_set, train(), validation_set
#> <Experiment> [predicted]: #> + Task: iris #> + Learner: classif.rpart #> + Model: [rpart] #> + Predictions: [PredictionClassif] #> - Performance: [missing] #> #> Public: clone(), ctrl, data, has_errors, hash, learner, log(), model, #> performance, predict(), prediction, score(), seeds, state, task, #> test_set, timings, train_set, train(), validation_set
e$state
#> [1] predicted #> Levels: undefined < defined < trained < predicted < scored
e$prediction
#> <PredictionClassif> for 30 observations: #> row_id truth response #> 1: 121 virginica virginica #> 2: 122 virginica versicolor #> 3: 123 virginica virginica #> --- #> 28: 148 virginica virginica #> 29: 149 virginica virginica #> 30: 150 virginica virginica
e$score()
#> INFO [mlr3] Scoring predictions of learner 'classif.rpart' on task 'iris' ...
#> <Experiment> [scored]: #> + Task: iris #> + Learner: classif.rpart #> + Model: [rpart] #> + Predictions: [PredictionClassif] #> + Performance: classif.ce=0.1666667 #> #> Public: clone(), ctrl, data, has_errors, hash, learner, log(), model, #> performance, predict(), prediction, score(), seeds, state, task, #> test_set, timings, train_set, train(), validation_set
#> <Experiment> [scored]: #> + Task: iris #> + Learner: classif.rpart #> + Model: [rpart] #> + Predictions: [PredictionClassif] #> + Performance: classif.ce=0.1666667 #> #> Public: clone(), ctrl, data, has_errors, hash, learner, log(), model, #> performance, predict(), prediction, score(), seeds, state, task, #> test_set, timings, train_set, train(), validation_set
e$state
#> [1] scored #> Levels: undefined < defined < trained < predicted < scored
e$performance
#> classif.ce #> 0.1666667
e$train_set
#> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 #> [19] 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 #> [37] 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 #> [55] 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 #> [73] 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 #> [91] 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 #> [109] 109 110 111 112 113 114 115 116 117 118 119 120
e$test_set
#> [1] 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 #> [20] 140 141 142 143 144 145 146 147 148 149 150