From 7ab10b56b686cc3b296f62c8a464c7cc10ae66f8 Mon Sep 17 00:00:00 2001 From: KarlaMelgarejo Date: Fri, 27 Feb 2026 18:30:57 -0600 Subject: [PATCH 1/4] feat: add excluded_controls parameter to geolift pipeline Add ability to exclude specific locations from the control group selection process. The parameter is propagated through the full pipeline: - select_controls() and select_controls_exclusive() - evaluate_group() and evaluate_group_exclusive() - BetterGroups() including all executor.map calls (single-cell, multi-cell, global) - run_geo_analysis_streamlit_app() - experimental_design.py UI widget (below excluded_locations) - api.py Form parameter with parsing - tasks.py Celery task parameter Co-Authored-By: Claude Sonnet 4.6 --- Murray/main.py | 35 +++++++++++++++++++++++++++++++---- api.py | 9 ++++++++- experimental_design.py | 6 ++++++ tasks.py | 2 ++ 4 files changed, 47 insertions(+), 5 deletions(-) diff --git a/Murray/main.py b/Murray/main.py index befb3aa..28b81a7 100644 --- a/Murray/main.py +++ b/Murray/main.py @@ -107,7 +107,7 @@ def select_treatments(similarity_matrix, treatment_size, excluded_locations): def select_controls( - correlation_matrix, treatment_group, min_correlation=0.8, fallback_n=1 + correlation_matrix, treatment_group, min_correlation=0.8, fallback_n=1, excluded_controls=None ): """ Dynamically selects control group states based on correlation values. @@ -118,6 +118,7 @@ def select_controls( treatment_group (list): List of states in the treatment group. min_correlation (float): Minimum correlation threshold to consider a state as part of the control group. fallback_n (int): Number of top correlated states to select if no state meets the min_correlation. + excluded_controls (list): List of states to exclude from the control group. Returns: list: List of states selected as the control group. @@ -126,6 +127,9 @@ def select_controls( f"select_controls called: treatment_group={treatment_group}, min_correlation={min_correlation}" ) + if excluded_controls is None: + excluded_controls = [] + control_group = set() for treatment_location in treatment_group: @@ -140,6 +144,7 @@ def select_controls( treatment_row[ (treatment_row >= min_correlation) & (~treatment_row.index.isin(treatment_group)) + & (~treatment_row.index.isin(excluded_controls)) ] .sort_values(ascending=False) .index.tolist() @@ -150,7 +155,10 @@ def select_controls( f"No states meet min_correlation {min_correlation} for {treatment_location}, using fallback" ) similar_states = ( - treatment_row[~treatment_row.index.isin(treatment_group)] + treatment_row[ + ~treatment_row.index.isin(treatment_group) + & (~treatment_row.index.isin(excluded_controls)) + ] .sort_values(ascending=False) .head(fallback_n) .index.tolist() @@ -343,7 +351,7 @@ def smape(A, F): def evaluate_group( - treatment_group, data, total_Y, correlation_matrix, min_holdout, df_pivot, treatment_period=None + treatment_group, data, total_Y, correlation_matrix, min_holdout, df_pivot, treatment_period=None, excluded_controls=None ): """ Evaluates a treatment group and returns error metrics. @@ -377,6 +385,7 @@ def evaluate_group( correlation_matrix=correlation_matrix, treatment_group=treatment_group, min_correlation=0.8, + excluded_controls=excluded_controls, ) logger.debug(f"Control group selected: {control_group}") @@ -567,6 +576,7 @@ def select_controls_exclusive( excluded_locations=None, min_correlation=0.8, fallback_n=1, + excluded_controls=None, ): """ Dynamically selects control group states based on correlation values. @@ -574,6 +584,7 @@ def select_controls_exclusive( - Excludes current treatment group locations - Excludes globally excluded locations - Excludes locations that have been used as treatment in previous cells + - Excludes explicitly excluded control locations - ALLOWS reuse of control locations from previous cells Args: @@ -583,6 +594,7 @@ def select_controls_exclusive( excluded_locations (list): List of globally excluded locations. min_correlation (float): Minimum correlation threshold to consider a state as part of the control group. fallback_n (int): Number of top correlated states to select if no state meets the min_correlation. + excluded_controls (list): List of states to exclude from the control group. Returns: list: List of states selected as the control group. @@ -591,6 +603,8 @@ def select_controls_exclusive( used_treatment_locations = set() if excluded_locations is None: excluded_locations = [] + if excluded_controls is None: + excluded_controls = [] logger.debug( f"select_controls_exclusive called: treatment_group={treatment_group}, used_treatment_locations={len(used_treatment_locations)}, excluded_locations={len(excluded_locations)}" @@ -598,7 +612,7 @@ def select_controls_exclusive( control_group = set() all_excluded = ( - set(treatment_group) | used_treatment_locations | set(excluded_locations) + set(treatment_group) | used_treatment_locations | set(excluded_locations) | set(excluded_controls) ) for treatment_location in treatment_group: @@ -657,6 +671,7 @@ def evaluate_group_exclusive( used_treatment_locations=None, excluded_locations=None, treatment_period=None, + excluded_controls=None, ): """ Evaluates a treatment group with location exclusivity for multi-cell mode. @@ -704,6 +719,7 @@ def evaluate_group_exclusive( used_treatment_locations=used_treatment_locations, excluded_locations=excluded_locations, min_correlation=0.8, + excluded_controls=excluded_controls, ) logger.debug(f"Control group selected: {control_group}") @@ -798,6 +814,7 @@ def BetterGroups( status_updater=None, multicell_config=None, global_optimization=False, + excluded_controls=None, ): """ Enhanced simulates and evaluates treatment groups for geo-experiments. @@ -909,6 +926,8 @@ def BetterGroups( [df_pivot] * total_groups, [used_treatment_locations] * total_groups, [excluded_locations] * total_groups, + [None] * total_groups, + [excluded_controls] * total_groups, ) for idx, result in enumerate(futures): @@ -977,6 +996,7 @@ def BetterGroups( df_pivot=df_pivot, used_treatment_locations=current_used_treatments, excluded_locations=excluded_locations, + excluded_controls=excluded_controls, ) if result is not None: final_results.append(result) @@ -1038,6 +1058,8 @@ def BetterGroups( [correlation_matrix] * total_groups, [min_holdout] * total_groups, [df_pivot] * total_groups, + [None] * total_groups, + [excluded_controls] * total_groups, ) for idx, result in enumerate(futures): results.append(result) @@ -1316,6 +1338,8 @@ def optimize_global_multicell( [df_pivot] * len(groups), [set()] * len(groups), # No used locations in phase 1 [excluded_locations] * len(groups), + [None] * len(groups), + [excluded_controls] * len(groups), ) for result in futures: @@ -2061,6 +2085,7 @@ def run_geo_analysis_streamlit_app( inference_type="iid", global_optimization=False, progress_updater=None, + excluded_controls=None, ): """ Runs a complete geo analysis pipeline including market correlation, group optimization, @@ -2132,6 +2157,7 @@ def run_geo_analysis_streamlit_app( status_updater=status_text_1, multicell_config=multicell_config, global_optimization=global_optimization, + excluded_controls=excluded_controls, ) if simulation_results is None: @@ -2311,6 +2337,7 @@ def run_geo_analysis( correlation_matrix=correlation_matrix, progress_updater=progress_bar_1, status_updater=status_text_1, + excluded_controls=excluded_controls, ) # Step 3: Evaluate sensitivity for different deltas and periods diff --git a/api.py b/api.py index 2d781be..a511bc2 100644 --- a/api.py +++ b/api.py @@ -125,6 +125,7 @@ async def analyze_design( location_column: str = Form(None), target_column: str = Form(None), excluded_locations: str = Form(None), + excluded_controls: str = Form(None), maximum_treatment_percentage: float = Form(0.3), significance_level: float = Form(0.1), deltas_range: str = Form("0.01,0.1,0.01"), @@ -184,7 +185,12 @@ async def analyze_design( excluded_locations = tuple(map(str, excluded_locations.split(','))) else: excluded_locations = tuple() - + # Handle empty excluded_controls properly + if excluded_controls and excluded_controls.strip(): + excluded_controls = tuple(map(str, excluded_controls.split(','))) + else: + excluded_controls = tuple() + # Process multicell parameters multicell_config = None if enable_multicell: @@ -221,6 +227,7 @@ async def analyze_design( location_column=location_column, target_column=target_column, excluded_locations=excluded_locations, + excluded_controls=excluded_controls, maximum_treatment_percentage=maximum_treatment_percentage, significance_level=significance_level, deltas_range=deltas_range, diff --git a/experimental_design.py b/experimental_design.py index 54dc01d..6ef8fec 100644 --- a/experimental_design.py +++ b/experimental_design.py @@ -767,6 +767,9 @@ def reset_states(): excluded_locations = st.multiselect( "Select excluded locations", cleaned["location"].unique() ) + excluded_controls = st.multiselect( + "Select excluded controls", cleaned["location"].unique() + ) # Analyze data and recommend statistical test data_analysis = analyze_data_characteristics(cleaned, col_target="Y") @@ -955,6 +958,7 @@ def reset_states(): # Track parameter changes to reset simulation state current_params = { "excluded_locations": excluded_locations, + "excluded_controls": excluded_controls, "maximum_treatment_percentage_pre": maximum_treatment_percentage_pre, "significance_level_pre": significance_level_pre, "deltas_range": (delta_min, delta_max, delta_step), @@ -1017,6 +1021,7 @@ def reset_states(): results = run_geo_analysis_streamlit_app( data=cleaned, excluded_locations=excluded_locations, + excluded_controls=excluded_controls, maximum_treatment_percentage=maximum_treatment_percentage, significance_level=significance_level, deltas_range=deltas_range, @@ -1056,6 +1061,7 @@ def reset_states(): with st.expander("🔧 Current Analysis Settings", expanded=False): st.write(f"**Maximum Treatment Percentage:** {maximum_treatment_percentage*100:.1f}%") st.write(f"**Excluded Locations:** {len(excluded_locations)} locations") + st.write(f"**Excluded Controls:** {len(excluded_controls)} controls") st.write(f"**Total Locations Available:** {len(cleaned['location'].unique())} locations") st.write(f"**Time Periods:** {len(cleaned['time'].unique())} periods") st.write(f"**Total Y Sum:** {cleaned['Y'].sum():,.2f}") diff --git a/tasks.py b/tasks.py index f6562be..289defd 100644 --- a/tasks.py +++ b/tasks.py @@ -107,6 +107,7 @@ def analyze_design_task( location_column: str, target_column: str, excluded_locations: tuple, + excluded_controls: tuple, maximum_treatment_percentage: float, significance_level: float, deltas_range: tuple, @@ -258,6 +259,7 @@ def advance_to_sensitivity(): results = run_geo_analysis_streamlit_app( data=data, excluded_locations=excluded_locations, + excluded_controls=excluded_controls, maximum_treatment_percentage=maximum_treatment_percentage, significance_level=significance_level, deltas_range=deltas_range, From 7205baca96e37ccac1b098276f17e138aa8af309 Mon Sep 17 00:00:00 2001 From: KarlaMelgarejo Date: Tue, 3 Mar 2026 17:55:42 -0600 Subject: [PATCH 2/4] feat: rename excluded params and add excluded_control_locations to evaluation - Rename excluded_locations -> excluded_treatment_locations across all files - Rename excluded_controls -> excluded_control_locations across all files - Add excluded_control_locations parameter to run_geo_evaluation() in post_analysis.py - Add excluded_control_locations widget in experimental_evaluation.py UI - Add logging of excluded_control_locations in run_geo_analysis_streamlit_app - Add validation in run_geo_evaluation for empty pre-treatment period with clear error message - Handle ValueError gracefully in experimental_evaluation.py UI Co-Authored-By: Claude Sonnet 4.6 --- Murray/main.py | 158 +++++++++++++++++++------------------ Murray/post_analysis.py | 9 +++ api.py | 24 +++--- experimental_design.py | 16 ++-- experimental_evaluation.py | 24 ++++-- tasks.py | 8 +- 6 files changed, 130 insertions(+), 109 deletions(-) diff --git a/Murray/main.py b/Murray/main.py index 28b81a7..ba76375 100644 --- a/Murray/main.py +++ b/Murray/main.py @@ -25,7 +25,7 @@ def is_streamlit_context(): ) -def select_treatments(similarity_matrix, treatment_size, excluded_locations): +def select_treatments(similarity_matrix, treatment_size, excluded_treatment_locations): """ Selects n combinations of treatments based on a similarity DataFrame, excluding certain states from the treatment selection but allowing their inclusion in the control. @@ -34,23 +34,23 @@ def select_treatments(similarity_matrix, treatment_size, excluded_locations): Args: similarity_matrix (pd.DataFrame): DataFrame containing correlations between locations in a standard matrix format treatment_size (int): Number of treatments to select for each combination. - excluded_locations (list): List of locations to exclude from the treatment selection. + excluded_treatment_locations (list): List of locations to exclude from the treatment selection. Returns: list: A list of unique combinations, each combination being a list of states. """ - # Filter out empty strings from excluded_locations - excluded_locations = [loc for loc in excluded_locations if loc.strip()] + # Filter out empty strings from excluded_treatment_locations + excluded_treatment_locations = [loc for loc in excluded_treatment_locations if loc.strip()] logger.debug( - f"select_treatments called: treatment_size={treatment_size}, excluded_locations={excluded_locations}" + f"select_treatments called: treatment_size={treatment_size}, excluded_treatment_locations={excluded_treatment_locations}" ) missing_locations = [ location - for location in excluded_locations + for location in excluded_treatment_locations if location not in similarity_matrix.index or location not in similarity_matrix.columns ] @@ -64,8 +64,8 @@ def select_treatments(similarity_matrix, treatment_size, excluded_locations): ) similarity_matrix_filtered = similarity_matrix.loc[ - ~similarity_matrix.index.isin(excluded_locations), - ~similarity_matrix.columns.isin(excluded_locations), + ~similarity_matrix.index.isin(excluded_treatment_locations), + ~similarity_matrix.columns.isin(excluded_treatment_locations), ] logger.debug( @@ -86,10 +86,10 @@ def select_treatments(similarity_matrix, treatment_size, excluded_locations): max_combinations = comb(n, r) n_combinations = max_combinations - if n_combinations > 5000: - n_combinations = 5000 - # if n_combinations > 2000: - # n_combinations = 2000 + # if n_combinations > 5000: + # n_combinations = 5000 + if n_combinations > 2000: + n_combinations = 2000 logger.debug(f"Generating {n_combinations} combinations") @@ -107,7 +107,7 @@ def select_treatments(similarity_matrix, treatment_size, excluded_locations): def select_controls( - correlation_matrix, treatment_group, min_correlation=0.8, fallback_n=1, excluded_controls=None + correlation_matrix, treatment_group, min_correlation=0.8, fallback_n=1, excluded_control_locations=None ): """ Dynamically selects control group states based on correlation values. @@ -118,7 +118,7 @@ def select_controls( treatment_group (list): List of states in the treatment group. min_correlation (float): Minimum correlation threshold to consider a state as part of the control group. fallback_n (int): Number of top correlated states to select if no state meets the min_correlation. - excluded_controls (list): List of states to exclude from the control group. + excluded_control_locations (list): List of states to exclude from the control group. Returns: list: List of states selected as the control group. @@ -127,8 +127,8 @@ def select_controls( f"select_controls called: treatment_group={treatment_group}, min_correlation={min_correlation}" ) - if excluded_controls is None: - excluded_controls = [] + if excluded_control_locations is None: + excluded_control_locations = [] control_group = set() @@ -144,7 +144,7 @@ def select_controls( treatment_row[ (treatment_row >= min_correlation) & (~treatment_row.index.isin(treatment_group)) - & (~treatment_row.index.isin(excluded_controls)) + & (~treatment_row.index.isin(excluded_control_locations)) ] .sort_values(ascending=False) .index.tolist() @@ -157,7 +157,7 @@ def select_controls( similar_states = ( treatment_row[ ~treatment_row.index.isin(treatment_group) - & (~treatment_row.index.isin(excluded_controls)) + & (~treatment_row.index.isin(excluded_control_locations)) ] .sort_values(ascending=False) .head(fallback_n) @@ -351,7 +351,7 @@ def smape(A, F): def evaluate_group( - treatment_group, data, total_Y, correlation_matrix, min_holdout, df_pivot, treatment_period=None, excluded_controls=None + treatment_group, data, total_Y, correlation_matrix, min_holdout, df_pivot, treatment_period=None, excluded_control_locations=None ): """ Evaluates a treatment group and returns error metrics. @@ -385,7 +385,7 @@ def evaluate_group( correlation_matrix=correlation_matrix, treatment_group=treatment_group, min_correlation=0.8, - excluded_controls=excluded_controls, + excluded_control_locations=excluded_control_locations, ) logger.debug(f"Control group selected: {control_group}") @@ -471,7 +471,7 @@ def evaluate_group( def select_treatments_exclusive( - similarity_matrix, treatment_size, excluded_locations, used_treatment_locations=None + similarity_matrix, treatment_size, excluded_treatment_locations, used_treatment_locations=None ): """ Improved treatment selection for multi-cell mode ensuring treatment location exclusivity. @@ -480,7 +480,7 @@ def select_treatments_exclusive( Args: similarity_matrix (pd.DataFrame): DataFrame containing correlations between locations treatment_size (int): Number of treatments to select for each combination - excluded_locations (list): List of locations to exclude globally + excluded_treatment_locations (list): List of locations to exclude globally used_treatment_locations (set): Set of treatment locations already used in previous cells Returns: @@ -489,9 +489,9 @@ def select_treatments_exclusive( if used_treatment_locations is None: used_treatment_locations = set() - # Filter out empty strings from excluded_locations - excluded_locations = [loc for loc in excluded_locations if loc.strip()] - all_excluded = set(excluded_locations) | used_treatment_locations + # Filter out empty strings from excluded_treatment_locations + excluded_treatment_locations = [loc for loc in excluded_treatment_locations if loc.strip()] + all_excluded = set(excluded_treatment_locations) | used_treatment_locations logger.debug( f"select_treatments_exclusive: treatment_size={treatment_size}, excluded={len(all_excluded)} locations" @@ -499,7 +499,7 @@ def select_treatments_exclusive( missing_locations = [ location - for location in excluded_locations + for location in excluded_treatment_locations if location not in similarity_matrix.index or location not in similarity_matrix.columns ] @@ -573,10 +573,10 @@ def select_controls_exclusive( correlation_matrix, treatment_group, used_treatment_locations=None, - excluded_locations=None, + excluded_treatment_locations=None, min_correlation=0.8, fallback_n=1, - excluded_controls=None, + excluded_control_locations=None, ): """ Dynamically selects control group states based on correlation values. @@ -591,28 +591,28 @@ def select_controls_exclusive( correlation_matrix (pd.DataFrame): Correlation matrix between states. treatment_group (list): List of states in the treatment group. used_treatment_locations (set): Set of treatment locations already used in previous cells. - excluded_locations (list): List of globally excluded locations. + excluded_treatment_locations (list): List of globally excluded locations. min_correlation (float): Minimum correlation threshold to consider a state as part of the control group. fallback_n (int): Number of top correlated states to select if no state meets the min_correlation. - excluded_controls (list): List of states to exclude from the control group. + excluded_control_locations (list): List of states to exclude from the control group. Returns: list: List of states selected as the control group. """ if used_treatment_locations is None: used_treatment_locations = set() - if excluded_locations is None: - excluded_locations = [] - if excluded_controls is None: - excluded_controls = [] + if excluded_treatment_locations is None: + excluded_treatment_locations = [] + if excluded_control_locations is None: + excluded_control_locations = [] logger.debug( - f"select_controls_exclusive called: treatment_group={treatment_group}, used_treatment_locations={len(used_treatment_locations)}, excluded_locations={len(excluded_locations)}" + f"select_controls_exclusive called: treatment_group={treatment_group}, used_treatment_locations={len(used_treatment_locations)}, excluded_treatment_locations={len(excluded_treatment_locations)}" ) control_group = set() all_excluded = ( - set(treatment_group) | used_treatment_locations | set(excluded_locations) | set(excluded_controls) + set(treatment_group) | used_treatment_locations | set(excluded_treatment_locations) | set(excluded_control_locations) ) for treatment_location in treatment_group: @@ -669,9 +669,9 @@ def evaluate_group_exclusive( min_holdout, df_pivot, used_treatment_locations=None, - excluded_locations=None, + excluded_treatment_locations=None, treatment_period=None, - excluded_controls=None, + excluded_control_locations=None, ): """ Evaluates a treatment group with location exclusivity for multi-cell mode. @@ -687,7 +687,7 @@ def evaluate_group_exclusive( min_holdout (float): Minimum required holdout percentage df_pivot (pd.DataFrame): Pivoted data with time as index and locations as columns used_treatment_locations (set): Set of locations already used as treatment in other cells - excluded_locations (list): List of globally excluded locations + excluded_treatment_locations (list): List of globally excluded locations treatment_period (int): Number of periods for treatment (if None, uses 80/20 split) Returns: @@ -717,9 +717,9 @@ def evaluate_group_exclusive( correlation_matrix=correlation_matrix, treatment_group=treatment_group, used_treatment_locations=used_treatment_locations, - excluded_locations=excluded_locations, + excluded_treatment_locations=excluded_treatment_locations, min_correlation=0.8, - excluded_controls=excluded_controls, + excluded_control_locations=excluded_control_locations, ) logger.debug(f"Control group selected: {control_group}") @@ -806,7 +806,7 @@ def evaluate_group_exclusive( def BetterGroups( similarity_matrix, - excluded_locations, + excluded_treatment_locations, data, correlation_matrix, maximum_treatment_percentage=0.50, @@ -814,7 +814,7 @@ def BetterGroups( status_updater=None, multicell_config=None, global_optimization=False, - excluded_controls=None, + excluded_control_locations=None, ): """ Enhanced simulates and evaluates treatment groups for geo-experiments. @@ -826,7 +826,7 @@ def BetterGroups( Args: similarity_matrix (pd.DataFrame): Correlation matrix for treatment selection - excluded_locations (list): List of locations to exclude from treatment selection + excluded_treatment_locations (list): List of locations to exclude from treatment selection data (pd.DataFrame): Input data with 'location', 'time', and 'Y' columns correlation_matrix (pd.DataFrame): Market correlation matrix for control selection maximum_treatment_percentage (float): Maximum treatment percentage (default: 0.50) @@ -846,10 +846,10 @@ def BetterGroups( """ unique_locations = data["location"].unique() no_locations = len(unique_locations) - # max_group_size = round(no_locations * 0.35) - # min_elements_in_treatment = round(no_locations * 0.20) - max_group_size = round(no_locations * 0.45) - min_elements_in_treatment = round(no_locations * 0.15) + max_group_size = round(no_locations * 0.35) + min_elements_in_treatment = round(no_locations * 0.20) + # max_group_size = round(no_locations * 0.45) + #min_elements_in_treatment = round(no_locations * 0.15) min_holdout = 100 - (maximum_treatment_percentage * 100) total_Y = data["Y"].sum() @@ -873,7 +873,7 @@ def BetterGroups( similarity_matrix=similarity_matrix, allowed_sizes=sizes, total_cells_needed=top_n, - excluded_locations=excluded_locations, + excluded_treatment_locations=excluded_treatment_locations, data=data, correlation_matrix=correlation_matrix, maximum_treatment_percentage=maximum_treatment_percentage, @@ -889,7 +889,7 @@ def BetterGroups( for size in sizes: groups = select_treatments_exclusive( - similarity_matrix, size, excluded_locations, used_treatment_locations + similarity_matrix, size, excluded_treatment_locations, used_treatment_locations ) if not groups: logger.warning( @@ -904,7 +904,7 @@ def BetterGroups( select_treatments_exclusive( similarity_matrix, s, - excluded_locations, + excluded_treatment_locations, used_treatment_locations, ) ) @@ -925,9 +925,9 @@ def BetterGroups( [min_holdout] * total_groups, [df_pivot] * total_groups, [used_treatment_locations] * total_groups, - [excluded_locations] * total_groups, + [excluded_treatment_locations] * total_groups, [None] * total_groups, - [excluded_controls] * total_groups, + [excluded_control_locations] * total_groups, ) for idx, result in enumerate(futures): @@ -995,8 +995,8 @@ def BetterGroups( min_holdout=min_holdout, df_pivot=df_pivot, used_treatment_locations=current_used_treatments, - excluded_locations=excluded_locations, - excluded_controls=excluded_controls, + excluded_treatment_locations=excluded_treatment_locations, + excluded_control_locations=excluded_control_locations, ) if result is not None: final_results.append(result) @@ -1039,7 +1039,7 @@ def BetterGroups( logger.info(f"Starting single-cell mode") possible_groups = [] for size in range(min_elements_in_treatment, max_group_size + 1): - groups = select_treatments(similarity_matrix, size, excluded_locations) + groups = select_treatments(similarity_matrix, size, excluded_treatment_locations) possible_groups.extend(groups) if not possible_groups: @@ -1059,7 +1059,7 @@ def BetterGroups( [min_holdout] * total_groups, [df_pivot] * total_groups, [None] * total_groups, - [excluded_controls] * total_groups, + [excluded_control_locations] * total_groups, ) for idx, result in enumerate(futures): results.append(result) @@ -1175,7 +1175,7 @@ def update_progress(selected_count): return selected_cells -def _validate_multicell_config(similarity_matrix, allowed_sizes, total_cells_needed, excluded_locations, data): +def _validate_multicell_config(similarity_matrix, allowed_sizes, total_cells_needed, excluded_treatment_locations, data): """ Validate multicell configuration and provide actionable error messages. @@ -1183,7 +1183,7 @@ def _validate_multicell_config(similarity_matrix, allowed_sizes, total_cells_nee similarity_matrix: Correlation matrix for treatment selection allowed_sizes: List of allowed cell sizes to choose from total_cells_needed: Total number of cells in final experiment - excluded_locations: Globally excluded locations + excluded_treatment_locations: Globally excluded locations data: Input data Returns: @@ -1195,7 +1195,7 @@ def _validate_multicell_config(similarity_matrix, allowed_sizes, total_cells_nee unique_locations = data["location"].unique() total_locations = len(unique_locations) - excluded_count = len(set(excluded_locations)) + excluded_count = len(set(excluded_treatment_locations)) available_locations = total_locations - excluded_count # Check basic feasibility @@ -1250,7 +1250,7 @@ def optimize_global_multicell( similarity_matrix, allowed_sizes, total_cells_needed, - excluded_locations, + excluded_treatment_locations, data, correlation_matrix, maximum_treatment_percentage, @@ -1268,7 +1268,7 @@ def optimize_global_multicell( similarity_matrix: Correlation matrix for treatment selection allowed_sizes: List of allowed cell sizes to choose from total_cells_needed: Total number of cells in final experiment - excluded_locations: Globally excluded locations + excluded_treatment_locations: Globally excluded locations data: Input data correlation_matrix: Market correlation matrix maximum_treatment_percentage: Max treatment percentage @@ -1286,7 +1286,7 @@ def optimize_global_multicell( # Pre-flight validation is_valid, warnings, suggestions = _validate_multicell_config( - similarity_matrix, allowed_sizes, total_cells_needed, excluded_locations, data + similarity_matrix, allowed_sizes, total_cells_needed, excluded_treatment_locations, data ) for warning in warnings: @@ -1319,7 +1319,7 @@ def optimize_global_multicell( logger.info(f"Generating candidates for size {size}") groups = select_treatments_exclusive( - similarity_matrix, size, excluded_locations, used_treatment_locations=set() + similarity_matrix, size, excluded_treatment_locations, used_treatment_locations=set() ) if not groups: @@ -1337,9 +1337,9 @@ def optimize_global_multicell( [min_holdout] * len(groups), [df_pivot] * len(groups), [set()] * len(groups), # No used locations in phase 1 - [excluded_locations] * len(groups), + [excluded_treatment_locations] * len(groups), [None] * len(groups), - [excluded_controls] * len(groups), + [excluded_control_locations] * len(groups), ) for result in futures: @@ -1415,7 +1415,7 @@ def optimize_global_multicell( correlation_matrix=correlation_matrix, treatment_group=treatment_group, used_treatment_locations=all_treatment_locations, # Exclude ALL treatments - excluded_locations=excluded_locations, + excluded_treatment_locations=excluded_treatment_locations, min_correlation=0.8, ) @@ -2073,7 +2073,7 @@ def run_geo_analysis_streamlit_app( significance_level, deltas_range, periods_range, - excluded_locations, + excluded_treatment_locations, progress_bar_1=None, status_text_1=None, progress_bar_2=None, @@ -2085,7 +2085,7 @@ def run_geo_analysis_streamlit_app( inference_type="iid", global_optimization=False, progress_updater=None, - excluded_controls=None, + excluded_control_locations=None, ): """ Runs a complete geo analysis pipeline including market correlation, group optimization, @@ -2097,7 +2097,7 @@ def run_geo_analysis_streamlit_app( significance_level (float): Significance level for statistical testing. deltas_range (tuple): Range of delta values to evaluate as (start, stop, step). periods_range (tuple): Range of treatment periods to evaluate as (start, stop, step). - excluded_locations (list): List of states to exclude from the analysis. + excluded_treatment_locations (list): List of states to exclude from the analysis. progress_bar_1 (callable): Progress bar updater for group optimization phase. status_text_1 (callable): Status text updater for group optimization phase. progress_bar_2 (callable): Progress bar updater for sensitivity evaluation phase. @@ -2124,7 +2124,8 @@ def run_geo_analysis_streamlit_app( logger.info(f" - significance_level: {significance_level}") logger.info(f" - deltas_range: {deltas_range}") logger.info(f" - periods_range: {periods_range}") - logger.info(f" - excluded_locations: {excluded_locations}") + logger.info(f" - excluded_treatment_locations: {excluded_treatment_locations}") + logger.info(f" - excluded_control_locations: {excluded_control_locations}") logger.info(f" - multicell_config: {multicell_config}") logger.info("=" * 80) @@ -2150,14 +2151,14 @@ def run_geo_analysis_streamlit_app( simulation_results = BetterGroups( similarity_matrix=correlation_matrix, maximum_treatment_percentage=maximum_treatment_percentage, - excluded_locations=excluded_locations, + excluded_treatment_locations=excluded_treatment_locations, data=data, correlation_matrix=correlation_matrix, progress_updater=progress_bar_1, status_updater=status_text_1, multicell_config=multicell_config, global_optimization=global_optimization, - excluded_controls=excluded_controls, + excluded_control_locations=excluded_control_locations, ) if simulation_results is None: @@ -2271,7 +2272,7 @@ def run_geo_analysis( significance_level, deltas_range, periods_range, - excluded_locations, + excluded_treatment_locations, progress_bar_1=None, status_text_1=None, progress_bar_2=None, @@ -2292,7 +2293,7 @@ def run_geo_analysis( significance_level (float): Significance level for statistical testing. deltas_range (tuple): Range of delta values to evaluate as (start, stop, step). periods_range (tuple): Range of treatment periods to evaluate as (start, stop, step). - excluded_locations (list): List of states to exclude from the analysis. + excluded_treatment_locations (list): List of states to exclude from the analysis. progress_bar_1 (optional): First progress bar for UI updates. status_text_1 (optional): First status text for UI updates. progress_bar_2 (optional): Second progress bar for UI updates. @@ -2316,7 +2317,8 @@ def run_geo_analysis( logger.info(f" - significance_level: {significance_level}") logger.info(f" - deltas_range: {deltas_range}") logger.info(f" - periods_range: {periods_range}") - logger.info(f" - excluded_locations: {excluded_locations}") + logger.info(f" - excluded_treatment_locations: {excluded_treatment_locations}") + logger.info(f" - excluded_control_locations: {excluded_control_locations}") logger.info("=" * 80) if progress_bar_1 or progress_bar_2 or status_text_1 or status_text_2 is None: @@ -2332,12 +2334,12 @@ def run_geo_analysis( simulation_results = BetterGroups( similarity_matrix=correlation_matrix, maximum_treatment_percentage=maximum_treatment_percentage, - excluded_locations=excluded_locations, + excluded_treatment_locations=excluded_treatment_locations, data=data, correlation_matrix=correlation_matrix, progress_updater=progress_bar_1, status_updater=status_text_1, - excluded_controls=excluded_controls, + excluded_control_locations=excluded_control_locations, ) # Step 3: Evaluate sensitivity for different deltas and periods diff --git a/Murray/post_analysis.py b/Murray/post_analysis.py index 47d32a1..022bff0 100644 --- a/Murray/post_analysis.py +++ b/Murray/post_analysis.py @@ -18,6 +18,7 @@ def run_geo_evaluation( n_permutations=50000, inference_type="iid", significance_level=0.1, + excluded_control_locations=None, ): logger.info("Starting run_geo_evaluation") logger.info(f"Input data shape: {data_input.shape}") @@ -55,6 +56,7 @@ def smape(A, F): correlation_matrix=correlation_matrix, treatment_group=treatment_group, min_correlation=0.8, + excluded_control_locations=excluded_control_locations, ) logger.info(f"Control group selected: {control_group}") @@ -100,6 +102,13 @@ def smape(A, F): time_train = time_index[:start_position_treatment] time_test = time_index[start_position_treatment:] + if len(X_train) == 0: + raise ValueError( + f"No pre-treatment periods available for training. " + f"The treatment start date ({start_treatment}) is at or before the beginning of the dataset. " + f"Please select a later treatment start date." + ) + logger.info("Fitting synthetic control model...") model = SyntheticControl(use_ridge_adjustment=True, ridge_alpha=1.0) model.fit(X_train, y_train, time_train=time_train) diff --git a/api.py b/api.py index a511bc2..43182f1 100644 --- a/api.py +++ b/api.py @@ -124,8 +124,8 @@ async def analyze_design( date_column: str = Form(None), location_column: str = Form(None), target_column: str = Form(None), - excluded_locations: str = Form(None), - excluded_controls: str = Form(None), + excluded_treatment_locations: str = Form(None), + excluded_control_locations: str = Form(None), maximum_treatment_percentage: float = Form(0.3), significance_level: float = Form(0.1), deltas_range: str = Form("0.01,0.1,0.01"), @@ -180,16 +180,16 @@ async def analyze_design( # Parse parameters deltas_range = tuple(map(float, deltas_range.split(','))) periods_range = tuple(map(int, periods_range.split(','))) - # Handle empty excluded_locations properly - if excluded_locations and excluded_locations.strip(): - excluded_locations = tuple(map(str, excluded_locations.split(','))) + # Handle empty excluded_treatment_locations properly + if excluded_treatment_locations and excluded_treatment_locations.strip(): + excluded_treatment_locations = tuple(map(str, excluded_treatment_locations.split(','))) else: - excluded_locations = tuple() - # Handle empty excluded_controls properly - if excluded_controls and excluded_controls.strip(): - excluded_controls = tuple(map(str, excluded_controls.split(','))) + excluded_treatment_locations = tuple() + # Handle empty excluded_control_locations properly + if excluded_control_locations and excluded_control_locations.strip(): + excluded_control_locations = tuple(map(str, excluded_control_locations.split(','))) else: - excluded_controls = tuple() + excluded_control_locations = tuple() # Process multicell parameters multicell_config = None @@ -226,8 +226,8 @@ async def analyze_design( date_column=date_column, location_column=location_column, target_column=target_column, - excluded_locations=excluded_locations, - excluded_controls=excluded_controls, + excluded_treatment_locations=excluded_treatment_locations, + excluded_control_locations=excluded_control_locations, maximum_treatment_percentage=maximum_treatment_percentage, significance_level=significance_level, deltas_range=deltas_range, diff --git a/experimental_design.py b/experimental_design.py index 6ef8fec..d10b963 100644 --- a/experimental_design.py +++ b/experimental_design.py @@ -764,10 +764,10 @@ def reset_states(): unsafe_allow_html=True, ) # Location exclusion and test type selection - excluded_locations = st.multiselect( + excluded_treatment_locations = st.multiselect( "Select excluded locations", cleaned["location"].unique() ) - excluded_controls = st.multiselect( + excluded_control_locations = st.multiselect( "Select excluded controls", cleaned["location"].unique() ) @@ -957,8 +957,8 @@ def reset_states(): # Track parameter changes to reset simulation state current_params = { - "excluded_locations": excluded_locations, - "excluded_controls": excluded_controls, + "excluded_treatment_locations": excluded_treatment_locations, + "excluded_control_locations": excluded_control_locations, "maximum_treatment_percentage_pre": maximum_treatment_percentage_pre, "significance_level_pre": significance_level_pre, "deltas_range": (delta_min, delta_max, delta_step), @@ -1020,8 +1020,8 @@ def reset_states(): # Run main geo analysis simulation results = run_geo_analysis_streamlit_app( data=cleaned, - excluded_locations=excluded_locations, - excluded_controls=excluded_controls, + excluded_treatment_locations=excluded_treatment_locations, + excluded_control_locations=excluded_control_locations, maximum_treatment_percentage=maximum_treatment_percentage, significance_level=significance_level, deltas_range=deltas_range, @@ -1060,8 +1060,8 @@ def reset_states(): # Show current settings for debugging with st.expander("🔧 Current Analysis Settings", expanded=False): st.write(f"**Maximum Treatment Percentage:** {maximum_treatment_percentage*100:.1f}%") - st.write(f"**Excluded Locations:** {len(excluded_locations)} locations") - st.write(f"**Excluded Controls:** {len(excluded_controls)} controls") + st.write(f"**Excluded Locations:** {len(excluded_treatment_locations)} locations") + st.write(f"**Excluded Controls:** {len(excluded_control_locations)} controls") st.write(f"**Total Locations Available:** {len(cleaned['location'].unique())} locations") st.write(f"**Time Periods:** {len(cleaned['time'].unique())} periods") st.write(f"**Total Y Sum:** {cleaned['Y'].sum():,.2f}") diff --git a/experimental_evaluation.py b/experimental_evaluation.py index 5cb5b4f..f24d7b7 100644 --- a/experimental_evaluation.py +++ b/experimental_evaluation.py @@ -766,6 +766,9 @@ def reset_states(): treatment_group = st.multiselect( "Select treatment group", data1["location"].unique() ) + excluded_control_locations = st.multiselect( + "Select excluded control locations", data1["location"].unique() + ) spend = st.number_input("Select spend") mmm_option = st.selectbox( "Select the option to calculate the iROAS or iCPA", ["iROAS", "iCPA"] @@ -795,6 +798,7 @@ def reset_states(): "start_treatment": start_treatment, "end_treatment": end_treatment, "treatment_group": treatment_group, + "excluded_control_locations": excluded_control_locations, "spend": spend, "mmm_option": mmm_option, "col_target": col_target, @@ -814,14 +818,20 @@ def reset_states(): update_metrics("experimental_evaluation") with st.spinner("Running analysis..."): + try: + results = run_geo_evaluation( + data1, + start_treatment, + end_treatment, + treatment_group, + spend, + excluded_control_locations=excluded_control_locations, + ) + except ValueError as e: + st.error(f"Error: {e}") + st.session_state.evaluation_button_clicked = False + st.stop() - results = run_geo_evaluation( - data1, - start_treatment, - end_treatment, - treatment_group, - spend, - ) treatment = results["treatment"] st.session_state.treatment = treatment counterfactual = results["counterfactual"] diff --git a/tasks.py b/tasks.py index 289defd..10d0e83 100644 --- a/tasks.py +++ b/tasks.py @@ -106,8 +106,8 @@ def analyze_design_task( date_column: str, location_column: str, target_column: str, - excluded_locations: tuple, - excluded_controls: tuple, + excluded_treatment_locations: tuple, + excluded_control_locations: tuple, maximum_treatment_percentage: float, significance_level: float, deltas_range: tuple, @@ -258,8 +258,8 @@ def advance_to_sensitivity(): results = run_geo_analysis_streamlit_app( data=data, - excluded_locations=excluded_locations, - excluded_controls=excluded_controls, + excluded_treatment_locations=excluded_treatment_locations, + excluded_control_locations=excluded_control_locations, maximum_treatment_percentage=maximum_treatment_percentage, significance_level=significance_level, deltas_range=deltas_range, From 0fe4ab5346914e54b8c4f4b292fac45d787d9672 Mon Sep 17 00:00:00 2001 From: KarlaMelgarejo Date: Wed, 4 Mar 2026 15:03:21 -0600 Subject: [PATCH 3/4] fix: increase max combinations limit from 2000 to 5000 Co-Authored-By: Claude Sonnet 4.6 --- Murray/main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Murray/main.py b/Murray/main.py index ba76375..685b239 100644 --- a/Murray/main.py +++ b/Murray/main.py @@ -86,10 +86,10 @@ def select_treatments(similarity_matrix, treatment_size, excluded_treatment_loca max_combinations = comb(n, r) n_combinations = max_combinations - # if n_combinations > 5000: - # n_combinations = 5000 - if n_combinations > 2000: - n_combinations = 2000 + if n_combinations > 5000: + n_combinations = 5000 + #if n_combinations > 2000: + # n_combinations = 2000 logger.debug(f"Generating {n_combinations} combinations") From 43586f5d1f267730dfad546742c91404d3c6e347 Mon Sep 17 00:00:00 2001 From: KarlaMelgarejo Date: Tue, 7 Apr 2026 16:06:46 -0600 Subject: [PATCH 4/4] Mi ultima version (falta revisar el smape) --- Murray/main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Murray/main.py b/Murray/main.py index 685b239..073963d 100644 --- a/Murray/main.py +++ b/Murray/main.py @@ -846,10 +846,10 @@ def BetterGroups( """ unique_locations = data["location"].unique() no_locations = len(unique_locations) - max_group_size = round(no_locations * 0.35) - min_elements_in_treatment = round(no_locations * 0.20) - # max_group_size = round(no_locations * 0.45) - #min_elements_in_treatment = round(no_locations * 0.15) + #max_group_size = round(no_locations * 0.35) + #min_elements_in_treatment = round(no_locations * 0.20) + max_group_size = round(no_locations * 0.45) + min_elements_in_treatment = round(no_locations * 0.15) min_holdout = 100 - (maximum_treatment_percentage * 100) total_Y = data["Y"].sum()