Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions nstat/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def setMaxTime(self, maxTime: float | None = None, holdVals: int = 0) -> None:
timeVec = self.getTime()
if float(np.max(timeVec)) < target:
minTime = float(np.min(timeVec))
n_samples = int(round(self.sampleRate * (target - minTime))) + 1
n_samples = int(float(self.sampleRate) * (target - minTime) + 1.0)
n_samples = max(n_samples, timeVec.size)
newTime = np.linspace(minTime, target, n_samples, dtype=float)
numSamples = int(newTime.size - timeVec.size)
Expand Down Expand Up @@ -1322,7 +1322,10 @@ def computeStatistics(self, makePlots: int = 0) -> None:
self.burstRate = float(self.numBursts / duration) if duration > 0 else np.nan
self.numSpikesPerBurst = (burst_end - burst_start + 1).astype(float)
self.avgSpikesPerBurst = float(np.mean(self.numSpikesPerBurst + 1.0))
self.stdSpikesPerBurst = float(np.std(self.numSpikesPerBurst + 1.0))
if self.numSpikesPerBurst.size > 1:
self.stdSpikesPerBurst = float(np.std(self.numSpikesPerBurst + 1.0, ddof=1))
elif self.numSpikesPerBurst.size == 1:
self.stdSpikesPerBurst = 0.0

self.Lstatistic = self.getLStatistic()
if makePlots == 1:
Expand Down Expand Up @@ -1480,9 +1483,8 @@ def computeRate(self) -> SignalObj:

def restoreToOriginal(self) -> None:
self.spikeTimes = self.originalSpikeTimes.copy()
self.sampleRate = float(self.originalSampleRate)
self.minTime = float(self.originalMinTime)
self.maxTime = float(self.originalMaxTime)
self.minTime = float(np.min(self.spikeTimes)) if self.spikeTimes.size else 0.0
self.maxTime = float(np.max(self.spikeTimes)) if self.spikeTimes.size else 0.0
self.clearSigRep()

def partitionNST(
Expand Down
12 changes: 8 additions & 4 deletions parity/class_fidelity.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ items:
- Structure serialization is close but not exhaustive for every MATLAB-only field.
required_remediation:
- Extend the committed MATLAB-derived fixtures beyond derivative, integral, spline
resampling, filtering, and xcorr to cover autocorrelation selectors and the remaining
resampling, filtering, `makeCompatible`, and `xcorr` to cover the remaining
spectral utility methods.
- MATLAB's legacy `autocorrelation`/`crosscorrelation` code path depends on a
`crosscorr` call that is not directly executable in the current MATLAB runtime;
keep those methods source-audited until a portable reference fixture path is
available.
plotting_report_parity: Core plotting and correlation helpers are implemented; some
MATLAB-only plot selectors, spectral utilities, and report-style helpers remain
lighter.
Expand Down Expand Up @@ -103,9 +107,9 @@ items:
- Some MATLAB visual styling and distribution-fit detail in the ISI plotting helpers
remains lighter than MATLAB.
required_remediation:
- Extend the committed MATLAB-derived fixtures beyond getSigRep, partitionNST, and
burst-stat summaries to cover ISI plotting traces and the remaining visualization
details.
- Extend the committed MATLAB-derived fixtures beyond getSigRep, partitionNST, restore-bound
semantics, and burst-stat summaries to cover ISI plotting traces and the remaining
visualization details.
plotting_report_parity: Raster, ISI, and burst-oriented plotting helpers now execute
on the canonical class, though visual detail remains lighter than MATLAB.
- matlab_name: nstColl
Expand Down
Binary file modified tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat
Binary file not shown.
Binary file modified tests/parity/fixtures/matlab_gold/signalobj_exactness.mat
Binary file not shown.
21 changes: 21 additions & 0 deletions tests/test_matlab_gold_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@ def _vector(payload: dict[str, np.ndarray], key: str) -> np.ndarray:
def test_signalobj_matches_matlab_gold_fixture() -> None:
payload = _load_fixture("signalobj_exactness.mat")
signal = SignalObj(_vector(payload, "time"), np.asarray(payload["data"], dtype=float), "sig", "time", "s", "u", ["x1", "x2"])
signal_1 = signal.getSubSignal(1)
signal_2 = SignalObj(np.arange(0.05, 0.5, 0.1), [0.0, 1.0, 0.0, -1.0, 0.0], "sig2", "time", "s", "u", ["x3"])

filtered = signal.filter(_vector(payload, "filter_b"), _vector(payload, "filter_a"))
derivative = signal.derivative
integral = signal.integral()
resampled = signal.resample(_scalar(payload, "resample_rate"))
xcorr = signal.getSubSignal(1).xcorr(signal.getSubSignal(2), int(_scalar(payload, "xcorr_maxlag")))
compatible_left, compatible_right = signal_1.makeCompatible(signal_2, holdVals=1)

np.testing.assert_allclose(filtered.data, np.asarray(payload["filtered_data"], dtype=float), rtol=1e-8, atol=1e-10)
np.testing.assert_allclose(derivative.data, np.asarray(payload["derivative_data"], dtype=float), rtol=1e-8, atol=1e-10)
Expand All @@ -41,6 +44,9 @@ def test_signalobj_matches_matlab_gold_fixture() -> None:
np.testing.assert_allclose(resampled.data, np.asarray(payload["resampled_data"], dtype=float), rtol=1e-8, atol=1e-10)
np.testing.assert_allclose(xcorr.time, _vector(payload, "xcorr_time"), rtol=1e-12, atol=1e-12)
np.testing.assert_allclose(xcorr.data.reshape(-1), _vector(payload, "xcorr_data"), rtol=1e-8, atol=1e-10)
np.testing.assert_allclose(compatible_left.time, _vector(payload, "compat_time"), rtol=1e-12, atol=1e-12)
np.testing.assert_allclose(compatible_left.data.reshape(-1), _vector(payload, "compat_left_data"), rtol=1e-8, atol=1e-10)
np.testing.assert_allclose(compatible_right.data.reshape(-1), _vector(payload, "compat_right_data"), rtol=1e-8, atol=1e-10)


def test_nspiketrain_matches_matlab_gold_fixture() -> None:
Expand Down Expand Up @@ -71,6 +77,21 @@ def test_nspiketrain_matches_matlab_gold_fixture() -> None:
np.testing.assert_allclose(parts.getNST(1).spikeTimes, _vector(payload, "part1_spikes"), rtol=1e-8, atol=1e-10)
np.testing.assert_allclose(parts.getNST(2).spikeTimes, _vector(payload, "part2_spikes"), rtol=1e-8, atol=1e-10)

restore_train = nspikeTrain(_vector(payload, "spikeTimes"), "restore", 0.2, -0.1, 0.8, "time", "s", "spikes", "spk", -1)
restore_train.setSigRep(0.1, -0.1, 0.8)
restore_train.setMinTime(-0.3)
restore_train.setMaxTime(1.1)
restore_train.restoreToOriginal()

np.testing.assert_allclose(float(restore_train.minTime), _scalar(payload, "restore_min_time"), rtol=1e-12, atol=1e-12)
np.testing.assert_allclose(float(restore_train.maxTime), _scalar(payload, "restore_max_time"), rtol=1e-12, atol=1e-12)

burst_train = nspikeTrain([0.0, 0.001, 0.002, 0.007, 0.507, 0.508, 0.509, 0.514], "bursting", 0.001, 0.0, 0.6, "time", "s", "spikes", "spk", 0)
np.testing.assert_allclose(float(burst_train.avgSpikesPerBurst), _scalar(payload, "burst_avgSpikesPerBurst"), rtol=1e-8, atol=1e-10)
np.testing.assert_allclose(float(burst_train.stdSpikesPerBurst), _scalar(payload, "burst_stdSpikesPerBurst"), rtol=1e-8, atol=1e-10)
np.testing.assert_allclose(float(burst_train.numBursts), _scalar(payload, "burst_numBursts"), rtol=1e-8, atol=1e-10)
np.testing.assert_allclose(burst_train.numSpikesPerBurst, _vector(payload, "burst_numSpikesPerBurst"), rtol=1e-8, atol=1e-10)


def test_cif_eval_surface_matches_matlab_gold_fixture() -> None:
payload = _load_fixture("cif_exactness.mat")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_nspiketrain_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def test_nspiketrain_setsigrep_restore_and_field_access_match_matlab_surface() -
assert train.getFieldVal("missing") == []

train.restoreToOriginal()
assert train.sampleRate == 5.0
np.testing.assert_allclose([train.minTime, train.maxTime], [0.0, 1.0])
assert train.sampleRate == 10.0
np.testing.assert_allclose([train.minTime, train.maxTime], [0.2, 0.6])


def test_nspiketrain_compute_statistics_matches_matlab_style_burst_metrics() -> None:
Expand Down
18 changes: 18 additions & 0 deletions tools/parity/matlab/export_matlab_gold_fixtures.m
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@ function export_signalobj_fixture(fixtureRoot)
t = (0:0.1:0.4)';
data = [1 0; 2 1; 4 0; 8 -1; 16 0];
s = SignalObj(t, data, 'sig', 'time', 's', 'u', {'x1', 'x2'});
s1 = s.getSubSignal(1);
s2 = SignalObj((0.05:0.1:0.45)', [0; 1; 0; -1; 0], 'sig2', 'time', 's', 'u', {'x3'});

filtered = s.filter([0.25 0.5 0.25], 1);
resampled = s.resample(20);
derivative = s.derivative;
integral_sig = s.integral();
xc = xcorr(s.getSubSignal(1), s.getSubSignal(2), 2);
[s1c, s2c] = s1.makeCompatible(s2, 1);

payload = struct();
payload.time = s.time;
Expand All @@ -49,6 +52,9 @@ function export_signalobj_fixture(fixtureRoot)
payload.xcorr_maxlag = 2;
payload.xcorr_time = xc.time;
payload.xcorr_data = xc.data;
payload.compat_time = s1c.time;
payload.compat_left_data = s1c.data;
payload.compat_right_data = s2c.data;

save(fullfile(fixtureRoot, 'signalobj_exactness.mat'), '-struct', 'payload');
end
Expand All @@ -59,6 +65,12 @@ function export_nspiketrain_fixture(fixtureRoot)
nst = nspikeTrain(spikeTimes, 'nst', binwidth, 0.0, 0.5, 'time', 's', 'spikes', 'spk', 0);
sig = nst.getSigRep(binwidth, 0.0, 0.5);
parts = nst.partitionNST([0.0 0.2 0.5]);
restoreTrain = nspikeTrain(spikeTimes, 'restore', 0.2, -0.1, 0.8, 'time', 's', 'spikes', 'spk', -1);
restoreTrain.setSigRep(0.1, -0.1, 0.8);
restoreTrain.setMinTime(-0.3);
restoreTrain.setMaxTime(1.1);
restoreTrain.restoreToOriginal();
burstTrain = nspikeTrain([0.0; 0.001; 0.002; 0.007; 0.507; 0.508; 0.509; 0.514], 'bursting', 0.001, 0.0, 0.6, 'time', 's', 'spikes', 'spk', 0);

payload = struct();
payload.spikeTimes = spikeTimes;
Expand All @@ -76,6 +88,12 @@ function export_nspiketrain_fixture(fixtureRoot)
payload.numSpikesPerBurst = nst.numSpikesPerBurst;
payload.part1_spikes = parts.getNST(1).spikeTimes;
payload.part2_spikes = parts.getNST(2).spikeTimes;
payload.restore_min_time = restoreTrain.minTime;
payload.restore_max_time = restoreTrain.maxTime;
payload.burst_avgSpikesPerBurst = burstTrain.avgSpikesPerBurst;
payload.burst_stdSpikesPerBurst = burstTrain.stdSpikesPerBurst;
payload.burst_numBursts = burstTrain.numBursts;
payload.burst_numSpikesPerBurst = burstTrain.numSpikesPerBurst;

save(fullfile(fixtureRoot, 'nspiketrain_exactness.mat'), '-struct', 'payload');
end
Expand Down