diff --git a/bsb/__init__.py b/bsb/__init__.py index 7f741f94..ca97caaf 100644 --- a/bsb/__init__.py +++ b/bsb/__init__.py @@ -131,6 +131,7 @@ def __dir__(): import bsb.trees import bsb.voxels +AdapterCheckpoint: typing.Type["bsb.simulation.adapter.AdapterCheckpoint"] AdapterError: typing.Type["bsb.exceptions.AdapterError"] AdapterProgress: typing.Type["bsb.simulation.adapter.AdapterProgress"] AfterConnectivityHook: typing.Type["bsb.postprocessing.AfterConnectivityHook"] diff --git a/bsb/simulation/adapter.py b/bsb/simulation/adapter.py index 50cf9cdc..4e411690 100644 --- a/bsb/simulation/adapter.py +++ b/bsb/simulation/adapter.py @@ -46,6 +46,66 @@ def complete(self): return +class AdapterCheckpoint: + """Class that manages checkpointing of a simulation. In self.checkpoints a dictionary is saved with the checkpoint time as key and the value + is a list of simulations that have to flush at that checkpoint. + The get_status() method should be called in SimulatorAdapter run() to check if a checkpoint is reached. + """ + + def __init__(self, simulations): + self.simulations = simulations + self.resolutions = [] + self.checkpoints = {} + for sim in simulations: + for device in sim.devices.values(): + self.resolutions.append(sim.resolution) + device_ckp = device.get_checkpoints(sim.duration, sim.resolution) + for checkpoint in device_ckp: + if checkpoint not in self.checkpoints: + self.checkpoints[checkpoint] = [sim] + else: + self.checkpoints[checkpoint].append(sim) + self.iterator = iter(self.sort_checkpoints()) + self.status = next(self.iterator, None) + + def sort_checkpoints(self): + return sorted(self.checkpoints.keys()) + + def get_status(self, i): + """Checks whether the current simulation time has reached a checkpoint. If so, it advances to the next checkpoint""" + if self.status == i: + self.status = next(self.iterator, None) + return True + else: + return False + + def suitable_step(self, pstep): + """ + Check the greatest common divisor between progression step (pstep) and checkpoints value. + + :return: gdc interval (float). + """ + sorted = np.array(self.sort_checkpoints()) + max_resolution = max(self.resolutions) + if pstep == int(pstep): + check_multiple = sorted % pstep + else: + check_multiple = sorted / pstep - np.array(sorted / pstep, dtype=int) + if all(check_multiple == 0): + return pstep + elif any(sorted / max_resolution != np.array(sorted / max_resolution, dtype=int)): + raise ValueError( + f"Provided checkpoints are not multiple of resolution: {max_resolution}" + ) + else: + # We are here because pstep is too large. Now we look for the GDC between pstep and our checkpoints + converted = np.array(sorted / max_resolution, dtype=int) + min_step = int(pstep / max_resolution) + for i in range(0, len(converted)): + min_step = np.gcd(min_step, converted[i]) + return min_step * max_resolution + + class SimulationData: def __init__(self, simulation: "Simulation", result=None): self.chunks = None @@ -133,4 +193,4 @@ def add_progress_listener(self, listener): self._progress_listeners.append(listener) -__all__ = ["AdapterProgress", "SimulationData", "SimulatorAdapter"] +__all__ = ["AdapterCheckpoint", "AdapterProgress", "SimulationData", "SimulatorAdapter"] diff --git a/bsb/simulation/device.py b/bsb/simulation/device.py index dfcce51e..c4b55ace 100644 --- a/bsb/simulation/device.py +++ b/bsb/simulation/device.py @@ -1,9 +1,13 @@ +from bsb import types + from .. import config from .component import SimulationComponent @config.node class DeviceModel(SimulationComponent): + checkpoints = config.attr(type=types.or_(float, types.list(type=float)), default=None) + def implement(self, adapter, simulation, simdata): raise NotImplementedError( "The " @@ -11,5 +15,25 @@ def implement(self, adapter, simulation, simdata): + " device does not implement any `implement` function." ) + def get_checkpoints(self, duration, resolution): + """If checkpoints attribute is not set return an empty list, otherwise return a list of checkpoints (in ms). + If only a float it is provided it is assumed to be the time interval between checkpoints + """ + if self.checkpoints: + if isinstance(self.checkpoints, float): + import numpy as np + + multiple = self.checkpoints / resolution + if multiple != int(multiple): + raise ValueError( + f"In device {self.name} , Checkpoints must be a multiple of {resolution}" + ) + chkp_array = np.delete(np.arange(0, duration, self.checkpoints), 0) + return chkp_array + else: + return self.checkpoints + else: + return [] + __all__ = ["DeviceModel"] diff --git a/docs/simulation/advanced.rst b/docs/simulation/advanced.rst new file mode 100644 index 00000000..a4ef2499 --- /dev/null +++ b/docs/simulation/advanced.rst @@ -0,0 +1,83 @@ +##################### +Simulation Components +##################### + +The `Simulation` object encapsulates all the parameters necessary to adapt a reconstructed network to a specific simulator. +It primarily handles the conversion of cell types and connectivity into simulator-compatible formats and prepares the experimental configuration. + +A `Simulation` is defined by the following attributes: + +* ``simulator``:*str* - Specifies the simulation software to be used. +* ``duration``: *float* - Total duration of the simulation in milliseconds. +* ``cell_models``: *dict* - Contains simulator-specific representations of the network's :doc:`CellTypes `. +* ``connection_models``: *dict* - Provides instructions for handling the network’s :doc:`ConnectivityStrategies `. +* ``devices``: *dict* - Lists the simulation devices to be included. +* ``post_prepare``: *Callable* - A hook that is executed after the simulation has been prepared. + + + + +Simulator Adapters +================== + +The `SimulatorAdapter` is the core abstraction responsible for adapting BSB simulation data and +execution flow to the specifics of a target simulator. + +This class manages the simulation pipeline, which consists of the following stages: + +**Prepare** → **Post Prepare** → **Run** → **Collect** + +As an abstract base class, it is designed to be extended to implement simulator-specific behavior. +At a minimum, custom adapters must implement the :guilabel:`prepare()` and :guilabel:`run()` methods: + +* :guilabel:`prepare()`: Initializes the simulation by configuring parameters, creating cell_models and connection_models, and invoking the :guilabel:`implement()` method on the simulation devices. +* :guilabel:`run()`: Executes the simulation. Typically, this involves stepping through simulation time intervals until the defined duration is reached, using the simulator’s solver to compute the system’s evolution. + +The **Post Prepare** and **Collect** phases do not need to be implemented. After the :guilabel:`prepare()` method completes, the adapter will execute any functions specified in the ``post_prepare`` hook. +Once the simulation ends, the **Collect** phase gathers and finalizes results. + +Adapter Iterators +----------------- +To monitor simulation progress at defined intervals, you can use the `AdapterProgress` class. +This utility handles iteration over simulation time steps. +Example usage: + +.. code-block:: python + + def run(self, *simulations: "Simulation"): + + duration = max(sim.duration for sim in simulations) + progress = AdapterProgress(duration) + my_interval=1 + for oi, i in progress.steps(step=my_interval): + my_solver(oi,i) ## call the solver from time oi to time i + tick = progress.tick(i) + progress.complete() + +* :guilabel:`steps(step=...)`: Yields time intervals of the specified step size (default is 1 ms). +* :guilabel:`tick(i)`: Returns a ``SimpleNamespace`` object with current progress information. + +If intermediate result collection is needed before simulation ends, use the `AdapterCheckpoint` class. +It coordinates checkpoints from all registered devices and merges them into a unified schedule. + +.. code-block:: python + + def run(self, *simulations: "Simulation"): + + sim = simulations["sim_name"] + duration = sim.duration + progress = AdapterProgress(duration) + my_interval=1 + checkpoint = AdapterCheckpoint(simulations) + optimal_interval = checkpoint.suitable_step(my_interval) + for oi, i in progress.steps(step=optimal_interval): + my_solver(oi,i) ## call the solver from time oi to time i + tick = progress.tick(i) + + if checkpoint.get_status(i): + self.simdata[sim].result.flush() + + progress.complete() + +* :guilabel:`suitable_step(interval)`: Suggests an optimal interval compatible with the defined checkpoints. +* :guilabel:`get_status(time)`: Returns True if a checkpoint has been reached at the given simulation time. \ No newline at end of file diff --git a/docs/simulation/intro.rst b/docs/simulation/intro.rst index 0bf1a5dc..cb940bb7 100644 --- a/docs/simulation/intro.rst +++ b/docs/simulation/intro.rst @@ -20,6 +20,8 @@ All of the above is simulation backend specific and is covered in the correspond * :doc:`NEURON `. * :doc:`ARBOR `. +To familiarize with general aspects of simulation components you can dive in the :doc:`Simulation components page `. + Running Simulations =================== diff --git a/docs/simulation/nest.rst b/docs/simulation/nest.rst index d47cf25c..8e05dab6 100644 --- a/docs/simulation/nest.rst +++ b/docs/simulation/nest.rst @@ -146,11 +146,12 @@ NEST provides two types of devices: *recording* devices, for extracting informat and *stimulation* devices, for delivering stimuli. The ``bsb-nest`` module provides interfaces for NEST devices through the ``NestDevice`` object. -To properly configure a device, you need to specify three attributes: +To properly configure a device, you need to specify four attributes: * :guilabel:`weight` : *float* specifying the connection weight between the device and its target (required). * :guilabel:`delay` : *float* specifying the transmission delay between the device and its target (required). * :guilabel:`targeting` : Specifies the targets of the device, which can be a population or a NEST rule. + * :guilabel:`receptor_type` : *int* ID of the postsynaptic target receptor. For example, to create a device named ``my_new_device`` of class ``device_type``, with a weight of 1 and a delay of 0.1 ms, targeting the population of ``my_cell_model``: @@ -182,6 +183,38 @@ and a delay of 0.1 ms, targeting the population of ``my_cell_model``: } ) +By default, device results are collected only at the end of the simulation. However, if intermediate result collection is required, +you can specify a series of time checkpoints at which the simulation will pause and gather partial results. +These checkpoints can be configured using the :guilabel:`checkpoints` attribute. This attribute accepts either: + + * A *list* of *float* values, each representing a specific time (in milliseconds) at which to collect results. + * A single *float* value, which will be interpreted as a fixed time interval between consecutive checkpoints (in ms). + +Example configuration: + +.. tab-set-code:: + + .. code-block:: json + + "my_new_device": { + "device": "device_type", + "weight": 1, + "delay": 0.1, + "checkpoints": 100, + } + .. code-block:: python + + config.simulations["my_simulation_name"].devices=dict( + my_new_device={ + "device": "device_type", + "weight": 1, + "delay": 0.1, + "checkpoints": 100, + } + ) + +In the example above, the device ``my_new_device`` will collect results every 100 milliseconds during the simulation. + Stimulation devices ------------------- diff --git a/docs/simulation/simulation-toc.rst b/docs/simulation/simulation-toc.rst index b37cdc36..8001f117 100644 --- a/docs/simulation/simulation-toc.rst +++ b/docs/simulation/simulation-toc.rst @@ -9,3 +9,4 @@ Simulation nest neuron arbor + advanced diff --git a/tests/test_simulation.py b/tests/test_simulation.py index d55bb5e9..21d4e34e 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -1,8 +1,10 @@ import unittest +import numpy as np +from bsb_arbor.device import ArborDevice from bsb_test import FixedPosConfigFixture, NumpyTestCase, RandomStorageFixture -from bsb import Scaffold +from bsb import AdapterCheckpoint, Scaffold, config, get_simulation_adapter class TestSimulate( @@ -36,3 +38,144 @@ def test_simulate(self): devices=dict(), ) self.network.run_simulation("test") + + +@config.node +class MockDevice(ArborDevice): + def implement(self): + pass + + def implement_probes(self, simdata, target): + pass + + def implement_generators(self, simdata, target): + pass + + +class test_adaptercheckpoint( + FixedPosConfigFixture, + RandomStorageFixture, + NumpyTestCase, + unittest.TestCase, + engine_name="hdf5", +): + + def setUp(self): + super().setUp() + self.cfg.connectivity.add( + "all_to_all", + dict( + strategy="bsb.connectivity.AllToAll", + presynaptic=dict(cell_types=["test_cell"]), + postsynaptic=dict(cell_types=["test_cell"]), + ), + ) + self.network = Scaffold(self.cfg, self.storage) + self.network.compile(clear=True) + self.network.simulations.add( + "test", + simulator="arbor", + duration=100, + resolution=0.25, + cell_models=dict(), + connection_models=dict(), + devices=dict( + test_mock={ + "device": MockDevice, + "targetting": {"strategy": "all"}, + "resolution": 0.25, + "checkpoints": [12.5, 25, 50], + } + ), + ) + + def test_wrong_value(self): + """Check that checkpoints values are multiple of simulation resolution""" + self.network.simulations.add( + "wtest", + simulator="arbor", + duration=100, + resolution=1, + cell_models=dict(), + connection_models=dict(), + devices=dict( + test_mock={ + "device": MockDevice, + "targetting": {"strategy": "all"}, + "resolution": 1, + "checkpoints": [12.5, 25, 50], + } + ), + ) + sim = self.network.simulations["wtest"] + AC = AdapterCheckpoint([sim]) + with self.assertRaises(ValueError): + AC.suitable_step(1) + + def test_checkpoints(self): + + sim = self.network.simulations["test"] + print(self.network.simulations["test"].simulator) + AC = AdapterCheckpoint([sim]) + min_step = AC.suitable_step(1) + self.assertEqual( + min_step, 0.5, "Suitable step should lower the progression step from 1 to 0.5" + ) + self.assertEqual( + AC.sort_checkpoints(), + [12.5, 25, 50], + "Do not return the correct checkpoints list", + ) + + def test_multi_sim(self): + self.network.simulations.add( + "2_sim", + simulator="arbor", + duration=100, + resolution=0.25, + cell_models=dict(), + connection_models=dict(), + devices=dict( + test_mock={ + "device": MockDevice, + "targetting": {"strategy": "all"}, + "resolution": 0.25, + "checkpoints": [17, 20, 25, 64, 71], + } + ), + ) + sim = [self.network.simulations["test"], self.network.simulations["2_sim"]] + AC = AdapterCheckpoint(sim) + min_step = AC.suitable_step(1) + self.assertEqual( + min_step, 0.5, "Suitable step should lower the progression step from 1 to 0.5" + ) + + time_iterator = iter( + np.arange(0, self.network.simulations["test"].duration, min_step) + ) + check_points = [] + sim_ref = [] + for step in time_iterator: + if AC.get_status(step): + check_points.append(step) + sim_ref.append([sim.name for sim in AC.checkpoints[step]]) + expected_sim_order = [ + ["test"], + ["2_sim"], + ["2_sim"], + ["test", "2_sim"], + ["test"], + ["2_sim"], + ["2_sim"], + ] + self.assertEqual( + sim_ref, + expected_sim_order, + "The references to simulations are wrongly assigned", + ) + self.assertEqual( + check_points, + [12.5, 17, 20, 25, 50, 64, 71], + "Do not return the correct checkpoints list", + )