diff --git a/DESCRIPTION b/DESCRIPTION index 7304e27..b3f7f8d 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -26,5 +26,5 @@ Suggests: MASS, pracma LazyData: TRUE -RoxygenNote: 6.0.1 +RoxygenNote: 7.1.2 VignetteBuilder: knitr diff --git a/R/ccf.R b/R/ccf.R index 69b8ca0..bb286a4 100644 --- a/R/ccf.R +++ b/R/ccf.R @@ -124,14 +124,21 @@ canonical_correlation_forest.formula = function( #' @param newdata A data frame or a matrix containing the test data. #' @param verbose Optional argument to control if additional information are #' printed to the output. Default is \code{FALSE}. +#' @param probClass Optional argument specifying name of class to compute probabilities for #' @param ... Additional parameters passed on to prediction from individual #' canonical correlation trees. #' @export predict.canonical_correlation_forest = function( - object, newdata, verbose = FALSE, ...) { + object, newdata, verbose = FALSE, probClass = NULL, ...) { if (missing(newdata)) { stop("Argument 'newdata' is missing.") } + if(!is.null(probClass)){ + classNames = names(m1$forest[[1]]$trainingCounts) + if(!(probClass %in% classNames)){ + stop(paste0("Argument probClass = ", probClass, " is not in list of class names. Options are: ", paste(classNames, collapse = ', '))) + } + } ntree <- length(object$forest) treePredictions <- matrix(NA, nrow = nrow(newdata), ncol = ntree) @@ -140,20 +147,30 @@ predict.canonical_correlation_forest = function( if (verbose) { cat("calculating predictions\n") } + # returns list of list - treePredictions = lapply(object$forest, predict, newdata) + treePredictions = lapply(object$forest, predict, newdata, probClass = probClass) # convert to matrix treePredictions = do.call(cbind, treePredictions) - if (verbose) { - cat("Majority vote\n") - } - treePredictions <- apply(treePredictions, 1, function(row) { - names(which.max(table(row))) - }) - return(treePredictions) + if(!is.null(probClass)){ + if (verbose) { + cat("Mean probability\n") + } + return(rowMeans(treePredictions)) + }else{ + if (verbose) { + cat("Majority vote\n") + } + treePredictions <- apply(treePredictions, 1, function(row) { + names(which.max(table(row))) + }) + + return(treePredictions) + } } + #' Visualization of canonical correlation forest #' #' TODO: document diff --git a/R/cct.R b/R/cct.R index 91f3f23..ac6e005 100644 --- a/R/cct.R +++ b/R/cct.R @@ -257,10 +257,14 @@ canonical_correlation_tree = function( } #' @export -predict.canonical_correlation_tree = function(object, newData, ...){ +predict.canonical_correlation_tree = function(object, newData, probClass = NULL, ...){ tree = object if (tree$isLeaf) { - return(tree$classIndex) + if(!is.null(probClass)){ + return(tree$trainingCounts[names(tree$trainingCounts) == probClass]/sum(tree$trainingCounts)) + }else{ + return(tree$classIndex) + } } nr_of_features = length(tree$decisionProjection) # TODO use formula instead of all but last column @@ -277,12 +281,12 @@ predict.canonical_correlation_tree = function(object, newData, ...){ if (any(lessThanPartPoint)) { currentNodeClasses[lessThanPartPoint, ] = predict.canonical_correlation_tree(tree$refLeftChild, - X[lessThanPartPoint, ,drop = FALSE]) #nolint + X[lessThanPartPoint, ,drop = FALSE], probClass = probClass) #nolint } if (any(!lessThanPartPoint)) { currentNodeClasses[!lessThanPartPoint, ] = predict.canonical_correlation_tree(tree$refRightChild, - X[!lessThanPartPoint, ,drop = FALSE]) #nolint + X[!lessThanPartPoint, ,drop = FALSE], probClass = probClass) #nolint } return(currentNodeClasses) } diff --git a/man/canonical_correlation_tree.Rd b/man/canonical_correlation_tree.Rd index d30eb67..4a93da8 100644 --- a/man/canonical_correlation_tree.Rd +++ b/man/canonical_correlation_tree.Rd @@ -4,9 +4,16 @@ \alias{canonical_correlation_tree} \title{Computes a canonical correlation tree} \usage{ -canonical_correlation_tree(X, Y, depth = 0, minPointsForSplit = 2, - maxDepthSplit = Inf, xVariationTolerance = 1e-10, - projectionBootstrap = FALSE, ancestralProbs = NULL) +canonical_correlation_tree( + X, + Y, + depth = 0, + minPointsForSplit = 2, + maxDepthSplit = Inf, + xVariationTolerance = 1e-10, + projectionBootstrap = FALSE, + ancestralProbs = NULL +) } \arguments{ \item{X}{Predictor matrix of size \eqn{n \times p} with \eqn{n} observations and \eqn{p} diff --git a/man/ccf.Rd b/man/ccf.Rd index fb0ec10..23cbdab 100644 --- a/man/ccf.Rd +++ b/man/ccf.Rd @@ -6,14 +6,18 @@ \alias{canonical_correlation_forest.formula} \title{Canonical correlation forest} \usage{ -canonical_correlation_forest(x, y = NULL, ntree = 200, verbose = FALSE, - ...) +canonical_correlation_forest(x, y = NULL, ntree = 200, verbose = FALSE, ...) -\method{canonical_correlation_forest}{default}(x, y = NULL, ntree = 200, - verbose = FALSE, projectionBootstrap = FALSE, ...) +\method{canonical_correlation_forest}{default}( + x, + y = NULL, + ntree = 200, + verbose = FALSE, + projectionBootstrap = FALSE, + ... +) -\method{canonical_correlation_forest}{formula}(x, y = NULL, ntree = 200, - verbose = FALSE, ...) +\method{canonical_correlation_forest}{formula}(x, y = NULL, ntree = 200, verbose = FALSE, ...) } \arguments{ \item{x}{Numeric matrix (n * p) with n observations of p variables} diff --git a/man/predict.canonical_correlation_forest.Rd b/man/predict.canonical_correlation_forest.Rd index 1ac83b5..3638212 100644 --- a/man/predict.canonical_correlation_forest.Rd +++ b/man/predict.canonical_correlation_forest.Rd @@ -4,8 +4,7 @@ \alias{predict.canonical_correlation_forest} \title{Prediction from canonical correlation forest} \usage{ -\method{predict}{canonical_correlation_forest}(object, newdata, - verbose = FALSE, ...) +\method{predict}{canonical_correlation_forest}(object, newdata, verbose = FALSE, probClass = NULL, ...) } \arguments{ \item{object}{An object of class \code{canonical_correlation_forest}, as created @@ -16,6 +15,8 @@ by the function \code{\link{canonical_correlation_forest}}.} \item{verbose}{Optional argument to control if additional information are printed to the output. Default is \code{FALSE}.} +\item{probClass}{Optional argument specifying name of class to compute probabilities for} + \item{...}{Additional parameters passed on to prediction from individual canonical correlation trees.} } diff --git a/man/spirals.Rd b/man/spirals.Rd index dd5cb64..95cbc8a 100644 --- a/man/spirals.Rd +++ b/man/spirals.Rd @@ -4,12 +4,14 @@ \name{spirals} \alias{spirals} \title{Spiral dataset} -\format{A data frame with 10000 rows and 3 variables: +\format{ +A data frame with 10000 rows and 3 variables: \describe{ \item{x}{numeric scalar: x-coordinate} \item{y}{numeric scalar: y-coordinate} \item{class}{integer: either 1,2 or 3} -}} +} +} \source{ Created by T. Rainforth, URL: \url{https://bitbucket.org/twgr/ccf/raw/49d5fce6fc006bc9a8949c7149fc9524535ce418/Datasets/spirals.csv}