Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 47 additions & 8 deletions brian2tools/baseexport/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def collect_NeuronGroup(group, run_namespace):
# get user defined stateupdation method
if isinstance(group.method_choice, str):
neuron_dict['user_method'] = group.method_choice
# if not specified by user
# TODO collect from run time
else:
neuron_dict['user_method'] = None

Expand Down Expand Up @@ -83,6 +81,11 @@ def collect_NeuronGroup(group, run_namespace):
if isinstance(obj, StateUpdater):
neuron_dict['when'] = obj.when
neuron_dict['order'] = obj.order
# capture the runtime-resolved method (may differ from user_method
# when the user did not specify one explicitly)
if neuron_dict['user_method'] is None and isinstance(
obj.method_choice, str):
neuron_dict['resolved_method'] = obj.method_choice

# resolve group-specific identifiers
identifiers = group.resolve_all(identifiers, run_namespace)
Expand Down Expand Up @@ -129,8 +132,6 @@ def collect_SpatialNeuron(group, run_namespace):
# get user defined stateupdation method
if isinstance(group.method_choice, str):
neuron_dict['user_method'] = group.method_choice
# if not specified by user
# TODO collect from run time
else:
neuron_dict['user_method'] = None

Expand Down Expand Up @@ -173,15 +174,20 @@ def collect_SpatialNeuron(group, run_namespace):
if isinstance(obj, StateUpdater):
neuron_dict['when'] = obj.when
neuron_dict['order'] = obj.order

# capture the runtime-resolved method (may differ from user_method
# when the user did not specify one explicitly)
if neuron_dict['user_method'] is None and isinstance(
obj.method_choice, str):
neuron_dict['resolved_method'] = obj.method_choice

# resolve group-specific identifiers
identifiers = group.resolve_all(identifiers, run_namespace)
# with the identifiers connected to group, prune away unwanted
identifiers = _prepare_identifiers(identifiers)
# check the dictionary is not empty
if identifiers:
neuron_dict['identifiers'] = identifiers

return neuron_dict


Expand Down Expand Up @@ -625,9 +631,11 @@ def collect_Synapses(synapses, run_namespace):
'dt': obj.clock.dt, 'order': obj.order,
'when': obj.when
}
# check delay is defined
# only capture scalar (homogeneous) delays here; per-synapse
# heterogeneous delays are already recorded via the generic
# initializer mechanism in initializers_connectors
if obj.variables['delay'].scalar:
path.update({'delay': obj.delay[:]})
path.update({'delay': obj.delay[:]})
pathways.append(path)
# check any identifiers specific to pathway expression
_, _, unknown = analyse_identifiers(obj.code, obj.variables)
Expand All @@ -648,6 +656,37 @@ def collect_Synapses(synapses, run_namespace):
return synapse_dict


def collect_NetworkOperation(net_op):
"""
Collect details of a `brian2.core.magic.NetworkOperation` and represent
them in dictionary format.

The function body cannot be serialized automatically. A warning is logged
so the user knows the operation is present but its code is not captured.

Parameters
----------
net_op : brian2.core.magic.NetworkOperation

Returns
-------
net_op_dict : dict
"""
from brian2.utils.logger import get_logger
_logger = get_logger(__name__)
_logger.warn(
f"NetworkOperation '{net_op.name}' cannot be fully serialized: "
"the function body is not captured. Only scheduling metadata is stored."
)
return {
'name': net_op.name,
'dt': net_op.clock.dt,
'when': net_op.when,
'order': net_op.order,
'unsupported': 'function body not serializable',
}


def collect_PoissonInput(poinp, run_namespace):
"""
Collect details of `PoissonInput` and represent them in dictionary
Expand Down
7 changes: 5 additions & 2 deletions brian2tools/baseexport/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Synapses,
get_local_namespace,
)
from brian2.core.operations import NetworkOperation
from brian2.core.variables import DynamicArrayVariable
from brian2.devices.device import Device, RuntimeDevice, all_devices
from brian2.groups import NeuronGroup
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(self):
self.supported_objs = (NeuronGroup, SpatialNeuron, SpikeGeneratorGroup,
PoissonGroup, StateMonitor, SpikeMonitor,
EventMonitor, PopulationRateMonitor, Synapses,
PoissonInput)
PoissonInput, NetworkOperation)
self.runs = []
self.initializers_connectors = []
self.array_cache = {}
Expand Down Expand Up @@ -148,7 +149,9 @@ def network_run(self, network, duration, namespace=None, level=0, **kwds):
run_inactive = [] # inactive objects for the run

# dictionary to store objects and its collector functions
collector_map={'neurongroup': {'f': collect_NeuronGroup, 'n': True},
collector_map={'networkoperation': {'f': collect_NetworkOperation,
'n': False},
'neurongroup': {'f': collect_NeuronGroup, 'n': True},
'spatialneuron': {'f': collect_SpatialNeuron, 'n': True},
'poissongroup': {'f': collect_PoissonGroup, 'n': True},
'spikegeneratorgroup': {'f': collect_SpikeGenerator,
Expand Down