Skip to content

Commit 7a79af4

Browse files
authored
Merge pull request #357 from DoubleML/s-update-setters
Update DoubleMLData setters to fall to previous values if checks fail
2 parents 9eda55d + c895262 commit 7a79af4

File tree

2 files changed

+97
-10
lines changed

2 files changed

+97
-10
lines changed

doubleml/data/base_data.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,9 @@ def x_cols(self):
463463
@x_cols.setter
464464
def x_cols(self, value):
465465
reset_value = hasattr(self, "_x_cols")
466+
466467
if value is not None:
468+
# Basic checks
467469
if isinstance(value, str):
468470
value = [value]
469471
if not isinstance(value, list):
@@ -476,7 +478,17 @@ def x_cols(self, value):
476478
if not set(value).issubset(set(self.all_variables)):
477479
raise ValueError("Invalid covariates x_cols. At least one covariate is no data column.")
478480
assert set(value).issubset(set(self.all_variables))
479-
self._x_cols = value
481+
482+
if not reset_value:
483+
self._x_cols = value
484+
else:
485+
previous_value = self._x_cols
486+
self._x_cols = value
487+
try:
488+
self._check_disjoint_sets()
489+
except ValueError:
490+
self._x_cols = previous_value
491+
raise
480492

481493
else:
482494
excluded_cols = {self.y_col} | set(self.d_cols)
@@ -486,8 +498,6 @@ def x_cols(self, value):
486498
self._x_cols = [col for col in self.data.columns if col not in excluded_cols]
487499

488500
if reset_value:
489-
self._check_disjoint_sets()
490-
# by default, we initialize to the first treatment variable
491501
self.set_x_d(self.d_cols[0])
492502

493503
@property
@@ -500,6 +510,8 @@ def d_cols(self):
500510
@d_cols.setter
501511
def d_cols(self, value):
502512
reset_value = hasattr(self, "_d_cols")
513+
514+
# Basic checks
503515
if isinstance(value, str):
504516
value = [value]
505517
if not isinstance(value, list):
@@ -511,10 +523,19 @@ def d_cols(self, value):
511523
raise ValueError("Invalid treatment variable(s) d_cols: Contains duplicate values.")
512524
if not set(value).issubset(set(self.all_variables)):
513525
raise ValueError("Invalid treatment variable(s) d_cols. At least one treatment variable is no data column.")
514-
self._d_cols = value
526+
527+
if not reset_value:
528+
self._d_cols = value
529+
else:
530+
previous_value = self._d_cols
531+
self._d_cols = value
532+
try:
533+
self._check_disjoint_sets()
534+
except ValueError:
535+
self._d_cols = previous_value
536+
raise
537+
515538
if reset_value:
516-
self._check_disjoint_sets()
517-
# by default, we initialize to the first treatment variable
518539
self.set_x_d(self.d_cols[0])
519540

520541
@property
@@ -527,15 +548,27 @@ def y_col(self):
527548
@y_col.setter
528549
def y_col(self, value):
529550
reset_value = hasattr(self, "_y_col")
551+
552+
# Basic checks
530553
if not isinstance(value, str):
531554
raise TypeError(
532555
f"The outcome variable y_col must be of str type. {str(value)} of type {str(type(value))} was passed."
533556
)
534557
if value not in self.all_variables:
535558
raise ValueError(f"Invalid outcome variable y_col. {value} is no data column.")
536-
self._y_col = value
559+
560+
if not reset_value:
561+
self._y_col = value
562+
else:
563+
previous_value = self._y_col
564+
self._y_col = value
565+
try:
566+
self._check_disjoint_sets()
567+
except ValueError:
568+
self._y_col = previous_value
569+
raise
570+
537571
if reset_value:
538-
self._check_disjoint_sets()
539572
self._set_y_z()
540573

541574
@property
@@ -548,7 +581,9 @@ def z_cols(self):
548581
@z_cols.setter
549582
def z_cols(self, value):
550583
reset_value = hasattr(self, "_z_cols")
584+
551585
if value is not None:
586+
# Basic validation
552587
if isinstance(value, str):
553588
value = [value]
554589
if not isinstance(value, list):
@@ -562,12 +597,22 @@ def z_cols(self, value):
562597
raise ValueError(
563598
"Invalid instrumental variable(s) z_cols. At least one instrumental variable is no data column."
564599
)
565-
self._z_cols = value
600+
601+
if not reset_value:
602+
self._z_cols = value
603+
else:
604+
previous_value = self._z_cols
605+
self._z_cols = value
606+
try:
607+
self._check_disjoint_sets()
608+
except ValueError:
609+
self._z_cols = previous_value
610+
raise
611+
566612
else:
567613
self._z_cols = None
568614

569615
if reset_value:
570-
self._check_disjoint_sets()
571616
self._set_y_z()
572617

573618
@property

doubleml/data/tests/test_dml_data.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,3 +619,45 @@ def test_dml_data_w_missing_d(generate_data1):
619619
assert dml_data.force_all_d_finite is False
620620
dml_data.force_all_d_finite = "allow-nan"
621621
assert dml_data.force_all_d_finite == "allow-nan"
622+
623+
624+
@pytest.mark.ci
625+
def test_property_setter_rollback_on_validation_failure():
626+
"""Test that property setters don't mutate the object if validation fails."""
627+
np.random.seed(3141)
628+
dml_data = make_plr_CCDDHNR2018(n_obs=100)
629+
630+
# Store original values
631+
original_y_col = dml_data.y_col
632+
original_d_cols = dml_data.d_cols.copy()
633+
original_x_cols = dml_data.x_cols.copy()
634+
original_z_cols = dml_data.z_cols
635+
636+
# Test y_col setter - try to set y_col to a value that's already in d_cols
637+
with pytest.raises(
638+
ValueError, match=r"d cannot be set as outcome variable ``y_col`` and treatment variable in ``d_cols``"
639+
):
640+
dml_data.y_col = "d"
641+
# Object should remain unchanged
642+
assert dml_data.y_col == original_y_col
643+
644+
# Test d_cols setter - try to set d_cols to include the outcome variable
645+
with pytest.raises(
646+
ValueError, match=r"y cannot be set as outcome variable ``y_col`` and treatment variable in ``d_cols``"
647+
):
648+
dml_data.d_cols = ["y", "d"]
649+
# Object should remain unchanged
650+
assert dml_data.d_cols == original_d_cols
651+
652+
# Test x_cols setter - try to set x_cols to include the outcome variable
653+
with pytest.raises(ValueError, match=r"y cannot be set as outcome variable ``y_col`` and covariate in ``x_cols``"):
654+
dml_data.x_cols = ["X1", "y", "X2"]
655+
# Object should remain unchanged
656+
assert dml_data.x_cols == original_x_cols
657+
658+
# Test z_cols setter - try to set z_cols to include the outcome variable
659+
msg = r"At least one variable/column is set as outcome variable \(``y_col``\) and instrumental variable \(``z_cols``\)"
660+
with pytest.raises(ValueError, match=msg):
661+
dml_data.z_cols = ["y"]
662+
# Object should remain unchanged
663+
assert dml_data.z_cols == original_z_cols

0 commit comments

Comments
 (0)