Container object for a machine learning experiment.

Format

R6::R6Class object.

Construction

Experiment$new(task = NULL, learner = NULL, ctrl = list())
  • task :: Task
    May be NULL during initialization, but is mandatory to train the Experiment.

  • learner :: Learner
    May be NULL during initialization, but is mandatory to train the Experiment.

  • ctrl :: named list()
    Control object, see mlr_control().

Fields

  • data :: named list()
    Internal data storage as a named list with the following slots:

    • iteration :: integer(1)
      Refers to the iteration number of the stored Resampling object. If the experiment is constructed manually, this is always 1, as there is only one training-test split.

    • learner :: Learner
      A clone of the Learner provided during construction.

    • measures :: list() of Measure
      Measures for performance assessment.

    • performance :: named numeric()
      Aggregated scores returned by the measures, named with measure ids.

    • prediction :: Prediction
      Prediction object as returned by the Learner's predict() call.

    • resampling :: Resampling
      Is NULL prior to calling $train(). If the experiment is constructed manually (i.e., not via resample() or benchmark()), a ResamplingCustom object is stored.

    • task :: Task
      A clone of the Task provided during construction.

    • train_log :: data.table::data.table()
      Log for the training step.

    • predict_log :: data.table::data.table()
      Log for the predict step.

    • train_time :: numeric(1)
      Elapsed time during train in seconds.

    • predict_time :: numeric(1)
      Elapsed time during predict in seconds.

    • score_time :: numeric(1)
      Elapsed time during score in seconds.

  • ctrl :: list()
    Control settings passed during initialization.

  • has_errors :: logical(1)
    Flag which is TRUE if any error has been recorded during $train() or $predict().

  • hash :: character(1)
    Hash (unique identifier) for this object. This hash is cached.

  • state :: ordered(1)
    Returns the state of the experiment as ordered factor: "defined", "trained", "predicted", or "scored".

  • train_set :: (integer() | character())
    The row ids of the training set for $train().

  • test_set :: (integer() | character())
    The row ids of the test set for $predict()

  • learner :: Learner
    Access the stored Learner.

  • model :: any
    Access the trained model of the Learner.

  • performance :: numeric()
    Access the scored performance scores as returned by the Measure stored in the Task.

  • prediction :: Prediction
    Access the individual predictions of the model stored in the Learner.

  • seeds :: integer(3)
    Named integer of random number generator seeds passed to set.seed() prior to calling external code in train(), predict() or score(). Names must match "train", "predict" and "score". Set to NA to disable seeding (default).

  • task :: Task
    Access to the stored Task.

  • timings :: numeric(3)
    Stores the elapsed time for the steps train(), predict() and score() in seconds with up to millisecond accuracy (c.f. proc.time()). Timings are NA if the respective step has not been performed yet.

  • validation_set :: (integer() || character())
    The row ids of the validation set of the Task.

Methods

Examples

e = Experiment$new( task = mlr_tasks$get("iris"), learner = mlr_learners$get("classif.rpart") ) print(e)
#> <Experiment> [defined]: #> + Task: iris #> + Learner: classif.rpart #> - Model: [missing] #> - Predictions: [missing] #> - Performance: [missing] #> #> Public: clone(), ctrl, data, has_errors, hash, learner, log(), model, #> performance, predict(), prediction, score(), seeds, state, task, #> test_set, timings, train_set, train(), validation_set
e$state
#> [1] defined #> Levels: undefined < defined < trained < predicted < scored
e$train(row_ids = 1:120)
#> INFO [mlr3] Training learner 'classif.rpart' on task 'iris' ...
#> <Experiment> [trained]: #> + Task: iris #> + Learner: classif.rpart #> + Model: [rpart] #> - Predictions: [missing] #> - Performance: [missing] #> #> Public: clone(), ctrl, data, has_errors, hash, learner, log(), model, #> performance, predict(), prediction, score(), seeds, state, task, #> test_set, timings, train_set, train(), validation_set
#> <Experiment> [trained]: #> + Task: iris #> + Learner: classif.rpart #> + Model: [rpart] #> - Predictions: [missing] #> - Performance: [missing] #> #> Public: clone(), ctrl, data, has_errors, hash, learner, log(), model, #> performance, predict(), prediction, score(), seeds, state, task, #> test_set, timings, train_set, train(), validation_set
e$state
#> [1] trained #> Levels: undefined < defined < trained < predicted < scored
e$model
#> n= 120 #> #> node), split, n, loss, yval, (yprob) #> * denotes terminal node #> #> 1) root 120 70 setosa (0.41666667 0.41666667 0.16666667) #> 2) Petal.Length< 2.45 50 0 setosa (1.00000000 0.00000000 0.00000000) * #> 3) Petal.Length>=2.45 70 20 versicolor (0.00000000 0.71428571 0.28571429) #> 6) Petal.Length< 4.95 49 1 versicolor (0.00000000 0.97959184 0.02040816) * #> 7) Petal.Length>=4.95 21 2 virginica (0.00000000 0.09523810 0.90476190) *
e$predict(row_ids = 121:150)
#> INFO [mlr3] Predicting with model of learner 'classif.rpart' on task 'iris' ...
#> <Experiment> [predicted]: #> + Task: iris #> + Learner: classif.rpart #> + Model: [rpart] #> + Predictions: [PredictionClassif] #> - Performance: [missing] #> #> Public: clone(), ctrl, data, has_errors, hash, learner, log(), model, #> performance, predict(), prediction, score(), seeds, state, task, #> test_set, timings, train_set, train(), validation_set
#> <Experiment> [predicted]: #> + Task: iris #> + Learner: classif.rpart #> + Model: [rpart] #> + Predictions: [PredictionClassif] #> - Performance: [missing] #> #> Public: clone(), ctrl, data, has_errors, hash, learner, log(), model, #> performance, predict(), prediction, score(), seeds, state, task, #> test_set, timings, train_set, train(), validation_set
e$state
#> [1] predicted #> Levels: undefined < defined < trained < predicted < scored
e$prediction
#> <PredictionClassif> for 30 observations: #> row_id response truth #> 1: 121 virginica virginica #> 2: 122 versicolor virginica #> 3: 123 virginica virginica #> --- #> 28: 148 virginica virginica #> 29: 149 virginica virginica #> 30: 150 virginica virginica
e$score()
#> INFO [mlr3] Scoring predictions of learner 'classif.rpart' on task 'iris' ...
#> <Experiment> [scored]: #> + Task: iris #> + Learner: classif.rpart #> + Model: [rpart] #> + Predictions: [PredictionClassif] #> + Performance: classif.mmce=0.1666667 #> #> Public: clone(), ctrl, data, has_errors, hash, learner, log(), model, #> performance, predict(), prediction, score(), seeds, state, task, #> test_set, timings, train_set, train(), validation_set
#> <Experiment> [scored]: #> + Task: iris #> + Learner: classif.rpart #> + Model: [rpart] #> + Predictions: [PredictionClassif] #> + Performance: classif.mmce=0.1666667 #> #> Public: clone(), ctrl, data, has_errors, hash, learner, log(), model, #> performance, predict(), prediction, score(), seeds, state, task, #> test_set, timings, train_set, train(), validation_set
e$state
#> [1] scored #> Levels: undefined < defined < trained < predicted < scored
e$performance
#> classif.mmce #> 0.1666667
e$train_set
#> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 #> [19] 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 #> [37] 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 #> [55] 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 #> [73] 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 #> [91] 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 #> [109] 109 110 111 112 113 114 115 116 117 118 119 120
e$test_set
#> [1] 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 #> [20] 140 141 142 143 144 145 146 147 148 149 150