From b252722c7d80e7e4ae14487c9c1eb0d836fd9e8e Mon Sep 17 00:00:00 2001 From: connorschwartz <46463980+connorschwartz@users.noreply.github.com> Date: Mon, 4 May 2026 21:48:21 -0400 Subject: [PATCH] Support different variables for different model groups --- openavmkit/benchmark.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/openavmkit/benchmark.py b/openavmkit/benchmark.py index 69378dd..33198f5 100644 --- a/openavmkit/benchmark.py +++ b/openavmkit/benchmark.py @@ -490,6 +490,7 @@ def get_variable_recommendations( settings_model = settings.get("modeling", {}) vacant_status = "vacant" if vacant_only else "main" model_entries = settings_model.get("models", {}).get(vacant_status, {}) + model_entries = model_entries.get(model_group, model_entries) entry: dict | None = model_entries.get("model", model_entries.get("default", {})) if variables_to_use is None: variables_to_use: list | None = entry.get("ind_vars", None) @@ -1437,7 +1438,7 @@ def run_one_model( if save_results: t.start("write") main_vacant_hedonic = "hedonic" if hedonic else "vacant" if vacant_only else "main" - location = get_model_location(settings, main_vacant_hedonic, model_name) + location = get_model_location(settings, main_vacant_hedonic, model_name, model_group) _write_model_results(results, outpath, settings, location, verbose=verbose) t.stop("write") @@ -1545,6 +1546,7 @@ def run_one_hedonic_model( settings=settings, save_results=save_results, verbose=verbose, + model_group=model_group, ) return results @@ -1796,6 +1798,7 @@ def _predict_one_model( model_engine: str, outpath: str, settings: dict, + model_group: str, save_results: bool = False, verbose: bool = False, ) -> SingleModelResults: @@ -1868,12 +1871,13 @@ def _predict_one_model( if save_results: mvh = settings.get("modeling", {}).get("models", {}).get(main_vacant_hedonic, {}) + mvh = mvh.get(model_group, mvh) model_entry = mvh.get("model_name", mvh.get("default", {})) location = model_entry.get("location", None) if location is None: location = get_important_field(settings, "loc_neighborhood") - location = get_model_location(settings, main_vacant_hedonic, model_name) + location = get_model_location(settings, main_vacant_hedonic, model_name, model_group) _write_model_results(results, outpath, settings, location, verbose=verbose) return results @@ -2222,9 +2226,11 @@ def _write_model_results(results: SingleModelResults, outpath: str, settings: di def get_model_location( settings: dict, main_vacant_hedonic: str, - model_name: str + model_name: str, + model_group: str ): mvh = settings.get("modeling", {}).get("models", {}).get(main_vacant_hedonic, {}) + mvh = mvh.get(model_group, mvh) model_entry = mvh.get(model_name, mvh.get("default", {})) location = model_entry.get("location", None) if location is None: @@ -3099,6 +3105,7 @@ def _prepare_ds( s_model = s.get("modeling", {}) vacant_status = "vacant" if vacant_only else "main" model_entries = s_model.get("models", {}).get(vacant_status, {}) + model_entries = model_entries.get(model_group, model_entries) entry: dict | None = model_entries.get("model", model_entries.get("default", {})) if ind_vars is None: @@ -3622,6 +3629,7 @@ def _run_hedonic_models( settings=settings, save_results=save_results, verbose=verbose, + model_group=model_group, ) if results is not None: hedonic_results[model_name] = results @@ -4247,6 +4255,7 @@ def _run_models( models_to_skip = settings_model_instructions.get(main_vacant_hedonic,{}).get("skip",{}).get(model_group,[]) model_entries = settings_model.get("models").get(main_vacant_hedonic, {}) + model_entries = model_entries.get(model_group, model_entries) if models_to_run is None: models_to_run = list(model_entries.keys()) @@ -4739,4 +4748,4 @@ def _quick_shap( shaps = _calc_shap(smr.model, X_train, X_train) if plot: - plot_full_beeswarm(shaps, title=title) \ No newline at end of file + plot_full_beeswarm(shaps, title=title)