From 4ebe1a704f245cd51e0e4ecfe0d9d5ba0b375d5b Mon Sep 17 00:00:00 2001 From: Rojan Shrestha Date: Tue, 29 Jul 2025 22:06:05 +0545 Subject: [PATCH 1/3] Added post_treatment_variable_name parameter and sklearn model summary for did --- causalpy/experiments/diff_in_diff.py | 65 +++++++++++++++++------ docs/source/_static/interrogate_badge.svg | 8 +-- 2 files changed, 52 insertions(+), 21 deletions(-) diff --git a/causalpy/experiments/diff_in_diff.py b/causalpy/experiments/diff_in_diff.py index 04b62370..2b307f00 100644 --- a/causalpy/experiments/diff_in_diff.py +++ b/causalpy/experiments/diff_in_diff.py @@ -26,7 +26,6 @@ from causalpy.custom_exceptions import ( DataException, - FormulaException, ) from causalpy.plot_utils import plot_xY from causalpy.pymc_models import PyMCModel @@ -84,6 +83,7 @@ def __init__( formula: str, time_variable_name: str, group_variable_name: str, + post_treatment_variable_name: str = "post_treatment", model=None, **kwargs, ) -> None: @@ -95,6 +95,7 @@ def __init__( self.formula = formula self.time_variable_name = time_variable_name self.group_variable_name = group_variable_name + self.post_treatment_variable_name = post_treatment_variable_name self.input_validation() y, X = dmatrices(formula, self.data) @@ -128,6 +129,12 @@ def __init__( } self.model.fit(X=self.X, y=self.y, coords=COORDS) elif isinstance(self.model, RegressorMixin): + # For scikit-learn models, automatically set fit_intercept=False + # This ensures the intercept is included in the coefficients array rather than being a separate intercept_ attribute + # without this, the intercept is not included in the coefficients array hence would be displayed as 0 in the model summary + # TODO: later, this should be handled in ScikitLearnAdaptor itself + if hasattr(self.model, "fit_intercept"): + self.model.fit_intercept = False self.model.fit(X=self.X, y=self.y) else: raise ValueError("Model type not recognized") @@ -173,7 +180,7 @@ def __init__( # just the treated group .query(f"{self.group_variable_name} == 1") # just the treatment period(s) - .query("post_treatment == True") + .query(f"{self.post_treatment_variable_name} == True") # drop the outcome variable .drop(self.outcome_variable_name, axis=1) # We may have multiple units per time point, we only want one time point @@ -189,7 +196,10 @@ def __init__( # INTERVENTION: set the interaction term between the group and the # post_treatment variable to zero. This is the counterfactual. for i, label in enumerate(self.labels): - if "post_treatment" in label and self.group_variable_name in label: + if ( + self.post_treatment_variable_name in label + and self.group_variable_name in label + ): new_x.iloc[:, i] = 0 self.y_pred_counterfactual = self.model.predict(np.asarray(new_x)) @@ -198,16 +208,24 @@ def __init__( # This is the coefficient on the interaction term coeff_names = self.model.idata.posterior.coords["coeffs"].data for i, label in enumerate(coeff_names): - if "post_treatment" in label and self.group_variable_name in label: + if ( + self.post_treatment_variable_name in label + and self.group_variable_name in label + ): self.causal_impact = self.model.idata.posterior["beta"].isel( {"coeffs": i} ) elif isinstance(self.model, RegressorMixin): # This is the coefficient on the interaction term - # TODO: CHECK FOR CORRECTNESS - self.causal_impact = ( - self.y_pred_treatment[1] - self.y_pred_counterfactual[0] - ).item() + # Store the coefficient into dictionary {intercept:value} + coef_map = dict(zip(self.labels, self.model.get_coeffs())) + # Create and find the interaction term based on the values user provided + interaction_term = ( + f"{self.group_variable_name}:{self.post_treatment_variable_name}" + ) + matched_key = next((k for k in coef_map if interaction_term in k), None) + att = coef_map.get(matched_key) + self.causal_impact = att else: raise ValueError("Model type not recognized") @@ -215,15 +233,28 @@ def __init__( def input_validation(self): """Validate the input data and model formula for correctness""" - if "post_treatment" not in self.formula: - raise FormulaException( - "A predictor called `post_treatment` should be in the formula" - ) - - if "post_treatment" not in self.data.columns: - raise DataException( - "Require a boolean column labelling observations which are `treated`" - ) + if ( + self.post_treatment_variable_name not in self.formula + or self.post_treatment_variable_name not in self.data.columns + ): + if self.post_treatment_variable_name == "post_treatment": + # Default case - user didn't specify custom name, so guide them to use "post_treatment" + raise DataException( + "Missing 'post_treatment' in formula or dataset.\n" + "Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n" + "1) Add 'post_treatment' to formula (e.g., 'y ~ 1 + group*post_treatment')\n" + "2) and ensure dataset has boolean column 'post_treatment'.\n" + "To use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'." + ) + else: + # Custom case - user specified custom name, so remind them what they specified + raise DataException( + f"Missing required variable '{self.post_treatment_variable_name}' in formula or dataset.\n\n" + f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', " + f"please ensure:\n" + f"1) formula includes '{self.post_treatment_variable_name}'\n" + f"2) dataset has boolean column named '{self.post_treatment_variable_name}'" + ) if "unit" not in self.data.columns: raise DataException( diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 4704ef6c..3e6a538d 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,10 +1,10 @@ - interrogate: 95.5% + interrogate: 93.6% - + @@ -12,8 +12,8 @@ interrogate interrogate - 95.5% - 95.5% + 93.6% + 93.6% From 7fbb27a3d3f5214768b043f10d1504a92a0edce0 Mon Sep 17 00:00:00 2001 From: Rojan Shrestha Date: Wed, 30 Jul 2025 11:20:51 +0545 Subject: [PATCH 2/3] Refactor DiD validation: segregate FormulaException and DataException --- causalpy/experiments/diff_in_diff.py | 38 +++++++++++++++++++--------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/causalpy/experiments/diff_in_diff.py b/causalpy/experiments/diff_in_diff.py index 2b307f00..a093359a 100644 --- a/causalpy/experiments/diff_in_diff.py +++ b/causalpy/experiments/diff_in_diff.py @@ -26,6 +26,7 @@ from causalpy.custom_exceptions import ( DataException, + FormulaException, ) from causalpy.plot_utils import plot_xY from causalpy.pymc_models import PyMCModel @@ -233,27 +234,40 @@ def __init__( def input_validation(self): """Validate the input data and model formula for correctness""" - if ( - self.post_treatment_variable_name not in self.formula - or self.post_treatment_variable_name not in self.data.columns - ): + # Check if post_treatment_variable_name is in formula + if self.post_treatment_variable_name not in self.formula: + if self.post_treatment_variable_name == "post_treatment": + # Default case - user didn't specify custom name, so guide them to use "post_treatment" + raise FormulaException( + "Missing 'post_treatment' in formula.\n" + "Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n" + "Add 'post_treatment' to formula (e.g., 'y ~ 1 + group*post_treatment').\n" + "Or to use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'." + ) + else: + # Custom case - user specified custom name, so remind them what they specified + raise FormulaException( + f"Missing required variable '{self.post_treatment_variable_name}' in formula.\n\n" + f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', " + f"please ensure formula includes '{self.post_treatment_variable_name}'" + ) + + # Check if post_treatment_variable_name is in data columns + if self.post_treatment_variable_name not in self.data.columns: if self.post_treatment_variable_name == "post_treatment": # Default case - user didn't specify custom name, so guide them to use "post_treatment" raise DataException( - "Missing 'post_treatment' in formula or dataset.\n" + "Missing 'post_treatment' column in dataset.\n" "Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n" - "1) Add 'post_treatment' to formula (e.g., 'y ~ 1 + group*post_treatment')\n" - "2) and ensure dataset has boolean column 'post_treatment'.\n" - "To use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'." + "Ensure dataset has boolean column 'post_treatment'.\n" + "or to use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'." ) else: # Custom case - user specified custom name, so remind them what they specified raise DataException( - f"Missing required variable '{self.post_treatment_variable_name}' in formula or dataset.\n\n" + f"Missing required column '{self.post_treatment_variable_name}' in dataset.\n\n" f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', " - f"please ensure:\n" - f"1) formula includes '{self.post_treatment_variable_name}'\n" - f"2) dataset has boolean column named '{self.post_treatment_variable_name}'" + f"please ensure dataset has boolean column named '{self.post_treatment_variable_name}'" ) if "unit" not in self.data.columns: From c232d89411d82fb9883f3c2e134a780431b55f28 Mon Sep 17 00:00:00 2001 From: Rojan Shrestha Date: Tue, 5 Aug 2025 00:31:52 +0545 Subject: [PATCH 3/3] added validations for interactions, test coverage expanded to test interaction terms,more generic messages --- causalpy/experiments/diff_in_diff.py | 96 +++++++++++------ causalpy/tests/test_input_validation.py | 119 +++++++++++++++++++++- docs/source/_static/interrogate_badge.svg | 6 +- 3 files changed, 186 insertions(+), 35 deletions(-) diff --git a/causalpy/experiments/diff_in_diff.py b/causalpy/experiments/diff_in_diff.py index a093359a..132cd2ae 100644 --- a/causalpy/experiments/diff_in_diff.py +++ b/causalpy/experiments/diff_in_diff.py @@ -15,6 +15,8 @@ Difference in differences """ +import re + import arviz as az import numpy as np import pandas as pd @@ -233,42 +235,21 @@ def __init__( return def input_validation(self): + # Validate formula structure and interaction interaction terms + self._validate_formula_interaction_terms() + """Validate the input data and model formula for correctness""" # Check if post_treatment_variable_name is in formula if self.post_treatment_variable_name not in self.formula: - if self.post_treatment_variable_name == "post_treatment": - # Default case - user didn't specify custom name, so guide them to use "post_treatment" - raise FormulaException( - "Missing 'post_treatment' in formula.\n" - "Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n" - "Add 'post_treatment' to formula (e.g., 'y ~ 1 + group*post_treatment').\n" - "Or to use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'." - ) - else: - # Custom case - user specified custom name, so remind them what they specified - raise FormulaException( - f"Missing required variable '{self.post_treatment_variable_name}' in formula.\n\n" - f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', " - f"please ensure formula includes '{self.post_treatment_variable_name}'" - ) + raise FormulaException( + f"Missing required variable '{self.post_treatment_variable_name}' in formula" + ) # Check if post_treatment_variable_name is in data columns if self.post_treatment_variable_name not in self.data.columns: - if self.post_treatment_variable_name == "post_treatment": - # Default case - user didn't specify custom name, so guide them to use "post_treatment" - raise DataException( - "Missing 'post_treatment' column in dataset.\n" - "Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n" - "Ensure dataset has boolean column 'post_treatment'.\n" - "or to use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'." - ) - else: - # Custom case - user specified custom name, so remind them what they specified - raise DataException( - f"Missing required column '{self.post_treatment_variable_name}' in dataset.\n\n" - f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', " - f"please ensure dataset has boolean column named '{self.post_treatment_variable_name}'" - ) + raise DataException( + f"Missing required column '{self.post_treatment_variable_name}' in dataset" + ) if "unit" not in self.data.columns: raise DataException( @@ -281,6 +262,61 @@ def input_validation(self): coded. Consisting of 0's and 1's only.""" ) + def _get_interaction_terms(self): + """ + Extract interaction terms from the formula. + Returns a list of interaction terms (those with '*' or ':'). + """ + # Define interaction indicators + INTERACTION_INDICATORS = ["*", ":"] + + # Remove whitespace + formula = self.formula.replace(" ", "") + + # Extract right-hand side of the formula + rhs = formula.split("~")[1] + + # Split terms by '+' or '-' while keeping them intact + terms = re.split(r"(?=[+-])", rhs) + + # Clean up terms and get interaction terms (those with '*' or ':') + interaction_terms = [] + for term in terms: + # Remove leading + or - for processing + clean_term = term.lstrip("+-") + if any(indicator in clean_term for indicator in INTERACTION_INDICATORS): + interaction_terms.append(clean_term) + + return interaction_terms + + def _validate_formula_interaction_terms(self): + """ + Validate that the formula contains at most one interaction term and no three-way or higher-order interactions. + Raises FormulaException if more than one interaction term is found or if any interaction term has more than 2 variables. + """ + # Define interaction indicators + INTERACTION_INDICATORS = ["*", ":"] + + # Get interaction terms + interaction_terms = self._get_interaction_terms() + + # Check for interaction terms with more than 2 variables (more than one '*' or ':') + for term in interaction_terms: + total_indicators = sum( + term.count(indicator) for indicator in INTERACTION_INDICATORS + ) + if ( + total_indicators >= 2 + ): # 3 or more variables (e.g., a*b*c or a:b:c has 2 symbols) + raise FormulaException( + f"Formula contains interaction term with more than 2 variables: {term}. Only two-way interactions are allowed." + ) + + if len(interaction_terms) > 1: + raise FormulaException( + f"Formula contains more than 1 interaction term: {interaction_terms}. Maximum of 1 allowed." + ) + def summary(self, round_to=None) -> None: """Print summary of main results and model coefficients. diff --git a/causalpy/tests/test_input_validation.py b/causalpy/tests/test_input_validation.py index 43fd9208..69ca3753 100644 --- a/causalpy/tests/test_input_validation.py +++ b/causalpy/tests/test_input_validation.py @@ -30,18 +30,29 @@ def test_did_validation_post_treatment_formula(): - """Test that we get a FormulaException if do not include post_treatment in the - formula""" + """Test that we get a FormulaException for invalid formulas and missing post_treatment variables""" df = pd.DataFrame( { "group": [0, 0, 1, 1], "t": [0, 1, 0, 1], "unit": [0, 0, 1, 1], "post_treatment": [0, 1, 0, 1], + "male": [0, 1, 0, 1], # Additional variable for testing "y": [1, 2, 3, 4], } ) + df_with_custom = pd.DataFrame( + { + "group": [0, 0, 1, 1], + "t": [0, 1, 0, 1], + "unit": [0, 0, 1, 1], + "custom_post": [0, 1, 0, 1], # Custom column name + "y": [1, 2, 3, 4], + } + ) + + # Test 1: Missing post_treatment variable in formula with pytest.raises(FormulaException): _ = cp.DifferenceInDifferences( df, @@ -51,6 +62,7 @@ def test_did_validation_post_treatment_formula(): model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), ) + # Test 2: Missing post_treatment variable in formula (duplicate test) with pytest.raises(FormulaException): _ = cp.DifferenceInDifferences( df, @@ -60,6 +72,88 @@ def test_did_validation_post_treatment_formula(): model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), ) + # Test 3: Custom post_treatment_variable_name but formula uses different name + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df_with_custom, + formula="y ~ 1 + group*post_treatment", # Formula uses 'post_treatment' + time_variable_name="t", + group_variable_name="group", + post_treatment_variable_name="custom_post", # But user specifies 'custom_post' + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 4: Default post_treatment_variable_name but formula uses different name + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group*custom_post", # Formula uses 'custom_post' + time_variable_name="t", + group_variable_name="group", + # post_treatment_variable_name defaults to "post_treatment" + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 5: Repeated interaction terms (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group*post_treatment + group*post_treatment", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 6: Three-way interactions using * (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group*post_treatment*male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 7: Three-way interactions using : (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group:post_treatment:male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 8: Multiple different interaction terms using * (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group*post_treatment + group*male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 9: Multiple different interaction terms using : (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group:post_treatment + group:male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + + # Test 10: Mixed issues - multiple terms + three-way interaction (should be invalid) + with pytest.raises(FormulaException): + _ = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group + group*post_treatment + group:post_treatment:male", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + def test_did_validation_post_treatment_data(): """Test that we get a DataException if do not include post_treatment in the data""" @@ -91,6 +185,27 @@ def test_did_validation_post_treatment_data(): model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), ) + # Test 2: Custom post_treatment_variable_name but column doesn't exist in data + df_with_post = pd.DataFrame( + { + "group": [0, 0, 1, 1], + "t": [0, 1, 0, 1], + "unit": [0, 0, 1, 1], + "post_treatment": [0, 1, 0, 1], # Data has 'post_treatment' + "y": [1, 2, 3, 4], + } + ) + + with pytest.raises(DataException): + _ = cp.DifferenceInDifferences( + df_with_post, + formula="y ~ 1 + group*custom_post", # Formula uses 'custom_post' + time_variable_name="t", + group_variable_name="group", + post_treatment_variable_name="custom_post", # User specifies 'custom_post' + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + def test_did_validation_unit_data(): """Test that we get a DataException if do not include unit in the data""" diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 3e6a538d..08c36d5e 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 93.6% + interrogate: 92.6% @@ -12,8 +12,8 @@ interrogate interrogate - 93.6% - 93.6% + 92.6% + 92.6%