@@ -463,7 +463,9 @@ def x_cols(self):
463463 @x_cols .setter
464464 def x_cols (self , value ):
465465 reset_value = hasattr (self , "_x_cols" )
466+
466467 if value is not None :
468+ # Basic checks
467469 if isinstance (value , str ):
468470 value = [value ]
469471 if not isinstance (value , list ):
@@ -476,7 +478,17 @@ def x_cols(self, value):
476478 if not set (value ).issubset (set (self .all_variables )):
477479 raise ValueError ("Invalid covariates x_cols. At least one covariate is no data column." )
478480 assert set (value ).issubset (set (self .all_variables ))
479- self ._x_cols = value
481+
482+ if reset_value :
483+ previous_value = self ._x_cols
484+ self ._x_cols = value
485+ try :
486+ self ._check_disjoint_sets ()
487+ except ValueError :
488+ self ._x_cols = previous_value
489+ raise
490+ else :
491+ self ._x_cols = value
480492
481493 else :
482494 excluded_cols = {self .y_col } | set (self .d_cols )
@@ -486,8 +498,6 @@ def x_cols(self, value):
486498 self ._x_cols = [col for col in self .data .columns if col not in excluded_cols ]
487499
488500 if reset_value :
489- self ._check_disjoint_sets ()
490- # by default, we initialize to the first treatment variable
491501 self .set_x_d (self .d_cols [0 ])
492502
493503 @property
@@ -500,6 +510,8 @@ def d_cols(self):
500510 @d_cols .setter
501511 def d_cols (self , value ):
502512 reset_value = hasattr (self , "_d_cols" )
513+
514+ # Basic checks
503515 if isinstance (value , str ):
504516 value = [value ]
505517 if not isinstance (value , list ):
@@ -511,10 +523,19 @@ def d_cols(self, value):
511523 raise ValueError ("Invalid treatment variable(s) d_cols: Contains duplicate values." )
512524 if not set (value ).issubset (set (self .all_variables )):
513525 raise ValueError ("Invalid treatment variable(s) d_cols. At least one treatment variable is no data column." )
514- self ._d_cols = value
526+
527+ if reset_value :
528+ previous_value = self ._d_cols
529+ self ._d_cols = value
530+ try :
531+ self ._check_disjoint_sets ()
532+ except ValueError :
533+ self ._d_cols = previous_value
534+ raise
535+ else :
536+ self ._d_cols = value
537+
515538 if reset_value :
516- self ._check_disjoint_sets ()
517- # by default, we initialize to the first treatment variable
518539 self .set_x_d (self .d_cols [0 ])
519540
520541 @property
@@ -541,9 +562,9 @@ def y_col(self, value):
541562 self ._y_col = value
542563 try :
543564 self ._check_disjoint_sets ()
544- except ValueError as e :
565+ except ValueError :
545566 self ._y_col = previous_value
546- raise e
567+ raise
547568 else :
548569 self ._y_col = value
549570
@@ -560,7 +581,9 @@ def z_cols(self):
560581 @z_cols .setter
561582 def z_cols (self , value ):
562583 reset_value = hasattr (self , "_z_cols" )
584+
563585 if value is not None :
586+ # Basic validation
564587 if isinstance (value , str ):
565588 value = [value ]
566589 if not isinstance (value , list ):
@@ -574,12 +597,22 @@ def z_cols(self, value):
574597 raise ValueError (
575598 "Invalid instrumental variable(s) z_cols. At least one instrumental variable is no data column."
576599 )
577- self ._z_cols = value
600+
601+ if reset_value :
602+ previous_value = self ._z_cols
603+ self ._z_cols = value
604+ try :
605+ self ._check_disjoint_sets ()
606+ except ValueError :
607+ self ._z_cols = previous_value
608+ raise
609+ else :
610+ self ._z_cols = value
611+
578612 else :
579613 self ._z_cols = None
580614
581615 if reset_value :
582- self ._check_disjoint_sets ()
583616 self ._set_y_z ()
584617
585618 @property
0 commit comments