1
0
Fork 0
mirror of https://github.com/prise6/aVirtualTwins.git synced 2024-05-11 21:06:32 +02:00
aVirtualTwins/R/forest.fold.R
2015-05-31 21:09:04 +02:00

80 lines
2.2 KiB
R

# VT.FOREST.FOLD ----------------------------------------------------------
VT.forest.fold <- setRefClass(
Class = "VT.forest.fold",
contains = "VT.forest",
fields = list(
interactions = "logical",
fold = "numeric",
ratio = "numeric",
groups = "vector"
),
methods = list(
initialize = function(vt.object, fold, ratio, interactions = T, ...){
.self$fold <- fold
.self$ratio <- ratio
.self$interactions <- interactions
callSuper(vt.object, ...)
},
run = function(parallel = F, ...){
.self$groups <- sample(1:.self$fold, nrow(.self$vt.object$data), replace = T)
for(g in 1:.self$fold){
.self$runOneForest(g, ...)
}
.self$computeDifft()
},
runOneForest = function(group, ...){
data <- .self$vt.object$getX(interactions = .self$interactions)
X <- data[.self$groups != group, ]
Y <- .self$vt.object$data[.self$groups != group, 1]
Yeff <- table(Y) # 1 -> levels(Y)[1] & 2 -> levels(Y)[2]
sampmin <- min(Yeff[1], Yeff[2])
if(sampmin == Yeff[2]){
samp2 <- sampmin
samp1 <- min(Yeff[1], round(.self$ratio*Yeff[1], digits = 0))
}else{
samp2 <- Yeff[2]
samp1 <- sampmin
}
rf <- randomForest(x = X, y = Y, sampsize = c(samp1, samp2), keep.forest = T, ...)
.self$computeTwin1(rf, group)
.self$computeTwin2(rf, group)
},
computeTwin1 = function(rfor, group){
data <- .self$vt.object$data[.self$groups == group, -1]
.self$twin1[.self$groups == group] <- VT.predict(rfor = rfor, newdata = data, type = .self$vt.object$type)
return(invisible(.self$twin1))
},
computeTwin2 = function(rfor, group){
.self$vt.object$switchTreatment()
data <- .self$vt.object$getX(interactions = .self$interactions)
data <- data[.self$groups == group, ]
.self$twin2[.self$groups == group] <- VT.predict(rfor = rfor, newdata = data, type = .self$vt.object$type)
.self$vt.object$switchTreatment()
return(invisible(.self$twin2))
}
)
)