diff --git a/brian2tools/baseexport/collector.py b/brian2tools/baseexport/collector.py index 9f497600..9369f174 100644 --- a/brian2tools/baseexport/collector.py +++ b/brian2tools/baseexport/collector.py @@ -49,8 +49,6 @@ def collect_NeuronGroup(group, run_namespace): # get user defined stateupdation method if isinstance(group.method_choice, str): neuron_dict['user_method'] = group.method_choice - # if not specified by user - # TODO collect from run time else: neuron_dict['user_method'] = None @@ -83,6 +81,11 @@ def collect_NeuronGroup(group, run_namespace): if isinstance(obj, StateUpdater): neuron_dict['when'] = obj.when neuron_dict['order'] = obj.order + # capture the runtime-resolved method (may differ from user_method + # when the user did not specify one explicitly) + if neuron_dict['user_method'] is None and isinstance( + obj.method_choice, str): + neuron_dict['resolved_method'] = obj.method_choice # resolve group-specific identifiers identifiers = group.resolve_all(identifiers, run_namespace) @@ -129,8 +132,6 @@ def collect_SpatialNeuron(group, run_namespace): # get user defined stateupdation method if isinstance(group.method_choice, str): neuron_dict['user_method'] = group.method_choice - # if not specified by user - # TODO collect from run time else: neuron_dict['user_method'] = None @@ -173,7 +174,12 @@ def collect_SpatialNeuron(group, run_namespace): if isinstance(obj, StateUpdater): neuron_dict['when'] = obj.when neuron_dict['order'] = obj.order - + # capture the runtime-resolved method (may differ from user_method + # when the user did not specify one explicitly) + if neuron_dict['user_method'] is None and isinstance( + obj.method_choice, str): + neuron_dict['resolved_method'] = obj.method_choice + # resolve group-specific identifiers identifiers = group.resolve_all(identifiers, run_namespace) # with the identifiers connected to group, prune away unwanted @@ -181,7 +187,7 @@ def collect_SpatialNeuron(group, run_namespace): # check the dictionary is not empty if identifiers: neuron_dict['identifiers'] = identifiers - + return neuron_dict @@ -625,9 +631,11 @@ def collect_Synapses(synapses, run_namespace): 'dt': obj.clock.dt, 'order': obj.order, 'when': obj.when } - # check delay is defined + # only capture scalar (homogeneous) delays here; per-synapse + # heterogeneous delays are already recorded via the generic + # initializer mechanism in initializers_connectors if obj.variables['delay'].scalar: - path.update({'delay': obj.delay[:]}) + path.update({'delay': obj.delay[:]}) pathways.append(path) # check any identifiers specific to pathway expression _, _, unknown = analyse_identifiers(obj.code, obj.variables) @@ -648,6 +656,37 @@ def collect_Synapses(synapses, run_namespace): return synapse_dict +def collect_NetworkOperation(net_op): + """ + Collect details of a `brian2.core.magic.NetworkOperation` and represent + them in dictionary format. + + The function body cannot be serialized automatically. A warning is logged + so the user knows the operation is present but its code is not captured. + + Parameters + ---------- + net_op : brian2.core.magic.NetworkOperation + + Returns + ------- + net_op_dict : dict + """ + from brian2.utils.logger import get_logger + _logger = get_logger(__name__) + _logger.warn( + f"NetworkOperation '{net_op.name}' cannot be fully serialized: " + "the function body is not captured. Only scheduling metadata is stored." + ) + return { + 'name': net_op.name, + 'dt': net_op.clock.dt, + 'when': net_op.when, + 'order': net_op.order, + 'unsupported': 'function body not serializable', + } + + def collect_PoissonInput(poinp, run_namespace): """ Collect details of `PoissonInput` and represent them in dictionary diff --git a/brian2tools/baseexport/device.py b/brian2tools/baseexport/device.py index 8d716611..eba86270 100644 --- a/brian2tools/baseexport/device.py +++ b/brian2tools/baseexport/device.py @@ -9,6 +9,7 @@ Synapses, get_local_namespace, ) +from brian2.core.operations import NetworkOperation from brian2.core.variables import DynamicArrayVariable from brian2.devices.device import Device, RuntimeDevice, all_devices from brian2.groups import NeuronGroup @@ -66,7 +67,7 @@ def __init__(self): self.supported_objs = (NeuronGroup, SpatialNeuron, SpikeGeneratorGroup, PoissonGroup, StateMonitor, SpikeMonitor, EventMonitor, PopulationRateMonitor, Synapses, - PoissonInput) + PoissonInput, NetworkOperation) self.runs = [] self.initializers_connectors = [] self.array_cache = {} @@ -148,7 +149,9 @@ def network_run(self, network, duration, namespace=None, level=0, **kwds): run_inactive = [] # inactive objects for the run # dictionary to store objects and its collector functions - collector_map={'neurongroup': {'f': collect_NeuronGroup, 'n': True}, + collector_map={'networkoperation': {'f': collect_NetworkOperation, + 'n': False}, + 'neurongroup': {'f': collect_NeuronGroup, 'n': True}, 'spatialneuron': {'f': collect_SpatialNeuron, 'n': True}, 'poissongroup': {'f': collect_PoissonGroup, 'n': True}, 'spikegeneratorgroup': {'f': collect_SpikeGenerator,