2015-05-31 21:09:04 +02:00
|
|
|
# VT.FOREST.FOLD ----------------------------------------------------------
|
|
|
|
|
2015-06-10 16:48:18 +02:00
|
|
|
#' Difft via k random forests
|
|
|
|
#'
|
2015-06-01 23:22:49 +02:00
|
|
|
#' A reference class to compute twins via k random forest
|
|
|
|
#'
|
2015-06-10 16:48:18 +02:00
|
|
|
#' \code{VT.forest.fold} extends \code{VT.forest}
|
2015-06-01 23:22:49 +02:00
|
|
|
#'
|
2015-06-10 16:48:18 +02:00
|
|
|
#' Twins are estimated by k-fold cross validation. A forest is computed on k-1/k
|
|
|
|
#' of the data and then used to estimate twin1 and twin2 on 1/k of the left
|
|
|
|
#' data.
|
2015-06-01 23:22:49 +02:00
|
|
|
#'
|
2015-06-10 16:48:18 +02:00
|
|
|
#' @include forest.R
|
|
|
|
#'
|
|
|
|
#' @field interactions logical set TRUE if model has been computed with
|
|
|
|
#' interactions
|
|
|
|
#' @field fold numeric, number of fold, i.e. number of forest (k)
|
|
|
|
#' @field ratio numeric experimental, use to balance sampsize. Defaut to 1.
|
|
|
|
#' @field groups vector Define which observations belong to which group
|
2015-06-12 00:34:25 +02:00
|
|
|
#' @field ... field from parent class : \code{\link{VT.forest}}
|
2015-06-10 16:48:18 +02:00
|
|
|
#'
|
|
|
|
#' @name VT.forest.fold
|
|
|
|
#'
|
|
|
|
#' @seealso \code{\link{VT.difft}}, \code{\link{VT.forest}},
|
|
|
|
#' \code{\link{VT.forest.one}}, \code{\link{VT.forest.double}}
|
|
|
|
#'
|
2015-06-01 23:22:49 +02:00
|
|
|
#' @import methods
|
2015-07-25 02:10:28 +02:00
|
|
|
#'
|
|
|
|
#' @export VT.forest.fold
|
|
|
|
#'
|
|
|
|
|
2015-05-31 21:09:04 +02:00
|
|
|
VT.forest.fold <- setRefClass(
|
|
|
|
Class = "VT.forest.fold",
|
|
|
|
|
|
|
|
contains = "VT.forest",
|
|
|
|
|
|
|
|
fields = list(
|
|
|
|
interactions = "logical",
|
|
|
|
fold = "numeric",
|
|
|
|
ratio = "numeric",
|
|
|
|
groups = "vector"
|
|
|
|
),
|
|
|
|
|
|
|
|
methods = list(
|
|
|
|
initialize = function(vt.object, fold, ratio, interactions = T, ...){
|
|
|
|
|
|
|
|
.self$fold <- fold
|
|
|
|
|
|
|
|
.self$ratio <- ratio
|
|
|
|
|
|
|
|
.self$interactions <- interactions
|
|
|
|
|
|
|
|
callSuper(vt.object, ...)
|
|
|
|
},
|
|
|
|
|
|
|
|
run = function(parallel = F, ...){
|
|
|
|
|
|
|
|
.self$groups <- sample(1:.self$fold, nrow(.self$vt.object$data), replace = T)
|
|
|
|
|
|
|
|
for(g in 1:.self$fold){
|
|
|
|
.self$runOneForest(g, ...)
|
|
|
|
}
|
|
|
|
|
|
|
|
.self$computeDifft()
|
|
|
|
},
|
|
|
|
|
|
|
|
runOneForest = function(group, ...){
|
|
|
|
data <- .self$vt.object$getX(interactions = .self$interactions)
|
2015-07-12 01:20:20 +02:00
|
|
|
X <- data[.self$groups != group, -1]
|
2015-05-31 21:09:04 +02:00
|
|
|
Y <- .self$vt.object$data[.self$groups != group, 1]
|
|
|
|
Yeff <- table(Y) # 1 -> levels(Y)[1] & 2 -> levels(Y)[2]
|
|
|
|
sampmin <- min(Yeff[1], Yeff[2])
|
|
|
|
|
|
|
|
if(sampmin == Yeff[2]){
|
|
|
|
samp2 <- sampmin
|
|
|
|
samp1 <- min(Yeff[1], round(.self$ratio*Yeff[1], digits = 0))
|
|
|
|
}else{
|
|
|
|
samp2 <- Yeff[2]
|
|
|
|
samp1 <- sampmin
|
|
|
|
}
|
2015-06-10 16:48:18 +02:00
|
|
|
if(!requireNamespace("randomForest", quietly = TRUE)) stop("randomForest package must be loaded.")
|
2015-05-31 21:09:04 +02:00
|
|
|
rf <- randomForest(x = X, y = Y, sampsize = c(samp1, samp2), keep.forest = T, ...)
|
|
|
|
|
|
|
|
.self$computeTwin1(rf, group)
|
|
|
|
.self$computeTwin2(rf, group)
|
|
|
|
},
|
|
|
|
|
|
|
|
computeTwin1 = function(rfor, group){
|
|
|
|
|
2015-07-12 01:20:20 +02:00
|
|
|
data <- .self$vt.object$getX(interactions = .self$interactions)
|
|
|
|
data <- data[.self$groups == group, -1]
|
2015-05-31 21:09:04 +02:00
|
|
|
|
|
|
|
.self$twin1[.self$groups == group] <- VT.predict(rfor = rfor, newdata = data, type = .self$vt.object$type)
|
|
|
|
|
|
|
|
return(invisible(.self$twin1))
|
|
|
|
},
|
|
|
|
|
|
|
|
computeTwin2 = function(rfor, group){
|
|
|
|
|
|
|
|
.self$vt.object$switchTreatment()
|
|
|
|
data <- .self$vt.object$getX(interactions = .self$interactions)
|
|
|
|
data <- data[.self$groups == group, ]
|
|
|
|
|
|
|
|
.self$twin2[.self$groups == group] <- VT.predict(rfor = rfor, newdata = data, type = .self$vt.object$type)
|
|
|
|
|
|
|
|
.self$vt.object$switchTreatment()
|
|
|
|
return(invisible(.self$twin2))
|
|
|
|
}
|
|
|
|
)
|
|
|
|
)
|