|
1 | | -from sklearn import __version__ as sklearn_version |
2 | 1 | from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, clone, is_classifier, is_regressor |
3 | 2 | from sklearn.utils.multiclass import unique_labels |
4 | | -from sklearn.utils.validation import _check_sample_weight, check_is_fitted |
5 | | - |
6 | | - |
7 | | -def parse_version(version): |
8 | | - return tuple(map(int, version.split(".")[:2])) |
9 | | - |
10 | | - |
11 | | -# TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0 |
12 | | -sklearn_supports_validation = parse_version(sklearn_version) >= (1, 6) |
13 | | -if sklearn_supports_validation: |
14 | | - from sklearn.utils.validation import validate_data |
| 3 | +from sklearn.utils.validation import _check_sample_weight, check_is_fitted, validate_data |
15 | 4 |
|
16 | 5 |
|
17 | 6 | class GlobalRegressor(RegressorMixin, BaseEstimator): |
@@ -45,11 +34,7 @@ def fit(self, X, y, sample_weight=None): |
45 | 34 | if not is_regressor(self.base_estimator): |
46 | 35 | raise ValueError(f"base_estimator must be a regressor. Got {self.base_estimator.__class__.__name__} instead.") |
47 | 36 |
|
48 | | - # TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0 |
49 | | - if sklearn_supports_validation: |
50 | | - X, y = validate_data(self, X, y) |
51 | | - else: |
52 | | - X, y = self._validate_data(X, y) |
| 37 | + X, y = validate_data(self, X, y) |
53 | 38 | _check_sample_weight(sample_weight, X) |
54 | 39 | self._fitted_learner = clone(self.base_estimator) |
55 | 40 | self._fitted_learner.fit(X, y) |
@@ -101,11 +86,7 @@ def fit(self, X, y, sample_weight=None): |
101 | 86 | if not is_classifier(self.base_estimator): |
102 | 87 | raise ValueError(f"base_estimator must be a classifier. Got {self.base_estimator.__class__.__name__} instead.") |
103 | 88 |
|
104 | | - # TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0 |
105 | | - if sklearn_supports_validation: |
106 | | - X, y = validate_data(self, X, y) |
107 | | - else: |
108 | | - X, y = self._validate_data(X, y) |
| 89 | + X, y = validate_data(self, X, y) |
109 | 90 | _check_sample_weight(sample_weight, X) |
110 | 91 | self.classes_ = unique_labels(y) |
111 | 92 | self._fitted_learner = clone(self.base_estimator) |
|
0 commit comments