Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ Suggests:
MASS,
pracma
LazyData: TRUE
RoxygenNote: 6.0.1
RoxygenNote: 7.1.2
VignetteBuilder: knitr
35 changes: 26 additions & 9 deletions R/ccf.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,21 @@ canonical_correlation_forest.formula = function(
#' @param newdata A data frame or a matrix containing the test data.
#' @param verbose Optional argument to control if additional information are
#' printed to the output. Default is \code{FALSE}.
#' @param probClass Optional argument specifying name of class to compute probabilities for
#' @param ... Additional parameters passed on to prediction from individual
#' canonical correlation trees.
#' @export
predict.canonical_correlation_forest = function(
object, newdata, verbose = FALSE, ...) {
object, newdata, verbose = FALSE, probClass = NULL, ...) {
if (missing(newdata)) {
stop("Argument 'newdata' is missing.")
}
if(!is.null(probClass)){
classNames = names(m1$forest[[1]]$trainingCounts)
if(!(probClass %in% classNames)){
stop(paste0("Argument probClass = ", probClass, " is not in list of class names. Options are: ", paste(classNames, collapse = ', ')))
}
}

ntree <- length(object$forest)
treePredictions <- matrix(NA, nrow = nrow(newdata), ncol = ntree)
Expand All @@ -140,20 +147,30 @@ predict.canonical_correlation_forest = function(
if (verbose) {
cat("calculating predictions\n")
}

# returns list of list
treePredictions = lapply(object$forest, predict, newdata)
treePredictions = lapply(object$forest, predict, newdata, probClass = probClass)
# convert to matrix
treePredictions = do.call(cbind, treePredictions)
if (verbose) {
cat("Majority vote\n")
}
treePredictions <- apply(treePredictions, 1, function(row) {
names(which.max(table(row)))
})

return(treePredictions)
if(!is.null(probClass)){
if (verbose) {
cat("Mean probability\n")
}
return(rowMeans(treePredictions))
}else{
if (verbose) {
cat("Majority vote\n")
}
treePredictions <- apply(treePredictions, 1, function(row) {
names(which.max(table(row)))
})

return(treePredictions)
}
}


#' Visualization of canonical correlation forest
#'
#' TODO: document
Expand Down
12 changes: 8 additions & 4 deletions R/cct.R
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,14 @@ canonical_correlation_tree = function(
}

#' @export
predict.canonical_correlation_tree = function(object, newData, ...){
predict.canonical_correlation_tree = function(object, newData, probClass = NULL, ...){
tree = object
if (tree$isLeaf) {
return(tree$classIndex)
if(!is.null(probClass)){
return(tree$trainingCounts[names(tree$trainingCounts) == probClass]/sum(tree$trainingCounts))
}else{
return(tree$classIndex)
}
}
nr_of_features = length(tree$decisionProjection)
# TODO use formula instead of all but last column
Expand All @@ -277,12 +281,12 @@ predict.canonical_correlation_tree = function(object, newData, ...){
if (any(lessThanPartPoint)) {
currentNodeClasses[lessThanPartPoint, ] =
predict.canonical_correlation_tree(tree$refLeftChild,
X[lessThanPartPoint, ,drop = FALSE]) #nolint
X[lessThanPartPoint, ,drop = FALSE], probClass = probClass) #nolint
}
if (any(!lessThanPartPoint)) {
currentNodeClasses[!lessThanPartPoint, ] =
predict.canonical_correlation_tree(tree$refRightChild,
X[!lessThanPartPoint, ,drop = FALSE]) #nolint
X[!lessThanPartPoint, ,drop = FALSE], probClass = probClass) #nolint
}
return(currentNodeClasses)
}
Expand Down
13 changes: 10 additions & 3 deletions man/canonical_correlation_tree.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 10 additions & 6 deletions man/ccf.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions man/predict.canonical_correlation_forest.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions man/spirals.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.