|
4 | 4 | import numpy as np |
5 | 5 | import scipy |
6 | 6 | from numpy.testing import assert_almost_equal |
7 | | -from sklearn.datasets import make_regression |
| 7 | +from sklearn.datasets import make_regression, make_classification |
| 8 | +import json |
8 | 9 |
|
9 | 10 | try: |
10 | 11 | from xgboost import XGBClassifier, XGBRegressor |
@@ -137,6 +138,111 @@ def test_xgbclassifier_sparse_no_base_score(self): |
137 | 138 | got = sess.run(None, feeds)[1] |
138 | 139 | assert_almost_equal(expected.reshape((-1, 2)), got, decimal=4) |
139 | 140 |
|
| 141 | + @unittest.skipIf(XGBRegressor is None, "xgboost is not available") |
| 142 | + def test_xgbclassifier_multiclass_base_score(self): |
| 143 | + """Test multiclass classifier - xgboost 3 can have different base_scores per class""" |
| 144 | + X, y = make_classification( |
| 145 | + n_samples=200, n_features=10, n_classes=3, |
| 146 | + n_informative=5, n_redundant=0, random_state=42 |
| 147 | + ) |
| 148 | + X = X.astype(np.float32) |
| 149 | + |
| 150 | + clf = XGBClassifier(n_estimators=3, max_depth=4, random_state=42) |
| 151 | + clf.fit(X, y) |
| 152 | + expected = clf.predict_proba(X).astype(np.float32) |
| 153 | + |
| 154 | + onx = convert_xgboost( |
| 155 | + clf, |
| 156 | + initial_types=[("X", FloatTensorType(shape=[None, None]))], |
| 157 | + target_opset=TARGET_OPSET, |
| 158 | + ) |
| 159 | + feeds = {"X": X} |
| 160 | + |
| 161 | + sess = InferenceSession( |
| 162 | + onx.SerializeToString(), providers=["CPUExecutionProvider"] |
| 163 | + ) |
| 164 | + got = sess.run(None, feeds)[1] |
| 165 | + assert_almost_equal(expected, got, decimal=4) |
| 166 | + |
| 167 | + @unittest.skipIf(XGBRegressor is None, "xgboost is not available") |
| 168 | + def test_xgbclassifier_multiclass_base_score_in_onnx(self): |
| 169 | + """Verify that base_values are actually present in the ONNX graph""" |
| 170 | + X, y = make_classification( |
| 171 | + n_samples=200, n_features=10, n_classes=3, |
| 172 | + n_informative=5, n_redundant=0, random_state=42 |
| 173 | + ) |
| 174 | + X = X.astype(np.float32) |
| 175 | + |
| 176 | + clf = XGBClassifier(n_estimators=3, max_depth=4, random_state=42) |
| 177 | + clf.fit(X, y) |
| 178 | + |
| 179 | + config = json.loads(clf.get_booster().save_config()) |
| 180 | + base_score_str = config["learner"]["learner_model_param"]["base_score"] |
| 181 | + |
| 182 | + onx = convert_xgboost( |
| 183 | + clf, |
| 184 | + initial_types=[("X", FloatTensorType(shape=[None, None]))], |
| 185 | + target_opset=TARGET_OPSET, |
| 186 | + ) |
| 187 | + |
| 188 | + tree_ensemble_node = None |
| 189 | + for node in onx.graph.node: |
| 190 | + if node.op_type == "TreeEnsembleClassifier": |
| 191 | + tree_ensemble_node = node |
| 192 | + break |
| 193 | + |
| 194 | + self.assertIsNotNone(tree_ensemble_node, "TreeEnsembleClassifier node not found") |
| 195 | + |
| 196 | + base_values = None |
| 197 | + for attr in tree_ensemble_node.attribute: |
| 198 | + if attr.name == "base_values": |
| 199 | + base_values = list(attr.floats) |
| 200 | + break |
| 201 | + |
| 202 | + self.assertIsNotNone(base_values, "base_values attribute not found in ONNX model") |
| 203 | + self.assertEqual(len(base_values), 3, "base_values should have 3 elements for 3-class problem") |
| 204 | + |
| 205 | + # In xgboost 3+, base_score is a string array like "[3.4E-1,3.3E-1,3.3E-1]" |
| 206 | + # Verify that base_values in ONNX match the xgboost config |
| 207 | + if base_score_str.startswith("[") and base_score_str.endswith("]"): |
| 208 | + expected_base_scores = json.loads(base_score_str) |
| 209 | + for i, val in enumerate(base_values): |
| 210 | + if i < len(expected_base_scores): |
| 211 | + self.assertAlmostEqual(val, expected_base_scores[i], places=5) |
| 212 | + |
| 213 | + @unittest.skipIf(XGBRegressor is None, "xgboost is not available") |
| 214 | + def test_xgbregressor_base_score_in_onnx(self): |
| 215 | + """Verify that regressor base_values are present in the ONNX graph""" |
| 216 | + X, y = make_regression(n_samples=200, n_features=10, random_state=42) |
| 217 | + X = X.astype(np.float32) |
| 218 | + y = y.astype(np.float32) |
| 219 | + |
| 220 | + reg = XGBRegressor(n_estimators=3, max_depth=4, random_state=42) |
| 221 | + reg.fit(X, y) |
| 222 | + |
| 223 | + onx = convert_xgboost( |
| 224 | + reg, |
| 225 | + initial_types=[("X", FloatTensorType(shape=[None, None]))], |
| 226 | + target_opset=TARGET_OPSET, |
| 227 | + ) |
| 228 | + |
| 229 | + tree_ensemble_node = None |
| 230 | + for node in onx.graph.node: |
| 231 | + if node.op_type == "TreeEnsembleRegressor": |
| 232 | + tree_ensemble_node = node |
| 233 | + break |
| 234 | + |
| 235 | + self.assertIsNotNone(tree_ensemble_node, "TreeEnsembleRegressor node not found") |
| 236 | + |
| 237 | + base_values = None |
| 238 | + for attr in tree_ensemble_node.attribute: |
| 239 | + if attr.name == "base_values": |
| 240 | + base_values = list(attr.floats) |
| 241 | + break |
| 242 | + |
| 243 | + self.assertIsNotNone(base_values, "base_values attribute not found in ONNX model") |
| 244 | + self.assertGreater(len(base_values), 0, "base_values should not be empty") |
| 245 | + |
140 | 246 |
|
141 | 247 | if __name__ == "__main__": |
142 | 248 | unittest.main(verbosity=2) |
0 commit comments