11from doubleml import DoubleMLPLR
2- import numpy as np
32from sklearn .utils import check_X_y
43
54from .double_ml_aws_lambda import DoubleMLLambda
6- from ._helper import _attach_learner , _attach_smpls , _extract_preds
5+ from ._helper import _attach_learner , _attach_smpls
76
87
98class DoubleMLPLRServerless (DoubleMLPLR , DoubleMLLambda ):
109 def __init__ (self ,
1110 lambda_function_name ,
1211 aws_region ,
1312 obj_dml_data ,
14- ml_g ,
13+ ml_l ,
1514 ml_m ,
1615 n_folds = 5 ,
1716 n_rep = 1 ,
@@ -20,18 +19,18 @@ def __init__(self,
2019 draw_sample_splitting = True ,
2120 apply_cross_fitting = True ):
2221 DoubleMLPLR .__init__ (self ,
23- obj_dml_data ,
24- ml_g ,
25- ml_m ,
26- n_folds ,
27- n_rep ,
28- score ,
29- dml_procedure ,
30- draw_sample_splitting ,
31- apply_cross_fitting )
22+ obj_dml_data = obj_dml_data ,
23+ ml_l = ml_l ,
24+ ml_m = ml_m ,
25+ n_folds = n_folds ,
26+ n_rep = n_rep ,
27+ score = score ,
28+ dml_procedure = dml_procedure ,
29+ draw_sample_splitting = draw_sample_splitting ,
30+ apply_cross_fitting = apply_cross_fitting )
3231 DoubleMLLambda .__init__ (self ,
33- lambda_function_name ,
34- aws_region )
32+ lambda_function_name = lambda_function_name ,
33+ aws_region = aws_region )
3534
3635 def _ml_nuisance_aws_lambda (self , cv_params ):
3736 assert self ._dml_data .n_treat == 1
@@ -42,18 +41,18 @@ def _ml_nuisance_aws_lambda(self, cv_params):
4241
4342 payload = self ._dml_data .get_payload ()
4443
45- payload_ml_g = payload .copy ()
44+ payload_ml_l = payload .copy ()
4645 payload_ml_m = payload .copy ()
4746
48- _attach_learner (payload_ml_g ,
49- 'ml_g ' , self .learner ['ml_g ' ],
47+ _attach_learner (payload_ml_l ,
48+ 'ml_l ' , self .learner ['ml_l ' ],
5049 self ._dml_data .y_col , self ._dml_data .x_cols )
5150
5251 _attach_learner (payload_ml_m ,
5352 'ml_m' , self .learner ['ml_m' ],
5453 self ._dml_data .d_cols [0 ], self ._dml_data .x_cols )
5554
56- payloads = _attach_smpls ([payload_ml_g , payload_ml_m ],
55+ payloads = _attach_smpls ([payload_ml_l , payload_ml_m ],
5756 [self .smpls , self .smpls ],
5857 self .n_folds ,
5958 self .n_rep ,
@@ -70,8 +69,9 @@ def _ml_nuisance_aws_lambda(self, cv_params):
7069 # compute score elements
7170 self ._psi_a [:, i_rep , self ._i_treat ], self ._psi_b [:, i_rep , self ._i_treat ] = \
7271 self ._score_elements (y , d ,
73- preds ['ml_g ' ][:, i_rep ],
72+ preds ['ml_l ' ][:, i_rep ],
7473 preds ['ml_m' ][:, i_rep ],
74+ None ,
7575 self .smpls [i_rep ])
7676
7777 return
0 commit comments