diff --git a/README_POC.md b/README_POC.md new file mode 100644 index 00000000..8b38677e --- /dev/null +++ b/README_POC.md @@ -0,0 +1,54 @@ +# PoC: BrianExporter / BrianImporter + +Proof-of-concept for GSoC 2026: *Serialization/Deserialization for Brian2 models, results, and input data*. + +## What this is + +A minimal demonstration that a Brian2 network — equations, connectivity, and state variables — can be serialized to a portable archive and reconstructed object-by-object without running the simulation again. + +This PoC lives entirely in `brian2tools` and builds on the existing `collect_*()` infrastructure in `brian2tools/baseexport/collector.py`. + +## What was built + +``` +brian2tools/baseexport/exporter.py BrianExporter.serialize(net, 'file.brian') +brian2tools/baseimport/__init__.py package entry point +brian2tools/baseimport/importer.py BrianImporter.load('file.brian') → (net, namespace) +brian2tools/tests/test_poc_exporter.py 11 tests covering archive structure + round-trips +examples/poc_exporter_demo.py end-to-end demo +``` + +The `.brian` archive is a ZIP file containing: +- `model.json` — network structure from existing `collect_*()` functions, with `Quantity` objects converted to JSON-safe dicts +- `arrays.npz` — numerical data: state variable values + synaptic connectivity arrays + +## How to run + +```bash +# Install brian2tools in development mode (from repo root) +pip install -e . + +# Run the demo +python examples/poc_exporter_demo.py + +# Run the tests +python -m pytest brian2tools/tests/test_poc_exporter.py -v +``` + +## Core mechanism + +`BrianExporter.serialize()` does three things that `BaseExporter` deliberately omits: + +1. **Converts Quantities to JSON** — `collect_Equations()` stores `eqs.unit` as a raw `Quantity`; `_json_safe()` converts all Quantities to `{'__type__': 'quantity', 'value': ..., 'dim': [7-element SI exponent tuple]}`. + +2. **Captures state variable values** — `BaseExporter` intercepts code generation before the simulation, so it never sees actual values. `_collect_state()` reads them after `net.run()` via `var.get_value()`. + +3. **Captures actual connectivity arrays** — `collect_Synapses()` stores the `connect()` arguments (condition, p, n) but not the resulting `_synaptic_pre`/`_synaptic_post` arrays. `_collect_connectivity()` captures the arrays directly so `BrianImporter` can restore exact connectivity via `syn.connect(i=i_arr, j=j_arr)` — critical for probabilistic connections. + +`BrianImporter.load()` reconstructs objects in dependency order (groups → synapses → state restore) and returns a `Network` ready to continue running. + +## What this is not + +- Not a full implementation — monitors, PoissonGroup, SpikeGeneratorGroup, and SpatialNeuron reconstruction are planned but not in this PoC. +- Not a device mode integration — `BrianExporter` is called explicitly after `net.run()`, not via `set_device('exporter')`. +- Not production-ready — edge cases (TimedArray identifiers, multiple clocks, SpatialNeuron) are out of scope for this PoC. diff --git a/brian2tools/baseexport/exporter.py b/brian2tools/baseexport/exporter.py new file mode 100644 index 00000000..a03f6d03 --- /dev/null +++ b/brian2tools/baseexport/exporter.py @@ -0,0 +1,288 @@ +""" +BrianExporter — serialize a Brian2 Network to a portable .brian archive. + +The .brian archive is a ZIP file containing: + - model.json : network structure produced by existing collect_*() functions + (same dict as device.runs[0]['components'] in BaseExporter), + with Quantity objects converted to JSON-safe dicts + - arrays.npz : numerical data that cannot go in JSON — state variable values, + synaptic connectivity arrays (_synaptic_pre/_synaptic_post) + +Call serialize() *after* net.run() to capture both structure and state. + + net.run(10*ms) + BrianExporter().serialize(net, 'snapshot.brian') + +See also +-------- +brian2tools.baseimport.importer.BrianImporter +""" + +import io +import json +import zipfile + +import numpy as np + +import brian2 +from brian2 import Synapses, get_local_namespace +from brian2.core.variables import ArrayVariable +from brian2.units.fundamentalunits import Quantity + +from .collector import ( + collect_EventMonitor, + collect_NeuronGroup, + collect_PoissonGroup, + collect_PoissonInput, + collect_PopulationRateMonitor, + collect_SpikeGenerator, + collect_SpikeMonitor, + collect_StateMonitor, + collect_Synapses, +) + +# Mirrors collector_map in BaseExporter.network_run() (device.py line 151). +# Tuple: (collector_function, needs_run_namespace) +COLLECTOR_MAP = { + 'neurongroup': (collect_NeuronGroup, True), + 'poissongroup': (collect_PoissonGroup, True), + 'spikegeneratorgroup': (collect_SpikeGenerator, True), + 'statemonitor': (collect_StateMonitor, False), + 'spikemonitor': (collect_SpikeMonitor, False), + 'eventmonitor': (collect_EventMonitor, False), + 'populationratemonitor': (collect_PopulationRateMonitor, False), + 'synapses': (collect_Synapses, True), + 'poissoninput': (collect_PoissonInput, True), +} + +FORMAT_VERSION = '1' + + +def _quantity_to_dict(q): + """ + Convert a Brian2 Quantity to a JSON-serializable dict. + + Stores the raw SI value and the 7-element dimension tuple + (metre, kg, second, ampere, kelvin, mole, candela exponents) + from Dimension._dims so reconstruction is unit-system independent. + """ + value = q.variable if hasattr(q, 'variable') else np.asarray(q) + return { + '__type__': 'quantity', + 'value': value.tolist() if isinstance(value, np.ndarray) else float(value), + 'dim': list(q.dim._dims), + } + + +def _json_safe(obj, arrays_dict, prefix=''): + """ + Recursively convert a collector dict to JSON-serializable form. + + The main problem with collector output is that collect_Equations() + (collector.py line 212) stores eqs.unit as a raw Quantity, and + collect_PoissonGroup() (line 366), collect_SpikeGenerator() (line 306), + and _prepare_identifiers() (helper.py line 34) also produce Quantity + values. This function converts all of them. + + Quantity → {'__type__': 'quantity', 'value': ..., 'dim': [...]} + np.ndarray → stored in arrays_dict, replaced by {'__type__': 'array', 'key': ...} + np.integer / np.floating / np.bool_ → Python primitives + """ + if isinstance(obj, Quantity): + return _quantity_to_dict(obj) + elif isinstance(obj, np.ndarray): + key = prefix + arrays_dict[key] = obj + return {'__type__': 'array', 'key': key} + elif isinstance(obj, dict): + return { + k: _json_safe(v, arrays_dict, + prefix=f'{prefix}.{k}' if prefix else k) + for k, v in obj.items() + } + elif isinstance(obj, (list, tuple)): + return [ + _json_safe(item, arrays_dict, prefix=f'{prefix}[{i}]') + for i, item in enumerate(obj) + ] + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.bool_): + return bool(obj) + return obj + + +class BrianExporter: + """ + Export a Brian2 Network to a portable .brian archive after net.run(). + + Extends the structural capture of BaseExporter with two things that + BaseExporter deliberately omits: + + 1. State variable values — collect_NeuronGroup() stores equations but + not the current v[:], w[:] etc. This class captures them. + + 2. Actual connectivity arrays — collect_Synapses() (collector.py line 565) + stores the connect() arguments (condition, p, n) via synapses_connect() + (device.py line 337), but not the resulting _synaptic_pre/_synaptic_post + arrays (scalar delays only, line 629). This class captures the arrays + directly so BrianImporter can restore exact connectivity via + syn.connect(i=i_arr, j=j_arr) without re-running probabilistic logic. + """ + + def serialize(self, net, filepath, namespace=None, level=0): + """ + Serialize network structure and state to a .brian archive. + + Parameters + ---------- + net : brian2.core.network.Network + A network that has already been run (or at least before_run'd). + filepath : str + Destination path; conventionally ends in '.brian'. + namespace : dict, optional + Additional namespace for resolving identifiers. If not given, + collected from the caller's local scope — same pattern as + BaseExporter.network_run() (device.py line 141). + level : int, optional + Stack depth offset for namespace collection. + """ + if namespace is None: + namespace = get_local_namespace(level + 1) + + arrays_dict = {} + + components = self._collect_structure(net, arrays_dict, namespace) + state_vars = self._collect_state(net, arrays_dict) + connectivity = self._collect_connectivity(net, arrays_dict) + + model = { + 'format_version': FORMAT_VERSION, + 'brian_version': brian2.__version__, + 't': float(net.t_), + 'components': components, + 'state_variables': state_vars, + 'connectivity': connectivity, + } + + self._write_archive(filepath, model, arrays_dict) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _collect_structure(self, net, arrays_dict, run_namespace): + """ + Call existing collect_*() functions for every network object. + + Mirrors BaseExporter.network_run() (device.py line 170) using the + same COLLECTOR_MAP pattern (device.py line 151). Passes result + through _json_safe() to resolve the Quantity-in-dict problem. + + Also adds 'equations_str' — str(obj.user_equations) — to groups + that have one, because NeuronGroup.__init__ accepts a plain string + and str(Equations) produces a canonical parseable form. + """ + components = {} + + for obj in net.objects: + obj_type = type(obj).__name__.lower() + if obj_type not in COLLECTOR_MAP: + continue + + collector_fn, needs_ns = COLLECTOR_MAP[obj_type] + obj_dict = (collector_fn(obj, run_namespace) + if needs_ns else collector_fn(obj)) + + # equations_str lets BrianImporter call NeuronGroup(N, model_str) + # or Synapses(src, tgt, model_str) directly. + # NeuronGroup uses user_equations; Synapses uses equations. + if hasattr(obj, 'user_equations'): + obj_dict['equations_str'] = str(obj.user_equations) + elif hasattr(obj, 'equations') and obj.equations is not None: + obj_dict['equations_str'] = str(obj.equations) + + safe = _json_safe(obj_dict, arrays_dict, + prefix=f'struct.{obj_type}.{obj.name}') + components.setdefault(obj_type, []).append(safe) + + return components + + def _collect_state(self, net, arrays_dict): + """ + Capture current values of all public ArrayVariables. + + BaseExporter is a Device subclass that intercepts code generation + before the simulation runs, so it never sees actual values. This + method runs after net.run() and reads them directly via + var.get_value() — the same mechanism Network.store() uses internally + (group.py line 768, VariableOwner._full_state()). + """ + state_vars = {} + # Variables internal to Brian2 that should not be serialized. + _SKIP = frozenset({'i', 'j', 'N', 't', 'dt', 't_in_timesteps'}) + + for obj in net.objects: + if not hasattr(obj, 'variables'): + continue + for var_name, var in obj.variables.items(): + if not isinstance(var, ArrayVariable): + continue + # var.owner is a different Python object from obj even when + # they wrap the same group; compare by name instead of identity. + if not hasattr(var.owner, 'name'): + continue + if var.owner.name != obj.name: + continue + if var_name.startswith('_') or var_name in _SKIP: + continue + try: + values = var.get_value().copy() + key = f'state.{obj.name}.{var_name}' + arrays_dict[key] = values + state_vars[f'{obj.name}.{var_name}'] = {'array_key': key} + except Exception: + pass + + return state_vars + + def _collect_connectivity(self, net, arrays_dict): + """ + Capture _synaptic_pre and _synaptic_post for every Synapses object. + + collect_Synapses() (collector.py line 565) stores the arguments + passed to connect() — condition string, p, n — but NOT the resulting + integer index arrays. For probabilistic connections (p=0.1) or + condition-based connections, those arguments cannot reproduce the + exact same connectivity on load. Storing the arrays directly makes + restoration deterministic. + """ + connectivity = {} + + for obj in net.objects: + if not isinstance(obj, Synapses): + continue + try: + i_arr = obj.variables['_synaptic_pre'].get_value().copy() + j_arr = obj.variables['_synaptic_post'].get_value().copy() + i_key = f'conn.{obj.name}.i' + j_key = f'conn.{obj.name}.j' + arrays_dict[i_key] = i_arr + arrays_dict[j_key] = j_arr + connectivity[obj.name] = {'i_key': i_key, 'j_key': j_key} + except Exception: + pass + + return connectivity + + def _write_archive(self, filepath, model, arrays_dict): + """Write model.json + arrays.npz into a single ZIP archive.""" + npz_buf = io.BytesIO() + np.savez_compressed(npz_buf, **arrays_dict) + npz_buf.seek(0) + + with zipfile.ZipFile(filepath, 'w', zipfile.ZIP_DEFLATED) as zf: + zf.writestr('model.json', json.dumps(model, indent=2)) + zf.writestr('arrays.npz', npz_buf.read()) diff --git a/brian2tools/baseimport/__init__.py b/brian2tools/baseimport/__init__.py new file mode 100644 index 00000000..be7ccd3a --- /dev/null +++ b/brian2tools/baseimport/__init__.py @@ -0,0 +1,11 @@ +""" +baseimport — reconstruct a Brian2 Network from a .brian archive. + + from brian2tools.baseimport import BrianImporter + net, namespace = BrianImporter().load('snapshot.brian') + net.run(10*ms) +""" + +from .importer import BrianImporter + +__all__ = ['BrianImporter'] diff --git a/brian2tools/baseimport/importer.py b/brian2tools/baseimport/importer.py new file mode 100644 index 00000000..ff91dbb6 --- /dev/null +++ b/brian2tools/baseimport/importer.py @@ -0,0 +1,300 @@ +""" +BrianImporter — reconstruct a Brian2 Network from a .brian archive. + +The importer reads the model.json + arrays.npz produced by BrianExporter +and calls the standard Brian2 constructors to recreate each object. + +Each _reconstruct_*() method maps directly from the dict fields produced +by the corresponding collect_*() function in collector.py, augmented by +the extra fields BrianExporter adds (equations_str, state_variables, +connectivity arrays). + +This is a minimal implementation covering NeuronGroup and Synapses — enough +to demonstrate that the round-trip approach is sound. Monitors, PoissonGroup, +SpikeGeneratorGroup follow the same pattern and will be added in the full +implementation. +""" + +import io +import json +import warnings +import zipfile + +import numpy as np + +import brian2 +from brian2 import Network, NeuronGroup, Synapses, ms, second +from brian2.core.variables import ArrayVariable, DynamicArrayVariable +from brian2.units.fundamentalunits import Dimension, Quantity + + +def _dict_to_quantity(d): + """ + Inverse of _quantity_to_dict() in exporter.py. + + Reconstructs a Quantity from {'value': ..., 'dim': [7 floats]}. + + Directly constructs a Dimension from the stored 7-element _dims tuple + (metre, kg, second, amp, kelvin, mole, candela exponents). This avoids + depending on a specific unit name being exported and is robust across + Brian2 versions since _dims is a stable internal attribute. + """ + value = np.asarray(d['value']) + dims_tuple = tuple(float(x) for x in d['dim']) + dim = Dimension.__new__(Dimension) + object.__setattr__(dim, '_dims', dims_tuple) + return Quantity(value, dim=dim) + + +def _restore_obj(obj_dict): + """ + Walk a JSON-safe dict and convert {'__type__': 'quantity', ...} entries + back to Quantity objects, and {'__type__': 'array', ...} entries to a + sentinel (arrays are loaded separately from arrays.npz). + """ + if isinstance(obj_dict, dict): + if obj_dict.get('__type__') == 'quantity': + return _dict_to_quantity(obj_dict) + return {k: _restore_obj(v) for k, v in obj_dict.items()} + elif isinstance(obj_dict, list): + return [_restore_obj(item) for item in obj_dict] + return obj_dict + + +class BrianImporter: + """ + Reconstruct a Brian2 Network from a .brian archive. + + Usage + ----- + net, namespace = BrianImporter().load('snapshot.brian') + net.run(10*ms) # continue from the saved state + + The returned namespace contains any TimedArray or custom-function + objects that were part of the original network's identifier scope. + """ + + def load(self, filepath): + """ + Load a .brian archive and return a reconstructed Network. + + Parameters + ---------- + filepath : str + Path to a .brian archive created by BrianExporter. + + Returns + ------- + net : brian2.core.network.Network + namespace : dict + Namespace containing reconstructed TimedArray objects and + any other non-Quantity identifiers from the original network. + """ + model_dict, arrays = self._load_archive(filepath) + self._check_version(model_dict) + + components = model_dict.get('components', {}) + state_vars = model_dict.get('state_variables', {}) + connectivity = model_dict.get('connectivity', {}) + + namespace = {} + groups_by_name = {} + all_objects = [] + + # --- NeuronGroups ------------------------------------------------- + # Must be reconstructed before Synapses (source/target resolution). + for ng_dict in components.get('neurongroup', []): + ng = self._reconstruct_neurongroup(ng_dict, namespace) + groups_by_name[ng.name] = ng + all_objects.append(ng) + + # --- Synapses ----------------------------------------------------- + # After all groups exist; connect() uses stored i/j arrays. + for syn_dict in components.get('synapses', []): + syn = self._reconstruct_synapses( + syn_dict, groups_by_name, connectivity, arrays, namespace + ) + groups_by_name[syn.name] = syn + all_objects.append(syn) + + # --- Restore state ------------------------------------------------ + # Done AFTER connect() so DynamicArrayVariable sizes (set by + # connect()) are correct before we write synaptic variable values. + # Mirrors VariableOwner._restore_from_full_state() (group.py:780). + for obj in all_objects: + self._restore_state(obj, state_vars, arrays) + + net = Network(*all_objects) + net.t_ = model_dict.get('t', 0.0) + return net, namespace + + # ------------------------------------------------------------------ + # Reconstruction helpers + # ------------------------------------------------------------------ + + def _reconstruct_neurongroup(self, ng_dict, namespace): + """ + Reconstruct a NeuronGroup from its serialized dict. + + Consumes collect_NeuronGroup() (collector.py:20) output, augmented + by BrianExporter's 'equations_str' field. + + Key mappings + ------------ + ng_dict['N'] ← group._N (collector.py:47) + ng_dict['equations_str'] ← str(user_equations) (added by exporter) + ng_dict['events']['spike']['threshold']['code'] (collector.py:254) + ng_dict['events']['spike']['reset']['code'] (collector.py:262) + ng_dict['events']['spike']['refractory'] (collector.py:269) + ng_dict['user_method'] ← method_choice (collector.py:50) + ng_dict['identifiers'] ← _prepare_identifiers (helper.py:12) + """ + N = ng_dict['N'] + model_str = ng_dict.get('equations_str', '') + kwargs = {'name': ng_dict['name']} + + if ng_dict.get('user_method'): + kwargs['method'] = ng_dict['user_method'] + + # Extract threshold / reset / refractory from the events dict + # produced by collect_Events() (collector.py:225). + events = _restore_obj(ng_dict.get('events', {})) + if 'spike' in events: + spike = events['spike'] + kwargs['threshold'] = spike['threshold']['code'] + if 'reset' in spike: + kwargs['reset'] = spike['reset']['code'] + if 'refractory' in spike: + kwargs['refractory'] = spike['refractory'] + + # Rebuild namespace from stored identifiers so the equation string + # can resolve user-defined constants (e.g. tau = 10*ms). + # _prepare_identifiers() (helper.py:12) filters to Quantity, + # TimedArray, and custom Function objects only. + identifiers = _restore_obj(ng_dict.get('identifiers', {})) + namespace.update(identifiers) + + return NeuronGroup(N, model_str, namespace=namespace, **kwargs) + + def _reconstruct_synapses(self, syn_dict, groups_by_name, + connectivity, arrays, namespace): + """ + Reconstruct a Synapses object and restore connectivity. + + Consumes collect_Synapses() (collector.py:565) output, augmented + by BrianExporter's 'equations_str' and connectivity arrays. + + Key mappings + ------------ + syn_dict['source'] ← collect_SpikeSource() (collector.py:396) + str or {'start','stop','group'} for Subgroups + syn_dict['pathways'][*]['prepost'] / ['code'] (collector.py:619-631) + connectivity[name]['i_key'], ['j_key'] ← _synaptic_pre/_post arrays + """ + source = self._resolve_source(syn_dict['source'], groups_by_name) + target = self._resolve_source(syn_dict['target'], groups_by_name) + + kwargs = {'name': syn_dict['name'], 'namespace': namespace} + + if syn_dict.get('equations_str'): + kwargs['model'] = syn_dict['equations_str'] + + if syn_dict.get('user_method'): + kwargs['method'] = syn_dict['user_method'] + + # Extract on_pre / on_post from the pathways list. + # collect_Synapses() (collector.py:619) stores each SynapticPathway + # as {'prepost': 'pre'/'post', 'code': str, ...}. + for pathway in syn_dict.get('pathways', []): + if pathway['prepost'] == 'pre': + kwargs['on_pre'] = pathway['code'] + elif pathway['prepost'] == 'post': + kwargs['on_post'] = pathway['code'] + + syn = Synapses(source, target, **kwargs) + + # Restore connectivity from stored i/j arrays rather than + # re-running the original condition string. This is the critical + # difference from BaseExporter: probabilistic connections (p=0.1) + # would produce different results each run. + conn = connectivity.get(syn_dict['name']) + if conn: + i_arr = arrays[conn['i_key']] + j_arr = arrays[conn['j_key']] + if len(i_arr) > 0: + syn.connect(i=i_arr, j=j_arr) + else: + syn.connect() + + return syn + + def _resolve_source(self, source_ref, groups_by_name): + """ + Resolve a collect_SpikeSource() return value to a Brian2 group. + + collect_SpikeSource() (collector.py:396) returns: + str → regular group (use name directly) + {'start': int, 'stop': int, + 'group': str} → Subgroup slice + Note: 'stop' in the dict is inclusive (source.stop - 1 at line 406). + """ + if isinstance(source_ref, dict): + parent = groups_by_name[source_ref['group']] + return parent[source_ref['start']:source_ref['stop'] + 1] + return groups_by_name[source_ref] + + def _restore_state(self, obj, state_vars, arrays): + """ + Set state variable values from arrays.npz after construction. + + Called after connect() so DynamicArrayVariable sizes are finalised. + Mirrors VariableOwner._restore_from_full_state() (brian2 group.py:780): + resize DynamicArrayVariables before calling set_value(). + """ + if not hasattr(obj, 'variables'): + return + + for key, info in state_vars.items(): + obj_name, var_name = key.split('.', 1) + if obj_name != obj.name: + continue + if var_name not in obj.variables: + continue + + var = obj.variables[var_name] + if not isinstance(var, ArrayVariable) or var.read_only: + continue + + array_key = info['array_key'] + if array_key not in arrays: + continue + + values = arrays[array_key] + try: + if isinstance(var, DynamicArrayVariable): + var.resize(len(values)) + var.set_value(values) + except Exception as exc: + warnings.warn(f'Could not restore {key}: {exc}') + + # ------------------------------------------------------------------ + # Archive I/O + # ------------------------------------------------------------------ + + def _load_archive(self, filepath): + with zipfile.ZipFile(filepath, 'r') as zf: + model_dict = json.loads(zf.read('model.json').decode()) + arrays_buf = io.BytesIO(zf.read('arrays.npz')) + arrays = dict(np.load(arrays_buf, allow_pickle=False)) + return model_dict, arrays + + def _check_version(self, model_dict): + file_ver = model_dict.get('brian_version', 'unknown') + if file_ver != brian2.__version__: + warnings.warn( + f'Archive was created with Brian {file_ver}; ' + f'current Brian is {brian2.__version__}. ' + 'This may cause compatibility issues.', + UserWarning, + stacklevel=3, + ) diff --git a/brian2tools/tests/test_poc_exporter.py b/brian2tools/tests/test_poc_exporter.py new file mode 100644 index 00000000..379d924c --- /dev/null +++ b/brian2tools/tests/test_poc_exporter.py @@ -0,0 +1,287 @@ +""" +Tests for BrianExporter and BrianImporter (PoC). + +Follows the same style as test_baseexport.py: plain functions, no test +classes, imports from brian2 at the top. +""" + +import json +import os +import tempfile +import zipfile + +import numpy as np +from numpy.testing import assert_allclose, assert_array_equal + +import pytest + +from brian2 import ( + Network, + NeuronGroup, + Synapses, + StateMonitor, + SpikeMonitor, + start_scope, + ms, + mV, + volt, +) + +from brian2tools.baseexport.exporter import BrianExporter +from brian2tools.baseimport import BrianImporter + +# Module-level constant so serialize() can resolve 'tau' in the namespace. +# Matches the pattern used in test_baseexport.py: identifiers that appear in +# equation strings must be resolvable at the call site of serialize(). +TAU = 10 * ms + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_lif_network(N=10, run_duration=5 * ms): + """Small deterministic LIF network used across multiple tests.""" + start_scope() + G = NeuronGroup( + N, + 'dv/dt = (1 - v) / TAU : 1', + threshold='v > 0.9', + reset='v = 0', + method='exact', + namespace={'TAU': TAU}, + name='lif_group', + ) + G.v = 'rand()' + net = Network(G) + net.run(run_duration) + return net, G + + +def _serialize_to_tmp(net): + """Serialize net to a temp file; return the path.""" + fd, path = tempfile.mkstemp(suffix='.brian') + os.close(fd) + # level=0: get_local_namespace captures the frame of _serialize_to_tmp, + # which inherits f_globals from the test module (contains TAU etc.) + BrianExporter().serialize(net, path) + return path + + +# --------------------------------------------------------------------------- +# Archive structure tests +# --------------------------------------------------------------------------- + +def test_archive_is_valid_zip(): + """Output file must be a valid ZIP archive.""" + net, _ = _make_lif_network() + path = _serialize_to_tmp(net) + try: + assert zipfile.is_zipfile(path) + finally: + os.unlink(path) + + +def test_archive_contains_required_files(): + """ZIP must contain model.json and arrays.npz.""" + net, _ = _make_lif_network() + path = _serialize_to_tmp(net) + try: + with zipfile.ZipFile(path) as zf: + names = zf.namelist() + assert 'model.json' in names + assert 'arrays.npz' in names + finally: + os.unlink(path) + + +def test_model_json_is_valid_json(): + """model.json must parse without error.""" + net, _ = _make_lif_network() + path = _serialize_to_tmp(net) + try: + with zipfile.ZipFile(path) as zf: + model = json.loads(zf.read('model.json').decode()) + assert 'components' in model + assert 'format_version' in model + assert 'brian_version' in model + finally: + os.unlink(path) + + +def test_no_raw_quantity_in_json(): + """ + model.json must not contain any raw Quantity objects. + + collect_Equations() (collector.py:212) stores eqs.unit as a Quantity. + BrianExporter._json_safe() must convert it; this test verifies that. + """ + net, _ = _make_lif_network() + path = _serialize_to_tmp(net) + try: + with zipfile.ZipFile(path) as zf: + raw = zf.read('model.json').decode() + # A raw Quantity would repr as e.g. '1. * volt' or 'volt' + # After conversion it appears as {'__type__': 'quantity', ...} + # The JSON string must be parseable without a custom decoder. + model = json.loads(raw) + assert model is not None # passed json.loads → no raw Quantity + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# Structure capture tests +# --------------------------------------------------------------------------- + +def test_neurongroup_captured_in_components(): + """NeuronGroup must appear under components['neurongroup'].""" + net, G = _make_lif_network() + path = _serialize_to_tmp(net) + try: + with zipfile.ZipFile(path) as zf: + model = json.loads(zf.read('model.json').decode()) + ng_list = model['components']['neurongroup'] + assert len(ng_list) == 1 + assert ng_list[0]['name'] == 'lif_group' + assert ng_list[0]['N'] == 10 + finally: + os.unlink(path) + + +def test_equations_str_field_present(): + """ + BrianExporter adds 'equations_str' to NeuronGroup dicts so that + BrianImporter can call NeuronGroup(N, model_str) directly. + collect_NeuronGroup() does not add this field. + """ + net, _ = _make_lif_network() + path = _serialize_to_tmp(net) + try: + with zipfile.ZipFile(path) as zf: + model = json.loads(zf.read('model.json').decode()) + ng_dict = model['components']['neurongroup'][0] + assert 'equations_str' in ng_dict + assert 'dv/dt' in ng_dict['equations_str'] + finally: + os.unlink(path) + + +def test_state_variables_captured(): + """State variable values must appear in state_variables and arrays.npz.""" + net, G = _make_lif_network() + path = _serialize_to_tmp(net) + try: + with zipfile.ZipFile(path) as zf: + model = json.loads(zf.read('model.json').decode()) + arrays = dict(np.load( + __import__('io').BytesIO(zf.read('arrays.npz')), + allow_pickle=False, + )) + key = 'lif_group.v' + assert key in model['state_variables'] + arr_key = model['state_variables'][key]['array_key'] + assert arr_key in arrays + assert_allclose(arrays[arr_key], G.v[:]) + finally: + os.unlink(path) + + +def test_connectivity_arrays_captured(): + """ + _synaptic_pre and _synaptic_post must appear in arrays.npz. + + collect_Synapses() (collector.py:565) does not store these arrays; + BrianExporter._collect_connectivity() adds them. + """ + start_scope() + G = NeuronGroup(10, 'dv/dt = (1-v)/TAU : 1', + threshold='v>0.9', reset='v=0', method='exact', + namespace={'TAU': TAU}) + S = Synapses(G, G, 'w : 1', on_pre='v += w') + S.connect(j='i') + S.w = '0.1' + net = Network(G, S) + net.run(1 * ms) + + path = _serialize_to_tmp(net) + try: + with zipfile.ZipFile(path) as zf: + model = json.loads(zf.read('model.json').decode()) + arrays = dict(np.load( + __import__('io').BytesIO(zf.read('arrays.npz')), + allow_pickle=False, + )) + conn = model['connectivity']['synapses'] + i_arr = arrays[conn['i_key']] + j_arr = arrays[conn['j_key']] + assert_array_equal(i_arr, S.i[:]) + assert_array_equal(j_arr, S.j[:]) + finally: + os.unlink(path) + + +# --------------------------------------------------------------------------- +# Round-trip tests +# --------------------------------------------------------------------------- + +def test_round_trip_neurongroup_state(): + """ + Reconstruct a NeuronGroup; state variable v must match exactly. + """ + net, G = _make_lif_network(N=20) + original_v = G.v[:].copy() + path = _serialize_to_tmp(net) + try: + net2, _ = BrianImporter().load(path) + # Find the reconstructed group by name + G2 = next(o for o in net2.objects + if o.name == 'lif_group') + assert len(G2) == 20 + assert_allclose(G2.v[:], original_v) + finally: + os.unlink(path) + + +def test_round_trip_synapses_connectivity(): + """ + Round-trip a network with Synapses; i[:] and j[:] must match exactly. + + This verifies the core insight: BrianImporter restores connectivity + from stored _synaptic_pre/_synaptic_post arrays, not by re-running + the probabilistic connect() call. + """ + start_scope() + G = NeuronGroup(20, 'dv/dt = (1-v)/TAU : 1', + threshold='v>0.9', reset='v=0', method='exact', + namespace={'TAU': TAU}) + S = Synapses(G, G, 'w : 1', on_pre='v += w') + S.connect(p=0.5) # probabilistic — must NOT be re-run on import + S.w = 'rand() * 0.3' + net = Network(G, S) + net.run(2 * ms) + + original_i = S.i[:].copy() + original_j = S.j[:].copy() + original_w = S.w[:].copy() + path = _serialize_to_tmp(net) + try: + net2, _ = BrianImporter().load(path) + S2 = next(o for o in net2.objects if isinstance(o, Synapses)) + assert_array_equal(S2.i[:], original_i) + assert_array_equal(S2.j[:], original_j) + assert_allclose(S2.w[:], original_w) + finally: + os.unlink(path) + + +def test_network_time_restored(): + """net.t_ must be preserved across serialize/load.""" + net, _ = _make_lif_network(run_duration=7 * ms) + original_t = float(net.t_) + path = _serialize_to_tmp(net) + try: + net2, _ = BrianImporter().load(path) + assert net2.t_ == pytest.approx(original_t) + finally: + os.unlink(path) diff --git a/examples/poc_exporter_demo.py b/examples/poc_exporter_demo.py new file mode 100644 index 00000000..69fcc195 --- /dev/null +++ b/examples/poc_exporter_demo.py @@ -0,0 +1,116 @@ +""" +BrianExporter / BrianImporter — end-to-end demo. + +Demonstrates the core round-trip: + 1. Build and run a small LIF network. + 2. Serialize to a .brian archive with BrianExporter. + 3. Load it back with BrianImporter and verify state is preserved. + +Run: + python examples/poc_exporter_demo.py +""" + +import json +import os +import zipfile + +import numpy as np +from numpy.testing import assert_allclose, assert_array_equal + +from brian2 import ( + Network, + NeuronGroup, + Synapses, + StateMonitor, + SpikeMonitor, + start_scope, + ms, + mV, +) + +from brian2tools.baseexport.exporter import BrianExporter +from brian2tools.baseimport import BrianImporter + +ARCHIVE = '/tmp/poc_demo.brian' + + +# --------------------------------------------------------------------------- +# 1. Build and run +# --------------------------------------------------------------------------- + +start_scope() + +tau = 10 * ms + +G = NeuronGroup( + 20, + 'dv/dt = (1 - v) / tau : 1', + threshold='v > 0.9', + reset='v = 0', + method='exact', + name='neurons', +) +G.v = 'rand()' + +S = Synapses(G, G, 'w : 1', on_pre='v += w', name='synapses') +S.connect(j='i') # one-to-one — deterministic, easy to verify +S.w = '0.05' + +mon = StateMonitor(G, 'v', record=True, name='voltage_mon') + +net = Network(G, S, mon) +net.run(5 * ms) + +print(f'[1] Network ran for {net.t / ms:.1f} ms') +print(f' G.v[:5] = {G.v[:5]}') +print(f' S.w[:5] = {S.w[:5]}') +print(f' N_syn = {len(S.i)}') + +# --------------------------------------------------------------------------- +# 2. Serialize +# --------------------------------------------------------------------------- + +BrianExporter().serialize(net, ARCHIVE) +print(f'\n[2] Serialized to {ARCHIVE}') + +# Show what is inside the archive +with zipfile.ZipFile(ARCHIVE) as zf: + model = json.loads(zf.read('model.json').decode()) + arrays = dict(np.load(__import__('io').BytesIO(zf.read('arrays.npz')), + allow_pickle=False)) + +print(f' components : {list(model["components"].keys())}') +print(f' arrays : {len(arrays)} entries') +print(f' network t : {model["t"] * 1000:.1f} ms') + +ng_dict = model['components']['neurongroup'][0] +print(f' equations_str snippet : {ng_dict["equations_str"].strip()[:60]}') + +conn = model['connectivity'].get('synapses', {}) +if conn: + print(f' connectivity i[:5] : {arrays[conn["i_key"]][:5]}') + +# --------------------------------------------------------------------------- +# 3. Reconstruct and verify +# --------------------------------------------------------------------------- + +net2, namespace = BrianImporter().load(ARCHIVE) + +G2 = next(o for o in net2.objects if o.name == 'neurons') +S2 = next(o for o in net2.objects if o.name == 'synapses') + +print(f'\n[3] Reconstructed network') +print(f' G2.v[:5] = {G2.v[:5]}') +print(f' S2.w[:5] = {S2.w[:5]}') + +# Verify state is identical +assert_allclose(G2.v[:], G.v[:], err_msg='v mismatch after round-trip') +assert_array_equal(S2.i[:], S.i[:]) +assert_array_equal(S2.j[:], S.j[:]) +assert_allclose(S2.w[:], S.w[:]) +assert net2.t_ == net.t_ + +print('\n[OK] All assertions passed — round-trip is lossless.') + +# Clean up +os.unlink(ARCHIVE)