@@ -440,74 +440,223 @@ def _nuisance_tuning(
440440 if scoring_methods is None :
441441 scoring_methods = {"ml_g" : None , "ml_pi" : None , "ml_m" : None }
442442
443- # nuisance training sets conditional on d
444- _ , smpls_d0_s1 , _ , smpls_d1_s1 = _get_cond_smpls_2d (smpls , d , s )
445- train_inds = [train_index for (train_index , _ ) in smpls ]
446- train_inds_d0_s1 = [train_index for (train_index , _ ) in smpls_d0_s1 ]
447- train_inds_d1_s1 = [train_index for (train_index , _ ) in smpls_d1_s1 ]
448-
449- # hyperparameter tuning for ML
450- g_d0_tune_res = _dml_tune (
451- y ,
452- x ,
453- train_inds_d0_s1 ,
454- self ._learner ["ml_g" ],
455- param_grids ["ml_g" ],
456- scoring_methods ["ml_g" ],
457- n_folds_tune ,
458- n_jobs_cv ,
459- search_mode ,
460- n_iter_randomized_search ,
461- )
462- g_d1_tune_res = _dml_tune (
463- y ,
464- x ,
465- train_inds_d1_s1 ,
466- self ._learner ["ml_g" ],
467- param_grids ["ml_g" ],
468- scoring_methods ["ml_g" ],
469- n_folds_tune ,
470- n_jobs_cv ,
471- search_mode ,
472- n_iter_randomized_search ,
473- )
474- pi_tune_res = _dml_tune (
475- s ,
476- dx ,
477- train_inds ,
478- self ._learner ["ml_pi" ],
479- param_grids ["ml_pi" ],
480- scoring_methods ["ml_pi" ],
481- n_folds_tune ,
482- n_jobs_cv ,
483- search_mode ,
484- n_iter_randomized_search ,
485- )
486- m_tune_res = _dml_tune (
487- d ,
488- x ,
489- train_inds ,
490- self ._learner ["ml_m" ],
491- param_grids ["ml_m" ],
492- scoring_methods ["ml_m" ],
493- n_folds_tune ,
494- n_jobs_cv ,
495- search_mode ,
496- n_iter_randomized_search ,
497- )
443+ if self ._score == "nonignorable" :
444+
445+ train_inds = [train_index for (train_index , _ ) in smpls ]
446+
447+ # inner folds: split train set into two halves (pi-tuning vs. m/g-tuning)
448+ def get_inner_train_inds (train_inds , d , s , random_state = 42 ):
449+ inner_train0_inds = []
450+ inner_train1_inds = []
451+
452+ for train_index in train_inds :
453+ d_fold = d [train_index ]
454+ s_fold = s [train_index ]
455+ stratify_vec = d_fold + 2 * s_fold
456+
457+ inner0 , inner1 = train_test_split (
458+ train_index , test_size = 0.5 , stratify = stratify_vec , random_state = random_state
459+ )
460+
461+ inner_train0_inds .append (inner0 )
462+ inner_train1_inds .append (inner1 )
463+
464+ return inner_train0_inds , inner_train1_inds
465+
466+ inner_train0_inds , inner_train1_inds = get_inner_train_inds (train_inds , d , s )
467+
468+ # split inner1 by (d,s) to build g-models for treated/control
469+ def filter_inner1_by_ds (inner_train1_inds , d , s ):
470+ inner1_d0_s1 = []
471+ inner1_d1_s1 = []
472+
473+ for inner1 in inner_train1_inds :
474+ d_fold = d [inner1 ]
475+ s_fold = s [inner1 ]
476+
477+ mask_d0_s1 = (d_fold == 0 ) & (s_fold == 1 )
478+ mask_d1_s1 = (d_fold == 1 ) & (s_fold == 1 )
479+
480+ inner1_d0_s1 .append (inner1 [mask_d0_s1 ])
481+ inner1_d1_s1 .append (inner1 [mask_d1_s1 ])
482+
483+ return inner1_d0_s1 , inner1_d1_s1
484+
485+ inner_train1_d0_s1 , inner_train1_d1_s1 = filter_inner1_by_ds (inner_train1_inds , d , s )
486+
487+ x_d_z = np .concatenate ([x , d .reshape (- 1 , 1 ), z .reshape (- 1 , 1 )], axis = 1 )
488+
489+ # ml_pi: tune on inner0, predict pi-hat on inner1
490+ pi_hat_list = []
491+ pi_tune_res_nonignorable = []
492+
493+ for inner0 , inner1 in zip (inner_train0_inds , inner_train1_inds ):
494+
495+ # tune pi on inner0
496+ pi_tune_res = _dml_tune (
497+ s ,
498+ x_d_z ,
499+ [inner0 ],
500+ self ._learner ["ml_pi" ],
501+ param_grids ["ml_pi" ],
502+ scoring_methods ["ml_pi" ],
503+ n_folds_tune ,
504+ n_jobs_cv ,
505+ search_mode ,
506+ n_iter_randomized_search ,
507+ )
508+ best_params = pi_tune_res [0 ].best_params_
509+
510+ # fit tuned model
511+ ml_pi_temp = clone (self ._learner ["ml_pi" ])
512+ ml_pi_temp .set_params (** best_params )
513+ ml_pi_temp .fit (x_d_z [inner0 ], s [inner0 ])
514+
515+ # predict proba on inner1
516+ pi_hat_all = _predict_zero_one_propensity (ml_pi_temp , x_d_z )
517+ pi_hat = pi_hat_all [inner1 ]
518+ pi_hat_list .append ((inner1 , pi_hat )) # (index, value) tuple
519+
520+ # save best params
521+ pi_tune_res_nonignorable .append (pi_tune_res [0 ])
522+
523+ pi_hat_full = np .full (shape = s .shape , fill_value = np .nan )
524+
525+ for inner1 , pi_hat in pi_hat_list :
526+ pi_hat_full [inner1 ] = pi_hat
527+
528+ # ml_m: tune with x + pi-hats
529+ x_pi = np .concatenate ([x , pi_hat_full .reshape (- 1 , 1 )], axis = 1 )
530+
531+ m_tune_res = _dml_tune (
532+ d ,
533+ x_pi ,
534+ inner_train1_inds ,
535+ self ._learner ["ml_m" ],
536+ param_grids ["ml_m" ],
537+ scoring_methods ["ml_m" ],
538+ n_folds_tune ,
539+ n_jobs_cv ,
540+ search_mode ,
541+ n_iter_randomized_search ,
542+ )
543+
544+ # ml_g: tune with x + d + pi-hats for d=0, d=1
545+ x_pi_d = np .concatenate ([x , d .reshape (- 1 , 1 ), pi_hat_full .reshape (- 1 , 1 )], axis = 1 )
546+
547+ g_d0_tune_res = _dml_tune (
548+ y ,
549+ x_pi_d ,
550+ inner_train1_d0_s1 ,
551+ self ._learner ["ml_g" ],
552+ param_grids ["ml_g" ],
553+ scoring_methods ["ml_g" ],
554+ n_folds_tune ,
555+ n_jobs_cv ,
556+ search_mode ,
557+ n_iter_randomized_search ,
558+ )
559+ g_d1_tune_res = _dml_tune (
560+ y ,
561+ x_pi_d ,
562+ inner_train1_d1_s1 ,
563+ self ._learner ["ml_g" ],
564+ param_grids ["ml_g" ],
565+ scoring_methods ["ml_g" ],
566+ n_folds_tune ,
567+ n_jobs_cv ,
568+ search_mode ,
569+ n_iter_randomized_search ,
570+ )
571+
572+ g_d0_best_params = [xx .best_params_ for xx in g_d0_tune_res ]
573+ g_d1_best_params = [xx .best_params_ for xx in g_d1_tune_res ]
574+ pi_best_params = [xx .best_params_ for xx in pi_tune_res_nonignorable ]
575+ m_best_params = [xx .best_params_ for xx in m_tune_res ]
576+
577+ params = {"ml_g_d0" : g_d0_best_params , "ml_g_d1" : g_d1_best_params , "ml_pi" : pi_best_params , "ml_m" : m_best_params }
578+
579+ tune_res = {
580+ "g_d0_tune" : g_d0_tune_res ,
581+ "g_d1_tune" : g_d1_tune_res ,
582+ "pi_tune" : pi_tune_res_nonignorable ,
583+ "m_tune" : m_tune_res ,
584+ }
585+
586+ res = {"params" : params , "tune_res" : tune_res }
587+
588+ return res
589+
590+ else :
591+
592+ # nuisance training sets conditional on d
593+ _ , smpls_d0_s1 , _ , smpls_d1_s1 = _get_cond_smpls_2d (smpls , d , s )
594+ train_inds = [train_index for (train_index , _ ) in smpls ]
595+ train_inds_d0_s1 = [train_index for (train_index , _ ) in smpls_d0_s1 ]
596+ train_inds_d1_s1 = [train_index for (train_index , _ ) in smpls_d1_s1 ]
597+
598+ # hyperparameter tuning for ML
599+ g_d0_tune_res = _dml_tune (
600+ y ,
601+ x ,
602+ train_inds_d0_s1 ,
603+ self ._learner ["ml_g" ],
604+ param_grids ["ml_g" ],
605+ scoring_methods ["ml_g" ],
606+ n_folds_tune ,
607+ n_jobs_cv ,
608+ search_mode ,
609+ n_iter_randomized_search ,
610+ )
611+ g_d1_tune_res = _dml_tune (
612+ y ,
613+ x ,
614+ train_inds_d1_s1 ,
615+ self ._learner ["ml_g" ],
616+ param_grids ["ml_g" ],
617+ scoring_methods ["ml_g" ],
618+ n_folds_tune ,
619+ n_jobs_cv ,
620+ search_mode ,
621+ n_iter_randomized_search ,
622+ )
623+ pi_tune_res = _dml_tune (
624+ s ,
625+ dx ,
626+ train_inds ,
627+ self ._learner ["ml_pi" ],
628+ param_grids ["ml_pi" ],
629+ scoring_methods ["ml_pi" ],
630+ n_folds_tune ,
631+ n_jobs_cv ,
632+ search_mode ,
633+ n_iter_randomized_search ,
634+ )
635+ m_tune_res = _dml_tune (
636+ d ,
637+ x ,
638+ train_inds ,
639+ self ._learner ["ml_m" ],
640+ param_grids ["ml_m" ],
641+ scoring_methods ["ml_m" ],
642+ n_folds_tune ,
643+ n_jobs_cv ,
644+ search_mode ,
645+ n_iter_randomized_search ,
646+ )
498647
499- g_d0_best_params = [xx .best_params_ for xx in g_d0_tune_res ]
500- g_d1_best_params = [xx .best_params_ for xx in g_d1_tune_res ]
501- pi_best_params = [xx .best_params_ for xx in pi_tune_res ]
502- m_best_params = [xx .best_params_ for xx in m_tune_res ]
648+ g_d0_best_params = [xx .best_params_ for xx in g_d0_tune_res ]
649+ g_d1_best_params = [xx .best_params_ for xx in g_d1_tune_res ]
650+ pi_best_params = [xx .best_params_ for xx in pi_tune_res ]
651+ m_best_params = [xx .best_params_ for xx in m_tune_res ]
503652
504- params = {"ml_g_d0" : g_d0_best_params , "ml_g_d1" : g_d1_best_params , "ml_pi" : pi_best_params , "ml_m" : m_best_params }
653+ params = {"ml_g_d0" : g_d0_best_params , "ml_g_d1" : g_d1_best_params , "ml_pi" : pi_best_params , "ml_m" : m_best_params }
505654
506- tune_res = {"g_d0_tune" : g_d0_tune_res , "g_d1_tune" : g_d1_tune_res , "pi_tune" : pi_tune_res , "m_tune" : m_tune_res }
655+ tune_res = {"g_d0_tune" : g_d0_tune_res , "g_d1_tune" : g_d1_tune_res , "pi_tune" : pi_tune_res , "m_tune" : m_tune_res }
507656
508- res = {"params" : params , "tune_res" : tune_res }
657+ res = {"params" : params , "tune_res" : tune_res }
509658
510- return res
659+ return res
511660
512661 def _sensitivity_element_est (self , preds ):
513662 pass
0 commit comments