Skip to content

Predict function not working #13

@sreedta8

Description

@sreedta8

Here is my training model:

fit40 <- stan4bart(
  formula = sales ~
    hdummys+tv_ads+dig_ads+prt_ads+ # linear component ("fixef")
    (1|dmaseqid) + # multilevel ("ranef") #damaseqid is a factor variable
    bart(. -region -coupons -hdummys -tv_ads -dig_ads -prt_ads), # use bart for other variables
  verbose = -1, # suppress ALL output
  # low numbers for illustration
  data = train, # 8400 rows
  chains = 1, iter = 100, bart_args = list(n.trees = 5,keepTrees = TRUE)) # using only 1 chain

this runs without a problem. Then I use the predict function as follows:

predict(fit40, newdata=test, type = c("ev", "ppd", "indiv.fixef", "indiv.ranef","indiv.bart"), # test data has 2520 rows
combine_chains = FALSE, # has only 1 chain, no need to combine
sample_new_levels = TRUE)

I get the following error:

Warning message in validateXTest(newdata, attr(data@x, "term.labels"), ncol(data@x), :
“column names of 'test' does not equal that of 'x': 'dmaseqid.1, dmaseqid.2, dmaseqid.3, dmaseqid.4, dmaseqid.5, dmaseqid.6, dmaseqid.7, dmaseqid.8, dmaseqid.9, dmaseqid.10, dmaseqid.11, dmaseqid.12, dmaseqid.13, dmaseqid.14, dmaseqid.15, dmaseqid.16, dmaseqid.17, dmaseqid.18, dmaseqid.19, dmaseqid.20, dmaseqid.21, dmaseqid.22, dmaseqid.23, dmaseqid.24, dmaseqid.25, dmaseqid.26, dmaseqid.27, dmaseqid.28, dmaseqid.29, dmaseqid.30, dmaseqid.31, dmaseqid.32, dmaseqid.33, dmaseqid.34, dmaseqid.35, dmaseqid.36, dmaseqid.37, dmaseqid.38, dmaseqid.39, dmaseqid.40, dmaseqid.41, dmaseqid.42, dmaseqid.43, dmaseqid.44, dmaseqid.45, dmaseqid.46, dmaseqid.47, dmaseqid.48, dmaseqid.49, dmaseqid.50, dmaseqid.51, dmaseqid.52, dmaseqid.53, dmaseqid.54, dmaseqid.55, dmaseqid.56, dmaseqid.57, dmaseqid.58, dmaseqid.59, dmaseqid.60, dmaseqid.61, dmaseqid.62, dmaseqid.63, dmaseqid.64, dmaseqid.65, dmaseqid.66, dmaseqid.67, dmaseqid.68, dmaseqid.69, dmaseqid.70, dmaseqid.71, dmaseqid.72, dmaseqid.73, dmaseqid.74, dmaseqid.75, dmaseqid.76, dmaseqid.77, dmaseqid.78, dmaseqid.79, dmaseqid.80, dmaseqid.81, dmaseqid.82, dmaseqid.83, dmaseqid.84, dmaseqid.85, dmaseqid.86, dmaseqid.87, dmaseqid.88, dmaseqid.89, dmaseqid.90, dmaseqid.91, dmaseqid.92, dmaseqid.93, dmaseqid.94, dmaseqid.95, dmaseqid.96, dmaseqid.97, dmaseqid.98, dmaseqid.99, dmaseqid.100, dmaseqid.101, dmaseqid.102, dmaseqid.103, dmaseqid.104, dmaseqid.105, dmaseqid.106, dmaseqid.107, dmaseqid.108, dmaseqid.109, dmaseqid.110, dmaseqid.111, dmaseqid.112, dmaseqid.113, dmaseqid.114, dmaseqid.115, dmaseqid.116, dmaseqid.117, dmaseqid.118, dmaseqid.119, dmaseqid.120, dmaseqid.121, dmaseqid.122, dmaseqid.123, dmaseqid.124, dmaseqid.125, dmaseqid.126, dmaseqid.127, dmaseqid.128, dmaseqid.129, dmaseqid.130, dmaseqid.131, dmaseqid.132, dmaseqid.133, dmaseqid.134, dmaseqid.135, dmaseqid.136, dmaseqid.137, dmaseqid.138, dmaseqid.139, dmaseqid.140, dmaseqid.141, dmaseqid.142, dmaseqid.143, dmaseqid.144, dmaseqid.145, dmaseqid.146, dmaseqid.147, dmaseqid.148, dmaseqid.149, dmaseqid.150, dmaseqid.151, dmaseqid.152, dmaseqid.153, dmaseqid.154, dmaseqid.155, dmaseqid.156, dmaseqid.157, dmaseqid.158, dmaseqid.159, dmaseqid.160, dmaseqid.161, dmaseqid.162, dmaseqid.163, dmaseqid.164, dmaseqid.165, dmaseqid.166, dmaseqid.167, dmaseqid.168, dmaseqid.169, dmaseqid.170, dmaseqid.171, dmaseqid.172, dmaseqid.173, dmaseqid.174, dmaseqid.175, dmaseqid.176, dmaseqid.177, dmaseqid.178, dmaseqid.179, dmaseqid.180, dmaseqid.181, dmaseqid.182, dmaseqid.183, dmaseqid.184, dmaseqid.185, dmaseqid.186, dmaseqid.187, dmaseqid.188, dmaseqid.189, dmaseqid.190, dmaseqid.191, dmaseqid.192, dmaseqid.193, dmaseqid.194, dmaseqid.195, dmaseqid.196, dmaseqid.197, dmaseqid.198, dmaseqid.199, dmaseqid.200, dmaseqid.201, dmaseqid.202, dmaseqid.203, dmaseqid.204, dmaseqid.205, dmaseqid.206, dmaseqid.207, dmaseqid.208, dmaseqid.209, dmaseqid.210, hdummys, tv_ads, dig_ads, prt_ads, region, coupons'; match will be made by position”

Error in dimnames(indiv.bart) <- list(observation = NULL, sample = NULL, : length of 'dimnames' [3] must match that of 'dims' [2]
Traceback:

1. predict(fit40, newdata = test, type = c("ev", "ppd", "indiv.fixef", 
 .     "indiv.ranef", "indiv.bart"), combine_chains = FALSE, sample_new_levels = TRUE)
2. predict.stan4bartFit(fit40, newdata = test, type = c("ev", "ppd", 
 .     "indiv.fixef", "indiv.ranef", "indiv.bart"), combine_chains = FALSE, 
 .     sample_new_levels = TRUE)

Does this have a solution? My train and test data frames have the exactly the same columns, just the number of rows are different. I read here by using a single chain we can overcome the error that comes up with number of dimensions associated with bart component.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions