Skip to content

Commit 2b46bc6

Browse files
committed
update sklearn dependency to 1.6.0
1 parent e430fe1 commit 2b46bc6

File tree

1 file changed

+3
-22
lines changed

1 file changed

+3
-22
lines changed

doubleml/utils/global_learner.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,6 @@
1-
from sklearn import __version__ as sklearn_version
21
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, clone, is_classifier, is_regressor
32
from sklearn.utils.multiclass import unique_labels
4-
from sklearn.utils.validation import _check_sample_weight, check_is_fitted
5-
6-
7-
def parse_version(version):
8-
return tuple(map(int, version.split(".")[:2]))
9-
10-
11-
# TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0
12-
sklearn_supports_validation = parse_version(sklearn_version) >= (1, 6)
13-
if sklearn_supports_validation:
14-
from sklearn.utils.validation import validate_data
3+
from sklearn.utils.validation import _check_sample_weight, check_is_fitted, validate_data
154

165

176
class GlobalRegressor(RegressorMixin, BaseEstimator):
@@ -45,11 +34,7 @@ def fit(self, X, y, sample_weight=None):
4534
if not is_regressor(self.base_estimator):
4635
raise ValueError(f"base_estimator must be a regressor. Got {self.base_estimator.__class__.__name__} instead.")
4736

48-
# TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0
49-
if sklearn_supports_validation:
50-
X, y = validate_data(self, X, y)
51-
else:
52-
X, y = self._validate_data(X, y)
37+
X, y = validate_data(self, X, y)
5338
_check_sample_weight(sample_weight, X)
5439
self._fitted_learner = clone(self.base_estimator)
5540
self._fitted_learner.fit(X, y)
@@ -101,11 +86,7 @@ def fit(self, X, y, sample_weight=None):
10186
if not is_classifier(self.base_estimator):
10287
raise ValueError(f"base_estimator must be a classifier. Got {self.base_estimator.__class__.__name__} instead.")
10388

104-
# TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0
105-
if sklearn_supports_validation:
106-
X, y = validate_data(self, X, y)
107-
else:
108-
X, y = self._validate_data(X, y)
89+
X, y = validate_data(self, X, y)
10990
_check_sample_weight(sample_weight, X)
11091
self.classes_ = unique_labels(y)
11192
self._fitted_learner = clone(self.base_estimator)

0 commit comments

Comments
 (0)