Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ Imports:
survival,
lattice,
splines,
Rcpp (>= 0.11.5)
Rcpp (>= 0.11.5),
Matrix
Suggests:
testthat (>= 0.11.0),
knitr,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export(to_old_gbm)
export(training_params)
export(trees)
import(lattice)
importFrom(Matrix,sparseMatrix)
importFrom(Rcpp,sourceCpp)
importFrom(grDevices,rainbow)
importFrom(graphics,abline)
Expand Down
65 changes: 41 additions & 24 deletions R/gbm-predict.r
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
#' Predict method for GBM Model Fits
#'
#'
#' Predicted values based on a generalized boosted model object - from
#' gbmt
#'
#'
#' \code{predict.GBMFit} produces predicted values for each
#' observation in a new dataset \code{newdata} using the first
#' \code{num_trees} iterations of the boosting sequence. If
#' \code{num_trees} is a vector than the result is a matrix with each
#' column representing the predictions from gbm models with
#' \code{num_trees[1]} iterations, \code{num_trees[2]} iterations, and
#' so on.
#'
#'
#' The predictions from \code{gbmt} do not include the offset
#' term. The user may add the value of the offset to the predicted
#' value if desired.
#'
#'
#' If \code{gbm_fit_obj} was fit using \code{\link{gbmt}}, there will
#' be no \code{Terms} component. Therefore, the user has greater
#' responsibility to make sure that \code{newdata} is of the same
#' format (order and number of variables) as the one originally used
#' to fit the model.
#'
#'
#' @param object Object of class inheriting from \code{GBMFit}.
#' @param newdata Data frame of observations for which to make
#' predictions
Expand All @@ -30,51 +30,55 @@
#' @param type The scale on which gbm makes the predictions
#' @param single.tree If \code{single.tree=TRUE} then \code{gbm_predict}
#' returns only the predictions from tree(s) \code{n.trees}
#' @param nodes If \code{nodes=TRUE} then \code{gbm_predict}
#' returns a sparse matrix of nodes indicating which internal and terminal nodes
#' each observation passes through in the entire ensemble.
#' @param \dots further arguments passed to or from other methods
#' @return Returns a vector of predictions. By default the predictions
#' are on the scale of f(x). For example, for the Bernoulli loss the
#' returned value is on the log odds scale, poisson loss on the log
#' scale, and coxph is on the log hazard scale.
#'
#'
#' If \code{type="response"} then \code{gbmt} converts back to the same scale as
#' the outcome. Currently the only effect this will have is returning
#' probabilities for bernoulli and expected counts for poisson. For the other
#' distributions "response" and "link" return the same.
#' @seealso \code{\link{gbmt}}
#' @keywords models regression
#' @importFrom stats predict
#' @export
#' @importFrom Matrix sparseMatrix
#' @export
predict.GBMFit <- function(object, newdata, n.trees,
type="link", single.tree=FALSE,
type="link", single.tree=FALSE, nodes=FALSE,
...)
{
# Check inputs
if(!is.element(type, c("link","response" ))) {
stop("type must be either 'link' or 'response'")
}

if(missing(newdata) || !is.data.frame(newdata)) {
stop("newdata must be provided as a data frame")
}

if(missing(n.trees)) {
stop("Number of trees to be used in prediction must be provided.")
}

if (length(n.trees) == 0) {
stop("n.trees cannot be NULL or a vector of zero length")
}

if(any(n.trees != as.integer(n.trees)) || is.na(all(n.trees == as.integer(n.trees)))
|| any(n.trees < 0)) {
stop("n.trees must be a vector of non-negative integers")
}

if(!is.null(attr(object$Terms,"offset")))
{
warning("predict.GBMFit does not add the offset to the predicted values.")
}

# Get data
if(!is.null(object$Terms)) {
x <- model.frame(terms(reformulate(object$variables$var_names)),
Expand All @@ -83,7 +87,7 @@ predict.GBMFit <- function(object, newdata, n.trees,
} else {
x <- newdata
}

# Convert predictor factors into appropriate numeric
for(i in seq_len(ncol(x))) {
if(is.factor(x[,i])) {
Expand All @@ -92,27 +96,40 @@ predict.GBMFit <- function(object, newdata, n.trees,
} else {
new_compare <- levels(x[,i])
}

if (!identical(object$variables$var_levels[[i]], new_compare)) {
x[,i] <- factor(x[,i], union(object$variables$var_levels[[i]], levels(x[,i])))
}

x[,i] <- as.numeric(x[,i])-1
}
}

if(any(n.trees > object$params$num_trees)) {
n.trees[n.trees > object$params$num_trees] <- object$params$num_trees
warning("Number of trees exceeded number fit so far. Using ", paste(n.trees,collapse=" "),".")
}

i.ntree.order <- order(n.trees)

# Next if block for compatibility with objects created with 1.6
if(is.null(object$num.classes)) {
object$num.classes <- 1
}


if (nodes) {

sparse_nodes_info <- .Call("gbm_pred_sparse_nodes",
X=as.matrix(as.data.frame(x)),
n.trees=as.integer(n.trees[order(n.trees)]),
trees=trees(object),
c.split=object$c.split,
var.type=as.integer(object$variables$var_type),
PACKAGE = "gbm3")

return(do.call(Matrix::sparseMatrix, c(sparse_nodes_info, index1=FALSE)))
}

predF <- .Call("gbm_pred",
X=as.matrix(as.data.frame(x)),
n.trees=as.integer(n.trees[order(n.trees)]),
Expand All @@ -122,14 +139,14 @@ predict.GBMFit <- function(object, newdata, n.trees,
var.type=as.integer(object$variables$var_type),
single.tree = as.integer(single.tree),
PACKAGE = "gbm3")

# Convert into matrix of predictions
if((length(n.trees) > 1) || (!is.null(object$num.classes) && (object$num.classes > 1))) {
predF <- matrix(predF, ncol=length(n.trees), byrow=FALSE)
colnames(predF) <- n.trees
predF[, order(n.trees)] <- predF
}

# Adjust scale of predictions
if(type=="response") {
predF <- adjust_pred_scale(predF, object$distribution)
Expand All @@ -139,4 +156,4 @@ predict.GBMFit <- function(object, newdata, n.trees,
}
return(predF)
}

2 changes: 1 addition & 1 deletion man/gbm_dist.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion man/predict.GBMFit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/to_old_gbm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

127 changes: 126 additions & 1 deletion src/gbmentry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ SEXP gbm(SEXP response, SEXP intResponse, SEXP offset_vec, SEXP covariates,

// Set up parameters for initialization
DataDistParams datadistparams(
response, intResponse, offset_vec, covariates, covar_order,
response, intResponse, offset_vec, covariates, covar_order,
obs_weight, misc, prior_coeff_var, row_to_obs_id, var_classes,
monotonicity_vec, dist_family, fraction_inbag, num_rows_in_training,
num_obs_in_training, number_offeatures, parallel);
Expand Down Expand Up @@ -456,4 +456,129 @@ SEXP gbm_plot(
END_RCPP
} // gbm_plot


//-----------------------------------
// Function: gbm_pred_sparse_nodes
//
// Returns: Sparse node matrix data - an Rcpp::List of Sparse Matrix components
//
// Description: Makes predictions using a previously fit
// gbm model and data.
//
// Parameters:
// covariates - SEXP containing the predictor values - becomes
// const Rcpp::NumericMatrix.
// num_trees - SEXP containing an int or vector of ints specifying the number
// of
// trees to make predictions on - stored as const
// Rcpp::IntegerVector.
// fitted_trees - SEXP containing lists defining the previously fitted trees -
// stored as const Rcpp::GenericVector.
// categorical_splits - SEXP containing list of the categories of the split
// variables
// defining a tree - stored as a
// const
// Rcpp::GenericVector.
// variable_type - SEXP containing integers specifying whether the variable
// is continuous/nominal- stored as const
//-----------------------------------

SEXP gbm_pred_sparse_nodes(SEXP covariates, SEXP num_trees, SEXP fitted_trees,
SEXP categorical_splits, SEXP variable_type) {
BEGIN_RCPP
int tree_num = 0;
int obs_num = 0;
const Rcpp::NumericMatrix kCovarMat(covariates);
const int kNumCovarRows = kCovarMat.nrow();
const Rcpp::IntegerVector kTrees(num_trees);
const Rcpp::GenericVector kFittedTrees(fitted_trees);
const Rcpp::IntegerVector kVarType(variable_type);
const Rcpp::GenericVector kSplits(categorical_splits);
int prediction_iteration = 0;

if ((kCovarMat.ncol() != kVarType.size())) {
throw gbm_exception::InvalidArgument("shape mismatch");
}

// Sparse matrix objects
std::vector<int> kSparseMatrixRows;
std::vector<int> kSparseMatrixCols;
Rcpp::IntegerVector kDims(2, 0);
int nCols = 0;

// initialize the predicted values
tree_num = 0;
for (prediction_iteration = 0; prediction_iteration < kTrees.size();
prediction_iteration++) {

const int kCurrTree = kTrees[prediction_iteration];

while (tree_num < kCurrTree) {
const Rcpp::GenericVector kThisFitTree = kFittedTrees[tree_num];
const Rcpp::IntegerVector kThisSplitVar = kThisFitTree[0];
const Rcpp::NumericVector kThisSplitCode = kThisFitTree[1];
const Rcpp::IntegerVector kThisLeftNode = kThisFitTree[2];
const Rcpp::IntegerVector kThisRightNode = kThisFitTree[3];
const Rcpp::IntegerVector kThisMissingNode = kThisFitTree[4];

for (obs_num = 0; obs_num < kNumCovarRows; obs_num++) {
int iCurrentNode = 0;
kSparseMatrixRows.push_back(obs_num);
kSparseMatrixCols.push_back(iCurrentNode + nCols);

while (kThisSplitVar[iCurrentNode] != -1) {
const double dX =
kCovarMat[kThisSplitVar[iCurrentNode] * kNumCovarRows + obs_num];
// missing?
if (ISNA(dX)) {
iCurrentNode = kThisMissingNode[iCurrentNode];
}
// continuous?
else if (kVarType[kThisSplitVar[iCurrentNode]] == 0) {
if (dX < kThisSplitCode[iCurrentNode]) {
iCurrentNode = kThisLeftNode[iCurrentNode];
} else {
iCurrentNode = kThisRightNode[iCurrentNode];
}
} else // categorical
{
const Rcpp::IntegerVector kMySplits =
kSplits[kThisSplitCode[iCurrentNode]];
if (kMySplits.size() < (int)dX + 1) {
iCurrentNode = kThisMissingNode[iCurrentNode];
} else {
const int iCatSplitIndicator = kMySplits[(int)dX];
if (iCatSplitIndicator == -1) {
iCurrentNode = kThisLeftNode[iCurrentNode];
} else if (iCatSplitIndicator == 1) {
iCurrentNode = kThisRightNode[iCurrentNode];
} else // categorical level not present in training
{
iCurrentNode = kThisMissingNode[iCurrentNode];
}
}
}
kSparseMatrixRows.push_back(obs_num);
kSparseMatrixCols.push_back(iCurrentNode + nCols);
}
} // iObs
nCols += kThisSplitVar.size();
tree_num++;
} // iTree
} // iPredIteration

kDims[0] = kNumCovarRows;
kDims[1] = nCols;

return(Rcpp::List::create(
Rcpp::Named("i") = Rcpp::wrap(kSparseMatrixRows),
Rcpp::Named("j") = Rcpp::wrap(kSparseMatrixCols),
Rcpp::Named("dims") = Rcpp::wrap(kDims)));
END_RCPP
}
} // end extern "C"