diff --git a/Murray/main.py b/Murray/main.py index befb3aa..073963d 100644 --- a/Murray/main.py +++ b/Murray/main.py @@ -25,7 +25,7 @@ def is_streamlit_context(): ) -def select_treatments(similarity_matrix, treatment_size, excluded_locations): +def select_treatments(similarity_matrix, treatment_size, excluded_treatment_locations): """ Selects n combinations of treatments based on a similarity DataFrame, excluding certain states from the treatment selection but allowing their inclusion in the control. @@ -34,23 +34,23 @@ def select_treatments(similarity_matrix, treatment_size, excluded_locations): Args: similarity_matrix (pd.DataFrame): DataFrame containing correlations between locations in a standard matrix format treatment_size (int): Number of treatments to select for each combination. - excluded_locations (list): List of locations to exclude from the treatment selection. + excluded_treatment_locations (list): List of locations to exclude from the treatment selection. Returns: list: A list of unique combinations, each combination being a list of states. """ - # Filter out empty strings from excluded_locations - excluded_locations = [loc for loc in excluded_locations if loc.strip()] + # Filter out empty strings from excluded_treatment_locations + excluded_treatment_locations = [loc for loc in excluded_treatment_locations if loc.strip()] logger.debug( - f"select_treatments called: treatment_size={treatment_size}, excluded_locations={excluded_locations}" + f"select_treatments called: treatment_size={treatment_size}, excluded_treatment_locations={excluded_treatment_locations}" ) missing_locations = [ location - for location in excluded_locations + for location in excluded_treatment_locations if location not in similarity_matrix.index or location not in similarity_matrix.columns ] @@ -64,8 +64,8 @@ def select_treatments(similarity_matrix, treatment_size, excluded_locations): ) similarity_matrix_filtered = similarity_matrix.loc[ - ~similarity_matrix.index.isin(excluded_locations), - ~similarity_matrix.columns.isin(excluded_locations), + ~similarity_matrix.index.isin(excluded_treatment_locations), + ~similarity_matrix.columns.isin(excluded_treatment_locations), ] logger.debug( @@ -88,8 +88,8 @@ def select_treatments(similarity_matrix, treatment_size, excluded_locations): n_combinations = max_combinations if n_combinations > 5000: n_combinations = 5000 - # if n_combinations > 2000: - # n_combinations = 2000 + #if n_combinations > 2000: + # n_combinations = 2000 logger.debug(f"Generating {n_combinations} combinations") @@ -107,7 +107,7 @@ def select_treatments(similarity_matrix, treatment_size, excluded_locations): def select_controls( - correlation_matrix, treatment_group, min_correlation=0.8, fallback_n=1 + correlation_matrix, treatment_group, min_correlation=0.8, fallback_n=1, excluded_control_locations=None ): """ Dynamically selects control group states based on correlation values. @@ -118,6 +118,7 @@ def select_controls( treatment_group (list): List of states in the treatment group. min_correlation (float): Minimum correlation threshold to consider a state as part of the control group. fallback_n (int): Number of top correlated states to select if no state meets the min_correlation. + excluded_control_locations (list): List of states to exclude from the control group. Returns: list: List of states selected as the control group. @@ -126,6 +127,9 @@ def select_controls( f"select_controls called: treatment_group={treatment_group}, min_correlation={min_correlation}" ) + if excluded_control_locations is None: + excluded_control_locations = [] + control_group = set() for treatment_location in treatment_group: @@ -140,6 +144,7 @@ def select_controls( treatment_row[ (treatment_row >= min_correlation) & (~treatment_row.index.isin(treatment_group)) + & (~treatment_row.index.isin(excluded_control_locations)) ] .sort_values(ascending=False) .index.tolist() @@ -150,7 +155,10 @@ def select_controls( f"No states meet min_correlation {min_correlation} for {treatment_location}, using fallback" ) similar_states = ( - treatment_row[~treatment_row.index.isin(treatment_group)] + treatment_row[ + ~treatment_row.index.isin(treatment_group) + & (~treatment_row.index.isin(excluded_control_locations)) + ] .sort_values(ascending=False) .head(fallback_n) .index.tolist() @@ -343,7 +351,7 @@ def smape(A, F): def evaluate_group( - treatment_group, data, total_Y, correlation_matrix, min_holdout, df_pivot, treatment_period=None + treatment_group, data, total_Y, correlation_matrix, min_holdout, df_pivot, treatment_period=None, excluded_control_locations=None ): """ Evaluates a treatment group and returns error metrics. @@ -377,6 +385,7 @@ def evaluate_group( correlation_matrix=correlation_matrix, treatment_group=treatment_group, min_correlation=0.8, + excluded_control_locations=excluded_control_locations, ) logger.debug(f"Control group selected: {control_group}") @@ -462,7 +471,7 @@ def evaluate_group( def select_treatments_exclusive( - similarity_matrix, treatment_size, excluded_locations, used_treatment_locations=None + similarity_matrix, treatment_size, excluded_treatment_locations, used_treatment_locations=None ): """ Improved treatment selection for multi-cell mode ensuring treatment location exclusivity. @@ -471,7 +480,7 @@ def select_treatments_exclusive( Args: similarity_matrix (pd.DataFrame): DataFrame containing correlations between locations treatment_size (int): Number of treatments to select for each combination - excluded_locations (list): List of locations to exclude globally + excluded_treatment_locations (list): List of locations to exclude globally used_treatment_locations (set): Set of treatment locations already used in previous cells Returns: @@ -480,9 +489,9 @@ def select_treatments_exclusive( if used_treatment_locations is None: used_treatment_locations = set() - # Filter out empty strings from excluded_locations - excluded_locations = [loc for loc in excluded_locations if loc.strip()] - all_excluded = set(excluded_locations) | used_treatment_locations + # Filter out empty strings from excluded_treatment_locations + excluded_treatment_locations = [loc for loc in excluded_treatment_locations if loc.strip()] + all_excluded = set(excluded_treatment_locations) | used_treatment_locations logger.debug( f"select_treatments_exclusive: treatment_size={treatment_size}, excluded={len(all_excluded)} locations" @@ -490,7 +499,7 @@ def select_treatments_exclusive( missing_locations = [ location - for location in excluded_locations + for location in excluded_treatment_locations if location not in similarity_matrix.index or location not in similarity_matrix.columns ] @@ -564,9 +573,10 @@ def select_controls_exclusive( correlation_matrix, treatment_group, used_treatment_locations=None, - excluded_locations=None, + excluded_treatment_locations=None, min_correlation=0.8, fallback_n=1, + excluded_control_locations=None, ): """ Dynamically selects control group states based on correlation values. @@ -574,31 +584,35 @@ def select_controls_exclusive( - Excludes current treatment group locations - Excludes globally excluded locations - Excludes locations that have been used as treatment in previous cells + - Excludes explicitly excluded control locations - ALLOWS reuse of control locations from previous cells Args: correlation_matrix (pd.DataFrame): Correlation matrix between states. treatment_group (list): List of states in the treatment group. used_treatment_locations (set): Set of treatment locations already used in previous cells. - excluded_locations (list): List of globally excluded locations. + excluded_treatment_locations (list): List of globally excluded locations. min_correlation (float): Minimum correlation threshold to consider a state as part of the control group. fallback_n (int): Number of top correlated states to select if no state meets the min_correlation. + excluded_control_locations (list): List of states to exclude from the control group. Returns: list: List of states selected as the control group. """ if used_treatment_locations is None: used_treatment_locations = set() - if excluded_locations is None: - excluded_locations = [] + if excluded_treatment_locations is None: + excluded_treatment_locations = [] + if excluded_control_locations is None: + excluded_control_locations = [] logger.debug( - f"select_controls_exclusive called: treatment_group={treatment_group}, used_treatment_locations={len(used_treatment_locations)}, excluded_locations={len(excluded_locations)}" + f"select_controls_exclusive called: treatment_group={treatment_group}, used_treatment_locations={len(used_treatment_locations)}, excluded_treatment_locations={len(excluded_treatment_locations)}" ) control_group = set() all_excluded = ( - set(treatment_group) | used_treatment_locations | set(excluded_locations) + set(treatment_group) | used_treatment_locations | set(excluded_treatment_locations) | set(excluded_control_locations) ) for treatment_location in treatment_group: @@ -655,8 +669,9 @@ def evaluate_group_exclusive( min_holdout, df_pivot, used_treatment_locations=None, - excluded_locations=None, + excluded_treatment_locations=None, treatment_period=None, + excluded_control_locations=None, ): """ Evaluates a treatment group with location exclusivity for multi-cell mode. @@ -672,7 +687,7 @@ def evaluate_group_exclusive( min_holdout (float): Minimum required holdout percentage df_pivot (pd.DataFrame): Pivoted data with time as index and locations as columns used_treatment_locations (set): Set of locations already used as treatment in other cells - excluded_locations (list): List of globally excluded locations + excluded_treatment_locations (list): List of globally excluded locations treatment_period (int): Number of periods for treatment (if None, uses 80/20 split) Returns: @@ -702,8 +717,9 @@ def evaluate_group_exclusive( correlation_matrix=correlation_matrix, treatment_group=treatment_group, used_treatment_locations=used_treatment_locations, - excluded_locations=excluded_locations, + excluded_treatment_locations=excluded_treatment_locations, min_correlation=0.8, + excluded_control_locations=excluded_control_locations, ) logger.debug(f"Control group selected: {control_group}") @@ -790,7 +806,7 @@ def evaluate_group_exclusive( def BetterGroups( similarity_matrix, - excluded_locations, + excluded_treatment_locations, data, correlation_matrix, maximum_treatment_percentage=0.50, @@ -798,6 +814,7 @@ def BetterGroups( status_updater=None, multicell_config=None, global_optimization=False, + excluded_control_locations=None, ): """ Enhanced simulates and evaluates treatment groups for geo-experiments. @@ -809,7 +826,7 @@ def BetterGroups( Args: similarity_matrix (pd.DataFrame): Correlation matrix for treatment selection - excluded_locations (list): List of locations to exclude from treatment selection + excluded_treatment_locations (list): List of locations to exclude from treatment selection data (pd.DataFrame): Input data with 'location', 'time', and 'Y' columns correlation_matrix (pd.DataFrame): Market correlation matrix for control selection maximum_treatment_percentage (float): Maximum treatment percentage (default: 0.50) @@ -829,8 +846,8 @@ def BetterGroups( """ unique_locations = data["location"].unique() no_locations = len(unique_locations) - # max_group_size = round(no_locations * 0.35) - # min_elements_in_treatment = round(no_locations * 0.20) + #max_group_size = round(no_locations * 0.35) + #min_elements_in_treatment = round(no_locations * 0.20) max_group_size = round(no_locations * 0.45) min_elements_in_treatment = round(no_locations * 0.15) min_holdout = 100 - (maximum_treatment_percentage * 100) @@ -856,7 +873,7 @@ def BetterGroups( similarity_matrix=similarity_matrix, allowed_sizes=sizes, total_cells_needed=top_n, - excluded_locations=excluded_locations, + excluded_treatment_locations=excluded_treatment_locations, data=data, correlation_matrix=correlation_matrix, maximum_treatment_percentage=maximum_treatment_percentage, @@ -872,7 +889,7 @@ def BetterGroups( for size in sizes: groups = select_treatments_exclusive( - similarity_matrix, size, excluded_locations, used_treatment_locations + similarity_matrix, size, excluded_treatment_locations, used_treatment_locations ) if not groups: logger.warning( @@ -887,7 +904,7 @@ def BetterGroups( select_treatments_exclusive( similarity_matrix, s, - excluded_locations, + excluded_treatment_locations, used_treatment_locations, ) ) @@ -908,7 +925,9 @@ def BetterGroups( [min_holdout] * total_groups, [df_pivot] * total_groups, [used_treatment_locations] * total_groups, - [excluded_locations] * total_groups, + [excluded_treatment_locations] * total_groups, + [None] * total_groups, + [excluded_control_locations] * total_groups, ) for idx, result in enumerate(futures): @@ -976,7 +995,8 @@ def BetterGroups( min_holdout=min_holdout, df_pivot=df_pivot, used_treatment_locations=current_used_treatments, - excluded_locations=excluded_locations, + excluded_treatment_locations=excluded_treatment_locations, + excluded_control_locations=excluded_control_locations, ) if result is not None: final_results.append(result) @@ -1019,7 +1039,7 @@ def BetterGroups( logger.info(f"Starting single-cell mode") possible_groups = [] for size in range(min_elements_in_treatment, max_group_size + 1): - groups = select_treatments(similarity_matrix, size, excluded_locations) + groups = select_treatments(similarity_matrix, size, excluded_treatment_locations) possible_groups.extend(groups) if not possible_groups: @@ -1038,6 +1058,8 @@ def BetterGroups( [correlation_matrix] * total_groups, [min_holdout] * total_groups, [df_pivot] * total_groups, + [None] * total_groups, + [excluded_control_locations] * total_groups, ) for idx, result in enumerate(futures): results.append(result) @@ -1153,7 +1175,7 @@ def update_progress(selected_count): return selected_cells -def _validate_multicell_config(similarity_matrix, allowed_sizes, total_cells_needed, excluded_locations, data): +def _validate_multicell_config(similarity_matrix, allowed_sizes, total_cells_needed, excluded_treatment_locations, data): """ Validate multicell configuration and provide actionable error messages. @@ -1161,7 +1183,7 @@ def _validate_multicell_config(similarity_matrix, allowed_sizes, total_cells_nee similarity_matrix: Correlation matrix for treatment selection allowed_sizes: List of allowed cell sizes to choose from total_cells_needed: Total number of cells in final experiment - excluded_locations: Globally excluded locations + excluded_treatment_locations: Globally excluded locations data: Input data Returns: @@ -1173,7 +1195,7 @@ def _validate_multicell_config(similarity_matrix, allowed_sizes, total_cells_nee unique_locations = data["location"].unique() total_locations = len(unique_locations) - excluded_count = len(set(excluded_locations)) + excluded_count = len(set(excluded_treatment_locations)) available_locations = total_locations - excluded_count # Check basic feasibility @@ -1228,7 +1250,7 @@ def optimize_global_multicell( similarity_matrix, allowed_sizes, total_cells_needed, - excluded_locations, + excluded_treatment_locations, data, correlation_matrix, maximum_treatment_percentage, @@ -1246,7 +1268,7 @@ def optimize_global_multicell( similarity_matrix: Correlation matrix for treatment selection allowed_sizes: List of allowed cell sizes to choose from total_cells_needed: Total number of cells in final experiment - excluded_locations: Globally excluded locations + excluded_treatment_locations: Globally excluded locations data: Input data correlation_matrix: Market correlation matrix maximum_treatment_percentage: Max treatment percentage @@ -1264,7 +1286,7 @@ def optimize_global_multicell( # Pre-flight validation is_valid, warnings, suggestions = _validate_multicell_config( - similarity_matrix, allowed_sizes, total_cells_needed, excluded_locations, data + similarity_matrix, allowed_sizes, total_cells_needed, excluded_treatment_locations, data ) for warning in warnings: @@ -1297,7 +1319,7 @@ def optimize_global_multicell( logger.info(f"Generating candidates for size {size}") groups = select_treatments_exclusive( - similarity_matrix, size, excluded_locations, used_treatment_locations=set() + similarity_matrix, size, excluded_treatment_locations, used_treatment_locations=set() ) if not groups: @@ -1315,7 +1337,9 @@ def optimize_global_multicell( [min_holdout] * len(groups), [df_pivot] * len(groups), [set()] * len(groups), # No used locations in phase 1 - [excluded_locations] * len(groups), + [excluded_treatment_locations] * len(groups), + [None] * len(groups), + [excluded_control_locations] * len(groups), ) for result in futures: @@ -1391,7 +1415,7 @@ def optimize_global_multicell( correlation_matrix=correlation_matrix, treatment_group=treatment_group, used_treatment_locations=all_treatment_locations, # Exclude ALL treatments - excluded_locations=excluded_locations, + excluded_treatment_locations=excluded_treatment_locations, min_correlation=0.8, ) @@ -2049,7 +2073,7 @@ def run_geo_analysis_streamlit_app( significance_level, deltas_range, periods_range, - excluded_locations, + excluded_treatment_locations, progress_bar_1=None, status_text_1=None, progress_bar_2=None, @@ -2061,6 +2085,7 @@ def run_geo_analysis_streamlit_app( inference_type="iid", global_optimization=False, progress_updater=None, + excluded_control_locations=None, ): """ Runs a complete geo analysis pipeline including market correlation, group optimization, @@ -2072,7 +2097,7 @@ def run_geo_analysis_streamlit_app( significance_level (float): Significance level for statistical testing. deltas_range (tuple): Range of delta values to evaluate as (start, stop, step). periods_range (tuple): Range of treatment periods to evaluate as (start, stop, step). - excluded_locations (list): List of states to exclude from the analysis. + excluded_treatment_locations (list): List of states to exclude from the analysis. progress_bar_1 (callable): Progress bar updater for group optimization phase. status_text_1 (callable): Status text updater for group optimization phase. progress_bar_2 (callable): Progress bar updater for sensitivity evaluation phase. @@ -2099,7 +2124,8 @@ def run_geo_analysis_streamlit_app( logger.info(f" - significance_level: {significance_level}") logger.info(f" - deltas_range: {deltas_range}") logger.info(f" - periods_range: {periods_range}") - logger.info(f" - excluded_locations: {excluded_locations}") + logger.info(f" - excluded_treatment_locations: {excluded_treatment_locations}") + logger.info(f" - excluded_control_locations: {excluded_control_locations}") logger.info(f" - multicell_config: {multicell_config}") logger.info("=" * 80) @@ -2125,13 +2151,14 @@ def run_geo_analysis_streamlit_app( simulation_results = BetterGroups( similarity_matrix=correlation_matrix, maximum_treatment_percentage=maximum_treatment_percentage, - excluded_locations=excluded_locations, + excluded_treatment_locations=excluded_treatment_locations, data=data, correlation_matrix=correlation_matrix, progress_updater=progress_bar_1, status_updater=status_text_1, multicell_config=multicell_config, global_optimization=global_optimization, + excluded_control_locations=excluded_control_locations, ) if simulation_results is None: @@ -2245,7 +2272,7 @@ def run_geo_analysis( significance_level, deltas_range, periods_range, - excluded_locations, + excluded_treatment_locations, progress_bar_1=None, status_text_1=None, progress_bar_2=None, @@ -2266,7 +2293,7 @@ def run_geo_analysis( significance_level (float): Significance level for statistical testing. deltas_range (tuple): Range of delta values to evaluate as (start, stop, step). periods_range (tuple): Range of treatment periods to evaluate as (start, stop, step). - excluded_locations (list): List of states to exclude from the analysis. + excluded_treatment_locations (list): List of states to exclude from the analysis. progress_bar_1 (optional): First progress bar for UI updates. status_text_1 (optional): First status text for UI updates. progress_bar_2 (optional): Second progress bar for UI updates. @@ -2290,7 +2317,8 @@ def run_geo_analysis( logger.info(f" - significance_level: {significance_level}") logger.info(f" - deltas_range: {deltas_range}") logger.info(f" - periods_range: {periods_range}") - logger.info(f" - excluded_locations: {excluded_locations}") + logger.info(f" - excluded_treatment_locations: {excluded_treatment_locations}") + logger.info(f" - excluded_control_locations: {excluded_control_locations}") logger.info("=" * 80) if progress_bar_1 or progress_bar_2 or status_text_1 or status_text_2 is None: @@ -2306,11 +2334,12 @@ def run_geo_analysis( simulation_results = BetterGroups( similarity_matrix=correlation_matrix, maximum_treatment_percentage=maximum_treatment_percentage, - excluded_locations=excluded_locations, + excluded_treatment_locations=excluded_treatment_locations, data=data, correlation_matrix=correlation_matrix, progress_updater=progress_bar_1, status_updater=status_text_1, + excluded_control_locations=excluded_control_locations, ) # Step 3: Evaluate sensitivity for different deltas and periods diff --git a/Murray/post_analysis.py b/Murray/post_analysis.py index 47d32a1..022bff0 100644 --- a/Murray/post_analysis.py +++ b/Murray/post_analysis.py @@ -18,6 +18,7 @@ def run_geo_evaluation( n_permutations=50000, inference_type="iid", significance_level=0.1, + excluded_control_locations=None, ): logger.info("Starting run_geo_evaluation") logger.info(f"Input data shape: {data_input.shape}") @@ -55,6 +56,7 @@ def smape(A, F): correlation_matrix=correlation_matrix, treatment_group=treatment_group, min_correlation=0.8, + excluded_control_locations=excluded_control_locations, ) logger.info(f"Control group selected: {control_group}") @@ -100,6 +102,13 @@ def smape(A, F): time_train = time_index[:start_position_treatment] time_test = time_index[start_position_treatment:] + if len(X_train) == 0: + raise ValueError( + f"No pre-treatment periods available for training. " + f"The treatment start date ({start_treatment}) is at or before the beginning of the dataset. " + f"Please select a later treatment start date." + ) + logger.info("Fitting synthetic control model...") model = SyntheticControl(use_ridge_adjustment=True, ridge_alpha=1.0) model.fit(X_train, y_train, time_train=time_train) diff --git a/api.py b/api.py index 2d781be..43182f1 100644 --- a/api.py +++ b/api.py @@ -124,7 +124,8 @@ async def analyze_design( date_column: str = Form(None), location_column: str = Form(None), target_column: str = Form(None), - excluded_locations: str = Form(None), + excluded_treatment_locations: str = Form(None), + excluded_control_locations: str = Form(None), maximum_treatment_percentage: float = Form(0.3), significance_level: float = Form(0.1), deltas_range: str = Form("0.01,0.1,0.01"), @@ -179,12 +180,17 @@ async def analyze_design( # Parse parameters deltas_range = tuple(map(float, deltas_range.split(','))) periods_range = tuple(map(int, periods_range.split(','))) - # Handle empty excluded_locations properly - if excluded_locations and excluded_locations.strip(): - excluded_locations = tuple(map(str, excluded_locations.split(','))) + # Handle empty excluded_treatment_locations properly + if excluded_treatment_locations and excluded_treatment_locations.strip(): + excluded_treatment_locations = tuple(map(str, excluded_treatment_locations.split(','))) else: - excluded_locations = tuple() - + excluded_treatment_locations = tuple() + # Handle empty excluded_control_locations properly + if excluded_control_locations and excluded_control_locations.strip(): + excluded_control_locations = tuple(map(str, excluded_control_locations.split(','))) + else: + excluded_control_locations = tuple() + # Process multicell parameters multicell_config = None if enable_multicell: @@ -220,7 +226,8 @@ async def analyze_design( date_column=date_column, location_column=location_column, target_column=target_column, - excluded_locations=excluded_locations, + excluded_treatment_locations=excluded_treatment_locations, + excluded_control_locations=excluded_control_locations, maximum_treatment_percentage=maximum_treatment_percentage, significance_level=significance_level, deltas_range=deltas_range, diff --git a/experimental_design.py b/experimental_design.py index 54dc01d..d10b963 100644 --- a/experimental_design.py +++ b/experimental_design.py @@ -764,9 +764,12 @@ def reset_states(): unsafe_allow_html=True, ) # Location exclusion and test type selection - excluded_locations = st.multiselect( + excluded_treatment_locations = st.multiselect( "Select excluded locations", cleaned["location"].unique() ) + excluded_control_locations = st.multiselect( + "Select excluded controls", cleaned["location"].unique() + ) # Analyze data and recommend statistical test data_analysis = analyze_data_characteristics(cleaned, col_target="Y") @@ -954,7 +957,8 @@ def reset_states(): # Track parameter changes to reset simulation state current_params = { - "excluded_locations": excluded_locations, + "excluded_treatment_locations": excluded_treatment_locations, + "excluded_control_locations": excluded_control_locations, "maximum_treatment_percentage_pre": maximum_treatment_percentage_pre, "significance_level_pre": significance_level_pre, "deltas_range": (delta_min, delta_max, delta_step), @@ -1016,7 +1020,8 @@ def reset_states(): # Run main geo analysis simulation results = run_geo_analysis_streamlit_app( data=cleaned, - excluded_locations=excluded_locations, + excluded_treatment_locations=excluded_treatment_locations, + excluded_control_locations=excluded_control_locations, maximum_treatment_percentage=maximum_treatment_percentage, significance_level=significance_level, deltas_range=deltas_range, @@ -1055,7 +1060,8 @@ def reset_states(): # Show current settings for debugging with st.expander("🔧 Current Analysis Settings", expanded=False): st.write(f"**Maximum Treatment Percentage:** {maximum_treatment_percentage*100:.1f}%") - st.write(f"**Excluded Locations:** {len(excluded_locations)} locations") + st.write(f"**Excluded Locations:** {len(excluded_treatment_locations)} locations") + st.write(f"**Excluded Controls:** {len(excluded_control_locations)} controls") st.write(f"**Total Locations Available:** {len(cleaned['location'].unique())} locations") st.write(f"**Time Periods:** {len(cleaned['time'].unique())} periods") st.write(f"**Total Y Sum:** {cleaned['Y'].sum():,.2f}") diff --git a/experimental_evaluation.py b/experimental_evaluation.py index 5cb5b4f..f24d7b7 100644 --- a/experimental_evaluation.py +++ b/experimental_evaluation.py @@ -766,6 +766,9 @@ def reset_states(): treatment_group = st.multiselect( "Select treatment group", data1["location"].unique() ) + excluded_control_locations = st.multiselect( + "Select excluded control locations", data1["location"].unique() + ) spend = st.number_input("Select spend") mmm_option = st.selectbox( "Select the option to calculate the iROAS or iCPA", ["iROAS", "iCPA"] @@ -795,6 +798,7 @@ def reset_states(): "start_treatment": start_treatment, "end_treatment": end_treatment, "treatment_group": treatment_group, + "excluded_control_locations": excluded_control_locations, "spend": spend, "mmm_option": mmm_option, "col_target": col_target, @@ -814,14 +818,20 @@ def reset_states(): update_metrics("experimental_evaluation") with st.spinner("Running analysis..."): + try: + results = run_geo_evaluation( + data1, + start_treatment, + end_treatment, + treatment_group, + spend, + excluded_control_locations=excluded_control_locations, + ) + except ValueError as e: + st.error(f"Error: {e}") + st.session_state.evaluation_button_clicked = False + st.stop() - results = run_geo_evaluation( - data1, - start_treatment, - end_treatment, - treatment_group, - spend, - ) treatment = results["treatment"] st.session_state.treatment = treatment counterfactual = results["counterfactual"] diff --git a/tasks.py b/tasks.py index f6562be..10d0e83 100644 --- a/tasks.py +++ b/tasks.py @@ -106,7 +106,8 @@ def analyze_design_task( date_column: str, location_column: str, target_column: str, - excluded_locations: tuple, + excluded_treatment_locations: tuple, + excluded_control_locations: tuple, maximum_treatment_percentage: float, significance_level: float, deltas_range: tuple, @@ -257,7 +258,8 @@ def advance_to_sensitivity(): results = run_geo_analysis_streamlit_app( data=data, - excluded_locations=excluded_locations, + excluded_treatment_locations=excluded_treatment_locations, + excluded_control_locations=excluded_control_locations, maximum_treatment_percentage=maximum_treatment_percentage, significance_level=significance_level, deltas_range=deltas_range,