Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/sugarscape_ig/ss_polars/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
def eat(self):
# Only consider cells currently occupied by agents of this set
cells = self.space.cells.filter(pl.col("agent_id").is_not_null())
mask_in_set = cells["agent_id"].is_in(self.index)
mask_in_set = cells["agent_id"].is_in(self.index.implode())
if mask_in_set.any():
cells = cells.filter(mask_in_set)
ids = cells["agent_id"]
Expand Down Expand Up @@ -201,7 +201,7 @@ def get_best_moves(self, neighborhood: pl.DataFrame):
)
if len(best_moves) > 0:
condition = condition | pl.col("blocking_agent_id").is_in(
best_moves["agent_id_center"]
best_moves["agent_id_center"].implode()
)

condition = condition & (pl.col("priority") == 1)
Expand All @@ -212,7 +212,9 @@ def get_best_moves(self, neighborhood: pl.DataFrame):

# Remove agents that have already moved
neighborhood = neighborhood.filter(
~pl.col("agent_id_center").is_in(best_moves["agent_id_center"])
~pl.col("agent_id_center").is_in(
best_moves["agent_id_center"].implode()
)
)

# Remove cells that have been already selected
Expand Down
8 changes: 4 additions & 4 deletions mesa_frames/abstract/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,7 @@ def _place_or_move_agents_to_cells(

if __debug__:
# Check ids presence in model using public API
b_contained = agents.is_in(self.model.sets.ids)
b_contained = agents.is_in(self.model.sets.ids.implode())
if (isinstance(b_contained, Series) and not b_contained.all()) or (
isinstance(b_contained, bool) and not b_contained
):
Expand Down Expand Up @@ -1610,7 +1610,7 @@ def remove_agents(

if __debug__:
# Check ids presence in model via public ids
b_contained = agents.is_in(obj.model.sets.ids)
b_contained = agents.is_in(obj.model.sets.ids.implode())
if (isinstance(b_contained, Series) and not b_contained.all()) or (
isinstance(b_contained, bool) and not b_contained
):
Expand Down Expand Up @@ -1792,7 +1792,7 @@ def _get_df_coords(
if agents is not None:
agents = self._get_ids_srs(agents)
# Check ids presence in model
b_contained = agents.is_in(self.model.sets.ids)
b_contained = agents.is_in(self.model.sets.ids.implode())
if (isinstance(b_contained, Series) and not b_contained.all()) or (
isinstance(b_contained, bool) and not b_contained
):
Expand Down Expand Up @@ -1872,7 +1872,7 @@ def _place_or_move_agents(
warn("Some agents are already present in the grid", RuntimeWarning)

# Check if agents are present in the model using the public ids
b_contained = agents.is_in(self.model.sets.ids)
b_contained = agents.is_in(self.model.sets.ids.implode())
if (isinstance(b_contained, Series) and not b_contained.all()) or (
isinstance(b_contained, bool) and not b_contained
):
Expand Down
37 changes: 21 additions & 16 deletions mesa_frames/concrete/agentset.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,11 @@ def contains(
agents: PolarsIdsLike,
) -> bool | pl.Series:
if isinstance(agents, pl.Series):
return agents.is_in(self._df["unique_id"])
return agents.is_in(self._df["unique_id"].implode())
elif isinstance(agents, Collection) and not isinstance(agents, str):
return pl.Series(agents, dtype=pl.UInt64).is_in(self._df["unique_id"])
return pl.Series(agents, dtype=pl.UInt64).is_in(
self._df["unique_id"].implode()
)
else:
return agents in self._df["unique_id"]

Expand Down Expand Up @@ -322,7 +324,7 @@ def remove(self, agents: PolarsIdsLike | AgentMask, inplace: bool = True) -> Sel
# Normalize to Series of unique_ids
ids = obj._df_index(obj._get_masked_df(agents), "unique_id")
# Validate presence
if not ids.is_in(obj._df["unique_id"]).all():
if not ids.is_in(obj._df["unique_id"].implode()).all():
raise KeyError("Some 'unique_id' of mask are not present in this AgentSet.")
# Remove by ids
return obj._discard(ids)
Expand Down Expand Up @@ -396,8 +398,8 @@ def select(
if filter_func:
mask = mask & filter_func(obj)
if n is not None:
mask = (obj._df["unique_id"]).is_in(
obj._df.filter(mask).sample(n)["unique_id"]
mask = obj._df["unique_id"].is_in(
obj._df.filter(mask).sample(n)["unique_id"].implode()
)
if negate:
mask = mask.not_()
Expand Down Expand Up @@ -456,7 +458,9 @@ def _concatenate_agentsets(
for obj in iter(agentsets):
# Remove agents that are already in the final DataFrame
final_dfs.append(
obj._df.filter(pl.col("unique_id").is_in(final_indices).not_())
obj._df.filter(
pl.col("unique_id").is_in(final_indices.implode()).not_()
)
)
# Add the indices of the active agents of current AgentSet
final_active_indices.append(obj._df.filter(obj._mask)["unique_id"])
Expand All @@ -476,13 +480,13 @@ def _concatenate_agentsets(
final_active_index = pl.concat(
[obj._df.filter(obj._mask)["unique_id"] for obj in agentsets]
)
final_mask = final_df["unique_id"].is_in(final_active_index)
final_mask = final_df["unique_id"].is_in(final_active_index.implode())
self._df = final_df
self._mask = final_mask
# If some ids were removed in the do-method, we need to remove them also from final_df
if not isinstance(original_masked_index, type(None)):
ids_to_remove = original_masked_index.filter(
original_masked_index.is_in(self._df["unique_id"]).not_()
original_masked_index.is_in(self._df["unique_id"].implode()).not_()
)
if not ids_to_remove.is_empty():
self.remove(ids_to_remove, inplace=True)
Expand All @@ -499,7 +503,7 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series:
and len(mask) == len(self._df)
):
return mask
return self._df["unique_id"].is_in(mask)
return self._df["unique_id"].is_in(mask.implode())

if isinstance(mask, pl.Expr):
return mask
Expand Down Expand Up @@ -532,13 +536,13 @@ def _get_masked_df(
):
return self._df.filter(mask)
elif isinstance(mask, pl.DataFrame):
if not mask["unique_id"].is_in(self._df["unique_id"]).all():
if not mask["unique_id"].is_in(self._df["unique_id"].implode()).all():
raise KeyError(
"Some 'unique_id' of mask are not present in DataFrame 'unique_id'."
)
return mask.select("unique_id").join(self._df, on="unique_id", how="left")
elif isinstance(mask, pl.Series):
if not mask.is_in(self._df["unique_id"]).all():
if not mask.is_in(self._df["unique_id"].implode()).all():
raise KeyError(
"Some 'unique_id' of mask are not present in DataFrame 'unique_id'."
)
Expand All @@ -553,7 +557,7 @@ def _get_masked_df(
mask_series = pl.Series(mask, dtype=pl.UInt64)
else:
mask_series = pl.Series([mask], dtype=pl.UInt64)
if not mask_series.is_in(self._df["unique_id"]).all():
if not mask_series.is_in(self._df["unique_id"].implode()).all():
raise KeyError(
"Some 'unique_id' of mask are not present in DataFrame 'unique_id'."
)
Expand Down Expand Up @@ -585,12 +589,13 @@ def _discard(self, ids: PolarsIdsLike) -> Self:
def _update_mask(
self, original_active_indices: pl.Series, new_indices: pl.Series | None = None
) -> None:
original_active = original_active_indices.implode()
if new_indices is not None:
self._mask = self._df["unique_id"].is_in(
original_active_indices
) | self._df["unique_id"].is_in(new_indices)
self._mask = self._df["unique_id"].is_in(original_active) | self._df[
"unique_id"
].is_in(new_indices.implode())
else:
self._mask = self._df["unique_id"].is_in(original_active_indices)
self._mask = self._df["unique_id"].is_in(original_active)

def __getattr__(self, key: str) -> Any:
if key == "name":
Expand Down
2 changes: 1 addition & 1 deletion mesa_frames/concrete/agentsetregistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def _check_ids_presence(self, other: list[AgentSet]) -> pl.DataFrame:
[
presence_df,
(
new_ids.is_in(presence_df["unique_id"])
new_ids.is_in(presence_df["unique_id"].implode())
.to_frame("present")
.with_columns(unique_id=new_ids)
.select(["unique_id", "present"])
Expand Down
6 changes: 3 additions & 3 deletions mesa_frames/concrete/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _df_contains(
column: str,
values: Collection[Any],
) -> pl.Series:
return pl.Series("contains", values).is_in(df[column])
return pl.Series("contains", values).is_in(df[column].implode())

def _df_div(
self,
Expand Down Expand Up @@ -290,7 +290,7 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series:
):
return mask
assert isinstance(index_cols, str)
return df[index_cols].is_in(mask)
return df[index_cols].is_in(mask.implode())

def bool_mask_from_df(mask: pl.DataFrame) -> pl.Series:
assert index_cols, list[str]
Expand Down Expand Up @@ -632,7 +632,7 @@ def _srs_contains(
) -> pl.Series:
if not isinstance(values, Collection):
values = [values]
return pl.Series(values).is_in(pl.Series(srs))
return pl.Series(values).is_in(pl.Series(srs).implode())

def _srs_range(
self,
Expand Down
Loading