diff --git a/bsb_neuron/adapter.py b/bsb_neuron/adapter.py index dc9c929..8262b3b 100644 --- a/bsb_neuron/adapter.py +++ b/bsb_neuron/adapter.py @@ -4,7 +4,6 @@ import numpy as np from bsb import ( - MPI, AdapterError, AdapterProgress, Chunk, @@ -61,8 +60,8 @@ def fill_parameter_data(parameters, data): class NeuronAdapter(SimulatorAdapter): initial = -65 - def __init__(self): - super().__init__() + def __init__(self, comm=None): + super().__init__(comm=comm) self.network = None self.next_gid = 0 @@ -72,7 +71,7 @@ def engine(self): return engine - def prepare(self, simulation, comm=None): + def prepare(self, simulation): self.simdata[simulation] = NeuronSimulationData( simulation, result=NeuronResult(simulation) ) @@ -97,14 +96,14 @@ def prepare(self, simulation, comm=None): def load_balance(self, simulation): simdata = self.simdata[simulation] chunk_stats = simulation.scaffold.storage.get_chunk_stats() - size = MPI.get_size() + size = self.comm.get_size() all_chunks = [Chunk.from_id(int(chunk), None) for chunk in chunk_stats.keys()] simdata.node_chunk_alloc = [all_chunks[rank::size] for rank in range(0, size)] simdata.chunk_node_map = {} for node, chunks in enumerate(simdata.node_chunk_alloc): for chunk in chunks: simdata.chunk_node_map[chunk] = node - simdata.chunks = simdata.node_chunk_alloc[MPI.get_rank()] + simdata.chunks = simdata.node_chunk_alloc[self.comm.get_rank()] simdata.placement = { model: model.get_placement_set(chunks=simdata.chunks) for model in simulation.cell_models.values()