Skip to content

Commit 58e2097

Browse files
committed
Support xgboost 3
Signed-off-by: Asher Wright <asherw@squareup.com>
1 parent aee69af commit 58e2097

File tree

4 files changed

+149
-16
lines changed

4 files changed

+149
-16
lines changed

onnxmltools/convert/xgboost/_parse.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,16 @@ def _get_attributes(booster):
121121
kwargs["objective"] = "binary:logistic"
122122

123123
if "base_score" not in kwargs:
124-
kwargs["base_score"] = 0.5
124+
kwargs["base_score"] = [0.5]
125125
elif isinstance(kwargs["base_score"], str):
126-
kwargs["base_score"] = float(kwargs["base_score"])
126+
base_score_str = kwargs["base_score"]
127+
if base_score_str.startswith("[") and base_score_str.endswith("]"):
128+
# xgboost >= 3.0: base_score is a string array
129+
bs = json.loads(base_score_str)
130+
kwargs["base_score"] = [float(x) for x in bs]
131+
else:
132+
# xgboost >= 2, < 3: base_score is a string float
133+
kwargs["base_score"] = [float(base_score_str)]
127134
return kwargs
128135

129136

onnxmltools/convert/xgboost/common.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,15 @@ def get_xgb_params(xgb_node):
2929
if xgb_node.n_estimators is not None:
3030
params["n_estimators"] = xgb_node.n_estimators
3131
if "base_score" in config["learner"]["learner_model_param"]:
32-
bs = float(config["learner"]["learner_model_param"]["base_score"])
33-
# xgboost >= 2.0
34-
params["base_score"] = bs
32+
base_score_raw = config["learner"]["learner_model_param"]["base_score"]
33+
# xgboost >= 3.0: base_score is a string array
34+
if base_score_raw.startswith("[") and base_score_raw.endswith("]"):
35+
base_score = json.loads(base_score_raw)
36+
params["base_score"] = [float(x) for x in base_score]
37+
else:
38+
# xgboost >= 2, < 3: base_score is a string float
39+
params["base_score"] = [float(base_score_raw)]
40+
3541
if "num_target" in config["learner"]["learner_model_param"]:
3642
params["n_targets"] = int(
3743
config["learner"]["learner_model_param"]["num_target"]

onnxmltools/convert/xgboost/operator_converters/XGBoost.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def common_members(xgb_node, inputs):
5353
else:
5454
best_ntree_limit = params.get("best_ntree_limit", None)
5555
if base_score is None:
56-
base_score = 0.5
56+
base_score = [0.5]
5757
booster = xgb_node.get_booster()
5858
# The json format was available in October 2017.
5959
# XGBoost 0.7 was the first version released with it.
@@ -259,7 +259,10 @@ def convert(scope, operator, container):
259259
raise RuntimeError("Objective '{}' not supported.".format(objective))
260260

261261
attr_pairs = XGBRegressorConverter._get_default_tree_attribute_pairs()
262-
attr_pairs["base_values"] = [base_score]
262+
if isinstance(base_score, list):
263+
attr_pairs["base_values"] = base_score
264+
else:
265+
attr_pairs["base_values"] = [base_score]
263266

264267
if best_ntree_limit and best_ntree_limit < len(js_trees):
265268
js_trees = js_trees[:best_ntree_limit]
@@ -288,7 +291,8 @@ def convert(scope, operator, container):
288291

289292
if objective == "count:poisson":
290293
cst = scope.get_unique_variable_name("poisson")
291-
container.add_initializer(cst, TensorProto.FLOAT, [1], [base_score])
294+
# base_score is always a list
295+
container.add_initializer(cst, TensorProto.FLOAT, [len(base_score)], base_score)
292296
new_name = scope.get_unique_variable_name("exp")
293297
container.add_node("Exp", names, [new_name])
294298
container.add_node("Mul", [new_name, cst], operator.output_full_names)
@@ -350,17 +354,27 @@ def convert(scope, operator, container):
350354
attr_pairs["post_transform"] = "LOGISTIC"
351355
attr_pairs["class_ids"] = [0 for v in attr_pairs["class_treeids"]]
352356
if js_trees[0].get("leaf", None) == 0:
353-
attr_pairs["base_values"] = [base_score]
354-
elif base_score != 0.5:
355-
# 0.5 -> cst = 0
356-
cst = -np.log(1 / np.float32(base_score) - 1.0)
357-
attr_pairs["base_values"] = [cst]
357+
# base_score is always a list
358+
attr_pairs["base_values"] = base_score
359+
else:
360+
# Transform base_score - for binary, use first element
361+
bs_val = base_score[0]
362+
if bs_val != 0.5:
363+
# 0.5 -> cst = 0
364+
cst = -np.log(1 / np.float32(bs_val) - 1.0)
365+
attr_pairs["base_values"] = [cst]
358366
else:
359-
attr_pairs["base_values"] = [base_score]
367+
# base_score is always a list
368+
attr_pairs["base_values"] = base_score
360369
else:
361370
# See https://github.com/dmlc/xgboost/blob/main/src/common/math.h#L35.
362371
attr_pairs["post_transform"] = "SOFTMAX"
363-
attr_pairs["base_values"] = [base_score for n in range(ncl)]
372+
# base_score is always a list - may have different values per class in xgboost 3+
373+
# If base_score has fewer elements than classes, replicate to match
374+
if len(base_score) == 1:
375+
attr_pairs["base_values"] = base_score * ncl
376+
else:
377+
attr_pairs["base_values"] = base_score
364378
attr_pairs["class_ids"] = [v % ncl for v in attr_pairs["class_treeids"]]
365379

366380
classes = xgb_node.classes_

tests/xgboost/test_xgboost_converters_base_score.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import numpy as np
55
import scipy
66
from numpy.testing import assert_almost_equal
7-
from sklearn.datasets import make_regression
7+
from sklearn.datasets import make_regression, make_classification
8+
import json
89

910
try:
1011
from xgboost import XGBClassifier, XGBRegressor
@@ -137,6 +138,111 @@ def test_xgbclassifier_sparse_no_base_score(self):
137138
got = sess.run(None, feeds)[1]
138139
assert_almost_equal(expected.reshape((-1, 2)), got, decimal=4)
139140

141+
@unittest.skipIf(XGBRegressor is None, "xgboost is not available")
142+
def test_xgbclassifier_multiclass_base_score(self):
143+
"""Test multiclass classifier - xgboost 3 can have different base_scores per class"""
144+
X, y = make_classification(
145+
n_samples=200, n_features=10, n_classes=3,
146+
n_informative=5, n_redundant=0, random_state=42
147+
)
148+
X = X.astype(np.float32)
149+
150+
clf = XGBClassifier(n_estimators=3, max_depth=4, random_state=42)
151+
clf.fit(X, y)
152+
expected = clf.predict_proba(X).astype(np.float32)
153+
154+
onx = convert_xgboost(
155+
clf,
156+
initial_types=[("X", FloatTensorType(shape=[None, None]))],
157+
target_opset=TARGET_OPSET,
158+
)
159+
feeds = {"X": X}
160+
161+
sess = InferenceSession(
162+
onx.SerializeToString(), providers=["CPUExecutionProvider"]
163+
)
164+
got = sess.run(None, feeds)[1]
165+
assert_almost_equal(expected, got, decimal=4)
166+
167+
@unittest.skipIf(XGBRegressor is None, "xgboost is not available")
168+
def test_xgbclassifier_multiclass_base_score_in_onnx(self):
169+
"""Verify that base_values are actually present in the ONNX graph"""
170+
X, y = make_classification(
171+
n_samples=200, n_features=10, n_classes=3,
172+
n_informative=5, n_redundant=0, random_state=42
173+
)
174+
X = X.astype(np.float32)
175+
176+
clf = XGBClassifier(n_estimators=3, max_depth=4, random_state=42)
177+
clf.fit(X, y)
178+
179+
config = json.loads(clf.get_booster().save_config())
180+
base_score_str = config["learner"]["learner_model_param"]["base_score"]
181+
182+
onx = convert_xgboost(
183+
clf,
184+
initial_types=[("X", FloatTensorType(shape=[None, None]))],
185+
target_opset=TARGET_OPSET,
186+
)
187+
188+
tree_ensemble_node = None
189+
for node in onx.graph.node:
190+
if node.op_type == "TreeEnsembleClassifier":
191+
tree_ensemble_node = node
192+
break
193+
194+
self.assertIsNotNone(tree_ensemble_node, "TreeEnsembleClassifier node not found")
195+
196+
base_values = None
197+
for attr in tree_ensemble_node.attribute:
198+
if attr.name == "base_values":
199+
base_values = list(attr.floats)
200+
break
201+
202+
self.assertIsNotNone(base_values, "base_values attribute not found in ONNX model")
203+
self.assertEqual(len(base_values), 3, "base_values should have 3 elements for 3-class problem")
204+
205+
# In xgboost 3+, base_score is a string array like "[3.4E-1,3.3E-1,3.3E-1]"
206+
# Verify that base_values in ONNX match the xgboost config
207+
if base_score_str.startswith("[") and base_score_str.endswith("]"):
208+
expected_base_scores = json.loads(base_score_str)
209+
for i, val in enumerate(base_values):
210+
if i < len(expected_base_scores):
211+
self.assertAlmostEqual(val, expected_base_scores[i], places=5)
212+
213+
@unittest.skipIf(XGBRegressor is None, "xgboost is not available")
214+
def test_xgbregressor_base_score_in_onnx(self):
215+
"""Verify that regressor base_values are present in the ONNX graph"""
216+
X, y = make_regression(n_samples=200, n_features=10, random_state=42)
217+
X = X.astype(np.float32)
218+
y = y.astype(np.float32)
219+
220+
reg = XGBRegressor(n_estimators=3, max_depth=4, random_state=42)
221+
reg.fit(X, y)
222+
223+
onx = convert_xgboost(
224+
reg,
225+
initial_types=[("X", FloatTensorType(shape=[None, None]))],
226+
target_opset=TARGET_OPSET,
227+
)
228+
229+
tree_ensemble_node = None
230+
for node in onx.graph.node:
231+
if node.op_type == "TreeEnsembleRegressor":
232+
tree_ensemble_node = node
233+
break
234+
235+
self.assertIsNotNone(tree_ensemble_node, "TreeEnsembleRegressor node not found")
236+
237+
base_values = None
238+
for attr in tree_ensemble_node.attribute:
239+
if attr.name == "base_values":
240+
base_values = list(attr.floats)
241+
break
242+
243+
self.assertIsNotNone(base_values, "base_values attribute not found in ONNX model")
244+
self.assertGreater(len(base_values), 0, "base_values should not be empty")
245+
140246

141247
if __name__ == "__main__":
142248
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)