Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ Suggests:
sandwich,
AER,
rpart,
bbotk
bbotk,
mlr3pipelines
VignetteBuilder: knitr
Collate:
'double_ml.R'
Expand Down
25 changes: 14 additions & 11 deletions R/double_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ DoubleML = R6Class("DoubleML",
i_treat = NA_integer_,
fold_specific_params = NULL,
summary_table = NULL,
learner_class = list(),
task_type = list(),
is_cluster_data = FALSE,
n_folds_per_cluster = NA_integer_,
smpls_cluster_ = NULL,
Expand Down Expand Up @@ -1250,28 +1250,31 @@ DoubleML = R6Class("DoubleML",
check_character(learner, max.len = 1),
check_class(learner, "Learner"))

if (test_class(learner, "AutoTuner")) {
stop(paste0(
"Learners of class 'AutoTuner' are not supported."
))
}
if (is.character(learner)) {
# warning("Learner provision by character() will be deprecated in the
# future.")
learner = lrn(learner)
}

if (Regr & test_class(learner, "LearnerRegr")) {
private$learner_class[learner_name] = "LearnerRegr"
}
if (Classif & test_class(learner, "LearnerClassif")) {
private$learner_class[learner_name] = "LearnerClassif"
if ((Regr & learner$task_type == "regr") |
(Classif & learner$task_type == "classif")) {
private$task_type[learner_name] = learner$task_type
}

if ((Regr & !Classif & !test_class(learner, "LearnerRegr"))) {
if ((Regr & !Classif & !learner$task_type == "regr")) {
stop(paste0(
"Invalid learner provided for ", learner_name,
": must be of class 'LearnerRegr'"))
": 'learner$task_type' must be 'regr'"))
}
if ((Classif & !Regr & !test_class(learner, "LearnerClassif"))) {
if ((Classif & !Regr & !learner$task_type == "classif")) {
stop(paste0(
"Invalid learner provided for ", learner_name,
": must be of class 'LearnerClassif'"))
": 'learner$task_type must be 'classif'"))
}
invisible(learner)
},
Expand Down Expand Up @@ -1333,7 +1336,7 @@ DoubleML = R6Class("DoubleML",
this_learner = names(tune_settings$measure)[i_msr]
tune_settings$measure[[this_learner]] = set_default_measure(
tune_settings$measure[[this_learner]],
private$learner_class[[this_learner]])
private$task_type[[this_learner]])
}
}

Expand Down
87 changes: 44 additions & 43 deletions R/double_ml_iivm.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,42 +138,43 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
#' The `DoubleMLData` object providing the data and specifying the variables
#' of the causal model.
#'
#' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr], `character(1)`) \cr
#' An object of the class [mlr3 regression learner][mlr3::LearnerRegr] to
#' pass a learner, possibly with specified parameters, for example
#' `lrn("regr.cv_glmnet", s = "lambda.min")`.
#' Alternatively, a `character(1)` specifying the name of a
#' [mlr3 regression learner][mlr3::LearnerRegr] that is available in
#' [mlr3](https://mlr3.mlr-org.com/index.html) or its extension packages
#' [mlr3learners](https://mlr3learners.mlr-org.com/) or
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/),
#' for example `"regr.cv_glmnet"`. \cr
#' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr],
#' [`Learner`][mlr3::Learner], `character(1)`) \cr
#' A learner of the class [`LearnerRegr`][mlr3::LearnerRegr], which is
#' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
#' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
#' Alternatively, a [`Learner`][mlr3::Learner] object with public field
#' `task_type = "regr"` can be passed, for example of class
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
#' be passed with specified parameters, for example
#' `lrn("regr.cv_glmnet", s = "lambda.min")`. \cr
#' `ml_g` refers to the nuisance function \eqn{g_0(Z,X) = E[Y|X,Z]}.
#'
#' @param ml_m ([`LearnerClassif`][mlr3::LearnerClassif], `character(1)`) \cr
#' An object of the class
#' [mlr3 classification learner][mlr3::LearnerClassif] to pass a learner,
#' possibly with specified parameters, for example
#' `lrn("classif.cv_glmnet", s = "lambda.min")`.
#' Alternatively, a `character(1)` specifying the name of
#' a [mlr3 classification learner][mlr3::LearnerClassif] that is available
#' in [mlr3](https://mlr3.mlr-org.com/index.html) or its extension packages
#' [mlr3learners](https://mlr3learners.mlr-org.com/) or
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/),
#' for example `"classif.cv_glmnet"`. \cr
#' @param ml_m ([`LearnerClassif`][mlr3::LearnerClassif],
#' [`Learner`][mlr3::Learner], `character(1)`) \cr
#' A learner of the class [`LearnerClassif`][mlr3::LearnerClassif], which is
#' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
#' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
#' Alternatively, a [`Learner`][mlr3::Learner] object with public field
#' `task_type = "classif"` can be passed, for example of class
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
#' be passed with specified parameters, for example
#' `lrn("classif.cv_glmnet", s = "lambda.min")`. \cr
#' `ml_m` refers to the nuisance function \eqn{m_0(X) = E[Z|X]}.
#'
#' @param ml_r ([`LearnerClassif`][mlr3::LearnerClassif], `character(1)`) \cr
#' An object of the class
#' [mlr3 classification learner][mlr3::LearnerClassif] to pass a learner,
#' possibly with specified parameters, for example
#' `lrn("classif.cv_glmnet", s = "lambda.min")`.
#' Alternatively, a `character(1)` specifying the name of a
#' [mlr3 classification learner][mlr3::LearnerClassif] that is available in
#' [mlr3](https://mlr3.mlr-org.com/index.html) or its extension packages
#' [mlr3learners](https://mlr3learners.mlr-org.com/) or
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/),
#' for example `"classif.cv_glmnet"`. \cr
#' @param ml_r ([`LearnerClassif`][mlr3::LearnerClassif],
#' [`Learner`][mlr3::Learner], `character(1)`) \cr
#' A learner of the class [`LearnerClassif`][mlr3::LearnerClassif], which is
#' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
#' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
#' Alternatively, a [`Learner`][mlr3::Learner] object with public field
#' `task_type = "classif"` can be passed, for example of class
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
#' be passed with specified parameters, for example
#' `lrn("classif.cv_glmnet", s = "lambda.min")`. \cr
#' `ml_r` refers to the nuisance function \eqn{r_0(Z,X) = E[D|X,Z]}.
#'
#' @param n_folds (`integer(1)`)\cr
Expand Down Expand Up @@ -241,7 +242,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",

private$check_data(self$data)
private$check_score(self$score)
private$learner_class = list(
private$task_type = list(
"ml_g" = NULL,
"ml_m" = NULL,
"ml_r" = NULL)
Expand Down Expand Up @@ -295,7 +296,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
smpls = smpls,
est_params = self$get_params("ml_m"),
return_train_preds = FALSE,
learner_class = private$learner_class$ml_m,
task_type = private$task_type$ml_m,
fold_specific_params = private$fold_specific_params)

g0_hat = dml_cv_predict(self$learner$ml_g,
Expand All @@ -306,7 +307,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
smpls = cond_smpls$smpls_0,
est_params = self$get_params("ml_g0"),
return_train_preds = FALSE,
learner_class = private$learner_class$ml_g,
task_type = private$task_type$ml_g,
fold_specific_params = private$fold_specific_params)

g1_hat = dml_cv_predict(self$learner$ml_g,
Expand All @@ -317,7 +318,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
smpls = cond_smpls$smpls_1,
est_params = self$get_params("ml_g1"),
return_train_preds = FALSE,
learner_class = private$learner_class$ml_g,
task_type = private$task_type$ml_g,
fold_specific_params = private$fold_specific_params)

if (self$subgroups$always_takers == FALSE) {
Expand All @@ -331,7 +332,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
smpls = cond_smpls$smpls_0,
est_params = self$get_params("ml_r0"),
return_train_preds = FALSE,
learner_class = private$learner_class$ml_r,
task_type = private$task_type$ml_r,
fold_specific_params = private$fold_specific_params)
}

Expand All @@ -346,7 +347,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
smpls = cond_smpls$smpls_1,
est_params = self$get_params("ml_r1"),
return_train_preds = FALSE,
learner_class = private$learner_class$ml_r,
task_type = private$task_type$ml_r,
fold_specific_params = private$fold_specific_params)
}

Expand Down Expand Up @@ -421,7 +422,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
nuisance_id = "nuis_m",
param_set$ml_m, tune_settings,
tune_settings$measure$ml_m,
private$learner_class$ml_m)
private$task_type$ml_m)

tuning_result_g0 = dml_tune(self$learner$ml_g,
c(self$data$x_cols, self$data$other_treat_cols),
Expand All @@ -430,7 +431,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
nuisance_id = "nuis_g0",
param_set$ml_g, tune_settings,
tune_settings$measure$ml_g,
private$learner_class$ml_g)
private$task_type$ml_g)

tuning_result_g1 = dml_tune(self$learner$ml_g,
c(self$data$x_cols, self$data$other_treat_cols),
Expand All @@ -439,7 +440,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
nuisance_id = "nuis_g1",
param_set$ml_g, tune_settings,
tune_settings$measure$ml_g,
private$learner_class$ml_g)
private$task_type$ml_g)

if (self$subgroups$always_takers == TRUE) {
tuning_result_r0 = dml_tune(self$learner$ml_r,
Expand All @@ -449,7 +450,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
nuisance_id = "nuis_r0",
param_set$ml_r, tune_settings,
tune_settings$measure$ml_r,
private$learner_class$ml_r)
private$task_type$ml_r)
} else {
tuning_result_r0 = list(list(), "params" = list(list()))
}
Expand All @@ -462,7 +463,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM",
nuisance_id = "nuis_r1",
param_set$ml_r, tune_settings,
tune_settings$measure$ml_r,
private$learner_class$ml_r)
private$task_type$ml_r)
} else {
tuning_result_r1 = list(list(), "params" = list(list()))
}
Expand Down
57 changes: 29 additions & 28 deletions R/double_ml_irm.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,29 +107,30 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
#' The `DoubleMLData` object providing the data and specifying the variables
#' of the causal model.
#'
#' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr], `character(1)`) \cr
#' An object of the class [mlr3 regression learner][mlr3::LearnerRegr] to
#' pass a learner, possibly with specified parameters, for example
#' `lrn("regr.cv_glmnet", s = "lambda.min")`.
#' Alternatively, a `character(1)` specifying the name of a
#' [mlr3 regression learner][mlr3::LearnerRegr] that is available in
#' [mlr3](https://mlr3.mlr-org.com/index.html) or its extension packages
#' [mlr3learners](https://mlr3learners.mlr-org.com/) or
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/), for example
#' `"regr.cv_glmnet"`. \cr
#' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr],
#' [`Learner`][mlr3::Learner], `character(1)`) \cr
#' A learner of the class [`LearnerRegr`][mlr3::LearnerRegr], which is
#' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
#' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
#' Alternatively, a [`Learner`][mlr3::Learner] object with public field
#' `task_type = "regr"` can be passed, for example of class
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
#' be passed with specified parameters, for example
#' `lrn("regr.cv_glmnet", s = "lambda.min")`. \cr
#' `ml_g` refers to the nuisance function \eqn{g_0(X) = E[Y|X,D]}.
#'
#' @param ml_m ([`LearnerClassif`][mlr3::LearnerClassif], `character(1)`) \cr
#' An object of the class
#' [mlr3 classification learner][mlr3::LearnerClassif] to pass a learner,
#' possibly with specified parameters, for example
#' `lrn("classif.cv_glmnet", s = "lambda.min")`.
#' Alternatively, a `character(1)` specifying the name of a
#' [mlr3 classification learner][mlr3::LearnerClassif] that is available
#' in [mlr3](https://mlr3.mlr-org.com/index.html) or its extension packages
#' [mlr3learners](https://mlr3learners.mlr-org.com/) or
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/),
#' for example `"classif.cv_glmnet"`. \cr
#' @param ml_m ([`LearnerClassif`][mlr3::LearnerClassif],
#' [`Learner`][mlr3::Learner], `character(1)`) \cr
#' A learner of the class [`LearnerClassif`][mlr3::LearnerClassif], which is
#' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its
#' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or
#' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/).
#' Alternatively, a [`Learner`][mlr3::Learner] object with public field
#' `task_type = "classif"` can be passed, for example of class
#' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly
#' be passed with specified parameters, for example
#' `lrn("classif.cv_glmnet", s = "lambda.min")`. \cr
#' `ml_m` refers to the nuisance function \eqn{m_0(X) = E[D|X]}.
#'
#' @param n_folds (`integer(1)`)\cr
Expand Down Expand Up @@ -185,7 +186,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",

private$check_data(self$data)
private$check_score(self$score)
private$learner_class = list(
private$task_type = list(
"ml_g" = NULL,
"ml_m" = NULL)
ml_g = private$assert_learner(ml_g, "ml_g", Regr = TRUE, Classif = FALSE)
Expand Down Expand Up @@ -227,7 +228,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
smpls = smpls,
est_params = self$get_params("ml_m"),
return_train_preds = FALSE,
learner_class = private$learner_class$ml_m,
task_type = private$task_type$ml_m,
fold_specific_params = private$fold_specific_params)

g0_hat = dml_cv_predict(self$learner$ml_g,
Expand All @@ -238,7 +239,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
smpls = cond_smpls$smpls_0,
est_params = self$get_params("ml_g0"),
return_train_preds = FALSE,
learner_class = private$learner_class$ml_g,
task_type = private$task_type$ml_g,
fold_specific_params = private$fold_specific_params)

g1_hat = NULL
Expand All @@ -251,7 +252,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
smpls = cond_smpls$smpls_1,
est_params = self$get_params("ml_g1"),
return_train_preds = FALSE,
learner_class = private$learner_class$ml_g,
task_type = private$task_type$ml_g,
fold_specific_params = private$fold_specific_params)
}

Expand Down Expand Up @@ -330,7 +331,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
nuisance_id = "nuis_m",
param_set$ml_m, tune_settings,
tune_settings$measure$ml_m,
private$learner_class$ml_m)
private$task_type$ml_m)

tuning_result_g0 = dml_tune(self$learner$ml_g,
c(self$data$x_cols, self$data$other_treat_cols),
Expand All @@ -339,7 +340,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
nuisance_id = "nuis_g0",
param_set$ml_g, tune_settings,
tune_settings$measure$ml_g,
private$learner_class$ml_g)
private$task_type$ml_g)

if ((is.character(self$score) && self$score == "ATE") || is.function(self$score)) {
tuning_result_g1 = dml_tune(self$learner$ml_g,
Expand All @@ -349,7 +350,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM",
nuisance_id = "nuis_g1",
param_set$ml_g, tune_settings,
tune_settings$measure$ml_g,
private$learner_class$ml_g)
private$task_type$ml_g)
} else {
tuning_result_g1 = list(list(), "params" = list(list()))
}
Expand Down
Loading