Skip to content

Commit ba1c51f

Browse files
committed
adjust manual plr for classifier ml_l
1 parent ec91b15 commit ba1c51f

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

doubleml/plm/tests/_utils_plr_manual.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,10 @@ def fit_nuisance_plr_classifier(
153153
y, x, d, learner_l, learner_m, learner_g, smpls, fit_g=True, l_params=None, m_params=None, g_params=None
154154
):
155155
ml_l = clone(learner_l)
156-
l_hat = fit_predict(y, x, ml_l, l_params, smpls)
156+
if is_classifier(learner_l):
157+
l_hat = fit_predict_proba(y, x, ml_l, l_params, smpls)
158+
else:
159+
l_hat = fit_predict(y, x, ml_l, l_params, smpls)
157160

158161
ml_m = clone(learner_m)
159162
m_hat = fit_predict_proba(d, x, ml_m, m_params, smpls)

0 commit comments

Comments
 (0)