Skip to content

Conversation

@MalteKurz
Copy link
Member

@MalteKurz MalteKurz commented Oct 10, 2022

Description

This PR implements the often requested feature to store the estimated models for nuisance parameters. To use it, call the method fit() with option store_models=True. Example:

library(DoubleML) library(mlr3) library(mlr3learners) library(data.table) set.seed(2) ml_g = lrn("regr.ranger", num.trees = 10, max.depth = 2) ml_m = ml_g$clone() obj_dml_data = make_plr_CCDDHNR2018(alpha = 0.5) dml_plr_obj = DoubleMLPLR$new(obj_dml_data, ml_g, ml_m) dml_plr_obj$fit(store_models=TRUE)

The estimated models can then be found in the attribute dml_plr_obj$models:

dml_plr_obj$models $ml_l $ml_l$d $ml_l$d[[1]] $ml_l$d[[1]][[1]] <LearnerRegrRanger:regr.ranger> * Model: ranger * Parameters: num.threads=1, num.trees=10, max.depth=2 * Packages: mlr3, mlr3learners, ranger * Predict Types: [response], se * Feature Types: logical, integer, numeric, character, factor, ordered * Properties: hotstart_backward, importance, oob_error, weights $ml_l$d[[1]][[2]] <LearnerRegrRanger:regr.ranger> * Model: ranger * Parameters: num.threads=1, num.trees=10, max.depth=2 * Packages: mlr3, mlr3learners, ranger * Predict Types: [response], se * Feature Types: logical, integer, numeric, character, factor, ordered * Properties: hotstart_backward, importance, oob_error, weights $ml_l$d[[1]][[3]] <LearnerRegrRanger:regr.ranger> * Model: ranger * Parameters: num.threads=1, num.trees=10, max.depth=2 * Packages: mlr3, mlr3learners, ranger * Predict Types: [response], se * Feature Types: logical, integer, numeric, character, factor, ordered * Properties: hotstart_backward, importance, oob_error, weights $ml_l$d[[1]][[4]] <LearnerRegrRanger:regr.ranger> * Model: ranger * Parameters: num.threads=1, num.trees=10, max.depth=2 * Packages: mlr3, mlr3learners, ranger * Predict Types: [response], se * Feature Types: logical, integer, numeric, character, factor, ordered * Properties: hotstart_backward, importance, oob_error, weights $ml_l$d[[1]][[5]] <LearnerRegrRanger:regr.ranger> * Model: ranger * Parameters: num.threads=1, num.trees=10, max.depth=2 * Packages: mlr3, mlr3learners, ranger * Predict Types: [response], se * Feature Types: logical, integer, numeric, character, factor, ordered * Properties: hotstart_backward, importance, oob_error, weights $ml_m $ml_m$d $ml_m$d[[1]] $ml_m$d[[1]][[1]] <LearnerRegrRanger:regr.ranger> * Model: ranger * Parameters: num.threads=1, num.trees=10, max.depth=2 * Packages: mlr3, mlr3learners, ranger * Predict Types: [response], se * Feature Types: logical, integer, numeric, character, factor, ordered * Properties: hotstart_backward, importance, oob_error, weights $ml_m$d[[1]][[2]] <LearnerRegrRanger:regr.ranger> * Model: ranger * Parameters: num.threads=1, num.trees=10, max.depth=2 * Packages: mlr3, mlr3learners, ranger * Predict Types: [response], se * Feature Types: logical, integer, numeric, character, factor, ordered * Properties: hotstart_backward, importance, oob_error, weights $ml_m$d[[1]][[3]] <LearnerRegrRanger:regr.ranger> * Model: ranger * Parameters: num.threads=1, num.trees=10, max.depth=2 * Packages: mlr3, mlr3learners, ranger * Predict Types: [response], se * Feature Types: logical, integer, numeric, character, factor, ordered * Properties: hotstart_backward, importance, oob_error, weights $ml_m$d[[1]][[4]] <LearnerRegrRanger:regr.ranger> * Model: ranger * Parameters: num.threads=1, num.trees=10, max.depth=2 * Packages: mlr3, mlr3learners, ranger * Predict Types: [response], se * Feature Types: logical, integer, numeric, character, factor, ordered * Properties: hotstart_backward, importance, oob_error, weights $ml_m$d[[1]][[5]] <LearnerRegrRanger:regr.ranger> * Model: ranger * Parameters: num.threads=1, num.trees=10, max.depth=2 * Packages: mlr3, mlr3learners, ranger * Predict Types: [response], se * Feature Types: logical, integer, numeric, character, factor, ordered * Properties: hotstart_backward, importance, oob_error, weights 

Note that the number of fitted models depends on the settings and the considered model. The outer named list contains one entry for each nuisance part (here ml_l and ml_m). For each nuisance part there is a named list containing an entry for each treatment variable (here only 'd'). The next inner part is a list of length n_rep (repeated cross-fitting) and then a list of length n_folds (number of folds per repeated cross fit).

PR Checklist

  • The title of the pull request summarizes the changes made.
  • The PR contains a detailed description of all changes and additions.
  • The code passes R CMD check and all (unit) tests (see our contributing guidelines for details).
  • Enhancements or new feature are equipped with unit tests.
  • The changes adhere to the "mlr-style" standards (see our contributing guidelines for details).
Copy link
Member

@PhilippBach PhilippBach left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @MalteKurz for preparing the PR. As you indicate in the check boxes of the PR, there are still the tests missing; Either I can do them in the next days myself or I'll add my review to them later (in case you're faster 😃 ). Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants