Skip to content

Commit 2f12d4b

Browse files
committed
update z and x cols setter
1 parent 2fcf523 commit 2f12d4b

File tree

1 file changed

+43
-10
lines changed

1 file changed

+43
-10
lines changed

doubleml/data/base_data.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,9 @@ def x_cols(self):
463463
@x_cols.setter
464464
def x_cols(self, value):
465465
reset_value = hasattr(self, "_x_cols")
466+
466467
if value is not None:
468+
# Basic checks
467469
if isinstance(value, str):
468470
value = [value]
469471
if not isinstance(value, list):
@@ -476,7 +478,17 @@ def x_cols(self, value):
476478
if not set(value).issubset(set(self.all_variables)):
477479
raise ValueError("Invalid covariates x_cols. At least one covariate is no data column.")
478480
assert set(value).issubset(set(self.all_variables))
479-
self._x_cols = value
481+
482+
if reset_value:
483+
previous_value = self._x_cols
484+
self._x_cols = value
485+
try:
486+
self._check_disjoint_sets()
487+
except ValueError:
488+
self._x_cols = previous_value
489+
raise
490+
else:
491+
self._x_cols = value
480492

481493
else:
482494
excluded_cols = {self.y_col} | set(self.d_cols)
@@ -486,8 +498,6 @@ def x_cols(self, value):
486498
self._x_cols = [col for col in self.data.columns if col not in excluded_cols]
487499

488500
if reset_value:
489-
self._check_disjoint_sets()
490-
# by default, we initialize to the first treatment variable
491501
self.set_x_d(self.d_cols[0])
492502

493503
@property
@@ -500,6 +510,8 @@ def d_cols(self):
500510
@d_cols.setter
501511
def d_cols(self, value):
502512
reset_value = hasattr(self, "_d_cols")
513+
514+
# Basic checks
503515
if isinstance(value, str):
504516
value = [value]
505517
if not isinstance(value, list):
@@ -511,10 +523,19 @@ def d_cols(self, value):
511523
raise ValueError("Invalid treatment variable(s) d_cols: Contains duplicate values.")
512524
if not set(value).issubset(set(self.all_variables)):
513525
raise ValueError("Invalid treatment variable(s) d_cols. At least one treatment variable is no data column.")
514-
self._d_cols = value
526+
527+
if reset_value:
528+
previous_value = self._d_cols
529+
self._d_cols = value
530+
try:
531+
self._check_disjoint_sets()
532+
except ValueError:
533+
self._d_cols = previous_value
534+
raise
535+
else:
536+
self._d_cols = value
537+
515538
if reset_value:
516-
self._check_disjoint_sets()
517-
# by default, we initialize to the first treatment variable
518539
self.set_x_d(self.d_cols[0])
519540

520541
@property
@@ -541,9 +562,9 @@ def y_col(self, value):
541562
self._y_col = value
542563
try:
543564
self._check_disjoint_sets()
544-
except ValueError as e:
565+
except ValueError:
545566
self._y_col = previous_value
546-
raise e
567+
raise
547568
else:
548569
self._y_col = value
549570

@@ -560,7 +581,9 @@ def z_cols(self):
560581
@z_cols.setter
561582
def z_cols(self, value):
562583
reset_value = hasattr(self, "_z_cols")
584+
563585
if value is not None:
586+
# Basic validation
564587
if isinstance(value, str):
565588
value = [value]
566589
if not isinstance(value, list):
@@ -574,12 +597,22 @@ def z_cols(self, value):
574597
raise ValueError(
575598
"Invalid instrumental variable(s) z_cols. At least one instrumental variable is no data column."
576599
)
577-
self._z_cols = value
600+
601+
if reset_value:
602+
previous_value = self._z_cols
603+
self._z_cols = value
604+
try:
605+
self._check_disjoint_sets()
606+
except ValueError:
607+
self._z_cols = previous_value
608+
raise
609+
else:
610+
self._z_cols = value
611+
578612
else:
579613
self._z_cols = None
580614

581615
if reset_value:
582-
self._check_disjoint_sets()
583616
self._set_y_z()
584617

585618
@property

0 commit comments

Comments
 (0)