diff --git a/bsb/simulation/targetting.py b/bsb/simulation/targetting.py index 2e610b51..9960161e 100755 --- a/bsb/simulation/targetting.py +++ b/bsb/simulation/targetting.py @@ -9,6 +9,7 @@ from ..config import refs, types if typing.TYPE_CHECKING: + from ..cell_types import CellType from .cell import CellModel @@ -58,6 +59,19 @@ def get_targets(self, adapter, simulation, simdata): } +class CellTypeFilter: + cell_types: list["CellType"] = config.reflist(refs.cell_type_ref, required=False) + only_local: bool = config.attr(type=bool, default=True) + + def get_targets(self, adapter, simulation, simdata): + chunks = simdata.chunks if self.only_local else None + return { + cell_name: cell_type.get_placement_set(chunks=chunks) + for cell_name, cell_type in simulation.scaffold.cell_types.items() + if not self.cell_types or cell_type in self.cell_types + } + + class FractionFilter: count = config.attr( type=int, required=types.mut_excl("fraction", "count", required=False) @@ -141,13 +155,20 @@ class ByIdTargetting(FractionFilter, CellTargetting, classmap_entry="by_id"): @FractionFilter.filter def get_targets(self, adapter, simulation, simdata): - by_name = {model.name: model for model in simdata.populations.keys()} - return { - model: simdata.populations[model][ids] - for model_name, ids in self.ids.items() - if (model := by_name.get(model_name)) is not None + by_name = { + model.name: model + for model, pop in simdata.populations.items() + if len(pop) > 0 } + dict_target = {} + for model_name, ids in self.ids.items(): + if (model := by_name.get(model_name)) is not None: + pop = simdata.populations[model] + my_ids = simdata.placement[model].convert_to_local(ids) + dict_target[model] = pop[my_ids] + return dict_target + @config.node class ByLabelTargetting( @@ -209,6 +230,36 @@ def get_targets(self, adapter, simulation, simdata): } +@config.node +class SphericalTargettingCellTypes( + CellTypeFilter, FractionFilter, Targetting, classmap_entry="sphere_cell_types" +): + """ + Targets all cell types in a sphere. + """ + + origin: list[float] = config.attr(type=types.list(type=float, size=3), required=True) + radius: float = config.attr(type=float, required=True) + + @FractionFilter.filter + def get_targets(self, adapter, simulation, simdata): + """ + Target all or certain cells within a sphere of specified radius. + """ + return { + model: ps.load_ids()[ + ( + np.sum( + (ps.load_positions() - self.origin) ** 2, + axis=1, + ) + < self.radius**2 + ) + ] + for model, ps in super().get_targets(adapter, simulation, simdata).items() + } + + @config.node class SphericalTargetting( CellModelFilter, FractionFilter, CellTargetting, classmap_entry="sphere" @@ -247,7 +298,7 @@ def get_targets(self, adapter, simulation, simdata): ) class LocationTargetting: def get_locations(self, cell): - return cell.locations + return [v for v in cell.locations.values()] @config.node