From f9d7336a932be69cce9fe58c87d68ce4e2362584 Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Fri, 19 Sep 2025 18:28:49 +0200 Subject: [PATCH] refactor: optimize agent presence checks by using implode() for unique_id comparisons --- examples/sugarscape_ig/ss_polars/agents.py | 8 +++-- mesa_frames/abstract/space.py | 8 ++--- mesa_frames/concrete/agentset.py | 37 ++++++++++++---------- mesa_frames/concrete/agentsetregistry.py | 2 +- mesa_frames/concrete/mixin.py | 6 ++-- 5 files changed, 34 insertions(+), 27 deletions(-) diff --git a/examples/sugarscape_ig/ss_polars/agents.py b/examples/sugarscape_ig/ss_polars/agents.py index 32ca91f5..f74e94c1 100644 --- a/examples/sugarscape_ig/ss_polars/agents.py +++ b/examples/sugarscape_ig/ss_polars/agents.py @@ -37,7 +37,7 @@ def __init__( def eat(self): # Only consider cells currently occupied by agents of this set cells = self.space.cells.filter(pl.col("agent_id").is_not_null()) - mask_in_set = cells["agent_id"].is_in(self.index) + mask_in_set = cells["agent_id"].is_in(self.index.implode()) if mask_in_set.any(): cells = cells.filter(mask_in_set) ids = cells["agent_id"] @@ -201,7 +201,7 @@ def get_best_moves(self, neighborhood: pl.DataFrame): ) if len(best_moves) > 0: condition = condition | pl.col("blocking_agent_id").is_in( - best_moves["agent_id_center"] + best_moves["agent_id_center"].implode() ) condition = condition & (pl.col("priority") == 1) @@ -212,7 +212,9 @@ def get_best_moves(self, neighborhood: pl.DataFrame): # Remove agents that have already moved neighborhood = neighborhood.filter( - ~pl.col("agent_id_center").is_in(best_moves["agent_id_center"]) + ~pl.col("agent_id_center").is_in( + best_moves["agent_id_center"].implode() + ) ) # Remove cells that have been already selected diff --git a/mesa_frames/abstract/space.py b/mesa_frames/abstract/space.py index 39abe6bd..5d52b368 100644 --- a/mesa_frames/abstract/space.py +++ b/mesa_frames/abstract/space.py @@ -984,7 +984,7 @@ def _place_or_move_agents_to_cells( if __debug__: # Check ids presence in model using public API - b_contained = agents.is_in(self.model.sets.ids) + b_contained = agents.is_in(self.model.sets.ids.implode()) if (isinstance(b_contained, Series) and not b_contained.all()) or ( isinstance(b_contained, bool) and not b_contained ): @@ -1610,7 +1610,7 @@ def remove_agents( if __debug__: # Check ids presence in model via public ids - b_contained = agents.is_in(obj.model.sets.ids) + b_contained = agents.is_in(obj.model.sets.ids.implode()) if (isinstance(b_contained, Series) and not b_contained.all()) or ( isinstance(b_contained, bool) and not b_contained ): @@ -1792,7 +1792,7 @@ def _get_df_coords( if agents is not None: agents = self._get_ids_srs(agents) # Check ids presence in model - b_contained = agents.is_in(self.model.sets.ids) + b_contained = agents.is_in(self.model.sets.ids.implode()) if (isinstance(b_contained, Series) and not b_contained.all()) or ( isinstance(b_contained, bool) and not b_contained ): @@ -1872,7 +1872,7 @@ def _place_or_move_agents( warn("Some agents are already present in the grid", RuntimeWarning) # Check if agents are present in the model using the public ids - b_contained = agents.is_in(self.model.sets.ids) + b_contained = agents.is_in(self.model.sets.ids.implode()) if (isinstance(b_contained, Series) and not b_contained.all()) or ( isinstance(b_contained, bool) and not b_contained ): diff --git a/mesa_frames/concrete/agentset.py b/mesa_frames/concrete/agentset.py index 2a9b1a55..8ff429d6 100644 --- a/mesa_frames/concrete/agentset.py +++ b/mesa_frames/concrete/agentset.py @@ -232,9 +232,11 @@ def contains( agents: PolarsIdsLike, ) -> bool | pl.Series: if isinstance(agents, pl.Series): - return agents.is_in(self._df["unique_id"]) + return agents.is_in(self._df["unique_id"].implode()) elif isinstance(agents, Collection) and not isinstance(agents, str): - return pl.Series(agents, dtype=pl.UInt64).is_in(self._df["unique_id"]) + return pl.Series(agents, dtype=pl.UInt64).is_in( + self._df["unique_id"].implode() + ) else: return agents in self._df["unique_id"] @@ -322,7 +324,7 @@ def remove(self, agents: PolarsIdsLike | AgentMask, inplace: bool = True) -> Sel # Normalize to Series of unique_ids ids = obj._df_index(obj._get_masked_df(agents), "unique_id") # Validate presence - if not ids.is_in(obj._df["unique_id"]).all(): + if not ids.is_in(obj._df["unique_id"].implode()).all(): raise KeyError("Some 'unique_id' of mask are not present in this AgentSet.") # Remove by ids return obj._discard(ids) @@ -396,8 +398,8 @@ def select( if filter_func: mask = mask & filter_func(obj) if n is not None: - mask = (obj._df["unique_id"]).is_in( - obj._df.filter(mask).sample(n)["unique_id"] + mask = obj._df["unique_id"].is_in( + obj._df.filter(mask).sample(n)["unique_id"].implode() ) if negate: mask = mask.not_() @@ -456,7 +458,9 @@ def _concatenate_agentsets( for obj in iter(agentsets): # Remove agents that are already in the final DataFrame final_dfs.append( - obj._df.filter(pl.col("unique_id").is_in(final_indices).not_()) + obj._df.filter( + pl.col("unique_id").is_in(final_indices.implode()).not_() + ) ) # Add the indices of the active agents of current AgentSet final_active_indices.append(obj._df.filter(obj._mask)["unique_id"]) @@ -476,13 +480,13 @@ def _concatenate_agentsets( final_active_index = pl.concat( [obj._df.filter(obj._mask)["unique_id"] for obj in agentsets] ) - final_mask = final_df["unique_id"].is_in(final_active_index) + final_mask = final_df["unique_id"].is_in(final_active_index.implode()) self._df = final_df self._mask = final_mask # If some ids were removed in the do-method, we need to remove them also from final_df if not isinstance(original_masked_index, type(None)): ids_to_remove = original_masked_index.filter( - original_masked_index.is_in(self._df["unique_id"]).not_() + original_masked_index.is_in(self._df["unique_id"].implode()).not_() ) if not ids_to_remove.is_empty(): self.remove(ids_to_remove, inplace=True) @@ -499,7 +503,7 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series: and len(mask) == len(self._df) ): return mask - return self._df["unique_id"].is_in(mask) + return self._df["unique_id"].is_in(mask.implode()) if isinstance(mask, pl.Expr): return mask @@ -532,13 +536,13 @@ def _get_masked_df( ): return self._df.filter(mask) elif isinstance(mask, pl.DataFrame): - if not mask["unique_id"].is_in(self._df["unique_id"]).all(): + if not mask["unique_id"].is_in(self._df["unique_id"].implode()).all(): raise KeyError( "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." ) return mask.select("unique_id").join(self._df, on="unique_id", how="left") elif isinstance(mask, pl.Series): - if not mask.is_in(self._df["unique_id"]).all(): + if not mask.is_in(self._df["unique_id"].implode()).all(): raise KeyError( "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." ) @@ -553,7 +557,7 @@ def _get_masked_df( mask_series = pl.Series(mask, dtype=pl.UInt64) else: mask_series = pl.Series([mask], dtype=pl.UInt64) - if not mask_series.is_in(self._df["unique_id"]).all(): + if not mask_series.is_in(self._df["unique_id"].implode()).all(): raise KeyError( "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." ) @@ -585,12 +589,13 @@ def _discard(self, ids: PolarsIdsLike) -> Self: def _update_mask( self, original_active_indices: pl.Series, new_indices: pl.Series | None = None ) -> None: + original_active = original_active_indices.implode() if new_indices is not None: - self._mask = self._df["unique_id"].is_in( - original_active_indices - ) | self._df["unique_id"].is_in(new_indices) + self._mask = self._df["unique_id"].is_in(original_active) | self._df[ + "unique_id" + ].is_in(new_indices.implode()) else: - self._mask = self._df["unique_id"].is_in(original_active_indices) + self._mask = self._df["unique_id"].is_in(original_active) def __getattr__(self, key: str) -> Any: if key == "name": diff --git a/mesa_frames/concrete/agentsetregistry.py b/mesa_frames/concrete/agentsetregistry.py index 4b486ba2..dbccf601 100644 --- a/mesa_frames/concrete/agentsetregistry.py +++ b/mesa_frames/concrete/agentsetregistry.py @@ -575,7 +575,7 @@ def _check_ids_presence(self, other: list[AgentSet]) -> pl.DataFrame: [ presence_df, ( - new_ids.is_in(presence_df["unique_id"]) + new_ids.is_in(presence_df["unique_id"].implode()) .to_frame("present") .with_columns(unique_id=new_ids) .select(["unique_id", "present"]) diff --git a/mesa_frames/concrete/mixin.py b/mesa_frames/concrete/mixin.py index 4900536e..69be0423 100644 --- a/mesa_frames/concrete/mixin.py +++ b/mesa_frames/concrete/mixin.py @@ -206,7 +206,7 @@ def _df_contains( column: str, values: Collection[Any], ) -> pl.Series: - return pl.Series("contains", values).is_in(df[column]) + return pl.Series("contains", values).is_in(df[column].implode()) def _df_div( self, @@ -290,7 +290,7 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series: ): return mask assert isinstance(index_cols, str) - return df[index_cols].is_in(mask) + return df[index_cols].is_in(mask.implode()) def bool_mask_from_df(mask: pl.DataFrame) -> pl.Series: assert index_cols, list[str] @@ -632,7 +632,7 @@ def _srs_contains( ) -> pl.Series: if not isinstance(values, Collection): values = [values] - return pl.Series(values).is_in(pl.Series(srs)) + return pl.Series(values).is_in(pl.Series(srs).implode()) def _srs_range( self,