2015-05-31 21:09:04 +02:00
|
|
|
# VT.TREE.REG -------------------------------------------------------------
|
|
|
|
|
2015-06-11 09:32:10 +02:00
|
|
|
#' Regression tree to find subgroups
|
|
|
|
#'
|
|
|
|
#' See \code{\link{VT.tree}}
|
|
|
|
#'
|
2015-06-01 23:22:49 +02:00
|
|
|
#' @include tree.R
|
|
|
|
#'
|
2015-06-11 13:34:55 +02:00
|
|
|
#' @export VT.tree.reg
|
|
|
|
#'
|
2015-06-11 09:32:10 +02:00
|
|
|
#' @name VT.tree.reg
|
2016-10-09 02:44:17 +02:00
|
|
|
#'
|
|
|
|
#' @importFrom rpart rpart
|
2015-06-01 23:22:49 +02:00
|
|
|
|
2015-05-31 21:09:04 +02:00
|
|
|
VT.tree.reg <- setRefClass(
|
|
|
|
Class = "VT.tree.reg",
|
|
|
|
|
|
|
|
contains = "VT.tree",
|
|
|
|
|
|
|
|
methods = list(
|
|
|
|
initialize = function(vt.difft, threshold = 0.05, sens = ">", screening = NULL){
|
|
|
|
callSuper(vt.difft, threshold, sens, screening)
|
|
|
|
|
|
|
|
.self$name <- .self$computeNameOfTree("reg")
|
|
|
|
|
|
|
|
.self$outcome <- .self$vt.difft$difft
|
|
|
|
},
|
|
|
|
|
|
|
|
run = function(...){
|
|
|
|
callSuper()
|
|
|
|
data <- .self$getData()
|
|
|
|
|
2015-06-01 23:22:49 +02:00
|
|
|
.self$tree <- rpart::rpart(as.formula(paste(.self$name, ".", sep = "~")), data = data, ...)
|
2015-05-31 21:09:04 +02:00
|
|
|
|
|
|
|
if(.self$sens == ">")
|
2016-10-09 02:44:17 +02:00
|
|
|
res <- ifelse(stats::predict(.self$tree) >= (.self$threshold), 1, 0)
|
2015-05-31 21:09:04 +02:00
|
|
|
else
|
2016-10-09 02:44:17 +02:00
|
|
|
res <- ifelse(stats::predict(.self$tree) <= (.self$threshold), 1, 0)
|
2015-05-31 21:09:04 +02:00
|
|
|
|
|
|
|
.self$Ahat <- res
|
|
|
|
# if(sum(res) != 0) .self$vt.forest$addAhatColumn(name, res)
|
|
|
|
return(invisible(tree))
|
|
|
|
},
|
|
|
|
|
|
|
|
sumup = function(){
|
|
|
|
cat("Regression Tree")
|
|
|
|
callSuper()
|
|
|
|
}
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
VT.tree.reg$lock("threshold", "vt.difft")
|
|
|
|
|