Skip to content

Possible extract bug #77

@tjmckinley

Description

@tjmckinley

Hi,

Many thanks for creating this package, it's working really well for me but now I've hit a snag. I was trying to work through the vignette on "Working with Saved Trees", and I cannot get the predictions from the saved trees to match those from the original bart model.

A reproducible example is below, where I simulate some data, and then use the extract() function to extract the trees, and then apply the getPredictionsForTree() function in the vignette to the first individual in the training data.

If I run this over the 3 trees in chain 1 sample 1, and sum them, then I think I should obtain the first element of running predict() on the original bart object. However, it's nowhere near, and I can't figure out why. I have written my own code to rebuild the trees as well, and this matches the getPredictionsForTree() function, so I wonder if there's a standardisation or something that I'm missing somewhere. I have tried various things but to no avail so far. Any guidance on what I'm doing wrong would be greatly appreciated. The leaf values in the extracted trees look much smaller than I would expect given the scale of the response variable, so I wonder if there's something going on in there that I am missing?

Many thanks,

TJ

library(dbarts)

## simulate data
f <- function(x) {
  10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
  10 * x[,4] + 5 * x[,5]
}
set.seed(99)
sigma <- 1.0
n <- 100
x <- matrix(runif(n * 10), n, 10)
y <- rnorm(n, f(x), sigma)
data <- data.frame(x, y)

## fit BART model
bartFit <- bart(
  y ~ ., data,
  ndpost = 4, # number of posterior samples
  nskip = 1000, # number of "warmup" samples to discard
  nchain = 2, # number of independent, parallel chains
  nthread = 1, # units of parallel execution
  ntree = 3, # number of trees per chain
  seed = 2, # chosen to generate a deep tree
  keeptrees = TRUE,
  verbose = FALSE
)

## code from vignette
getPredictionsForTree <- function(tree, x) {
  predictions <- rep(NA_real_, nrow(x))
  getPredictionsForTreeRecursive <- function(tree, indices) {
    if (tree$var[1] == -1) {
      # Assigns in the calling environment by using <<-
      predictions[indices] <<- tree$value[1]
      return(1)
    }
    goesLeft <- x[indices, tree$var[1]] <= tree$value[1]
    headOfLeftBranch <- tree[-1,]
    n_nodes.left <- getPredictionsForTreeRecursive(
    headOfLeftBranch, indices[goesLeft])
    headOfRightBranch <- tree[seq.int(2 + n_nodes.left, nrow(tree)),]
    n_nodes.right <- getPredictionsForTreeRecursive(
    headOfRightBranch, indices[!goesLeft])
    return(1 + n_nodes.left + n_nodes.right)
  }
  getPredictionsForTreeRecursive(tree, seq_len(nrow(x)))
  return(predictions)
}

## extract trees
trees <- extract(bartFit, "trees")

## sum over trees in sample 1
p <- numeric(3)
for(i in 1:3) {
  treeOfInterest <- subset(trees, chain == 1 & sample == 1 & tree == i) 
  p[i] <- getPredictionsForTree(treeOfInterest, x[1, , drop = FALSE])
}
sum(p)
[1] -0.3402133
## this does not equal the first sample for the equivalent observation
predict(bartFit, x, type = "bart")[1, 1]
[1] 6.383313

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions