diff --git a/nstat/core.py b/nstat/core.py index 31b03234..3ea85c59 100644 --- a/nstat/core.py +++ b/nstat/core.py @@ -481,7 +481,7 @@ def setMaxTime(self, maxTime: float | None = None, holdVals: int = 0) -> None: timeVec = self.getTime() if float(np.max(timeVec)) < target: minTime = float(np.min(timeVec)) - n_samples = int(round(self.sampleRate * (target - minTime))) + 1 + n_samples = int(float(self.sampleRate) * (target - minTime) + 1.0) n_samples = max(n_samples, timeVec.size) newTime = np.linspace(minTime, target, n_samples, dtype=float) numSamples = int(newTime.size - timeVec.size) @@ -1322,7 +1322,10 @@ def computeStatistics(self, makePlots: int = 0) -> None: self.burstRate = float(self.numBursts / duration) if duration > 0 else np.nan self.numSpikesPerBurst = (burst_end - burst_start + 1).astype(float) self.avgSpikesPerBurst = float(np.mean(self.numSpikesPerBurst + 1.0)) - self.stdSpikesPerBurst = float(np.std(self.numSpikesPerBurst + 1.0)) + if self.numSpikesPerBurst.size > 1: + self.stdSpikesPerBurst = float(np.std(self.numSpikesPerBurst + 1.0, ddof=1)) + elif self.numSpikesPerBurst.size == 1: + self.stdSpikesPerBurst = 0.0 self.Lstatistic = self.getLStatistic() if makePlots == 1: @@ -1480,9 +1483,8 @@ def computeRate(self) -> SignalObj: def restoreToOriginal(self) -> None: self.spikeTimes = self.originalSpikeTimes.copy() - self.sampleRate = float(self.originalSampleRate) - self.minTime = float(self.originalMinTime) - self.maxTime = float(self.originalMaxTime) + self.minTime = float(np.min(self.spikeTimes)) if self.spikeTimes.size else 0.0 + self.maxTime = float(np.max(self.spikeTimes)) if self.spikeTimes.size else 0.0 self.clearSigRep() def partitionNST( diff --git a/parity/class_fidelity.yml b/parity/class_fidelity.yml index 14f6e1a2..a4ff1489 100644 --- a/parity/class_fidelity.yml +++ b/parity/class_fidelity.yml @@ -41,8 +41,12 @@ items: - Structure serialization is close but not exhaustive for every MATLAB-only field. required_remediation: - Extend the committed MATLAB-derived fixtures beyond derivative, integral, spline - resampling, filtering, and xcorr to cover autocorrelation selectors and the remaining + resampling, filtering, `makeCompatible`, and `xcorr` to cover the remaining spectral utility methods. + - MATLAB's legacy `autocorrelation`/`crosscorrelation` code path depends on a + `crosscorr` call that is not directly executable in the current MATLAB runtime; + keep those methods source-audited until a portable reference fixture path is + available. plotting_report_parity: Core plotting and correlation helpers are implemented; some MATLAB-only plot selectors, spectral utilities, and report-style helpers remain lighter. @@ -103,9 +107,9 @@ items: - Some MATLAB visual styling and distribution-fit detail in the ISI plotting helpers remains lighter than MATLAB. required_remediation: - - Extend the committed MATLAB-derived fixtures beyond getSigRep, partitionNST, and - burst-stat summaries to cover ISI plotting traces and the remaining visualization - details. + - Extend the committed MATLAB-derived fixtures beyond getSigRep, partitionNST, restore-bound + semantics, and burst-stat summaries to cover ISI plotting traces and the remaining + visualization details. plotting_report_parity: Raster, ISI, and burst-oriented plotting helpers now execute on the canonical class, though visual detail remains lighter than MATLAB. - matlab_name: nstColl diff --git a/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat b/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat index 64522508..bdcfad9b 100644 Binary files a/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat and b/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat b/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat index d931cfe1..b15a0a74 100644 Binary files a/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat and b/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat differ diff --git a/tests/test_matlab_gold_fixtures.py b/tests/test_matlab_gold_fixtures.py index 6693a368..48c2ad98 100644 --- a/tests/test_matlab_gold_fixtures.py +++ b/tests/test_matlab_gold_fixtures.py @@ -27,12 +27,15 @@ def _vector(payload: dict[str, np.ndarray], key: str) -> np.ndarray: def test_signalobj_matches_matlab_gold_fixture() -> None: payload = _load_fixture("signalobj_exactness.mat") signal = SignalObj(_vector(payload, "time"), np.asarray(payload["data"], dtype=float), "sig", "time", "s", "u", ["x1", "x2"]) + signal_1 = signal.getSubSignal(1) + signal_2 = SignalObj(np.arange(0.05, 0.5, 0.1), [0.0, 1.0, 0.0, -1.0, 0.0], "sig2", "time", "s", "u", ["x3"]) filtered = signal.filter(_vector(payload, "filter_b"), _vector(payload, "filter_a")) derivative = signal.derivative integral = signal.integral() resampled = signal.resample(_scalar(payload, "resample_rate")) xcorr = signal.getSubSignal(1).xcorr(signal.getSubSignal(2), int(_scalar(payload, "xcorr_maxlag"))) + compatible_left, compatible_right = signal_1.makeCompatible(signal_2, holdVals=1) np.testing.assert_allclose(filtered.data, np.asarray(payload["filtered_data"], dtype=float), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(derivative.data, np.asarray(payload["derivative_data"], dtype=float), rtol=1e-8, atol=1e-10) @@ -41,6 +44,9 @@ def test_signalobj_matches_matlab_gold_fixture() -> None: np.testing.assert_allclose(resampled.data, np.asarray(payload["resampled_data"], dtype=float), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(xcorr.time, _vector(payload, "xcorr_time"), rtol=1e-12, atol=1e-12) np.testing.assert_allclose(xcorr.data.reshape(-1), _vector(payload, "xcorr_data"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(compatible_left.time, _vector(payload, "compat_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(compatible_left.data.reshape(-1), _vector(payload, "compat_left_data"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(compatible_right.data.reshape(-1), _vector(payload, "compat_right_data"), rtol=1e-8, atol=1e-10) def test_nspiketrain_matches_matlab_gold_fixture() -> None: @@ -71,6 +77,21 @@ def test_nspiketrain_matches_matlab_gold_fixture() -> None: np.testing.assert_allclose(parts.getNST(1).spikeTimes, _vector(payload, "part1_spikes"), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(parts.getNST(2).spikeTimes, _vector(payload, "part2_spikes"), rtol=1e-8, atol=1e-10) + restore_train = nspikeTrain(_vector(payload, "spikeTimes"), "restore", 0.2, -0.1, 0.8, "time", "s", "spikes", "spk", -1) + restore_train.setSigRep(0.1, -0.1, 0.8) + restore_train.setMinTime(-0.3) + restore_train.setMaxTime(1.1) + restore_train.restoreToOriginal() + + np.testing.assert_allclose(float(restore_train.minTime), _scalar(payload, "restore_min_time"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(restore_train.maxTime), _scalar(payload, "restore_max_time"), rtol=1e-12, atol=1e-12) + + burst_train = nspikeTrain([0.0, 0.001, 0.002, 0.007, 0.507, 0.508, 0.509, 0.514], "bursting", 0.001, 0.0, 0.6, "time", "s", "spikes", "spk", 0) + np.testing.assert_allclose(float(burst_train.avgSpikesPerBurst), _scalar(payload, "burst_avgSpikesPerBurst"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(burst_train.stdSpikesPerBurst), _scalar(payload, "burst_stdSpikesPerBurst"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(float(burst_train.numBursts), _scalar(payload, "burst_numBursts"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(burst_train.numSpikesPerBurst, _vector(payload, "burst_numSpikesPerBurst"), rtol=1e-8, atol=1e-10) + def test_cif_eval_surface_matches_matlab_gold_fixture() -> None: payload = _load_fixture("cif_exactness.mat") diff --git a/tests/test_nspiketrain_fidelity.py b/tests/test_nspiketrain_fidelity.py index c8430566..3afdfd70 100644 --- a/tests/test_nspiketrain_fidelity.py +++ b/tests/test_nspiketrain_fidelity.py @@ -61,8 +61,8 @@ def test_nspiketrain_setsigrep_restore_and_field_access_match_matlab_surface() - assert train.getFieldVal("missing") == [] train.restoreToOriginal() - assert train.sampleRate == 5.0 - np.testing.assert_allclose([train.minTime, train.maxTime], [0.0, 1.0]) + assert train.sampleRate == 10.0 + np.testing.assert_allclose([train.minTime, train.maxTime], [0.2, 0.6]) def test_nspiketrain_compute_statistics_matches_matlab_style_burst_metrics() -> None: diff --git a/tools/parity/matlab/export_matlab_gold_fixtures.m b/tools/parity/matlab/export_matlab_gold_fixtures.m index ac44f983..2c4fe953 100644 --- a/tools/parity/matlab/export_matlab_gold_fixtures.m +++ b/tools/parity/matlab/export_matlab_gold_fixtures.m @@ -28,12 +28,15 @@ function export_signalobj_fixture(fixtureRoot) t = (0:0.1:0.4)'; data = [1 0; 2 1; 4 0; 8 -1; 16 0]; s = SignalObj(t, data, 'sig', 'time', 's', 'u', {'x1', 'x2'}); +s1 = s.getSubSignal(1); +s2 = SignalObj((0.05:0.1:0.45)', [0; 1; 0; -1; 0], 'sig2', 'time', 's', 'u', {'x3'}); filtered = s.filter([0.25 0.5 0.25], 1); resampled = s.resample(20); derivative = s.derivative; integral_sig = s.integral(); xc = xcorr(s.getSubSignal(1), s.getSubSignal(2), 2); +[s1c, s2c] = s1.makeCompatible(s2, 1); payload = struct(); payload.time = s.time; @@ -49,6 +52,9 @@ function export_signalobj_fixture(fixtureRoot) payload.xcorr_maxlag = 2; payload.xcorr_time = xc.time; payload.xcorr_data = xc.data; +payload.compat_time = s1c.time; +payload.compat_left_data = s1c.data; +payload.compat_right_data = s2c.data; save(fullfile(fixtureRoot, 'signalobj_exactness.mat'), '-struct', 'payload'); end @@ -59,6 +65,12 @@ function export_nspiketrain_fixture(fixtureRoot) nst = nspikeTrain(spikeTimes, 'nst', binwidth, 0.0, 0.5, 'time', 's', 'spikes', 'spk', 0); sig = nst.getSigRep(binwidth, 0.0, 0.5); parts = nst.partitionNST([0.0 0.2 0.5]); +restoreTrain = nspikeTrain(spikeTimes, 'restore', 0.2, -0.1, 0.8, 'time', 's', 'spikes', 'spk', -1); +restoreTrain.setSigRep(0.1, -0.1, 0.8); +restoreTrain.setMinTime(-0.3); +restoreTrain.setMaxTime(1.1); +restoreTrain.restoreToOriginal(); +burstTrain = nspikeTrain([0.0; 0.001; 0.002; 0.007; 0.507; 0.508; 0.509; 0.514], 'bursting', 0.001, 0.0, 0.6, 'time', 's', 'spikes', 'spk', 0); payload = struct(); payload.spikeTimes = spikeTimes; @@ -76,6 +88,12 @@ function export_nspiketrain_fixture(fixtureRoot) payload.numSpikesPerBurst = nst.numSpikesPerBurst; payload.part1_spikes = parts.getNST(1).spikeTimes; payload.part2_spikes = parts.getNST(2).spikeTimes; +payload.restore_min_time = restoreTrain.minTime; +payload.restore_max_time = restoreTrain.maxTime; +payload.burst_avgSpikesPerBurst = burstTrain.avgSpikesPerBurst; +payload.burst_stdSpikesPerBurst = burstTrain.stdSpikesPerBurst; +payload.burst_numBursts = burstTrain.numBursts; +payload.burst_numSpikesPerBurst = burstTrain.numSpikesPerBurst; save(fullfile(fixtureRoot, 'nspiketrain_exactness.mat'), '-struct', 'payload'); end