@@ -263,9 +263,9 @@ def _check_data(self, obj_dml_data):
263263 return
264264
265265 def _nuisance_est (self , smpls , n_jobs_cv , external_predictions , return_models = False ):
266- x , y = check_X_y (self ._dml_data .x , self ._dml_data .y , force_all_finite = False )
267- x , z = check_X_y (x , np .ravel (self ._dml_data .z ), force_all_finite = False )
268- x , d = check_X_y (x , self ._dml_data .d , force_all_finite = False )
266+ x , y = check_X_y (self ._dml_data .x , self ._dml_data .y , ensure_all_finite = False )
267+ x , z = check_X_y (x , np .ravel (self ._dml_data .z ), ensure_all_finite = False )
268+ x , d = check_X_y (x , self ._dml_data .d , ensure_all_finite = False )
269269
270270 # get train indices for z == 0 and z == 1
271271 smpls_z0 , smpls_z1 = _get_cond_smpls (smpls , z )
@@ -448,9 +448,9 @@ def _score_elements(self, y, z, d, g_hat0, g_hat1, m_hat, r_hat0, r_hat1, smpls)
448448 def _nuisance_tuning (
449449 self , smpls , param_grids , scoring_methods , n_folds_tune , n_jobs_cv , search_mode , n_iter_randomized_search
450450 ):
451- x , y = check_X_y (self ._dml_data .x , self ._dml_data .y , force_all_finite = False )
452- x , z = check_X_y (x , np .ravel (self ._dml_data .z ), force_all_finite = False )
453- x , d = check_X_y (x , self ._dml_data .d , force_all_finite = False )
451+ x , y = check_X_y (self ._dml_data .x , self ._dml_data .y , ensure_all_finite = False )
452+ x , z = check_X_y (x , np .ravel (self ._dml_data .z ), ensure_all_finite = False )
453+ x , d = check_X_y (x , self ._dml_data .d , ensure_all_finite = False )
454454
455455 # get train indices for z == 0 and z == 1
456456 smpls_z0 , smpls_z1 = _get_cond_smpls (smpls , z )
0 commit comments