From 23a96b8986f50286fc1284a03f9856f4fdc0f0b2 Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Sun, 8 Mar 2026 00:04:09 -0500 Subject: [PATCH] Tighten SignalObj and nspikeTrain MATLAB fixtures --- nstat/core.py | 12 +++++----- parity/class_fidelity.yml | 12 ++++++---- .../matlab_gold/nspiketrain_exactness.mat | Bin 1073 -> 1451 bytes .../matlab_gold/signalobj_exactness.mat | Bin 1102 -> 1310 bytes tests/test_matlab_gold_fixtures.py | 21 ++++++++++++++++++ tests/test_nspiketrain_fidelity.py | 4 ++-- .../matlab/export_matlab_gold_fixtures.m | 18 +++++++++++++++ 7 files changed, 56 insertions(+), 11 deletions(-) diff --git a/nstat/core.py b/nstat/core.py index 31b03234..3ea85c59 100644 --- a/nstat/core.py +++ b/nstat/core.py @@ -481,7 +481,7 @@ def setMaxTime(self, maxTime: float | None = None, holdVals: int = 0) -> None: timeVec = self.getTime() if float(np.max(timeVec)) < target: minTime = float(np.min(timeVec)) - n_samples = int(round(self.sampleRate * (target - minTime))) + 1 + n_samples = int(float(self.sampleRate) * (target - minTime) + 1.0) n_samples = max(n_samples, timeVec.size) newTime = np.linspace(minTime, target, n_samples, dtype=float) numSamples = int(newTime.size - timeVec.size) @@ -1322,7 +1322,10 @@ def computeStatistics(self, makePlots: int = 0) -> None: self.burstRate = float(self.numBursts / duration) if duration > 0 else np.nan self.numSpikesPerBurst = (burst_end - burst_start + 1).astype(float) self.avgSpikesPerBurst = float(np.mean(self.numSpikesPerBurst + 1.0)) - self.stdSpikesPerBurst = float(np.std(self.numSpikesPerBurst + 1.0)) + if self.numSpikesPerBurst.size > 1: + self.stdSpikesPerBurst = float(np.std(self.numSpikesPerBurst + 1.0, ddof=1)) + elif self.numSpikesPerBurst.size == 1: + self.stdSpikesPerBurst = 0.0 self.Lstatistic = self.getLStatistic() if makePlots == 1: @@ -1480,9 +1483,8 @@ def computeRate(self) -> SignalObj: def restoreToOriginal(self) -> None: self.spikeTimes = self.originalSpikeTimes.copy() - self.sampleRate = float(self.originalSampleRate) - self.minTime = float(self.originalMinTime) - self.maxTime = float(self.originalMaxTime) + self.minTime = float(np.min(self.spikeTimes)) if self.spikeTimes.size else 0.0 + self.maxTime = float(np.max(self.spikeTimes)) if self.spikeTimes.size else 0.0 self.clearSigRep() def partitionNST( diff --git a/parity/class_fidelity.yml b/parity/class_fidelity.yml index 14f6e1a2..a4ff1489 100644 --- a/parity/class_fidelity.yml +++ b/parity/class_fidelity.yml @@ -41,8 +41,12 @@ items: - Structure serialization is close but not exhaustive for every MATLAB-only field. required_remediation: - Extend the committed MATLAB-derived fixtures beyond derivative, integral, spline - resampling, filtering, and xcorr to cover autocorrelation selectors and the remaining + resampling, filtering, `makeCompatible`, and `xcorr` to cover the remaining spectral utility methods. + - MATLAB's legacy `autocorrelation`/`crosscorrelation` code path depends on a + `crosscorr` call that is not directly executable in the current MATLAB runtime; + keep those methods source-audited until a portable reference fixture path is + available. plotting_report_parity: Core plotting and correlation helpers are implemented; some MATLAB-only plot selectors, spectral utilities, and report-style helpers remain lighter. @@ -103,9 +107,9 @@ items: - Some MATLAB visual styling and distribution-fit detail in the ISI plotting helpers remains lighter than MATLAB. required_remediation: - - Extend the committed MATLAB-derived fixtures beyond getSigRep, partitionNST, and - burst-stat summaries to cover ISI plotting traces and the remaining visualization - details. + - Extend the committed MATLAB-derived fixtures beyond getSigRep, partitionNST, restore-bound + semantics, and burst-stat summaries to cover ISI plotting traces and the remaining + visualization details. plotting_report_parity: Raster, ISI, and burst-oriented plotting helpers now execute on the canonical class, though visual detail remains lighter than MATLAB. - matlab_name: nstColl diff --git a/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat b/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat index 645225083aededc764ced2b0a8a657a0d132f399..bdcfad9b3ffa8a060e78055d108198063f3e5780 100644 GIT binary patch delta 408 zcmdnUv6_2=BbSk>m4TUpk%5uf#6abV32Yln_A=M=GcYii0dd8g$H@r^9~hF%6rMF) zQaHf%)L1EzQ9;j3SMQ|0SMT|=XS(%Ip7DL!vTC-eFiY4iW`+k`+!sK)Er`(VyQTxG zTR8IhzYShb=Gijb%Hev!Uysj#LIpo9T`j#%e;sdry%T5rboo4ViZT*48W=Dzyi?*5tpOPTWJCR5fa?D`r~}SxpDkrF?2??(VC3Ua z$k4Tp?IlQ`4O|~exG@?ki*z#=24w~ode+2c<`p=mm1Pw=WK9(IU|G@E**J4%BO}AR HJ6uWt2JMaN delta 27 jcmZ3@y^&*rBbR}Zm7$@6k%5uf#6abV32Yln_A&ziYrzN< diff --git a/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat b/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat index d931cfe1a7cc68b8b6f33cbc6a6c0e2713d64f5b..b15a0a7406133ab5bf57743ede528f6897a29466 100644 GIT binary patch delta 216 zcmX@dF^_A4BbSk>m4T^(k%5uf#6abV32YlnZZWg+GcYi?Oy18dRnPY5$)7t-M<)H4 zbmmUT1}i^>YY7ewo6{W<`fXe6lI^qAIAU5qy}!ZCa8X%$3P_t75Le83oScyGfg#CE z;aS5ag#%nqn={-r#28>^&AGB=(vK-eCb_81kWv%ln_$Cm>;bzCNVgqacL!Lv5t4Qm hnD#|yzVtx!uc%5>c*eBZLF@njb}nNEh6EYTWdKBsT0j5* delta 27 jcmbQob&g|#BbR}Zm7$@6k%5uf#6abV32YlnZZQJ@ZC40D diff --git a/tests/test_matlab_gold_fixtures.py b/tests/test_matlab_gold_fixtures.py index 6693a368..48c2ad98 100644 --- a/tests/test_matlab_gold_fixtures.py +++ b/tests/test_matlab_gold_fixtures.py @@ -27,12 +27,15 @@ def _vector(payload: dict[str, np.ndarray], key: str) -> np.ndarray: def test_signalobj_matches_matlab_gold_fixture() -> None: payload = _load_fixture("signalobj_exactness.mat") signal = SignalObj(_vector(payload, "time"), np.asarray(payload["data"], dtype=float), "sig", "time", "s", "u", ["x1", "x2"]) + signal_1 = signal.getSubSignal(1) + signal_2 = SignalObj(np.arange(0.05, 0.5, 0.1), [0.0, 1.0, 0.0, -1.0, 0.0], "sig2", "time", "s", "u", ["x3"]) filtered = signal.filter(_vector(payload, "filter_b"), _vector(payload, "filter_a")) derivative = signal.derivative integral = signal.integral() resampled = signal.resample(_scalar(payload, "resample_rate")) xcorr = signal.getSubSignal(1).xcorr(signal.getSubSignal(2), int(_scalar(payload, "xcorr_maxlag"))) + compatible_left, compatible_right = signal_1.makeCompatible(signal_2, holdVals=1) np.testing.assert_allclose(filtered.data, np.asarray(payload["filtered_data"], dtype=float), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(derivative.data, np.asarray(payload["derivative_data"], dtype=float), rtol=1e-8, atol=1e-10) @@ -41,6 +44,9 @@ def test_signalobj_matches_matlab_gold_fixture() -> None: np.testing.assert_allclose(resampled.data, np.asarray(payload["resampled_data"], dtype=float), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(xcorr.time, _vector(payload, "xcorr_time"), rtol=1e-12, atol=1e-12) np.testing.assert_allclose(xcorr.data.reshape(-1), _vector(payload, "xcorr_data"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(compatible_left.time, _vector(payload, "compat_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(compatible_left.data.reshape(-1), _vector(payload, "compat_left_data"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(compatible_right.data.reshape(-1), _vector(payload, "compat_right_data"), rtol=1e-8, atol=1e-10) def test_nspiketrain_matches_matlab_gold_fixture() -> None: @@ -71,6 +77,21 @@ def test_nspiketrain_matches_matlab_gold_fixture() -> None: np.testing.assert_allclose(parts.getNST(1).spikeTimes, _vector(payload, "part1_spikes"), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(parts.getNST(2).spikeTimes, _vector(payload, "part2_spikes"), rtol=1e-8, atol=1e-10) + restore_train = nspikeTrain(_vector(payload, "spikeTimes"), "restore", 0.2, -0.1, 0.8, "time", "s", "spikes", "spk", -1) + restore_train.setSigRep(0.1, -0.1, 0.8) + restore_train.setMinTime(-0.3) + restore_train.setMaxTime(1.1) + restore_train.restoreToOriginal() + + np.testing.assert_allclose(float(restore_train.minTime), _scalar(payload, "restore_min_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(restore_train.maxTime), _scalar(payload, "restore_max_time"), rtol=1e-12, atol=1e-12) + + burst_train = nspikeTrain([0.0, 0.001, 0.002, 0.007, 0.507, 0.508, 0.509, 0.514], "bursting", 0.001, 0.0, 0.6, "time", "s", "spikes", "spk", 0) + np.testing.assert_allclose(float(burst_train.avgSpikesPerBurst), _scalar(payload, "burst_avgSpikesPerBurst"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(burst_train.stdSpikesPerBurst), _scalar(payload, "burst_stdSpikesPerBurst"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(burst_train.numBursts), _scalar(payload, "burst_numBursts"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(burst_train.numSpikesPerBurst, _vector(payload, "burst_numSpikesPerBurst"), rtol=1e-8, atol=1e-10) + def test_cif_eval_surface_matches_matlab_gold_fixture() -> None: payload = _load_fixture("cif_exactness.mat") diff --git a/tests/test_nspiketrain_fidelity.py b/tests/test_nspiketrain_fidelity.py index c8430566..3afdfd70 100644 --- a/tests/test_nspiketrain_fidelity.py +++ b/tests/test_nspiketrain_fidelity.py @@ -61,8 +61,8 @@ def test_nspiketrain_setsigrep_restore_and_field_access_match_matlab_surface() - assert train.getFieldVal("missing") == [] train.restoreToOriginal() - assert train.sampleRate == 5.0 - np.testing.assert_allclose([train.minTime, train.maxTime], [0.0, 1.0]) + assert train.sampleRate == 10.0 + np.testing.assert_allclose([train.minTime, train.maxTime], [0.2, 0.6]) def test_nspiketrain_compute_statistics_matches_matlab_style_burst_metrics() -> None: diff --git a/tools/parity/matlab/export_matlab_gold_fixtures.m b/tools/parity/matlab/export_matlab_gold_fixtures.m index ac44f983..2c4fe953 100644 --- a/tools/parity/matlab/export_matlab_gold_fixtures.m +++ b/tools/parity/matlab/export_matlab_gold_fixtures.m @@ -28,12 +28,15 @@ function export_signalobj_fixture(fixtureRoot) t = (0:0.1:0.4)'; data = [1 0; 2 1; 4 0; 8 -1; 16 0]; s = SignalObj(t, data, 'sig', 'time', 's', 'u', {'x1', 'x2'}); +s1 = s.getSubSignal(1); +s2 = SignalObj((0.05:0.1:0.45)', [0; 1; 0; -1; 0], 'sig2', 'time', 's', 'u', {'x3'}); filtered = s.filter([0.25 0.5 0.25], 1); resampled = s.resample(20); derivative = s.derivative; integral_sig = s.integral(); xc = xcorr(s.getSubSignal(1), s.getSubSignal(2), 2); +[s1c, s2c] = s1.makeCompatible(s2, 1); payload = struct(); payload.time = s.time; @@ -49,6 +52,9 @@ function export_signalobj_fixture(fixtureRoot) payload.xcorr_maxlag = 2; payload.xcorr_time = xc.time; payload.xcorr_data = xc.data; +payload.compat_time = s1c.time; +payload.compat_left_data = s1c.data; +payload.compat_right_data = s2c.data; save(fullfile(fixtureRoot, 'signalobj_exactness.mat'), '-struct', 'payload'); end @@ -59,6 +65,12 @@ function export_nspiketrain_fixture(fixtureRoot) nst = nspikeTrain(spikeTimes, 'nst', binwidth, 0.0, 0.5, 'time', 's', 'spikes', 'spk', 0); sig = nst.getSigRep(binwidth, 0.0, 0.5); parts = nst.partitionNST([0.0 0.2 0.5]); +restoreTrain = nspikeTrain(spikeTimes, 'restore', 0.2, -0.1, 0.8, 'time', 's', 'spikes', 'spk', -1); +restoreTrain.setSigRep(0.1, -0.1, 0.8); +restoreTrain.setMinTime(-0.3); +restoreTrain.setMaxTime(1.1); +restoreTrain.restoreToOriginal(); +burstTrain = nspikeTrain([0.0; 0.001; 0.002; 0.007; 0.507; 0.508; 0.509; 0.514], 'bursting', 0.001, 0.0, 0.6, 'time', 's', 'spikes', 'spk', 0); payload = struct(); payload.spikeTimes = spikeTimes; @@ -76,6 +88,12 @@ function export_nspiketrain_fixture(fixtureRoot) payload.numSpikesPerBurst = nst.numSpikesPerBurst; payload.part1_spikes = parts.getNST(1).spikeTimes; payload.part2_spikes = parts.getNST(2).spikeTimes; +payload.restore_min_time = restoreTrain.minTime; +payload.restore_max_time = restoreTrain.maxTime; +payload.burst_avgSpikesPerBurst = burstTrain.avgSpikesPerBurst; +payload.burst_stdSpikesPerBurst = burstTrain.stdSpikesPerBurst; +payload.burst_numBursts = burstTrain.numBursts; +payload.burst_numSpikesPerBurst = burstTrain.numSpikesPerBurst; save(fullfile(fixtureRoot, 'nspiketrain_exactness.mat'), '-struct', 'payload'); end