diff --git a/R/forest.wrapper.R b/R/forest.wrapper.R index 82bc272..fae5174 100644 --- a/R/forest.wrapper.R +++ b/R/forest.wrapper.R @@ -26,12 +26,12 @@ vt.forest <- function(forest.type = "one", vt.data, interactions = T, method = " params <- list(...) if (forest.type == "one"){ - if(! "rf" %in% names(params) ){ + if(! "model" %in% names(params) ){ rf <- randomForest(x = vt.data$getX(interactions = interactions, trt = NULL), y = vt.data$getY(), ...) } else{ - rf <- params["rf"] + rf <- params[["model"]] } vt.difft <- VT.forest.one(vt.object = vt.data, model = rf, interactions = interactions, method = method) @@ -41,19 +41,19 @@ vt.forest <- function(forest.type = "one", vt.data, interactions = T, method = " y = vt.data$getY(1), ...) } else - rf_trt1 <- params["model_trt1"] + rf_trt1 <- params[["model_trt1"]] if(! "model_trt0" %in% names(params) ){ rf_trt0 <- randomForest(x = vt.data$getX(trt = 1, interactions = interactions), y = vt.data$getY(1), ...) } else - rf_trt0 <- params["rf_trt0"] + rf_trt0 <- params[["model_trt0"]] - vt.difft <- VT.forest.double(vt.object = vt.data, model_trt1 = rf_trt1, model_trt0 = rf_trt0, method = method, ...) + vt.difft <- VT.forest.double(vt.object = vt.data, model_trt1 = rf_trt1, model_trt0 = rf_trt0, method = method) } else if (forest.type == "fold"){ - fold <- ifelse(! "fold" %in% names(params) , 5, params["fold"]) - fold <- ifelse(! "ratio" %in% names(params) , 1, params["ratio"]) + fold <- ifelse(! "fold" %in% names(params) , 5, as.numeric(params["fold"])) + ratio <- ifelse(! "ratio" %in% names(params) , 1, as.numeric(params["ratio"])) vt.difft <- aVirtualTwins:::VT.forest.fold(vt.object = vt.data, fold = fold, ratio = ratio, interactions = interactions, method = method)