diff --git a/openavmkit/data.py b/openavmkit/data.py index 10cda32..32d2ea5 100644 --- a/openavmkit/data.py +++ b/openavmkit/data.py @@ -842,7 +842,7 @@ def _enrich_sup_spatial_lag_for_model_group( df_sub = df_sub[~pd.isna(df_sub["latitude"]) & ~pd.isna(df_sub["longitude"])] # Choose the number of nearest neighbors to use - k = 5 # adjust this number as needed + k = s_sl.get("sale_price", 5) # adjust this number as needed df_sub_train = df_sub.loc[df_sub["key_sale"].isin(train_keys)].copy() @@ -878,7 +878,7 @@ def _enrich_sup_spatial_lag_for_model_group( # Query the tree: for each parcel in df_universe, find the k nearest sales # distances: shape (n_universe, k); indices: corresponding indices in df_sales - distances, indices = sales_tree.query(universe_coords, k=k) + distances, indices = sales_tree.query(universe_coords, k=min(len(sales_coords_train), k)) # Ensure that distances and indices are 2D arrays (if k==1, reshape them) if k == 1: