@@ -428,86 +428,114 @@ def _nuisance_tuning(
428428 ):
429429 x , y = check_X_y (self ._dml_data .x , self ._dml_data .y , force_all_finite = False )
430430 x , d = check_X_y (x , self ._dml_data .d , force_all_finite = False )
431- # time indicator is used for selection (selection not available in DoubleMLData yet)
432431 x , s = check_X_y (x , self ._dml_data .s , force_all_finite = False )
433432
434433 if self ._score == "nonignorable" :
435434 z , _ = check_X_y (self ._dml_data .z , y , force_all_finite = False )
436- dx = np .column_stack ((x , d , z ))
437- else :
438- dx = np .column_stack ((x , d ))
439435
440436 if scoring_methods is None :
441437 scoring_methods = {"ml_g" : None , "ml_pi" : None , "ml_m" : None }
442438
443- # nuisance training sets conditional on d
444- _ , smpls_d0_s1 , _ , smpls_d1_s1 = _get_cond_smpls_2d (smpls , d , s )
445- train_inds = [train_index for (train_index , _ ) in smpls ]
446- train_inds_d0_s1 = [train_index for (train_index , _ ) in smpls_d0_s1 ]
447- train_inds_d1_s1 = [train_index for (train_index , _ ) in smpls_d1_s1 ]
448-
449- # hyperparameter tuning for ML
450- g_d0_tune_res = _dml_tune (
451- y ,
452- x ,
453- train_inds_d0_s1 ,
454- self ._learner ["ml_g" ],
455- param_grids ["ml_g" ],
456- scoring_methods ["ml_g" ],
457- n_folds_tune ,
458- n_jobs_cv ,
459- search_mode ,
460- n_iter_randomized_search ,
461- )
462- g_d1_tune_res = _dml_tune (
463- y ,
464- x ,
465- train_inds_d1_s1 ,
466- self ._learner ["ml_g" ],
467- param_grids ["ml_g" ],
468- scoring_methods ["ml_g" ],
469- n_folds_tune ,
470- n_jobs_cv ,
471- search_mode ,
472- n_iter_randomized_search ,
473- )
474- pi_tune_res = _dml_tune (
475- s ,
476- dx ,
477- train_inds ,
478- self ._learner ["ml_pi" ],
479- param_grids ["ml_pi" ],
480- scoring_methods ["ml_pi" ],
481- n_folds_tune ,
482- n_jobs_cv ,
483- search_mode ,
484- n_iter_randomized_search ,
485- )
486- m_tune_res = _dml_tune (
487- d ,
488- x ,
489- train_inds ,
490- self ._learner ["ml_m" ],
491- param_grids ["ml_m" ],
492- scoring_methods ["ml_m" ],
493- n_folds_tune ,
494- n_jobs_cv ,
495- search_mode ,
496- n_iter_randomized_search ,
497- )
439+ # Nested helper functions
440+ def tune_learner (target , features , train_indices , learner_key ):
441+ return _dml_tune (
442+ target ,
443+ features ,
444+ train_indices ,
445+ self ._learner [learner_key ],
446+ param_grids [learner_key ],
447+ scoring_methods [learner_key ],
448+ n_folds_tune ,
449+ n_jobs_cv ,
450+ search_mode ,
451+ n_iter_randomized_search ,
452+ )
498453
499- g_d0_best_params = [xx .best_params_ for xx in g_d0_tune_res ]
500- g_d1_best_params = [xx .best_params_ for xx in g_d1_tune_res ]
501- pi_best_params = [xx .best_params_ for xx in pi_tune_res ]
502- m_best_params = [xx .best_params_ for xx in m_tune_res ]
454+ def split_inner_folds (train_inds , d , s , random_state = 42 ):
455+ inner_train0_inds , inner_train1_inds = [], []
456+ for train_index in train_inds :
457+ stratify_vec = d [train_index ] + 2 * s [train_index ]
458+ inner0 , inner1 = train_test_split (train_index , test_size = 0.5 , stratify = stratify_vec , random_state = random_state )
459+ inner_train0_inds .append (inner0 )
460+ inner_train1_inds .append (inner1 )
461+ return inner_train0_inds , inner_train1_inds
462+
463+ def filter_by_ds (inner_train1_inds , d , s ):
464+ inner1_d0_s1 , inner1_d1_s1 = [], []
465+ for inner1 in inner_train1_inds :
466+ d_fold , s_fold = d [inner1 ], s [inner1 ]
467+ mask_d0_s1 = (d_fold == 0 ) & (s_fold == 1 )
468+ mask_d1_s1 = (d_fold == 1 ) & (s_fold == 1 )
469+
470+ inner1_d0_s1 .append (inner1 [mask_d0_s1 ])
471+ inner1_d1_s1 .append (inner1 [mask_d1_s1 ])
472+ return inner1_d0_s1 , inner1_d1_s1
503473
504- params = { "ml_g_d0" : g_d0_best_params , "ml_g_d1" : g_d1_best_params , "ml_pi" : pi_best_params , "ml_m" : m_best_params }
474+ if self . _score == "nonignorable" :
505475
506- tune_res = {"g_d0_tune" : g_d0_tune_res , "g_d1_tune" : g_d1_tune_res , "pi_tune" : pi_tune_res , "m_tune" : m_tune_res }
476+ train_inds = [train_index for (train_index , _ ) in smpls ]
477+
478+ # inner folds: split train set into two halves (pi-tuning vs. m/g-tuning)
479+ inner_train0_inds , inner_train1_inds = split_inner_folds (train_inds , d , s )
480+ # split inner1 by (d,s) to build g-models for treated/control
481+ inner_train1_d0_s1 , inner_train1_d1_s1 = filter_by_ds (inner_train1_inds , d , s )
482+
483+ # Tune ml_pi
484+ x_d_z = np .column_stack ((x , d , z ))
485+ pi_tune_res = []
486+ pi_hat_full = np .full (shape = s .shape , fill_value = np .nan )
487+ for inner0 , inner1 in zip (inner_train0_inds , inner_train1_inds ):
488+ res = tune_learner (s , x_d_z , [inner0 ], "ml_pi" )
489+ best_params = res [0 ].best_params_
490+
491+ # Fit tuned model and predict
492+ ml_pi_temp = clone (self ._learner ["ml_pi" ])
493+ ml_pi_temp .set_params (** best_params )
494+ ml_pi_temp .fit (x_d_z [inner0 ], s [inner0 ])
495+ pi_hat_full [inner1 ] = _predict_zero_one_propensity (ml_pi_temp , x_d_z )[inner1 ]
496+ pi_tune_res .append (res [0 ])
497+
498+ # Tune ml_m with x + pi-hats
499+ x_pi = np .column_stack ([x , pi_hat_full .reshape (- 1 , 1 )])
500+ m_tune_res = tune_learner (d , x_pi , inner_train1_inds , "ml_m" )
501+
502+ # Tune ml_g for d=0 and d=1
503+ x_pi_d = np .column_stack ([x , d .reshape (- 1 , 1 ), pi_hat_full .reshape (- 1 , 1 )])
504+ g_d0_tune_res = tune_learner (y , x_pi_d , inner_train1_d0_s1 , "ml_g" )
505+ g_d1_tune_res = tune_learner (y , x_pi_d , inner_train1_d1_s1 , "ml_g" )
507506
508- res = {"params" : params , "tune_res" : tune_res }
507+ else :
508+ # nuisance training sets conditional on d
509+ _ , smpls_d0_s1 , _ , smpls_d1_s1 = _get_cond_smpls_2d (smpls , d , s )
510+ train_inds = [train_index for (train_index , _ ) in smpls ]
511+ train_inds_d0_s1 = [train_index for (train_index , _ ) in smpls_d0_s1 ]
512+ train_inds_d1_s1 = [train_index for (train_index , _ ) in smpls_d1_s1 ]
513+
514+ # Tune ml_g for d=0 and d=1
515+ g_d0_tune_res = tune_learner (y , x , train_inds_d0_s1 , "ml_g" )
516+ g_d1_tune_res = tune_learner (y , x , train_inds_d1_s1 , "ml_g" )
517+
518+ # Tune ml_pi and ml_m
519+ x_d = np .column_stack ((x , d ))
520+ pi_tune_res = tune_learner (s , x_d , train_inds , "ml_pi" )
521+ m_tune_res = tune_learner (d , x , train_inds , "ml_m" )
522+
523+ # Collect results
524+ params = {
525+ "ml_g_d0" : [res .best_params_ for res in g_d0_tune_res ],
526+ "ml_g_d1" : [res .best_params_ for res in g_d1_tune_res ],
527+ "ml_pi" : [res .best_params_ for res in pi_tune_res ],
528+ "ml_m" : [res .best_params_ for res in m_tune_res ],
529+ }
530+
531+ tune_res = {
532+ "g_d0_tune" : g_d0_tune_res ,
533+ "g_d1_tune" : g_d1_tune_res ,
534+ "pi_tune" : pi_tune_res ,
535+ "m_tune" : m_tune_res ,
536+ }
509537
510- return res
538+ return { "params" : params , "tune_res" : tune_res }
511539
512540 def _sensitivity_element_est (self , preds ):
513541 pass
0 commit comments