-
Notifications
You must be signed in to change notification settings - Fork 24
Description
Hi,
Many thanks for creating this package, it's working really well for me but now I've hit a snag. I was trying to work through the vignette on "Working with Saved Trees", and I cannot get the predictions from the saved trees to match those from the original bart model.
A reproducible example is below, where I simulate some data, and then use the extract() function to extract the trees, and then apply the getPredictionsForTree() function in the vignette to the first individual in the training data.
If I run this over the 3 trees in chain 1 sample 1, and sum them, then I think I should obtain the first element of running predict() on the original bart object. However, it's nowhere near, and I can't figure out why. I have written my own code to rebuild the trees as well, and this matches the getPredictionsForTree() function, so I wonder if there's a standardisation or something that I'm missing somewhere. I have tried various things but to no avail so far. Any guidance on what I'm doing wrong would be greatly appreciated. The leaf values in the extracted trees look much smaller than I would expect given the scale of the response variable, so I wonder if there's something going on in there that I am missing?
Many thanks,
TJ
library(dbarts)
## simulate data
f <- function(x) {
10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
10 * x[,4] + 5 * x[,5]
}
set.seed(99)
sigma <- 1.0
n <- 100
x <- matrix(runif(n * 10), n, 10)
y <- rnorm(n, f(x), sigma)
data <- data.frame(x, y)
## fit BART model
bartFit <- bart(
y ~ ., data,
ndpost = 4, # number of posterior samples
nskip = 1000, # number of "warmup" samples to discard
nchain = 2, # number of independent, parallel chains
nthread = 1, # units of parallel execution
ntree = 3, # number of trees per chain
seed = 2, # chosen to generate a deep tree
keeptrees = TRUE,
verbose = FALSE
)
## code from vignette
getPredictionsForTree <- function(tree, x) {
predictions <- rep(NA_real_, nrow(x))
getPredictionsForTreeRecursive <- function(tree, indices) {
if (tree$var[1] == -1) {
# Assigns in the calling environment by using <<-
predictions[indices] <<- tree$value[1]
return(1)
}
goesLeft <- x[indices, tree$var[1]] <= tree$value[1]
headOfLeftBranch <- tree[-1,]
n_nodes.left <- getPredictionsForTreeRecursive(
headOfLeftBranch, indices[goesLeft])
headOfRightBranch <- tree[seq.int(2 + n_nodes.left, nrow(tree)),]
n_nodes.right <- getPredictionsForTreeRecursive(
headOfRightBranch, indices[!goesLeft])
return(1 + n_nodes.left + n_nodes.right)
}
getPredictionsForTreeRecursive(tree, seq_len(nrow(x)))
return(predictions)
}
## extract trees
trees <- extract(bartFit, "trees")
## sum over trees in sample 1
p <- numeric(3)
for(i in 1:3) {
treeOfInterest <- subset(trees, chain == 1 & sample == 1 & tree == i)
p[i] <- getPredictionsForTree(treeOfInterest, x[1, , drop = FALSE])
}
sum(p)
[1] -0.3402133
## this does not equal the first sample for the equivalent observation
predict(bartFit, x, type = "bart")[1, 1]
[1] 6.383313