Skip to content

Commit 5c197ab

Browse files
Make work with newer versions of xgboost (#1307)
* make xgboost tests version robust * make xgb functions backwards compatible with 2.0.0.0 version * fix type * use correct version for xgboost switching * add news * more robust xgboost param moving * don't access new xgboost function directly * check for installation * polish news * note in docs --------- Co-authored-by: ‘topepo’ <mxkuhn@gmail.com>
1 parent bea062d commit 5c197ab

File tree

5 files changed

+259
-101
lines changed

5 files changed

+259
-101
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
* `surv_reg()` is now defunct and will error if called. Please use `survival_reg()` instead (#1206).
1313

14+
* Enable parsnip to work with xgboost version > 2.0.0.0. (#1227)
1415

1516
# parsnip 1.3.3
1617

R/boost_tree.R

Lines changed: 110 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ xgb_train <- function(
271271
event_level = c("first", "second"),
272272
...
273273
) {
274+
rlang::check_installed("xgboost")
274275
event_level <- rlang::arg_match(event_level, c("first", "second"))
275276
others <- list(...)
276277

@@ -340,31 +341,70 @@ xgb_train <- function(
340341

341342
others <- process_others(others, arg_list)
342343

344+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
345+
if (!is.null(num_class) && num_class > 2) {
346+
arg_list$num_class <- num_class
347+
}
348+
349+
param_names <- names(
350+
formals(
351+
getFromNamespace("xgb.params", ns = "xgboost")
352+
)
353+
)
354+
355+
if (any(param_names %in% names(others))) {
356+
elements <- param_names[param_names %in% names(others)]
357+
358+
for (element in elements) {
359+
arg_list[[element]] <- others[[element]]
360+
others[[element]] <- NULL
361+
}
362+
}
363+
364+
if (is.null(arg_list$objective)) {
365+
if (is.numeric(y)) {
366+
arg_list$objective <- "reg:squarederror"
367+
} else {
368+
if (num_class == 2) {
369+
arg_list$objective <- "binary:logistic"
370+
} else {
371+
arg_list$objective <- "multi:softprob"
372+
}
373+
}
374+
}
375+
}
376+
343377
main_args <- c(
344378
list(
345379
data = quote(x$data),
346-
watchlist = quote(x$watchlist),
347380
params = arg_list,
348381
nrounds = nrounds,
349382
early_stopping_rounds = early_stop
350383
),
351384
others
352385
)
386+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
387+
main_args$evals <- quote(x$watchlist)
388+
} else {
389+
main_args$watchlist <- quote(x$watchlist)
390+
}
353391

354-
if (is.null(main_args$objective)) {
355-
if (is.numeric(y)) {
356-
main_args$objective <- "reg:squarederror"
357-
} else {
358-
if (num_class == 2) {
359-
main_args$objective <- "binary:logistic"
392+
if (utils::packageVersion("xgboost") < "2.0.0.0") {
393+
if (is.null(main_args$objective)) {
394+
if (is.numeric(y)) {
395+
main_args$objective <- "reg:squarederror"
360396
} else {
361-
main_args$objective <- "multi:softprob"
397+
if (num_class == 2) {
398+
main_args$objective <- "binary:logistic"
399+
} else {
400+
main_args$objective <- "multi:softprob"
401+
}
362402
}
363403
}
364-
}
365404

366-
if (!is.null(num_class) && num_class > 2) {
367-
main_args$num_class <- num_class
405+
if (!is.null(num_class) && num_class > 2) {
406+
main_args$num_class <- num_class
407+
}
368408
}
369409

370410
call <- make_call(fun = "xgb.train", ns = "xgboost", main_args)
@@ -471,6 +511,7 @@ as_xgb_data <- function(
471511
event_level = "first",
472512
...
473513
) {
514+
rlang::check_installed("xgboost")
474515
lvls <- levels(y)
475516
n <- nrow(x)
476517

@@ -506,21 +547,52 @@ as_xgb_data <- function(
506547
watch_list <- list(validation = val_data)
507548

508549
info_list <- list(label = y[trn_index])
509-
if (!is.null(weights)) {
510-
info_list$weight <- weights[trn_index]
550+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
551+
if (!is.null(weights)) {
552+
dat <- xgboost::xgb.DMatrix(
553+
data = x[trn_index, , drop = FALSE],
554+
missing = NA,
555+
label = y[trn_index],
556+
weight = weights[trn_index]
557+
)
558+
} else {
559+
dat <- xgboost::xgb.DMatrix(
560+
data = x[trn_index, , drop = FALSE],
561+
missing = NA,
562+
label = y[trn_index]
563+
)
564+
}
565+
} else {
566+
if (!is.null(weights)) {
567+
info_list$weight <- weights[trn_index]
568+
}
569+
dat <- xgboost::xgb.DMatrix(
570+
data = x[trn_index, , drop = FALSE],
571+
missing = NA,
572+
info = info_list
573+
)
511574
}
512-
dat <- xgboost::xgb.DMatrix(
513-
data = x[trn_index, , drop = FALSE],
514-
missing = NA,
515-
info = info_list
516-
)
517575
} else {
518-
info_list <- list(label = y)
519-
if (!is.null(weights)) {
520-
info_list$weight <- weights
576+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
577+
if (!is.null(weights)) {
578+
dat <- xgboost::xgb.DMatrix(
579+
x,
580+
missing = NA,
581+
label = y,
582+
weight = weights
583+
)
584+
} else {
585+
dat <- xgboost::xgb.DMatrix(x, missing = NA, label = y)
586+
}
587+
watch_list <- list(training = dat)
588+
} else {
589+
info_list <- list(label = y)
590+
if (!is.null(weights)) {
591+
info_list$weight <- weights
592+
}
593+
dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list)
594+
watch_list <- list(training = dat)
521595
}
522-
dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list)
523-
watch_list <- list(training = dat)
524596
}
525597
} else {
526598
dat <- xgboost::setinfo(x, "label", y)
@@ -579,12 +651,21 @@ multi_predict._xgb.Booster <-
579651
}
580652

581653
xgb_by_tree <- function(tree, object, new_data, type, ...) {
582-
pred <- xgb_predict(
583-
object$fit,
584-
new_data = new_data,
585-
iterationrange = c(1, tree + 1),
586-
ntreelimit = NULL
587-
)
654+
rlang::check_installed("xgboost")
655+
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
656+
pred <- xgb_predict(
657+
object$fit,
658+
new_data = new_data,
659+
iterationrange = c(1, tree + 1)
660+
)
661+
} else {
662+
pred <- xgb_predict(
663+
object$fit,
664+
new_data = new_data,
665+
iterationrange = c(1, tree + 1),
666+
ntreelimit = NULL
667+
)
668+
}
588669

589670
# switch based on prediction type
590671
if (object$spec$mode == "regression") {

man/rmd/boost_tree_xgboost.Rmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#| include: false
44
```
55

6-
`r descr_models("boost_tree", "xgboost")`
6+
`r descr_models("boost_tree", "xgboost")`. Note that in late 2025, a new version of xgboost was released with differences in its interface and model objects. This version of parsnip should work with either version.
77

88
## Tuning Parameters
99

man/rmd/boost_tree_xgboost.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22

33

4-
For this engine, there are multiple modes: classification and regression
4+
For this engine, there are multiple modes: classification and regression. Note that in late 2025, a new version of xgboost was released with differences in its interface and model objects. This version of parsnip should work with either version.
55

66
## Tuning Parameters
77

0 commit comments

Comments
 (0)