1
0
Fork 0
mirror of https://github.com/prise6/aVirtualTwins.git synced 2024-05-08 20:46:31 +02:00
aVirtualTwins/R/forest.double.R

54 lines
1.8 KiB
R

# VT.FOREST.DOUBLE --------------------------------------------------------
# IF RUNNING DOUBLE FOREST COMPUTATION
#' A reference class to compute twins via double random forests
#'
#' @include forest.R
#'
#' @field model_trt1 a caret/RandomForest/randomForest object for treatment T = 1
#' @field model_trt0 a caret/RandomForest/randomForest object for treatment T = 0
#'
#' @import methods
VT.forest.double <- setRefClass(
Class = "VT.forest.double",
contains = "VT.forest",
fields = list(
model_trt1 = "ANY",
model_trt0 = "ANY"
),
methods = list(
initialize = function(vt.object, model_trt1, model_trt0, ...){
.self$checkModel(model_trt1)
.self$checkModel(model_trt0)
.self$model_trt1 <- model_trt1
.self$model_trt0 <- model_trt0
callSuper(vt.object, ...)
},
computeTwin1 = function(){
"Compute twin1 with OOB predictions from double forests"
# Model with treatment (1)
.self$twin1[.self$vt.object$data[, 2] == 1] <- VT.predict(rfor = .self$model_trt1, type = .self$vt.object$type)
# Model without treatment (0)
.self$twin1[vt.object$data[, 2] == 0] <- VT.predict(rfor = .self$model_trt0, type = .self$vt.object$type)
return(.self$twin1)
return(invisible(.self$twin1))
},
computeTwin2 = function(){
"Compute twin2 by the other part of data in the other forest"
# Model with treatment (1)
.self$twin2[.self$vt.object$data[, 2] == 1] <- VT.predict(.self$model_trt0, newdata = .self$vt.object$getX(1, interactions = F), type = .self$vt.object$type)
# Model without treatment (0)
.self$twin2[.self$vt.object$data[, 2] == 0] <- VT.predict(.self$model_trt1, newdata = .self$vt.object$getX(0, interactions = F), type = .self$vt.object$type)
return(invisible(.self$twin2))
}
)
)