diff --git a/bsb_neuron/adapter.py b/bsb_neuron/adapter.py index 8262b3b..cdfbced 100644 --- a/bsb_neuron/adapter.py +++ b/bsb_neuron/adapter.py @@ -6,6 +6,7 @@ from bsb import ( AdapterError, AdapterProgress, + AdapterCheckpoint, Chunk, DatasetNotFoundError, SimulationData, @@ -42,6 +43,37 @@ def flush(segment): segment.analogsignals.append( AnalogSignal(list(v), sampling_period=p.dt * ms, **annotations) ) + # Free the memory + print(f"Size V vec: {v.size()}") + v.remove(0, v.size() - 1) + + self.create_recorder(flush) + + def record_lfp(self, obj_list, matrices, **annotations): + from patch import p + from quantities import ms + + v_list = [[p.record(obj) for obj in location_list] for location_list in obj_list] + + def flush(segment): + if "units" not in annotations.keys(): + annotations["units"] = "mV" + + V = sum( + [ + np.array(matrices[cell_id]) @ np.array(v_list[cell_id]) + for cell_id in range(len(obj_list)) + ] + ) + # Need to flatten the array to pass AnalogSignal -> V_flat should be something like: [(mea array at time 0), (mea array at time 1)...(final mea array)] + V_flat = V.flatten(order="F") + segment.analogsignals.append( + AnalogSignal(V_flat, sampling_period=p.dt * ms, **annotations) + ) + # Free the memory of the Vectors + for location_list in v_list: + for obj in location_list: + obj.remove(0, obj.size() - 1) self.create_recorder(flush) @@ -120,11 +152,19 @@ def run(self, *simulations: "Simulation"): self.engine.finitialize(self.initial) duration = max(sim.duration for sim in simulations) progress = AdapterProgress(duration) - for oi, i in progress.steps(step=1): + progress_step = 1 + checkpoint = AdapterCheckpoint(simulations) + minimum_step = checkpoint.suitable_step(progress_step) + for oi, i in progress.steps(step=minimum_step): pc.psolve(i) tick = progress.tick(i) for listener in self._progress_listeners: listener(simulations, tick) + if checkpoint.get_status(i): + [ + self.simdata[sim].result.flush() + for sim in checkpoint.checkpoints[i] + ] progress.complete() report("Finished simulation.", level=2) finally: diff --git a/bsb_neuron/devices/__init__.py b/bsb_neuron/devices/__init__.py index ce2852e..344bcc5 100644 --- a/bsb_neuron/devices/__init__.py +++ b/bsb_neuron/devices/__init__.py @@ -1,5 +1,7 @@ from .current_clamp import CurrentClamp from .ion_recorder import IonRecorder +from .lfp_recorder import LFPRecorder +from .membrane_current_recorder import MembraneCurrentRecorder from .spike_generator import SpikeGenerator from .synapse_recorder import SynapseRecorder from .voltage_clamp import VoltageClamp diff --git a/bsb_neuron/devices/lfp_recorder.py b/bsb_neuron/devices/lfp_recorder.py new file mode 100644 index 0000000..9212503 --- /dev/null +++ b/bsb_neuron/devices/lfp_recorder.py @@ -0,0 +1,146 @@ +import typing + +import MEAutility.core as mu +import numpy as np +from bsb import LocationTargetting, config, types +from lfpykit import CellGeometry, LineSourcePotential, RecMEAElectrode + +from .membrane_current_recorder import MembraneCurrentRecorder + + +@config.node +class MeaElectrode: + electrode_name = config.attr(type=str, required=True) + definitions: dict[typing.Any] = config.dict(type=types.any_()) + rotations = config.dict(type=types.or_(types.list(type=int), float), default=None) + shift = config.list(type=int, default=None) + + def __boot__(self): + if self.electrode_name in mu.return_mea_list(): + self.custom = False + else: + if self.definitions: + self.custom = True + self.definitions["electrode_name"] = self.electrode_name + + else: + raise ValueError( + f"Do not find {self.electrode_name} probe. Available models for MEA arrays: {mu.return_mea_list()}" + ) + + def return_probe(self): + # Check if we are using a custom probe and create MEA object + if self.custom: + # Clean definitions, make sure that scaffold objects are not passed to MEA classes + info_dict = {} + for key, value in self.definitions.items(): + if key not in ["scaffold", "_config_parent"]: + info_dict[key] = value + pos = mu.get_positions(info_dict) + if mu.check_if_rect(info_dict): + mea_obj = mu.RectMEA(positions=pos, info=info_dict) + else: + mea_obj = mu.MEA(positions=pos, info=info_dict) + else: + mea_obj = mu.MEA.return_mea(self.electrode_name) + # If a rotation is selected rotate the array + if self.rotations: + mea_obj.rotate(self.rotations["axis"], self.rotations["angle"]) + if self.shift: + mea_obj.move(self.shift) + return mea_obj + + +@config.node +class LFPRecorder(MembraneCurrentRecorder, classmap_entry="lfp_recorder"): + locations = config.attr(type=LocationTargetting, default={"strategy": "everywhere"}) + mea_electrode = config.attr(type=MeaElectrode, required=True) + checkpoints = config.attr(type=types.or_(float, types.list(type=float)), default=[]) + + def implement(self, adapter, simulation, simdata): + my_probe = self.mea_electrode.return_probe() + for model, pop in self.targetting.get_targets( + adapter, simulation, simdata + ).items(): + + origins = simdata.placement[model].load_positions() + list_of_sections = [[] for x in range(len(pop))] + trs_matrices = [0 for x in range(len(pop))] + global_ids = [0 for x in range(len(pop))] + for local_cell_id, target in enumerate(pop): + # collect all locations from the target cell + locations = self.locations.get_locations(target) + n_locs = len(locations) + x_i = np.zeros([n_locs, 2]) + y_i = np.zeros([n_locs, 2]) + z_i = np.zeros([n_locs, 2]) + d_i = np.zeros([n_locs, 2]) + + for i_loc, location in enumerate(locations): + + # get for each location xyz coords and diam + section = location.section + idx_loc = section.locations.index( + location._loc + ) # index in section.location + idx_next_loc = ( + idx_loc + 1 if location.arc(0) < 1 else idx_loc + ) # there is another point after + x_i[i_loc] = [section.x3d(idx_loc), section.x3d(idx_next_loc)] + y_i[i_loc] = [section.y3d(idx_loc), section.y3d(idx_next_loc)] + z_i[i_loc] = [section.z3d(idx_loc), section.z3d(idx_next_loc)] + d_i[i_loc] = [section.diam3d(idx_loc), section.diam3d(idx_next_loc)] + + # note: recording by default done section(loc.arc(0)) + # create CellGeometry of target by using the selected locations + # matrix M (given the probe geometry/properties) + + origin = origins[local_cell_id] + + cell_i = CellGeometry( + x=y_i + origin[0], y=x_i + origin[1], z=z_i + origin[2], d=d_i + ) + lsp = RecMEAElectrode( + cell_i, + sigma_T=0.3, + sigma_S=1.5, + sigma_G=0.0, + h=400.0, + z_shift=-100.0, + method="linesource", + steps=20, + probe=my_probe, + ) + + M_i = lsp.get_transformation_matrix() + + pos_nan = np.logical_not( + np.isnan(np.sum(M_i, 0)) + ) # check for nan, this happens when points with 0 length + + # Store the transform matrix and the cell id in lists + trs_matrices[local_cell_id] = M_i[:, pos_nan] + global_ids[local_cell_id] = target.id + + for i_loc, location in enumerate(locations): + if pos_nan[i_loc]: + section = location.section + x = location.arc(0) + list_of_sections[local_cell_id].append( + section(x).__record_imem__() + ) + + trs_matrix_size = np.sum([np.shape(mat)[1] for mat in trs_matrices]) + obj_list_size = np.sum([len(obj) for obj in list_of_sections]) + if trs_matrix_size != obj_list_size: + raise ValueError( + f" In LFP recorder {self.name} numbers of computed sections do not match! {trs_matrix_size} != {obj_list_size}" + ) + + simdata.result.record_lfp( + list_of_sections, + trs_matrices, + name=self.name, + cell_type=model.name, + id_list=global_ids, + ) diff --git a/bsb_neuron/devices/membrane_current_recorder.py b/bsb_neuron/devices/membrane_current_recorder.py new file mode 100644 index 0000000..155fc92 --- /dev/null +++ b/bsb_neuron/devices/membrane_current_recorder.py @@ -0,0 +1,35 @@ +from bsb import LocationTargetting, config + +from ..device import NeuronDevice + + +@config.node +class MembraneCurrentRecorder(NeuronDevice, classmap_entry="membrane_current_recorder"): + locations = config.attr(type=LocationTargetting, default={"strategy": "everywhere"}) + + def implement(self, adapter, simulation, simdata): + for model, pop in self.targetting.get_targets( + adapter, simulation, simdata + ).items(): + for target in pop: + for location in self.locations.get_locations(target): + self._add_imem_recorder( + simdata.result, + location, + name=self.name, + cell_type=target.cell_model.name, + cell_id=target.id, + loc=location._loc, + ) + + def _get_imem(self, location): + from patch import p + + section = location.section + x = location.arc(0) + return p.record(section(x).__record_imem__()) + + def _add_imem_recorder(self, results, location, **annotations): + section = location.section + x = location.arc(0) + results.record(section(x).__record_imem__(), **annotations, units="nA") diff --git a/pyproject.toml b/pyproject.toml index cf13240..7686fe4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,9 @@ dynamic = ["version", "description"] dependencies = [ "bsb-core~=5.0", "nrn-patch~=4.0", - "arborize[neuron]~=4.1" + "arborize[neuron]~=4.1", + "MEAutility~=1.5", + "LFPykit~=0.5" ] [tool.flit.module]