@@ -463,7 +463,9 @@ def x_cols(self):
463463 @x_cols .setter
464464 def x_cols (self , value ):
465465 reset_value = hasattr (self , "_x_cols" )
466+
466467 if value is not None :
468+ # Basic checks
467469 if isinstance (value , str ):
468470 value = [value ]
469471 if not isinstance (value , list ):
@@ -476,7 +478,17 @@ def x_cols(self, value):
476478 if not set (value ).issubset (set (self .all_variables )):
477479 raise ValueError ("Invalid covariates x_cols. At least one covariate is no data column." )
478480 assert set (value ).issubset (set (self .all_variables ))
479- self ._x_cols = value
481+
482+ if not reset_value :
483+ self ._x_cols = value
484+ else :
485+ previous_value = self ._x_cols
486+ self ._x_cols = value
487+ try :
488+ self ._check_disjoint_sets ()
489+ except ValueError :
490+ self ._x_cols = previous_value
491+ raise
480492
481493 else :
482494 excluded_cols = {self .y_col } | set (self .d_cols )
@@ -486,8 +498,6 @@ def x_cols(self, value):
486498 self ._x_cols = [col for col in self .data .columns if col not in excluded_cols ]
487499
488500 if reset_value :
489- self ._check_disjoint_sets ()
490- # by default, we initialize to the first treatment variable
491501 self .set_x_d (self .d_cols [0 ])
492502
493503 @property
@@ -500,6 +510,8 @@ def d_cols(self):
500510 @d_cols .setter
501511 def d_cols (self , value ):
502512 reset_value = hasattr (self , "_d_cols" )
513+
514+ # Basic checks
503515 if isinstance (value , str ):
504516 value = [value ]
505517 if not isinstance (value , list ):
@@ -511,10 +523,19 @@ def d_cols(self, value):
511523 raise ValueError ("Invalid treatment variable(s) d_cols: Contains duplicate values." )
512524 if not set (value ).issubset (set (self .all_variables )):
513525 raise ValueError ("Invalid treatment variable(s) d_cols. At least one treatment variable is no data column." )
514- self ._d_cols = value
526+
527+ if not reset_value :
528+ self ._d_cols = value
529+ else :
530+ previous_value = self ._d_cols
531+ self ._d_cols = value
532+ try :
533+ self ._check_disjoint_sets ()
534+ except ValueError :
535+ self ._d_cols = previous_value
536+ raise
537+
515538 if reset_value :
516- self ._check_disjoint_sets ()
517- # by default, we initialize to the first treatment variable
518539 self .set_x_d (self .d_cols [0 ])
519540
520541 @property
@@ -527,15 +548,27 @@ def y_col(self):
527548 @y_col .setter
528549 def y_col (self , value ):
529550 reset_value = hasattr (self , "_y_col" )
551+
552+ # Basic checks
530553 if not isinstance (value , str ):
531554 raise TypeError (
532555 f"The outcome variable y_col must be of str type. { str (value )} of type { str (type (value ))} was passed."
533556 )
534557 if value not in self .all_variables :
535558 raise ValueError (f"Invalid outcome variable y_col. { value } is no data column." )
536- self ._y_col = value
559+
560+ if not reset_value :
561+ self ._y_col = value
562+ else :
563+ previous_value = self ._y_col
564+ self ._y_col = value
565+ try :
566+ self ._check_disjoint_sets ()
567+ except ValueError :
568+ self ._y_col = previous_value
569+ raise
570+
537571 if reset_value :
538- self ._check_disjoint_sets ()
539572 self ._set_y_z ()
540573
541574 @property
@@ -548,7 +581,9 @@ def z_cols(self):
548581 @z_cols .setter
549582 def z_cols (self , value ):
550583 reset_value = hasattr (self , "_z_cols" )
584+
551585 if value is not None :
586+ # Basic validation
552587 if isinstance (value , str ):
553588 value = [value ]
554589 if not isinstance (value , list ):
@@ -562,12 +597,22 @@ def z_cols(self, value):
562597 raise ValueError (
563598 "Invalid instrumental variable(s) z_cols. At least one instrumental variable is no data column."
564599 )
565- self ._z_cols = value
600+
601+ if not reset_value :
602+ self ._z_cols = value
603+ else :
604+ previous_value = self ._z_cols
605+ self ._z_cols = value
606+ try :
607+ self ._check_disjoint_sets ()
608+ except ValueError :
609+ self ._z_cols = previous_value
610+ raise
611+
566612 else :
567613 self ._z_cols = None
568614
569615 if reset_value :
570- self ._check_disjoint_sets ()
571616 self ._set_y_z ()
572617
573618 @property
0 commit comments