@@ -428,235 +428,114 @@ def _nuisance_tuning(
428428 ):
429429 x , y = check_X_y (self ._dml_data .x , self ._dml_data .y , force_all_finite = False )
430430 x , d = check_X_y (x , self ._dml_data .d , force_all_finite = False )
431- # time indicator is used for selection (selection not available in DoubleMLData yet)
432431 x , s = check_X_y (x , self ._dml_data .s , force_all_finite = False )
433432
434433 if self ._score == "nonignorable" :
435434 z , _ = check_X_y (self ._dml_data .z , y , force_all_finite = False )
436- dx = np .column_stack ((x , d , z ))
437- else :
438- dx = np .column_stack ((x , d ))
439435
440436 if scoring_methods is None :
441437 scoring_methods = {"ml_g" : None , "ml_pi" : None , "ml_m" : None }
442438
439+ # Nested helper functions
440+ def tune_learner (target , features , train_indices , learner_key ):
441+ return _dml_tune (
442+ target ,
443+ features ,
444+ train_indices ,
445+ self ._learner [learner_key ],
446+ param_grids [learner_key ],
447+ scoring_methods [learner_key ],
448+ n_folds_tune ,
449+ n_jobs_cv ,
450+ search_mode ,
451+ n_iter_randomized_search ,
452+ )
453+
454+ def split_inner_folds (train_inds , d , s , random_state = 42 ):
455+ inner_train0_inds , inner_train1_inds = [], []
456+ for train_index in train_inds :
457+ stratify_vec = d [train_index ] + 2 * s [train_index ]
458+ inner0 , inner1 = train_test_split (train_index , test_size = 0.5 , stratify = stratify_vec , random_state = random_state )
459+ inner_train0_inds .append (inner0 )
460+ inner_train1_inds .append (inner1 )
461+ return inner_train0_inds , inner_train1_inds
462+
463+ def filter_by_ds (inner_train1_inds , d , s ):
464+ inner1_d0_s1 , inner1_d1_s1 = [], []
465+ for inner1 in inner_train1_inds :
466+ d_fold , s_fold = d [inner1 ], s [inner1 ]
467+ mask_d0_s1 = (d_fold == 0 ) & (s_fold == 1 )
468+ mask_d1_s1 = (d_fold == 1 ) & (s_fold == 1 )
469+
470+ inner1_d0_s1 .append (inner1 [mask_d0_s1 ])
471+ inner1_d1_s1 .append (inner1 [mask_d1_s1 ])
472+ return inner1_d0_s1 , inner1_d1_s1
473+
443474 if self ._score == "nonignorable" :
444475
445476 train_inds = [train_index for (train_index , _ ) in smpls ]
446477
447478 # inner folds: split train set into two halves (pi-tuning vs. m/g-tuning)
448- def get_inner_train_inds (train_inds , d , s , random_state = 42 ):
449- inner_train0_inds = []
450- inner_train1_inds = []
451-
452- for train_index in train_inds :
453- d_fold = d [train_index ]
454- s_fold = s [train_index ]
455- stratify_vec = d_fold + 2 * s_fold
456-
457- inner0 , inner1 = train_test_split (
458- train_index , test_size = 0.5 , stratify = stratify_vec , random_state = random_state
459- )
460-
461- inner_train0_inds .append (inner0 )
462- inner_train1_inds .append (inner1 )
463-
464- return inner_train0_inds , inner_train1_inds
465-
466- inner_train0_inds , inner_train1_inds = get_inner_train_inds (train_inds , d , s )
467-
479+ inner_train0_inds , inner_train1_inds = split_inner_folds (train_inds , d , s )
468480 # split inner1 by (d,s) to build g-models for treated/control
469- def filter_inner1_by_ds (inner_train1_inds , d , s ):
470- inner1_d0_s1 = []
471- inner1_d1_s1 = []
472-
473- for inner1 in inner_train1_inds :
474- d_fold = d [inner1 ]
475- s_fold = s [inner1 ]
476-
477- mask_d0_s1 = (d_fold == 0 ) & (s_fold == 1 )
478- mask_d1_s1 = (d_fold == 1 ) & (s_fold == 1 )
479-
480- inner1_d0_s1 .append (inner1 [mask_d0_s1 ])
481- inner1_d1_s1 .append (inner1 [mask_d1_s1 ])
482-
483- return inner1_d0_s1 , inner1_d1_s1
484-
485- inner_train1_d0_s1 , inner_train1_d1_s1 = filter_inner1_by_ds (inner_train1_inds , d , s )
486-
487- x_d_z = np .concatenate ([x , d .reshape (- 1 , 1 ), z .reshape (- 1 , 1 )], axis = 1 )
488-
489- # ml_pi: tune on inner0, predict pi-hat on inner1
490- pi_hat_list = []
491- pi_tune_res_nonignorable = []
481+ inner_train1_d0_s1 , inner_train1_d1_s1 = filter_by_ds (inner_train1_inds , d , s )
492482
483+ # Tune ml_pi
484+ x_d_z = np .column_stack ((x , d , z ))
485+ pi_tune_res = []
486+ pi_hat_full = np .full (shape = s .shape , fill_value = np .nan )
493487 for inner0 , inner1 in zip (inner_train0_inds , inner_train1_inds ):
488+ res = tune_learner (s , x_d_z , [inner0 ], "ml_pi" )
489+ best_params = res [0 ].best_params_
494490
495- # tune pi on inner0
496- pi_tune_res = _dml_tune (
497- s ,
498- x_d_z ,
499- [inner0 ],
500- self ._learner ["ml_pi" ],
501- param_grids ["ml_pi" ],
502- scoring_methods ["ml_pi" ],
503- n_folds_tune ,
504- n_jobs_cv ,
505- search_mode ,
506- n_iter_randomized_search ,
507- )
508- best_params = pi_tune_res [0 ].best_params_
509-
510- # fit tuned model
491+ # Fit tuned model and predict
511492 ml_pi_temp = clone (self ._learner ["ml_pi" ])
512493 ml_pi_temp .set_params (** best_params )
513494 ml_pi_temp .fit (x_d_z [inner0 ], s [inner0 ])
495+ pi_hat_full [inner1 ] = _predict_zero_one_propensity (ml_pi_temp , x_d_z )[inner1 ]
496+ pi_tune_res .append (res [0 ])
514497
515- # predict proba on inner1
516- pi_hat_all = _predict_zero_one_propensity (ml_pi_temp , x_d_z )
517- pi_hat = pi_hat_all [inner1 ]
518- pi_hat_list .append ((inner1 , pi_hat )) # (index, value) tuple
519-
520- # save best params
521- pi_tune_res_nonignorable .append (pi_tune_res [0 ])
522-
523- pi_hat_full = np .full (shape = s .shape , fill_value = np .nan )
524-
525- for inner1 , pi_hat in pi_hat_list :
526- pi_hat_full [inner1 ] = pi_hat
527-
528- # ml_m: tune with x + pi-hats
529- x_pi = np .concatenate ([x , pi_hat_full .reshape (- 1 , 1 )], axis = 1 )
530-
531- m_tune_res = _dml_tune (
532- d ,
533- x_pi ,
534- inner_train1_inds ,
535- self ._learner ["ml_m" ],
536- param_grids ["ml_m" ],
537- scoring_methods ["ml_m" ],
538- n_folds_tune ,
539- n_jobs_cv ,
540- search_mode ,
541- n_iter_randomized_search ,
542- )
543-
544- # ml_g: tune with x + d + pi-hats for d=0, d=1
545- x_pi_d = np .concatenate ([x , d .reshape (- 1 , 1 ), pi_hat_full .reshape (- 1 , 1 )], axis = 1 )
546-
547- g_d0_tune_res = _dml_tune (
548- y ,
549- x_pi_d ,
550- inner_train1_d0_s1 ,
551- self ._learner ["ml_g" ],
552- param_grids ["ml_g" ],
553- scoring_methods ["ml_g" ],
554- n_folds_tune ,
555- n_jobs_cv ,
556- search_mode ,
557- n_iter_randomized_search ,
558- )
559- g_d1_tune_res = _dml_tune (
560- y ,
561- x_pi_d ,
562- inner_train1_d1_s1 ,
563- self ._learner ["ml_g" ],
564- param_grids ["ml_g" ],
565- scoring_methods ["ml_g" ],
566- n_folds_tune ,
567- n_jobs_cv ,
568- search_mode ,
569- n_iter_randomized_search ,
570- )
571-
572- g_d0_best_params = [xx .best_params_ for xx in g_d0_tune_res ]
573- g_d1_best_params = [xx .best_params_ for xx in g_d1_tune_res ]
574- pi_best_params = [xx .best_params_ for xx in pi_tune_res_nonignorable ]
575- m_best_params = [xx .best_params_ for xx in m_tune_res ]
498+ # Tune ml_m with x + pi-hats
499+ x_pi = np .column_stack ([x , pi_hat_full .reshape (- 1 , 1 )])
500+ m_tune_res = tune_learner (d , x_pi , inner_train1_inds , "ml_m" )
576501
577- params = {"ml_g_d0" : g_d0_best_params , "ml_g_d1" : g_d1_best_params , "ml_pi" : pi_best_params , "ml_m" : m_best_params }
578-
579- tune_res = {
580- "g_d0_tune" : g_d0_tune_res ,
581- "g_d1_tune" : g_d1_tune_res ,
582- "pi_tune" : pi_tune_res_nonignorable ,
583- "m_tune" : m_tune_res ,
584- }
585-
586- res = {"params" : params , "tune_res" : tune_res }
587-
588- return res
502+ # Tune ml_g for d=0 and d=1
503+ x_pi_d = np .column_stack ([x , d .reshape (- 1 , 1 ), pi_hat_full .reshape (- 1 , 1 )])
504+ g_d0_tune_res = tune_learner (y , x_pi_d , inner_train1_d0_s1 , "ml_g" )
505+ g_d1_tune_res = tune_learner (y , x_pi_d , inner_train1_d1_s1 , "ml_g" )
589506
590507 else :
591-
592508 # nuisance training sets conditional on d
593509 _ , smpls_d0_s1 , _ , smpls_d1_s1 = _get_cond_smpls_2d (smpls , d , s )
594510 train_inds = [train_index for (train_index , _ ) in smpls ]
595511 train_inds_d0_s1 = [train_index for (train_index , _ ) in smpls_d0_s1 ]
596512 train_inds_d1_s1 = [train_index for (train_index , _ ) in smpls_d1_s1 ]
597513
598- # hyperparameter tuning for ML
599- g_d0_tune_res = _dml_tune (
600- y ,
601- x ,
602- train_inds_d0_s1 ,
603- self ._learner ["ml_g" ],
604- param_grids ["ml_g" ],
605- scoring_methods ["ml_g" ],
606- n_folds_tune ,
607- n_jobs_cv ,
608- search_mode ,
609- n_iter_randomized_search ,
610- )
611- g_d1_tune_res = _dml_tune (
612- y ,
613- x ,
614- train_inds_d1_s1 ,
615- self ._learner ["ml_g" ],
616- param_grids ["ml_g" ],
617- scoring_methods ["ml_g" ],
618- n_folds_tune ,
619- n_jobs_cv ,
620- search_mode ,
621- n_iter_randomized_search ,
622- )
623- pi_tune_res = _dml_tune (
624- s ,
625- dx ,
626- train_inds ,
627- self ._learner ["ml_pi" ],
628- param_grids ["ml_pi" ],
629- scoring_methods ["ml_pi" ],
630- n_folds_tune ,
631- n_jobs_cv ,
632- search_mode ,
633- n_iter_randomized_search ,
634- )
635- m_tune_res = _dml_tune (
636- d ,
637- x ,
638- train_inds ,
639- self ._learner ["ml_m" ],
640- param_grids ["ml_m" ],
641- scoring_methods ["ml_m" ],
642- n_folds_tune ,
643- n_jobs_cv ,
644- search_mode ,
645- n_iter_randomized_search ,
646- )
647-
648- g_d0_best_params = [xx .best_params_ for xx in g_d0_tune_res ]
649- g_d1_best_params = [xx .best_params_ for xx in g_d1_tune_res ]
650- pi_best_params = [xx .best_params_ for xx in pi_tune_res ]
651- m_best_params = [xx .best_params_ for xx in m_tune_res ]
652-
653- params = {"ml_g_d0" : g_d0_best_params , "ml_g_d1" : g_d1_best_params , "ml_pi" : pi_best_params , "ml_m" : m_best_params }
654-
655- tune_res = {"g_d0_tune" : g_d0_tune_res , "g_d1_tune" : g_d1_tune_res , "pi_tune" : pi_tune_res , "m_tune" : m_tune_res }
514+ # Tune ml_g for d=0 and d=1
515+ g_d0_tune_res = tune_learner (y , x , train_inds_d0_s1 , "ml_g" )
516+ g_d1_tune_res = tune_learner (y , x , train_inds_d1_s1 , "ml_g" )
517+
518+ # Tune ml_pi and ml_m
519+ x_d = np .column_stack ((x , d ))
520+ pi_tune_res = tune_learner (s , x_d , train_inds , "ml_pi" )
521+ m_tune_res = tune_learner (d , x , train_inds , "ml_m" )
522+
523+ # Collect results
524+ params = {
525+ "ml_g_d0" : [res .best_params_ for res in g_d0_tune_res ],
526+ "ml_g_d1" : [res .best_params_ for res in g_d1_tune_res ],
527+ "ml_pi" : [res .best_params_ for res in pi_tune_res ],
528+ "ml_m" : [res .best_params_ for res in m_tune_res ],
529+ }
656530
657- res = {"params" : params , "tune_res" : tune_res }
531+ tune_res = {
532+ "g_d0_tune" : g_d0_tune_res ,
533+ "g_d1_tune" : g_d1_tune_res ,
534+ "pi_tune" : pi_tune_res ,
535+ "m_tune" : m_tune_res ,
536+ }
658537
659- return res
538+ return { "params" : params , "tune_res" : tune_res }
660539
661540 def _sensitivity_element_est (self , preds ):
662541 pass
0 commit comments