Skip to content

Commit 5b04677

Browse files
committed
update rdd weight check
1 parent f1e7668 commit 5b04677

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

doubleml/utils/_checks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import inspect
21
import warnings
32

43
import numpy as np
54
from sklearn.utils.multiclass import type_of_target
5+
from sklearn.utils.validation import has_fit_parameter
66

77

88
def _check_in_zero_one(value, name, include_zero=True, include_one=True):
@@ -514,7 +514,7 @@ def _check_sample_splitting(all_smpls, all_smpls_cluster, dml_data, is_cluster_d
514514

515515

516516
def _check_supports_sample_weights(learner, learner_name):
517-
if "sample_weight" not in inspect.signature(learner.fit).parameters:
517+
if not has_fit_parameter(learner, "sample_weight"):
518518
raise ValueError(
519519
f"The {learner_name} learner {str(learner)} does not support sample weights. "
520520
"Please choose a learner that supports sample weights."

0 commit comments

Comments
 (0)