diff --git a/.gitignore b/.gitignore index 2342bc42..bda6bad3 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,10 @@ docs/_build/ # Local report artifacts output/ tmp/ + +# Internal-only porting/parity notes (kept outside public repo) +CLEANROOM_POLICY.md +PARITY_SPEC.md +CANONICAL_VALIDATION_ARTIFACTS.md +DISCREPANCIES.md +parity/CYCLE_VALIDATION_CHECKLIST.md diff --git a/CANONICAL_VALIDATION_ARTIFACTS.md b/CANONICAL_VALIDATION_ARTIFACTS.md deleted file mode 100644 index de5f053f..00000000 --- a/CANONICAL_VALIDATION_ARTIFACTS.md +++ /dev/null @@ -1,43 +0,0 @@ -# Canonical Validation Artifacts - -This document records the canonical gate-mode validation artifact set and -reproduction command for `nSTAT-python` parity checks. - -## Pinned MATLAB reference -- Repository: `https://github.com/cajigaslab/nSTAT.git` -- Commit SHA: `470fde8f9f6b60fe8f9ec51155e34478b6d541f6` -- Config source: [`parity/matlab_reference.yml`](./parity/matlab_reference.yml) - -## Canonical local artifact set (latest) -Generated on: `2026-03-03` (America/New_York) - -- PDF: [`output/pdf/nstat_python_validation_report_20260303_232103.pdf`](./output/pdf/nstat_python_validation_report_20260303_232103.pdf) -- JSON: [`output/pdf/nstat_python_validation_report_20260303_232103.json`](./output/pdf/nstat_python_validation_report_20260303_232103.json) -- CSV: [`output/pdf/nstat_python_validation_report_20260303_232103.csv`](./output/pdf/nstat_python_validation_report_20260303_232103.csv) - -## Reproduction command -```bash -python tools/reports/generate_validation_pdf.py \ - --repo-root "$PWD" \ - --matlab-help-root /tmp/upstream-nstat/helpfiles \ - --notebook-group all \ - --timeout 900 \ - --skip-command-tests \ - --parity-mode gate \ - --enforce-unique-images \ - --min-unique-images-per-topic 1 \ - --max-cross-topic-reuse-ratio 1.0 -``` - -## CI canonical names -CI workflows normalize the latest gate-mode report to stable artifact names: - -- `output/pdf/validation_gate_mode_latest.pdf` -- `output/pdf/validation_gate_mode_latest.json` -- `output/pdf/validation_gate_mode_latest.csv` - -Image-mode parity artifacts are emitted under: - -- `output/pdf/image_mode_parity/summary.json` -- `output/pdf/image_mode_parity/pairs.json` -- `output/pdf/image_mode_parity/diff/` diff --git a/CLEANROOM_POLICY.md b/CLEANROOM_POLICY.md deleted file mode 100644 index d99e5fc4..00000000 --- a/CLEANROOM_POLICY.md +++ /dev/null @@ -1,13 +0,0 @@ -# Clean-Room Policy for nSTAT-python - -This repository is a clean-room Python implementation of nSTAT. - -## Rules -- No MATLAB runtime/build/test dependency is allowed. -- No code/docs/workflows/notebooks/config files are copied from MATLAB nSTAT. -- Only example data may be shared across repositories, and only when explicitly listed in the allowlist. - -## Enforcement -- CI runs a hash-overlap compliance job against `cajigaslab/nSTAT`. -- Non-data hash collisions fail the build. -- Shared data files must be explicitly listed in `tools/compliance/shared_data_allowlist.yml`. diff --git a/DISCREPANCIES.md b/DISCREPANCIES.md deleted file mode 100644 index aa9bab1f..00000000 --- a/DISCREPANCIES.md +++ /dev/null @@ -1,28 +0,0 @@ -# nSTAT-python Discrepancy Log - -This log tracks MATLAB-vs-Python parity issues with minimal repro details. - -| ID | Scope | Symptom | Minimal Repro | Suspected Cause | Status | Fix / PR | -|---|---|---|---|---|---|---| -| DSP-001 | `ExplicitStimulusWhiskerData` notebook | Strict line-port remained partial and notebook used synthetic stimulus instead of MATLAB gold fixture arrays | `python tools/parity/sync_parity_artifacts.py --matlab-root ` then inspect `parity/function_example_alignment_report.json` topic row | Notebook template had extra synthetic workflow lines and lacked fixture-backed assertion | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_explicit_stimulus_whisker_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | -| DSP-002 | `HybridFilterExample` notebook | Strict line-port partial | same as above | Python notebook contained extra simulation scaffolding and lacked MATLAB-fixture numeric assertions | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_hybrid_filter_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | -| DSP-003 | `ValidationDataSet` notebook | Strict line-port partial | same as above | Python workflow was synthetic-only and lacked MATLAB-gold fixture parity assertions | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_validation_dataset_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | -| DSP-004 | `PPSimExample` notebook | Strict line-port partial | same as above | Python execution cell had synthetic scaffolding and no direct MATLAB fixture comparison | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_ppsimexample_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | -| DSP-005 | `StimulusDecode2D` notebook | Strict line-port partial | same as above | Python workflow lacked MATLAB-gold 2D decode fixture metrics | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_stimulus_decode_2d_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | -| DSP-006 | `SignalObjExamples` notebook | Strict line-port partial and no standalone MATLAB-gold notebook assertion | same as above; run `pytest tests/test_parity_matlab_gold.py -k SignalObjExamples` | Notebook template and parity suite did not include deterministic SignalObj fixture metrics | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_signal_obj_examples_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | -| DSP-007 | `HistoryExamples` notebook | Strict line-port partial with missing fixture-backed numeric parity | same as above; run `pytest tests/test_parity_matlab_gold.py -k HistoryExamples` | Missing MATLAB-gold fixture export and parity assertion for history basis generation | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_history_examples_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | -| DSP-008 | `PPThinning` notebook | Strict line-port partial and missing thinning summary parity checks | same as above; run `pytest tests/test_parity_matlab_gold.py -k PPThinning` | Notebook lacked fixture-backed acceptance-rate and spike-count comparisons | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_ppthinning_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | -| DSP-009 | `HippocampalPlaceCellExample` notebook | Strict line-port partial due anchor drift beyond baseline snapshot rows | `python tools/parity/generate_equivalence_audit.py ...` and inspect strict line-port section | Snapshot ratio accounting only covered first 64 rows and omitted extended MATLAB anchors | Resolved | Regression: `tests/test_equivalence_audit_report.py::test_top_mismatch_topics_meet_line_port_regression_thresholds`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | -| DSP-010 | `NetworkTutorial` notebook | Strict line-port partial with no standalone MATLAB-gold fixture metrics | same as above; run `pytest tests/test_parity_matlab_gold.py -k NetworkTutorial` | Missing deterministic fixture assertions and incomplete MATLAB anchor coverage | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_network_tutorial_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | -| DSP-011 | `publish_all_helpfiles` notebook | Strict line-port partial from missing extended MATLAB publish anchors | `python tools/parity/generate_equivalence_audit.py ...` and inspect strict line-port section | Baseline snapshot excluded long-form publish steps needed for strict verification | Resolved | Regression: `tests/test_equivalence_audit_report.py::test_top_mismatch_topics_meet_line_port_regression_thresholds`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | -| DSP-012 | image-mode parity gate | New page-SSIM checker failed across platforms at an initial strict threshold (`SSIM>=0.80`) | `python tools/reports/check_pdf_image_parity.py --ssim-threshold 0.80 --max-failing-pages 0` | Renderer/font differences and page composition drift produce false negatives at aggressive threshold | Resolved | Regression: `.github/workflows/image-mode-parity.yml` + `tools/reports/check_pdf_image_parity.py`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | - -## Rules -- Every parity bug fix must include a regression test that would fail before the fix. -- Close an item only when: - - parity test(s) pass locally and in CI - - corresponding row in `parity/function_example_alignment_report.json` is updated - - PR/commit link is recorded. - -## Open discrepancies -- None at this time. New mismatches should be added with a minimal reproducible command and linked regression coverage. diff --git a/PARITY_SPEC.md b/PARITY_SPEC.md deleted file mode 100644 index bf484f0a..00000000 --- a/PARITY_SPEC.md +++ /dev/null @@ -1,93 +0,0 @@ -# nSTAT-to-nSTAT-python Parity Specification - -This document defines how `nSTAT-python` is measured against MATLAB `nSTAT`. - -## Gold Standard Baseline -- MATLAB reference repository: `/Users/iahncajigas/Library/CloudStorage/Dropbox/Research/Matlab/nSTAT_currentRelease_Local` -- Baseline lock file: `baseline/baseline_lock.yml` -- Frozen MATLAB example-data snapshot: `parity/matlab_gold_snapshot_20260302.yml` -- MATLAB commit hash is pinned in the lock file and must be updated intentionally. - -## Scope -- In scope: - - Core computational classes and workflows (Signal, spike trains, trial assembly, CIF/GLM fitting, decoding) - - Help pages and class/topic discoverability - - Executable notebooks that mirror MATLAB example workflows - - CI parity discovery and reporting -- Out of scope: - - MATLAB desktop help-browser internals - - Simulink-only integrations - - MATLAB documentation-only setup/reference examples that have no executable - computational workflow parity target in Python: - - `DocumentationSetup2025b` - - `FitResSummaryExamples` - - `FitResultExamples` - - `FitResultReference` - -## Parity Contract -1. Every in-scope MATLAB class has a Python implementation route through: - - Primary Python API (`nstat.*`), and/or - - Compatibility API (`nstat.compat.matlab.*`) -2. Example workflows in `parity/example_mapping.yaml` must have: - - A notebook in `notebooks/` - - A help page in `docs/help/examples/` -3. Class help pages must exist for all mapped MATLAB classes. -4. Numerical behavior is validated by tolerance-driven tests under `tests/parity/`. -5. Parity discovery produces machine-readable artifacts: - - `parity/matlab_api_inventory.json` - - `parity/python_api_inventory.json` - - `parity/parity_gap_report.json` - - `parity/method_probe_report.json` - - `parity/function_example_alignment_report.json` - -## Severity Model -- `high`: missing class implementation route, missing notebook/help artifact, missing mapped class help page. -- `medium`: missing mapped method/alias, metadata/TOC mismatch. -- `low`: informational or optional parity signal. - -CI enforces `--fail-on high` for parity discovery so missing critical artifacts block merges. - -## Source Files -- Method mapping: `parity/method_mapping.yaml` -- Example mapping: `parity/example_mapping.yaml` -- Discovery scripts: `tools/parity/` - -## Current Status (2026-03-02) -- Baseline lock refreshed: - - MATLAB commit: `1b5237b3176f6fc8aa3199d471e4bb7845a3ad5a` - - Python commit: `8b69adf11dc0ff340e416ce97ffc90eebc011c41` -- Latest structural parity snapshot (`parity/parity_gap_report.json`): - - `summary.high = 0` - - `summary.medium = 0` - - `summary.low = 0` -- Latest functional equivalence audit (`parity/function_example_alignment_report.json`): - - Method-level audit: - - `total_methods = 501` - - `contract_verified_methods = 480` - - `contract_explicit_verified_methods = 277` - - `probe_verified_methods = 203` - - `unverified_behavior_methods = 0` - - `missing_symbol_methods = 0` - - Example-level audit: - - `total_topics = 30` - - `pending_manual_review_topics = 0` - - `missing_artifact_topics = 0` - - `missing_executable_topics = 0` - - `matlab_doc_only_topics = 4` - - `validated_topics = 26` -- Updated visual validation report: - - `output/pdf/nstat_python_validation_report_20260302_145510.pdf` (all notebooks, gate mode) - -## Acceptance Checklist -- [x] Class and example inventory artifacts regenerate successfully. -- [x] High-severity parity issues remain at zero. -- [x] Full notebook suite is passing on the validated commit. -- [x] Visual validation PDF has been regenerated after parity changes. -- [x] Structural method-mapping gaps are closed (`parity/parity_gap_report.json`). -- [ ] Functional parity contracts cover all mapped methods (`parity/function_example_alignment_report.json` currently 480/501; 21 methods explicitly excluded by policy). -- [x] Example workflows complete line-by-line manual review and output-lock verification for in-scope topics (0 pending manual review). -- [x] Method-closure sprint backlog generated (`parity/method_closure_sprint.md`). - -## Notes -- This repository is a clean-room implementation. MATLAB code is a behavioral reference only. -- Runtime MATLAB dependency is prohibited for normal package use. diff --git a/README.md b/README.md index e593c464..e50b231c 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,9 @@ # nSTAT-python -`nSTAT-python` is a clean-room Python implementation of the nSTAT toolbox. +`nSTAT-python` is a Python toolbox for neural spike-train analysis, modeling, and decoding. [![test-and-build](https://github.com/cajigaslab/nSTAT-python/actions/workflows/ci.yml/badge.svg)](https://github.com/cajigaslab/nSTAT-python/actions/workflows/ci.yml) -[![parity-gate](https://github.com/cajigaslab/nSTAT-python/actions/workflows/parity-gate.yml/badge.svg)](https://github.com/cajigaslab/nSTAT-python/actions/workflows/parity-gate.yml) -[![performance-parity](https://github.com/cajigaslab/nSTAT-python/actions/workflows/performance-parity.yml/badge.svg)](https://github.com/cajigaslab/nSTAT-python/actions/workflows/performance-parity.yml) -[![image-mode-parity](https://github.com/cajigaslab/nSTAT-python/actions/workflows/image-mode-parity.yml/badge.svg)](https://github.com/cajigaslab/nSTAT-python/actions/workflows/image-mode-parity.yml) [![pages](https://github.com/cajigaslab/nSTAT-python/actions/workflows/pages.yml/badge.svg)](https://github.com/cajigaslab/nSTAT-python/actions/workflows/pages.yml) -[![validation-pdf](https://github.com/cajigaslab/nSTAT-python/actions/workflows/validation-pdf.yml/badge.svg)](https://github.com/cajigaslab/nSTAT-python/actions/workflows/validation-pdf.yml) - -## Design goals -- Zero MATLAB runtime dependency -- Class-structure parity with MATLAB nSTAT -- Python-native implementation and docs -- Searchable help pages on GitHub Pages -- Executable learning notebooks ## Installation @@ -32,7 +21,7 @@ python -m pip install -e .[dev,docs,notebooks] ## How to install nSTAT (post-install setup) -Run the Python-native setup helper `nstat_install` (no MATLAB required): +Run the setup helper: ```bash nstat-install @@ -44,179 +33,197 @@ Equivalent Python API: from nstat.install import nstat_install report = nstat_install() -print(report.cache_dir) ``` -## Quick start - -```python -import numpy as np -from nstat.signal import Covariate -from nstat.spikes import SpikeTrain - -t = np.linspace(0.0, 1.0, 1001) -x = np.sin(2 * np.pi * 5 * t) -cov = Covariate(time=t, data=x, name="stimulus", labels=["stim"]) -spikes = SpikeTrain(spike_times=np.array([0.11, 0.42, 0.77]), t_start=0.0, t_end=1.0) -print(cov.sample_rate_hz, spikes.firing_rate_hz()) -``` +## Examples -## Documentation and help pages -- Docs home: [cajigaslab.github.io/nSTAT-python](https://cajigaslab.github.io/nSTAT-python/) -- Help index: [cajigaslab.github.io/nSTAT-python/help](https://cajigaslab.github.io/nSTAT-python/help/) -- Canonical validation artifacts: [CANONICAL_VALIDATION_ARTIFACTS.md](./CANONICAL_VALIDATION_ARTIFACTS.md) +> These examples generate figures with `matplotlib` and save PNGs under `examples/readme_examples/images/`. +> The images below show the expected output. -## Data policy -Only example data may be shared with MATLAB nSTAT. All non-data files are unique to this repository. +Examples below require `matplotlib`: -## MATLAB Data Mirror +```bash +python -m pip install matplotlib +``` -Use the bundled workflow to mirror MATLAB toolbox example data into this repo with checksums: +### Example 1 — Multi-taper spectrum of a signal +Run: ```bash -python tools/data_mirror/run_mirror_workflow.py \ - --source-root /path/to/matlab/nSTAT/data \ - --version 20260302 \ - --clean +python examples/readme_examples/example1_multitaper_spectrum.py ``` -This command performs: -1. Source snapshot manifest generation. -2. Byte-for-byte mirrored copy into `data/shared/matlab_gold_/`. -3. Shared-data allowlist regeneration. -4. Dataset API manifest regeneration (`data/datasets_manifest.json`). -5. Strict checksum verification. +```python +import matplotlib +matplotlib.use("Agg") -To re-verify later: +from pathlib import Path -```bash -python tools/data_mirror/verify_matlab_data.py \ - --manifest data/shared/matlab_gold_20260302.manifest.json \ - --strict +import matplotlib.pyplot as plt +import numpy as np + +from nstat.compat.matlab import SignalObj + +rng = np.random.default_rng(0) +fs_hz = 1000.0 +dt = 1.0 / fs_hz +duration_s = 2.0 +time = np.arange(0.0, duration_s, dt, dtype=float) + +signal = ( + 1.0 * np.sin(2.0 * np.pi * 10.0 * time) + + 0.6 * np.sin(2.0 * np.pi * 40.0 * time + 0.3) + + 0.2 * np.sin(2.0 * np.pi * 75.0 * time) + + 0.12 * rng.standard_normal(time.size) +) + +sig_obj = SignalObj(time=time, data=signal, name="synthetic_signal", units="a.u.") +freq_hz, psd = sig_obj.MTMspectrum() + +fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(7.0, 5.0), sharex=False) +preview_mask = time <= 1.0 +ax1.plot(time[preview_mask], signal[preview_mask], color="black", linewidth=1.0) +ax1.set_xlabel("time (s)") +ax1.set_ylabel("amplitude") +ax1.set_title("Synthetic signal (first 1 s)") +ax2.plot(freq_hz, psd, color="tab:blue", linewidth=1.2) +ax2.set_xlim(0.0, 150.0) +ax2.set_xlabel("frequency (Hz)") +ax2.set_ylabel("power spectral density") +ax2.set_title("Multi-taper spectrum") +fig.tight_layout() + +out_dir = Path("examples/readme_examples/images") +out_dir.mkdir(parents=True, exist_ok=True) +fig.savefig(out_dir / "readme_example1_multitaper_spectrum.png", dpi=180) ``` -## MATLAB parity workflow +**Expected output** +![Multi-taper spectrum](examples/readme_examples/images/readme_example1_multitaper_spectrum.png) -Generate parity inventories and the machine-readable gap report: +### Example 2 — Simulate a spike train from a time-varying CIF +Run: ```bash -python tools/parity/build_parity_snapshot.py \ - --matlab-root /path/to/matlab/nSTAT \ - --fail-on high +python examples/readme_examples/example2_simulate_cif_spiketrain.py ``` -Artifacts are written to: -- `parity/matlab_api_inventory.json` -- `parity/python_api_inventory.json` -- `parity/parity_gap_report.json` -- `parity/TIER1_PORT_BACKLOG.md` +```python +import matplotlib +matplotlib.use("Agg") -Tier-1 progress gate: +from pathlib import Path -```bash -python tools/parity/check_tier1_progress.py \ - --report parity/parity_gap_report.json \ - --policy parity/tier1_gate_policy.yml +import matplotlib.pyplot as plt +import numpy as np + +from nstat.compat.matlab import CIF, Covariate + +np.random.seed(0) +dt = 0.001 +duration_s = 2.0 +time = np.arange(0.0, duration_s + 0.5 * dt, dt, dtype=float) + +lambda_t = 12.0 + 5.5 * np.sin(2.0 * np.pi * 2.0 * time) + 1.5 * np.cos(2.0 * np.pi * 6.0 * time) +lambda_t = np.clip(lambda_t, 0.1, None) + +lambda_cov = Covariate(time=time, data=lambda_t, name="Lambda(t)", units="spikes/s", labels=["lambda"]) +coll = CIF.simulateCIFByThinningFromLambda(lambda_cov, 1, dt) +spike_times = coll.getNST(0).spike_times + +fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(7.0, 5.0), sharex=True, gridspec_kw={"height_ratios": [2.0, 1.0]}) +ax1.plot(time, lambda_t, color="tab:blue", linewidth=1.4) +ax1.set_ylabel("rate (spikes/s)") +ax1.set_title("Time-varying CIF") +ax2.vlines(spike_times, 0.0, 1.0, color="black", linewidth=0.8) +ax2.set_ylim(0.0, 1.0) +ax2.set_xlabel("time (s)") +ax2.set_ylabel("spikes") +ax2.set_title("Simulated spike train") +fig.tight_layout() + +out_dir = Path("examples/readme_examples/images") +out_dir.mkdir(parents=True, exist_ok=True) +fig.savefig(out_dir / "readme_example2_simulate_cif_spiketrain.png", dpi=180) ``` -MATLAB-style adapters are available under: -- `nstat.compat.matlab` +**Expected output** +![CIF spike train simulation](examples/readme_examples/images/readme_example2_simulate_cif_spiketrain.png) -Sync parity artifacts (functional audit + sprint backlog + help dashboard/docs): +### Example 3 — Spike-train raster (collection, nstColl.plot) +Run: ```bash -python tools/parity/sync_parity_artifacts.py \ - --matlab-root /path/to/matlab/nSTAT +python examples/readme_examples/example3_spike_train_raster_nstcoll.py ``` -## RC Release Automation +```python +import matplotlib +matplotlib.use("Agg") -Use the GitHub Actions workflow `.github/workflows/release-rc.yml` to: -1. Rebuild parity artifacts and enforce gates. -2. Generate a full validation PDF. -3. Auto-generate RC release notes from parity reports. -4. Publish/update a pre-release with the latest PDF asset attached. +from pathlib import Path -You can trigger it from GitHub Actions (`release-rc`) with an input tag -like `v1.0.0-rc3`. +import matplotlib.pyplot as plt +import numpy as np -## Stable Release Promotion +from nstat.compat.matlab import CIF, Covariate -Use `.github/workflows/release-stable.yml` to promote a validated RC to a stable release. -The workflow: -1. Checks out the RC tag commit. -2. Runs hard checks (lint, typing, unit tests, docs build). -3. Runs parity and numeric-drift gates. -4. Regenerates the validation PDF. -5. Creates/pushes the stable tag and publishes a non-prerelease release. +np.random.seed(0) +dt = 0.001 +duration_s = 2.0 +n_units = 20 +time = np.arange(0.0, duration_s + 0.5 * dt, dt, dtype=float) -Inputs: -- `rc_tag` (for example `v1.0.0-rc3`) -- `stable_tag` (for example `v1.0.0`) +lambda_t = 9.0 + 4.0 * np.sin(2.0 * np.pi * 1.5 * time) + 2.0 * np.sin(2.0 * np.pi * 4.0 * time + 0.25) +lambda_t = np.clip(lambda_t, 0.1, None) -## PR-Native Parity Gate +lambda_cov = Covariate(time=time, data=lambda_t, name="Lambda(t)", units="spikes/s", labels=["lambda"]) +coll = CIF.simulateCIFByThinningFromLambda(lambda_cov, n_units, dt) -`.github/workflows/parity-gate.yml` runs on every pull request and enforces: -- Parity snapshot gate (`--fail-on medium`) -- Numeric drift thresholds -- Functional parity policy -- Example-output parity policy -- Synchronized parity artifacts (`parity/*`, `docs/help/*`, `docs/notebooks.md`, `baseline/help_mapping.json`) -- Full gate-mode validation PDF generation with canonical artifact names: - - `output/pdf/validation_gate_mode_latest.pdf` - - `output/pdf/validation_gate_mode_latest.json` - - `output/pdf/validation_gate_mode_latest.csv` +fig, ax = plt.subplots(figsize=(7.0, 4.5)) +plt.sca(ax) +coll.plot() +ax.set_xlabel("time (s)") +ax.set_ylabel("unit index") +ax.set_title("Spike-train raster (nstColl.plot)") +ax.set_ylim(0.5, n_units + 0.5) +fig.tight_layout() -## Function-Level Performance Parity +out_dir = Path("examples/readme_examples/images") +out_dir.mkdir(parents=True, exist_ok=True) +fig.savefig(out_dir / "readme_example3_spike_train_raster_nstcoll.png", dpi=180) +``` -Run deterministic Python workload benchmarks: +**Expected output** +![Spike train raster](examples/readme_examples/images/readme_example3_spike_train_raster_nstcoll.png) -```bash -python tools/performance/run_python_benchmarks.py \ - --tiers S,M,L \ - --repeats 7 \ - --warmup 2 \ - --out-json output/performance/python_performance_report.json \ - --out-csv output/performance/python_performance_report.csv -``` +## Examples and notebooks -Compare Python runtime/memory metrics against MATLAB baseline fixtures: +- Python scripts and notebooks: `notebooks/` +- Learning notebooks are executable and suitable for local exploration or CI smoke runs. -```bash -python tools/performance/compare_matlab_python_performance.py \ - --python-report output/performance/python_performance_report.json \ - --matlab-report tests/performance/fixtures/matlab/performance_baseline_470fde8.json \ - --policy parity/performance_gate_policy.yml \ - --previous-python-report tests/performance/fixtures/python/performance_baseline_linux_latest.json \ - --report-out parity/performance_parity_report.json \ - --csv-out parity/performance_parity_report.csv \ - --fail-on-regression \ - --require-regression-env-match -``` +## Documentation + +- Docs home: [cajigaslab.github.io/nSTAT-python](https://cajigaslab.github.io/nSTAT-python/) +- Help index: [cajigaslab.github.io/nSTAT-python/help](https://cajigaslab.github.io/nSTAT-python/help/) -Generate MATLAB baseline report (controlled environment): +## Developer notes + +- Run tests: ```bash -matlab -batch "addpath('matlab/benchmark'); run_matlab_performance_benchmarks( ... - 'tests/performance/fixtures/matlab/performance_baseline_470fde8.json', ... - 'tests/performance/fixtures/matlab/performance_baseline_470fde8.csv', ... - '/path/to/nSTAT')" +pytest -q ``` -## Branch Protection Automation +- Build docs: -To apply required checks on `main` (admin token required): +```bash +sphinx-build -b html docs docs/_build +``` -Current required checks on `main`: -- `unit-lint (3.11)` -- `unit-lint (3.12)` -- `docs-smoke-notebooks` -- `matlab-data-integrity` -- `cleanroom-compliance` -- `parity-checks` -- `build-validation-pdf` -- `image-mode-parity` -- `performance-parity` +## Cite -## Paper reference -Cajigas I, Malik WQ, Brown EN. nSTAT: Open-source neural spike train analysis toolbox for Matlab. *J Neurosci Methods* (2012), DOI: `10.1016/j.jneumeth.2012.08.009`, PMID: `22981419`. +Cajigas, I., Malika, W. Q., & Brown, E. N. (2012). +nSTAT: Open-source neural spike train analysis toolbox for Matlab. +Journal of Neuroscience Methods, 211, 245–264. +https://doi.org/10.1016/j.jneumeth.2012.08.009 diff --git a/examples/readme_examples/example1_multitaper_spectrum.py b/examples/readme_examples/example1_multitaper_spectrum.py new file mode 100644 index 00000000..6a92194d --- /dev/null +++ b/examples/readme_examples/example1_multitaper_spectrum.py @@ -0,0 +1,53 @@ +import matplotlib +matplotlib.use("Agg") + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + +from nstat.compat.matlab import SignalObj + + +def main() -> None: + rng = np.random.default_rng(0) + fs_hz = 1000.0 + dt = 1.0 / fs_hz + duration_s = 2.0 + time = np.arange(0.0, duration_s, dt, dtype=float) + + signal = ( + 1.0 * np.sin(2.0 * np.pi * 10.0 * time) + + 0.6 * np.sin(2.0 * np.pi * 40.0 * time + 0.3) + + 0.2 * np.sin(2.0 * np.pi * 75.0 * time) + + 0.12 * rng.standard_normal(time.size) + ) + + sig_obj = SignalObj(time=time, data=signal, name="synthetic_signal", units="a.u.") + freq_hz, psd = sig_obj.MTMspectrum() + + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(7.0, 5.0), sharex=False) + + preview_mask = time <= 1.0 + ax1.plot(time[preview_mask], signal[preview_mask], color="black", linewidth=1.0) + ax1.set_xlabel("time (s)") + ax1.set_ylabel("amplitude") + ax1.set_title("Synthetic signal (first 1 s)") + + ax2.plot(freq_hz, psd, color="tab:blue", linewidth=1.2) + ax2.set_xlim(0.0, 150.0) + ax2.set_xlabel("frequency (Hz)") + ax2.set_ylabel("power spectral density") + ax2.set_title("Multi-taper spectrum") + + fig.tight_layout() + + out_dir = Path(__file__).resolve().parent / "images" + out_dir.mkdir(parents=True, exist_ok=True) + out_path = out_dir / "readme_example1_multitaper_spectrum.png" + fig.savefig(out_path, dpi=180) + plt.close(fig) + + +if __name__ == "__main__": + main() diff --git a/examples/readme_examples/example2_simulate_cif_spiketrain.py b/examples/readme_examples/example2_simulate_cif_spiketrain.py new file mode 100644 index 00000000..7cb9d9f1 --- /dev/null +++ b/examples/readme_examples/example2_simulate_cif_spiketrain.py @@ -0,0 +1,65 @@ +import matplotlib +matplotlib.use("Agg") + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + +from nstat.compat.matlab import CIF, Covariate + + +def main() -> None: + np.random.seed(0) + + dt = 0.001 + duration_s = 2.0 + time = np.arange(0.0, duration_s + 0.5 * dt, dt, dtype=float) + + lambda_t = ( + 12.0 + + 5.5 * np.sin(2.0 * np.pi * 2.0 * time) + + 1.5 * np.cos(2.0 * np.pi * 6.0 * time) + ) + lambda_t = np.clip(lambda_t, 0.1, None) + + lambda_cov = Covariate( + time=time, + data=lambda_t, + name="Lambda(t)", + units="spikes/s", + labels=["lambda"], + ) + + coll = CIF.simulateCIFByThinningFromLambda(lambda_cov, 1, dt) + spike_times = coll.getNST(0).spike_times + + fig, (ax1, ax2) = plt.subplots( + 2, + 1, + figsize=(7.0, 5.0), + sharex=True, + gridspec_kw={"height_ratios": [2.0, 1.0]}, + ) + + ax1.plot(time, lambda_t, color="tab:blue", linewidth=1.4) + ax1.set_ylabel("rate (spikes/s)") + ax1.set_title("Time-varying CIF") + + ax2.vlines(spike_times, 0.0, 1.0, color="black", linewidth=0.8) + ax2.set_ylim(0.0, 1.0) + ax2.set_xlabel("time (s)") + ax2.set_ylabel("spikes") + ax2.set_title("Simulated spike train") + + fig.tight_layout() + + out_dir = Path(__file__).resolve().parent / "images" + out_dir.mkdir(parents=True, exist_ok=True) + out_path = out_dir / "readme_example2_simulate_cif_spiketrain.png" + fig.savefig(out_path, dpi=180) + plt.close(fig) + + +if __name__ == "__main__": + main() diff --git a/examples/readme_examples/example3_spike_train_raster_nstcoll.py b/examples/readme_examples/example3_spike_train_raster_nstcoll.py new file mode 100644 index 00000000..51e70a8c --- /dev/null +++ b/examples/readme_examples/example3_spike_train_raster_nstcoll.py @@ -0,0 +1,55 @@ +import matplotlib +matplotlib.use("Agg") + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + +from nstat.compat.matlab import CIF, Covariate + + +def main() -> None: + np.random.seed(0) + + dt = 0.001 + duration_s = 2.0 + n_units = 20 + time = np.arange(0.0, duration_s + 0.5 * dt, dt, dtype=float) + + lambda_t = ( + 9.0 + + 4.0 * np.sin(2.0 * np.pi * 1.5 * time) + + 2.0 * np.sin(2.0 * np.pi * 4.0 * time + 0.25) + ) + lambda_t = np.clip(lambda_t, 0.1, None) + + lambda_cov = Covariate( + time=time, + data=lambda_t, + name="Lambda(t)", + units="spikes/s", + labels=["lambda"], + ) + + coll = CIF.simulateCIFByThinningFromLambda(lambda_cov, n_units, dt) + + fig, ax = plt.subplots(figsize=(7.0, 4.5)) + plt.sca(ax) + coll.plot() + ax.set_xlabel("time (s)") + ax.set_ylabel("unit index") + ax.set_title("Spike-train raster (nstColl.plot)") + ax.set_ylim(0.5, n_units + 0.5) + + fig.tight_layout() + + out_dir = Path(__file__).resolve().parent / "images" + out_dir.mkdir(parents=True, exist_ok=True) + out_path = out_dir / "readme_example3_spike_train_raster_nstcoll.png" + fig.savefig(out_path, dpi=180) + plt.close(fig) + + +if __name__ == "__main__": + main() diff --git a/examples/readme_examples/images/readme_example1_multitaper_spectrum.png b/examples/readme_examples/images/readme_example1_multitaper_spectrum.png new file mode 100644 index 00000000..f8ee297c Binary files /dev/null and b/examples/readme_examples/images/readme_example1_multitaper_spectrum.png differ diff --git a/examples/readme_examples/images/readme_example2_simulate_cif_spiketrain.png b/examples/readme_examples/images/readme_example2_simulate_cif_spiketrain.png new file mode 100644 index 00000000..d615713f Binary files /dev/null and b/examples/readme_examples/images/readme_example2_simulate_cif_spiketrain.png differ diff --git a/examples/readme_examples/images/readme_example3_spike_train_raster_nstcoll.png b/examples/readme_examples/images/readme_example3_spike_train_raster_nstcoll.png new file mode 100644 index 00000000..3f99df40 Binary files /dev/null and b/examples/readme_examples/images/readme_example3_spike_train_raster_nstcoll.png differ diff --git a/parity/CYCLE_VALIDATION_CHECKLIST.md b/parity/CYCLE_VALIDATION_CHECKLIST.md deleted file mode 100644 index f15173bb..00000000 --- a/parity/CYCLE_VALIDATION_CHECKLIST.md +++ /dev/null @@ -1,47 +0,0 @@ -# Cycle Validation Checklist (2026-03-04) - -Commands used each cycle: -- `pytest -q` -- `python tools/parity/build_numeric_drift_report.py --fixtures-manifest tests/parity/fixtures/matlab_gold/manifest.yml --thresholds parity/numeric_drift_thresholds.yml --report-out parity/numeric_drift_report.json --fail-on-violation` -- `python tools/parity/check_functional_parity_progress.py --report parity/function_example_alignment_report.json --policy parity/functional_gate_policy.yml` -- `python tools/parity/check_example_output_spec.py --report parity/function_example_alignment_report.json --spec parity/example_output_spec.yml` -- `python tools/reports/generate_validation_pdf.py --repo-root "$PWD" --matlab-help-root /Users/iahncajigas/Library/CloudStorage/Dropbox/Research/Matlab/nSTAT_currentRelease_Local/helpfiles --notebook-group all --timeout 900 --skip-command-tests --parity-mode gate --enforce-unique-images --min-unique-images-per-topic 1 --max-cross-topic-reuse-ratio 1.0` -- `python tools/reports/generate_validation_pdf.py --repo-root "$PWD" --matlab-help-root /Users/iahncajigas/Library/CloudStorage/Dropbox/Research/Matlab/nSTAT_currentRelease_Local/helpfiles --notebook-group all --timeout 900 --skip-command-tests --parity-mode image --skip-parity-check` -- `python tools/reports/build_image_parity_pdfs.py --report-json --python-out output/pdf/image_mode_parity/python_pages.pdf --matlab-out output/pdf/image_mode_parity/matlab_pages.pdf --pairs-json output/pdf/image_mode_parity/pairs.json` -- `python tools/reports/check_pdf_image_parity.py --python-pdf output/pdf/image_mode_parity/python_pages.pdf --matlab-pdf output/pdf/image_mode_parity/matlab_pages.pdf --out-dir output/pdf/image_mode_parity --dpi 150 --ssim-threshold 0.70 --max-failing-pages 0` -- `python tools/performance/run_python_benchmarks.py --tiers S --repeats 5 --warmup 1 --out-json output/performance/python_performance_report.json --out-csv output/performance/python_performance_report.csv` -- `python tools/performance/compare_matlab_python_performance.py --python-report output/performance/python_performance_report.json --matlab-report tests/performance/fixtures/matlab/performance_baseline_470fde8.json --policy parity/performance_gate_policy.yml --previous-python-report tests/performance/fixtures/python/performance_baseline_linux_latest.json --report-out output/performance/performance_parity_report.json --csv-out output/performance/performance_parity_report.csv --fail-on-regression --require-regression-env-match` -- Local macOS reruns use `tests/performance/fixtures/python/performance_baseline_20260303.json` with the same command to satisfy strict env matching. - -## Cycle 1 -- Log: `output/cycle/cycle1.log` -- `pytest`: PASS -- numeric drift (0 failed topics): PASS -- functional parity (no gaps/partials): PASS -- example output spec: PASS -- gate-mode validation PDF (0 parity failures, 0 uniqueness violations): PASS -- image-mode parity (0 failing pages): PASS -- performance-parity (0 regression failures): PASS -- Fixes applied in cycle: comparator option to require regression env match + regression test coverage. - -## Cycle 2 -- Log: `output/cycle/cycle2.log` -- `pytest`: PASS -- numeric drift (0 failed topics): PASS -- functional parity (no gaps/partials): PASS -- example output spec: PASS -- gate-mode validation PDF (0 parity failures, 0 uniqueness violations): PASS -- image-mode parity (0 failing pages): PASS -- performance-parity (0 regression failures): PASS -- Fixes applied in cycle: Linux baseline + strict regression env matching in workflow/tests, decoding `computeSpikeRateCIs` vectorization, and added deterministic performance workloads for `nspikeTrain.getSigRep` and `Analysis.fitGLM`. - -## Cycle 3 -- Log: `output/cycle/cycle3.log` -- `pytest`: PASS -- numeric drift (0 failed topics): PASS -- functional parity (no gaps/partials): PASS -- example output spec: PASS -- gate-mode validation PDF (0 parity failures, 0 uniqueness violations): PASS -- image-mode parity (0 failing pages): PASS -- performance-parity (0 regression failures): PASS -- Fixes applied in cycle: none required; full acceptance suite rerun clean after Cycle 2 changes. diff --git a/src/nstat/compat/matlab/__init__.py b/src/nstat/compat/matlab/__init__.py index 09829dfb..f30f5674 100644 --- a/src/nstat/compat/matlab/__init__.py +++ b/src/nstat/compat/matlab/__init__.py @@ -3300,6 +3300,119 @@ def _build_unit_impulse_basis(numBasis: int, minTime: float, maxTime: float, del time = np.asarray(basis_sig.time, dtype=float).reshape(-1) return basis_mat, time + @staticmethod + def _draw_xk_samples_spec(xK_arr: np.ndarray, Wku_arr: np.ndarray, Mc: int, rng: np.random.Generator) -> np.ndarray: + # MATLAB mirror: for r=1:numBasis, for c=1:Mc, xKdraw(r,:,c)=xK(r,:)+chol(WkuTemp)*z + numBasis, K = xK_arr.shape + xK_draw = np.zeros((numBasis, K, int(Mc)), dtype=float) + for r in range(numBasis): + if Wku_arr.ndim == 4: + Wku_temp = np.asarray(Wku_arr[r, r, :, :], dtype=float) + elif Wku_arr.ndim == 3: + Wku_temp = np.asarray(Wku_arr[r, :, :], dtype=float) + elif Wku_arr.ndim == 2: + Wku_temp = np.asarray(Wku_arr, dtype=float) + else: + Wku_temp = np.asarray(0.0, dtype=float) + if Wku_temp.ndim == 0: + chol_m = np.diag(np.repeat(float(np.sqrt(max(Wku_temp.item(), 0.0))), K)) + else: + chol_m = DecodingAlgorithms._chol_like_matlab(Wku_temp) + if chol_m.shape != (K, K): + raise ValueError("Wku covariance slice must be KxK") + for c in range(int(Mc)): + z = rng.normal(0.0, 1.0, size=(K,)) + xK_draw[r, :, c] = xK_arr[r, :] + (chol_m @ z) + return xK_draw + + @staticmethod + def _draw_xk_samples_fast(xK_arr: np.ndarray, Wku_arr: np.ndarray, Mc: int, rng: np.random.Generator) -> np.ndarray: + # Fast equivalent of _draw_xk_samples_spec with identical RNG ordering. + numBasis, K = xK_arr.shape + xK_draw = np.zeros((numBasis, K, int(Mc)), dtype=float) + for r in range(numBasis): + if Wku_arr.ndim == 4: + Wku_temp = np.asarray(Wku_arr[r, r, :, :], dtype=float) + elif Wku_arr.ndim == 3: + Wku_temp = np.asarray(Wku_arr[r, :, :], dtype=float) + elif Wku_arr.ndim == 2: + Wku_temp = np.asarray(Wku_arr, dtype=float) + else: + Wku_temp = np.asarray(0.0, dtype=float) + if Wku_temp.ndim == 0: + chol_m = np.diag(np.repeat(float(np.sqrt(max(Wku_temp.item(), 0.0))), K)) + else: + chol_m = DecodingAlgorithms._chol_like_matlab(Wku_temp) + if chol_m.shape != (K, K): + raise ValueError("Wku covariance slice must be KxK") + z_draw = rng.normal(0.0, 1.0, size=(int(Mc), K)) + xK_draw[r, :, :] = xK_arr[r, :][:, None] + (chol_m @ z_draw.T) + return xK_draw + + @staticmethod + def _compute_draw_rates_spec( + basis_mat: np.ndarray, + xK_draw: np.ndarray, + draw_index: int, + Hk: list[np.ndarray], + gamma_vec: np.ndarray, + fit_type: str, + delta: float, + ) -> np.ndarray: + # MATLAB mirror: for each draw c and trial k, evaluate lambda_k(t). + K = xK_draw.shape[1] + n_time = basis_mat.shape[0] + rates = np.zeros((n_time, K), dtype=float) + for k in range(K): + stim_k = basis_mat @ xK_draw[:, k, draw_index] + hk = Hk[k] + cols = min(hk.shape[1], gamma_vec.size) + if cols > 0 and np.any(np.abs(gamma_vec[:cols]) > 0.0): + hist_lin = hk[:, :cols] @ gamma_vec[:cols] + else: + hist_lin = np.zeros(stim_k.shape[0], dtype=float) + eta = stim_k + hist_lin + if fit_type == "poisson": + lam = np.exp(eta) + else: + exp_eta = np.exp(eta) + lam = exp_eta / (1.0 + exp_eta) + rates[:, k] = lam / float(delta) + return rates + + @staticmethod + def _compute_draw_rates_fast( + basis_mat: np.ndarray, + xK_draw: np.ndarray, + draw_index: int, + hist_term: np.ndarray, + fit_type: str, + delta: float, + ) -> np.ndarray: + stim_ck = basis_mat @ xK_draw[:, :, draw_index] + eta = stim_ck + hist_term + if fit_type == "poisson": + rates = np.exp(eta) + else: + exp_eta = np.exp(eta) + rates = exp_eta / (1.0 + exp_eta) + return rates / float(delta) + + @staticmethod + def _compute_prob_mat_spec(spike_rate: np.ndarray, Mc: int) -> np.ndarray: + # MATLAB mirror: upper-triangle probability matrix P(rate_m > rate_k). + K = spike_rate.shape[1] + prob_mat = np.zeros((K, K), dtype=float) + for k in range(K): + for m in range(k + 1, K): + prob_mat[k, m] = float(np.sum(spike_rate[:, m] > spike_rate[:, k])) / float(Mc) + return prob_mat + + @staticmethod + def _compute_prob_mat_fast(spike_rate: np.ndarray) -> np.ndarray: + prob_full = np.mean(spike_rate[:, None, :] > spike_rate[:, :, None], axis=0) + return np.triu(np.asarray(prob_full, dtype=float), k=1) + @staticmethod def _compute_spike_rate_cis_matlab( xK: np.ndarray, @@ -3313,7 +3426,10 @@ def _compute_spike_rate_cis_matlab( windowTimes: Any = None, Mc: int = 500, alphaVal: float = 0.05, + implementation: str = "fast", ) -> tuple[_Covariate, np.ndarray, np.ndarray]: + # MATLAB reference block: DecodingAlgorithms.computeSpikeRateCIs + # Keep a readable spec path; use fast helpers for CI/runtime workflows. xK_arr = np.asarray(xK, dtype=float) if xK_arr.ndim != 2: raise ValueError("xK must be 2D with shape (numBasis, K)") @@ -3330,7 +3446,11 @@ def _compute_spike_rate_cis_matlab( raise ValueError("alphaVal must be in (0, 1)") if int(Mc) <= 0: raise ValueError("Mc must be > 0") + impl = str(implementation).lower() + if impl not in {"spec", "fast"}: + raise ValueError("implementation must be 'spec' or 'fast'") + # MATLAB block: construct unit-impulse basis on [0, Tmax]. min_time = 0.0 max_time = float(dN_arr.shape[1] - 1) * float(delta) basis_mat, basis_time = DecodingAlgorithms._build_unit_impulse_basis(numBasis, min_time, max_time, float(delta)) @@ -3342,6 +3462,7 @@ def _compute_spike_rate_cis_matlab( basis_time = basis_time[: dN_arr.shape[1]] time = basis_time + # MATLAB block: build history design matrices H{k} when windowTimes provided. window_vals = np.asarray([] if windowTimes is None else windowTimes, dtype=float).reshape(-1) if window_vals.size > 0: if window_vals.size == 1: @@ -3365,27 +3486,13 @@ def _compute_spike_rate_cis_matlab( Hk = [np.zeros((dN_arr.shape[1], 1), dtype=float) for _ in range(K)] gamma_vec = np.zeros(1, dtype=float) + # MATLAB block: Monte Carlo coefficient draws xKdraw. Wku_arr = np.asarray(Wku, dtype=float) - xK_draw = np.zeros((numBasis, K, int(Mc)), dtype=float) rng = np.random.default_rng(0) - for r in range(numBasis): - if Wku_arr.ndim == 4: - Wku_temp = np.asarray(Wku_arr[r, r, :, :], dtype=float) - elif Wku_arr.ndim == 3: - Wku_temp = np.asarray(Wku_arr[r, :, :], dtype=float) - elif Wku_arr.ndim == 2: - Wku_temp = np.asarray(Wku_arr, dtype=float) - else: - Wku_temp = np.asarray(0.0, dtype=float) - if Wku_temp.ndim == 0: - chol_m = np.diag(np.repeat(float(np.sqrt(max(Wku_temp.item(), 0.0))), K)) - else: - chol_m = DecodingAlgorithms._chol_like_matlab(Wku_temp) - if chol_m.shape != (K, K): - raise ValueError("Wku covariance slice must be KxK") - # Preserve MATLAB-parity draw ordering by sampling (Mc, K), where each row is one Monte-Carlo draw. - z_draw = rng.normal(0.0, 1.0, size=(int(Mc), K)) - xK_draw[r, :, :] = xK_arr[r, :][:, None] + (chol_m @ z_draw.T) + if impl == "fast": + xK_draw = DecodingAlgorithms._draw_xk_samples_fast(xK_arr, Wku_arr, int(Mc), rng) + else: + xK_draw = DecodingAlgorithms._draw_xk_samples_spec(xK_arr, Wku_arr, int(Mc), rng) spike_rate = np.zeros((int(Mc), K), dtype=float) mask = (time >= float(t0)) & (time <= float(tf)) @@ -3394,7 +3501,8 @@ def _compute_spike_rate_cis_matlab( if integrate_fn is None: integrate_fn = getattr(np, "trapz", None) # pragma: no cover - NumPy<2 fallback - if window_vals.size > 0 and np.any(np.abs(gamma_vec) > 0.0): + use_history = window_vals.size > 0 and np.any(np.abs(gamma_vec) > 0.0) + if use_history and impl == "fast": hist_term = np.zeros((dN_arr.shape[1], K), dtype=float) for k in range(K): hk = Hk[k] @@ -3403,14 +3511,27 @@ def _compute_spike_rate_cis_matlab( else: hist_term = np.zeros((dN_arr.shape[1], K), dtype=float) + # MATLAB block: for each draw c, integrate trial rates over [t0, tf]. for c in range(int(Mc)): - stim_ck = basis_mat @ xK_draw[:, :, c] - eta = stim_ck + hist_term - if fit_type == "poisson": - rates = np.exp(eta) / float(delta) + if impl == "fast": + rates = DecodingAlgorithms._compute_draw_rates_fast( + basis_mat=basis_mat, + xK_draw=xK_draw, + draw_index=c, + hist_term=hist_term, + fit_type=fit_type, + delta=float(delta), + ) else: - exp_eta = np.exp(eta) - rates = (exp_eta / (1.0 + exp_eta)) / float(delta) + rates = DecodingAlgorithms._compute_draw_rates_spec( + basis_mat=basis_mat, + xK_draw=xK_draw, + draw_index=c, + Hk=Hk, + gamma_vec=gamma_vec, + fit_type=fit_type, + delta=float(delta), + ) if np.sum(mask) < 2: integral_vals = np.zeros(K, dtype=float) else: @@ -3435,6 +3556,7 @@ def _compute_spike_rate_cis_matlab( CIs[k, 0] = float(lo[-1]) if lo.size else float(vals[0]) CIs[k, 1] = float(hi[0]) if hi.size else float(vals[-1]) + # MATLAB block: emit Covariate with attached confidence interval. spike_rate_sig = Covariate( time=np.arange(1, K + 1, dtype=float), data=np.mean(spike_rate, axis=0), @@ -3452,9 +3574,10 @@ def _compute_spike_rate_cis_matlab( ci_obj.setColor("b") spike_rate_sig.setConfInterval(ci_obj) - # prob_mat(k,m) = P(rate_m > rate_k), with MATLAB-style upper-triangle usage. - prob_full = np.mean(spike_rate[:, None, :] > spike_rate[:, :, None], axis=0) - prob_mat = np.triu(np.asarray(prob_full, dtype=float), k=1) + if impl == "fast": + prob_mat = DecodingAlgorithms._compute_prob_mat_fast(spike_rate) + else: + prob_mat = DecodingAlgorithms._compute_prob_mat_spec(spike_rate, int(Mc)) sig_mat = (prob_mat > (1.0 - float(alphaVal))).astype(float) return spike_rate_sig, prob_mat, sig_mat