@@ -479,16 +479,16 @@ def x_cols(self, value):
479479 raise ValueError ("Invalid covariates x_cols. At least one covariate is no data column." )
480480 assert set (value ).issubset (set (self .all_variables ))
481481
482- if reset_value :
482+ if not reset_value :
483+ self ._x_cols = value
484+ else :
483485 previous_value = self ._x_cols
484486 self ._x_cols = value
485487 try :
486488 self ._check_disjoint_sets ()
487489 except ValueError :
488490 self ._x_cols = previous_value
489491 raise
490- else :
491- self ._x_cols = value
492492
493493 else :
494494 excluded_cols = {self .y_col } | set (self .d_cols )
@@ -524,16 +524,16 @@ def d_cols(self, value):
524524 if not set (value ).issubset (set (self .all_variables )):
525525 raise ValueError ("Invalid treatment variable(s) d_cols. At least one treatment variable is no data column." )
526526
527- if reset_value :
527+ if not reset_value :
528+ self ._d_cols = value
529+ else :
528530 previous_value = self ._d_cols
529531 self ._d_cols = value
530532 try :
531533 self ._check_disjoint_sets ()
532534 except ValueError :
533535 self ._d_cols = previous_value
534536 raise
535- else :
536- self ._d_cols = value
537537
538538 if reset_value :
539539 self .set_x_d (self .d_cols [0 ])
@@ -557,16 +557,16 @@ def y_col(self, value):
557557 if value not in self .all_variables :
558558 raise ValueError (f"Invalid outcome variable y_col. { value } is no data column." )
559559
560- if reset_value :
560+ if not reset_value :
561+ self ._y_col = value
562+ else :
561563 previous_value = self ._y_col
562564 self ._y_col = value
563565 try :
564566 self ._check_disjoint_sets ()
565567 except ValueError :
566568 self ._y_col = previous_value
567569 raise
568- else :
569- self ._y_col = value
570570
571571 if reset_value :
572572 self ._set_y_z ()
@@ -598,16 +598,16 @@ def z_cols(self, value):
598598 "Invalid instrumental variable(s) z_cols. At least one instrumental variable is no data column."
599599 )
600600
601- if reset_value :
601+ if not reset_value :
602+ self ._z_cols = value
603+ else :
602604 previous_value = self ._z_cols
603605 self ._z_cols = value
604606 try :
605607 self ._check_disjoint_sets ()
606608 except ValueError :
607609 self ._z_cols = previous_value
608610 raise
609- else :
610- self ._z_cols = value
611611
612612 else :
613613 self ._z_cols = None
0 commit comments