From 11cdb3ea780b7ac35c865bae45bbb7581e4cbad8 Mon Sep 17 00:00:00 2001 From: Atharva Kelkar Date: Tue, 2 Sep 2025 15:14:31 +0200 Subject: [PATCH 1/3] ENH: Pass fit_params through train methods Signed-off-by: Atharva Kelkar --- econml/dml/_rlearner.py | 6 +++--- econml/dml/dml.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/econml/dml/_rlearner.py b/econml/dml/_rlearner.py index 6eca79aab..a594c9890 100644 --- a/econml/dml/_rlearner.py +++ b/econml/dml/_rlearner.py @@ -50,12 +50,12 @@ def __init__(self, model_y: ModelSelector, model_t: ModelSelector): self._model_y = model_y self._model_t = model_t - def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): + def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params): assert Z is None, "Cannot accept instrument!" self._model_t.train(is_selecting, folds, X, W, T, ** - filter_none_kwargs(sample_weight=sample_weight, groups=groups)) + filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params)) self._model_y.train(is_selecting, folds, X, W, Y, ** - filter_none_kwargs(sample_weight=sample_weight, groups=groups)) + filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params)) return self def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, diff --git a/econml/dml/dml.py b/econml/dml/dml.py index 23e47f8ec..b77ab9dc0 100644 --- a/econml/dml/dml.py +++ b/econml/dml/dml.py @@ -97,7 +97,7 @@ def __init__(self, model: SingleModelSelector, discrete_target): self._model = clone(model, safe=False) self._discrete_target = discrete_target - def train(self, is_selecting, folds, X, W, Target, sample_weight=None, groups=None): + def train(self, is_selecting, folds, X, W, Target, sample_weight=None, groups=None, **fit_params): if self._discrete_target: # In this case, the Target is the one-hot-encoding of the treatment variable # We need to go back to the label representation of the one-hot so as to call @@ -108,7 +108,7 @@ def train(self, is_selecting, folds, X, W, Target, sample_weight=None, groups=No Target = inverse_onehot(Target) self._model.train(is_selecting, folds, _combine(X, W, Target.shape[0]), Target, - **filter_none_kwargs(groups=groups, sample_weight=sample_weight)) + **filter_none_kwargs(groups=groups, sample_weight=sample_weight, **fit_params)) return self @property From 4a5f48279d561166738d0f7c65bd8cfa27b82c92 Mon Sep 17 00:00:00 2001 From: Atharva Kelkar Date: Tue, 2 Sep 2025 15:23:57 +0200 Subject: [PATCH 2/3] ENH: passing fit_params through all train functions Signed-off-by: Atharva Kelkar --- econml/dr/_drlearner.py | 4 ++-- econml/iv/dml/_dml.py | 18 +++++++++--------- econml/iv/dr/_dr.py | 26 +++++++++++++------------- econml/panel/dml/_dml.py | 6 +++--- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/econml/dr/_drlearner.py b/econml/dr/_drlearner.py index d7e468724..720532d91 100644 --- a/econml/dr/_drlearner.py +++ b/econml/dr/_drlearner.py @@ -70,7 +70,7 @@ def __init__(self, def _combine(self, X, W): return np.hstack([arr for arr in [X, W] if arr is not None]) - def train(self, is_selecting, folds, Y, T, X=None, W=None, *, sample_weight=None, groups=None): + def train(self, is_selecting, folds, Y, T, X=None, W=None, *, sample_weight=None, groups=None, **fit_params): if Y.ndim != 1 and (Y.ndim != 2 or Y.shape[1] != 1): raise ValueError("The outcome matrix must be of shape ({0}, ) or ({0}, 1), " "instead got {1}.".format(len(X), Y.shape)) @@ -80,7 +80,7 @@ def train(self, is_selecting, folds, Y, T, X=None, W=None, *, sample_weight=None raise AttributeError("Provided crossfit folds contain training splits that " + "don't contain all treatments") XW = self._combine(X, W) - filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight) + filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight, **fit_params) self._model_propensity.train(is_selecting, folds, XW, inverse_onehot(T), groups=groups, **filtered_kwargs) self._model_regression.train(is_selecting, folds, np.hstack([XW, T]), Y, groups=groups, **filtered_kwargs) diff --git a/econml/iv/dml/_dml.py b/econml/iv/dml/_dml.py index 12e690e34..58a62c116 100644 --- a/econml/iv/dml/_dml.py +++ b/econml/iv/dml/_dml.py @@ -53,16 +53,16 @@ def __init__(self, else: self._model_z_xw = model_z - def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): - self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups) - self._model_t_xw.train(is_selecting, folds, X=X, W=W, Target=T, sample_weight=sample_weight, groups=groups) + def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params): + self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups, **fit_params) + self._model_t_xw.train(is_selecting, folds, X=X, W=W, Target=T, sample_weight=sample_weight, groups=groups, **fit_params) if self._projection: # concat W and Z WZ = _combine(W, Z, Y.shape[0]) self._model_t_xwz.train(is_selecting, folds, X=X, W=WZ, Target=T, - sample_weight=sample_weight, groups=groups) + sample_weight=sample_weight, groups=groups, **fit_params) else: - self._model_z_xw.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups) + self._model_z_xw.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups, **fit_params) return self def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): @@ -720,15 +720,15 @@ def __init__(self, model_y_xw: ModelSelector, model_t_xw: ModelSelector, model_t self._model_t_xw = model_t_xw self._model_t_xwz = model_t_xwz - def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): + def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params): self._model_y_xw.train(is_selecting, folds, X, W, Y, ** - filter_none_kwargs(sample_weight=sample_weight, groups=groups)) + filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params)) self._model_t_xw.train(is_selecting, folds, X, W, T, ** - filter_none_kwargs(sample_weight=sample_weight, groups=groups)) + filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params)) # concat W and Z WZ = _combine(W, Z, Y.shape[0]) self._model_t_xwz.train(is_selecting, folds, X, WZ, T, - **filter_none_kwargs(sample_weight=sample_weight, groups=groups)) + **filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params)) return self def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): diff --git a/econml/iv/dr/_dr.py b/econml/iv/dr/_dr.py index 6de0591b0..5a879004e 100644 --- a/econml/iv/dr/_dr.py +++ b/econml/iv/dr/_dr.py @@ -56,20 +56,20 @@ def __init__(self, *, prel_model_effect, model_y_xw, model_t_xw, model_z, else: self._model_z_xw = model_z - def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): + def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params): # T and Z only allow single continuous or binary, keep the shape of (n,) for continuous and (n,1) for binary T = T.ravel() if not self._discrete_treatment else T Z = Z.ravel() if not self._discrete_instrument else Z - self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups) - self._model_t_xw.train(is_selecting, folds, X=X, W=W, Target=T, sample_weight=sample_weight, groups=groups) + self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups, **fit_params) + self._model_t_xw.train(is_selecting, folds, X=X, W=W, Target=T, sample_weight=sample_weight, groups=groups, **fit_params) if self._projection: WZ = _combine(W, Z, Y.shape[0]) self._model_t_xwz.train(is_selecting, folds, X=X, W=WZ, Target=T, - sample_weight=sample_weight, groups=groups) + sample_weight=sample_weight, groups=groups, **fit_params) else: - self._model_z_xw.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups) + self._model_z_xw.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups, **fit_params) # TODO: prel_model_effect could allow sample_var and freq_weight? if self._discrete_instrument: @@ -77,7 +77,7 @@ def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight if self._discrete_treatment: T = inverse_onehot(T) self._prel_model_effect.fit(Y, T, Z=Z, X=X, - W=W, sample_weight=sample_weight, groups=groups) + W=W, sample_weight=sample_weight, groups=groups, **fit_params) return self def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): @@ -215,11 +215,11 @@ def _get_target(self, T_res, Z_res, T, Z): def train(self, is_selecting, folds, prel_theta, Y_res, T_res, Z_res, - Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): + Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params): # T and Z only allow single continuous or binary, keep the shape of (n,) for continuous and (n,1) for binary target = self._get_target(T_res, Z_res, T, Z) self._model_tz_xw.train(is_selecting, folds, X=X, W=W, Target=target, - sample_weight=sample_weight, groups=groups) + sample_weight=sample_weight, groups=groups, **fit_params) return self @@ -2386,16 +2386,16 @@ def __init__(self, self._dummy_z = dummy_z self._prel_model_effect = prel_model_effect - def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): - self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups) + def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params): + self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups, **fit_params) # concat W and Z WZ = _combine(W, Z, Y.shape[0]) - self._model_t_xwz.train(is_selecting, folds, X=X, W=WZ, Target=T, sample_weight=sample_weight, groups=groups) - self._dummy_z.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups) + self._model_t_xwz.train(is_selecting, folds, X=X, W=WZ, Target=T, sample_weight=sample_weight, groups=groups, **fit_params) + self._dummy_z.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups, **fit_params) # we need to undo the one-hot encoding for calling effect, # since it expects raw values self._prel_model_effect.fit(Y, inverse_onehot(T), Z=inverse_onehot(Z), X=X, W=W, - sample_weight=sample_weight, groups=groups) + sample_weight=sample_weight, groups=groups, **fit_params) return self def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): diff --git a/econml/panel/dml/_dml.py b/econml/panel/dml/_dml.py index cf336eb61..653c06be3 100644 --- a/econml/panel/dml/_dml.py +++ b/econml/panel/dml/_dml.py @@ -40,7 +40,7 @@ def __init__(self, model_y, model_t, n_periods): self._model_t = model_t self.n_periods = n_periods - def train(self, is_selecting, folds, Y, T, X=None, W=None, sample_weight=None, groups=None): + def train(self, is_selecting, folds, Y, T, X=None, W=None, sample_weight=None, groups=None, **fit_params): """Fit a series of nuisance models for each period or period pairs.""" assert Y.shape[0] % self.n_periods == 0, \ "Length of training data should be an integer multiple of time periods." @@ -87,13 +87,13 @@ def _translate_inds(t, inds): self._index_or_None(X, period_filters[t]), self._index_or_None( W, period_filters[t]), - Y[period_filters[self.n_periods - 1]]) + Y[period_filters[self.n_periods - 1]], **fit_params) for j in np.arange(t, self.n_periods): self._model_t_trained[j][t].train( is_selecting, translated_folds, self._index_or_None(X, period_filters[t]), self._index_or_None(W, period_filters[t]), - T[period_filters[j]]) + T[period_filters[j]], **fit_params) return self def predict(self, Y, T, X=None, W=None, sample_weight=None, groups=None): From be106b5121ed85d6fffbbea481c0b1dc9b4af1c2 Mon Sep 17 00:00:00 2001 From: Atharva Kelkar Date: Tue, 2 Sep 2025 15:36:00 +0200 Subject: [PATCH 3/3] Linted Signed-off-by: Atharva Kelkar --- econml/iv/dml/_dml.py | 487 ++++++++----- econml/iv/dr/_dr.py | 1608 ++++++++++++++++++++++++----------------- 2 files changed, 1238 insertions(+), 857 deletions(-) diff --git a/econml/iv/dml/_dml.py b/econml/iv/dml/_dml.py index 58a62c116..b09af039d 100644 --- a/econml/iv/dml/_dml.py +++ b/econml/iv/dml/_dml.py @@ -22,10 +22,18 @@ from ..._ortho_learner import _OrthoLearner from ..._cate_estimator import LinearModelFinalCateEstimatorMixin, LinearCateEstimator from ...sklearn_extensions.linear_model import StatsModels2SLS, StatsModelsLinearRegression -from ...sklearn_extensions.model_selection import (ModelSelector, SingleModelSelector) -from ...utilities import (get_feature_names_or_default, filter_none_kwargs, add_intercept, - cross_product, broadcast_unit_treatments, reshape_treatmentwise_effects, shape, - parse_final_model_params, Summary) +from ...sklearn_extensions.model_selection import ModelSelector, SingleModelSelector +from ...utilities import ( + get_feature_names_or_default, + filter_none_kwargs, + add_intercept, + cross_product, + broadcast_unit_treatments, + reshape_treatmentwise_effects, + shape, + parse_final_model_params, + Summary, +) from ...dml.dml import _make_first_stage_selector, _FinalWrapper from ...dml._rlearner import _ModelFinal from ..._shap import _shap_explain_joint_linear_model_cate, _shap_explain_model_cate @@ -39,12 +47,9 @@ def _combine(W, Z, n_samples): class _OrthoIVNuisanceSelector(ModelSelector): - - def __init__(self, - model_y_xw: SingleModelSelector, - model_t_xw: SingleModelSelector, - model_z: SingleModelSelector, - projection): + def __init__( + self, model_y_xw: SingleModelSelector, model_t_xw: SingleModelSelector, model_z: SingleModelSelector, projection + ): self._model_y_xw = model_y_xw self._model_t_xw = model_t_xw self._projection = projection @@ -54,15 +59,22 @@ def __init__(self, self._model_z_xw = model_z def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params): - self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups, **fit_params) - self._model_t_xw.train(is_selecting, folds, X=X, W=W, Target=T, sample_weight=sample_weight, groups=groups, **fit_params) + self._model_y_xw.train( + is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups, **fit_params + ) + self._model_t_xw.train( + is_selecting, folds, X=X, W=W, Target=T, sample_weight=sample_weight, groups=groups, **fit_params + ) if self._projection: # concat W and Z WZ = _combine(W, Z, Y.shape[0]) - self._model_t_xwz.train(is_selecting, folds, X=X, W=WZ, Target=T, - sample_weight=sample_weight, groups=groups, **fit_params) + self._model_t_xwz.train( + is_selecting, folds, X=X, W=WZ, Target=T, sample_weight=sample_weight, groups=groups, **fit_params + ) else: - self._model_z_xw.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups, **fit_params) + self._model_z_xw.train( + is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups, **fit_params + ) return self def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): @@ -124,11 +136,11 @@ def __init__(self, model_final, featurizer, fit_cate_intercept): self._fit_cate_intercept = fit_cate_intercept if self._fit_cate_intercept: - add_intercept_trans = FunctionTransformer(add_intercept, - validate=True) + add_intercept_trans = FunctionTransformer(add_intercept, validate=True) if featurizer: - self._featurizer = Pipeline([('featurize', self._original_featurizer), - ('add_intercept', add_intercept_trans)]) + self._featurizer = Pipeline( + [('featurize', self._original_featurizer), ('add_intercept', add_intercept_trans)] + ) else: self._featurizer = add_intercept_trans else: @@ -146,8 +158,19 @@ def _combine(self, X, T, fitting=True): F = np.ones((T.shape[0], 1)) return cross_product(F, T) - def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None, - sample_weight=None, freq_weight=None, sample_var=None, groups=None): + def fit( + self, + Y, + T, + X=None, + W=None, + Z=None, + nuisances=None, + sample_weight=None, + freq_weight=None, + sample_var=None, + groups=None, + ): Y_res, T_res, Z_res = nuisances # Track training dimensions to see if Y or T is a vector instead of a 2-dimensional array @@ -156,20 +179,19 @@ def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None, XT_res = self._combine(X, T_res) XZ_res = self._combine(X, Z_res) - filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight, - freq_weight=freq_weight, sample_var=sample_var) + filtered_kwargs = filter_none_kwargs( + sample_weight=sample_weight, freq_weight=freq_weight, sample_var=sample_var + ) self._model_final.fit(XZ_res, XT_res, Y_res, **filtered_kwargs) return self def predict(self, X=None): - X2, T = broadcast_unit_treatments(X if X is not None else np.empty((1, 0)), - self._d_t[0] if self._d_t else 1) + X2, T = broadcast_unit_treatments(X if X is not None else np.empty((1, 0)), self._d_t[0] if self._d_t else 1) XT = self._combine(None if X is None else X2, T, fitting=False) prediction = self._model_final.predict(XT) - return reshape_treatmentwise_effects(prediction, - self._d_t, self._d_y) + return reshape_treatmentwise_effects(prediction, self._d_t, self._d_y) def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, groups=None): Y_res, T_res, Z_res = nuisances @@ -180,8 +202,9 @@ def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None effects = self.predict(X).reshape((-1, Y_res.shape[1], T_res.shape[1])) Y_res_pred = np.einsum('ijk,ik->ij', effects, T_res).reshape(Y_res.shape) if sample_weight is not None: - return np.linalg.norm(np.average(cross_product(Z_res, Y_res - Y_res_pred), weights=sample_weight, axis=0), - ord=2) + return np.linalg.norm( + np.average(cross_product(Z_res, Y_res - Y_res_pred), weights=sample_weight, axis=0), ord=2 + ) else: return np.linalg.norm(np.mean(cross_product(Z_res, Y_res - Y_res_pred), axis=0), ord=2) @@ -359,24 +382,27 @@ def true_heterogeneity_function(X): (-1.27031..., 0.99694...) """ - def __init__(self, *, - model_y_xw="auto", - model_t_xw="auto", - model_t_xwz="auto", - model_z_xw="auto", - projection=False, - featurizer=None, - fit_cate_intercept=True, - discrete_outcome=False, - discrete_treatment=False, - treatment_featurizer=None, - discrete_instrument=False, - categories='auto', - cv=2, - mc_iters=None, - mc_agg='mean', - random_state=None, - allow_missing=False): + def __init__( + self, + *, + model_y_xw="auto", + model_t_xw="auto", + model_t_xwz="auto", + model_z_xw="auto", + projection=False, + featurizer=None, + fit_cate_intercept=True, + discrete_outcome=False, + discrete_treatment=False, + treatment_featurizer=None, + discrete_instrument=False, + categories='auto', + cv=2, + mc_iters=None, + mc_agg='mean', + random_state=None, + allow_missing=False, + ): self.model_y_xw = clone(model_y_xw, safe=False) self.model_t_xw = clone(model_t_xw, safe=False) self.model_t_xwz = clone(model_t_xwz, safe=False) @@ -385,16 +411,18 @@ def __init__(self, *, self.featurizer = clone(featurizer, safe=False) self.fit_cate_intercept = fit_cate_intercept - super().__init__(discrete_outcome=discrete_outcome, - discrete_instrument=discrete_instrument, - discrete_treatment=discrete_treatment, - treatment_featurizer=treatment_featurizer, - categories=categories, - cv=cv, - mc_iters=mc_iters, - mc_agg=mc_agg, - random_state=random_state, - allow_missing=allow_missing) + super().__init__( + discrete_outcome=discrete_outcome, + discrete_instrument=discrete_instrument, + discrete_treatment=discrete_treatment, + treatment_featurizer=treatment_featurizer, + categories=categories, + cv=cv, + mc_iters=mc_iters, + mc_agg=mc_agg, + random_state=random_state, + allow_missing=allow_missing, + ) def _gen_allowed_missing_vars(self): return ['W'] if self.allow_missing else [] @@ -409,32 +437,44 @@ def _gen_ortho_learner_model_final(self): return _OrthoIVModelFinal(self._gen_model_final(), self._gen_featurizer(), self.fit_cate_intercept) def _gen_ortho_learner_model_nuisance(self): - model_y = _make_first_stage_selector(self.model_y_xw, - is_discrete=self.discrete_outcome, - random_state=self.random_state) + model_y = _make_first_stage_selector( + self.model_y_xw, is_discrete=self.discrete_outcome, random_state=self.random_state + ) - model_t = _make_first_stage_selector(self.model_t_xw, - is_discrete=self.discrete_treatment, - random_state=self.random_state) + model_t = _make_first_stage_selector( + self.model_t_xw, is_discrete=self.discrete_treatment, random_state=self.random_state + ) if self.projection: # train E[T|X,W,Z] - model_z = _make_first_stage_selector(self.model_t_xwz, - is_discrete=self.discrete_treatment, - random_state=self.random_state) + model_z = _make_first_stage_selector( + self.model_t_xwz, is_discrete=self.discrete_treatment, random_state=self.random_state + ) else: # train E[Z|X,W] # note: discrete_instrument rather than discrete_treatment in call to _make_first_stage_selector - model_z = _make_first_stage_selector(self.model_z_xw, - is_discrete=self.discrete_instrument, - random_state=self.random_state) - - return _OrthoIVNuisanceSelector(model_y, model_t, model_z, - self.projection) - - def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None, - cache_values=False, inference="auto"): + model_z = _make_first_stage_selector( + self.model_z_xw, is_discrete=self.discrete_instrument, random_state=self.random_state + ) + + return _OrthoIVNuisanceSelector(model_y, model_t, model_z, self.projection) + + def fit( + self, + Y, + T, + *, + Z, + X=None, + W=None, + sample_weight=None, + freq_weight=None, + sample_var=None, + groups=None, + cache_values=False, + inference="auto", + ): """ Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`. @@ -475,18 +515,33 @@ def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, freq_weight=None, self: OrthoIV instance """ if self.projection: - assert self.model_z_xw == "auto", ("In the case of projection=True, model_z_xw will not be fitted, " - "please leave it when initializing the estimator!") + assert self.model_z_xw == "auto", ( + "In the case of projection=True, model_z_xw will not be fitted, " + "please leave it when initializing the estimator!" + ) else: - assert self.model_t_xwz == "auto", ("In the case of projection=False, model_t_xwz will not be fitted, " - "please leave it when initializing the estimator!") + assert self.model_t_xwz == "auto", ( + "In the case of projection=False, model_t_xwz will not be fitted, " + "please leave it when initializing the estimator!" + ) # Replacing fit from _OrthoLearner, to reorder arguments and improve the docstring - return super().fit(Y, T, X=X, W=W, Z=Z, - sample_weight=sample_weight, freq_weight=freq_weight, sample_var=sample_var, groups=groups, - cache_values=cache_values, inference=inference) + return super().fit( + Y, + T, + X=X, + W=W, + Z=Z, + sample_weight=sample_weight, + freq_weight=freq_weight, + sample_var=sample_var, + groups=groups, + cache_values=cache_values, + inference=inference, + ) def refit_final(self, *, inference='auto'): return super().refit_final(inference=inference) + refit_final.__doc__ = _OrthoLearner.refit_final.__doc__ def score(self, Y, T, Z, X=None, W=None, sample_weight=None): @@ -701,8 +756,9 @@ def residuals_(self): if not hasattr(self, '_cached_values'): raise AttributeError("Estimator is not fitted yet!") if self._cached_values is None: - raise AttributeError("`fit` was called with `cache_values=False`. " - "Set to `True` to enable residual storage.") + raise AttributeError( + "`fit` was called with `cache_values=False`. Set to `True` to enable residual storage." + ) Y_res, T_res, Z_res = self._cached_values.nuisances return Y_res, T_res, Z_res, self._cached_values.X, self._cached_values.W, self._cached_values.Z @@ -721,14 +777,22 @@ def __init__(self, model_y_xw: ModelSelector, model_t_xw: ModelSelector, model_t self._model_t_xwz = model_t_xwz def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params): - self._model_y_xw.train(is_selecting, folds, X, W, Y, ** - filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params)) - self._model_t_xw.train(is_selecting, folds, X, W, T, ** - filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params)) + self._model_y_xw.train( + is_selecting, folds, X, W, Y, **filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params) + ) + self._model_t_xw.train( + is_selecting, folds, X, W, T, **filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params) + ) # concat W and Z WZ = _combine(W, Z, Y.shape[0]) - self._model_t_xwz.train(is_selecting, folds, X, WZ, T, - **filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params)) + self._model_t_xwz.train( + is_selecting, + folds, + X, + WZ, + T, + **filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params), + ) return self def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): @@ -785,8 +849,21 @@ class _BaseDMLIV(_OrthoLearner): # A helper class that access all the internal fitted objects of a DMLIV Cate Estimator. # Used by both Parametric and Non Parametric DMLIV. # override only so that we can enforce Z to be required - def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None, - cache_values=False, inference=None): + def fit( + self, + Y, + T, + *, + Z, + X=None, + W=None, + sample_weight=None, + freq_weight=None, + sample_var=None, + groups=None, + cache_values=False, + inference=None, + ): """ Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`. @@ -825,9 +902,19 @@ def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, freq_weight=None, ------- self """ - return super().fit(Y, T, X=X, W=W, Z=Z, - sample_weight=sample_weight, freq_weight=freq_weight, sample_var=sample_var, groups=groups, - cache_values=cache_values, inference=inference) + return super().fit( + Y, + T, + X=X, + W=W, + Z=Z, + sample_weight=sample_weight, + freq_weight=freq_weight, + sample_var=sample_var, + groups=groups, + cache_values=cache_values, + inference=inference, + ) def score(self, Y, T, Z, X=None, W=None, sample_weight=None): """ @@ -960,8 +1047,9 @@ def residuals_(self): if not hasattr(self, '_cached_values'): raise AttributeError("Estimator is not fitted yet!") if self._cached_values is None: - raise AttributeError("`fit` was called with `cache_values=False`. " - "Set to `True` to enable residual storage.") + raise AttributeError( + "`fit` was called with `cache_values=False`. Set to `True` to enable residual storage." + ) Y_res, T_res = self._cached_values.nuisances return Y_res, T_res, self._cached_values.X, self._cached_values.W, self._cached_values.Z @@ -1151,39 +1239,44 @@ def true_heterogeneity_function(X): """ - def __init__(self, *, - model_y_xw="auto", - model_t_xw="auto", - model_t_xwz="auto", - model_final=StatsModelsLinearRegression(fit_intercept=False), - featurizer=None, - fit_cate_intercept=True, - discrete_outcome=False, - discrete_treatment=False, - treatment_featurizer=None, - discrete_instrument=False, - categories='auto', - cv=2, - mc_iters=None, - mc_agg='mean', - random_state=None, - allow_missing=False): + def __init__( + self, + *, + model_y_xw="auto", + model_t_xw="auto", + model_t_xwz="auto", + model_final=StatsModelsLinearRegression(fit_intercept=False), + featurizer=None, + fit_cate_intercept=True, + discrete_outcome=False, + discrete_treatment=False, + treatment_featurizer=None, + discrete_instrument=False, + categories='auto', + cv=2, + mc_iters=None, + mc_agg='mean', + random_state=None, + allow_missing=False, + ): self.model_y_xw = clone(model_y_xw, safe=False) self.model_t_xw = clone(model_t_xw, safe=False) self.model_t_xwz = clone(model_t_xwz, safe=False) self.model_final = clone(model_final, safe=False) self.featurizer = clone(featurizer, safe=False) self.fit_cate_intercept = fit_cate_intercept - super().__init__(discrete_outcome=discrete_outcome, - discrete_treatment=discrete_treatment, - treatment_featurizer=treatment_featurizer, - discrete_instrument=discrete_instrument, - categories=categories, - cv=cv, - mc_iters=mc_iters, - mc_agg=mc_agg, - random_state=random_state, - allow_missing=allow_missing) + super().__init__( + discrete_outcome=discrete_outcome, + discrete_treatment=discrete_treatment, + treatment_featurizer=treatment_featurizer, + discrete_instrument=discrete_instrument, + categories=categories, + cv=cv, + mc_iters=mc_iters, + mc_agg=mc_agg, + random_state=random_state, + allow_missing=allow_missing, + ) def _gen_featurizer(self): return clone(self.featurizer, safe=False) @@ -1204,10 +1297,9 @@ def _gen_ortho_learner_model_nuisance(self): return _BaseDMLIVNuisanceSelector(self._gen_model_y_xw(), self._gen_model_t_xw(), self._gen_model_t_xwz()) def _gen_ortho_learner_model_final(self): - return _BaseDMLIVModelFinal(_FinalWrapper(self._gen_model_final(), - self.fit_cate_intercept, - self._gen_featurizer(), - False)) + return _BaseDMLIVModelFinal( + _FinalWrapper(self._gen_model_final(), self.fit_cate_intercept, self._gen_featurizer(), False) + ) @property def bias_part_of_coef(self): @@ -1221,12 +1313,18 @@ def shap_values(self, X, *, feature_names=None, treatment_names=None, output_nam if hasattr(self, "featurizer_") and self.featurizer_ is not None: X = self.featurizer_.transform(X) feature_names = self.cate_feature_names(feature_names) - return _shap_explain_joint_linear_model_cate(self.model_final_, X, self._d_t, self._d_y, - self.bias_part_of_coef, - feature_names=feature_names, treatment_names=treatment_names, - output_names=output_names, - input_names=self._input_names, - background_samples=background_samples) + return _shap_explain_joint_linear_model_cate( + self.model_final_, + X, + self._d_t, + self._d_y, + self.bias_part_of_coef, + feature_names=feature_names, + treatment_names=treatment_names, + output_names=output_names, + input_names=self._input_names, + background_samples=background_samples, + ) shap_values.__doc__ = LinearCateEstimator.shap_values.__doc__ @@ -1245,9 +1343,15 @@ def coef_(self): a vector and not a 2D array. For binary treatment the n_t dimension is also omitted. """ - return parse_final_model_params(self.model_final_.coef_, self.model_final_.intercept_, - self._d_y, self._d_t, self._d_t_in, self.bias_part_of_coef, - self.fit_cate_intercept_)[0] + return parse_final_model_params( + self.model_final_.coef_, + self.model_final_.intercept_, + self._d_y, + self._d_t, + self._d_t_in, + self.bias_part_of_coef, + self.fit_cate_intercept_, + )[0] @property def intercept_(self): @@ -1264,9 +1368,15 @@ def intercept_(self): """ if not self.fit_cate_intercept_: raise AttributeError("No intercept was fitted!") - return parse_final_model_params(self.model_final_.coef_, self.model_final_.intercept_, - self._d_y, self._d_t, self._d_t_in, self.bias_part_of_coef, - self.fit_cate_intercept_)[1] + return parse_final_model_params( + self.model_final_.coef_, + self.model_final_.intercept_, + self._d_y, + self._d_t, + self._d_t_in, + self.bias_part_of_coef, + self.fit_cate_intercept_, + )[1] def summary(self, decimals=3, feature_names=None, treatment_names=None, output_names=None): """ @@ -1302,11 +1412,11 @@ def summary(self, decimals=3, feature_names=None, treatment_names=None, output_n extra_txt.append("$Y = \\Theta(X)\\cdot \\psi(T) + g(X, W) + \\epsilon$") extra_txt.append("where $\\psi(T)$ is the output of the `treatment_featurizer") extra_txt.append( - "and for every outcome $i$ and featurized treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:") + "and for every outcome $i$ and featurized treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:" + ) else: extra_txt.append("$Y = \\Theta(X)\\cdot T + g(X, W) + \\epsilon$") - extra_txt.append( - "where for every outcome $i$ and treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:") + extra_txt.append("where for every outcome $i$ and treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:") if self.featurizer: extra_txt.append("$\\Theta_{ij}(X) = \\phi(X)' coef_{ij} + cate\\_intercept_{ij}$") @@ -1314,9 +1424,11 @@ def summary(self, decimals=3, feature_names=None, treatment_names=None, output_n else: extra_txt.append("$\\Theta_{ij}(X) = X' coef_{ij} + cate\\_intercept_{ij}$") - extra_txt.append("Coefficient Results table portrays the $coef_{ij}$ parameter vector for " - "each outcome $i$ and treatment $j$. " - "Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter.") + extra_txt.append( + "Coefficient Results table portrays the $coef_{ij}$ parameter vector for " + "each outcome $i$ and treatment $j$. " + "Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter." + ) smry.add_extra_txt(extra_txt) d_t = self._d_t[0] if self._d_t else 1 @@ -1540,37 +1652,42 @@ def true_heterogeneity_function(X): """ - def __init__(self, *, - model_y_xw="auto", - model_t_xw="auto", - model_t_xwz="auto", - model_final, - discrete_outcome=False, - discrete_treatment=False, - treatment_featurizer=None, - discrete_instrument=False, - featurizer=None, - categories='auto', - cv=2, - mc_iters=None, - mc_agg='mean', - random_state=None, - allow_missing=False): + def __init__( + self, + *, + model_y_xw="auto", + model_t_xw="auto", + model_t_xwz="auto", + model_final, + discrete_outcome=False, + discrete_treatment=False, + treatment_featurizer=None, + discrete_instrument=False, + featurizer=None, + categories='auto', + cv=2, + mc_iters=None, + mc_agg='mean', + random_state=None, + allow_missing=False, + ): self.model_y_xw = clone(model_y_xw, safe=False) self.model_t_xw = clone(model_t_xw, safe=False) self.model_t_xwz = clone(model_t_xwz, safe=False) self.model_final = clone(model_final, safe=False) self.featurizer = clone(featurizer, safe=False) - super().__init__(discrete_outcome=discrete_outcome, - discrete_treatment=discrete_treatment, - discrete_instrument=discrete_instrument, - treatment_featurizer=treatment_featurizer, - categories=categories, - cv=cv, - mc_iters=mc_iters, - mc_agg=mc_agg, - random_state=random_state, - allow_missing=allow_missing) + super().__init__( + discrete_outcome=discrete_outcome, + discrete_treatment=discrete_treatment, + discrete_instrument=discrete_instrument, + treatment_featurizer=treatment_featurizer, + categories=categories, + cv=cv, + mc_iters=mc_iters, + mc_agg=mc_agg, + random_state=random_state, + allow_missing=allow_missing, + ) def _gen_featurizer(self): return clone(self.featurizer, safe=False) @@ -1591,17 +1708,21 @@ def _gen_ortho_learner_model_nuisance(self): return _BaseDMLIVNuisanceSelector(self._gen_model_y_xw(), self._gen_model_t_xw(), self._gen_model_t_xwz()) def _gen_ortho_learner_model_final(self): - return _BaseDMLIVModelFinal(_FinalWrapper(self._gen_model_final(), - False, - self._gen_featurizer(), - True)) + return _BaseDMLIVModelFinal(_FinalWrapper(self._gen_model_final(), False, self._gen_featurizer(), True)) def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100): - return _shap_explain_model_cate(self.const_marginal_effect, self.model_cate, X, self._d_t, self._d_y, - featurizer=self.featurizer_, - feature_names=feature_names, - treatment_names=treatment_names, - output_names=output_names, - input_names=self._input_names, - background_samples=background_samples) + return _shap_explain_model_cate( + self.const_marginal_effect, + self.model_cate, + X, + self._d_t, + self._d_y, + featurizer=self.featurizer_, + feature_names=feature_names, + treatment_names=treatment_names, + output_names=output_names, + input_names=self._input_names, + background_samples=background_samples, + ) + shap_values.__doc__ = LinearCateEstimator.shap_values.__doc__ diff --git a/econml/iv/dr/_dr.py b/econml/iv/dr/_dr.py index 5a879004e..1c551c853 100644 --- a/econml/iv/dr/_dr.py +++ b/econml/iv/dr/_dr.py @@ -21,13 +21,23 @@ from ..._ortho_learner import _OrthoLearner -from ..._cate_estimator import (StatsModelsCateEstimatorMixin, DebiasedLassoCateEstimatorMixin, - ForestModelFinalCateEstimatorMixin, GenericSingleTreatmentModelFinalInference, - LinearCateEstimator) +from ..._cate_estimator import ( + StatsModelsCateEstimatorMixin, + DebiasedLassoCateEstimatorMixin, + ForestModelFinalCateEstimatorMixin, + GenericSingleTreatmentModelFinalInference, + LinearCateEstimator, +) from ...sklearn_extensions.linear_model import StatsModelsLinearRegression, DebiasedLasso from ...sklearn_extensions.model_selection import ModelSelector, SingleModelSelector -from ...utilities import (add_intercept, filter_none_kwargs, - inverse_onehot, get_feature_names_or_default, check_high_dimensional, check_input_arrays) +from ...utilities import ( + add_intercept, + filter_none_kwargs, + inverse_onehot, + get_feature_names_or_default, + check_high_dimensional, + check_input_arrays, +) from ...grf import RegressionForest from ...dml.dml import _make_first_stage_selector from ...iv.dml import NonParamDMLIV @@ -42,9 +52,9 @@ def _combine(W, Z, n_samples): class _BaseDRIVNuisanceSelector(ModelSelector): - def __init__(self, *, prel_model_effect, model_y_xw, model_t_xw, model_z, - projection, - discrete_treatment, discrete_instrument): + def __init__( + self, *, prel_model_effect, model_y_xw, model_t_xw, model_z, projection, discrete_treatment, discrete_instrument + ): self._prel_model_effect = prel_model_effect self._model_y_xw = model_y_xw self._model_t_xw = model_t_xw @@ -61,23 +71,29 @@ def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight T = T.ravel() if not self._discrete_treatment else T Z = Z.ravel() if not self._discrete_instrument else Z - self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups, **fit_params) - self._model_t_xw.train(is_selecting, folds, X=X, W=W, Target=T, sample_weight=sample_weight, groups=groups, **fit_params) + self._model_y_xw.train( + is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups, **fit_params + ) + self._model_t_xw.train( + is_selecting, folds, X=X, W=W, Target=T, sample_weight=sample_weight, groups=groups, **fit_params + ) if self._projection: WZ = _combine(W, Z, Y.shape[0]) - self._model_t_xwz.train(is_selecting, folds, X=X, W=WZ, Target=T, - sample_weight=sample_weight, groups=groups, **fit_params) + self._model_t_xwz.train( + is_selecting, folds, X=X, W=WZ, Target=T, sample_weight=sample_weight, groups=groups, **fit_params + ) else: - self._model_z_xw.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups, **fit_params) + self._model_z_xw.train( + is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups, **fit_params + ) # TODO: prel_model_effect could allow sample_var and freq_weight? if self._discrete_instrument: Z = inverse_onehot(Z) if self._discrete_treatment: T = inverse_onehot(T) - self._prel_model_effect.fit(Y, T, Z=Z, X=X, - W=W, sample_weight=sample_weight, groups=groups, **fit_params) + self._prel_model_effect.fit(Y, T, Z=Z, X=X, W=W, sample_weight=sample_weight, groups=groups, **fit_params) return self def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): @@ -99,8 +115,7 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): # since it expects raw values raw_T = inverse_onehot(T) if self._discrete_treatment else T raw_Z = inverse_onehot(Z) if self._discrete_instrument else Z - effect_score = self._prel_model_effect.score(Y, raw_T, - Z=raw_Z, X=X, W=W, sample_weight=sample_weight) + effect_score = self._prel_model_effect.score(Y, raw_T, Z=raw_Z, X=X, W=W, sample_weight=sample_weight) else: effect_score = None @@ -133,8 +148,12 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None) T_pred = np.tile(T_pred.reshape(1, -1), (Y.shape[0], 1)) # for convenience, reshape Z,T to a vector since they are either binary or single dimensional continuous - T = T.reshape(T.shape[0],) - Z = Z.reshape(Z.shape[0],) + T = T.reshape( + T.shape[0], + ) + Z = Z.reshape( + Z.shape[0], + ) # reshape the predictions Y_pred = Y_pred.reshape(Y.shape) T_pred = T_pred.reshape(T.shape) @@ -163,9 +182,7 @@ def predict(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None) class _BaseDRIVNuisanceCovarianceSelector(ModelSelector): - def __init__(self, *, model_tz_xw, - projection, fit_cov_directly, - discrete_treatment, discrete_instrument): + def __init__(self, *, model_tz_xw, projection, fit_cov_directly, discrete_treatment, discrete_instrument): self._model_tz_xw = model_tz_xw self._projection = projection self._fit_cov_directly = fit_cov_directly @@ -189,7 +206,9 @@ def _get_target(self, T_res, Z_res, T, Z): # return shape (n,) T_pred = T - T_res.reshape(T.shape) T_proj = T_pred + Z_res.reshape(T.shape) - target = (T * T_proj).reshape(T.shape[0],) + target = (T * T_proj).reshape( + T.shape[0], + ) else: if self._fit_cov_directly: # we will fit on the covariance (T_res*Z_res) directly @@ -213,13 +232,28 @@ def _get_target(self, T_res, Z_res, T, Z): target = T * Z return target - def train(self, is_selecting, folds, - prel_theta, Y_res, T_res, Z_res, - Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params): + def train( + self, + is_selecting, + folds, + prel_theta, + Y_res, + T_res, + Z_res, + Y, + T, + X=None, + W=None, + Z=None, + sample_weight=None, + groups=None, + **fit_params, + ): # T and Z only allow single continuous or binary, keep the shape of (n,) for continuous and (n,1) for binary target = self._get_target(T_res, Z_res, T, Z) - self._model_tz_xw.train(is_selecting, folds, X=X, W=W, Target=target, - sample_weight=sample_weight, groups=groups, **fit_params) + self._model_tz_xw.train( + is_selecting, folds, X=X, W=W, Target=target, sample_weight=sample_weight, groups=groups, **fit_params + ) return self @@ -240,8 +274,12 @@ def predict(self, prel_theta, Y_res, T_res, Z_res, Y, T, X=None, W=None, Z=None, TZ_pred = np.tile(TZ_pred.reshape(1, -1), (Y.shape[0], 1)) # for convenience, reshape Z,T to a vector since they are either binary or single dimensional continuous - T = T.reshape(T.shape[0],) - Z = Z.reshape(Z.shape[0],) + T = T.reshape( + T.shape[0], + ) + Z = Z.reshape( + Z.shape[0], + ) # reshape the predictions TZ_pred = TZ_pred.reshape(T.shape) @@ -277,11 +315,11 @@ def __init__(self, model_final, featurizer, fit_cate_intercept, cov_clip, opt_re self._opt_reweighted = opt_reweighted if self._fit_cate_intercept: - add_intercept_trans = FunctionTransformer(add_intercept, - validate=True) + add_intercept_trans = FunctionTransformer(add_intercept, validate=True) if featurizer: - self._featurizer = Pipeline([('featurize', self._original_featurizer), - ('add_intercept', add_intercept_trans)]) + self._featurizer = Pipeline( + [('featurize', self._original_featurizer), ('add_intercept', add_intercept_trans)] + ) else: self._featurizer = add_intercept_trans else: @@ -298,8 +336,7 @@ def _effect_estimate(self, nuisances): # to the model-based preliminary estimate and do not add the correction term. cov_sign = np.sign(cov) cov_sign[cov_sign == 0] = 1 - clipped_cov = cov_sign * np.clip(np.abs(cov), - self._cov_clip, np.inf) + clipped_cov = cov_sign * np.clip(np.abs(cov), self._cov_clip, np.inf) return prel_theta + (res_y - prel_theta * res_t) * res_z / clipped_cov, clipped_cov, res_z def _transform_X(self, X, n=1, fitting=True): @@ -314,20 +351,34 @@ def _transform_X(self, X, n=1, fitting=True): F = np.ones((n, 1)) return F - def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None, - sample_weight=None, freq_weight=None, sample_var=None, groups=None): + def fit( + self, + Y, + T, + X=None, + W=None, + Z=None, + nuisances=None, + sample_weight=None, + freq_weight=None, + sample_var=None, + groups=None, + ): self.d_y = Y.shape[1:] self.d_t = T.shape[1:] theta_dr, clipped_cov, res_z = self._effect_estimate(nuisances) X = self._transform_X(X, n=theta_dr.shape[0]) if self._opt_reweighted and (sample_weight is not None): - sample_weight = sample_weight * clipped_cov.ravel()**2 + sample_weight = sample_weight * clipped_cov.ravel() ** 2 elif self._opt_reweighted: - sample_weight = clipped_cov.ravel()**2 + sample_weight = clipped_cov.ravel() ** 2 target_var = sample_var * (res_z**2 / clipped_cov**2) if sample_var is not None else None - self._model_final.fit(X, theta_dr, **filter_none_kwargs(sample_weight=sample_weight, - freq_weight=freq_weight, sample_var=target_var)) + self._model_final.fit( + X, + theta_dr, + **filter_none_kwargs(sample_weight=sample_weight, freq_weight=freq_weight, sample_var=target_var), + ) return self def predict(self, X=None): @@ -340,52 +391,56 @@ def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None X = self._transform_X(X, fitting=False) if self._opt_reweighted and (sample_weight is not None): - sample_weight = sample_weight * clipped_cov.ravel()**2 + sample_weight = sample_weight * clipped_cov.ravel() ** 2 elif self._opt_reweighted: - sample_weight = clipped_cov.ravel()**2 + sample_weight = clipped_cov.ravel() ** 2 - return np.average((theta_dr.ravel() - self._model_final.predict(X).ravel())**2, - weights=sample_weight, axis=0) + return np.average((theta_dr.ravel() - self._model_final.predict(X).ravel()) ** 2, weights=sample_weight, axis=0) class _BaseDRIV(_OrthoLearner): # A helper class that access all the internal fitted objects of a DRIV Cate Estimator. # Used by both DRIV and IntentToTreatDRIV. - def __init__(self, *, - model_final, - featurizer=None, - fit_cate_intercept=False, - cov_clip=1e-3, - opt_reweighted=False, - discrete_outcome=False, - discrete_instrument=False, - discrete_treatment=False, - treatment_featurizer=None, - categories='auto', - cv=2, - mc_iters=None, - mc_agg='mean', - random_state=None, - allow_missing=False, - use_ray=False, - ray_remote_func_options=None): + def __init__( + self, + *, + model_final, + featurizer=None, + fit_cate_intercept=False, + cov_clip=1e-3, + opt_reweighted=False, + discrete_outcome=False, + discrete_instrument=False, + discrete_treatment=False, + treatment_featurizer=None, + categories='auto', + cv=2, + mc_iters=None, + mc_agg='mean', + random_state=None, + allow_missing=False, + use_ray=False, + ray_remote_func_options=None, + ): self.model_final = clone(model_final, safe=False) self.featurizer = clone(featurizer, safe=False) self.fit_cate_intercept = fit_cate_intercept self.cov_clip = cov_clip self.opt_reweighted = opt_reweighted - super().__init__(discrete_outcome=discrete_outcome, - discrete_instrument=discrete_instrument, - discrete_treatment=discrete_treatment, - treatment_featurizer=treatment_featurizer, - categories=categories, - cv=cv, - mc_iters=mc_iters, - mc_agg=mc_agg, - random_state=random_state, - allow_missing=allow_missing, - use_ray=use_ray, - ray_remote_func_options=ray_remote_func_options) + super().__init__( + discrete_outcome=discrete_outcome, + discrete_instrument=discrete_instrument, + discrete_treatment=discrete_treatment, + treatment_featurizer=treatment_featurizer, + categories=categories, + cv=cv, + mc_iters=mc_iters, + mc_agg=mc_agg, + random_state=random_state, + allow_missing=allow_missing, + use_ray=use_ray, + ray_remote_func_options=ray_remote_func_options, + ) def _gen_allowed_missing_vars(self): return ['W'] if self.allow_missing else [] @@ -403,11 +458,16 @@ def _gen_model_final(self): return clone(self.model_final, safe=False) def _gen_ortho_learner_model_final(self): - return _BaseDRIVModelFinal(self._gen_model_final(), self._gen_featurizer(), self.fit_cate_intercept, - self.cov_clip, self.opt_reweighted) + return _BaseDRIVModelFinal( + self._gen_model_final(), self._gen_featurizer(), self.fit_cate_intercept, self.cov_clip, self.opt_reweighted + ) def _check_inputs(self, Y, T, Z, X, W): - Y1, T1, Z1, = check_input_arrays(Y, T, Z) + ( + Y1, + T1, + Z1, + ) = check_input_arrays(Y, T, Z) if len(Y1.shape) > 1 and Y1.shape[1] > 1: raise AssertionError("DRIV only supports single dimensional outcome") if len(T1.shape) > 1 and T1.shape[1] > 1: @@ -424,8 +484,21 @@ def _check_inputs(self, Y, T, Z, X, W): raise AttributeError("DRIV only supports single-dimensional continuous instruments") return Y, T, Z, X, W - def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None, - cache_values=False, inference="auto"): + def fit( + self, + Y, + T, + *, + Z, + X=None, + W=None, + sample_weight=None, + freq_weight=None, + sample_var=None, + groups=None, + cache_values=False, + inference="auto", + ): """ Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`. @@ -467,12 +540,23 @@ def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, freq_weight=None, """ Y, T, Z, X, W = self._check_inputs(Y, T, Z, X, W) # Replacing fit from _OrthoLearner, to reorder arguments and improve the docstring - return super().fit(Y, T, X=X, W=W, Z=Z, - sample_weight=sample_weight, freq_weight=freq_weight, sample_var=sample_var, groups=groups, - cache_values=cache_values, inference=inference) + return super().fit( + Y, + T, + X=X, + W=W, + Z=Z, + sample_weight=sample_weight, + freq_weight=freq_weight, + sample_var=sample_var, + groups=groups, + cache_values=cache_values, + inference=inference, + ) def refit_final(self, *, inference='auto'): return super().refit_final(inference=inference) + refit_final.__doc__ = _OrthoLearner.refit_final.__doc__ def score(self, Y, T, Z, X=None, W=None, sample_weight=None): @@ -575,13 +659,20 @@ def model_cate(self): return self.ortho_learner_model_final_._model_final def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100): - return _shap_explain_model_cate(self.const_marginal_effect, self.model_cate, X, self._d_t, self._d_y, - featurizer=self.featurizer_, - feature_names=feature_names, - treatment_names=treatment_names, - output_names=output_names, - input_names=self._input_names, - background_samples=background_samples) + return _shap_explain_model_cate( + self.const_marginal_effect, + self.model_cate, + X, + self._d_t, + self._d_y, + featurizer=self.featurizer_, + feature_names=feature_names, + treatment_names=treatment_names, + output_names=output_names, + input_names=self._input_names, + background_samples=background_samples, + ) + shap_values.__doc__ = LinearCateEstimator.shap_values.__doc__ @property @@ -596,43 +687,54 @@ def residuals_(self): if not hasattr(self, '_cached_values'): raise AttributeError("Estimator is not fitted yet!") if self._cached_values is None: - raise AttributeError("`fit` was called with `cache_values=False`. " - "Set to `True` to enable residual storage.") + raise AttributeError( + "`fit` was called with `cache_values=False`. Set to `True` to enable residual storage." + ) prel_theta, Y_res, T_res, Z_res, cov = self._cached_values.nuisances - return (prel_theta, Y_res, T_res, Z_res, cov, self._cached_values.X, self._cached_values.W, - self._cached_values.Z) + return ( + prel_theta, + Y_res, + T_res, + Z_res, + cov, + self._cached_values.X, + self._cached_values.W, + self._cached_values.Z, + ) class _DRIV(_BaseDRIV): """Private Base class for the DRIV algorithm.""" - def __init__(self, *, - model_y_xw="auto", - model_t_xw="auto", - model_z_xw="auto", - model_t_xwz="auto", - model_tz_xw="auto", - fit_cov_directly=True, - prel_model_effect, - model_final, - projection=False, - featurizer=None, - fit_cate_intercept=False, - cov_clip=1e-3, - opt_reweighted=False, - discrete_outcome=False, - discrete_instrument=False, - discrete_treatment=False, - treatment_featurizer=None, - categories='auto', - cv=2, - mc_iters=None, - mc_agg='mean', - random_state=None, - allow_missing=False, - use_ray=False, - ray_remote_func_options=None - ): + def __init__( + self, + *, + model_y_xw="auto", + model_t_xw="auto", + model_z_xw="auto", + model_t_xwz="auto", + model_tz_xw="auto", + fit_cov_directly=True, + prel_model_effect, + model_final, + projection=False, + featurizer=None, + fit_cate_intercept=False, + cov_clip=1e-3, + opt_reweighted=False, + discrete_outcome=False, + discrete_instrument=False, + discrete_treatment=False, + treatment_featurizer=None, + categories='auto', + cv=2, + mc_iters=None, + mc_agg='mean', + random_state=None, + allow_missing=False, + use_ray=False, + ray_remote_func_options=None, + ): self.model_y_xw = clone(model_y_xw, safe=False) self.model_t_xw = clone(model_t_xw, safe=False) self.model_t_xwz = clone(model_t_xwz, safe=False) @@ -641,23 +743,25 @@ def __init__(self, *, self.prel_model_effect = clone(prel_model_effect, safe=False) self.projection = projection self.fit_cov_directly = fit_cov_directly - super().__init__(model_final=model_final, - featurizer=featurizer, - fit_cate_intercept=fit_cate_intercept, - cov_clip=cov_clip, - opt_reweighted=opt_reweighted, - discrete_outcome=discrete_outcome, - discrete_instrument=discrete_instrument, - discrete_treatment=discrete_treatment, - treatment_featurizer=treatment_featurizer, - categories=categories, - cv=cv, - mc_iters=mc_iters, - mc_agg=mc_agg, - random_state=random_state, - allow_missing=allow_missing, - use_ray=use_ray, - ray_remote_func_options=ray_remote_func_options) + super().__init__( + model_final=model_final, + featurizer=featurizer, + fit_cate_intercept=fit_cate_intercept, + cov_clip=cov_clip, + opt_reweighted=opt_reweighted, + discrete_outcome=discrete_outcome, + discrete_instrument=discrete_instrument, + discrete_treatment=discrete_treatment, + treatment_featurizer=treatment_featurizer, + categories=categories, + cv=cv, + mc_iters=mc_iters, + mc_agg=mc_agg, + random_state=random_state, + allow_missing=allow_missing, + use_ray=use_ray, + ray_remote_func_options=ray_remote_func_options, + ) def _gen_prel_model_effect(self): return clone(self.prel_model_effect, safe=False) @@ -668,38 +772,44 @@ def _gen_ortho_learner_model_nuisance(self): if self.projection: # this is a regression model since the instrument E[T|X,W,Z] is always continuous - model_tz_xw = _make_first_stage_selector(self.model_tz_xw, - is_discrete=False, - random_state=self.random_state) + model_tz_xw = _make_first_stage_selector( + self.model_tz_xw, is_discrete=False, random_state=self.random_state + ) # we're using E[T|X,W,Z] as the instrument - model_z = _make_first_stage_selector(self.model_t_xwz, - is_discrete=self.discrete_treatment, - random_state=self.random_state) + model_z = _make_first_stage_selector( + self.model_t_xwz, is_discrete=self.discrete_treatment, random_state=self.random_state + ) else: - model_tz_xw = _make_first_stage_selector(self.model_tz_xw, - is_discrete=(self.discrete_treatment and - self.discrete_instrument and - not self.fit_cov_directly), - random_state=self.random_state) - - model_z = _make_first_stage_selector(self.model_z_xw, - is_discrete=self.discrete_instrument, - random_state=self.random_state) - - return [_BaseDRIVNuisanceSelector(prel_model_effect=self._gen_prel_model_effect(), - model_y_xw=model_y_xw, - model_t_xw=model_t_xw, - model_z=model_z, - projection=self.projection, - discrete_treatment=self.discrete_treatment, - discrete_instrument=self.discrete_instrument), - _BaseDRIVNuisanceCovarianceSelector(model_tz_xw=model_tz_xw, - projection=self.projection, - fit_cov_directly=self.fit_cov_directly, - discrete_treatment=self.discrete_treatment, - discrete_instrument=self.discrete_instrument)] + model_tz_xw = _make_first_stage_selector( + self.model_tz_xw, + is_discrete=(self.discrete_treatment and self.discrete_instrument and not self.fit_cov_directly), + random_state=self.random_state, + ) + + model_z = _make_first_stage_selector( + self.model_z_xw, is_discrete=self.discrete_instrument, random_state=self.random_state + ) + + return [ + _BaseDRIVNuisanceSelector( + prel_model_effect=self._gen_prel_model_effect(), + model_y_xw=model_y_xw, + model_t_xw=model_t_xw, + model_z=model_z, + projection=self.projection, + discrete_treatment=self.discrete_treatment, + discrete_instrument=self.discrete_instrument, + ), + _BaseDRIVNuisanceCovarianceSelector( + model_tz_xw=model_tz_xw, + projection=self.projection, + fit_cov_directly=self.fit_cov_directly, + discrete_treatment=self.discrete_treatment, + discrete_instrument=self.discrete_instrument, + ), + ] class DRIV(_DRIV): @@ -904,36 +1014,38 @@ def true_heterogeneity_function(X): array([-4.15076..., 5.99286..., -2.86512...]) """ - def __init__(self, *, - model_y_xw="auto", - model_t_xw="auto", - model_z_xw="auto", - model_t_xwz="auto", - model_tz_xw="auto", - fit_cov_directly=True, - flexible_model_effect="auto", - model_final=None, - prel_cate_approach="driv", - prel_cv=1, - prel_opt_reweighted=True, - projection=False, - featurizer=None, - fit_cate_intercept=False, - cov_clip=1e-3, - opt_reweighted=False, - discrete_outcome=False, - discrete_instrument=False, - discrete_treatment=False, - treatment_featurizer=None, - categories='auto', - cv=2, - mc_iters=None, - mc_agg='mean', - random_state=None, - allow_missing=False, - use_ray=False, - ray_remote_func_options=None - ): + def __init__( + self, + *, + model_y_xw="auto", + model_t_xw="auto", + model_z_xw="auto", + model_t_xwz="auto", + model_tz_xw="auto", + fit_cov_directly=True, + flexible_model_effect="auto", + model_final=None, + prel_cate_approach="driv", + prel_cv=1, + prel_opt_reweighted=True, + projection=False, + featurizer=None, + fit_cate_intercept=False, + cov_clip=1e-3, + opt_reweighted=False, + discrete_outcome=False, + discrete_instrument=False, + discrete_treatment=False, + treatment_featurizer=None, + categories='auto', + cv=2, + mc_iters=None, + mc_agg='mean', + random_state=None, + allow_missing=False, + use_ray=False, + ray_remote_func_options=None, + ): if flexible_model_effect == "auto": self.flexible_model_effect = StatsModelsLinearRegression(fit_intercept=False) else: @@ -941,31 +1053,33 @@ def __init__(self, *, self.prel_cate_approach = prel_cate_approach self.prel_cv = prel_cv self.prel_opt_reweighted = prel_opt_reweighted - super().__init__(model_y_xw=model_y_xw, - model_t_xw=model_t_xw, - model_z_xw=model_z_xw, - model_t_xwz=model_t_xwz, - model_tz_xw=model_tz_xw, - fit_cov_directly=fit_cov_directly, - prel_model_effect=self.prel_cate_approach, - model_final=model_final, - projection=projection, - featurizer=featurizer, - fit_cate_intercept=fit_cate_intercept, - cov_clip=cov_clip, - opt_reweighted=opt_reweighted, - discrete_outcome=discrete_outcome, - discrete_instrument=discrete_instrument, - discrete_treatment=discrete_treatment, - treatment_featurizer=treatment_featurizer, - categories=categories, - cv=cv, - mc_iters=mc_iters, - mc_agg=mc_agg, - random_state=random_state, - allow_missing=allow_missing, - use_ray=use_ray, - ray_remote_func_options=ray_remote_func_options) + super().__init__( + model_y_xw=model_y_xw, + model_t_xw=model_t_xw, + model_z_xw=model_z_xw, + model_t_xwz=model_t_xwz, + model_tz_xw=model_tz_xw, + fit_cov_directly=fit_cov_directly, + prel_model_effect=self.prel_cate_approach, + model_final=model_final, + projection=projection, + featurizer=featurizer, + fit_cate_intercept=fit_cate_intercept, + cov_clip=cov_clip, + opt_reweighted=opt_reweighted, + discrete_outcome=discrete_outcome, + discrete_instrument=discrete_instrument, + discrete_treatment=discrete_treatment, + treatment_featurizer=treatment_featurizer, + categories=categories, + cv=cv, + mc_iters=mc_iters, + mc_agg=mc_agg, + random_state=random_state, + allow_missing=allow_missing, + use_ray=use_ray, + ray_remote_func_options=ray_remote_func_options, + ) def _gen_model_final(self): if self.model_final is None: @@ -974,50 +1088,67 @@ def _gen_model_final(self): def _gen_prel_model_effect(self): if self.prel_cate_approach == "driv": - return _DRIV(model_y_xw=clone(self.model_y_xw, safe=False), - model_t_xw=clone(self.model_t_xw, safe=False), - model_z_xw=clone(self.model_z_xw, safe=False), - model_t_xwz=clone(self.model_t_xwz, safe=False), - model_tz_xw=clone(self.model_tz_xw, safe=False), - prel_model_effect=_DummyCATE(), - model_final=clone(self.flexible_model_effect, safe=False), - projection=self.projection, - fit_cov_directly=self.fit_cov_directly, - featurizer=self._gen_featurizer(), - fit_cate_intercept=self.fit_cate_intercept, - cov_clip=self.cov_clip, - opt_reweighted=self.prel_opt_reweighted, - discrete_instrument=self.discrete_instrument, - discrete_treatment=self.discrete_treatment, - discrete_outcome=self.discrete_outcome, - categories=self.categories, - cv=self.prel_cv, - mc_iters=self.mc_iters, - mc_agg=self.mc_agg, - random_state=self.random_state, - allow_missing=self.allow_missing) + return _DRIV( + model_y_xw=clone(self.model_y_xw, safe=False), + model_t_xw=clone(self.model_t_xw, safe=False), + model_z_xw=clone(self.model_z_xw, safe=False), + model_t_xwz=clone(self.model_t_xwz, safe=False), + model_tz_xw=clone(self.model_tz_xw, safe=False), + prel_model_effect=_DummyCATE(), + model_final=clone(self.flexible_model_effect, safe=False), + projection=self.projection, + fit_cov_directly=self.fit_cov_directly, + featurizer=self._gen_featurizer(), + fit_cate_intercept=self.fit_cate_intercept, + cov_clip=self.cov_clip, + opt_reweighted=self.prel_opt_reweighted, + discrete_instrument=self.discrete_instrument, + discrete_treatment=self.discrete_treatment, + discrete_outcome=self.discrete_outcome, + categories=self.categories, + cv=self.prel_cv, + mc_iters=self.mc_iters, + mc_agg=self.mc_agg, + random_state=self.random_state, + allow_missing=self.allow_missing, + ) elif self.prel_cate_approach == "dmliv": - return NonParamDMLIV(model_y_xw=clone(self.model_y_xw, safe=False), - model_t_xw=clone(self.model_t_xw, safe=False), - model_t_xwz=clone(self.model_t_xwz, safe=False), - model_final=clone(self.flexible_model_effect, safe=False), - discrete_instrument=self.discrete_instrument, - discrete_treatment=self.discrete_treatment, - discrete_outcome=self.discrete_outcome, - featurizer=self._gen_featurizer(), - categories=self.categories, - cv=self.prel_cv, - mc_iters=self.mc_iters, - mc_agg=self.mc_agg, - random_state=self.random_state, - allow_missing=self.allow_missing) + return NonParamDMLIV( + model_y_xw=clone(self.model_y_xw, safe=False), + model_t_xw=clone(self.model_t_xw, safe=False), + model_t_xwz=clone(self.model_t_xwz, safe=False), + model_final=clone(self.flexible_model_effect, safe=False), + discrete_instrument=self.discrete_instrument, + discrete_treatment=self.discrete_treatment, + discrete_outcome=self.discrete_outcome, + featurizer=self._gen_featurizer(), + categories=self.categories, + cv=self.prel_cv, + mc_iters=self.mc_iters, + mc_agg=self.mc_agg, + random_state=self.random_state, + allow_missing=self.allow_missing, + ) else: raise ValueError( - "We only support 'dmliv' or 'driv' preliminary model effect, " - f"but received '{self.prel_cate_approach}'!") - - def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None, - cache_values=False, inference="auto"): + f"We only support 'dmliv' or 'driv' preliminary model effect, but received '{self.prel_cate_approach}'!" + ) + + def fit( + self, + Y, + T, + *, + Z, + X=None, + W=None, + sample_weight=None, + freq_weight=None, + sample_var=None, + groups=None, + cache_values=False, + inference="auto", + ): """ Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`. @@ -1058,15 +1189,28 @@ def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, freq_weight=None, self """ if self.projection: - assert self.model_z_xw == "auto", ("In the case of projection=True, model_z_xw will not be fitted, " - "please keep it as default!") + assert self.model_z_xw == "auto", ( + "In the case of projection=True, model_z_xw will not be fitted, please keep it as default!" + ) if self.prel_cate_approach == "driv" and not self.projection: - assert self.model_t_xwz == "auto", ("In the case of projection=False and prel_cate_approach='driv', " - "model_t_xwz will not be fitted, " - "please keep it as default!") - return super().fit(Y, T, X=X, W=W, Z=Z, - sample_weight=sample_weight, freq_weight=freq_weight, sample_var=sample_var, groups=groups, - cache_values=cache_values, inference=inference) + assert self.model_t_xwz == "auto", ( + "In the case of projection=False and prel_cate_approach='driv', " + "model_t_xwz will not be fitted, " + "please keep it as default!" + ) + return super().fit( + Y, + T, + X=X, + W=W, + Z=Z, + sample_weight=sample_weight, + freq_weight=freq_weight, + sample_var=sample_var, + groups=groups, + cache_values=cache_values, + inference=inference, + ) @property def models_y_xw(self): @@ -1405,69 +1549,86 @@ def true_heterogeneity_function(X): (-1.27151..., 1.01512...) """ - def __init__(self, *, - model_y_xw="auto", - model_t_xw="auto", - model_z_xw="auto", - model_t_xwz="auto", - model_tz_xw="auto", - fit_cov_directly=True, - flexible_model_effect="auto", - prel_cate_approach="driv", - prel_cv=1, - prel_opt_reweighted=True, - projection=False, - featurizer=None, - fit_cate_intercept=True, - cov_clip=1e-3, - opt_reweighted=False, - discrete_outcome=False, - discrete_instrument=False, - discrete_treatment=False, - treatment_featurizer=None, - categories='auto', - cv=2, - mc_iters=None, - mc_agg='mean', - random_state=None, - allow_missing=False, - use_ray=False, - ray_remote_func_options=None - ): - super().__init__(model_y_xw=model_y_xw, - model_t_xw=model_t_xw, - model_z_xw=model_z_xw, - model_t_xwz=model_t_xwz, - model_tz_xw=model_tz_xw, - fit_cov_directly=fit_cov_directly, - flexible_model_effect=flexible_model_effect, - model_final=None, - prel_cate_approach=prel_cate_approach, - prel_cv=prel_cv, - prel_opt_reweighted=prel_opt_reweighted, - projection=projection, - featurizer=featurizer, - fit_cate_intercept=fit_cate_intercept, - cov_clip=cov_clip, - opt_reweighted=opt_reweighted, - discrete_outcome=discrete_outcome, - discrete_instrument=discrete_instrument, - discrete_treatment=discrete_treatment, - treatment_featurizer=treatment_featurizer, - categories=categories, - cv=cv, - mc_iters=mc_iters, - mc_agg=mc_agg, - random_state=random_state, - allow_missing=allow_missing, - use_ray=use_ray, - ray_remote_func_options=ray_remote_func_options) + def __init__( + self, + *, + model_y_xw="auto", + model_t_xw="auto", + model_z_xw="auto", + model_t_xwz="auto", + model_tz_xw="auto", + fit_cov_directly=True, + flexible_model_effect="auto", + prel_cate_approach="driv", + prel_cv=1, + prel_opt_reweighted=True, + projection=False, + featurizer=None, + fit_cate_intercept=True, + cov_clip=1e-3, + opt_reweighted=False, + discrete_outcome=False, + discrete_instrument=False, + discrete_treatment=False, + treatment_featurizer=None, + categories='auto', + cv=2, + mc_iters=None, + mc_agg='mean', + random_state=None, + allow_missing=False, + use_ray=False, + ray_remote_func_options=None, + ): + super().__init__( + model_y_xw=model_y_xw, + model_t_xw=model_t_xw, + model_z_xw=model_z_xw, + model_t_xwz=model_t_xwz, + model_tz_xw=model_tz_xw, + fit_cov_directly=fit_cov_directly, + flexible_model_effect=flexible_model_effect, + model_final=None, + prel_cate_approach=prel_cate_approach, + prel_cv=prel_cv, + prel_opt_reweighted=prel_opt_reweighted, + projection=projection, + featurizer=featurizer, + fit_cate_intercept=fit_cate_intercept, + cov_clip=cov_clip, + opt_reweighted=opt_reweighted, + discrete_outcome=discrete_outcome, + discrete_instrument=discrete_instrument, + discrete_treatment=discrete_treatment, + treatment_featurizer=treatment_featurizer, + categories=categories, + cv=cv, + mc_iters=mc_iters, + mc_agg=mc_agg, + random_state=random_state, + allow_missing=allow_missing, + use_ray=use_ray, + ray_remote_func_options=ray_remote_func_options, + ) def _gen_model_final(self): return StatsModelsLinearRegression(fit_intercept=False) - def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None, - cache_values=False, inference='auto'): + def fit( + self, + Y, + T, + *, + Z, + X=None, + W=None, + sample_weight=None, + freq_weight=None, + sample_var=None, + groups=None, + cache_values=False, + inference='auto', + ): """ Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`. @@ -1508,9 +1669,19 @@ def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, freq_weight=None, self """ # Replacing fit from _OrthoLearner, to reorder arguments and improve the docstring - return super().fit(Y, T, X=X, W=W, Z=Z, - sample_weight=sample_weight, freq_weight=freq_weight, sample_var=sample_var, groups=groups, - cache_values=cache_values, inference=inference) + return super().fit( + Y, + T, + X=X, + W=W, + Z=Z, + sample_weight=sample_weight, + freq_weight=freq_weight, + sample_var=sample_var, + groups=groups, + cache_values=cache_values, + inference=inference, + ) @property def fit_cate_intercept_(self): @@ -1774,41 +1945,44 @@ def true_heterogeneity_function(X): (-1.20712..., 1.07641...) """ - def __init__(self, *, - model_y_xw="auto", - model_t_xw="auto", - model_z_xw="auto", - model_t_xwz="auto", - model_tz_xw="auto", - fit_cov_directly=True, - flexible_model_effect="auto", - prel_cate_approach="driv", - prel_cv=1, - prel_opt_reweighted=True, - projection=False, - featurizer=None, - fit_cate_intercept=True, - alpha='auto', - n_alphas=100, - alpha_cov='auto', - n_alphas_cov=10, - max_iter=1000, - tol=1e-4, - n_jobs=None, - cov_clip=1e-3, - opt_reweighted=False, - discrete_outcome=False, - discrete_instrument=False, - discrete_treatment=False, - treatment_featurizer=None, - categories='auto', - cv=2, - mc_iters=None, - mc_agg='mean', - random_state=None, - allow_missing=False, - use_ray=False, - ray_remote_func_options=None): + def __init__( + self, + *, + model_y_xw="auto", + model_t_xw="auto", + model_z_xw="auto", + model_t_xwz="auto", + model_tz_xw="auto", + fit_cov_directly=True, + flexible_model_effect="auto", + prel_cate_approach="driv", + prel_cv=1, + prel_opt_reweighted=True, + projection=False, + featurizer=None, + fit_cate_intercept=True, + alpha='auto', + n_alphas=100, + alpha_cov='auto', + n_alphas_cov=10, + max_iter=1000, + tol=1e-4, + n_jobs=None, + cov_clip=1e-3, + opt_reweighted=False, + discrete_outcome=False, + discrete_instrument=False, + discrete_treatment=False, + treatment_featurizer=None, + categories='auto', + cv=2, + mc_iters=None, + mc_agg='mean', + random_state=None, + allow_missing=False, + use_ray=False, + ray_remote_func_options=None, + ): self.alpha = alpha self.n_alphas = n_alphas self.alpha_cov = alpha_cov @@ -1816,49 +1990,51 @@ def __init__(self, *, self.max_iter = max_iter self.tol = tol self.n_jobs = n_jobs - super().__init__(model_y_xw=model_y_xw, - model_t_xw=model_t_xw, - model_z_xw=model_z_xw, - model_t_xwz=model_t_xwz, - model_tz_xw=model_tz_xw, - fit_cov_directly=fit_cov_directly, - flexible_model_effect=flexible_model_effect, - model_final=None, - prel_cate_approach=prel_cate_approach, - prel_cv=prel_cv, - prel_opt_reweighted=prel_opt_reweighted, - projection=projection, - featurizer=featurizer, - fit_cate_intercept=fit_cate_intercept, - cov_clip=cov_clip, - opt_reweighted=opt_reweighted, - discrete_outcome=discrete_outcome, - discrete_instrument=discrete_instrument, - discrete_treatment=discrete_treatment, - treatment_featurizer=treatment_featurizer, - categories=categories, - cv=cv, - mc_iters=mc_iters, - mc_agg=mc_agg, - random_state=random_state, - allow_missing=allow_missing, - use_ray=use_ray, - ray_remote_func_options=ray_remote_func_options - ) + super().__init__( + model_y_xw=model_y_xw, + model_t_xw=model_t_xw, + model_z_xw=model_z_xw, + model_t_xwz=model_t_xwz, + model_tz_xw=model_tz_xw, + fit_cov_directly=fit_cov_directly, + flexible_model_effect=flexible_model_effect, + model_final=None, + prel_cate_approach=prel_cate_approach, + prel_cv=prel_cv, + prel_opt_reweighted=prel_opt_reweighted, + projection=projection, + featurizer=featurizer, + fit_cate_intercept=fit_cate_intercept, + cov_clip=cov_clip, + opt_reweighted=opt_reweighted, + discrete_outcome=discrete_outcome, + discrete_instrument=discrete_instrument, + discrete_treatment=discrete_treatment, + treatment_featurizer=treatment_featurizer, + categories=categories, + cv=cv, + mc_iters=mc_iters, + mc_agg=mc_agg, + random_state=random_state, + allow_missing=allow_missing, + use_ray=use_ray, + ray_remote_func_options=ray_remote_func_options, + ) def _gen_model_final(self): - return DebiasedLasso(alpha=self.alpha, - n_alphas=self.n_alphas, - alpha_cov=self.alpha_cov, - n_alphas_cov=self.n_alphas_cov, - fit_intercept=False, - max_iter=self.max_iter, - tol=self.tol, - n_jobs=self.n_jobs, - random_state=self.random_state) - - def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, groups=None, - cache_values=False, inference='auto'): + return DebiasedLasso( + alpha=self.alpha, + n_alphas=self.n_alphas, + alpha_cov=self.alpha_cov, + n_alphas_cov=self.n_alphas_cov, + fit_intercept=False, + max_iter=self.max_iter, + tol=self.tol, + n_jobs=self.n_jobs, + random_state=self.random_state, + ) + + def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, groups=None, cache_values=False, inference='auto'): """ Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`. @@ -1893,13 +2069,26 @@ def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, groups=None, """ # TODO: support freq_weight and sample_var in debiased lasso # Replacing fit from _OrthoLearner, to reorder arguments and improve the docstring - check_high_dimensional(X, T, threshold=5, featurizer=self.featurizer, - discrete_treatment=self.discrete_treatment, - msg="The number of features in the final model (< 5) is too small for a sparse model. " - "We recommend using the LinearDRLearner for this low-dimensional setting.") - return super().fit(Y, T, X=X, W=W, Z=Z, - sample_weight=sample_weight, groups=groups, - cache_values=cache_values, inference=inference) + check_high_dimensional( + X, + T, + threshold=5, + featurizer=self.featurizer, + discrete_treatment=self.discrete_treatment, + msg="The number of features in the final model (< 5) is too small for a sparse model. " + "We recommend using the LinearDRLearner for this low-dimensional setting.", + ) + return super().fit( + Y, + T, + X=X, + W=W, + Z=Z, + sample_weight=sample_weight, + groups=groups, + cache_values=cache_values, + inference=inference, + ) @property def fit_cate_intercept_(self): @@ -2064,8 +2253,7 @@ class ForestDRIV(ForestModelFinalCateEstimatorMixin, DRIV): The weighted impurity decrease equation is the following:: - N_t / N * (impurity - N_t_R / N_t * right_impurity - - N_t_L / N_t * left_impurity) + N_t / N * (impurity - N_t_R / N_t * right_impurity - N_t_L / N_t * left_impurity) where ``N`` is the total number of split samples, ``N_t`` is the number of split samples at the current node, ``N_t_L`` is the number of split samples in the @@ -2223,46 +2411,49 @@ def true_heterogeneity_function(X): array([ 1.36983..., 10.23568..., -0.17805...])) """ - def __init__(self, *, - model_y_xw="auto", - model_t_xw="auto", - model_z_xw="auto", - model_t_xwz="auto", - model_tz_xw="auto", - fit_cov_directly=True, - flexible_model_effect="auto", - prel_cate_approach="driv", - prel_cv=1, - prel_opt_reweighted=True, - projection=False, - featurizer=None, - n_estimators=1000, - max_depth=None, - min_samples_split=5, - min_samples_leaf=5, - min_weight_fraction_leaf=0., - max_features="auto", - min_impurity_decrease=0., - max_samples=.45, - min_balancedness_tol=.45, - honest=True, - subforest_size=4, - n_jobs=-1, - verbose=0, - cov_clip=1e-3, - opt_reweighted=False, - discrete_outcome=False, - discrete_instrument=False, - discrete_treatment=False, - treatment_featurizer=None, - categories='auto', - cv=2, - mc_iters=None, - mc_agg='mean', - random_state=None, - allow_missing=False, - use_ray=False, - ray_remote_func_options=None): + def __init__( + self, + *, + model_y_xw="auto", + model_t_xw="auto", + model_z_xw="auto", + model_t_xwz="auto", + model_tz_xw="auto", + fit_cov_directly=True, + flexible_model_effect="auto", + prel_cate_approach="driv", + prel_cv=1, + prel_opt_reweighted=True, + projection=False, + featurizer=None, + n_estimators=1000, + max_depth=None, + min_samples_split=5, + min_samples_leaf=5, + min_weight_fraction_leaf=0.0, + max_features="auto", + min_impurity_decrease=0.0, + max_samples=0.45, + min_balancedness_tol=0.45, + honest=True, + subforest_size=4, + n_jobs=-1, + verbose=0, + cov_clip=1e-3, + opt_reweighted=False, + discrete_outcome=False, + discrete_instrument=False, + discrete_treatment=False, + treatment_featurizer=None, + categories='auto', + cv=2, + mc_iters=None, + mc_agg='mean', + random_state=None, + allow_missing=False, + use_ray=False, + ray_remote_func_options=None, + ): self.n_estimators = n_estimators self.max_depth = max_depth self.min_samples_split = min_samples_split @@ -2276,55 +2467,58 @@ def __init__(self, *, self.subforest_size = subforest_size self.n_jobs = n_jobs self.verbose = verbose - super().__init__(model_y_xw=model_y_xw, - model_t_xw=model_t_xw, - model_z_xw=model_z_xw, - model_t_xwz=model_t_xwz, - model_tz_xw=model_tz_xw, - fit_cov_directly=fit_cov_directly, - flexible_model_effect=flexible_model_effect, - model_final=None, - prel_cate_approach=prel_cate_approach, - prel_cv=prel_cv, - prel_opt_reweighted=prel_opt_reweighted, - projection=projection, - featurizer=featurizer, - fit_cate_intercept=False, - cov_clip=cov_clip, - opt_reweighted=opt_reweighted, - discrete_outcome=discrete_outcome, - discrete_instrument=discrete_instrument, - discrete_treatment=discrete_treatment, - treatment_featurizer=treatment_featurizer, - categories=categories, - cv=cv, - mc_iters=mc_iters, - mc_agg=mc_agg, - random_state=random_state, - allow_missing=allow_missing, - use_ray=use_ray, - ray_remote_func_options=ray_remote_func_options) + super().__init__( + model_y_xw=model_y_xw, + model_t_xw=model_t_xw, + model_z_xw=model_z_xw, + model_t_xwz=model_t_xwz, + model_tz_xw=model_tz_xw, + fit_cov_directly=fit_cov_directly, + flexible_model_effect=flexible_model_effect, + model_final=None, + prel_cate_approach=prel_cate_approach, + prel_cv=prel_cv, + prel_opt_reweighted=prel_opt_reweighted, + projection=projection, + featurizer=featurizer, + fit_cate_intercept=False, + cov_clip=cov_clip, + opt_reweighted=opt_reweighted, + discrete_outcome=discrete_outcome, + discrete_instrument=discrete_instrument, + discrete_treatment=discrete_treatment, + treatment_featurizer=treatment_featurizer, + categories=categories, + cv=cv, + mc_iters=mc_iters, + mc_agg=mc_agg, + random_state=random_state, + allow_missing=allow_missing, + use_ray=use_ray, + ray_remote_func_options=ray_remote_func_options, + ) def _gen_model_final(self): - return RegressionForest(n_estimators=self.n_estimators, - max_depth=self.max_depth, - min_samples_split=self.min_samples_split, - min_samples_leaf=self.min_samples_leaf, - min_weight_fraction_leaf=self.min_weight_fraction_leaf, - max_features=self.max_features, - min_impurity_decrease=self.min_impurity_decrease, - max_samples=self.max_samples, - min_balancedness_tol=self.min_balancedness_tol, - honest=self.honest, - inference=True, - subforest_size=self.subforest_size, - n_jobs=self.n_jobs, - random_state=self.random_state, - verbose=self.verbose, - warm_start=False) - - def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, groups=None, - cache_values=False, inference='auto'): + return RegressionForest( + n_estimators=self.n_estimators, + max_depth=self.max_depth, + min_samples_split=self.min_samples_split, + min_samples_leaf=self.min_samples_leaf, + min_weight_fraction_leaf=self.min_weight_fraction_leaf, + max_features=self.max_features, + min_impurity_decrease=self.min_impurity_decrease, + max_samples=self.max_samples, + min_balancedness_tol=self.min_balancedness_tol, + honest=self.honest, + inference=True, + subforest_size=self.subforest_size, + n_jobs=self.n_jobs, + random_state=self.random_state, + verbose=self.verbose, + warm_start=False, + ) + + def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, groups=None, cache_values=False, inference='auto'): """ Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`. @@ -2361,9 +2555,17 @@ def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, groups=None, raise ValueError("This estimator does not support X=None!") # Replacing fit from _OrthoLearner, to reorder arguments and improve the docstring - return super().fit(Y, T, X=X, W=W, Z=Z, - sample_weight=sample_weight, groups=groups, - cache_values=cache_values, inference=inference) + return super().fit( + Y, + T, + X=X, + W=W, + Z=Z, + sample_weight=sample_weight, + groups=groups, + cache_values=cache_values, + inference=inference, + ) @property def model_final(self): @@ -2376,26 +2578,42 @@ def model_final(self, model): class _IntentToTreatDRIVNuisanceSelector(ModelSelector): - def __init__(self, - model_y_xw: SingleModelSelector, - model_t_xwz: SingleModelSelector, - dummy_z: SingleModelSelector, - prel_model_effect): + def __init__( + self, + model_y_xw: SingleModelSelector, + model_t_xwz: SingleModelSelector, + dummy_z: SingleModelSelector, + prel_model_effect, + ): self._model_y_xw = model_y_xw self._model_t_xwz = model_t_xwz self._dummy_z = dummy_z self._prel_model_effect = prel_model_effect def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params): - self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups, **fit_params) + self._model_y_xw.train( + is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups, **fit_params + ) # concat W and Z WZ = _combine(W, Z, Y.shape[0]) - self._model_t_xwz.train(is_selecting, folds, X=X, W=WZ, Target=T, sample_weight=sample_weight, groups=groups, **fit_params) - self._dummy_z.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups, **fit_params) + self._model_t_xwz.train( + is_selecting, folds, X=X, W=WZ, Target=T, sample_weight=sample_weight, groups=groups, **fit_params + ) + self._dummy_z.train( + is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups, **fit_params + ) # we need to undo the one-hot encoding for calling effect, # since it expects raw values - self._prel_model_effect.fit(Y, inverse_onehot(T), Z=inverse_onehot(Z), X=X, W=W, - sample_weight=sample_weight, groups=groups, **fit_params) + self._prel_model_effect.fit( + Y, + inverse_onehot(T), + Z=inverse_onehot(Z), + X=X, + W=W, + sample_weight=sample_weight, + groups=groups, + **fit_params, + ) return self def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): @@ -2412,8 +2630,9 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): if hasattr(self._prel_model_effect, 'score'): # we need to undo the one-hot encoding for calling effect, # since it expects raw values - effect_score = self._prel_model_effect.score(Y, inverse_onehot(T), - inverse_onehot(Z), X=X, W=W, sample_weight=sample_weight) + effect_score = self._prel_model_effect.score( + Y, inverse_onehot(T), inverse_onehot(Z), X=X, W=W, sample_weight=sample_weight + ) else: effect_score = None @@ -2472,54 +2691,59 @@ def predict_proba(self, X): class _IntentToTreatDRIV(_BaseDRIV): """Base class for the DRIV algorithm for the intent-to-treat A/B test setting.""" - def __init__(self, *, - model_y_xw="auto", - model_t_xwz="auto", - prel_model_effect, - model_final, - z_propensity="auto", - featurizer=None, - fit_cate_intercept=False, - discrete_outcome=False, - cov_clip=1e-3, - opt_reweighted=False, - categories='auto', - cv=3, - mc_iters=None, - mc_agg='mean', - random_state=None, - allow_missing=False, - use_ray=False, - ray_remote_func_options=None): + def __init__( + self, + *, + model_y_xw="auto", + model_t_xwz="auto", + prel_model_effect, + model_final, + z_propensity="auto", + featurizer=None, + fit_cate_intercept=False, + discrete_outcome=False, + cov_clip=1e-3, + opt_reweighted=False, + categories='auto', + cv=3, + mc_iters=None, + mc_agg='mean', + random_state=None, + allow_missing=False, + use_ray=False, + ray_remote_func_options=None, + ): self.model_y_xw = clone(model_y_xw, safe=False) self.model_t_xwz = clone(model_t_xwz, safe=False) self.prel_model_effect = clone(prel_model_effect, safe=False) self.z_propensity = z_propensity - super().__init__(model_final=model_final, - featurizer=featurizer, - fit_cate_intercept=fit_cate_intercept, - discrete_outcome=discrete_outcome, - cov_clip=cov_clip, - cv=cv, - mc_iters=mc_iters, - mc_agg=mc_agg, - discrete_instrument=True, - discrete_treatment=True, - categories=categories, - opt_reweighted=opt_reweighted, - random_state=random_state, - allow_missing=allow_missing, - use_ray=use_ray, - ray_remote_func_options=ray_remote_func_options) + super().__init__( + model_final=model_final, + featurizer=featurizer, + fit_cate_intercept=fit_cate_intercept, + discrete_outcome=discrete_outcome, + cov_clip=cov_clip, + cv=cv, + mc_iters=mc_iters, + mc_agg=mc_agg, + discrete_instrument=True, + discrete_treatment=True, + categories=categories, + opt_reweighted=opt_reweighted, + random_state=random_state, + allow_missing=allow_missing, + use_ray=use_ray, + ray_remote_func_options=ray_remote_func_options, + ) def _gen_prel_model_effect(self): return clone(self.prel_model_effect, safe=False) def _gen_ortho_learner_model_nuisance(self): - model_y_xw = _make_first_stage_selector(self.model_y_xw, - is_discrete=self.discrete_outcome, - random_state=self.random_state) + model_y_xw = _make_first_stage_selector( + self.model_y_xw, is_discrete=self.discrete_outcome, random_state=self.random_state + ) model_t_xwz = _make_first_stage_selector(self.model_t_xwz, is_discrete=True, random_state=self.random_state) if self.z_propensity == "auto": @@ -2704,29 +2928,31 @@ def true_heterogeneity_function(X): array([-4.52684..., 6.38767..., -2.67082...]) """ - def __init__(self, *, - model_y_xw="auto", - model_t_xwz="auto", - prel_cate_approach="driv", - flexible_model_effect="auto", - model_final=None, - prel_cv=1, - prel_opt_reweighted=True, - z_propensity="auto", - featurizer=None, - fit_cate_intercept=False, - discrete_outcome=False, - cov_clip=1e-3, - cv=3, - mc_iters=None, - mc_agg='mean', - opt_reweighted=False, - categories='auto', - random_state=None, - allow_missing=False, - use_ray=False, - ray_remote_func_options=None): - + def __init__( + self, + *, + model_y_xw="auto", + model_t_xwz="auto", + prel_cate_approach="driv", + flexible_model_effect="auto", + model_final=None, + prel_cv=1, + prel_opt_reweighted=True, + z_propensity="auto", + featurizer=None, + fit_cate_intercept=False, + discrete_outcome=False, + cov_clip=1e-3, + cv=3, + mc_iters=None, + mc_agg='mean', + opt_reweighted=False, + categories='auto', + random_state=None, + allow_missing=False, + use_ray=False, + ray_remote_func_options=None, + ): # maybe shouldn't expose fit_cate_intercept in this class? if flexible_model_effect == "auto": self.flexible_model_effect = StatsModelsLinearRegression(fit_intercept=False) @@ -2735,24 +2961,26 @@ def __init__(self, *, self.prel_cate_approach = prel_cate_approach self.prel_cv = prel_cv self.prel_opt_reweighted = prel_opt_reweighted - super().__init__(model_y_xw=model_y_xw, - model_t_xwz=model_t_xwz, - prel_model_effect=self.prel_cate_approach, - model_final=model_final, - z_propensity=z_propensity, - featurizer=featurizer, - fit_cate_intercept=fit_cate_intercept, - discrete_outcome=discrete_outcome, - cov_clip=cov_clip, - opt_reweighted=opt_reweighted, - categories=categories, - cv=cv, - mc_iters=mc_iters, - mc_agg=mc_agg, - random_state=random_state, - allow_missing=allow_missing, - use_ray=use_ray, - ray_remote_func_options=ray_remote_func_options) + super().__init__( + model_y_xw=model_y_xw, + model_t_xwz=model_t_xwz, + prel_model_effect=self.prel_cate_approach, + model_final=model_final, + z_propensity=z_propensity, + featurizer=featurizer, + fit_cate_intercept=fit_cate_intercept, + discrete_outcome=discrete_outcome, + cov_clip=cov_clip, + opt_reweighted=opt_reweighted, + categories=categories, + cv=cv, + mc_iters=mc_iters, + mc_agg=mc_agg, + random_state=random_state, + allow_missing=allow_missing, + use_ray=use_ray, + ray_remote_func_options=ray_remote_func_options, + ) def _gen_model_final(self): if self.model_final is None: @@ -2761,36 +2989,40 @@ def _gen_model_final(self): def _gen_prel_model_effect(self): if self.prel_cate_approach == "driv": - return _IntentToTreatDRIV(model_y_xw=clone(self.model_y_xw, safe=False), - model_t_xwz=clone(self.model_t_xwz, safe=False), - prel_model_effect=_DummyCATE(), - model_final=clone(self.flexible_model_effect, safe=False), - featurizer=self._gen_featurizer(), - fit_cate_intercept=self.fit_cate_intercept, - cov_clip=self.cov_clip, - categories=self.categories, - opt_reweighted=self.prel_opt_reweighted, - cv=self.prel_cv, - random_state=self.random_state, - allow_missing=self.allow_missing) + return _IntentToTreatDRIV( + model_y_xw=clone(self.model_y_xw, safe=False), + model_t_xwz=clone(self.model_t_xwz, safe=False), + prel_model_effect=_DummyCATE(), + model_final=clone(self.flexible_model_effect, safe=False), + featurizer=self._gen_featurizer(), + fit_cate_intercept=self.fit_cate_intercept, + cov_clip=self.cov_clip, + categories=self.categories, + opt_reweighted=self.prel_opt_reweighted, + cv=self.prel_cv, + random_state=self.random_state, + allow_missing=self.allow_missing, + ) elif self.prel_cate_approach == "dmliv": - return NonParamDMLIV(model_y_xw=clone(self.model_y_xw, safe=False), - model_t_xw=clone(self.model_t_xwz, safe=False), - model_t_xwz=clone(self.model_t_xwz, safe=False), - model_final=clone(self.flexible_model_effect, safe=False), - discrete_instrument=True, - discrete_treatment=True, - featurizer=self._gen_featurizer(), - categories=self.categories, - cv=self.prel_cv, - mc_iters=self.mc_iters, - mc_agg=self.mc_agg, - random_state=self.random_state, - allow_missing=self.allow_missing) + return NonParamDMLIV( + model_y_xw=clone(self.model_y_xw, safe=False), + model_t_xw=clone(self.model_t_xwz, safe=False), + model_t_xwz=clone(self.model_t_xwz, safe=False), + model_final=clone(self.flexible_model_effect, safe=False), + discrete_instrument=True, + discrete_treatment=True, + featurizer=self._gen_featurizer(), + categories=self.categories, + cv=self.prel_cv, + mc_iters=self.mc_iters, + mc_agg=self.mc_agg, + random_state=self.random_state, + allow_missing=self.allow_missing, + ) else: raise ValueError( - "We only support 'dmliv' or 'driv' preliminary model effect, " - f"but received '{self.prel_cate_approach}'!") + f"We only support 'dmliv' or 'driv' preliminary model effect, but received '{self.prel_cate_approach}'!" + ) @property def models_y_xw(self): @@ -3018,57 +3250,75 @@ def true_heterogeneity_function(X): (-2.07685..., 1.49784...) """ - def __init__(self, *, - model_y_xw="auto", - model_t_xwz="auto", - prel_cate_approach="driv", - flexible_model_effect="auto", - prel_cv=1, - prel_opt_reweighted=True, - z_propensity="auto", - featurizer=None, - fit_cate_intercept=True, - discrete_outcome=False, - cov_clip=1e-3, - cv=3, - mc_iters=None, - mc_agg='mean', - opt_reweighted=False, - categories='auto', - random_state=None, - allow_missing=False, - enable_federation=False, - use_ray=False, - ray_remote_func_options=None): - super().__init__(model_y_xw=model_y_xw, - model_t_xwz=model_t_xwz, - flexible_model_effect=flexible_model_effect, - model_final=None, - prel_cate_approach=prel_cate_approach, - prel_cv=prel_cv, - prel_opt_reweighted=prel_opt_reweighted, - z_propensity=z_propensity, - featurizer=featurizer, - fit_cate_intercept=fit_cate_intercept, - discrete_outcome=discrete_outcome, - cov_clip=cov_clip, - cv=cv, - mc_iters=mc_iters, - mc_agg=mc_agg, - opt_reweighted=opt_reweighted, - categories=categories, - random_state=random_state, - allow_missing=allow_missing, - use_ray=use_ray, - ray_remote_func_options=ray_remote_func_options) + def __init__( + self, + *, + model_y_xw="auto", + model_t_xwz="auto", + prel_cate_approach="driv", + flexible_model_effect="auto", + prel_cv=1, + prel_opt_reweighted=True, + z_propensity="auto", + featurizer=None, + fit_cate_intercept=True, + discrete_outcome=False, + cov_clip=1e-3, + cv=3, + mc_iters=None, + mc_agg='mean', + opt_reweighted=False, + categories='auto', + random_state=None, + allow_missing=False, + enable_federation=False, + use_ray=False, + ray_remote_func_options=None, + ): + super().__init__( + model_y_xw=model_y_xw, + model_t_xwz=model_t_xwz, + flexible_model_effect=flexible_model_effect, + model_final=None, + prel_cate_approach=prel_cate_approach, + prel_cv=prel_cv, + prel_opt_reweighted=prel_opt_reweighted, + z_propensity=z_propensity, + featurizer=featurizer, + fit_cate_intercept=fit_cate_intercept, + discrete_outcome=discrete_outcome, + cov_clip=cov_clip, + cv=cv, + mc_iters=mc_iters, + mc_agg=mc_agg, + opt_reweighted=opt_reweighted, + categories=categories, + random_state=random_state, + allow_missing=allow_missing, + use_ray=use_ray, + ray_remote_func_options=ray_remote_func_options, + ) self.enable_federation = enable_federation def _gen_model_final(self): return StatsModelsLinearRegression(fit_intercept=False, enable_federation=self.enable_federation) # override only so that we can update the docstring to indicate support for `StatsModelsInference` - def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, freq_weight=None, sample_var=None, groups=None, - cache_values=False, inference='auto'): + def fit( + self, + Y, + T, + *, + Z, + X=None, + W=None, + sample_weight=None, + freq_weight=None, + sample_var=None, + groups=None, + cache_values=False, + inference='auto', + ): """ Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`. @@ -3109,9 +3359,19 @@ def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, freq_weight=None, self : instance """ # TODO: do correct adjustment for sample_var - return super().fit(Y, T, Z=Z, X=X, W=W, - sample_weight=sample_weight, freq_weight=freq_weight, sample_var=sample_var, groups=groups, - cache_values=cache_values, inference=inference) + return super().fit( + Y, + T, + Z=Z, + X=X, + W=W, + sample_weight=sample_weight, + freq_weight=freq_weight, + sample_var=sample_var, + groups=groups, + cache_values=cache_values, + inference=inference, + ) @property def fit_cate_intercept_(self):