aVirtualTwins/R/tree.class.R

55 lines
1.4 KiB
R

# VT.TREE.CLASS -----------------------------------------------------------
#' Classification tree to find subgroups
#'
#' See \code{\link{VT.tree}}
#'
#' @include tree.R
#'
#' @name VT.tree.class
#'
#' @export VT.tree.class
#'
#' @import methods
#'
VT.tree.class <- setRefClass(
Class = "VT.tree.class",
contains = "VT.tree",
methods = list(
initialize = function(vt.difft, threshold = 0.05, sens = ">", screening = NULL){
callSuper(vt.difft, threshold, sens, screening)
.self$name <- .self$computeNameOfTree("class")
if(.self$sens == ">"){
.self$outcome <- ifelse(.self$vt.difft$difft >= .self$threshold, 1, 0)
} else {
.self$outcome <- ifelse(.self$vt.difft$difft <= .self$threshold, 1, 0)
}
},
run = function(...){
"VT.tree.class:run(...) Compute classification tree with rpart parameters"
callSuper()
data <- .self$getData()
if(sum(data[,1]) != 0){
.self$tree <- rpart::rpart(as.formula(paste(.self$name, ".", sep = "~")), data = data, method = "class", ...)
.self$Ahat <- as.numeric(predict(.self$tree, data, type = "class")) - 1
}else{
.self$Ahat <- .self$outcome
}
return(invisible(tree))
},
sumup = function(){
cat("Classification Tree")
callSuper()
}
)
)
VT.tree.class$lock("threshold", "vt.difft")