Container object for machine learning experiments.

Format

R6Class object.

Usage

# Construction
e = Experiment$new(task, learner, ...)
    # Members
e$ctrl
e$data
e$has_errors
e$hash
e$learner
e$logs
e$model
e$performance
e$prediction
e$state
e$task
e$test_set
e$timings
e$train_set
e$validation_set
    # Methods
e$predict(row_ids, newdata, ctrl = list())
e$score(ctrl = list())
e$train(row_ids, ctrl = list())

Arguments

  • task (Task): Task to conduct experiment on

  • learner (Learner): Learner to conduct experiment with.

  • row_ids (integer() | character()): Subset of the task's row ids to work on. Invalid row ids are silently ignored.

  • newdata (data.frame): New data to predict on. Will be appended to the task.

Details

  • $new() initializes a new machine learning experiment which can grow in a stepwise fashion.

  • $predict() uses the previously fitted model to predict new observations. The predictions are stored internally as an Prediction object and can be accessed via e$prediction as data.table().

  • $score() quantifies stored predictions using the task's Measure and stores the resulting performance values. The performance can be accessed via e$performance.

  • $train() fits the induced Learner on the (subset of the) task and'stores the model in the Learner. The model can be accessed via e$model.

  • $ctrl (list). List of control settings passed to $train(), $predict() and $score().

  • $data stores the internal representation of an Experiment as a named list with the following slots:

    • iteration (integer(1)). If the experiment is constructed manually, this is always 1.

    • learner (Learner).

    • measures (list of Measure). Measures used for performance assessment.

    • performance (named numeric). Performance values returned by the measures.

    • predict_log. Log for the predict step.

    • predict_time (numeric(1)). Elapsed time in microseconds.

    • prediction (Prediction).

    • resampling (Resampling). Is NULL prior to calling $train(). If the experiment is constructed manually (i.e., not via resample() or benchmark(), a ResamplingCustom object is stored here.

    • score_time (numeric(1)). Elapsed time in microseconds.

    • task (Task).

    • train_log: Log for the training step.

    • train_time (numeric(1)). Elapsed time in microseconds.

  • $has_errors (logical(1)). Whether the Experiment showed errors either during training or prediction.

  • $hash (character(1)). The hash of the experiment.

  • $logs (named list(2)) returns a list with names train and predict. Both store an object of class Log if logging of the learner has been enabled via mlr_control(), and are NULL if logging was disabled or the respective step has not been performed yet.

  • $state (ordered(1)) returns the state of the experiment: "defined", "trained", "predicted", or "scored".

  • $task and $learner can be used to access the Task and Learner.

  • $timings (named numeric(3)) holds the elapsed time for the steps train, predict and score in seconds with up to millisecond accuracy (c.f. proc.time()). Timings are NA if the respective step has not been performed yet.

  • $train_set and $test_set (integer() | character()) return the row ids of the training set or test set, respectively.

  • $validation_set (integer() | character()) returns the row ids of the validation set (see Task).

Examples

e = Experiment$new( task = mlr_tasks$get("iris"), learner = mlr_learners$get("classif.rpart") ) print(e)
#> <Experiment> [defined]: #> + Task: iris #> + Learner: rpart #> - Model: [missing] #> - Predictions: [missing] #> - Performance: [missing] #> #> Public: clone(), ctrl, data, has_errors, hash, learner, logs, model, #> performance, predict(), prediction, score(), state, task, test_set, #> timings, train_set, train(), validation_set
e$state
#> [1] defined #> Levels: undefined < defined < trained < predicted < scored
e$train(row_ids = 1:120)
#> INFO [mlr3] Training learner 'rpart' on task 'iris' ...
#> <Experiment> [trained]: #> + Task: iris #> + Learner: rpart #> + Model: [rpart] #> - Predictions: [missing] #> - Performance: [missing] #> #> Public: clone(), ctrl, data, has_errors, hash, learner, logs, model, #> performance, predict(), prediction, score(), state, task, test_set, #> timings, train_set, train(), validation_set
e$state
#> [1] trained #> Levels: undefined < defined < trained < predicted < scored
e$model
#> n= 120 #> #> node), split, n, loss, yval, (yprob) #> * denotes terminal node #> #> 1) root 120 70 setosa (0.41666667 0.41666667 0.16666667) #> 2) Petal.Length< 2.45 50 0 setosa (1.00000000 0.00000000 0.00000000) * #> 3) Petal.Length>=2.45 70 20 versicolor (0.00000000 0.71428571 0.28571429) #> 6) Petal.Length< 4.95 49 1 versicolor (0.00000000 0.97959184 0.02040816) * #> 7) Petal.Length>=4.95 21 2 virginica (0.00000000 0.09523810 0.90476190) *
e$predict(row_ids = 121:150)
#> INFO [mlr3] Predicting with model of learner 'rpart' on task 'iris' ...
#> <Experiment> [predicted]: #> + Task: iris #> + Learner: rpart #> + Model: [rpart] #> + Predictions: [PredictionClassif] #> - Performance: [missing] #> #> Public: clone(), ctrl, data, has_errors, hash, learner, logs, model, #> performance, predict(), prediction, score(), state, task, test_set, #> timings, train_set, train(), validation_set
e$state
#> [1] predicted #> Levels: undefined < defined < trained < predicted < scored
e$prediction
#> <PredictionClassif> for 30 observations: #> row_id response truth #> 1: 121 virginica virginica #> 2: 122 versicolor virginica #> 3: 123 virginica virginica #> --- #> 28: 148 virginica virginica #> 29: 149 virginica virginica #> 30: 150 virginica virginica
e$score()
#> INFO [mlr3] Scoring predictions of learner 'rpart' on task 'iris' ...
#> <Experiment> [scored]: #> + Task: iris #> + Learner: rpart #> + Model: [rpart] #> + Predictions: [PredictionClassif] #> + Performance: mmce=0.1666667 #> #> Public: clone(), ctrl, data, has_errors, hash, learner, logs, model, #> performance, predict(), prediction, score(), state, task, test_set, #> timings, train_set, train(), validation_set
e$state
#> [1] scored #> Levels: undefined < defined < trained < predicted < scored
e$performance
#> mmce #> 0.1666667
e$train_set
#> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 #> [19] 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 #> [37] 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 #> [55] 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 #> [73] 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 #> [91] 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 #> [109] 109 110 111 112 113 114 115 116 117 118 119 120
e$test_set
#> [1] 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 #> [20] 140 141 142 143 144 145 146 147 148 149 150