From 21f86c6bab271a226ce38d6529a9e0b53f71aa5b Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Sun, 8 Mar 2026 00:27:12 -0500 Subject: [PATCH] Expand MATLAB reference cross-checks --- nstat/matlab_reference.py | 4 +++- tests/test_matlab_reference.py | 7 +++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/nstat/matlab_reference.py b/nstat/matlab_reference.py index 99fc8792..3d38cfe1 100644 --- a/nstat/matlab_reference.py +++ b/nstat/matlab_reference.py @@ -71,7 +71,7 @@ def run_point_process_reference(*, matlab_repo: str | Path | None = None, seed: for i=1:sC.numSpikeTrains ppSpikeCounts(i) = length(sC.getNST(i).spikeTimes); end - ppLambdaHead = lambda.data(1:5,1)'; + ppLambdaHead = lambda.data(1:10,1)'; """, nargout=0, ) @@ -106,6 +106,7 @@ def run_simulated_network_reference(*, matlab_repo: str | Path | None = None, se [tout,~,yout] = sim('SimulatedNetwork2',[stim.minTime stim.maxTime],options,stim.dataToStructure); netSpikeCounts = [sum(yout(:,1)>.5), sum(yout(:,2)>.5)]; netProbHead = yout(1:5,3:4); + netStateHead = yout(1:5,1:2); netActual = [0 1; -4 0]; """, nargout=0, @@ -113,6 +114,7 @@ def run_simulated_network_reference(*, matlab_repo: str | Path | None = None, se return { "spike_counts": _to_numpy(engine.workspace["netSpikeCounts"]).reshape(-1), "prob_head": _to_numpy(engine.workspace["netProbHead"]), + "state_head": _to_numpy(engine.workspace["netStateHead"]), "actual_network": _to_numpy(engine.workspace["netActual"]), } diff --git a/tests/test_matlab_reference.py b/tests/test_matlab_reference.py index 697df4d8..83d8cc4f 100644 --- a/tests/test_matlab_reference.py +++ b/tests/test_matlab_reference.py @@ -30,8 +30,10 @@ def test_matlab_reference_executes_only_when_engine_is_available() -> None: network = run_simulated_network_reference(matlab_repo=MATLAB_REPO_ROOT) assert point_process["spike_counts"].shape == (5,) + assert point_process["lambda_head"].shape == (10,) assert network["spike_counts"].shape == (2,) assert network["prob_head"].shape == (5, 2) + assert network["state_head"].shape == (5, 2) np.testing.assert_allclose(network["actual_network"], np.array([[0.0, 1.0], [-4.0, 0.0]], dtype=float)) @@ -57,7 +59,7 @@ def test_native_point_process_simulation_matches_matlab_lambda_head_when_engine_ ) matlab_ref = run_point_process_reference(matlab_repo=MATLAB_REPO_ROOT, seed=5) - np.testing.assert_allclose(lambda_cov.data[:5, 0], matlab_ref["lambda_head"], rtol=1e-6, atol=1e-8) + np.testing.assert_allclose(lambda_cov.data[:10, 0], matlab_ref["lambda_head"], rtol=1e-6, atol=1e-8) @pytest.mark.skipif(not MATLAB_REPO_ROOT.exists(), reason="MATLAB reference repo not available") @@ -69,4 +71,5 @@ def test_native_network_simulation_preserves_matlab_connectivity_layout_when_eng matlab_ref = run_simulated_network_reference(matlab_repo=MATLAB_REPO_ROOT, seed=4) np.testing.assert_allclose(native.actual_network, matlab_ref["actual_network"]) - assert matlab_ref["prob_head"].shape == (5, 2) + np.testing.assert_allclose(native.lambda_delta[:5], matlab_ref["prob_head"], rtol=1e-6, atol=1e-8) + assert np.all((matlab_ref["state_head"] == 0.0) | (matlab_ref["state_head"] == 1.0))