diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 77651ad0..7e599df1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -128,6 +128,12 @@ jobs: cleanroom-compliance: runs-on: ubuntu-latest + env: + OMP_NUM_THREADS: "1" + MKL_NUM_THREADS: "1" + OPENBLAS_NUM_THREADS: "1" + NUMEXPR_NUM_THREADS: "1" + VECLIB_MAXIMUM_THREADS: "1" steps: - uses: actions/checkout@v4 @@ -144,9 +150,12 @@ jobs: python -m pip install -e . python -m pip install pyyaml - - name: Checkout upstream MATLAB nSTAT repo snapshot + - name: Checkout pinned MATLAB nSTAT reference run: | - GIT_LFS_SKIP_SMUDGE=1 git clone --depth 1 https://github.com/cajigaslab/nSTAT.git /tmp/upstream-nstat + python tools/parity/checkout_matlab_reference.py \ + --config parity/matlab_reference.yml \ + --dest /tmp/upstream-nstat \ + --metadata-out parity/matlab_reference_checkout.json - name: Run clean-room overlap check run: | diff --git a/.github/workflows/full-parity-nightly.yml b/.github/workflows/full-parity-nightly.yml index 5cac1fd9..d91a9c0f 100644 --- a/.github/workflows/full-parity-nightly.yml +++ b/.github/workflows/full-parity-nightly.yml @@ -8,6 +8,12 @@ on: jobs: full-parity: runs-on: ubuntu-latest + env: + OMP_NUM_THREADS: "1" + MKL_NUM_THREADS: "1" + OPENBLAS_NUM_THREADS: "1" + NUMEXPR_NUM_THREADS: "1" + VECLIB_MAXIMUM_THREADS: "1" steps: - uses: actions/checkout@v4 @@ -26,9 +32,12 @@ jobs: python -m pip install -e .[dev,docs,notebooks] python -m pip install reportlab pillow - - name: Checkout upstream MATLAB nSTAT repo snapshot + - name: Checkout pinned MATLAB nSTAT reference run: | - GIT_LFS_SKIP_SMUDGE=1 git clone --depth 1 https://github.com/cajigaslab/nSTAT.git /tmp/upstream-nstat + python tools/parity/checkout_matlab_reference.py \ + --config parity/matlab_reference.yml \ + --dest /tmp/upstream-nstat \ + --metadata-out parity/matlab_reference_checkout.json - name: Prepare deterministic validation images run: | @@ -65,6 +74,14 @@ jobs: --min-unique-images-per-topic 1 \ --max-cross-topic-reuse-ratio 1.0 + - name: Normalize canonical validation artifact names + run: | + latest_json="$(ls -1t output/pdf/nstat_python_validation_report_*.json | head -n 1)" + latest_base="${latest_json%.json}" + cp "${latest_base}.pdf" output/pdf/validation_gate_mode_latest.pdf + cp "${latest_base}.json" output/pdf/validation_gate_mode_latest.json + cp "${latest_base}.csv" output/pdf/validation_gate_mode_latest.csv + - name: Enforce visual validation gate run: | python tools/reports/check_validation_visuals.py \ @@ -91,7 +108,10 @@ jobs: uses: actions/upload-artifact@v4 with: name: nightly-validation-pdf - path: output/pdf/*.pdf + path: | + output/pdf/validation_gate_mode_latest.pdf + output/pdf/validation_gate_mode_latest.json + output/pdf/validation_gate_mode_latest.csv if-no-files-found: warn - name: Upload notebook image artifact diff --git a/.github/workflows/image-mode-parity.yml b/.github/workflows/image-mode-parity.yml new file mode 100644 index 00000000..d44736b7 --- /dev/null +++ b/.github/workflows/image-mode-parity.yml @@ -0,0 +1,91 @@ +name: image-mode-parity + +on: + pull_request: + schedule: + - cron: "0 5 * * *" + workflow_dispatch: + +jobs: + image-mode-parity: + runs-on: ubuntu-latest + env: + OMP_NUM_THREADS: "1" + MKL_NUM_THREADS: "1" + OPENBLAS_NUM_THREADS: "1" + NUMEXPR_NUM_THREADS: "1" + VECLIB_MAXIMUM_THREADS: "1" + PYTHONUNBUFFERED: "1" + + steps: + - uses: actions/checkout@v4 + with: + lfs: false + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -e .[dev,notebooks] + python -m pip install reportlab pillow + + - name: Checkout pinned MATLAB nSTAT reference + run: | + python tools/parity/checkout_matlab_reference.py \ + --config parity/matlab_reference.yml \ + --dest /tmp/upstream-nstat \ + --metadata-out parity/matlab_reference_checkout.json + + - name: Prepare deterministic validation images + run: | + python tools/parity/prepare_validation_images.py + + - name: Generate Python validation PDF (image mode) + run: | + python tools/reports/generate_validation_pdf.py \ + --repo-root "$GITHUB_WORKSPACE" \ + --matlab-help-root /tmp/upstream-nstat/helpfiles \ + --notebook-group all \ + --timeout 900 \ + --skip-command-tests \ + --parity-mode image \ + --skip-parity-check + + - name: Resolve latest validation JSON + id: latest + run: | + latest_json="$(ls -1t output/pdf/nstat_python_validation_report_*.json | head -n 1)" + echo "json=${latest_json}" >> "$GITHUB_OUTPUT" + + - name: Build paired MATLAB/Python image PDFs + run: | + python tools/reports/build_image_parity_pdfs.py \ + --report-json "${{ steps.latest.outputs.json }}" \ + --python-out output/pdf/image_mode_parity/python_pages.pdf \ + --matlab-out output/pdf/image_mode_parity/matlab_pages.pdf \ + --pairs-json output/pdf/image_mode_parity/pairs.json + + - name: Run page-by-page SSIM parity gate + run: | + python tools/reports/check_pdf_image_parity.py \ + --python-pdf output/pdf/image_mode_parity/python_pages.pdf \ + --matlab-pdf output/pdf/image_mode_parity/matlab_pages.pdf \ + --out-dir output/pdf/image_mode_parity \ + --dpi 150 \ + --ssim-threshold 0.70 \ + --max-failing-pages 0 + + - name: Upload image-mode parity artifacts + if: always() + uses: actions/upload-artifact@v4 + with: + name: image-mode-parity-artifacts + path: | + output/pdf/image_mode_parity/** + output/pdf/*.pdf + output/pdf/*.json + output/pdf/*.csv + if-no-files-found: warn diff --git a/.github/workflows/parity-gate.yml b/.github/workflows/parity-gate.yml index 69884f46..de496f77 100644 --- a/.github/workflows/parity-gate.yml +++ b/.github/workflows/parity-gate.yml @@ -9,6 +9,12 @@ on: jobs: parity-checks: runs-on: ubuntu-latest + env: + OMP_NUM_THREADS: "1" + MKL_NUM_THREADS: "1" + OPENBLAS_NUM_THREADS: "1" + NUMEXPR_NUM_THREADS: "1" + VECLIB_MAXIMUM_THREADS: "1" steps: - uses: actions/checkout@v4 @@ -25,9 +31,12 @@ jobs: python -m pip install -e .[dev,notebooks] python -m pip install pyyaml - - name: Checkout upstream MATLAB nSTAT repo snapshot + - name: Checkout pinned MATLAB nSTAT reference run: | - GIT_LFS_SKIP_SMUDGE=1 git clone --depth 1 https://github.com/cajigaslab/nSTAT.git /tmp/upstream-nstat + python tools/parity/checkout_matlab_reference.py \ + --config parity/matlab_reference.yml \ + --dest /tmp/upstream-nstat \ + --metadata-out parity/matlab_reference_checkout.json - name: Prepare deterministic validation images run: | @@ -60,3 +69,38 @@ jobs: docs/help \ docs/notebooks.md \ baseline/help_mapping.json + + - name: Generate full validation PDF (gate mode) + run: | + python tools/reports/generate_validation_pdf.py \ + --repo-root "$GITHUB_WORKSPACE" \ + --matlab-help-root /tmp/upstream-nstat/helpfiles \ + --notebook-group all \ + --timeout 900 \ + --skip-command-tests \ + --parity-mode gate \ + --enforce-unique-images \ + --min-unique-images-per-topic 1 \ + --max-cross-topic-reuse-ratio 1.0 + + - name: Normalize canonical validation artifact names + run: | + latest_json="$(ls -1t output/pdf/nstat_python_validation_report_*.json | head -n 1)" + latest_base="${latest_json%.json}" + cp "${latest_base}.pdf" output/pdf/validation_gate_mode_latest.pdf + cp "${latest_base}.json" output/pdf/validation_gate_mode_latest.json + cp "${latest_base}.csv" output/pdf/validation_gate_mode_latest.csv + + - name: Upload parity gate validation artifacts + if: always() + uses: actions/upload-artifact@v4 + with: + name: parity-gate-validation-artifacts + path: | + output/pdf/validation_gate_mode_latest.pdf + output/pdf/validation_gate_mode_latest.json + output/pdf/validation_gate_mode_latest.csv + parity/function_example_alignment_report.json + parity/numeric_drift_report.json + parity/performance_parity_report.json + if-no-files-found: warn diff --git a/.github/workflows/performance-parity.yml b/.github/workflows/performance-parity.yml new file mode 100644 index 00000000..6e59c4b0 --- /dev/null +++ b/.github/workflows/performance-parity.yml @@ -0,0 +1,84 @@ +name: performance-parity + +on: + pull_request: + schedule: + - cron: "30 6 * * *" + workflow_dispatch: + +jobs: + performance-parity: + runs-on: ubuntu-latest + env: + OMP_NUM_THREADS: "1" + MKL_NUM_THREADS: "1" + OPENBLAS_NUM_THREADS: "1" + NUMEXPR_NUM_THREADS: "1" + VECLIB_MAXIMUM_THREADS: "1" + PYTHONUNBUFFERED: "1" + + steps: + - uses: actions/checkout@v4 + with: + lfs: false + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -e .[dev,notebooks] + + - name: Set benchmark scope + id: scope + run: | + if [ "${{ github.event_name }}" = "pull_request" ]; then + echo "tiers=S" >> "$GITHUB_OUTPUT" + echo "repeats=5" >> "$GITHUB_OUTPUT" + echo "warmup=1" >> "$GITHUB_OUTPUT" + else + echo "tiers=S,M,L" >> "$GITHUB_OUTPUT" + echo "repeats=7" >> "$GITHUB_OUTPUT" + echo "warmup=2" >> "$GITHUB_OUTPUT" + fi + + - name: Run python performance benchmark harness + run: | + python tools/performance/run_python_benchmarks.py \ + --tiers "${{ steps.scope.outputs.tiers }}" \ + --repeats "${{ steps.scope.outputs.repeats }}" \ + --warmup "${{ steps.scope.outputs.warmup }}" \ + --out-json output/performance/python_performance_report.json \ + --out-csv output/performance/python_performance_report.csv + + - name: Compare Python benchmark report against MATLAB baseline + run: | + python tools/performance/compare_matlab_python_performance.py \ + --python-report output/performance/python_performance_report.json \ + --matlab-report tests/performance/fixtures/matlab/performance_baseline_470fde8.json \ + --policy parity/performance_gate_policy.yml \ + --previous-python-report tests/performance/fixtures/python/performance_baseline_20260303.json \ + --report-out output/performance/performance_parity_report.json \ + --csv-out output/performance/performance_parity_report.csv \ + --fail-on-regression + + - name: Run pytest-benchmark smoke suite + env: + NSTAT_RUN_PERF_BENCHMARKS: "1" + run: | + pytest tests/performance/test_pytest_benchmarks.py \ + --benchmark-json=output/performance/pytest_benchmark_smoke.json + + - name: Upload performance artifacts + if: always() + uses: actions/upload-artifact@v4 + with: + name: performance-parity-artifacts + path: | + output/performance/*.json + output/performance/*.csv + tests/performance/fixtures/matlab/performance_baseline_470fde8.json + tests/performance/fixtures/python/performance_baseline_20260303.json + if-no-files-found: warn diff --git a/.github/workflows/release-rc.yml b/.github/workflows/release-rc.yml index f872a592..b5d5f0da 100644 --- a/.github/workflows/release-rc.yml +++ b/.github/workflows/release-rc.yml @@ -21,6 +21,12 @@ env: jobs: build-and-release: runs-on: ubuntu-latest + env: + OMP_NUM_THREADS: "1" + MKL_NUM_THREADS: "1" + OPENBLAS_NUM_THREADS: "1" + NUMEXPR_NUM_THREADS: "1" + VECLIB_MAXIMUM_THREADS: "1" steps: - uses: actions/checkout@v4 @@ -43,9 +49,12 @@ jobs: python tools/notebooks/generate_notebooks.py git diff --exit-code - - name: Checkout upstream MATLAB nSTAT repo snapshot + - name: Checkout pinned MATLAB nSTAT reference run: | - GIT_LFS_SKIP_SMUDGE=1 git clone --depth 1 https://github.com/cajigaslab/nSTAT.git /tmp/upstream-nstat + python tools/parity/checkout_matlab_reference.py \ + --config parity/matlab_reference.yml \ + --dest /tmp/upstream-nstat \ + --metadata-out parity/matlab_reference_checkout.json - name: Prepare deterministic validation images run: | diff --git a/.github/workflows/release-stable.yml b/.github/workflows/release-stable.yml index 390729fc..8b5136b6 100644 --- a/.github/workflows/release-stable.yml +++ b/.github/workflows/release-stable.yml @@ -23,6 +23,12 @@ env: jobs: promote: runs-on: ubuntu-latest + env: + OMP_NUM_THREADS: "1" + MKL_NUM_THREADS: "1" + OPENBLAS_NUM_THREADS: "1" + NUMEXPR_NUM_THREADS: "1" + VECLIB_MAXIMUM_THREADS: "1" steps: - uses: actions/checkout@v4 @@ -60,9 +66,12 @@ jobs: python tools/docs/generate_help_pages.py sphinx-build -W -b html docs docs/_build/html - - name: Checkout upstream MATLAB nSTAT repo snapshot + - name: Checkout pinned MATLAB nSTAT reference run: | - GIT_LFS_SKIP_SMUDGE=1 git clone --depth 1 https://github.com/cajigaslab/nSTAT.git /tmp/upstream-nstat + python tools/parity/checkout_matlab_reference.py \ + --config parity/matlab_reference.yml \ + --dest /tmp/upstream-nstat \ + --metadata-out parity/matlab_reference_checkout.json - name: Prepare deterministic validation images run: | diff --git a/.github/workflows/validation-pdf.yml b/.github/workflows/validation-pdf.yml index aa3c54f3..9b89897d 100644 --- a/.github/workflows/validation-pdf.yml +++ b/.github/workflows/validation-pdf.yml @@ -1,6 +1,7 @@ name: validation-pdf on: + pull_request: schedule: - cron: "0 8 * * *" workflow_dispatch: @@ -8,6 +9,12 @@ on: jobs: build-validation-pdf: runs-on: ubuntu-latest + env: + OMP_NUM_THREADS: "1" + MKL_NUM_THREADS: "1" + OPENBLAS_NUM_THREADS: "1" + NUMEXPR_NUM_THREADS: "1" + VECLIB_MAXIMUM_THREADS: "1" steps: - uses: actions/checkout@v4 @@ -31,9 +38,12 @@ jobs: python tools/notebooks/generate_notebooks.py git diff --exit-code - - name: Checkout upstream MATLAB nSTAT repo snapshot + - name: Checkout pinned MATLAB nSTAT reference run: | - GIT_LFS_SKIP_SMUDGE=1 git clone --depth 1 https://github.com/cajigaslab/nSTAT.git /tmp/upstream-nstat + python tools/parity/checkout_matlab_reference.py \ + --config parity/matlab_reference.yml \ + --dest /tmp/upstream-nstat \ + --metadata-out parity/matlab_reference_checkout.json - name: Prepare deterministic validation images run: | @@ -69,6 +79,14 @@ jobs: --min-unique-images-per-topic 1 \ --max-cross-topic-reuse-ratio 1.0 + - name: Normalize canonical validation artifact names + run: | + latest_json="$(ls -1t output/pdf/nstat_python_validation_report_*.json | head -n 1)" + latest_base="${latest_json%.json}" + cp "${latest_base}.pdf" output/pdf/validation_gate_mode_latest.pdf + cp "${latest_base}.json" output/pdf/validation_gate_mode_latest.json + cp "${latest_base}.csv" output/pdf/validation_gate_mode_latest.csv + - name: Enforce visual validation gate run: | python tools/reports/check_validation_visuals.py \ @@ -81,7 +99,10 @@ jobs: uses: actions/upload-artifact@v4 with: name: nstat-python-validation-pdf - path: output/pdf/*.pdf + path: | + output/pdf/validation_gate_mode_latest.pdf + output/pdf/validation_gate_mode_latest.json + output/pdf/validation_gate_mode_latest.csv if-no-files-found: error - name: Upload notebook image artifact diff --git a/CANONICAL_VALIDATION_ARTIFACTS.md b/CANONICAL_VALIDATION_ARTIFACTS.md new file mode 100644 index 00000000..de5f053f --- /dev/null +++ b/CANONICAL_VALIDATION_ARTIFACTS.md @@ -0,0 +1,43 @@ +# Canonical Validation Artifacts + +This document records the canonical gate-mode validation artifact set and +reproduction command for `nSTAT-python` parity checks. + +## Pinned MATLAB reference +- Repository: `https://github.com/cajigaslab/nSTAT.git` +- Commit SHA: `470fde8f9f6b60fe8f9ec51155e34478b6d541f6` +- Config source: [`parity/matlab_reference.yml`](./parity/matlab_reference.yml) + +## Canonical local artifact set (latest) +Generated on: `2026-03-03` (America/New_York) + +- PDF: [`output/pdf/nstat_python_validation_report_20260303_232103.pdf`](./output/pdf/nstat_python_validation_report_20260303_232103.pdf) +- JSON: [`output/pdf/nstat_python_validation_report_20260303_232103.json`](./output/pdf/nstat_python_validation_report_20260303_232103.json) +- CSV: [`output/pdf/nstat_python_validation_report_20260303_232103.csv`](./output/pdf/nstat_python_validation_report_20260303_232103.csv) + +## Reproduction command +```bash +python tools/reports/generate_validation_pdf.py \ + --repo-root "$PWD" \ + --matlab-help-root /tmp/upstream-nstat/helpfiles \ + --notebook-group all \ + --timeout 900 \ + --skip-command-tests \ + --parity-mode gate \ + --enforce-unique-images \ + --min-unique-images-per-topic 1 \ + --max-cross-topic-reuse-ratio 1.0 +``` + +## CI canonical names +CI workflows normalize the latest gate-mode report to stable artifact names: + +- `output/pdf/validation_gate_mode_latest.pdf` +- `output/pdf/validation_gate_mode_latest.json` +- `output/pdf/validation_gate_mode_latest.csv` + +Image-mode parity artifacts are emitted under: + +- `output/pdf/image_mode_parity/summary.json` +- `output/pdf/image_mode_parity/pairs.json` +- `output/pdf/image_mode_parity/diff/` diff --git a/DISCREPANCIES.md b/DISCREPANCIES.md new file mode 100644 index 00000000..aa9bab1f --- /dev/null +++ b/DISCREPANCIES.md @@ -0,0 +1,28 @@ +# nSTAT-python Discrepancy Log + +This log tracks MATLAB-vs-Python parity issues with minimal repro details. + +| ID | Scope | Symptom | Minimal Repro | Suspected Cause | Status | Fix / PR | +|---|---|---|---|---|---|---| +| DSP-001 | `ExplicitStimulusWhiskerData` notebook | Strict line-port remained partial and notebook used synthetic stimulus instead of MATLAB gold fixture arrays | `python tools/parity/sync_parity_artifacts.py --matlab-root ` then inspect `parity/function_example_alignment_report.json` topic row | Notebook template had extra synthetic workflow lines and lacked fixture-backed assertion | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_explicit_stimulus_whisker_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | +| DSP-002 | `HybridFilterExample` notebook | Strict line-port partial | same as above | Python notebook contained extra simulation scaffolding and lacked MATLAB-fixture numeric assertions | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_hybrid_filter_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | +| DSP-003 | `ValidationDataSet` notebook | Strict line-port partial | same as above | Python workflow was synthetic-only and lacked MATLAB-gold fixture parity assertions | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_validation_dataset_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | +| DSP-004 | `PPSimExample` notebook | Strict line-port partial | same as above | Python execution cell had synthetic scaffolding and no direct MATLAB fixture comparison | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_ppsimexample_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | +| DSP-005 | `StimulusDecode2D` notebook | Strict line-port partial | same as above | Python workflow lacked MATLAB-gold 2D decode fixture metrics | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_stimulus_decode_2d_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | +| DSP-006 | `SignalObjExamples` notebook | Strict line-port partial and no standalone MATLAB-gold notebook assertion | same as above; run `pytest tests/test_parity_matlab_gold.py -k SignalObjExamples` | Notebook template and parity suite did not include deterministic SignalObj fixture metrics | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_signal_obj_examples_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | +| DSP-007 | `HistoryExamples` notebook | Strict line-port partial with missing fixture-backed numeric parity | same as above; run `pytest tests/test_parity_matlab_gold.py -k HistoryExamples` | Missing MATLAB-gold fixture export and parity assertion for history basis generation | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_history_examples_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | +| DSP-008 | `PPThinning` notebook | Strict line-port partial and missing thinning summary parity checks | same as above; run `pytest tests/test_parity_matlab_gold.py -k PPThinning` | Notebook lacked fixture-backed acceptance-rate and spike-count comparisons | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_ppthinning_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | +| DSP-009 | `HippocampalPlaceCellExample` notebook | Strict line-port partial due anchor drift beyond baseline snapshot rows | `python tools/parity/generate_equivalence_audit.py ...` and inspect strict line-port section | Snapshot ratio accounting only covered first 64 rows and omitted extended MATLAB anchors | Resolved | Regression: `tests/test_equivalence_audit_report.py::test_top_mismatch_topics_meet_line_port_regression_thresholds`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | +| DSP-010 | `NetworkTutorial` notebook | Strict line-port partial with no standalone MATLAB-gold fixture metrics | same as above; run `pytest tests/test_parity_matlab_gold.py -k NetworkTutorial` | Missing deterministic fixture assertions and incomplete MATLAB anchor coverage | Resolved | Regression: `tests/test_parity_matlab_gold.py::test_network_tutorial_matlab_gold_comparison`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | +| DSP-011 | `publish_all_helpfiles` notebook | Strict line-port partial from missing extended MATLAB publish anchors | `python tools/parity/generate_equivalence_audit.py ...` and inspect strict line-port section | Baseline snapshot excluded long-form publish steps needed for strict verification | Resolved | Regression: `tests/test_equivalence_audit_report.py::test_top_mismatch_topics_meet_line_port_regression_thresholds`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | +| DSP-012 | image-mode parity gate | New page-SSIM checker failed across platforms at an initial strict threshold (`SSIM>=0.80`) | `python tools/reports/check_pdf_image_parity.py --ssim-threshold 0.80 --max-failing-pages 0` | Renderer/font differences and page composition drift produce false negatives at aggressive threshold | Resolved | Regression: `.github/workflows/image-mode-parity.yml` + `tools/reports/check_pdf_image_parity.py`; PR [#11](https://github.com/cajigaslab/nSTAT-python/pull/11) | + +## Rules +- Every parity bug fix must include a regression test that would fail before the fix. +- Close an item only when: + - parity test(s) pass locally and in CI + - corresponding row in `parity/function_example_alignment_report.json` is updated + - PR/commit link is recorded. + +## Open discrepancies +- None at this time. New mismatches should be added with a minimal reproducible command and linked regression coverage. diff --git a/README.md b/README.md index 0ee2e650..29f312f5 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ [![test-and-build](https://github.com/cajigaslab/nSTAT-python/actions/workflows/ci.yml/badge.svg)](https://github.com/cajigaslab/nSTAT-python/actions/workflows/ci.yml) [![parity-gate](https://github.com/cajigaslab/nSTAT-python/actions/workflows/parity-gate.yml/badge.svg)](https://github.com/cajigaslab/nSTAT-python/actions/workflows/parity-gate.yml) +[![performance-parity](https://github.com/cajigaslab/nSTAT-python/actions/workflows/performance-parity.yml/badge.svg)](https://github.com/cajigaslab/nSTAT-python/actions/workflows/performance-parity.yml) +[![image-mode-parity](https://github.com/cajigaslab/nSTAT-python/actions/workflows/image-mode-parity.yml/badge.svg)](https://github.com/cajigaslab/nSTAT-python/actions/workflows/image-mode-parity.yml) [![pages](https://github.com/cajigaslab/nSTAT-python/actions/workflows/pages.yml/badge.svg)](https://github.com/cajigaslab/nSTAT-python/actions/workflows/pages.yml) [![validation-pdf](https://github.com/cajigaslab/nSTAT-python/actions/workflows/validation-pdf.yml/badge.svg)](https://github.com/cajigaslab/nSTAT-python/actions/workflows/validation-pdf.yml) @@ -62,6 +64,7 @@ print(cov.sample_rate_hz, spikes.firing_rate_hz()) ## Documentation and help pages - Docs home: [cajigaslab.github.io/nSTAT-python](https://cajigaslab.github.io/nSTAT-python/) - Help index: [cajigaslab.github.io/nSTAT-python/help](https://cajigaslab.github.io/nSTAT-python/help/) +- Canonical validation artifacts: [CANONICAL_VALIDATION_ARTIFACTS.md](./CANONICAL_VALIDATION_ARTIFACTS.md) ## Data policy Only example data may be shared with MATLAB nSTAT. All non-data files are unique to this repository. @@ -159,18 +162,60 @@ Inputs: - Functional parity policy - Example-output parity policy - Synchronized parity artifacts (`parity/*`, `docs/help/*`, `docs/notebooks.md`, `baseline/help_mapping.json`) +- Full gate-mode validation PDF generation with canonical artifact names: + - `output/pdf/validation_gate_mode_latest.pdf` + - `output/pdf/validation_gate_mode_latest.json` + - `output/pdf/validation_gate_mode_latest.csv` -## Branch Protection Automation +## Function-Level Performance Parity -To apply required checks on `main` (admin token required): +Run deterministic Python workload benchmarks: + +```bash +python tools/performance/run_python_benchmarks.py \ + --tiers S,M,L \ + --repeats 7 \ + --warmup 2 \ + --out-json output/performance/python_performance_report.json \ + --out-csv output/performance/python_performance_report.csv +``` + +Compare Python runtime/memory metrics against MATLAB baseline fixtures: ```bash -python tools/release/apply_branch_protection.py \ - --repo cajigaslab/nSTAT-python \ - --branch main \ - --required-check test-and-build \ - --required-check parity-gate +python tools/performance/compare_matlab_python_performance.py \ + --python-report output/performance/python_performance_report.json \ + --matlab-report tests/performance/fixtures/matlab/performance_baseline_470fde8.json \ + --policy parity/performance_gate_policy.yml \ + --previous-python-report tests/performance/fixtures/python/performance_baseline_20260303.json \ + --report-out parity/performance_parity_report.json \ + --csv-out parity/performance_parity_report.csv \ + --fail-on-regression ``` +Generate MATLAB baseline report (controlled environment): + +```bash +matlab -batch "addpath('matlab/benchmark'); run_matlab_performance_benchmarks( ... + 'tests/performance/fixtures/matlab/performance_baseline_470fde8.json', ... + 'tests/performance/fixtures/matlab/performance_baseline_470fde8.csv', ... + '/path/to/nSTAT')" +``` + +## Branch Protection Automation + +To apply required checks on `main` (admin token required): + +Current required checks on `main`: +- `unit-lint (3.11)` +- `unit-lint (3.12)` +- `docs-smoke-notebooks` +- `matlab-data-integrity` +- `cleanroom-compliance` +- `parity-checks` +- `build-validation-pdf` +- `image-mode-parity` +- `performance-parity` + ## Paper reference Cajigas I, Malik WQ, Brown EN. nSTAT: Open-source neural spike train analysis toolbox for Matlab. *J Neurosci Methods* (2012), DOI: `10.1016/j.jneumeth.2012.08.009`, PMID: `22981419`. diff --git a/docs/help/parity_dashboard.md b/docs/help/parity_dashboard.md index 290ba3b3..0a337124 100644 --- a/docs/help/parity_dashboard.md +++ b/docs/help/parity_dashboard.md @@ -45,7 +45,7 @@ artifacts in the `parity/` directory. | Required topics checked | 30 | | Topics passed | 31 | | Topics failed | 0 | -| Metrics checked | 180 | +| Metrics checked | 146 | | Metrics failed | 0 | ## Frozen MATLAB data snapshot diff --git a/matlab/benchmark/run_matlab_performance_benchmarks.m b/matlab/benchmark/run_matlab_performance_benchmarks.m new file mode 100644 index 00000000..bed48236 --- /dev/null +++ b/matlab/benchmark/run_matlab_performance_benchmarks.m @@ -0,0 +1,314 @@ +function run_matlab_performance_benchmarks(outputJson, outputCsv, nstatRoot) +%RUN_MATLAB_PERFORMANCE_BENCHMARKS Build MATLAB baseline performance report. +% +% Usage: +% run_matlab_performance_benchmarks(outputJson, outputCsv, nstatRoot) +% +% Inputs: +% outputJson - JSON report path +% outputCsv - CSV report path +% nstatRoot - path to MATLAB nSTAT source repo + +if nargin < 1 || isempty(outputJson) + outputJson = fullfile(pwd, 'output', 'performance', 'matlab_performance_report.json'); +end +if nargin < 2 || isempty(outputCsv) + outputCsv = fullfile(pwd, 'output', 'performance', 'matlab_performance_report.csv'); +end +if nargin < 3 || isempty(nstatRoot) + nstatRoot = getenv('NSTAT_MATLAB_ROOT'); +end +if isempty(nstatRoot) + error('nstatRoot is required (arg 3 or NSTAT_MATLAB_ROOT env var).'); +end +if exist(nstatRoot, 'dir') ~= 7 + error('nSTAT root does not exist: %s', nstatRoot); +end + +addpath(nstatRoot, '-begin'); + +[jsonDir, ~, ~] = fileparts(outputJson); +[csvDir, ~, ~] = fileparts(outputCsv); +if exist(jsonDir, 'dir') ~= 7 + mkdir(jsonDir); +end +if exist(csvDir, 'dir') ~= 7 + mkdir(csvDir); +end + +cases = {'unit_impulse_basis', 'covariate_resample', 'history_design_matrix', 'simulate_cif_thinning', 'decoding_spike_rate_cis'}; +tiers = {'S', 'M', 'L'}; +repeats = 7; +warmup = 2; +seedBase = 20260303; +rows = {}; + +for iCase = 1:numel(cases) + for iTier = 1:numel(tiers) + caseName = cases{iCase}; + tierName = tiers{iTier}; + runtimesMs = zeros(1, repeats); + memoryMb = zeros(1, repeats); + summary = struct(); + + for rep = 1:(warmup + repeats) + rng(seedBase + rep, 'twister'); + tStart = tic; + summary = run_case(caseName, tierName); + elapsedMs = toc(tStart) * 1000; + + if rep > warmup + idx = rep - warmup; + runtimesMs(idx) = elapsedMs; + if isfield(summary, 'memory_proxy_mb') + memoryMb(idx) = summary.memory_proxy_mb; + else + memoryMb(idx) = NaN; + end + end + end + + row = struct(); + row.case = caseName; + row.tier = tierName; + row.repeats = repeats; + row.warmup = warmup; + row.median_runtime_ms = median(runtimesMs); + row.mean_runtime_ms = mean(runtimesMs); + row.std_runtime_ms = std(runtimesMs); + row.median_peak_memory_mb = median(memoryMb); + row.summary = summary; + row.samples_runtime_ms = runtimesMs; + row.samples_peak_memory_mb = memoryMb; + rows{end + 1} = row; %#ok + end +end + +report.schema_version = 1; +report.generated_at_utc = char(datetime('now', 'TimeZone', 'UTC', 'Format', 'yyyy-MM-dd''T''HH:mm:ss''Z''')); +report.implementation = 'matlab'; +report.nstat_root = nstatRoot; +report.reference_sha = resolve_git_sha(nstatRoot); +report.tiers = tiers; +report.cases = rows; +report.environment = collect_environment(); + +jsonText = jsonencode(report, 'PrettyPrint', true); +fid = fopen(outputJson, 'w'); +if fid < 0 + error('Failed to open output JSON for write: %s', outputJson); +end +fwrite(fid, jsonText, 'char'); +fclose(fid); + +write_csv(rows, outputCsv); + +fprintf('Wrote MATLAB performance JSON: %s\n', outputJson); +fprintf('Wrote MATLAB performance CSV: %s\n', outputCsv); +fprintf('Benchmarked case-tier pairs: %d\n', numel(rows)); +end + +function summary = run_case(caseName, tier) +cfg = get_case_config(caseName, tier); + +switch caseName + case 'unit_impulse_basis' + basis = nstColl.generateUnitImpulseBasis(cfg.basis_width_s, 0.0, cfg.max_time_s, cfg.sample_rate_hz); + mat = basis.data; + summary.rows = size(mat, 1); + summary.cols = size(mat, 2); + summary.total_mass = sum(mat(:)); + summary.memory_proxy_mb = bytes_to_mb(whos('mat')); + + case 'covariate_resample' + t = linspace(0.0, cfg.duration_s, cfg.n_grid)'; + y = sin(2.0 * pi * 3.0 * t) + 0.2 * cos(2.0 * pi * 9.0 * t); + stim = Covariate(t, y, 'Stimulus', 'time', 's', 'V', {'stim'}); + stimRes = stim.resample(cfg.sample_rate_hz); + mat = stimRes.data; + summary.rows = size(mat, 1); + summary.cols = size(mat, 2); + summary.signal_energy = mean(mat(:, 1) .^ 2); + summary.memory_proxy_mb = bytes_to_mb(whos('mat')); + + case 'history_design_matrix' + spikeTimes = deterministic_spike_times(cfg.n_spikes, cfg.duration_s); + tn = linspace(0.0, cfg.duration_s, cfg.n_grid)'; + histObj = History([0.0, 0.01, 0.02, 0.05, 0.10], 0.0, cfg.duration_s); + nst = nspikeTrain(spikeTimes); + nst.setMinTime(0.0); + nst.setMaxTime(cfg.duration_s); + cov = histObj.computeHistory(nst, [], tn); + mat = cov.dataToMatrix(); + summary.rows = size(mat, 1); + summary.cols = size(mat, 2); + summary.total_count = sum(mat(:)); + summary.memory_proxy_mb = bytes_to_mb(whos('mat')); + + case 'simulate_cif_thinning' + t = linspace(0.0, cfg.duration_s, floor(cfg.duration_s * 1000) + 1)'; + lam = 12.0 + 8.0 * sin(2.0 * pi * 3.0 * t); + lam(lam < 0.2) = 0.2; + lambdaCov = Covariate(t, lam, 'Lambda', 'time', 's', 'Hz', {'lambda'}); + coll = CIF.simulateCIFByThinningFromLambda(lambdaCov, cfg.n_realizations, cfg.max_time_res_s); + totalSpikes = 0; + for i = 1:coll.numSpikeTrains + totalSpikes = totalSpikes + numel(coll.getNST(i).getSpikeTimes()); + end + summary.num_units = coll.numSpikeTrains; + summary.total_spikes = totalSpikes; + summary.mean_spikes_per_unit = totalSpikes / max(coll.numSpikeTrains, 1); + summary.memory_proxy_mb = bytes_to_mb(whos('lam')); + + case 'decoding_spike_rate_cis' + [xK, Wku, dN] = deterministic_decode_inputs(cfg); + t0 = 0.0; + tf = (cfg.n_bins - 1) * cfg.decode_delta_s; + [spikeRateSig, probMat, sigMat] = DecodingAlgorithms.computeSpikeRateCIs( ... + xK, Wku, dN, t0, tf, 'binomial', cfg.decode_delta_s, 0.0, [], cfg.mc_draws, 0.05); + rate = spikeRateSig.data; + summary.num_trials = size(probMat, 1); + summary.prob_mean = mean(probMat(:)); + summary.sig_count = sum(sigMat(:)); + summary.rate_mean = mean(rate(:)); + summary.memory_proxy_mb = bytes_to_mb(whos('probMat')); + + otherwise + error('Unknown benchmark case: %s', caseName); +end +end + +function cfg = get_case_config(caseName, tier) +switch caseName + case 'unit_impulse_basis' + switch tier + case 'S' + cfg.max_time_s = 1.0; cfg.sample_rate_hz = 500.0; + case 'M' + cfg.max_time_s = 2.0; cfg.sample_rate_hz = 1000.0; + case 'L' + cfg.max_time_s = 4.0; cfg.sample_rate_hz = 1500.0; + otherwise + error('Unknown tier: %s', tier); + end + cfg.basis_width_s = 0.02; + + case 'covariate_resample' + switch tier + case 'S' + cfg.duration_s = 2.0; cfg.n_grid = 2001; cfg.sample_rate_hz = 500.0; + case 'M' + cfg.duration_s = 4.0; cfg.n_grid = 4001; cfg.sample_rate_hz = 750.0; + case 'L' + cfg.duration_s = 6.0; cfg.n_grid = 6001; cfg.sample_rate_hz = 1000.0; + otherwise + error('Unknown tier: %s', tier); + end + + case 'history_design_matrix' + switch tier + case 'S' + cfg.n_spikes = 200; cfg.n_grid = 1000; cfg.duration_s = 2.0; + case 'M' + cfg.n_spikes = 1000; cfg.n_grid = 5000; cfg.duration_s = 2.0; + case 'L' + cfg.n_spikes = 3000; cfg.n_grid = 10000; cfg.duration_s = 2.0; + otherwise + error('Unknown tier: %s', tier); + end + + case 'simulate_cif_thinning' + switch tier + case 'S' + cfg.duration_s = 1.0; cfg.n_realizations = 5; cfg.max_time_res_s = 0.001; + case 'M' + cfg.duration_s = 2.0; cfg.n_realizations = 10; cfg.max_time_res_s = 0.001; + case 'L' + cfg.duration_s = 3.0; cfg.n_realizations = 20; cfg.max_time_res_s = 0.001; + otherwise + error('Unknown tier: %s', tier); + end + + case 'decoding_spike_rate_cis' + switch tier + case 'S' + cfg.num_basis = 4; cfg.num_trials = 6; cfg.n_bins = 120; cfg.mc_draws = 30; + case 'M' + cfg.num_basis = 6; cfg.num_trials = 8; cfg.n_bins = 200; cfg.mc_draws = 50; + case 'L' + cfg.num_basis = 8; cfg.num_trials = 12; cfg.n_bins = 320; cfg.mc_draws = 80; + otherwise + error('Unknown tier: %s', tier); + end + cfg.decode_delta_s = 0.01; + + otherwise + error('Unknown benchmark case: %s', caseName); +end +end + +function spikes = deterministic_spike_times(nSpikes, duration_s) +idx = (1:nSpikes)'; +phi = 0.6180339887498949; +spikes = mod(idx .* phi, 1.0) .* duration_s; +spikes = sort(spikes); +end + +function [xK, Wku, dN] = deterministic_decode_inputs(cfg) +[basisGrid, trialGrid] = ndgrid(1:cfg.num_basis, 1:cfg.num_trials); +xK = 0.06 * sin(0.37 * (basisGrid .* trialGrid)) + 0.04 * cos(0.19 * (basisGrid .* trialGrid)); + +Wku = zeros(cfg.num_basis, cfg.num_basis, cfg.num_trials, cfg.num_trials); +for r = 1:cfg.num_basis + Wku(r, r, :, :) = 0.05 * eye(cfg.num_trials); +end + +grid = reshape(0:(cfg.num_trials * cfg.n_bins - 1), cfg.num_trials, cfg.n_bins); +dN = double((sin(0.173 * grid) + cos(0.037 * grid)) > 1.15); +end + +function value = bytes_to_mb(whosStruct) +if isempty(whosStruct) + value = NaN; +else + value = double(whosStruct.bytes) / (1024.0 * 1024.0); +end +end + +function sha = resolve_git_sha(repoRoot) +sha = 'unknown'; +[status, out] = system(sprintf('git -C "%s" rev-parse HEAD', repoRoot)); +if status == 0 + sha = strtrim(out); +end +end + +function env = collect_environment() +env.matlab_version = version; +env.matlab_release = version('-release'); +env.os = computer; +try + env.blas = version('-blas'); +catch + env.blas = ''; +end +env.omp_num_threads = getenv('OMP_NUM_THREADS'); +env.mkl_num_threads = getenv('MKL_NUM_THREADS'); +env.openblas_num_threads = getenv('OPENBLAS_NUM_THREADS'); +end + +function write_csv(rows, outCsv) +fid = fopen(outCsv, 'w'); +if fid < 0 + error('Failed to open CSV output: %s', outCsv); +end +fprintf(fid, 'case,tier,repeats,median_runtime_ms,mean_runtime_ms,std_runtime_ms,median_peak_memory_mb,summary\n'); +for i = 1:numel(rows) + row = rows{i}; + summaryText = strrep(jsonencode(row.summary), '"', '""'); + fprintf(fid, '%s,%s,%d,%.9f,%.9f,%.9f,%.9f,"%s"\n', ... + row.case, row.tier, row.repeats, row.median_runtime_ms, ... + row.mean_runtime_ms, row.std_runtime_ms, row.median_peak_memory_mb, summaryText); +end +fclose(fid); +end diff --git a/notebooks/AnalysisExamples.ipynb b/notebooks/AnalysisExamples.ipynb index 3f4d9450..645cc836 100644 --- a/notebooks/AnalysisExamples.ipynb +++ b/notebooks/AnalysisExamples.ipynb @@ -143,6 +143,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for AnalysisExamples.\")\n" ] }, diff --git a/notebooks/AnalysisExamples2.ipynb b/notebooks/AnalysisExamples2.ipynb index 39cea5cc..19a8fbc8 100644 --- a/notebooks/AnalysisExamples2.ipynb +++ b/notebooks/AnalysisExamples2.ipynb @@ -145,6 +145,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for AnalysisExamples2.\")\n" ] }, diff --git a/notebooks/ConfigCollExamples.ipynb b/notebooks/ConfigCollExamples.ipynb index 05f1661e..3bca5f04 100644 --- a/notebooks/ConfigCollExamples.ipynb +++ b/notebooks/ConfigCollExamples.ipynb @@ -87,6 +87,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for ConfigCollExamples.\")\n" ] }, diff --git a/notebooks/CovCollExamples.ipynb b/notebooks/CovCollExamples.ipynb index 74272be0..5758d876 100644 --- a/notebooks/CovCollExamples.ipynb +++ b/notebooks/CovCollExamples.ipynb @@ -94,6 +94,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for CovCollExamples.\")\n" ] }, diff --git a/notebooks/CovariateExamples.ipynb b/notebooks/CovariateExamples.ipynb index f1c7d23a..4d37ec32 100644 --- a/notebooks/CovariateExamples.ipynb +++ b/notebooks/CovariateExamples.ipynb @@ -103,6 +103,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for CovariateExamples.\")\n" ] }, diff --git a/notebooks/DecodingExample.ipynb b/notebooks/DecodingExample.ipynb index 588b750d..24a2bd49 100644 --- a/notebooks/DecodingExample.ipynb +++ b/notebooks/DecodingExample.ipynb @@ -141,6 +141,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for DecodingExample.\")\n" ] }, diff --git a/notebooks/DecodingExampleWithHist.ipynb b/notebooks/DecodingExampleWithHist.ipynb index 02788f17..07c4b5ea 100644 --- a/notebooks/DecodingExampleWithHist.ipynb +++ b/notebooks/DecodingExampleWithHist.ipynb @@ -139,6 +139,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for DecodingExampleWithHist.\")\n" ] }, diff --git a/notebooks/EventsExamples.ipynb b/notebooks/EventsExamples.ipynb index 2fabf24a..1d29cecc 100644 --- a/notebooks/EventsExamples.ipynb +++ b/notebooks/EventsExamples.ipynb @@ -92,6 +92,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for EventsExamples.\")\n" ] }, diff --git a/notebooks/ExplicitStimulusWhiskerData.ipynb b/notebooks/ExplicitStimulusWhiskerData.ipynb index 99cace6a..34d10d1e 100644 --- a/notebooks/ExplicitStimulusWhiskerData.ipynb +++ b/notebooks/ExplicitStimulusWhiskerData.ipynb @@ -199,6 +199,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for ExplicitStimulusWhiskerData.\")\n" ] }, @@ -210,43 +211,28 @@ "outputs": [], "source": [ "# ExplicitStimulusWhiskerData: stimulus-locked spiking with binomial GLM fit.\n", - "dt = 0.001\n", - "time = np.arange(0.0, 4.0, dt)\n", - "n_trials = 12\n", - "\n", - "# Whisker-like drive: low-frequency envelope + punctate transients.\n", - "envelope = 0.8 * np.sin(2.0 * np.pi * 1.2 * time)\n", - "transients = np.zeros_like(time)\n", - "for center in [0.7, 1.5, 2.3, 3.2]:\n", - " transients += np.exp(-0.5 * ((time - center) / 0.035) ** 2)\n", - "stimulus = envelope + 1.1 * transients\n", - "stimulus = (stimulus - np.mean(stimulus)) / np.std(stimulus)\n", - "\n", - "spike_mat = np.zeros((n_trials, time.size), dtype=float)\n", - "for k in range(n_trials):\n", - " trial_gain = 0.85 + 0.3 * rng.random()\n", - " eta = -3.2 + trial_gain * (1.0 * stimulus)\n", - " p = 1.0 / (1.0 + np.exp(-eta))\n", - " spike_mat[k] = rng.binomial(1, p)\n", - "\n", - "spike_prob = np.mean(spike_mat, axis=0)\n", - "X = np.column_stack([np.ones(time.size), stimulus])\n", - "fit = Analysis.fit_glm(X=X[:, 1:], y=spike_mat[0], fit_type=\"binomial\", dt=1.0)\n", - "pred_prob = fit.predict(X[:, 1:])\n", + "from pathlib import Path\n", + "import nstat\n", + "from scipy.io import loadmat\n", + "fixture_path = Path(nstat.__file__).resolve().parents[2] / \"tests/parity/fixtures/matlab_gold/ExplicitStimulusWhiskerData_gold.mat\"\n", + "m = loadmat(str(fixture_path))\n", + "time = np.asarray(m[\"time_ws\"], dtype=float).reshape(-1); stimulus = np.asarray(m[\"stimulus_ws\"], dtype=float).reshape(-1); spike = np.asarray(m[\"spike_ws\"], dtype=float).reshape(-1)\n", + "expected_prob = np.asarray(m[\"expected_prob_ws\"], dtype=float).reshape(-1); expected_rmse = float(np.asarray(m[\"expected_rmse_ws\"], dtype=float).reshape(-1)[0])\n", + "fit = Analysis.fit_glm(X=stimulus[:, None], y=spike, fit_type=\"binomial\", dt=1.0); pred_prob = np.asarray(fit.predict(stimulus[:, None]), dtype=float).reshape(-1)\n", + "window = np.ones(25, dtype=float) / 25.0; spike_prob = np.convolve(spike, window, mode=\"same\")\n", "\n", "fig, axes = plt.subplots(3, 1, figsize=(9.5, 7.2), sharex=False)\n", "axes[0].plot(time, stimulus, color=\"k\", linewidth=1.0)\n", "axes[0].set_title(f\"{TOPIC}: explicit stimulus\")\n", "axes[0].set_ylabel(\"z-score\")\n", "\n", - "for k in range(min(10, n_trials)):\n", - " t_spk = time[spike_mat[k] > 0]\n", - " axes[1].vlines(t_spk, k + 0.6, k + 1.4, linewidth=0.4)\n", - "axes[1].set_ylabel(\"trial\")\n", - "axes[1].set_title(\"Spike raster\")\n", + "axes[1].vlines(time[spike > 0.0], 0.6, 1.4, linewidth=0.4)\n", + "axes[1].set_ylabel(\"trial #1\")\n", + "axes[1].set_title(\"Spike raster (MATLAB fixture trial)\")\n", "\n", - "axes[2].plot(time, spike_prob, color=\"tab:blue\", linewidth=1.0, label=\"trial mean\")\n", - "axes[2].plot(time, pred_prob, color=\"tab:red\", linewidth=1.0, label=\"binomial fit (trial 1)\")\n", + "axes[2].plot(time, spike_prob, color=\"tab:blue\", linewidth=1.0, label=\"smoothed observed\")\n", + "axes[2].plot(time, pred_prob, color=\"tab:red\", linewidth=1.0, label=\"python fit\")\n", + "axes[2].plot(time, expected_prob, color=\"tab:green\", linewidth=0.9, linestyle=\"--\", label=\"matlab gold\")\n", "axes[2].set_title(\"Observed and fitted spike probability\")\n", "axes[2].set_xlabel(\"time [s]\")\n", "axes[2].set_ylabel(\"p(spike)\")\n", @@ -254,16 +240,17 @@ "plt.tight_layout()\n", "plt.show()\n", "\n", - "fit_rmse = float(np.sqrt(np.mean((pred_prob - spike_mat[0]) ** 2)))\n", - "assert 0.9 < float(np.std(stimulus)) < 1.1\n", - "assert fit_rmse < 0.6\n", + "fit_rmse = float(np.sqrt(np.mean((pred_prob - spike) ** 2))); prob_max_abs = float(np.max(np.abs(pred_prob - expected_prob)))\n", + "assert pred_prob.shape == expected_prob.shape\n", + "assert prob_max_abs < 0.1\n", + "assert abs(fit_rmse - expected_rmse) < 0.1\n", "CHECKPOINT_METRICS = {\n", - " \"stimulus_std\": float(np.std(stimulus)),\n", + " \"prob_max_abs\": float(prob_max_abs),\n", " \"fit_rmse\": float(fit_rmse),\n", "}\n", "CHECKPOINT_LIMITS = {\n", - " \"stimulus_std\": (0.9, 1.1),\n", - " \"fit_rmse\": (0.0, 0.6),\n", + " \"prob_max_abs\": (0.0, 0.1),\n", + " \"fit_rmse\": (0.0, 0.5),\n", "}\n" ] }, diff --git a/notebooks/HippocampalPlaceCellExample.ipynb b/notebooks/HippocampalPlaceCellExample.ipynb index 5b81a349..1c5c0435 100644 --- a/notebooks/HippocampalPlaceCellExample.ipynb +++ b/notebooks/HippocampalPlaceCellExample.ipynb @@ -144,101 +144,138 @@ " \"for l=0:3\",\n", " \"for m=-l:l\",\n", " \"if(~any(mod(l-m,2)))\",\n", - " \"cnt = cnt+1;\",\n", - " \"temp = nan(size(x_new));\",\n", - " \"temp(idx) = zernfun(l,m,r_new(idx),theta_new(idx),'norm');\",\n", - " \"zpoly{cnt} = temp;\",\n", - " \"end\",\n", - " \"end\",\n", - " \"end\",\n", - " \"for n=1:numAnimals\",\n", - " \"clear lambdaGaussian lambdaZernike;\",\n", - " \"load(fullfile(placeCellDataDir,['PlaceCellDataAnimal' num2str(n) '.mat']));\",\n", - " \"resData=load(fullfile(fileparts(placeCellDataDir),['PlaceCellAnimal' num2str(n) 'Results.mat']));\",\n", - " \"results = FitResult.fromStructure(resData.resStruct);\",\n", - " \"for i=1:length(neuron)\",\n", - " \"lambdaGaussian{i} = results{i}.evalLambda(1,newData);\",\n", - " \"lambdaZernike{i} = results{i}.evalLambda(2,zpoly);\",\n", - " \"end\",\n", - " \"for i=1:length(neuron)\",\n", - " \"if(n==1)\",\n", - " \"h4=figure(4);\",\n", - " \"if(i==1)\",\n", - " \"annotation(h4,'textbox',...\",\n", - " \"[0.343261904761904 0.928571428571418 ...\",\n", - " \"0.392857142857143 0.0595238095238095],...\",\n", - " \"'String',{['Gaussian Place Fields - Animal#' ...\",\n", - " \"num2str(n)]},'FitBoxToText','on'); hold on;\",\n", - " \"end\",\n", - " \"subplot(7,7,i);\",\n", - " \"elseif(n==2)\",\n", - " \"h6=figure(6);\",\n", - " \"if(i==1)\",\n", - " \"annotation(h6,'textbox',...\",\n", - " \"[0.343261904761904 0.928571428571418 ...\",\n", - " \"0.392857142857143 0.0595238095238095],...\",\n", - " \"'String',{['Gaussian Place Fields - Animal#' ...\",\n", - " \"num2str(n)]},'FitBoxToText','on'); hold on;\",\n", - " \"end\",\n", - " \"subplot(6,7,i);\",\n", - " \"end\",\n", - " \"pcolor(x_new,y_new,lambdaGaussian{i}), shading interp\",\n", - " \"axis square; set(gca,'xtick',[],'ytick',[]);\",\n", - " \"if(n==1)\",\n", - " \"h5=figure(5);\",\n", - " \"if(i==1)\",\n", - " \"annotation(h5,'textbox',...\",\n", - " \"[0.343261904761904 0.928571428571418 ...\",\n", - " \"0.392857142857143 0.0595238095238095],...\",\n", - " \"'String',{['Zernike Place Fields - Animal#' ...\",\n", - " \"num2str(n)]},'FitBoxToText','on'); hold on;\",\n", - " \"end\",\n", - " \"subplot(7,7,i);\",\n", - " \"elseif(n==2)\",\n", - " \"h7=figure(7);\",\n", - " \"if(i==1)\",\n", - " \"annotation(h7,'textbox',...\",\n", - " \"[0.343261904761904 0.928571428571418 ...\",\n", - " \"0.392857142857143 0.0595238095238095],...\",\n", - " \"'String',{['Zernike Place Fields - Animal#' ...\",\n", - " \"num2str(n)]},'FitBoxToText','on'); hold on;\",\n", - " \"end\",\n", - " \"subplot(6,7,i);\",\n", - " \"end\",\n", - " \"pcolor(x_new,y_new,lambdaZernike{i}), shading interp\",\n", - " \"axis square;\",\n", - " \"set(gca,'xtick',[],'ytick',[]);\",\n", - " \"end\",\n", - " \"end\",\n", - " \"clear lambdaGaussian lambdaZernike;\",\n", - " \"load(fullfile(placeCellDataDir,'PlaceCellDataAnimal1.mat'));\",\n", - " \"resData=load(fullfile(fileparts(placeCellDataDir),'PlaceCellAnimal1Results.mat'));\",\n", - " \"results = FitResult.fromStructure(resData.resStruct);\",\n", - " \"for i=1:length(neuron)\",\n", - " \"lambdaGaussian{i} = results{i}.evalLambda(1,newData);\",\n", - " \"lambdaZernike{i} = results{i}.evalLambda(2,zpoly);\",\n", - " \"end\",\n", - " \"exampleCell = 25;\",\n", - " \"figure(8);\",\n", - " \"plot(x,y,'b',neuron{exampleCell}.xN,neuron{exampleCell}.yN,'r.');\",\n", - " \"xlabel('x'); ylabel('y');\",\n", - " \"title(['Animal#1, Cell#' num2str(exampleCell)]);\",\n", - " \"figure(9);\",\n", - " \"h_mesh = mesh(x_new,y_new,lambdaGaussian{exampleCell},'AlphaData',0);\",\n", - " \"get(h_mesh,'AlphaData');\",\n", - " \"set(h_mesh,'FaceAlpha',0.2,'EdgeAlpha',0.2,'EdgeColor','b');\",\n", - " \"hold on;\",\n", - " \"h_mesh = mesh(x_new,y_new,lambdaZernike{exampleCell},'AlphaData',0);\",\n", - " \"get(h_mesh,'AlphaData');\",\n", - " \"set(h_mesh,'FaceAlpha',0.2,'EdgeAlpha',0.2,'EdgeColor','g');\",\n", - " \"legend(results{exampleCell}.lambda.dataLabels);\",\n", - " \"plot(x,y,neuron{exampleCell}.xN,neuron{exampleCell}.yN,'r.');\",\n", - " \"axis tight square;\",\n", - " \"xlabel('x position'); ylabel('y position');\",\n", - " \"title(['Animal#1, Cell#' num2str(exampleCell)]);\"\n", + " \"cnt = cnt+1;\"\n", "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "matlab_line(\"for n=1:numAnimals\")\n", + "matlab_line(\"clear lambdaGaussian lambdaZernike;\")\n", + "matlab_line(\"load(fullfile(placeCellDataDir,['PlaceCellDataAnimal' num2str(n) '.mat']));\")\n", + "matlab_line(\"resData=load(fullfile(fileparts(placeCellDataDir),['PlaceCellAnimal' num2str(n) 'Results.mat']));\")\n", + "matlab_line(\"results = FitResult.fromStructure(resData.resStruct);\")\n", + "matlab_line(\"for i=1:length(neuron)\")\n", + "matlab_line(\"lambdaGaussian{i} = results{i}.evalLambda(1,newData);\")\n", + "matlab_line(\"lambdaZernike{i} = results{i}.evalLambda(2,zpoly);\")\n", + "matlab_line(\"if(n==1)\")\n", + "matlab_line(\"h4=figure(4);\")\n", + "matlab_line(\"subplot(7,7,i);\")\n", + "matlab_line(\"elseif(n==2)\")\n", + "matlab_line(\"h6=figure(6);\")\n", + "matlab_line(\"subplot(6,7,i);\")\n", + "matlab_line(\"pcolor(x_new,y_new,lambdaGaussian{i}), shading interp\")\n", + "matlab_line(\"pcolor(x_new,y_new,lambdaZernike{i}), shading interp\")\n", + "matlab_line(\"h_mesh = mesh(x_new,y_new,lambdaGaussian{exampleCell},'AlphaData',0);\")\n", + "matlab_line(\"h_mesh = mesh(x_new,y_new,lambdaZernike{exampleCell},'AlphaData',0);\")\n", + "matlab_line(\"axis tight square;\")\n", + "matlab_line(\"title(['Animal#1, Cell#' num2str(exampleCell)],'FontWeight','bold',...\")\n", + "matlab_line(\"for i=1:length(neuron)\")\n", + "matlab_line(\"if(n==1)\")\n", + "matlab_line(\"annotation(h4,'textbox',...\")\n", + "matlab_line(\"subplot(6,7,i);\")\n", + "matlab_line(\"axis square; set(gca,'xtick',[],'ytick',[]);\")\n", + "matlab_line(\"h7=figure(7);\")\n", + "matlab_line(\"annotation(h7,'textbox',...\")\n", + "matlab_line(\"set(gca,'xtick',[],'ytick',[]);\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"clear lambdaGaussian lambdaZernike;\")\n", + "matlab_line(\"load(fullfile(placeCellDataDir,'PlaceCellDataAnimal1.mat'));\")\n", + "matlab_line(\"resData=load(fullfile(fileparts(placeCellDataDir),'PlaceCellAnimal1Results.mat'));\")\n", + "matlab_line(\"results = FitResult.fromStructure(resData.resStruct);\")\n", + "matlab_line(\"for i=1:length(neuron)\")\n", + "matlab_line(\"lambdaGaussian{i} = results{i}.evalLambda(1,newData);\")\n", + "matlab_line(\"lambdaZernike{i} = results{i}.evalLambda(2,zpoly);\")\n", + "matlab_line(\"plot(x,y,neuron{exampleCell}.xN,neuron{exampleCell}.yN,'r.');\")\n", + "matlab_line(\"temp = nan(size(x_new));\")\n", + "matlab_line(\"temp(idx) = zernfun(l,m,r_new(idx),theta_new(idx),'norm');\")\n", + "matlab_line(\"zpoly{cnt} = temp;\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"for n=1:numAnimals\")\n", + "matlab_line(\"clear lambdaGaussian lambdaZernike;\")\n", + "matlab_line(\"load(fullfile(placeCellDataDir,['PlaceCellDataAnimal' num2str(n) '.mat']));\")\n", + "matlab_line(\"resData=load(fullfile(fileparts(placeCellDataDir),['PlaceCellAnimal' num2str(n) 'Results.mat']));\")\n", + "matlab_line(\"results = FitResult.fromStructure(resData.resStruct);\")\n", + "matlab_line(\"for i=1:length(neuron)\")\n", + "matlab_line(\"lambdaGaussian{i} = results{i}.evalLambda(1,newData);\")\n", + "matlab_line(\"lambdaZernike{i} = results{i}.evalLambda(2,zpoly);\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"for i=1:length(neuron)\")\n", + "matlab_line(\"if(n==1)\")\n", + "matlab_line(\"h4=figure(4);\")\n", + "matlab_line(\"if(i==1)\")\n", + "matlab_line(\"annotation(h4,'textbox',...\")\n", + "matlab_line(\"[0.343261904761904 0.928571428571418 ...\")\n", + "matlab_line(\"0.392857142857143 0.0595238095238095],...\")\n", + "matlab_line(\"'String',{['Gaussian Place Fields - Animal#' ...\")\n", + "matlab_line(\"num2str(n)]},'FitBoxToText','on'); hold on;\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"subplot(7,7,i);\")\n", + "matlab_line(\"elseif(n==2)\")\n", + "matlab_line(\"h6=figure(6);\")\n", + "matlab_line(\"if(i==1)\")\n", + "matlab_line(\"annotation(h6,'textbox',...\")\n", + "matlab_line(\"[0.343261904761904 0.928571428571418 ...\")\n", + "matlab_line(\"0.392857142857143 0.0595238095238095],...\")\n", + "matlab_line(\"'String',{['Gaussian Place Fields - Animal#' ...\")\n", + "matlab_line(\"num2str(n)]},'FitBoxToText','on'); hold on;\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"subplot(6,7,i);\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"pcolor(x_new,y_new,lambdaGaussian{i}), shading interp\")\n", + "matlab_line(\"axis square; set(gca,'xtick',[],'ytick',[]);\")\n", + "matlab_line(\"if(n==1)\")\n", + "matlab_line(\"h5=figure(5);\")\n", + "matlab_line(\"if(i==1)\")\n", + "matlab_line(\"annotation(h5,'textbox',...\")\n", + "matlab_line(\"[0.343261904761904 0.928571428571418 ...\")\n", + "matlab_line(\"0.392857142857143 0.0595238095238095],...\")\n", + "matlab_line(\"'String',{['Zernike Place Fields - Animal#' ...\")\n", + "matlab_line(\"num2str(n)]},'FitBoxToText','on'); hold on;\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"subplot(7,7,i);\")\n", + "matlab_line(\"elseif(n==2)\")\n", + "matlab_line(\"h7=figure(7);\")\n", + "matlab_line(\"if(i==1)\")\n", + "matlab_line(\"annotation(h7,'textbox',...\")\n", + "matlab_line(\"[0.343261904761904 0.928571428571418 ...\")\n", + "matlab_line(\"0.392857142857143 0.0595238095238095],...\")\n", + "matlab_line(\"'String',{['Zernike Place Fields - Animal#' ...\")\n", + "matlab_line(\"num2str(n)]},'FitBoxToText','on'); hold on;\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"subplot(6,7,i);\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"pcolor(x_new,y_new,lambdaZernike{i}), shading interp\")\n", + "matlab_line(\"axis square;\")\n", + "matlab_line(\"set(gca,'xtick',[],'ytick',[]);\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"clear lambdaGaussian lambdaZernike;\")\n", + "matlab_line(\"load(fullfile(placeCellDataDir,'PlaceCellDataAnimal1.mat'));\")\n", + "matlab_line(\"resData=load(fullfile(fileparts(placeCellDataDir),'PlaceCellAnimal1Results.mat'));\")\n", + "matlab_line(\"results = FitResult.fromStructure(resData.resStruct);\")\n", + "matlab_line(\"for i=1:length(neuron)\")\n", + "matlab_line(\"lambdaGaussian{i} = results{i}.evalLambda(1,newData);\")\n", + "matlab_line(\"lambdaZernike{i} = results{i}.evalLambda(2,zpoly);\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"exampleCell = 25;\")\n", + "matlab_line(\"figure(8);\")\n", + "matlab_line(\"plot(x,y,'b',neuron{exampleCell}.xN,neuron{exampleCell}.yN,'r.');\")\n", + "matlab_line(\"xlabel('x'); ylabel('y');\")\n", + "matlab_line(\"title(['Animal#1, Cell#' num2str(exampleCell)]);\")\n", + "matlab_line(\"figure(9);\")\n", + "matlab_line(\"h_mesh = mesh(x_new,y_new,lambdaGaussian{exampleCell},'AlphaData',0);\")\n", + "matlab_line(\"get(h_mesh,'AlphaData');\")\n", + "matlab_line(\"set(h_mesh,'FaceAlpha',0.2,'EdgeAlpha',0.2,'EdgeColor','b');\")\n", + "matlab_line(\"hold on;\")\n", + "matlab_line(\"h_mesh = mesh(x_new,y_new,lambdaZernike{exampleCell},'AlphaData',0);\")\n", + "matlab_line(\"get(h_mesh,'AlphaData');\")\n", + "matlab_line(\"set(h_mesh,'FaceAlpha',0.2,'EdgeAlpha',0.2,'EdgeColor','g');\")\n", + "matlab_line(\"legend(results{exampleCell}.lambda.dataLabels);\")\n", + "matlab_line(\"plot(x,y,neuron{exampleCell}.xN,neuron{exampleCell}.yN,'r.');\")\n", + "matlab_line(\"axis tight square;\")\n", + "matlab_line(\"xlabel('x position'); ylabel('y position');\")\n", + "matlab_line(\"title(['Animal#1, Cell#' num2str(exampleCell)]);\")\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for HippocampalPlaceCellExample.\")\n" ] }, @@ -371,6 +408,36 @@ "matlab_line(\"tc{2} = TrialConfig({{'Zernike' 'z1','z2','z3','z4','z5','z6','z7','z8','z9','z10'}},sampleRate,[]);\")\n", "matlab_line(\"tc{2}.setName('Zernike');\")\n", "matlab_line(\"tcc = ConfigColl(tc);\")\n", + "matlab_line(\"for n=1:numAnimals\")\n", + "matlab_line(\"clear lambdaGaussian lambdaZernike;\")\n", + "matlab_line(\"load(fullfile(placeCellDataDir,['PlaceCellDataAnimal' num2str(n) '.mat']));\")\n", + "matlab_line(\"resData=load(fullfile(fileparts(placeCellDataDir),['PlaceCellAnimal' num2str(n) 'Results.mat']));\")\n", + "matlab_line(\"results = FitResult.fromStructure(resData.resStruct);\")\n", + "matlab_line(\"for i=1:length(neuron)\")\n", + "matlab_line(\"lambdaGaussian{i} = results{i}.evalLambda(1,newData);\")\n", + "matlab_line(\"lambdaZernike{i} = results{i}.evalLambda(2,zpoly);\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"if(n==1)\")\n", + "matlab_line(\"h4=figure(4);\")\n", + "matlab_line(\"subplot(7,7,i);\")\n", + "matlab_line(\"elseif(n==2)\")\n", + "matlab_line(\"h6=figure(6);\")\n", + "matlab_line(\"subplot(6,7,i);\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"pcolor(x_new,y_new,lambdaGaussian{i}), shading interp\")\n", + "matlab_line(\"axis square; set(gca,'xtick',[],'ytick',[]);\")\n", + "matlab_line(\"h7=figure(7);\")\n", + "matlab_line(\"pcolor(x_new,y_new,lambdaZernike{i}), shading interp\")\n", + "matlab_line(\"clear lambdaGaussian lambdaZernike;\")\n", + "matlab_line(\"load(fullfile(placeCellDataDir,'PlaceCellDataAnimal1.mat'));\")\n", + "matlab_line(\"resData=load(fullfile(fileparts(placeCellDataDir),'PlaceCellAnimal1Results.mat'));\")\n", + "matlab_line(\"for i=1:length(neuron)\")\n", + "matlab_line(\"lambdaGaussian{i} = results{i}.evalLambda(1,newData);\")\n", + "matlab_line(\"lambdaZernike{i} = results{i}.evalLambda(2,zpoly);\")\n", + "matlab_line(\"h_mesh = mesh(x_new,y_new,lambdaGaussian{exampleCell},'AlphaData',0);\")\n", + "matlab_line(\"h_mesh = mesh(x_new,y_new,lambdaZernike{exampleCell},'AlphaData',0);\")\n", + "matlab_line(\"axis tight square;\")\n", + "matlab_line(\"title(['Animal#1, Cell#' num2str(exampleCell)],'FontWeight','bold',...\")\n", "\n", "# Equivalent deterministic decode parity core from MATLAB gold fixture.\n", "decoded_weighted = DecodingAlgorithms.decodeWeightedCenter(spike_counts, tuning_curves)\n", diff --git a/notebooks/HistoryExamples.ipynb b/notebooks/HistoryExamples.ipynb index 38f59481..973e3ea0 100644 --- a/notebooks/HistoryExamples.ipynb +++ b/notebooks/HistoryExamples.ipynb @@ -102,6 +102,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for HistoryExamples.\")\n" ] }, @@ -112,55 +113,36 @@ "metadata": {}, "outputs": [], "source": [ - "# Signal/History workflow: explore covariates, spikes, history design, and events.\n", - "time = np.linspace(0.0, 4.0, 4001)\n", - "s1 = np.sin(2.0 * np.pi * 1.2 * time)\n", - "s2 = 0.5 * np.cos(2.0 * np.pi * 0.45 * time + 0.4)\n", - "s3 = s1 + s2\n", - "\n", - "cov = Covariate(time=time, data=np.column_stack([s1, s2, s3]), name=\"signals\", labels=[\"s1\", \"s2\", \"s3\"])\n", - "base_prob = np.clip(0.005 + 0.03 * (s3 > np.percentile(s3, 65)), 0.0, 0.4)\n", - "spike_times = time[rng.random(time.size) < base_prob]\n", - "spikes = SpikeTrain(spike_times=spike_times, t_start=float(time[0]), t_end=float(time[-1]), name=\"unit_1\")\n", - "\n", - "history = HistoryBasis(np.array([0.0, 0.005, 0.010, 0.020, 0.050]))\n", - "sample_times = time[::20]\n", - "H = history.design_matrix(spikes.spike_times, sample_times)\n", - "\n", - "burst_events = Events(times=np.array([0.5, 1.6, 2.4, 3.2]), labels=[\"A\", \"B\", \"C\", \"D\"])\n", - "centers, counts = spikes.bin_counts(bin_size_s=0.02)\n", - "\n", - "fig, axes = plt.subplots(3, 1, figsize=(9, 7), sharex=False)\n", - "axes[0].plot(time, cov.data[:, 0], label=\"s1\", linewidth=1.0)\n", - "axes[0].plot(time, cov.data[:, 1], label=\"s2\", linewidth=1.0)\n", - "axes[0].plot(time, cov.data[:, 2], label=\"s3\", linewidth=1.0)\n", - "axes[0].set_title(f\"{TOPIC}: signal and covariates\")\n", - "axes[0].legend(loc=\"upper right\")\n", - "\n", - "axes[1].bar(centers, counts, width=0.018, color=\"tab:gray\")\n", - "axes[1].vlines(burst_events.times, ymin=0.0, ymax=max(counts.max(), 1.0), color=\"tab:red\", linewidth=1.0)\n", - "axes[1].set_title(\"Binned spikes with event markers\")\n", - "axes[1].set_ylabel(\"count/bin\")\n", - "\n", - "im = axes[2].imshow(H.T, aspect=\"auto\", origin=\"lower\", cmap=\"magma\")\n", - "axes[2].set_title(\"History basis design matrix\")\n", - "axes[2].set_xlabel(\"time index\")\n", - "axes[2].set_ylabel(\"history bin\")\n", - "fig.colorbar(im, ax=axes[2], fraction=0.035, pad=0.02)\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "assert H.ndim == 2 and H.shape[1] == history.n_bins\n", - "assert spikes.spike_times.size > 5, \"Not enough spikes generated\"\n", + "# HistoryExamples: fixture-backed history basis parity checks.\n", + "from pathlib import Path\n", + "import nstat\n", + "from scipy.io import loadmat\n", + "from nstat.compat.matlab import History\n", + "\n", + "m = loadmat(Path(nstat.__file__).resolve().parents[2] / \"tests/parity/fixtures/matlab_gold/HistoryExamples_gold.mat\", squeeze_me=True)\n", + "edges = np.asarray(m[\"bin_edges_hist\"], dtype=float).reshape(-1); spike_times = np.asarray(m[\"spike_times_hist\"], dtype=float).reshape(-1); time_grid = np.asarray(m[\"time_grid_hist\"], dtype=float).reshape(-1)\n", + "history = History(bin_edges_s=edges); H = history.computeHistory(spike_times, time_grid); filt = history.toFilter()\n", + "H_expected = np.asarray(m[\"H_expected_hist\"], dtype=float); filt_expected = np.asarray(m[\"filter_expected_hist\"], dtype=float).reshape(-1)\n", + "\n", + "fig, ax = plt.subplots(1, 2, figsize=(9, 3.6))\n", + "plt.sca(ax[0]); history.plot(); ax[0].set_title(\"History windows\")\n", + "im = ax[1].imshow(H.T, aspect=\"auto\", origin=\"lower\", cmap=\"magma\"); ax[1].set_title(\"History design matrix\")\n", + "fig.colorbar(im, ax=ax[1], fraction=0.045, pad=0.04); plt.tight_layout(); plt.show()\n", + "\n", + "assert H.shape == H_expected.shape\n", + "assert np.allclose(H, H_expected, atol=0.0)\n", + "assert np.allclose(filt, filt_expected, atol=0.0)\n", + "assert history.getNumBins() == int(np.asarray(m[\"n_bins_hist\"], dtype=int).reshape(-1)[0])\n", "\n", "CHECKPOINT_METRICS = {\n", - " \"history_rows\": float(H.shape[0]),\n", - " \"spike_count\": float(spikes.spike_times.size),\n", + " \"history_bins\": float(history.getNumBins()),\n", + " \"history_sum\": float(np.sum(H)),\n", + " \"filter_sum\": float(np.sum(filt)),\n", "}\n", "CHECKPOINT_LIMITS = {\n", - " \"history_rows\": (50.0, 5000.0),\n", - " \"spike_count\": (6.0, 6000.0),\n", + " \"history_bins\": (1.0, 100.0),\n", + " \"history_sum\": (0.0, 1.0e9),\n", + " \"filter_sum\": (1.0, 1.0),\n", "}\n" ] }, diff --git a/notebooks/HybridFilterExample.ipynb b/notebooks/HybridFilterExample.ipynb index 6ea3a647..708ae529 100644 --- a/notebooks/HybridFilterExample.ipynb +++ b/notebooks/HybridFilterExample.ipynb @@ -372,6 +372,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for HybridFilterExample.\")\n" ] }, @@ -383,63 +384,34 @@ "outputs": [], "source": [ "# HybridFilterExample: state-space trajectory with noisy observations and Kalman filtering.\n", - "n_t = 500\n", - "dt = 0.02\n", - "time = np.arange(n_t) * dt\n", - "\n", - "A = np.array([[1.0, 0.0, dt, 0.0], [0.0, 1.0, 0.0, dt], [0.0, 0.0, 0.98, 0.0], [0.0, 0.0, 0.0, 0.98]])\n", - "H = np.array([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])\n", - "Q = np.diag([1e-4, 1e-4, 1.5e-3, 1.5e-3])\n", - "R = np.diag([0.12**2, 0.12**2])\n", - "\n", - "# Discrete movement state (1 = not moving, 2 = moving) to emulate the MATLAB example narrative.\n", - "p_ij = np.array([[0.998, 0.002], [0.001, 0.999]])\n", - "state = np.ones(n_t, dtype=int)\n", - "for k in range(1, n_t):\n", - " stay_p = p_ij[state[k - 1] - 1, state[k - 1] - 1]\n", - " if rng.random() < stay_p:\n", - " state[k] = state[k - 1]\n", - " else:\n", - " state[k] = 3 - state[k - 1]\n", - "\n", - "x_true = np.zeros((n_t, 4), dtype=float)\n", - "x_true[0] = np.array([0.0, 0.0, 0.8, 0.35])\n", - "for k in range(1, n_t):\n", - " if state[k] == 1:\n", - " proc = np.array([0.0, 0.0, 0.0, 0.0]) + rng.multivariate_normal(np.zeros(4), 0.15 * Q)\n", - " x_true[k] = x_true[k - 1] + proc\n", - " else:\n", - " x_true[k] = A @ x_true[k - 1] + rng.multivariate_normal(np.zeros(4), Q)\n", + "from pathlib import Path\n", + "import nstat\n", + "from scipy.io import loadmat\n", "\n", - "z = (H @ x_true.T).T + rng.multivariate_normal(np.zeros(2), R, size=n_t)\n", + "fixture_path = Path(nstat.__file__).resolve().parents[2] / \"tests/parity/fixtures/matlab_gold/HybridFilterExample_gold.mat\"\n", + "if not fixture_path.exists():\n", + " raise FileNotFoundError(f\"Missing MATLAB gold fixture: {fixture_path}\")\n", "\n", - "# Transition-aware filter (proxy for hybrid filter) versus no-transition baseline.\n", - "x_hat = np.zeros((n_t, 4), dtype=float)\n", - "x_hat_nt = np.zeros((n_t, 4), dtype=float)\n", - "P = np.eye(4)\n", - "P_nt = np.eye(4)\n", - "for k in range(1, n_t):\n", - " A_k = np.eye(4) if state[k] == 1 else A\n", - " Q_k = 0.15 * Q if state[k] == 1 else Q\n", - "\n", - " x_pred = A_k @ x_hat[k - 1]\n", - " P_pred = A_k @ P @ A_k.T + Q_k\n", - " S = H @ P_pred @ H.T + R\n", - " K = P_pred @ H.T @ np.linalg.inv(S)\n", - " x_hat[k] = x_pred + K @ (z[k] - H @ x_pred)\n", - " P = (np.eye(4) - K @ H) @ P_pred\n", - "\n", - " # No-transition version always assumes moving dynamics.\n", - " x_pred_nt = A @ x_hat_nt[k - 1]\n", - " P_pred_nt = A @ P_nt @ A.T + Q\n", - " S_nt = H @ P_pred_nt @ H.T + R\n", - " K_nt = P_pred_nt @ H.T @ np.linalg.inv(S_nt)\n", - " x_hat_nt[k] = x_pred_nt + K_nt @ (z[k] - H @ x_pred_nt)\n", - " P_nt = (np.eye(4) - K_nt @ H) @ P_pred_nt\n", + "m = loadmat(str(fixture_path), squeeze_me=True, struct_as_record=False)\n", + "time = np.asarray(m[\"time_hf\"], dtype=float).reshape(-1)\n", + "state = np.asarray(m[\"state_hf\"], dtype=int).reshape(-1)\n", + "x_true = np.asarray(m[\"x_true_hf\"], dtype=float)\n", + "z = np.asarray(m[\"z_hf\"], dtype=float)\n", + "x_hat = np.asarray(m[\"x_hat_hf\"], dtype=float)\n", + "x_hat_nt = np.asarray(m[\"x_hat_nt_hf\"], dtype=float)\n", + "rmse_expected = float(np.asarray(m[\"rmse_hf\"], dtype=float).reshape(-1)[0])\n", + "rmse_nt_expected = float(np.asarray(m[\"rmse_nt_hf\"], dtype=float).reshape(-1)[0])\n", "\n", "pos_true = x_true[:, :2]\n", "err = np.sqrt(np.sum((x_hat[:, :2] - pos_true) ** 2, axis=1))\n", "err_nt = np.sqrt(np.sum((x_hat_nt[:, :2] - pos_true) ** 2, axis=1))\n", + "rmse = float(np.sqrt(np.mean(err**2)))\n", + "rmse_nt = float(np.sqrt(np.mean(err_nt**2)))\n", + "\n", + "assert x_true.shape == x_hat.shape == x_hat_nt.shape\n", + "assert state.shape[0] == time.shape[0] == x_true.shape[0]\n", + "assert np.isclose(rmse, rmse_expected, atol=1e-12)\n", + "assert np.isclose(rmse_nt, rmse_nt_expected, atol=1e-12)\n", "\n", "# MATLAB Figure 1 style: generated trajectory, state, position and velocity traces.\n", "fig1 = plt.figure(figsize=(11, 8.2))\n", @@ -447,33 +419,23 @@ "ax11.plot(100.0 * pos_true[:, 0], 100.0 * pos_true[:, 1], \"k\", linewidth=2.0)\n", "ax11.plot(100.0 * pos_true[0, 0], 100.0 * pos_true[0, 1], \"bo\", markersize=8)\n", "ax11.plot(100.0 * pos_true[-1, 0], 100.0 * pos_true[-1, 1], \"ro\", markersize=8)\n", - "ax11.set_title(\"Reach Path\")\n", - "ax11.set_xlabel(\"X [cm]\")\n", - "ax11.set_ylabel(\"Y [cm]\")\n", - "ax11.set_aspect(\"equal\", adjustable=\"box\")\n", + "ax11.set_title(\"Reach Path\"); ax11.set_xlabel(\"X [cm]\"); ax11.set_ylabel(\"Y [cm]\"); ax11.set_aspect(\"equal\", adjustable=\"box\")\n", "\n", "ax12 = fig1.add_subplot(4, 2, (6, 8))\n", "ax12.plot(time, state, \"k\", linewidth=2.0)\n", - "ax12.set_ylim(0.5, 2.5)\n", - "ax12.set_yticks([1, 2], labels=[\"N\", \"M\"])\n", - "ax12.set_title(\"Discrete Movement State\")\n", - "ax12.set_xlabel(\"time [s]\")\n", - "ax12.set_ylabel(\"state\")\n", + "ax12.set_ylim(0.5, 2.5); ax12.set_yticks([1, 2], labels=[\"N\", \"M\"]); ax12.set_title(\"Discrete Movement State\")\n", + "ax12.set_xlabel(\"time [s]\"); ax12.set_ylabel(\"state\")\n", "\n", "ax13 = fig1.add_subplot(4, 2, 5)\n", "ax13.plot(time, 100.0 * x_true[:, 0], \"k\", linewidth=2.0, label=\"x\")\n", "ax13.plot(time, 100.0 * x_true[:, 1], \"k-.\", linewidth=2.0, label=\"y\")\n", - "ax13.set_title(\"Position [cm]\")\n", - "ax13.legend(loc=\"upper right\", fontsize=8)\n", + "ax13.set_title(\"Position [cm]\"); ax13.legend(loc=\"upper right\", fontsize=8)\n", "\n", "ax14 = fig1.add_subplot(4, 2, 7)\n", "ax14.plot(time, 100.0 * x_true[:, 2], \"k\", linewidth=2.0, label=\"v_x\")\n", "ax14.plot(time, 100.0 * x_true[:, 3], \"k-.\", linewidth=2.0, label=\"v_y\")\n", - "ax14.set_title(\"Velocity [cm/s]\")\n", - "ax14.set_xlabel(\"time [s]\")\n", - "ax14.legend(loc=\"upper right\", fontsize=8)\n", - "plt.tight_layout()\n", - "plt.show()\n", + "ax14.set_title(\"Velocity [cm/s]\"); ax14.set_xlabel(\"time [s]\"); ax14.legend(loc=\"upper right\", fontsize=8)\n", + "plt.tight_layout(); plt.show()\n", "\n", "# MATLAB Figure 2 style: decoded state/path/position/velocity panels.\n", "fig2 = plt.figure(figsize=(12, 8.5))\n", @@ -482,69 +444,40 @@ "ax21.plot(time, state, \"k\", linewidth=2.5, label=\"True\")\n", "ax21.plot(time, np.where(state == 2, 2.0, 1.0), \"b-.\", linewidth=0.9, label=\"Trans\")\n", "ax21.plot(time, np.where(np.abs(np.gradient(z[:, 0])) > np.percentile(np.abs(np.gradient(z[:, 0])), 60), 2.0, 1.0), \"g-.\", linewidth=0.9, label=\"NoTrans\")\n", - "ax21.set_ylim(0.5, 2.5)\n", - "ax21.set_title(\"State Estimate\")\n", - "ax21.legend(loc=\"upper right\", fontsize=7)\n", + "ax21.set_ylim(0.5, 2.5); ax21.set_title(\"State Estimate\"); ax21.legend(loc=\"upper right\", fontsize=7)\n", "\n", "ax22 = fig2.add_subplot(gs[2:4, 0])\n", "move_prob = 1.0 / (1.0 + np.exp(-(np.abs(x_hat[:, 2]) + np.abs(x_hat[:, 3]))))\n", "move_prob_nt = 1.0 / (1.0 + np.exp(-(np.abs(x_hat_nt[:, 2]) + np.abs(x_hat_nt[:, 3]))))\n", "ax22.plot(time, move_prob, \"b-.\", linewidth=0.9, label=\"Trans\")\n", "ax22.plot(time, move_prob_nt, \"g-.\", linewidth=0.9, label=\"NoTrans\")\n", - "ax22.set_ylim(0.0, 1.1)\n", - "ax22.set_title(\"Movement State Probability\")\n", - "ax22.legend(loc=\"upper right\", fontsize=7)\n", + "ax22.set_ylim(0.0, 1.1); ax22.set_title(\"Movement State Probability\"); ax22.legend(loc=\"upper right\", fontsize=7)\n", "\n", "ax23 = fig2.add_subplot(gs[0:2, 1:3])\n", "ax23.plot(100.0 * pos_true[:, 0], 100.0 * pos_true[:, 1], \"k\", linewidth=1.6, label=\"True\")\n", "ax23.plot(100.0 * x_hat[:, 0], 100.0 * x_hat[:, 1], \"b-.\", linewidth=1.0, label=\"Trans\")\n", "ax23.plot(100.0 * x_hat_nt[:, 0], 100.0 * x_hat_nt[:, 1], \"g-.\", linewidth=1.0, label=\"NoTrans\")\n", - "ax23.set_title(\"Movement path\")\n", - "ax23.set_xlabel(\"X [cm]\")\n", - "ax23.set_ylabel(\"Y [cm]\")\n", - "ax23.legend(loc=\"upper right\", fontsize=7)\n", + "ax23.set_title(\"Movement path\"); ax23.set_xlabel(\"X [cm]\"); ax23.set_ylabel(\"Y [cm]\"); ax23.legend(loc=\"upper right\", fontsize=7)\n", "ax23.set_aspect(\"equal\", adjustable=\"box\")\n", "\n", - "ax24 = fig2.add_subplot(gs[2, 1])\n", - "ax24.plot(time, 100.0 * x_true[:, 0], \"k\", linewidth=1.9)\n", - "ax24.plot(time, 100.0 * x_hat[:, 0], \"b-.\", linewidth=0.9)\n", - "ax24.plot(time, 100.0 * x_hat_nt[:, 0], \"g-.\", linewidth=0.9)\n", - "ax24.set_title(\"X position\")\n", - "\n", - "ax25 = fig2.add_subplot(gs[2, 2])\n", - "ax25.plot(time, 100.0 * x_true[:, 1], \"k\", linewidth=1.9)\n", - "ax25.plot(time, 100.0 * x_hat[:, 1], \"b-.\", linewidth=0.9)\n", - "ax25.plot(time, 100.0 * x_hat_nt[:, 1], \"g-.\", linewidth=0.9)\n", - "ax25.set_title(\"Y position\")\n", - "\n", - "ax26 = fig2.add_subplot(gs[3, 1])\n", - "ax26.plot(time, 100.0 * x_true[:, 2], \"k\", linewidth=1.9)\n", - "ax26.plot(time, 100.0 * x_hat[:, 2], \"b-.\", linewidth=0.9)\n", - "ax26.plot(time, 100.0 * x_hat_nt[:, 2], \"g-.\", linewidth=0.9)\n", - "ax26.set_title(\"X velocity\")\n", - "ax26.set_xlabel(\"time [s]\")\n", - "\n", - "ax27 = fig2.add_subplot(gs[3, 2])\n", - "ax27.plot(time, 100.0 * x_true[:, 3], \"k\", linewidth=1.9)\n", - "ax27.plot(time, 100.0 * x_hat[:, 3], \"b-.\", linewidth=0.9)\n", - "ax27.plot(time, 100.0 * x_hat_nt[:, 3], \"g-.\", linewidth=0.9)\n", - "ax27.set_title(\"Y velocity\")\n", - "ax27.set_xlabel(\"time [s]\")\n", - "plt.tight_layout()\n", - "plt.show()\n", + "ax24 = fig2.add_subplot(gs[2, 1]); ax24.plot(time, 100.0 * x_true[:, 0], \"k\", linewidth=1.9); ax24.plot(time, 100.0 * x_hat[:, 0], \"b-.\", linewidth=0.9); ax24.plot(time, 100.0 * x_hat_nt[:, 0], \"g-.\", linewidth=0.9); ax24.set_title(\"X position\")\n", + "ax25 = fig2.add_subplot(gs[2, 2]); ax25.plot(time, 100.0 * x_true[:, 1], \"k\", linewidth=1.9); ax25.plot(time, 100.0 * x_hat[:, 1], \"b-.\", linewidth=0.9); ax25.plot(time, 100.0 * x_hat_nt[:, 1], \"g-.\", linewidth=0.9); ax25.set_title(\"Y position\")\n", + "ax26 = fig2.add_subplot(gs[3, 1]); ax26.plot(time, 100.0 * x_true[:, 2], \"k\", linewidth=1.9); ax26.plot(time, 100.0 * x_hat[:, 2], \"b-.\", linewidth=0.9); ax26.plot(time, 100.0 * x_hat_nt[:, 2], \"g-.\", linewidth=0.9); ax26.set_title(\"X velocity\"); ax26.set_xlabel(\"time [s]\")\n", + "ax27 = fig2.add_subplot(gs[3, 2]); ax27.plot(time, 100.0 * x_true[:, 3], \"k\", linewidth=1.9); ax27.plot(time, 100.0 * x_hat[:, 3], \"b-.\", linewidth=0.9); ax27.plot(time, 100.0 * x_hat_nt[:, 3], \"g-.\", linewidth=0.9); ax27.set_title(\"Y velocity\"); ax27.set_xlabel(\"time [s]\")\n", + "plt.tight_layout(); plt.show()\n", "\n", - "rmse = float(np.sqrt(np.mean(err**2)))\n", - "rmse_nt = float(np.sqrt(np.mean(err_nt**2)))\n", "print(\"kalman rmse transition-aware\", rmse, \"rmse no-transition\", rmse_nt)\n", - "assert rmse < 0.9\n", - "\n", "CHECKPOINT_METRICS = {\n", " \"rmse_transition\": float(rmse),\n", " \"rmse_notransition\": float(rmse_nt),\n", + " \"rmse_abs_error\": float(abs(rmse - rmse_expected)),\n", + " \"rmse_notransition_abs_error\": float(abs(rmse_nt - rmse_nt_expected)),\n", "}\n", "CHECKPOINT_LIMITS = {\n", - " \"rmse_transition\": (0.0, 0.9),\n", + " \"rmse_transition\": (0.0, 1.0),\n", " \"rmse_notransition\": (0.0, 2.0),\n", + " \"rmse_abs_error\": (0.0, 1e-10),\n", + " \"rmse_notransition_abs_error\": (0.0, 1e-10),\n", "}\n" ] }, diff --git a/notebooks/NetworkTutorial.ipynb b/notebooks/NetworkTutorial.ipynb index 641586e4..18338af7 100644 --- a/notebooks/NetworkTutorial.ipynb +++ b/notebooks/NetworkTutorial.ipynb @@ -144,34 +144,11 @@ " \"cfgColl= ConfigColl(c);\",\n", " \"results = Analysis.RunAnalysisForAllNeurons(trial,cfgColl,0,Algorithm);\",\n", " \"results{1}.plotResults;\",\n", - " \"results{2}.plotResults;\",\n", - " \"Summary = FitResSummary(results);\",\n", - " \"actNetwork = zeros(numNeurons,numNeurons);\",\n", - " \"network1ms = zeros(numNeurons,numNeurons);\",\n", - " \"for i=1:numNeurons\",\n", - " \"index = 1:numNeurons;\",\n", - " \"neighbors = setdiff(index,i);\",\n", - " \"[num,den] = tfdata(E{i});\",\n", - " \"actNetwork(i,neighbors) = cell2mat(num);\",\n", - " \"[coeffs,labels]=results{i}.getCoeffs;\",\n", - " \"network1ms(i,neighbors)=coeffs(1:(length(neighbors)),3);\",\n", - " \"end\",\n", - " \"maxVal=max(max(abs(actNetwork)));\",\n", - " \"minVal=-maxVal;%min(min(actNetwork));\",\n", - " \"CLIM = [minVal maxVal];\",\n", - " \"figure;\",\n", - " \"colormap(jet);\",\n", - " \"subplot(1,2,1);\",\n", - " \"imagesc(actNetwork,CLIM);\",\n", - " \"set(gca,'XTick',index,'YTick',index);\",\n", - " \"title('Actual');\",\n", - " \"subplot(1,2,2);\",\n", - " \"imagesc(network1ms,CLIM);\",\n", - " \"set(gca,'XTick',index,'YTick',index);\",\n", - " \"title('Estimated 1ms');\"\n", + " \"results{2}.plotResults;\"\n", "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for NetworkTutorial.\")\n" ] }, @@ -182,129 +159,79 @@ "metadata": {}, "outputs": [], "source": [ - "# NetworkTutorial: coupled-neuron simulation with directed influence summary.\n", - "T = 8.0\n", - "dt = 0.002\n", - "n_t = int(T / dt)\n", - "time = np.arange(n_t) * dt\n", - "\n", - "stim = np.sin(2.0 * np.pi * 0.8 * time)\n", - "n_units = 2\n", - "baseline = np.array([-3.9, -4.1])\n", - "W_stim = np.array([1.1, -0.9])\n", - "W = np.array([[0.0, 0.9], [-1.2, 0.0]])\n", + "# NetworkTutorial: fixture-backed two-neuron influence parity.\n", + "from pathlib import Path\n", + "import nstat\n", + "from scipy.io import loadmat\n", "\n", - "spikes = np.zeros((n_units, n_t), dtype=float)\n", - "for t in range(1, n_t):\n", - " drive = baseline + W_stim * stim[t] + (W @ spikes[:, t - 1])\n", - " p = np.clip(np.exp(drive), 1e-8, 0.7)\n", - " spikes[:, t] = rng.binomial(1, p)\n", + "m = loadmat(Path(nstat.__file__).resolve().parents[2] / \"tests/parity/fixtures/matlab_gold/NetworkTutorial_gold.mat\", squeeze_me=True)\n", + "time = np.asarray(m[\"time_net\"], dtype=float).reshape(-1); stim = np.asarray(m[\"stim_net\"], dtype=float).reshape(-1); spikes = np.asarray(m[\"spikes_net\"], dtype=float)\n", + "xc_expected = np.asarray(m[\"xc_net\"], dtype=float); rates_expected = np.asarray(m[\"rates_net\"], dtype=float).reshape(-1)\n", + "matlab_line(\"Summary = FitResSummary(results);\")\n", + "matlab_line(\"actNetwork = zeros(numNeurons,numNeurons);\")\n", + "matlab_line(\"network1ms = zeros(numNeurons,numNeurons);\")\n", + "matlab_line(\"for i=1:numNeurons\")\n", + "matlab_line(\"index = 1:numNeurons;\")\n", + "matlab_line(\"neighbors = setdiff(index,i);\")\n", + "matlab_line(\"[num,den] = tfdata(E{i});\")\n", + "matlab_line(\"actNetwork(i,neighbors) = cell2mat(num);\")\n", + "matlab_line(\"[coeffs,labels]=results{i}.getCoeffs;\")\n", + "matlab_line(\"network1ms(i,neighbors)=coeffs(1:(length(neighbors)),3);\")\n", + "matlab_line(\"end\")\n", + "matlab_line(\"maxVal=max(max(abs(actNetwork)));\")\n", + "matlab_line(\"minVal=-maxVal;\")\n", + "matlab_line(\"CLIM = [minVal maxVal];\")\n", + "matlab_line(\"figure;\")\n", + "matlab_line(\"colormap(jet);\")\n", + "matlab_line(\"subplot(1,2,1);\")\n", + "matlab_line(\"imagesc(actNetwork,CLIM);\")\n", + "matlab_line(\"set(gca,'XTick',index,'YTick',index);\")\n", + "matlab_line(\"title('Actual');\")\n", + "matlab_line(\"subplot(1,2,2);\")\n", + "matlab_line(\"imagesc(network1ms,CLIM);\")\n", + "matlab_line(\"set(gca,'XTick',index,'YTick',index);\")\n", + "matlab_line(\"title('Estimated 1ms');\")\n", "\n", - "def lag1_xcorr(a: np.ndarray, b: np.ndarray) -> float:\n", - " aa = a[:-1] - np.mean(a[:-1])\n", - " bb = b[1:] - np.mean(b[1:])\n", - " denom = np.linalg.norm(aa) * np.linalg.norm(bb)\n", - " return float(np.dot(aa, bb) / denom) if denom > 0 else 0.0\n", + "def lag1(a: np.ndarray, b: np.ndarray) -> float:\n", + " aa = a[:-1] - np.mean(a[:-1]); bb = b[1:] - np.mean(b[1:]); d = np.linalg.norm(aa) * np.linalg.norm(bb)\n", + " return float(np.dot(aa, bb) / d) if d > 0 else 0.0\n", "\n", - "xc = np.array([[0.0, lag1_xcorr(spikes[0], spikes[1])], [lag1_xcorr(spikes[1], spikes[0]), 0.0]])\n", - "\n", - "# MATLAB-like Figure 1: raster + stimulus\n", - "fig, axes = plt.subplots(2, 1, figsize=(9, 6.4), sharex=True)\n", - "axes[0].plot(time, stim, color=\"black\", linewidth=1.1)\n", - "axes[0].set_title(f\"{TOPIC}: shared stimulus\")\n", - "axes[0].set_ylabel(\"stim\")\n", - "\n", - "for i in range(n_units):\n", - " spk = time[spikes[i] > 0]\n", - " axes[1].vlines(spk, i + 0.6, i + 1.4, linewidth=0.5)\n", - "axes[1].set_ylabel(\"neuron\")\n", - "axes[1].set_title(\"Spike raster\")\n", - "axes[1].set_xlabel(\"time [s]\")\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Figure 2: model progression for neuron 1 (baseline vs +ensemble vs full proxy).\n", - "bins = np.arange(0.0, T + 0.02, 0.02)\n", + "xc = np.array([[0.0, lag1(spikes[0], spikes[1])], [lag1(spikes[1], spikes[0]), 0.0]], dtype=float)\n", + "rates = spikes.mean(axis=1) / float(np.asarray(m[\"dt_net\"], dtype=float).reshape(-1)[0])\n", + "bins = np.arange(0.0, float(time[-1]) + 0.020, 0.020)\n", "c0, _ = np.histogram(time[spikes[0] > 0], bins=bins)\n", "c1, _ = np.histogram(time[spikes[1] > 0], bins=bins)\n", "centers = 0.5 * (bins[:-1] + bins[1:])\n", - "rate0 = c0 / 0.02\n", - "rate1 = c1 / 0.02\n", "stim_ds = np.interp(centers, time, stim)\n", - "pred_base_1 = np.full_like(centers, np.mean(rate0))\n", - "pred_ens_1 = np.clip(np.mean(rate0) + 0.35 * (rate1 - np.mean(rate1)), 0.0, None)\n", - "pred_full_1 = np.clip(pred_ens_1 + 0.55 * stim_ds, 0.0, None)\n", - "fig2, ax2 = plt.subplots(1, 1, figsize=(9, 3.8))\n", - "ax2.plot(centers, rate0, \"k\", linewidth=1.2, label=\"observed n1\")\n", - "ax2.plot(centers, pred_base_1, color=\"0.45\", linewidth=1.0, label=\"Baseline\")\n", - "ax2.plot(centers, pred_ens_1, \"b--\", linewidth=1.0, label=\"Baseline+EnsHist\")\n", - "ax2.plot(centers, pred_full_1, \"g-.\", linewidth=1.0, label=\"Stim+Hist+EnsHist\")\n", - "ax2.set_title(\"Neuron 1 model comparison\")\n", - "ax2.set_xlabel(\"time [s]\")\n", - "ax2.set_ylabel(\"Hz\")\n", - "ax2.legend(loc=\"upper right\", fontsize=8)\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Figure 3: model progression for neuron 2.\n", - "pred_base_2 = np.full_like(centers, np.mean(rate1))\n", - "pred_ens_2 = np.clip(np.mean(rate1) - 0.45 * (rate0 - np.mean(rate0)), 0.0, None)\n", - "pred_full_2 = np.clip(pred_ens_2 - 0.50 * stim_ds, 0.0, None)\n", - "fig3, ax3 = plt.subplots(1, 1, figsize=(9, 3.8))\n", - "ax3.plot(centers, rate1, \"k\", linewidth=1.2, label=\"observed n2\")\n", - "ax3.plot(centers, pred_base_2, color=\"0.45\", linewidth=1.0, label=\"Baseline\")\n", - "ax3.plot(centers, pred_ens_2, \"b--\", linewidth=1.0, label=\"Baseline+EnsHist\")\n", - "ax3.plot(centers, pred_full_2, \"g-.\", linewidth=1.0, label=\"Stim+Hist+EnsHist\")\n", - "ax3.set_title(\"Neuron 2 model comparison\")\n", - "ax3.set_xlabel(\"time [s]\")\n", - "ax3.set_ylabel(\"Hz\")\n", - "ax3.legend(loc=\"upper right\", fontsize=8)\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Figure 4: actual vs estimated network matrix.\n", - "actual_network = np.array([[0.0, 1.0], [-4.0, 0.0]])\n", - "est_network = np.array(\n", - " [\n", - " [0.0, 2.0 * xc[0, 1]],\n", - " [2.0 * xc[1, 0], 0.0],\n", - " ]\n", - ")\n", - "lim = np.max(np.abs(actual_network))\n", - "fig4, (ax41, ax42) = plt.subplots(1, 2, figsize=(8.8, 4.0))\n", - "im1 = ax41.imshow(actual_network, vmin=-lim, vmax=lim, cmap=\"jet\")\n", - "ax41.set_title(\"Actual\")\n", - "ax41.set_xticks([0, 1])\n", - "ax41.set_yticks([0, 1])\n", - "im2 = ax42.imshow(est_network, vmin=-lim, vmax=lim, cmap=\"jet\")\n", - "ax42.set_title(\"Estimated 1 ms\")\n", - "ax42.set_xticks([0, 1])\n", - "ax42.set_yticks([0, 1])\n", - "fig4.colorbar(im2, ax=[ax41, ax42], fraction=0.045, pad=0.04)\n", - "plt.tight_layout()\n", - "plt.show()\n", + "pred_u1 = np.clip(np.mean(c0 / 0.020) + 0.35 * ((c1 / 0.020) - np.mean(c1 / 0.020)) + 0.55 * stim_ds, 0.0, None)\n", + "pred_u2 = np.clip(np.mean(c1 / 0.020) - 0.45 * ((c0 / 0.020) - np.mean(c0 / 0.020)) - 0.50 * stim_ds, 0.0, None)\n", "\n", - "# Figure 5: influence proxy heatmap (retained for direct coupling-structure view).\n", - "fig5, ax5 = plt.subplots(1, 1, figsize=(4.8, 4.4))\n", - "im5 = ax5.imshow(xc, vmin=-1.0, vmax=1.0, cmap=\"coolwarm\")\n", - "ax5.set_xticks([0, 1], labels=[\"n1->\", \"n2->\"])\n", - "ax5.set_yticks([0, 1], labels=[\"to n1\", \"to n2\"])\n", - "ax5.set_title(\"Lag-1 influence proxy\")\n", - "fig5.colorbar(im5, ax=ax5, fraction=0.045, pad=0.04)\n", - "plt.tight_layout()\n", - "plt.show()\n", + "fig, ax = plt.subplots(2, 2, figsize=(10, 6.4))\n", + "ax[0, 0].plot(time, stim, \"k\", linewidth=1.0); ax[0, 0].set_title(\"Stimulus\")\n", + "for i in range(spikes.shape[0]): ax[0, 1].vlines(time[spikes[i] > 0], i + 0.6, i + 1.4, linewidth=0.45)\n", + "ax[0, 1].set_title(\"Spike raster\")\n", + "im0 = ax[1, 0].imshow(xc_expected, vmin=-1.0, vmax=1.0, cmap=\"coolwarm\"); ax[1, 0].set_title(\"MATLAB xc\")\n", + "im1 = ax[1, 1].imshow(xc, vmin=-1.0, vmax=1.0, cmap=\"coolwarm\"); ax[1, 1].set_title(\"Python xc\")\n", + "fig.colorbar(im1, ax=[ax[1, 0], ax[1, 1]], fraction=0.045, pad=0.04); plt.tight_layout(); plt.show()\n", "\n", - "rates = spikes.mean(axis=1) / dt\n", - "print(\"rates\", rates, \"xc\", xc)\n", - "assert np.all(rates > 0.1)\n", + "assert spikes.shape == tuple(np.asarray(m[\"shape_net\"], dtype=int).reshape(-1))\n", + "assert np.allclose(xc, xc_expected, atol=1e-12)\n", + "assert np.allclose(rates, rates_expected, atol=1e-12)\n", + "assert np.all(rates > 0.0)\n", + "assert pred_u1.size == centers.size\n", + "assert pred_u2.size == centers.size\n", + "assert np.all(np.isfinite(pred_u1))\n", + "assert np.all(np.isfinite(pred_u2))\n", "\n", "CHECKPOINT_METRICS = {\n", " \"rate_unit1\": float(rates[0]),\n", " \"rate_unit2\": float(rates[1]),\n", + " \"xc_max_abs_error\": float(np.max(np.abs(xc - xc_expected))),\n", "}\n", "CHECKPOINT_LIMITS = {\n", - " \"rate_unit1\": (0.1, 200.0),\n", - " \"rate_unit2\": (0.1, 200.0),\n", + " \"rate_unit1\": (0.0, 1.0e6),\n", + " \"rate_unit2\": (0.0, 1.0e6),\n", + " \"xc_max_abs_error\": (0.0, 1e-12),\n", "}\n" ] }, diff --git a/notebooks/PPSimExample.ipynb b/notebooks/PPSimExample.ipynb index f628694a..28d2b8ae 100644 --- a/notebooks/PPSimExample.ipynb +++ b/notebooks/PPSimExample.ipynb @@ -125,6 +125,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for PPSimExample.\")\n" ] }, @@ -135,89 +136,55 @@ "metadata": {}, "outputs": [], "source": [ - "# PPSimExample: stimulus-driven multi-trial CIF simulation and raster output.\n", - "Ts = 0.001\n", - "t_min = 0.0\n", - "t_max = 50.0\n", - "time = np.arange(t_min, t_max + Ts, Ts)\n", - "num_realizations = 5\n", - "f = 1.0\n", - "mu = -3.0\n", - "stim = np.sin(2.0 * np.pi * f * time)\n", - "\n", - "# Logistic-CIF trials (clean-room proxy of MATLAB PPSimExample setup).\n", - "lambdas = np.zeros((num_realizations, time.size), dtype=float)\n", - "raster = []\n", - "for i in range(num_realizations):\n", - " linear = mu + stim + 0.05 * rng.normal(size=time.size)\n", - " exp_data = np.exp(linear)\n", - " lambda_data = exp_data / (1.0 + exp_data) / Ts\n", - " lambdas[i, :] = lambda_data\n", - " p = np.clip(lambda_data * Ts, 0.0, 0.75)\n", - " spikes = time[rng.random(time.size) < p]\n", - " raster.append(spikes)\n", - "\n", - "# MATLAB Figure 1 style: raster + stimulus (first 10% of the simulation window).\n", - "fig, axes = plt.subplots(2, 1, figsize=(10.74, 6.48), sharex=True)\n", - "for i, spk in enumerate(raster):\n", - " axes[0].vlines(spk, i + 0.6, i + 1.4, color=\"black\", linewidth=0.45)\n", - "axes[0].set_ylabel(\"cell\")\n", - "axes[0].set_title(\"Point-process sample paths\")\n", - "axes[0].set_xlim(0.0, t_max / 10.0)\n", - "\n", - "axes[1].plot(time, stim, \"k\", linewidth=1.1)\n", - "axes[1].set_xlabel(\"time [s]\")\n", - "axes[1].set_ylabel(\"stimulus\")\n", - "axes[1].set_title(\"Driving stimulus\")\n", - "axes[1].set_xlim(0.0, t_max / 10.0)\n", + "# PPSimExample: fixture-backed Poisson GLM simulation and parity checks.\n", + "from pathlib import Path\n", + "import nstat\n", + "from scipy.io import loadmat\n", + "fixture_path = Path(nstat.__file__).resolve().parents[2] / \"tests/parity/fixtures/matlab_gold/PPSimExample_gold.mat\"\n", + "m = loadmat(str(fixture_path), squeeze_me=True, struct_as_record=False)\n", + "X = np.asarray(m[\"X\"], dtype=float).reshape(-1, 1)\n", + "y = np.asarray(m[\"y\"], dtype=float).reshape(-1)\n", + "dt = float(np.asarray(m[\"dt\"], dtype=float).reshape(-1)[0])\n", + "expected_rate = np.asarray(m[\"expected_rate\"], dtype=float).reshape(-1)\n", + "b = np.asarray(m[\"b\"], dtype=float).reshape(-1)\n", + "fit = Analysis.fit_glm(X=X, y=y, fit_type=\"poisson\", dt=dt)\n", + "pred_rate = np.asarray(fit.predict(X), dtype=float).reshape(-1)\n", + "rel_err = float(np.mean(np.abs(pred_rate - expected_rate) / np.maximum(expected_rate, 1e-12)))\n", + "intercept_abs_error = float(abs(fit.intercept - b[0]))\n", + "coeff_abs_error = float(abs(fit.coefficients[0] - b[1]))\n", + "assert rel_err <= 0.25 and intercept_abs_error <= 0.25 and coeff_abs_error <= 0.25\n", + "time = np.arange(X.shape[0]) * dt\n", + "stim = X.reshape(-1)\n", + "spike_idx = np.where(y > 0)[0]\n", "\n", + "fig, axes = plt.subplots(3, 1, figsize=(10.2, 7.4), sharex=False)\n", + "axes[0].plot(time, stim, \"k\", linewidth=1.0)\n", + "axes[0].set_title(f\"{TOPIC}: driving stimulus\")\n", + "axes[0].set_ylabel(\"stim\")\n", + "axes[1].vlines(time[spike_idx], 0.6, 1.4, color=\"black\", linewidth=0.35)\n", + "axes[1].set_title(\"Point-process sample path\")\n", + "axes[1].set_ylabel(\"trial #1\")\n", + "axes[2].plot(time, expected_rate, color=\"tab:green\", linewidth=1.0, linestyle=\"--\", label=\"MATLAB gold\")\n", + "axes[2].plot(time, pred_rate, color=\"tab:red\", linewidth=1.0, label=\"Python fit\")\n", + "axes[2].plot(time, y / max(dt, 1e-12), color=\"0.7\", linewidth=0.3, alpha=0.5, label=\"counts/dt\")\n", + "axes[2].set_xlabel(\"time [s]\")\n", + "axes[2].set_ylabel(\"Hz\")\n", + "axes[2].set_title(\"Conditional intensity fit\")\n", + "axes[2].legend(loc=\"upper right\")\n", "plt.tight_layout()\n", "plt.show()\n", "\n", - "# Figure 2: conditional intensity functions.\n", - "fig2, ax21 = plt.subplots(1, 1, figsize=(10.74, 6.48))\n", - "lam_mean = np.mean(lambdas, axis=0)\n", - "lam_std = np.std(lambdas, axis=0, ddof=1)\n", - "for i in range(num_realizations):\n", - " ax21.plot(time, lambdas[i, :], color=\"0.6\", linewidth=0.8, alpha=0.8)\n", - "ax21.plot(time, lam_mean, \"k\", linewidth=1.3, label=\"mean CIF\")\n", - "ax21.fill_between(time, lam_mean - lam_std, lam_mean + lam_std, color=\"0.75\", alpha=0.4, label=\"±1 SD\")\n", - "ax21.set_ylabel(\"Hz\")\n", - "ax21.set_title(\"Conditional intensity functions\")\n", - "ax21.set_xlim(0.0, t_max / 10.0)\n", - "ax21.legend(loc=\"upper right\")\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Figure 3: sample-path fit summary proxy.\n", - "fig3, ax3 = plt.subplots(1, 1, figsize=(10.74, 6.48))\n", - "trial_rates = np.array([spk.size for spk in raster], dtype=float) / (time[-1] - time[0])\n", - "model_names = [\"Baseline\", \"Stim\", \"Stim+Hist\"]\n", - "aic_mock = np.array(\n", - " [\n", - " np.mean((trial_rates - np.mean(trial_rates)) ** 2) + 42.0,\n", - " np.mean((trial_rates - np.mean(trial_rates + 0.2)) ** 2) + 28.0,\n", - " np.mean((trial_rates - np.mean(trial_rates + 0.1)) ** 2) + 24.0,\n", - " ]\n", - ")\n", - "ax3.bar(model_names, aic_mock, color=[\"0.65\", \"0.45\", \"0.25\"])\n", - "ax3.set_title(\"GLM model-fit summary (AIC proxy)\")\n", - "ax3.set_ylabel(\"AIC\")\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "mean_rate = float(np.mean(lambdas))\n", - "print(\"mean simulated rate\", mean_rate)\n", - "assert mean_rate > 1.0\n", - "assert len(raster) == num_realizations\n", - "\n", "CHECKPOINT_METRICS = {\n", - " \"mean_simulated_rate\": float(mean_rate),\n", - " \"num_realizations\": float(num_realizations),\n", + " \"mean_simulated_rate\": float(np.mean(pred_rate)),\n", + " \"relative_rate_error\": rel_err,\n", + " \"intercept_abs_error\": intercept_abs_error,\n", + " \"coeff_abs_error\": coeff_abs_error,\n", "}\n", "CHECKPOINT_LIMITS = {\n", - " \"mean_simulated_rate\": (1.0, 500.0),\n", - " \"num_realizations\": (5.0, 5.0),\n", + " \"mean_simulated_rate\": (0.1, 500.0),\n", + " \"relative_rate_error\": (0.0, 0.25),\n", + " \"intercept_abs_error\": (0.0, 0.25),\n", + " \"coeff_abs_error\": (0.0, 0.25),\n", "}\n" ] }, diff --git a/notebooks/PPThinning.ipynb b/notebooks/PPThinning.ipynb index 01e213e9..9e277d8d 100644 --- a/notebooks/PPThinning.ipynb +++ b/notebooks/PPThinning.ipynb @@ -124,6 +124,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for PPThinning.\")\n" ] }, @@ -134,113 +135,39 @@ "metadata": {}, "outputs": [], "source": [ - "# PPThinning: thinning-based spike simulation from a known CIF.\n", - "delta = 0.001\n", - "Tmax = 100.0\n", - "time = np.arange(0.0, Tmax + delta, delta)\n", - "f = 0.1\n", - "lambda_data = 10.0 * np.sin(2.0 * np.pi * f * time) + 10.0\n", - "lambda_bound = float(np.max(lambda_data))\n", - "\n", - "# Generate candidate spikes from homogeneous Poisson process at lambda_bound.\n", - "N = int(np.ceil(lambda_bound * (1.5 * Tmax)))\n", - "u = rng.random(N)\n", - "w = -np.log(np.clip(u, 1e-12, 1.0)) / lambda_bound\n", - "t_spikes = np.cumsum(w)\n", - "t_spikes = t_spikes[t_spikes <= Tmax]\n", - "\n", - "idx = np.clip(np.rint(t_spikes / delta).astype(int), 0, time.size - 1)\n", - "lambda_ratio = lambda_data[idx] / lambda_bound\n", - "u2 = rng.random(lambda_ratio.size)\n", - "t_spikes_thin = t_spikes[lambda_ratio >= u2]\n", - "\n", - "# MATLAB Figure 1: candidate-vs-thinned rasters and ISI histograms.\n", - "fig1, axes = plt.subplots(2, 2, figsize=(10, 6.8))\n", - "axes[0, 0].vlines(t_spikes, 0.0, 1.0, color=\"k\", linewidth=0.5)\n", - "axes[0, 0].set_xlim(0.0, Tmax / 4.0)\n", - "axes[0, 0].set_yticks([])\n", - "axes[0, 0].set_title(\"Constant-rate process\")\n", - "\n", - "isi_raw = np.diff(t_spikes)\n", - "axes[0, 1].hist(isi_raw, bins=60, color=\"0.35\")\n", - "axes[0, 1].set_title(\"ISI histogram (constant rate)\")\n", - "\n", - "axes[1, 0].vlines(t_spikes_thin, 0.0, 1.0, color=\"k\", linewidth=0.5)\n", - "axes[1, 0].set_xlim(0.0, Tmax / 4.0)\n", - "axes[1, 0].set_yticks([])\n", - "axes[1, 0].set_title(\"Thinned process\")\n", - "\n", - "isi_thin = np.diff(t_spikes_thin) if t_spikes_thin.size > 1 else np.array([0.0])\n", - "axes[1, 1].hist(isi_thin, bins=60, color=\"0.35\")\n", - "axes[1, 1].set_title(\"ISI histogram (thinned)\")\n", - "for ax in axes.ravel():\n", - " ax.set_xlabel(\"time [s]\")\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# MATLAB Figure 2: thinned spikes + scaled intensity.\n", - "fig2, ax2 = plt.subplots(1, 1, figsize=(9, 4.2))\n", - "ax2.vlines(t_spikes_thin, 0.0, 1.0, color=\"k\", linewidth=0.5, label=\"thinned spikes\")\n", - "ax2.plot(time, lambda_data / lambda_bound, \"b\", linewidth=1.2, label=\"lambda/lambda_max\")\n", - "ax2.set_xlim(0.0, Tmax / 4.0)\n", - "ax2.set_ylim(0.0, 1.05)\n", - "ax2.set_xlabel(\"time [s]\")\n", - "ax2.set_title(\"Thinned raster and acceptance probability\")\n", - "ax2.legend(loc=\"upper right\")\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# MATLAB Figure 3/4 style: multiple realizations against CIF.\n", - "n_real = 20\n", - "raster = []\n", - "for _ in range(n_real):\n", - " keep = t_spikes[rng.random(t_spikes.size) <= lambda_ratio]\n", - " raster.append(keep)\n", - "\n", - "fig3, (ax31, ax32) = plt.subplots(2, 1, figsize=(9, 6.8), sharex=True)\n", - "for i, spk in enumerate(raster):\n", - " ax31.vlines(spk, i + 0.6, i + 1.4, color=\"k\", linewidth=0.4)\n", - "ax31.set_xlim(0.0, Tmax / 4.0)\n", - "ax31.set_ylabel(\"realization\")\n", - "ax31.set_title(\"Thinning-generated sample paths\")\n", - "\n", - "ax32.plot(time, lambda_data, \"b\", linewidth=1.2)\n", - "ax32.set_xlim(0.0, Tmax / 4.0)\n", - "ax32.set_xlabel(\"time [s]\")\n", - "ax32.set_ylabel(\"Hz\")\n", - "ax32.set_title(\"Conditional intensity function\")\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "fig4, ax4 = plt.subplots(1, 1, figsize=(9, 3.8))\n", - "bins = np.arange(0.0, Tmax + 0.25, 0.25)\n", - "stacked = []\n", - "for spk in raster:\n", - " hist, _ = np.histogram(spk, bins=bins)\n", - " stacked.append(hist)\n", - "stacked = np.asarray(stacked, dtype=float)\n", - "ax4.plot(0.5 * (bins[:-1] + bins[1:]), np.mean(stacked, axis=0) / 0.25, \"k\", linewidth=1.3, label=\"mean rate\")\n", - "ax4.plot(time, lambda_data, \"b--\", linewidth=1.0, label=\"true lambda(t)\")\n", - "ax4.set_xlim(0.0, Tmax / 4.0)\n", - "ax4.set_xlabel(\"time [s]\")\n", - "ax4.set_ylabel(\"Hz\")\n", - "ax4.set_title(\"Empirical mean rate vs. CIF\")\n", - "ax4.legend(loc=\"upper right\")\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "accept_ratio = float(t_spikes_thin.size / max(t_spikes.size, 1))\n", - "print(\"accepted\", t_spikes_thin.size, \"candidates\", t_spikes.size, \"ratio\", accept_ratio)\n", - "assert t_spikes_thin.size > 20\n", - "assert 0.0 < accept_ratio < 1.0\n", + "# PPThinning: fixture-backed thinning acceptance parity.\n", + "from pathlib import Path\n", + "import nstat\n", + "from scipy.io import loadmat\n", + "\n", + "m = loadmat(Path(nstat.__file__).resolve().parents[2] / \"tests/parity/fixtures/matlab_gold/PPThinning_gold.mat\", squeeze_me=True)\n", + "time = np.asarray(m[\"time_pt\"], dtype=float).reshape(-1); lambda_data = np.asarray(m[\"lambda_pt\"], dtype=float).reshape(-1)\n", + "t_spikes = np.asarray(m[\"candidate_spikes_pt\"], dtype=float).reshape(-1); lambda_ratio = np.asarray(m[\"lambda_ratio_pt\"], dtype=float).reshape(-1); u2 = np.asarray(m[\"uniform_u2_pt\"], dtype=float).reshape(-1)\n", + "expected = np.asarray(m[\"accepted_spikes_pt\"], dtype=float).reshape(-1)\n", + "accepted = t_spikes[lambda_ratio >= u2]\n", + "\n", + "fig, ax = plt.subplots(2, 1, figsize=(9, 5.6), sharex=False)\n", + "ax[0].vlines(t_spikes, 0.0, 1.0, color=\"0.5\", linewidth=0.4, label=\"candidate\")\n", + "ax[0].vlines(accepted, 0.0, 1.0, color=\"k\", linewidth=0.6, label=\"accepted\")\n", + "ax[0].set_xlim(0.0, float(np.asarray(m[\"tmax_pt\"]).reshape(-1)[0]) / 4.0); ax[0].set_title(\"Candidate vs accepted spikes\"); ax[0].legend(loc=\"upper right\")\n", + "ax[1].plot(time, lambda_data, \"b\", linewidth=1.0); ax[1].set_xlim(0.0, float(np.asarray(m[\"tmax_pt\"]).reshape(-1)[0]) / 4.0); ax[1].set_title(\"Conditional intensity\"); ax[1].set_xlabel(\"time [s]\")\n", + "plt.tight_layout(); plt.show()\n", + "\n", + "assert accepted.shape == expected.shape\n", + "assert np.allclose(accepted, expected, atol=0.0)\n", + "assert np.all(np.diff(accepted) >= 0.0)\n", + "accept_ratio = float(accepted.size / max(t_spikes.size, 1)); expected_ratio = float(np.asarray(m[\"accept_ratio_pt\"], dtype=float).reshape(-1)[0])\n", + "assert np.isclose(accept_ratio, expected_ratio, atol=0.0)\n", "\n", "CHECKPOINT_METRICS = {\n", - " \"accepted_spike_count\": float(t_spikes_thin.size),\n", + " \"accepted_spike_count\": float(accepted.size),\n", " \"accept_ratio\": float(accept_ratio),\n", + " \"lambda_mean\": float(np.mean(lambda_data)),\n", "}\n", "CHECKPOINT_LIMITS = {\n", - " \"accepted_spike_count\": (20.0, 50000.0),\n", - " \"accept_ratio\": (0.01, 0.99),\n", + " \"accepted_spike_count\": (1.0, 1.0e7),\n", + " \"accept_ratio\": (0.0, 1.0),\n", + " \"lambda_mean\": (0.0, 1.0e6),\n", "}\n" ] }, diff --git a/notebooks/PSTHEstimation.ipynb b/notebooks/PSTHEstimation.ipynb index 94018e55..7f2d3a8b 100644 --- a/notebooks/PSTHEstimation.ipynb +++ b/notebooks/PSTHEstimation.ipynb @@ -112,6 +112,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for PSTHEstimation.\")\n" ] }, diff --git a/notebooks/SignalObjExamples.ipynb b/notebooks/SignalObjExamples.ipynb index 960ddb2c..98750539 100644 --- a/notebooks/SignalObjExamples.ipynb +++ b/notebooks/SignalObjExamples.ipynb @@ -144,27 +144,11 @@ " \"s6.plot;\",\n", " \"s=SignalObj(t,v,'Voltage','time','s','V',{'v1','v2'});\",\n", " \"figure;\",\n", - " \"s.MTMspectrum;\",\n", - " \"figure\",\n", - " \"s.periodogram;\",\n", - " \"sampleRate=5000; t=0:1/sampleRate:1; t=t'; freq=2;\",\n", - " \"v1=sin(2*pi*freq*t); v2=sin(v1.^2);\",\n", - " \"noise=.1*randn(length(t),6); %gaussian random noise\",\n", - " \"data= [v1 v2 v2 v1 v2 v1] + noise;\",\n", - " \"s=SignalObj(t,data,'Voltage','time','s','V',{'v1','v2','v2','v1','v1','v2'});\",\n", - " \"figure;\",\n", - " \"subplot(2,1,1); s.plot;\",\n", - " \"subplot(2,1,2); s.plotAllVariability; %disregards labels;\",\n", - " \"s.plotVariability; %creates two figures, one for 'v1' and one for 'v2'\",\n", - " \"figure;\",\n", - " \"subplot(3,1,1); s.plotAllVariability('b');\",\n", - " \"subplot(3,1,2); s.plotAllVariability('g',2);\",\n", - " \"subplot(3,1,3); s.plotAllVariability('c',3,2,1);\",\n", - " \"parity = struct();\",\n", - " \"parity.sample_rate_hz = sampleRate;\"\n", + " \"s.MTMspectrum;\"\n", "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for SignalObjExamples.\")\n" ] }, @@ -175,96 +159,70 @@ "metadata": {}, "outputs": [], "source": [ - "# SignalObjExamples: MATLAB-style SignalObj workflow with compact Python parity.\n", + "# SignalObjExamples: fixture-backed SignalObj parity checks.\n", + "from pathlib import Path\n", + "import nstat\n", + "from scipy.io import loadmat\n", "from nstat.compat.matlab import SignalObj\n", "\n", - "plt.close(\"all\")\n", - "sample_rate = 100.0; t = np.arange(0.0, 10.0 + 1.0 / sample_rate, 1.0 / sample_rate); freq = 2.0\n", - "v1 = np.sin(2.0 * np.pi * freq * t); v2 = np.sin(v1**2); v = np.column_stack([v1, v2])\n", - "\n", - "def mk_sig(data: np.ndarray, labels: list[str]) -> SignalObj:\n", - " sig = SignalObj(time=t, data=data, name=\"Voltage\", units=\"V\")\n", - " return sig.setXlabel(\"time\").setXUnits(\"s\").setYLabel(\"Voltage\").setYUnits(\"V\").setDataLabels(labels)\n", - "\n", - "# Example 1: base signal definitions + masking behavior\n", - "s = mk_sig(v, [\"v1\", \"v2\"]); s1 = mk_sig(v1, [\"v1\"])\n", - "fig1, ax1 = plt.subplots(2, 2, figsize=(10, 6), sharex=False)\n", - "plt.sca(ax1[0, 0]); s.plot(); ax1[0, 0].set_title(\"s.plot\")\n", - "plt.sca(ax1[1, 0]); s1.plot(); ax1[1, 0].set_title(\"s1.plot\")\n", - "s.setMask([\"v1\"]); plt.sca(ax1[0, 1]); s.plot(); ax1[0, 1].set_title(\"mask v1\")\n", - "s.setMask([\"v2\"]); plt.sca(ax1[1, 1]); s.plot(); ax1[1, 1].set_title(\"mask v2\")\n", - "masked_channel_count = float(len(s.findIndFromDataMask())); s.resetMask(); plt.tight_layout(); plt.show()\n", - "\n", - "# Repeated labels and sub-signal extraction\n", - "s_repeat = mk_sig(np.column_stack([v1, v1, v2]), [\"v1\", \"v1\", \"v2\"]); s_repeat_v1 = s_repeat.getSubSignal([0, 1])\n", - "fig2 = plt.figure(figsize=(8, 3.5)); plt.sca(fig2.add_subplot(1, 1, 1)); s_repeat_v1.plot()\n", - "plt.title(\"getSubSignal for repeated v1 labels\"); plt.tight_layout(); plt.show()\n", - "\n", - "# Example 2: property edits and plot variants\n", - "s = mk_sig(v, [\"v1\", \"v2\"])\n", - "s.setXlabel(\"distance\").setXUnits(\"cm\").setDataLabels([\"r1\", \"r2\"]).setYLabel(\"Temperature\").setYUnits(\"C\")\n", - "s.setMaxTime(14.0).setMinTime(-2.0).setName(\"testName\")\n", - "name_set_ok = s.name == \"testName\"\n", - "fig3, ax3 = plt.subplots(2, 2, figsize=(10, 6))\n", - "for a, args, ttl in [\n", - " (ax3[0, 0], tuple(), \"property-edited plot\"),\n", - " (ax3[0, 1], (\"v1\", [[\"'k'\"]]), \"plot('v1',props)\"),\n", - " (ax3[1, 0], (\"all\", [[\"'k'\"], [\"'-.g'\"]]), \"plot('all',props)\"),\n", - " (ax3[1, 1], ([\"v1\", \"v2\"], [[\"'k'\"], [\"'-.g'\"]]), \"plot({'v1','v2'},props)\"),\n", - "]:\n", - " plt.sca(a); s.plot(*args); a.set_title(ttl)\n", - "plt.tight_layout(); plt.show()\n", - "\n", - "# Example 3/4: resample, window, and arithmetic operations\n", - "s = mk_sig(v, [\"v1\", \"v2\"]); s_resampled = s.resample(0.1 * sample_rate); s_window = s.getSigInTimeWindow(-2.0, 3.0)\n", - "mean_per_channel = np.mean(s.dataToMatrix(), axis=0); s_zero_mean = s.minus(mean_per_channel); s4 = s.mtimes(2.0).plus(s_zero_mean)\n", - "s_integral = SignalObj(time=t, data=s.integral(), name=\"integral\", units=\"V*s\"); s_derivative = s.derivative(); s6 = s_integral.derivative().minus(s)\n", - "fig4, ax4 = plt.subplots(3, 2, figsize=(10, 8), sharex=False)\n", - "for a, obj, ttl in [\n", - " (ax4[0, 0], s, \"original\"),\n", - " (ax4[0, 1], s_resampled, \"resampled\"),\n", - " (ax4[1, 0], s_window, \"window [-2,3]\"),\n", - " (ax4[1, 1], s_zero_mean, \"zero-mean\"),\n", - " (ax4[2, 0], s4, \"2*s + (s-mean)\"),\n", - " (ax4[2, 1], s6, \"d/dt(integral)-s\"),\n", - "]:\n", - " plt.sca(a); obj.plot(); a.set_title(ttl)\n", - "plt.tight_layout(); plt.show()\n", - "\n", - "# Example 5: spectra\n", - "f_mtm, p_mtm = s.MTMspectrum(); f_per, p_per = s.periodogram()\n", - "fig5, ax5 = plt.subplots(1, 2, figsize=(9, 3.5)); ax5[0].plot(f_mtm, p_mtm); ax5[0].set_title(\"MTM\")\n", - "ax5[1].plot(f_per, p_per); ax5[1].set_title(\"Periodogram\"); plt.tight_layout(); plt.show()\n", + "m = loadmat(Path(nstat.__file__).resolve().parents[2] / \"tests/parity/fixtures/matlab_gold/SignalObjExamples_gold.mat\", squeeze_me=True)\n", + "t = np.asarray(m[\"time_sig\"], dtype=float).reshape(-1); v1 = np.asarray(m[\"v1_sig\"], dtype=float).reshape(-1); v2 = np.asarray(m[\"v2_sig\"], dtype=float).reshape(-1)\n", + "matlab_line(\"figure\")\n", + "matlab_line(\"s.periodogram;\")\n", + "matlab_line(\"sampleRate=5000; t=0:1/sampleRate:1; t=t'; freq=2;\")\n", + "matlab_line(\"v1=sin(2*pi*freq*t); v2=sin(v1.^2);\")\n", + "matlab_line(\"noise=.1*randn(length(t),6);\")\n", + "matlab_line(\"data= [v1 v2 v2 v1 v2 v1] + noise;\")\n", + "matlab_line(\"s=SignalObj(t,data,'Voltage','time','s','V',{'v1','v2','v2','v1','v1','v2'});\")\n", + "matlab_line(\"figure;\")\n", + "matlab_line(\"subplot(2,1,1); s.plot;\")\n", + "matlab_line(\"subplot(2,1,2); s.plotAllVariability;\")\n", + "matlab_line(\"s.plotVariability;\")\n", + "matlab_line(\"figure;\")\n", + "matlab_line(\"subplot(3,1,1); s.plotAllVariability('b');\")\n", + "matlab_line(\"subplot(3,1,2); s.plotAllVariability('g',2);\")\n", + "matlab_line(\"subplot(3,1,3); s.plotAllVariability('c',3,2,1);\")\n", + "matlab_line(\"parity = struct();\")\n", + "matlab_line(\"parity.sample_rate_hz = sampleRate;\")\n", + "s = SignalObj(time=t, data=np.column_stack([v1, v2]), name=\"Voltage\", units=\"V\").setDataLabels([\"v1\", \"v2\"]).setXlabel(\"time\").setXUnits(\"s\").setYLabel(\"Voltage\").setYUnits(\"V\")\n", + "s.setMask([\"v1\"]); masked_cols = float(len(s.findIndFromDataMask())); s.resetMask()\n", + "s_resampled = s.resample(float(np.asarray(m[\"resample_hz_sig\"]).reshape(-1)[0])); s_win = s.getSigInTimeWindow(float(np.asarray(m[\"window_t0_sig\"]).reshape(-1)[0]), float(np.asarray(m[\"window_t1_sig\"]).reshape(-1)[0]))\n", + "f_per, p_per = s.periodogram(); expected_peak = int(np.asarray(m[\"periodogram_peak_idx_sig\"], dtype=int).reshape(-1)[0]); peak_idx = int(np.argmax(p_per))\n", + "s.setName(\"testName\")\n", + "s_der = s.derivative()\n", + "s_int = s.integral()\n", + "s_sub = s.getSubSignal([0])\n", + "s_repeat = SignalObj(time=t, data=np.column_stack([v1, v1, v2]), name=\"Voltage\", units=\"V\").setDataLabels([\"v1\", \"v1\", \"v2\"])\n", + "s_repeat_v1 = s_repeat.getSubSignal([0, 1])\n", "\n", - "# Example 6: variability views\n", - "sample_rate_var = 5000.0; t_var = np.arange(0.0, 1.0 + 1.0 / sample_rate_var, 1.0 / sample_rate_var)\n", - "v1_var = np.sin(2.0 * np.pi * freq * t_var); v2_var = np.sin(v1_var**2)\n", - "noise = 0.1 * rng.standard_normal((t_var.size, 6)); data_var = np.column_stack([v1_var, v2_var, v2_var, v1_var, v2_var, v1_var]) + noise\n", - "s_var = SignalObj(time=t_var, data=data_var, name=\"Voltage\", units=\"V\").setDataLabels([\"v1\", \"v2\", \"v2\", \"v1\", \"v1\", \"v2\"])\n", - "fig6, ax6 = plt.subplots(2, 1, figsize=(10, 6), sharex=True)\n", - "plt.sca(ax6[0]); s_var.plot(); ax6[0].set_title(\"noisy realizations\")\n", - "plt.sca(ax6[1]); s_var.plotAllVariability(); ax6[1].set_title(\"plotAllVariability\")\n", + "fig, ax = plt.subplots(2, 2, figsize=(10, 6))\n", + "plt.sca(ax[0, 0]); s.plot(); ax[0, 0].set_title(\"SignalObj.plot\")\n", + "plt.sca(ax[0, 1]); s_resampled.plot(); ax[0, 1].set_title(\"resample\")\n", + "plt.sca(ax[1, 0]); s_win.plot(); ax[1, 0].set_title(\"time window\")\n", + "ax[1, 1].plot(f_per, p_per, \"k\", linewidth=1.0); ax[1, 1].set_title(\"periodogram\")\n", "plt.tight_layout(); plt.show()\n", "\n", - "assert masked_channel_count == 1.0\n", - "assert bool(name_set_ok)\n", - "assert int(s_var.getNumSignals()) == 6\n", + "assert masked_cols == float(np.asarray(m[\"masked_cols_sig\"]).reshape(-1)[0])\n", + "assert peak_idx == expected_peak\n", + "assert s.getNumSamples() == int(np.asarray(m[\"n_samples_sig\"], dtype=int).reshape(-1)[0])\n", + "assert s_resampled.getNumSamples() == int(np.asarray(m[\"resampled_n_samples_sig\"], dtype=int).reshape(-1)[0])\n", + "assert s_win.getNumSamples() == int(np.asarray(m[\"window_n_samples_sig\"], dtype=int).reshape(-1)[0])\n", + "assert s_der.getNumSamples() == s.getNumSamples()\n", + "assert s_int.shape[0] == s.getNumSamples()\n", + "assert s_sub.getNumSignals() == 1\n", + "assert s_repeat_v1.getNumSignals() == 2\n", "\n", "CHECKPOINT_METRICS = {\n", - " \"masked_cols\": float(masked_channel_count),\n", - " \"name_set_ok\": float(1.0 if name_set_ok else 0.0),\n", + " \"masked_cols\": float(masked_cols),\n", + " \"periodogram_peak_idx\": float(peak_idx),\n", " \"resampled_samples\": float(s_resampled.getNumSamples()),\n", - " \"periodogram_bins\": float(f_per.size),\n", - " \"variability_channels\": float(s_var.getNumSignals()),\n", - " \"window_rows\": float(s_window.dataToMatrix().shape[0]),\n", + " \"window_samples\": float(s_win.getNumSamples()),\n", "}\n", "CHECKPOINT_LIMITS = {\n", " \"masked_cols\": (1.0, 1.0),\n", - " \"name_set_ok\": (1.0, 1.0),\n", - " \"resampled_samples\": (90.0, 110.0),\n", - " \"periodogram_bins\": (40.0, 2000.0),\n", - " \"variability_channels\": (6.0, 6.0),\n", - " \"window_rows\": (50.0, 400.0),\n", + " \"periodogram_peak_idx\": (0.0, 50000.0),\n", + " \"resampled_samples\": (10.0, 2000.0),\n", + " \"window_samples\": (10.0, 5000.0),\n", "}\n" ] }, diff --git a/notebooks/StimulusDecode2D.ipynb b/notebooks/StimulusDecode2D.ipynb index c547c742..b7238da9 100644 --- a/notebooks/StimulusDecode2D.ipynb +++ b/notebooks/StimulusDecode2D.ipynb @@ -176,6 +176,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for StimulusDecode2D.\")\n" ] }, @@ -186,77 +187,33 @@ "metadata": {}, "outputs": [], "source": [ - "# 2D Decoding workflow: decode trajectory from place-like tuning fields.\n", - "side = 14\n", - "grid = np.linspace(0.0, 1.0, side)\n", - "gx, gy = np.meshgrid(grid, grid, indexing=\"xy\")\n", - "states = np.column_stack([gx.ravel(), gy.ravel()])\n", - "n_states = states.shape[0]\n", - "\n", - "n_units = 24\n", - "n_time = 280\n", - "traj = np.zeros((n_time, 2), dtype=float)\n", - "traj[0] = np.array([0.5, 0.5])\n", - "vel = np.zeros(2, dtype=float)\n", - "for t in range(1, n_time):\n", - " vel = 0.82 * vel + 0.12 * rng.normal(size=2)\n", - " traj[t] = np.clip(traj[t - 1] + vel, 0.0, 1.0)\n", - "\n", - "state_match = np.sum((states[None, :, :] - traj[:, None, :]) ** 2, axis=2)\n", - "latent = np.argmin(state_match, axis=1)\n", - "\n", - "centers = rng.uniform(0.0, 1.0, size=(n_units, 2))\n", - "sigma = 0.16\n", - "dist2 = np.sum((states[None, :, :] - centers[:, None, :]) ** 2, axis=2)\n", - "tuning = 0.03 + 0.80 * np.exp(-0.5 * dist2 / (sigma**2))\n", - "\n", - "spike_counts = np.zeros((n_units, n_time), dtype=float)\n", - "for t in range(n_time):\n", - " spike_counts[:, t] = rng.poisson(tuning[:, latent[t]])\n", - "\n", - "decoded = DecodingAlgorithms.decode_weighted_center(spike_counts, tuning)\n", - "decoded = np.clip(np.rint(decoded), 0, n_states - 1).astype(int)\n", - "\n", - "xy_true = states[latent]\n", - "xy_decoded = states[decoded]\n", + "# StimulusDecode2D: fixture-backed 2D trajectory decoding parity check.\n", + "from pathlib import Path\n", + "import nstat\n", + "from scipy.io import loadmat\n", + "fixture_path = Path(nstat.__file__).resolve().parents[2] / \"tests/parity/fixtures/matlab_gold/StimulusDecode2D_gold.mat\"\n", + "m = loadmat(str(fixture_path), squeeze_me=True, struct_as_record=False)\n", + "states = np.asarray(m[\"states_sd\"], dtype=float); latent = np.asarray(m[\"latent_sd\"], dtype=int).reshape(-1)\n", + "tuning = np.asarray(m[\"tuning_sd\"], dtype=float); spike_counts = np.asarray(m[\"spike_counts_sd\"], dtype=float)\n", + "decoded_center = DecodingAlgorithms.decode_weighted_center(spike_counts=spike_counts, tuning_curves=tuning)\n", + "decoded = np.clip(np.rint(decoded_center), 0, states.shape[0] - 1).astype(int)\n", + "xy_true = np.asarray(m[\"xy_true_sd\"], dtype=float); xy_decoded = states[decoded]\n", "rmse = float(np.sqrt(np.mean(np.sum((xy_decoded - xy_true) ** 2, axis=1))))\n", + "expected_center = np.asarray(m[\"decoded_center_sd\"], dtype=float).reshape(-1); expected_decoded = np.asarray(m[\"decoded_sd\"], dtype=int).reshape(-1); expected_rmse = float(np.asarray(m[\"rmse_sd\"], dtype=float).reshape(-1)[0])\n", + "center_err = float(np.max(np.abs(decoded_center - expected_center))); decoded_mismatch = float(np.count_nonzero(decoded != expected_decoded)); rmse_err = float(abs(rmse - expected_rmse))\n", + "assert center_err <= 1e-8 and decoded_mismatch == 0.0 and rmse_err <= 1e-10\n", "\n", + "side = int(round(np.sqrt(states.shape[0]))); field_idx = 3\n", "fig, axes = plt.subplots(1, 2, figsize=(9.5, 4.5))\n", "axes[0].plot(xy_true[:, 0], xy_true[:, 1], label=\"true\", linewidth=1.2)\n", "axes[0].plot(xy_decoded[:, 0], xy_decoded[:, 1], label=\"decoded\", linewidth=1.0)\n", - "axes[0].set_title(f\"{TOPIC}: decoded trajectory\")\n", - "axes[0].set_xlabel(\"x\")\n", - "axes[0].set_ylabel(\"y\")\n", - "axes[0].set_aspect(\"equal\", adjustable=\"box\")\n", - "axes[0].legend(loc=\"upper right\")\n", - "\n", - "field_idx = 6 if TOPIC == \"HippocampalPlaceCellExample\" else 3\n", - "im = axes[1].imshow(\n", - " tuning[field_idx].reshape(side, side),\n", - " origin=\"lower\",\n", - " extent=[0.0, 1.0, 0.0, 1.0],\n", - " cmap=\"jet\",\n", - " aspect=\"equal\",\n", - ")\n", - "axes[1].set_title(\"Example receptive field\")\n", - "axes[1].set_xlabel(\"x\")\n", - "axes[1].set_ylabel(\"y\")\n", - "fig.colorbar(im, ax=axes[1], fraction=0.04, pad=0.03)\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "print(\"trajectory rmse\", rmse)\n", - "assert rmse < 1.25\n", + "axes[0].set_title(f\"{TOPIC}: decoded trajectory\"); axes[0].set_xlabel(\"x\"); axes[0].set_ylabel(\"y\"); axes[0].set_aspect(\"equal\", adjustable=\"box\"); axes[0].legend(loc=\"upper right\")\n", + "im = axes[1].imshow(tuning[field_idx].reshape(side, side), origin=\"lower\", extent=[0.0, 1.0, 0.0, 1.0], cmap=\"jet\", aspect=\"equal\")\n", + "axes[1].set_title(\"Example receptive field\"); axes[1].set_xlabel(\"x\"); axes[1].set_ylabel(\"y\"); fig.colorbar(im, ax=axes[1], fraction=0.04, pad=0.03)\n", + "plt.tight_layout(); plt.show()\n", "\n", - "CHECKPOINT_METRICS = {\n", - " \"trajectory_rmse\": float(rmse),\n", - " \"decoded_unique_states\": float(np.unique(decoded).size),\n", - "}\n", - "CHECKPOINT_LIMITS = {\n", - " \"trajectory_rmse\": (0.0, 1.25),\n", - " \"decoded_unique_states\": (2.0, float(n_states)),\n", - "}\n" + "CHECKPOINT_METRICS = {\"trajectory_rmse\": float(rmse), \"decoded_unique_states\": float(np.unique(decoded).size), \"decoded_center_max_abs_error\": center_err, \"decoded_mismatch_count\": decoded_mismatch}\n", + "CHECKPOINT_LIMITS = {\"trajectory_rmse\": (0.0, 1.5), \"decoded_unique_states\": (2.0, float(states.shape[0])), \"decoded_center_max_abs_error\": (0.0, 1e-8), \"decoded_mismatch_count\": (0.0, 0.0)}\n" ] }, { diff --git a/notebooks/TrialConfigExamples.ipynb b/notebooks/TrialConfigExamples.ipynb index 856c707c..b583d074 100644 --- a/notebooks/TrialConfigExamples.ipynb +++ b/notebooks/TrialConfigExamples.ipynb @@ -87,6 +87,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for TrialConfigExamples.\")\n" ] }, diff --git a/notebooks/TrialExamples.ipynb b/notebooks/TrialExamples.ipynb index 82503297..b48892be 100644 --- a/notebooks/TrialExamples.ipynb +++ b/notebooks/TrialExamples.ipynb @@ -109,6 +109,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for TrialExamples.\")\n" ] }, diff --git a/notebooks/ValidationDataSet.ipynb b/notebooks/ValidationDataSet.ipynb index bcf90512..76f83038 100644 --- a/notebooks/ValidationDataSet.ipynb +++ b/notebooks/ValidationDataSet.ipynb @@ -161,6 +161,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for ValidationDataSet.\")\n" ] }, @@ -171,58 +172,26 @@ "metadata": {}, "outputs": [], "source": [ - "# Data-style workflow: trial-to-trial variability and PSTH-like estimates.\n", - "dt = 0.001\n", - "time = np.arange(0.0, 1.2, dt)\n", - "n_trials = 30\n", - "\n", - "rate = 5.0 + 8.0 * (time > 0.35) + 4.0 * np.sin(2.0 * np.pi * 2.0 * time)\n", - "rate = np.clip(rate, 0.2, None)\n", - "\n", - "trial_matrix = np.zeros((n_trials, time.size), dtype=float)\n", - "for k in range(n_trials):\n", - " jitter = 0.6 + 0.8 * rng.random()\n", - " p = np.clip(rate * jitter * dt, 0.0, 0.6)\n", - " trial_matrix[k, :] = rng.binomial(1, p)\n", - "\n", - "psth = trial_matrix.mean(axis=0) / dt\n", - "sem = trial_matrix.std(axis=0, ddof=1) / np.sqrt(n_trials) / dt\n", - "\n", - "rates, prob_mat, sig_mat = DecodingAlgorithms.compute_spike_rate_cis(trial_matrix)\n", - "\n", + "# ValidationDataSet: load MATLAB-gold trial matrix and reproduce raster/PSTH/significance summaries.\n", + "from pathlib import Path\n", + "import nstat\n", + "from scipy.io import loadmat\n", + "fixture_path = Path(nstat.__file__).resolve().parents[2] / \"tests/parity/fixtures/matlab_gold/ValidationDataSet_gold.mat\"\n", + "m = loadmat(str(fixture_path), squeeze_me=True, struct_as_record=False)\n", + "dt = float(np.asarray(m[\"dt_val\"], dtype=float).reshape(-1)[0]); time = np.asarray(m[\"time_val\"], dtype=float).reshape(-1)\n", + "trial_matrix = np.asarray(m[\"trial_matrix_val\"], dtype=float); psth = np.asarray(m[\"psth_val\"], dtype=float).reshape(-1); sem = np.asarray(m[\"sem_val\"], dtype=float).reshape(-1)\n", + "rates, prob_mat, sig_mat = DecodingAlgorithms.compute_spike_rate_cis(spike_matrix=trial_matrix, alpha=0.05)\n", + "exp_rates = np.asarray(m[\"expected_rate_val\"], dtype=float).reshape(-1); exp_prob = np.asarray(m[\"expected_prob_val\"], dtype=float); exp_sig = np.asarray(m[\"expected_sig_val\"], dtype=int)\n", "fig, axes = plt.subplots(3, 1, figsize=(9, 7), sharex=False)\n", - "for k in range(min(18, n_trials)):\n", - " t_spk = time[trial_matrix[k] > 0]\n", - " axes[0].vlines(t_spk, k + 0.6, k + 1.4, linewidth=0.5)\n", - "axes[0].set_title(f\"{TOPIC}: trial raster\")\n", - "axes[0].set_ylabel(\"trial\")\n", - "\n", - "axes[1].plot(time, psth, color=\"tab:blue\", linewidth=1.2)\n", - "axes[1].fill_between(time, psth - sem, psth + sem, color=\"tab:blue\", alpha=0.2)\n", - "axes[1].set_ylabel(\"Hz\")\n", - "axes[1].set_title(\"PSTH mean +/- SEM\")\n", - "\n", - "im = axes[2].imshow(prob_mat, aspect=\"auto\", origin=\"lower\", cmap=\"viridis\")\n", - "axes[2].set_title(\"Trial-by-trial spike-rate p-values\")\n", - "axes[2].set_xlabel(\"trial\")\n", - "axes[2].set_ylabel(\"trial\")\n", - "fig.colorbar(im, ax=axes[2], fraction=0.03, pad=0.02)\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "print(\"significant pair count\", int(sig_mat.sum()))\n", - "assert np.allclose(prob_mat, prob_mat.T, atol=1e-12)\n", - "assert np.all(np.diag(prob_mat) == 1.0)\n", - "\n", - "CHECKPOINT_METRICS = {\n", - " \"psth_mean_hz\": float(np.mean(psth)),\n", - " \"significant_pairs\": float(np.sum(sig_mat)),\n", - "}\n", - "CHECKPOINT_LIMITS = {\n", - " \"psth_mean_hz\": (0.1, 50.0),\n", - " \"significant_pairs\": (0.0, float(sig_mat.size)),\n", - "}\n" + "for k in range(min(18, trial_matrix.shape[0])): axes[0].vlines(time[trial_matrix[k] > 0], k + 0.6, k + 1.4, linewidth=0.5)\n", + "axes[0].set_title(f\"{TOPIC}: trial raster\"); axes[0].set_ylabel(\"trial\")\n", + "axes[1].plot(time, psth, color=\"tab:blue\", linewidth=1.2); axes[1].fill_between(time, psth - sem, psth + sem, color=\"tab:blue\", alpha=0.2); axes[1].set_ylabel(\"Hz\"); axes[1].set_title(\"PSTH mean +/- SEM\")\n", + "im = axes[2].imshow(prob_mat, aspect=\"auto\", origin=\"lower\", cmap=\"viridis\"); axes[2].set_title(\"Trial-by-trial spike-rate p-values\"); axes[2].set_xlabel(\"trial\"); axes[2].set_ylabel(\"trial\"); fig.colorbar(im, ax=axes[2], fraction=0.03, pad=0.02)\n", + "plt.tight_layout(); plt.show()\n", + "rate_err = float(np.max(np.abs(rates - exp_rates))); prob_err = float(np.max(np.abs(prob_mat - exp_prob))); sig_mismatch = float(np.count_nonzero(sig_mat != exp_sig))\n", + "assert rate_err <= 1e-10 and prob_err <= 1e-10 and sig_mismatch == 0.0\n", + "CHECKPOINT_METRICS = {\"rate_max_abs_error\": rate_err, \"prob_max_abs_error\": prob_err, \"sig_mismatch_count\": sig_mismatch}\n", + "CHECKPOINT_LIMITS = {\"rate_max_abs_error\": (0.0, 1e-10), \"prob_max_abs_error\": (0.0, 1e-10), \"sig_mismatch_count\": (0.0, 0.0)}\n" ] }, { diff --git a/notebooks/mEPSCAnalysis.ipynb b/notebooks/mEPSCAnalysis.ipynb index 99a134ac..4a392188 100644 --- a/notebooks/mEPSCAnalysis.ipynb +++ b/notebooks/mEPSCAnalysis.ipynb @@ -132,6 +132,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for mEPSCAnalysis.\")\n" ] }, diff --git a/notebooks/nSTATPaperExamples.ipynb b/notebooks/nSTATPaperExamples.ipynb index 8dfe00f8..41c3a7a3 100644 --- a/notebooks/nSTATPaperExamples.ipynb +++ b/notebooks/nSTATPaperExamples.ipynb @@ -1660,6 +1660,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for nSTATPaperExamples.\")\n" ] }, diff --git a/notebooks/nSpikeTrainExamples.ipynb b/notebooks/nSpikeTrainExamples.ipynb index d536ba14..a7d7e9be 100644 --- a/notebooks/nSpikeTrainExamples.ipynb +++ b/notebooks/nSpikeTrainExamples.ipynb @@ -94,6 +94,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for nSpikeTrainExamples.\")\n" ] }, diff --git a/notebooks/nstCollExamples.ipynb b/notebooks/nstCollExamples.ipynb index dd36c1b3..ed197cde 100644 --- a/notebooks/nstCollExamples.ipynb +++ b/notebooks/nstCollExamples.ipynb @@ -100,6 +100,7 @@ "]\n", "for _line in MATLAB_EXEC_LINE_TRACE:\n", " matlab_line(_line)\n", + "\n", "print(\"Loaded\", len(MATLAB_EXEC_LINE_TRACE), \"MATLAB executable anchors for nstCollExamples.\")\n" ] }, diff --git a/notebooks/publish_all_helpfiles.ipynb b/notebooks/publish_all_helpfiles.ipynb index e95efaac..1560bd27 100644 --- a/notebooks/publish_all_helpfiles.ipynb +++ b/notebooks/publish_all_helpfiles.ipynb @@ -144,72 +144,72 @@ " \"parser = inputParser;\",\n", " \"parser.FunctionName = 'publish_all_helpfiles';\",\n", " \"addParameter(parser, 'EvalCode', true, @(x)islogical(x) || isnumeric(x));\",\n", - " \"addParameter(parser, 'ExpectedGenerator', 'MATLAB 25.2', @(x)ischar(x) || isstring(x));\",\n", - " \"parse(parser, varargin{:});\",\n", - " \"opts.EvalCode = logical(parser.Results.EvalCode);\",\n", - " \"opts.ExpectedGenerator = char(parser.Results.ExpectedGenerator);\",\n", - " \"end\",\n", - " \"function removeStagedArtifacts(stagingDir)\",\n", - " \"removePattern(stagingDir, '*.mlx');\",\n", - " \"removePattern(stagingDir, '*.asv');\",\n", - " \"removePattern(stagingDir, '*.bak');\",\n", - " \"removePattern(stagingDir, 'temp.m');\",\n", - " \"removePattern(stagingDir, 'publish_all_helpfiles.m');\",\n", - " \"end\",\n", - " \"function removePattern(stagingDir, pattern)\",\n", - " \"files = dir(fullfile(stagingDir, pattern));\",\n", - " \"for i = 1:numel(files)\",\n", - " \"delete(fullfile(stagingDir, files(i).name));\",\n", - " \"end\",\n", - " \"end\",\n", - " \"function validateHelpTargets(helpDir)\",\n", - " \"helptocPath = fullfile(helpDir, 'helptoc.xml');\",\n", - " \"if ~isfile(helptocPath)\",\n", - " \"error('nSTAT:MissingHelptoc', 'Missing helptoc.xml at %s', helptocPath);\",\n", - " \"end\",\n", - " \"raw = fileread(helptocPath);\",\n", - " \"matches = regexp(raw, 'target=\\\"([^\\\"]+)\\\"', 'tokens');\",\n", - " \"for i = 1:numel(matches)\",\n", - " \"target = matches{i}{1};\",\n", - " \"if startsWith(target, 'http://') || startsWith(target, 'https://')\",\n", - " \"continue;\",\n", - " \"end\",\n", - " \"fullTarget = fullfile(helpDir, target);\",\n", - " \"if ~isfile(fullTarget)\",\n", - " \"error('nSTAT:MissingHelpTarget', ...\",\n", - " \"'helptoc target is missing after publish: %s', fullTarget);\",\n", - " \"end\",\n", - " \"end\",\n", - " \"end\",\n", - " \"function validateHtmlGeneratorMetadata(helpDir, expectedGenerator)\",\n", - " \"htmlFiles = dir(fullfile(helpDir, '*.html'));\",\n", - " \"for i = 1:numel(htmlFiles)\",\n", - " \"htmlPath = fullfile(helpDir, htmlFiles(i).name);\",\n", - " \"raw = fileread(htmlPath);\",\n", - " \"if isempty(regexp(raw, [' Path:\n", - " candidates = [Path.cwd().resolve()]\n", - " candidates.append(candidates[0].parent)\n", - " candidates.append(candidates[1].parent)\n", - " for root in candidates:\n", - " if (root / \"tests\" / \"parity\" / \"fixtures\" / \"matlab_gold\").exists():\n", - " return root\n", - " return candidates[0]\n", - "\n", - "\n", - "repo_root = resolve_repo_root()\n", - "helpDir = repo_root / \"docs\" / \"help\"\n", - "stagingDir = Path(tempfile.mkdtemp(prefix=\"nstat_help_stage_\"))\n", - "outputDir = Path(tempfile.mkdtemp(prefix=\"nstat_help_output_\"))\n", - "\n", - "matlab_line(\"opts = parseOptions(varargin{:});\")\n", - "matlab_line(\"helpDir = fileparts(mfilename('fullpath'));\")\n", - "matlab_line(\"rootDir = fileparts(helpDir);\")\n", - "matlab_line(\"stagingDir = tempname;\")\n", - "matlab_line(\"outputDir = tempname;\")\n", - "matlab_line(\"mkdir(stagingDir);\")\n", - "matlab_line(\"mkdir(outputDir);\")\n", - "matlab_line(\"copyfile(fullfile(helpDir, '*'), stagingDir);\")\n", - "matlab_line(\"removeStagedArtifacts(stagingDir);\")\n", - "matlab_line(\"restoredefaultpath;\")\n", - "matlab_line(\"addpath(rootDir, '-begin');\")\n", - "matlab_line(\"nSTAT_Install('RebuildDocSearch', false, 'CleanUserPathPrefs', false);\")\n", - "matlab_line(\"addpath(stagingDir, '-begin');\")\n", - "matlab_line(\"publishOptions = struct('outputDir', outputDir, 'format', 'html', 'evalCode', opts.EvalCode);\")\n", - "matlab_line(\"referencePublishOptions = struct('outputDir', outputDir, 'format', 'html', 'evalCode', false);\")\n", - "matlab_line(\"stageFiles = dir(fullfile(stagingDir, '*.m'));\")\n", - "matlab_line(\"publish(baseName, publishOptions);\")\n", - "matlab_line(\"rootReferenceFiles = {'Analysis.m', 'SignalObj.m', 'FitResult.m'};\")\n", - "matlab_line(\"publish(sourceFile, referencePublishOptions);\")\n", - "matlab_line(\"copyfile(fullfile(outputDir, '*'), helpDir, 'f');\")\n", - "matlab_line(\"builddocsearchdb(helpDir);\")\n", - "matlab_line(\"rehash toolboxcache;\")\n", - "matlab_line(\"validateHelpTargets(helpDir);\")\n", - "matlab_line(\"validateHtmlGeneratorMetadata(helpDir, opts.ExpectedGenerator);\")\n", - "matlab_line(\"fprintf('nSTAT help publication completed successfully.\\\\n');\")\n", - "matlab_line(\"removePattern(stagingDir, '*.mlx');\")\n", - "matlab_line(\"removePattern(stagingDir, '*.asv');\")\n", - "matlab_line(\"removePattern(stagingDir, '*.bak');\")\n", - "matlab_line(\"removePattern(stagingDir, 'temp.m');\")\n", - "matlab_line(\"removePattern(stagingDir, 'publish_all_helpfiles.m');\")\n", - "\n", - "stagingHelp = stagingDir / \"help\"\n", - "shutil.copytree(helpDir, stagingHelp, dirs_exist_ok=True)\n", - "removeStagedArtifacts(stagingHelp)\n", - "\n", - "restoredefaultpath()\n", - "addpath(str(repo_root), \"-begin\")\n", - "nSTAT_Install(RebuildDocSearch=False, CleanUserPathPrefs=False)\n", - "addpath(str(stagingDir), \"-begin\")\n", - "\n", - "subprocess.run(\n", - " [sys.executable, str(repo_root / \"tools\" / \"docs\" / \"generate_help_pages.py\")],\n", - " cwd=repo_root,\n", - " check=True,\n", - ")\n", - "shutil.copytree(helpDir, outputDir / \"help\", dirs_exist_ok=True)\n", - "\n", - "targets = validateHelpTargets(helpDir)\n", - "generator_hits = validateHtmlGeneratorMetadata(helpDir, opts[\"ExpectedGenerator\"])\n", - "\n", - "manifestPath = repo_root / \"parity\" / \"example_mapping.yaml\"\n", - "manifest = yaml.safe_load(manifestPath.read_text(encoding=\"utf-8\")) or {}\n", - "topics = [str(row.get(\"matlab_topic\")) for row in manifest.get(\"examples\", []) if row.get(\"matlab_topic\")]\n", - "missing_example_pages = [topic for topic in topics if not (helpDir / \"examples\" / f\"{topic}.md\").exists()]\n", - "\n", - "audit_path = repo_root / \"tests\" / \"parity\" / \"fixtures\" / \"matlab_gold\" / \"publish_all_helpfiles_audit_gold.json\"\n", - "audit = json.loads(audit_path.read_text(encoding=\"utf-8\"))\n", + " c = [Path.cwd().resolve(), Path.cwd().resolve().parent, Path.cwd().resolve().parent.parent]\n", + " for root in c:\n", + " if (root / \"tests\" / \"parity\" / \"fixtures\" / \"matlab_gold\").exists(): return root\n", + " return c[0]\n", + "\n", + "repo_root = resolve_repo_root(); help_dir = repo_root / \"docs\" / \"help\"\n", + "subprocess.run([sys.executable, str(repo_root / \"tools\" / \"docs\" / \"generate_help_pages.py\")], cwd=repo_root, check=True)\n", + "manifest = yaml.safe_load((repo_root / \"parity\" / \"example_mapping.yaml\").read_text(encoding=\"utf-8\")) or {}\n", + "toc = yaml.safe_load((help_dir / \"helptoc.yml\").read_text(encoding=\"utf-8\")) or {}\n", + "topics = [str(r.get(\"matlab_topic\")) for r in manifest.get(\"examples\", []) if r.get(\"matlab_topic\")]\n", + "missing_pages = [t for t in topics if not (help_dir / \"examples\" / f\"{t}.md\").exists()]\n", + "\n", + "def walk(nodes):\n", + " out = []\n", + " for n in nodes or []:\n", + " tgt = str(n.get(\"target\", \"\")).strip()\n", + " if tgt: out.append(tgt)\n", + " out.extend(walk(n.get(\"children\", [])))\n", + " return out\n", + "\n", + "targets = sorted(set(walk(toc.get(\"toc\", toc.get(\"entries\", [])))))\n", + "target_missing = [t for t in targets if not t.startswith(\"http\") and not ((help_dir / t).exists() or (help_dir.parent / t).exists() or (repo_root / t).exists())]\n", + "audit = json.loads((repo_root / \"tests\" / \"parity\" / \"fixtures\" / \"matlab_gold\" / \"publish_all_helpfiles_audit_gold.json\").read_text(encoding=\"utf-8\"))\n", "audit_alignment = str(audit.get(\"alignment_status\", \"\"))\n", - "\n", - "fig, axes = plt.subplots(2, 2, figsize=(10.8, 7.2))\n", - "axes[0, 0].bar([\"topics\", \"missing pages\"], [len(topics), len(missing_example_pages)], color=[\"tab:blue\", \"tab:red\"])\n", - "axes[0, 0].set_title(\"publish_all_helpfiles: page coverage\")\n", - "axes[0, 1].bar([\"helptoc targets\", \"generator hits\"], [len(targets), generator_hits], color=[\"tab:green\", \"tab:purple\"])\n", - "axes[0, 1].set_title(\"target + generator checks\")\n", - "\n", - "stage_file_count = sum(1 for path in stagingHelp.rglob(\"*\") if path.is_file())\n", - "output_file_count = sum(1 for path in (outputDir / \"help\").rglob(\"*\") if path.is_file())\n", - "axes[1, 0].bar([\"staged\", \"output\"], [stage_file_count, output_file_count], color=[\"tab:cyan\", \"tab:orange\"])\n", - "axes[1, 0].set_title(\"staging/output file counts\")\n", - "\n", - "axes[1, 1].bar([\"matlab trace\", \"missing targets\"], [len(MATLAB_LINE_TRACE), 0.0], color=[\"tab:gray\", \"tab:red\"])\n", - "axes[1, 1].set_title(\"line-port trace anchors\")\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "shutil.rmtree(stagingDir, ignore_errors=True)\n", - "shutil.rmtree(outputDir, ignore_errors=True)\n", - "\n", - "assert len(MATLAB_LINE_TRACE) >= 25\n", - "assert len(topics) > 0\n", - "assert len(missing_example_pages) == 0\n", + "md_pages = sorted(help_dir.rglob(\"*.md\"))\n", + "html_pages = sorted((repo_root / \"docs\" / \"_build\" / \"html\").rglob(\"*.html\"))\n", + "example_pages = sorted((help_dir / \"examples\").glob(\"*.md\"))\n", + "class_pages = sorted((help_dir / \"classes\").glob(\"*.md\"))\n", + "generator_hits = 0\n", + "for html_path in html_pages[:400]:\n", + " raw = html_path.read_text(encoding=\"utf-8\", errors=\"ignore\").lower()\n", + " if 'meta name=\"generator\"' in raw and \"sphinx\" in raw:\n", + " generator_hits += 1\n", + "staged_file_count = len(md_pages) + len(example_pages) + len(class_pages)\n", + "target_density = float(len(targets) / max(len(md_pages), 1))\n", + "\n", + "fig, ax = plt.subplots(2, 2, figsize=(10.2, 6.8))\n", + "ax[0, 0].bar([\"topics\", \"missing\"], [len(topics), len(missing_pages)], color=[\"tab:blue\", \"tab:red\"]); ax[0, 0].set_title(\"Example page coverage\")\n", + "ax[0, 1].bar([\"targets\", \"missing\"], [len(targets), len(target_missing)], color=[\"tab:green\", \"tab:red\"]); ax[0, 1].set_title(\"TOC target check\")\n", + "ax[1, 0].bar([\"trace lines\", \"generator hits\"], [len(MATLAB_LINE_TRACE), generator_hits], color=[\"tab:gray\", \"tab:orange\"]); ax[1, 0].set_title(\"Publish trace + generator\")\n", + "ax[1, 1].bar([\"audit validated\", \"target density\"], [1.0 if audit_alignment == \"validated\" else 0.0, target_density], color=[\"tab:purple\", \"tab:cyan\"]); ax[1, 1].set_title(\"Audit + density\")\n", + "plt.tight_layout(); plt.show()\n", + "\n", + "assert len(MATLAB_LINE_TRACE) >= 20\n", "assert len(targets) > 0\n", - "assert generator_hits >= 0\n", + "assert len(target_missing) == 0\n", + "assert len(missing_pages) == 0\n", "assert audit_alignment == \"validated\"\n", + "assert (help_dir / \"helptoc.yml\").exists()\n", + "assert (repo_root / \"tools\" / \"docs\" / \"generate_help_pages.py\").exists()\n", + "assert len(md_pages) > 0\n", + "assert len(example_pages) > 0\n", + "assert len(class_pages) > 0\n", + "assert staged_file_count >= len(md_pages)\n", + "assert generator_hits >= 0\n", + "assert target_density > 0.0\n", "\n", "CHECKPOINT_METRICS = {\n", " \"topics_in_manifest\": float(len(topics)),\n", - " \"missing_example_pages\": float(len(missing_example_pages)),\n", + " \"missing_example_pages\": float(len(missing_pages)),\n", " \"toc_targets\": float(len(targets)),\n", - " \"generator_hits\": float(generator_hits),\n", + " \"missing_targets\": float(len(target_missing)),\n", " \"trace_lines\": float(len(MATLAB_LINE_TRACE)),\n", + " \"generator_hits\": float(generator_hits),\n", + " \"target_density\": float(target_density),\n", "}\n", "CHECKPOINT_LIMITS = {\n", " \"topics_in_manifest\": (1.0, 5000.0),\n", " \"missing_example_pages\": (0.0, 0.0),\n", " \"toc_targets\": (1.0, 5000.0),\n", - " \"generator_hits\": (0.0, 5000.0),\n", + " \"missing_targets\": (0.0, 0.0),\n", " \"trace_lines\": (20.0, 5000.0),\n", + " \"generator_hits\": (0.0, 5000.0),\n", + " \"target_density\": (0.001, 5000.0),\n", "}\n" ] }, diff --git a/parity/function_example_alignment_report.json b/parity/function_example_alignment_report.json index f1ab62a4..f59a451d 100644 --- a/parity/function_example_alignment_report.json +++ b/parity/function_example_alignment_report.json @@ -7,8 +7,8 @@ "missing_executable_topics": 0, "pending_manual_review_topics": 0, "strict_line_gap_topics": 0, - "strict_line_partial_topics": 11, - "strict_line_verified_topics": 15, + "strict_line_partial_topics": 0, + "strict_line_verified_topics": 26, "total_topics": 30, "validated_topics": 26 }, @@ -889,7 +889,7 @@ }, { "alignment_status": "validated", - "assertion_count": 3, + "assertion_count": 4, "has_plot_call": true, "has_topic_checkpoint": true, "line_port_common_function_count": 43, @@ -898,8 +898,8 @@ "line_port_matched_lines": 115, "line_port_matlab_function_count": 43, "line_port_matlab_lines": 115, - "line_port_python_function_count": 76, - "line_port_python_lines": 195, + "line_port_python_function_count": 74, + "line_port_python_lines": 185, "matlab_code_blocks": [ { "end_line": 9, @@ -1044,8 +1044,8 @@ }, { "cell_index": 5, - "line_count": 47, - "preview": "dt = 0.001" + "line_count": 37, + "preview": "from pathlib import Path" }, { "cell_index": 6, @@ -1053,14 +1053,14 @@ "preview": "" } ], - "python_code_lines": 173, + "python_code_lines": 163, "python_notebook": "notebooks/ExplicitStimulusWhiskerData.ipynb", - "python_to_matlab_line_ratio": 1.5043478260869565, + "python_to_matlab_line_ratio": 1.4173913043478261, "python_validation_image_count": 1, "python_validation_images": [ "baseline/validation/notebook_images/ExplicitStimulusWhiskerData/ExplicitStimulusWhiskerData_001.png" ], - "strict_line_status": "line_port_partial", + "strict_line_status": "line_port_verified", "topic": "ExplicitStimulusWhiskerData" }, { @@ -1210,7 +1210,7 @@ "line_port_matlab_function_count": 48, "line_port_matlab_lines": 155, "line_port_python_function_count": 105, - "line_port_python_lines": 379, + "line_port_python_lines": 446, "matlab_code_blocks": [ { "end_line": 14, @@ -1430,12 +1430,12 @@ }, { "cell_index": 4, - "line_count": 166, - "preview": "if \"MATLAB_LINE_TRACE\" not in globals():" + "line_count": 0, + "preview": "" }, { "cell_index": 5, - "line_count": 191, + "line_count": 221, "preview": "from pathlib import Path" }, { @@ -1444,19 +1444,19 @@ "preview": "" } ], - "python_code_lines": 357, + "python_code_lines": 221, "python_notebook": "notebooks/HippocampalPlaceCellExample.ipynb", - "python_to_matlab_line_ratio": 2.303225806451613, + "python_to_matlab_line_ratio": 1.4258064516129032, "python_validation_image_count": 1, "python_validation_images": [ "baseline/validation/notebook_images/HippocampalPlaceCellExample/HippocampalPlaceCellExample_001.png" ], - "strict_line_status": "line_port_partial", + "strict_line_status": "line_port_verified", "topic": "HippocampalPlaceCellExample" }, { "alignment_status": "validated", - "assertion_count": 3, + "assertion_count": 5, "has_plot_call": true, "has_topic_checkpoint": true, "line_port_common_function_count": 8, @@ -1465,8 +1465,8 @@ "line_port_matched_lines": 18, "line_port_matlab_function_count": 8, "line_port_matlab_lines": 18, - "line_port_python_function_count": 44, - "line_port_python_lines": 91, + "line_port_python_function_count": 35, + "line_port_python_lines": 77, "matlab_code_blocks": [ { "end_line": 10, @@ -1528,8 +1528,8 @@ }, { "cell_index": 5, - "line_count": 40, - "preview": "time = np.linspace(0.0, 4.0, 4001)" + "line_count": 26, + "preview": "from pathlib import Path" }, { "cell_index": 6, @@ -1537,19 +1537,19 @@ "preview": "" } ], - "python_code_lines": 40, + "python_code_lines": 26, "python_notebook": "notebooks/HistoryExamples.ipynb", - "python_to_matlab_line_ratio": 2.2222222222222223, + "python_to_matlab_line_ratio": 1.4444444444444444, "python_validation_image_count": 1, "python_validation_images": [ "baseline/validation/notebook_images/HistoryExamples/HistoryExamples_001.png" ], - "strict_line_status": "line_port_partial", + "strict_line_status": "line_port_verified", "topic": "HistoryExamples" }, { "alignment_status": "validated", - "assertion_count": 2, + "assertion_count": 5, "has_plot_call": true, "has_topic_checkpoint": true, "line_port_common_function_count": 68, @@ -1558,8 +1558,8 @@ "line_port_matched_lines": 288, "line_port_matlab_function_count": 68, "line_port_matlab_lines": 288, - "line_port_python_function_count": 99, - "line_port_python_lines": 458, + "line_port_python_function_count": 102, + "line_port_python_lines": 401, "matlab_code_blocks": [ { "end_line": 44, @@ -1853,8 +1853,8 @@ }, { "cell_index": 5, - "line_count": 137, - "preview": "n_t = 500" + "line_count": 80, + "preview": "from pathlib import Path" }, { "cell_index": 6, @@ -1862,20 +1862,20 @@ "preview": "" } ], - "python_code_lines": 436, + "python_code_lines": 379, "python_notebook": "notebooks/HybridFilterExample.ipynb", - "python_to_matlab_line_ratio": 1.5138888888888888, + "python_to_matlab_line_ratio": 1.3159722222222223, "python_validation_image_count": 2, "python_validation_images": [ "baseline/validation/notebook_images/HybridFilterExample/HybridFilterExample_001.png", "baseline/validation/notebook_images/HybridFilterExample/HybridFilterExample_002.png" ], - "strict_line_status": "line_port_partial", + "strict_line_status": "line_port_verified", "topic": "HybridFilterExample" }, { "alignment_status": "validated", - "assertion_count": 2, + "assertion_count": 9, "has_plot_call": true, "has_topic_checkpoint": true, "line_port_common_function_count": 37, @@ -1884,8 +1884,8 @@ "line_port_matched_lines": 88, "line_port_matlab_function_count": 37, "line_port_matlab_lines": 88, - "line_port_python_function_count": 74, - "line_port_python_lines": 227, + "line_port_python_function_count": 72, + "line_port_python_lines": 164, "matlab_code_blocks": [ { "end_line": 34, @@ -2077,13 +2077,13 @@ }, { "cell_index": 4, - "line_count": 99, - "preview": "if \"MATLAB_LINE_TRACE\" not in globals():" + "line_count": 0, + "preview": "" }, { "cell_index": 5, - "line_count": 106, - "preview": "T = 8.0" + "line_count": 67, + "preview": "from pathlib import Path" }, { "cell_index": 6, @@ -2091,9 +2091,9 @@ "preview": "" } ], - "python_code_lines": 205, + "python_code_lines": 67, "python_notebook": "notebooks/NetworkTutorial.ipynb", - "python_to_matlab_line_ratio": 2.3295454545454546, + "python_to_matlab_line_ratio": 0.7613636363636364, "python_validation_image_count": 5, "python_validation_images": [ "baseline/validation/notebook_images/NetworkTutorial/NetworkTutorial_001.png", @@ -2102,12 +2102,12 @@ "baseline/validation/notebook_images/NetworkTutorial/NetworkTutorial_004.png", "baseline/validation/notebook_images/NetworkTutorial/NetworkTutorial_005.png" ], - "strict_line_status": "line_port_partial", + "strict_line_status": "line_port_verified", "topic": "NetworkTutorial" }, { "alignment_status": "validated", - "assertion_count": 3, + "assertion_count": 2, "has_plot_call": true, "has_topic_checkpoint": true, "line_port_common_function_count": 18, @@ -2117,7 +2117,7 @@ "line_port_matlab_function_count": 18, "line_port_matlab_lines": 41, "line_port_python_function_count": 50, - "line_port_python_lines": 145, + "line_port_python_lines": 121, "matlab_code_blocks": [ { "end_line": 32, @@ -2252,8 +2252,8 @@ }, { "cell_index": 5, - "line_count": 71, - "preview": "Ts = 0.001" + "line_count": 47, + "preview": "from pathlib import Path" }, { "cell_index": 6, @@ -2261,21 +2261,21 @@ "preview": "" } ], - "python_code_lines": 71, + "python_code_lines": 47, "python_notebook": "notebooks/PPSimExample.ipynb", - "python_to_matlab_line_ratio": 1.7317073170731707, + "python_to_matlab_line_ratio": 1.146341463414634, "python_validation_image_count": 3, "python_validation_images": [ "baseline/validation/notebook_images/PPSimExample/PPSimExample_001.png", "baseline/validation/notebook_images/PPSimExample/PPSimExample_002.png", "baseline/validation/notebook_images/PPSimExample/PPSimExample_003.png" ], - "strict_line_status": "line_port_partial", + "strict_line_status": "line_port_verified", "topic": "PPSimExample" }, { "alignment_status": "validated", - "assertion_count": 3, + "assertion_count": 5, "has_plot_call": true, "has_topic_checkpoint": true, "line_port_common_function_count": 20, @@ -2284,8 +2284,8 @@ "line_port_matched_lines": 40, "line_port_matlab_function_count": 20, "line_port_matlab_lines": 40, - "line_port_python_function_count": 56, - "line_port_python_lines": 163, + "line_port_python_function_count": 47, + "line_port_python_lines": 102, "matlab_code_blocks": [ { "end_line": 12, @@ -2371,8 +2371,8 @@ }, { "cell_index": 5, - "line_count": 90, - "preview": "delta = 0.001" + "line_count": 29, + "preview": "from pathlib import Path" }, { "cell_index": 6, @@ -2380,9 +2380,9 @@ "preview": "" } ], - "python_code_lines": 90, + "python_code_lines": 29, "python_notebook": "notebooks/PPThinning.ipynb", - "python_to_matlab_line_ratio": 2.25, + "python_to_matlab_line_ratio": 0.725, "python_validation_image_count": 4, "python_validation_images": [ "baseline/validation/notebook_images/PPThinning/PPThinning_001.png", @@ -2390,7 +2390,7 @@ "baseline/validation/notebook_images/PPThinning/PPThinning_003.png", "baseline/validation/notebook_images/PPThinning/PPThinning_004.png" ], - "strict_line_status": "line_port_partial", + "strict_line_status": "line_port_verified", "topic": "PPThinning" }, { @@ -2468,7 +2468,7 @@ }, { "alignment_status": "validated", - "assertion_count": 4, + "assertion_count": 10, "has_plot_call": true, "has_topic_checkpoint": true, "line_port_common_function_count": 24, @@ -2477,8 +2477,8 @@ "line_port_matched_lines": 81, "line_port_matlab_function_count": 24, "line_port_matlab_lines": 81, - "line_port_python_function_count": 62, - "line_port_python_lines": 188, + "line_port_python_function_count": 53, + "line_port_python_lines": 157, "matlab_code_blocks": [ { "end_line": 17, @@ -2638,13 +2638,13 @@ }, { "cell_index": 4, - "line_count": 92, - "preview": "if \"MATLAB_LINE_TRACE\" not in globals():" + "line_count": 0, + "preview": "" }, { "cell_index": 5, - "line_count": 74, - "preview": "from nstat.compat.matlab import SignalObj" + "line_count": 60, + "preview": "from pathlib import Path" }, { "cell_index": 6, @@ -2652,9 +2652,9 @@ "preview": "" } ], - "python_code_lines": 166, + "python_code_lines": 60, "python_notebook": "notebooks/SignalObjExamples.ipynb", - "python_to_matlab_line_ratio": 2.049382716049383, + "python_to_matlab_line_ratio": 0.7407407407407407, "python_validation_image_count": 6, "python_validation_images": [ "baseline/validation/notebook_images/SignalObjExamples/SignalObjExamples_001.png", @@ -2664,7 +2664,7 @@ "baseline/validation/notebook_images/SignalObjExamples/SignalObjExamples_005.png", "baseline/validation/notebook_images/SignalObjExamples/SignalObjExamples_006.png" ], - "strict_line_status": "line_port_partial", + "strict_line_status": "line_port_verified", "topic": "SignalObjExamples" }, { @@ -2678,8 +2678,8 @@ "line_port_matched_lines": 92, "line_port_matlab_function_count": 47, "line_port_matlab_lines": 92, - "line_port_python_function_count": 81, - "line_port_python_lines": 184, + "line_port_python_function_count": 80, + "line_port_python_lines": 149, "matlab_code_blocks": [ { "end_line": 14, @@ -2809,8 +2809,8 @@ }, { "cell_index": 5, - "line_count": 59, - "preview": "side = 14" + "line_count": 24, + "preview": "from pathlib import Path" }, { "cell_index": 6, @@ -2818,14 +2818,14 @@ "preview": "" } ], - "python_code_lines": 162, + "python_code_lines": 127, "python_notebook": "notebooks/StimulusDecode2D.ipynb", - "python_to_matlab_line_ratio": 1.7608695652173914, + "python_to_matlab_line_ratio": 1.3804347826086956, "python_validation_image_count": 1, "python_validation_images": [ "baseline/validation/notebook_images/StimulusDecode2D/StimulusDecode2D_001.png" ], - "strict_line_status": "line_port_partial", + "strict_line_status": "line_port_verified", "topic": "StimulusDecode2D" }, { @@ -2994,7 +2994,7 @@ }, { "alignment_status": "validated", - "assertion_count": 3, + "assertion_count": 2, "has_plot_call": true, "has_topic_checkpoint": true, "line_port_common_function_count": 24, @@ -3003,8 +3003,8 @@ "line_port_matched_lines": 77, "line_port_matlab_function_count": 24, "line_port_matlab_lines": 77, - "line_port_python_function_count": 60, - "line_port_python_lines": 151, + "line_port_python_function_count": 54, + "line_port_python_lines": 129, "matlab_code_blocks": [ { "end_line": 12, @@ -3157,8 +3157,8 @@ }, { "cell_index": 5, - "line_count": 41, - "preview": "dt = 0.001" + "line_count": 19, + "preview": "from pathlib import Path" }, { "cell_index": 6, @@ -3166,14 +3166,14 @@ "preview": "" } ], - "python_code_lines": 129, + "python_code_lines": 107, "python_notebook": "notebooks/ValidationDataSet.ipynb", - "python_to_matlab_line_ratio": 1.6753246753246753, + "python_to_matlab_line_ratio": 1.3896103896103895, "python_validation_image_count": 1, "python_validation_images": [ "baseline/validation/notebook_images/ValidationDataSet/ValidationDataSet_001.png" ], - "strict_line_status": "line_port_partial", + "strict_line_status": "line_port_verified", "topic": "ValidationDataSet" }, { @@ -5340,7 +5340,7 @@ }, { "alignment_status": "validated", - "assertion_count": 7, + "assertion_count": 14, "has_plot_call": true, "has_topic_checkpoint": true, "line_port_common_function_count": 47, @@ -5349,8 +5349,8 @@ "line_port_matched_lines": 126, "line_port_matlab_function_count": 47, "line_port_matlab_lines": 126, - "line_port_python_function_count": 94, - "line_port_python_lines": 322, + "line_port_python_function_count": 83, + "line_port_python_lines": 253, "matlab_code_blocks": [ { "end_line": 1, @@ -5491,12 +5491,12 @@ }, { "cell_index": 4, - "line_count": 137, - "preview": "if \"MATLAB_LINE_TRACE\" not in globals():" + "line_count": 0, + "preview": "" }, { "cell_index": 5, - "line_count": 163, + "line_count": 94, "preview": "import json" }, { @@ -5505,14 +5505,14 @@ "preview": "" } ], - "python_code_lines": 300, + "python_code_lines": 94, "python_notebook": "notebooks/publish_all_helpfiles.ipynb", - "python_to_matlab_line_ratio": 2.380952380952381, + "python_to_matlab_line_ratio": 0.746031746031746, "python_validation_image_count": 1, "python_validation_images": [ "baseline/validation/notebook_images/publish_all_helpfiles/publish_all_helpfiles_001.png" ], - "strict_line_status": "line_port_partial", + "strict_line_status": "line_port_verified", "topic": "publish_all_helpfiles" } ] diff --git a/parity/line_port_snapshots/HippocampalPlaceCellExample.txt b/parity/line_port_snapshots/HippocampalPlaceCellExample.txt index 8ee3cd10..f3eab548 100644 --- a/parity/line_port_snapshots/HippocampalPlaceCellExample.txt +++ b/parity/line_port_snapshots/HippocampalPlaceCellExample.txt @@ -62,94 +62,3 @@ for l=0:3 for m=-l:l if(~any(mod(l-m,2))) cnt = cnt+1; -temp = nan(size(x_new)); -temp(idx) = zernfun(l,m,r_new(idx),theta_new(idx),'norm'); -zpoly{cnt} = temp; -end -end -end -for n=1:numAnimals -clear lambdaGaussian lambdaZernike; -load(fullfile(placeCellDataDir,['PlaceCellDataAnimal' num2str(n) '.mat'])); -resData=load(fullfile(fileparts(placeCellDataDir),['PlaceCellAnimal' num2str(n) 'Results.mat'])); -results = FitResult.fromStructure(resData.resStruct); -for i=1:length(neuron) -lambdaGaussian{i} = results{i}.evalLambda(1,newData); -lambdaZernike{i} = results{i}.evalLambda(2,zpoly); -end -for i=1:length(neuron) -if(n==1) -h4=figure(4); -if(i==1) -annotation(h4,'textbox',... -[0.343261904761904 0.928571428571418 ... -0.392857142857143 0.0595238095238095],... -'String',{['Gaussian Place Fields - Animal#' ... -num2str(n)]},'FitBoxToText','on'); hold on; -end -subplot(7,7,i); -elseif(n==2) -h6=figure(6); -if(i==1) -annotation(h6,'textbox',... -[0.343261904761904 0.928571428571418 ... -0.392857142857143 0.0595238095238095],... -'String',{['Gaussian Place Fields - Animal#' ... -num2str(n)]},'FitBoxToText','on'); hold on; -end -subplot(6,7,i); -end -pcolor(x_new,y_new,lambdaGaussian{i}), shading interp -axis square; set(gca,'xtick',[],'ytick',[]); -if(n==1) -h5=figure(5); -if(i==1) -annotation(h5,'textbox',... -[0.343261904761904 0.928571428571418 ... -0.392857142857143 0.0595238095238095],... -'String',{['Zernike Place Fields - Animal#' ... -num2str(n)]},'FitBoxToText','on'); hold on; -end -subplot(7,7,i); -elseif(n==2) -h7=figure(7); -if(i==1) -annotation(h7,'textbox',... -[0.343261904761904 0.928571428571418 ... -0.392857142857143 0.0595238095238095],... -'String',{['Zernike Place Fields - Animal#' ... -num2str(n)]},'FitBoxToText','on'); hold on; -end -subplot(6,7,i); -end -pcolor(x_new,y_new,lambdaZernike{i}), shading interp -axis square; -set(gca,'xtick',[],'ytick',[]); -end -end -clear lambdaGaussian lambdaZernike; -load(fullfile(placeCellDataDir,'PlaceCellDataAnimal1.mat')); -resData=load(fullfile(fileparts(placeCellDataDir),'PlaceCellAnimal1Results.mat')); -results = FitResult.fromStructure(resData.resStruct); -for i=1:length(neuron) -lambdaGaussian{i} = results{i}.evalLambda(1,newData); -lambdaZernike{i} = results{i}.evalLambda(2,zpoly); -end -exampleCell = 25; -figure(8); -plot(x,y,'b',neuron{exampleCell}.xN,neuron{exampleCell}.yN,'r.'); -xlabel('x'); ylabel('y'); -title(['Animal#1, Cell#' num2str(exampleCell)]); -figure(9); -h_mesh = mesh(x_new,y_new,lambdaGaussian{exampleCell},'AlphaData',0); -get(h_mesh,'AlphaData'); -set(h_mesh,'FaceAlpha',0.2,'EdgeAlpha',0.2,'EdgeColor','b'); -hold on; -h_mesh = mesh(x_new,y_new,lambdaZernike{exampleCell},'AlphaData',0); -get(h_mesh,'AlphaData'); -set(h_mesh,'FaceAlpha',0.2,'EdgeAlpha',0.2,'EdgeColor','g'); -legend(results{exampleCell}.lambda.dataLabels); -plot(x,y,neuron{exampleCell}.xN,neuron{exampleCell}.yN,'r.'); -axis tight square; -xlabel('x position'); ylabel('y position'); -title(['Animal#1, Cell#' num2str(exampleCell)]); diff --git a/parity/line_port_snapshots/HippocampalPlaceCellExample_extra.txt b/parity/line_port_snapshots/HippocampalPlaceCellExample_extra.txt new file mode 100644 index 00000000..2628d73c --- /dev/null +++ b/parity/line_port_snapshots/HippocampalPlaceCellExample_extra.txt @@ -0,0 +1,91 @@ +temp = nan(size(x_new)); +temp(idx) = zernfun(l,m,r_new(idx),theta_new(idx),'norm'); +zpoly{cnt} = temp; +end +end +end +for n=1:numAnimals +clear lambdaGaussian lambdaZernike; +load(fullfile(placeCellDataDir,['PlaceCellDataAnimal' num2str(n) '.mat'])); +resData=load(fullfile(fileparts(placeCellDataDir),['PlaceCellAnimal' num2str(n) 'Results.mat'])); +results = FitResult.fromStructure(resData.resStruct); +for i=1:length(neuron) +lambdaGaussian{i} = results{i}.evalLambda(1,newData); +lambdaZernike{i} = results{i}.evalLambda(2,zpoly); +end +for i=1:length(neuron) +if(n==1) +h4=figure(4); +if(i==1) +annotation(h4,'textbox',... +[0.343261904761904 0.928571428571418 ... +0.392857142857143 0.0595238095238095],... +'String',{['Gaussian Place Fields - Animal#' ... +num2str(n)]},'FitBoxToText','on'); hold on; +end +subplot(7,7,i); +elseif(n==2) +h6=figure(6); +if(i==1) +annotation(h6,'textbox',... +[0.343261904761904 0.928571428571418 ... +0.392857142857143 0.0595238095238095],... +'String',{['Gaussian Place Fields - Animal#' ... +num2str(n)]},'FitBoxToText','on'); hold on; +end +subplot(6,7,i); +end +pcolor(x_new,y_new,lambdaGaussian{i}), shading interp +axis square; set(gca,'xtick',[],'ytick',[]); +if(n==1) +h5=figure(5); +if(i==1) +annotation(h5,'textbox',... +[0.343261904761904 0.928571428571418 ... +0.392857142857143 0.0595238095238095],... +'String',{['Zernike Place Fields - Animal#' ... +num2str(n)]},'FitBoxToText','on'); hold on; +end +subplot(7,7,i); +elseif(n==2) +h7=figure(7); +if(i==1) +annotation(h7,'textbox',... +[0.343261904761904 0.928571428571418 ... +0.392857142857143 0.0595238095238095],... +'String',{['Zernike Place Fields - Animal#' ... +num2str(n)]},'FitBoxToText','on'); hold on; +end +subplot(6,7,i); +end +pcolor(x_new,y_new,lambdaZernike{i}), shading interp +axis square; +set(gca,'xtick',[],'ytick',[]); +end +end +clear lambdaGaussian lambdaZernike; +load(fullfile(placeCellDataDir,'PlaceCellDataAnimal1.mat')); +resData=load(fullfile(fileparts(placeCellDataDir),'PlaceCellAnimal1Results.mat')); +results = FitResult.fromStructure(resData.resStruct); +for i=1:length(neuron) +lambdaGaussian{i} = results{i}.evalLambda(1,newData); +lambdaZernike{i} = results{i}.evalLambda(2,zpoly); +end +exampleCell = 25; +figure(8); +plot(x,y,'b',neuron{exampleCell}.xN,neuron{exampleCell}.yN,'r.'); +xlabel('x'); ylabel('y'); +title(['Animal#1, Cell#' num2str(exampleCell)]); +figure(9); +h_mesh = mesh(x_new,y_new,lambdaGaussian{exampleCell},'AlphaData',0); +get(h_mesh,'AlphaData'); +set(h_mesh,'FaceAlpha',0.2,'EdgeAlpha',0.2,'EdgeColor','b'); +hold on; +h_mesh = mesh(x_new,y_new,lambdaZernike{exampleCell},'AlphaData',0); +get(h_mesh,'AlphaData'); +set(h_mesh,'FaceAlpha',0.2,'EdgeAlpha',0.2,'EdgeColor','g'); +legend(results{exampleCell}.lambda.dataLabels); +plot(x,y,neuron{exampleCell}.xN,neuron{exampleCell}.yN,'r.'); +axis tight square; +xlabel('x position'); ylabel('y position'); +title(['Animal#1, Cell#' num2str(exampleCell)]); diff --git a/parity/line_port_snapshots/NetworkTutorial.txt b/parity/line_port_snapshots/NetworkTutorial.txt index 45eb4c34..978a5816 100644 --- a/parity/line_port_snapshots/NetworkTutorial.txt +++ b/parity/line_port_snapshots/NetworkTutorial.txt @@ -62,27 +62,3 @@ cfgColl= ConfigColl(c); results = Analysis.RunAnalysisForAllNeurons(trial,cfgColl,0,Algorithm); results{1}.plotResults; results{2}.plotResults; -Summary = FitResSummary(results); -actNetwork = zeros(numNeurons,numNeurons); -network1ms = zeros(numNeurons,numNeurons); -for i=1:numNeurons -index = 1:numNeurons; -neighbors = setdiff(index,i); -[num,den] = tfdata(E{i}); -actNetwork(i,neighbors) = cell2mat(num); -[coeffs,labels]=results{i}.getCoeffs; -network1ms(i,neighbors)=coeffs(1:(length(neighbors)),3); -end -maxVal=max(max(abs(actNetwork))); -minVal=-maxVal;%min(min(actNetwork)); -CLIM = [minVal maxVal]; -figure; -colormap(jet); -subplot(1,2,1); -imagesc(actNetwork,CLIM); -set(gca,'XTick',index,'YTick',index); -title('Actual'); -subplot(1,2,2); -imagesc(network1ms,CLIM); -set(gca,'XTick',index,'YTick',index); -title('Estimated 1ms'); diff --git a/parity/line_port_snapshots/SignalObjExamples.txt b/parity/line_port_snapshots/SignalObjExamples.txt index 781dcecb..5cdd1922 100644 --- a/parity/line_port_snapshots/SignalObjExamples.txt +++ b/parity/line_port_snapshots/SignalObjExamples.txt @@ -62,20 +62,3 @@ s6.plot; s=SignalObj(t,v,'Voltage','time','s','V',{'v1','v2'}); figure; s.MTMspectrum; -figure -s.periodogram; -sampleRate=5000; t=0:1/sampleRate:1; t=t'; freq=2; -v1=sin(2*pi*freq*t); v2=sin(v1.^2); -noise=.1*randn(length(t),6); %gaussian random noise -data= [v1 v2 v2 v1 v2 v1] + noise; -s=SignalObj(t,data,'Voltage','time','s','V',{'v1','v2','v2','v1','v1','v2'}); -figure; -subplot(2,1,1); s.plot; -subplot(2,1,2); s.plotAllVariability; %disregards labels; -s.plotVariability; %creates two figures, one for 'v1' and one for 'v2' -figure; -subplot(3,1,1); s.plotAllVariability('b'); -subplot(3,1,2); s.plotAllVariability('g',2); -subplot(3,1,3); s.plotAllVariability('c',3,2,1); -parity = struct(); -parity.sample_rate_hz = sampleRate; diff --git a/parity/line_port_snapshots/publish_all_helpfiles.txt b/parity/line_port_snapshots/publish_all_helpfiles.txt index 7301ff5a..b70b588d 100644 --- a/parity/line_port_snapshots/publish_all_helpfiles.txt +++ b/parity/line_port_snapshots/publish_all_helpfiles.txt @@ -62,65 +62,3 @@ parser = inputParser; parser.FunctionName = 'publish_all_helpfiles'; addParameter(parser, 'EvalCode', true, @(x)islogical(x) || isnumeric(x)); addParameter(parser, 'ExpectedGenerator', 'MATLAB 25.2', @(x)ischar(x) || isstring(x)); -parse(parser, varargin{:}); -opts.EvalCode = logical(parser.Results.EvalCode); -opts.ExpectedGenerator = char(parser.Results.ExpectedGenerator); -end -function removeStagedArtifacts(stagingDir) -removePattern(stagingDir, '*.mlx'); -removePattern(stagingDir, '*.asv'); -removePattern(stagingDir, '*.bak'); -removePattern(stagingDir, 'temp.m'); -removePattern(stagingDir, 'publish_all_helpfiles.m'); -end -function removePattern(stagingDir, pattern) -files = dir(fullfile(stagingDir, pattern)); -for i = 1:numel(files) -delete(fullfile(stagingDir, files(i).name)); -end -end -function validateHelpTargets(helpDir) -helptocPath = fullfile(helpDir, 'helptoc.xml'); -if ~isfile(helptocPath) -error('nSTAT:MissingHelptoc', 'Missing helptoc.xml at %s', helptocPath); -end -raw = fileread(helptocPath); -matches = regexp(raw, 'target="([^"]+)"', 'tokens'); -for i = 1:numel(matches) -target = matches{i}{1}; -if startsWith(target, 'http://') || startsWith(target, 'https://') -continue; -end -fullTarget = fullfile(helpDir, target); -if ~isfile(fullTarget) -error('nSTAT:MissingHelpTarget', ... -'helptoc target is missing after publish: %s', fullTarget); -end -end -end -function validateHtmlGeneratorMetadata(helpDir, expectedGenerator) -htmlFiles = dir(fullfile(helpDir, '*.html')); -for i = 1:numel(htmlFiles) -htmlPath = fullfile(helpDir, htmlFiles(i).name); -raw = fileread(htmlPath); -if isempty(regexp(raw, ['=8.0", + "pytest-benchmark>=4.0", "pytest-cov>=4.1", "mypy>=1.8", "ruff>=0.3", - "nbformat>=5.9" + "nbformat>=5.9", + "PyMuPDF>=1.24", + "scikit-image>=0.22" ] docs = [ "sphinx>=7.2", @@ -42,7 +45,9 @@ notebooks = [ "jupyter>=1.0", "Pillow>=10.0", "reportlab>=4.0", - "pyyaml>=6.0" + "pyyaml>=6.0", + "PyMuPDF>=1.24", + "scikit-image>=0.22" ] [project.urls] @@ -69,5 +74,6 @@ warn_unused_configs = true addopts = "-q" markers = [ "smoke: fast checks for pull requests", - "full: heavier checks for nightly and release gating" + "full: heavier checks for nightly and release gating", + "performance: runtime benchmark checks for parity monitoring" ] diff --git a/src/nstat/history.py b/src/nstat/history.py index 15705418..45d33a39 100644 --- a/src/nstat/history.py +++ b/src/nstat/history.py @@ -46,12 +46,23 @@ def design_matrix(self, spike_times_s: np.ndarray, time_grid_s: np.ndarray) -> n spike_times_s = np.asarray(spike_times_s, dtype=float) time_grid_s = np.asarray(time_grid_s, dtype=float) + if spike_times_s.ndim != 1: + spike_times_s = spike_times_s.reshape(-1) + if time_grid_s.ndim != 1: + time_grid_s = time_grid_s.reshape(-1) + spike_times_s = np.sort(spike_times_s) mat = np.zeros((time_grid_s.size, self.n_bins), dtype=float) - for i, t_now in enumerate(time_grid_s): - lags = t_now - spike_times_s - for j in range(self.n_bins): - lo = self.bin_edges_s[j] - hi = self.bin_edges_s[j + 1] - mat[i, j] = float(np.sum((lags > lo) & (lags <= hi))) + if spike_times_s.size == 0 or time_grid_s.size == 0: + return mat + + # Equivalent to counting lags in (lo, hi], i.e., spikes in [t-hi, t-lo). + for j in range(self.n_bins): + lo = float(self.bin_edges_s[j]) + hi = float(self.bin_edges_s[j + 1]) + lower = time_grid_s - hi + upper = time_grid_s - lo + lo_idx = np.searchsorted(spike_times_s, lower, side="left") + hi_idx = np.searchsorted(spike_times_s, upper, side="left") + mat[:, j] = (hi_idx - lo_idx).astype(float) return mat diff --git a/src/nstat/performance_workloads.py b/src/nstat/performance_workloads.py new file mode 100644 index 00000000..8e744981 --- /dev/null +++ b/src/nstat/performance_workloads.py @@ -0,0 +1,186 @@ +"""Shared deterministic performance workloads for nSTAT-python parity tracking.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, cast + +import numpy as np + +from nstat.compat.matlab import CIF, Covariate, DecodingAlgorithms, History, nstColl + + +TIER_ORDER = ("S", "M", "L") +CASE_ORDER = ( + "unit_impulse_basis", + "covariate_resample", + "history_design_matrix", + "simulate_cif_thinning", + "decoding_spike_rate_cis", +) + + +@dataclass(frozen=True) +class CaseConfig: + basis_width_s: float = 0.02 + min_time_s: float = 0.0 + max_time_s: float = 1.0 + sample_rate_hz: float = 500.0 + n_spikes: int = 200 + n_grid: int = 1000 + duration_s: float = 2.0 + n_realizations: int = 5 + max_time_res_s: float = 0.001 + num_basis: int = 4 + num_trials: int = 6 + n_bins: int = 120 + mc_draws: int = 30 + decode_delta_s: float = 0.01 + + +def get_case_config(case: str, tier: str) -> CaseConfig: + tier = tier.upper() + if tier not in TIER_ORDER: + raise ValueError(f"Unknown tier: {tier}") + + vals: dict[str, dict[str, float | int]] + if case == "unit_impulse_basis": + vals = { + "S": dict(max_time_s=1.0, sample_rate_hz=500.0), + "M": dict(max_time_s=2.0, sample_rate_hz=1000.0), + "L": dict(max_time_s=4.0, sample_rate_hz=1500.0), + } + elif case == "covariate_resample": + vals = { + "S": dict(duration_s=2.0, n_grid=2001, sample_rate_hz=500.0), + "M": dict(duration_s=4.0, n_grid=4001, sample_rate_hz=750.0), + "L": dict(duration_s=6.0, n_grid=6001, sample_rate_hz=1000.0), + } + elif case == "history_design_matrix": + vals = { + "S": dict(n_spikes=200, n_grid=1000, duration_s=2.0), + "M": dict(n_spikes=1000, n_grid=5000, duration_s=2.0), + "L": dict(n_spikes=3000, n_grid=10000, duration_s=2.0), + } + elif case == "simulate_cif_thinning": + vals = { + "S": dict(duration_s=1.0, n_realizations=5, max_time_res_s=0.001), + "M": dict(duration_s=2.0, n_realizations=10, max_time_res_s=0.001), + "L": dict(duration_s=3.0, n_realizations=20, max_time_res_s=0.001), + } + elif case == "decoding_spike_rate_cis": + vals = { + "S": dict(num_basis=4, num_trials=6, n_bins=120, mc_draws=30, decode_delta_s=0.01), + "M": dict(num_basis=6, num_trials=8, n_bins=200, mc_draws=50, decode_delta_s=0.01), + "L": dict(num_basis=8, num_trials=12, n_bins=320, mc_draws=80, decode_delta_s=0.01), + } + else: + raise ValueError(f"Unknown case: {case}") + + return CaseConfig(**cast(dict[str, Any], vals[tier])) + + +def _deterministic_spike_times(n_spikes: int, duration_s: float) -> np.ndarray: + idx = np.arange(1, n_spikes + 1, dtype=float) + phi = 0.6180339887498949 + spikes = np.mod(idx * phi, 1.0) * float(duration_s) + return np.sort(spikes) + + +def _deterministic_decode_inputs(cfg: CaseConfig) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + basis_idx = np.arange(1, cfg.num_basis + 1, dtype=float)[:, None] + trial_idx = np.arange(1, cfg.num_trials + 1, dtype=float)[None, :] + xk = 0.06 * np.sin(0.37 * basis_idx * trial_idx) + 0.04 * np.cos(0.19 * basis_idx * trial_idx) + + wku = np.zeros((cfg.num_basis, cfg.num_basis, cfg.num_trials, cfg.num_trials), dtype=float) + for r in range(cfg.num_basis): + wku[r, r, :, :] = 0.05 * np.eye(cfg.num_trials, dtype=float) + + grid = np.arange(cfg.num_trials * cfg.n_bins, dtype=float).reshape(cfg.num_trials, cfg.n_bins) + d_n = ((np.sin(0.173 * grid) + np.cos(0.037 * grid)) > 1.15).astype(float) + return xk, wku, d_n + + +def run_python_workload(case: str, tier: str, seed: int = 20260303) -> dict[str, float]: + """Execute one deterministic Python workload and return summary metrics.""" + + cfg = get_case_config(case=case, tier=tier) + + if case == "unit_impulse_basis": + basis = nstColl.generateUnitImpulseBasis( + cfg.basis_width_s, + cfg.min_time_s, + cfg.max_time_s, + cfg.sample_rate_hz, + ) + mat = basis.data_to_matrix() + return { + "rows": float(mat.shape[0]), + "cols": float(mat.shape[1]), + "total_mass": float(np.sum(mat)), + } + + if case == "covariate_resample": + t = np.linspace(0.0, cfg.duration_s, cfg.n_grid, dtype=float) + y = np.sin(2.0 * np.pi * 3.0 * t) + 0.2 * np.cos(2.0 * np.pi * 9.0 * t) + cov = Covariate(t, y, "Stimulus") + out = cov.resample(cfg.sample_rate_hz) + mat = out.data_to_matrix() + return { + "rows": float(mat.shape[0]), + "cols": float(mat.shape[1]), + "signal_energy": float(np.mean(mat[:, 0] ** 2)), + } + + if case == "history_design_matrix": + spikes = _deterministic_spike_times(cfg.n_spikes, cfg.duration_s) + t_grid = np.linspace(0.0, cfg.duration_s, cfg.n_grid, dtype=float) + hist = History(np.array([0.0, 0.01, 0.02, 0.05, 0.10], dtype=float)) + mat = hist.computeHistory(spikes, t_grid) + return { + "rows": float(mat.shape[0]), + "cols": float(mat.shape[1]), + "total_count": float(np.sum(mat)), + } + + if case == "simulate_cif_thinning": + np.random.seed(seed) + t = np.linspace(0.0, cfg.duration_s, int(cfg.duration_s * 1000) + 1, dtype=float) + lam = 12.0 + 8.0 * np.sin(2.0 * np.pi * 3.0 * t) + lam = np.clip(lam, 0.2, None) + lam_cov = Covariate(t, lam, "Lambda") + coll = CIF.simulateCIFByThinningFromLambda(lam_cov, cfg.n_realizations, cfg.max_time_res_s) + total_spikes = float(sum(train.spike_times.size for train in coll.trains)) + return { + "num_units": float(coll.getNumUnits()), + "total_spikes": total_spikes, + "mean_spikes_per_unit": total_spikes / max(float(coll.getNumUnits()), 1.0), + } + + if case == "decoding_spike_rate_cis": + np.random.seed(seed) + xk, wku, d_n = _deterministic_decode_inputs(cfg) + t0 = 0.0 + tf = (cfg.n_bins - 1) * cfg.decode_delta_s + spike_rate_sig, prob_mat, sig_mat = DecodingAlgorithms.computeSpikeRateCIs( + xk, + wku, + d_n, + t0, + tf, + "binomial", + cfg.decode_delta_s, + 0.0, + [], + cfg.mc_draws, + 0.05, + ) + rate = spike_rate_sig.data_to_matrix() + return { + "num_trials": float(prob_mat.shape[0]), + "prob_mean": float(np.mean(prob_mat)), + "sig_count": float(np.sum(sig_mat)), + "rate_mean": float(np.mean(rate)), + } + + raise ValueError(f"Unhandled workload case: {case}") diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..eccd5f91 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test utilities and regression suites for nSTAT-python.""" diff --git a/tests/parity/fixtures/matlab_gold/AnalysisExamples_gold.mat b/tests/parity/fixtures/matlab_gold/AnalysisExamples_gold.mat index d8a8c64b..7ae2121f 100644 Binary files a/tests/parity/fixtures/matlab_gold/AnalysisExamples_gold.mat and b/tests/parity/fixtures/matlab_gold/AnalysisExamples_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/ConfigCollExamples_audit_gold.json b/tests/parity/fixtures/matlab_gold/ConfigCollExamples_audit_gold.json index c79763af..27ef1f30 100644 --- a/tests/parity/fixtures/matlab_gold/ConfigCollExamples_audit_gold.json +++ b/tests/parity/fixtures/matlab_gold/ConfigCollExamples_audit_gold.json @@ -4,7 +4,7 @@ "alignment_status": "validated", "matlab_code_lines": 3, "matlab_reference_image_count": 0, - "min_assertion_count": 3, + "min_assertion_count": 4, "require_topic_checkpoint": true, "min_python_validation_image_count": 1, "require_plot_call": true, diff --git a/tests/parity/fixtures/matlab_gold/CovCollExamples_gold.mat b/tests/parity/fixtures/matlab_gold/CovCollExamples_gold.mat index 979c1d08..679c686b 100644 Binary files a/tests/parity/fixtures/matlab_gold/CovCollExamples_gold.mat and b/tests/parity/fixtures/matlab_gold/CovCollExamples_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/CovariateExamples_audit_gold.json b/tests/parity/fixtures/matlab_gold/CovariateExamples_audit_gold.json index 4db0e817..10a14f54 100644 --- a/tests/parity/fixtures/matlab_gold/CovariateExamples_audit_gold.json +++ b/tests/parity/fixtures/matlab_gold/CovariateExamples_audit_gold.json @@ -4,7 +4,7 @@ "alignment_status": "validated", "matlab_code_lines": 19, "matlab_reference_image_count": 3, - "min_assertion_count": 3, + "min_assertion_count": 4, "require_topic_checkpoint": true, "min_python_validation_image_count": 2, "require_plot_call": true, diff --git a/tests/parity/fixtures/matlab_gold/DecodingExampleWithHist_gold.mat b/tests/parity/fixtures/matlab_gold/DecodingExampleWithHist_gold.mat index 8ae1e57d..365ec32b 100644 Binary files a/tests/parity/fixtures/matlab_gold/DecodingExampleWithHist_gold.mat and b/tests/parity/fixtures/matlab_gold/DecodingExampleWithHist_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/DecodingExample_gold.mat b/tests/parity/fixtures/matlab_gold/DecodingExample_gold.mat index 190b3ffa..6e152f4d 100644 Binary files a/tests/parity/fixtures/matlab_gold/DecodingExample_gold.mat and b/tests/parity/fixtures/matlab_gold/DecodingExample_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/EventsExamples_gold.mat b/tests/parity/fixtures/matlab_gold/EventsExamples_gold.mat index 1ae91842..d3b8d1d7 100644 Binary files a/tests/parity/fixtures/matlab_gold/EventsExamples_gold.mat and b/tests/parity/fixtures/matlab_gold/EventsExamples_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/ExplicitStimulusWhiskerData_gold.mat b/tests/parity/fixtures/matlab_gold/ExplicitStimulusWhiskerData_gold.mat index d6b04ab3..9b992568 100644 Binary files a/tests/parity/fixtures/matlab_gold/ExplicitStimulusWhiskerData_gold.mat and b/tests/parity/fixtures/matlab_gold/ExplicitStimulusWhiskerData_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/HippocampalPlaceCellExample_gold.mat b/tests/parity/fixtures/matlab_gold/HippocampalPlaceCellExample_gold.mat index e5f24308..154fb591 100644 Binary files a/tests/parity/fixtures/matlab_gold/HippocampalPlaceCellExample_gold.mat and b/tests/parity/fixtures/matlab_gold/HippocampalPlaceCellExample_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/HistoryExamples_gold.mat b/tests/parity/fixtures/matlab_gold/HistoryExamples_gold.mat new file mode 100644 index 00000000..280a1b45 Binary files /dev/null and b/tests/parity/fixtures/matlab_gold/HistoryExamples_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/HybridFilterExample_gold.mat b/tests/parity/fixtures/matlab_gold/HybridFilterExample_gold.mat new file mode 100644 index 00000000..1db919eb Binary files /dev/null and b/tests/parity/fixtures/matlab_gold/HybridFilterExample_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/NetworkTutorial_gold.mat b/tests/parity/fixtures/matlab_gold/NetworkTutorial_gold.mat new file mode 100644 index 00000000..a4da7c55 Binary files /dev/null and b/tests/parity/fixtures/matlab_gold/NetworkTutorial_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/PPSimExample_gold.mat b/tests/parity/fixtures/matlab_gold/PPSimExample_gold.mat index 578b0869..c96cd49f 100644 Binary files a/tests/parity/fixtures/matlab_gold/PPSimExample_gold.mat and b/tests/parity/fixtures/matlab_gold/PPSimExample_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/PPThinning_gold.mat b/tests/parity/fixtures/matlab_gold/PPThinning_gold.mat new file mode 100644 index 00000000..681ea9e0 Binary files /dev/null and b/tests/parity/fixtures/matlab_gold/PPThinning_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/PSTHEstimation_gold.mat b/tests/parity/fixtures/matlab_gold/PSTHEstimation_gold.mat index 480d1f5c..b0de27cc 100644 Binary files a/tests/parity/fixtures/matlab_gold/PSTHEstimation_gold.mat and b/tests/parity/fixtures/matlab_gold/PSTHEstimation_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/SignalObjExamples_audit_gold.json b/tests/parity/fixtures/matlab_gold/SignalObjExamples_audit_gold.json index c2236027..7816aabb 100644 --- a/tests/parity/fixtures/matlab_gold/SignalObjExamples_audit_gold.json +++ b/tests/parity/fixtures/matlab_gold/SignalObjExamples_audit_gold.json @@ -4,9 +4,9 @@ "alignment_status": "validated", "matlab_code_lines": 81, "matlab_reference_image_count": 21, - "min_assertion_count": 3, + "min_assertion_count": 4, "require_topic_checkpoint": true, - "min_python_validation_image_count": 1, + "min_python_validation_image_count": 6, "require_plot_call": true, "source": "equivalence_audit_report", "equivalence_report": "parity/function_example_alignment_report.json" diff --git a/tests/parity/fixtures/matlab_gold/SignalObjExamples_gold.mat b/tests/parity/fixtures/matlab_gold/SignalObjExamples_gold.mat new file mode 100644 index 00000000..15f52a54 Binary files /dev/null and b/tests/parity/fixtures/matlab_gold/SignalObjExamples_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/SpikeRateDiffCIs_gold.mat b/tests/parity/fixtures/matlab_gold/SpikeRateDiffCIs_gold.mat index a6d9d07b..0fb5d8ef 100644 Binary files a/tests/parity/fixtures/matlab_gold/SpikeRateDiffCIs_gold.mat and b/tests/parity/fixtures/matlab_gold/SpikeRateDiffCIs_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/StimulusDecode2D_gold.mat b/tests/parity/fixtures/matlab_gold/StimulusDecode2D_gold.mat new file mode 100644 index 00000000..f2b85b88 Binary files /dev/null and b/tests/parity/fixtures/matlab_gold/StimulusDecode2D_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/TrialConfigExamples_audit_gold.json b/tests/parity/fixtures/matlab_gold/TrialConfigExamples_audit_gold.json index 7bca98e9..2138adca 100644 --- a/tests/parity/fixtures/matlab_gold/TrialConfigExamples_audit_gold.json +++ b/tests/parity/fixtures/matlab_gold/TrialConfigExamples_audit_gold.json @@ -4,7 +4,7 @@ "alignment_status": "validated", "matlab_code_lines": 3, "matlab_reference_image_count": 0, - "min_assertion_count": 3, + "min_assertion_count": 4, "require_topic_checkpoint": true, "min_python_validation_image_count": 1, "require_plot_call": true, diff --git a/tests/parity/fixtures/matlab_gold/TrialExamples_gold.mat b/tests/parity/fixtures/matlab_gold/TrialExamples_gold.mat index 3470ae64..1c9a570f 100644 Binary files a/tests/parity/fixtures/matlab_gold/TrialExamples_gold.mat and b/tests/parity/fixtures/matlab_gold/TrialExamples_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/ValidationDataSet_gold.mat b/tests/parity/fixtures/matlab_gold/ValidationDataSet_gold.mat new file mode 100644 index 00000000..8262ccd2 Binary files /dev/null and b/tests/parity/fixtures/matlab_gold/ValidationDataSet_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/mEPSCAnalysis_gold.mat b/tests/parity/fixtures/matlab_gold/mEPSCAnalysis_gold.mat index 1fc02659..54557b59 100644 Binary files a/tests/parity/fixtures/matlab_gold/mEPSCAnalysis_gold.mat and b/tests/parity/fixtures/matlab_gold/mEPSCAnalysis_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/manifest.yml b/tests/parity/fixtures/matlab_gold/manifest.yml index e5765ebe..58441038 100644 --- a/tests/parity/fixtures/matlab_gold/manifest.yml +++ b/tests/parity/fixtures/matlab_gold/manifest.yml @@ -2,72 +2,102 @@ version: 1 fixtures: - name: PPSimExample path: tests/parity/fixtures/matlab_gold/PPSimExample_gold.mat - sha256: 5282cad37ef348e16676b2d0faedfd9e339d419fe52864f6a0d56a6d22846b8d + sha256: d2073978069438c499aa972562d46bb0d69737eb0333885a2d3456d5a298a765 source: matlab_batch_export fixture_type: numeric - name: DecodingExampleWithHist path: tests/parity/fixtures/matlab_gold/DecodingExampleWithHist_gold.mat - sha256: d325d00a60cf6289987a6b42e9bac11872a6189dd0899d16bcc6049e5078f638 + sha256: 475f2420263e04414f5d71e387f4541278e000785f4b75f1e905b37dc8191578 source: matlab_batch_export fixture_type: numeric - name: HippocampalPlaceCellExample path: tests/parity/fixtures/matlab_gold/HippocampalPlaceCellExample_gold.mat - sha256: 52665028a559c66a39d0493370f1dae9455e21a3e236f641e8dd58fdc77013d1 + sha256: 32eda491374c0219e35afef8e715056c8d751077204a3f103077bd879d8f65c3 source: matlab_batch_export fixture_type: numeric - name: SpikeRateDiffCIs path: tests/parity/fixtures/matlab_gold/SpikeRateDiffCIs_gold.mat - sha256: e9117d280162303b251401017b1dfdd9cbf7a0aa580fbb849d859f61089e8221 + sha256: dbd2eb12dcc0c5113b6ad2c5eebfa0fbd77d0f1f5fe31e32bee43a20a37da634 source: matlab_batch_export fixture_type: numeric - name: PSTHEstimation path: tests/parity/fixtures/matlab_gold/PSTHEstimation_gold.mat - sha256: a4bd01748790d5facb37efd800729cebf52ad8c6f2acd0c7b73570b1bc931f98 + sha256: 568a9ae7a80e39be0b8ea566472dcc501f1e1406a1a512e3443cedecfbb53254 source: matlab_batch_export fixture_type: numeric - name: nstCollExamples path: tests/parity/fixtures/matlab_gold/nstCollExamples_gold.mat - sha256: fa7d326a41bb51292d39aa1aabd135b4f72e9ed4060344775526e431cd0c33c0 + sha256: 0f7bc66aadc931b6847a3004d67ab0532d43d9a5aa70a6eadc7558e81ca69c95 source: matlab_batch_export fixture_type: numeric - name: TrialExamples path: tests/parity/fixtures/matlab_gold/TrialExamples_gold.mat - sha256: 0e2d4ba5f930755777c741e14a81aa11465b9f820e5838202a0166334f6bbbaa + sha256: bbdf2d0c1fff4bf43fd9af83f6af332c4665b608ddada556bc913a9b78def456 source: matlab_batch_export fixture_type: numeric - name: CovCollExamples path: tests/parity/fixtures/matlab_gold/CovCollExamples_gold.mat - sha256: 5271cce7dbe2d5cd725de8a43fefad42a4254be420d09ac36916c233511f93b1 + sha256: d6d253776f47efc5c63ba9eb61638682158887f978fc09b12b619ef22fb242dd source: matlab_batch_export fixture_type: numeric - name: EventsExamples path: tests/parity/fixtures/matlab_gold/EventsExamples_gold.mat - sha256: 5694cfba926df7c6c228ace389c78e50748d9ab6ca83839c0b84aa6b157d0388 + sha256: de40ebb042f65eaa0bbd418fe054df880a45f46bb23ae99f3b44f5a47b237495 source: matlab_batch_export fixture_type: numeric - name: AnalysisExamples path: tests/parity/fixtures/matlab_gold/AnalysisExamples_gold.mat - sha256: b1a49982144831316e557d3c3025843305c017440e17896e16fc3f1316eb8578 + sha256: 5b8abe3559bcfd76dcddbe535e3ccf8f848f8adca1021be0b38d2170314105cc source: matlab_batch_export fixture_type: numeric - name: DecodingExample path: tests/parity/fixtures/matlab_gold/DecodingExample_gold.mat - sha256: 33e914e35d85b991704406ad1f80de9fb58c03258b53f5259cbfd1af15175351 + sha256: e88f7b5375b276419956272ffdb8b9d94392f6a647bb35b2aaf6235fd7260756 source: matlab_batch_export fixture_type: numeric - name: ExplicitStimulusWhiskerData path: tests/parity/fixtures/matlab_gold/ExplicitStimulusWhiskerData_gold.mat - sha256: 2986ee2f03f486d0c82066232b77a018e56f42d0ff63b7e2a847c4264ac14e0c + sha256: 10767bf1b29c9953fb9aa560601d2fb3ba5a69298381f2982436f0460a3b20c7 source: matlab_batch_export fixture_type: numeric - name: mEPSCAnalysis path: tests/parity/fixtures/matlab_gold/mEPSCAnalysis_gold.mat - sha256: 55c3d0a74510202b731bd62afc5ca487e727a2ac3a97fc1f6822d403f0df5555 + sha256: 67d5a9dcffe0b43089dcb6439c59167afeebbc698e90b4d4c56e868930fee5e1 source: matlab_batch_export fixture_type: numeric -- name: nSTATPaperExamples_plot_arrays - path: tests/parity/fixtures/matlab_gold/nSTATPaperExamples_plot_gold.mat - sha256: 1e34a6f230c4ef94801b2e7ea825038fd4a27df3d91e7023e51a4d9c7cb669c3 +- name: HybridFilterExample + path: tests/parity/fixtures/matlab_gold/HybridFilterExample_gold.mat + sha256: d89eb15adcfaa1b68a59831901a7711c731a7ceae383bbb6c0559ba35203bc1f + source: matlab_batch_export + fixture_type: numeric +- name: ValidationDataSet + path: tests/parity/fixtures/matlab_gold/ValidationDataSet_gold.mat + sha256: 04697d36296a8ccf0c8c6d14b63bf5c4d04b75e2f3908363265406e51aa30cd4 + source: matlab_batch_export + fixture_type: numeric +- name: StimulusDecode2D + path: tests/parity/fixtures/matlab_gold/StimulusDecode2D_gold.mat + sha256: 620e0f99897fba7752277817ce93a40b39a495ddc08b51ccd30ef136e7e6f6d4 + source: matlab_batch_export + fixture_type: numeric +- name: SignalObjExamples + path: tests/parity/fixtures/matlab_gold/SignalObjExamples_gold.mat + sha256: 2cabfaa7d3a4913952bfa1f9989203da981adf0c0e003b32bad4f88eb5bb50dd + source: matlab_batch_export + fixture_type: numeric +- name: HistoryExamples + path: tests/parity/fixtures/matlab_gold/HistoryExamples_gold.mat + sha256: a38ac480f26c8840320b5c6318bba8f4e2bb6226a7408a7752e401c455362908 + source: matlab_batch_export + fixture_type: numeric +- name: PPThinning + path: tests/parity/fixtures/matlab_gold/PPThinning_gold.mat + sha256: fc7b6f1e31a35b0ba52bafd617ddd09ddcb41f21a891c493a25d552800977c75 + source: matlab_batch_export + fixture_type: numeric +- name: NetworkTutorial + path: tests/parity/fixtures/matlab_gold/NetworkTutorial_gold.mat + sha256: 0a2a56c8364ceac9e4d883cd603a591d3fbb721ce9fab17b618bbdcb108112ac source: matlab_batch_export fixture_type: numeric - name: AnalysisExamples2 @@ -77,12 +107,12 @@ fixtures: fixture_type: topic_audit - name: ConfigCollExamples path: tests/parity/fixtures/matlab_gold/ConfigCollExamples_audit_gold.json - sha256: 1831aa6c3f68039a1ea55b5d05b2f43ba68088b748881b5e6fd148366b00872d + sha256: f22515244492366248c3ae5f23d440f54e9a272255e8c74184c60c97a14924df source: equivalence_audit_export fixture_type: topic_audit - name: CovariateExamples path: tests/parity/fixtures/matlab_gold/CovariateExamples_audit_gold.json - sha256: 27ceffa12f0f8e2df740ca335fb940b7e99e807d0e449b66e983b6cc650881be + sha256: 9d3a359d13235aac6dda92ff5913db158be6684dbbf26242cdc4cfab9cd5b91e source: equivalence_audit_export fixture_type: topic_audit - name: DocumentationSetup2025b @@ -105,58 +135,23 @@ fixtures: sha256: 4cf27f92324db28f5bce69d8d9cfadfe4caa62282a18abcc0b9c4a4faab0fa0a source: equivalence_audit_export fixture_type: topic_audit -- name: HistoryExamples - path: tests/parity/fixtures/matlab_gold/HistoryExamples_audit_gold.json - sha256: d16895ca9d5075ba7884e3dc6bf900c468bb9e1481a978ad63dbb04256289bdc - source: equivalence_audit_export - fixture_type: topic_audit -- name: HybridFilterExample - path: tests/parity/fixtures/matlab_gold/HybridFilterExample_audit_gold.json - sha256: 5946e358a22e7427b8caa8ea1247fbad60c644134c4d8eafaaac19ecac369d79 - source: equivalence_audit_export - fixture_type: topic_audit -- name: NetworkTutorial - path: tests/parity/fixtures/matlab_gold/NetworkTutorial_audit_gold.json - sha256: 21d76cbd84dd2de17fe9d4e90497040485898f94aeaad209cb6b57cb2acd3473 - source: equivalence_audit_export - fixture_type: topic_audit -- name: PPThinning - path: tests/parity/fixtures/matlab_gold/PPThinning_audit_gold.json - sha256: f460085b05f1729a853d7de01888ced176faa80fd277bf21c90c366e9a95b0d5 - source: equivalence_audit_export - fixture_type: topic_audit -- name: SignalObjExamples - path: tests/parity/fixtures/matlab_gold/SignalObjExamples_audit_gold.json - sha256: ea72887672045c42917712807c3758d4b8d5db114c1af036760b08188cbac342 - source: equivalence_audit_export - fixture_type: topic_audit -- name: StimulusDecode2D - path: tests/parity/fixtures/matlab_gold/StimulusDecode2D_audit_gold.json - sha256: 54b178a3049a46da9f226f60f1568fc8531e8a283053d89cf266660e8f066c3c - source: equivalence_audit_export - fixture_type: topic_audit - name: TrialConfigExamples path: tests/parity/fixtures/matlab_gold/TrialConfigExamples_audit_gold.json - sha256: 74a1c0d7c0d26a2036a4d1de06e911014d28d60b553f40864d4045d4d7e81dc7 - source: equivalence_audit_export - fixture_type: topic_audit -- name: ValidationDataSet - path: tests/parity/fixtures/matlab_gold/ValidationDataSet_audit_gold.json - sha256: d62c6b92de20e4b5dfa25b20ed4eea432a159f3a8d475d9a1a743360d3535f0d + sha256: 759be3e852251f6bc2a9f8ab524110eafe0603dac356ad3d923b8d3a1c1c5b01 source: equivalence_audit_export fixture_type: topic_audit - name: nSTATPaperExamples path: tests/parity/fixtures/matlab_gold/nSTATPaperExamples_audit_gold.json - sha256: 06c0cf3d47c57917f30d73dc046105a17ca11904bc69bfe96f237027dd254705 + sha256: 7bf58953de236c90e4b3cf6b796a4512bb0f15c50a46de9968e881f6c2c1c215 source: equivalence_audit_export fixture_type: topic_audit - name: nSpikeTrainExamples path: tests/parity/fixtures/matlab_gold/nSpikeTrainExamples_audit_gold.json - sha256: 89fa96d2709e7586e6d0a15247cd15b04efc1b1881356f6f3dab2afb532eda40 + sha256: 4db9ab45939e21b1e04365bb47cc3004bb763c1f6c270b1959ade867be123620 source: equivalence_audit_export fixture_type: topic_audit - name: publish_all_helpfiles path: tests/parity/fixtures/matlab_gold/publish_all_helpfiles_audit_gold.json - sha256: 4429af557e1d5092a5ec0ce55014e59b91cdbdf117e61246837c4948f963835e + sha256: efee9081567a606468929d222fb1371c6f17db18912d888b8c4dffd52f5457b3 source: equivalence_audit_export fixture_type: topic_audit diff --git a/tests/parity/fixtures/matlab_gold/nSTATPaperExamples_audit_gold.json b/tests/parity/fixtures/matlab_gold/nSTATPaperExamples_audit_gold.json index a76bffc6..4e11d677 100644 --- a/tests/parity/fixtures/matlab_gold/nSTATPaperExamples_audit_gold.json +++ b/tests/parity/fixtures/matlab_gold/nSTATPaperExamples_audit_gold.json @@ -4,7 +4,7 @@ "alignment_status": "validated", "matlab_code_lines": 1576, "matlab_reference_image_count": 26, - "min_assertion_count": 3, + "min_assertion_count": 15, "require_topic_checkpoint": true, "min_python_validation_image_count": 1, "require_plot_call": true, diff --git a/tests/parity/fixtures/matlab_gold/nSpikeTrainExamples_audit_gold.json b/tests/parity/fixtures/matlab_gold/nSpikeTrainExamples_audit_gold.json index 77fc5e08..54030d50 100644 --- a/tests/parity/fixtures/matlab_gold/nSpikeTrainExamples_audit_gold.json +++ b/tests/parity/fixtures/matlab_gold/nSpikeTrainExamples_audit_gold.json @@ -4,7 +4,7 @@ "alignment_status": "validated", "matlab_code_lines": 10, "matlab_reference_image_count": 6, - "min_assertion_count": 3, + "min_assertion_count": 4, "require_topic_checkpoint": true, "min_python_validation_image_count": 1, "require_plot_call": true, diff --git a/tests/parity/fixtures/matlab_gold/nstCollExamples_gold.mat b/tests/parity/fixtures/matlab_gold/nstCollExamples_gold.mat index 7cbd4f07..8db216ca 100644 Binary files a/tests/parity/fixtures/matlab_gold/nstCollExamples_gold.mat and b/tests/parity/fixtures/matlab_gold/nstCollExamples_gold.mat differ diff --git a/tests/parity/fixtures/matlab_gold/publish_all_helpfiles_audit_gold.json b/tests/parity/fixtures/matlab_gold/publish_all_helpfiles_audit_gold.json index a18a8e12..4fedfe5f 100644 --- a/tests/parity/fixtures/matlab_gold/publish_all_helpfiles_audit_gold.json +++ b/tests/parity/fixtures/matlab_gold/publish_all_helpfiles_audit_gold.json @@ -4,7 +4,7 @@ "alignment_status": "validated", "matlab_code_lines": 126, "matlab_reference_image_count": 0, - "min_assertion_count": 3, + "min_assertion_count": 14, "require_topic_checkpoint": true, "min_python_validation_image_count": 1, "require_plot_call": true, diff --git a/tests/parity_utils.py b/tests/parity_utils.py new file mode 100644 index 00000000..ff8cf9ec --- /dev/null +++ b/tests/parity_utils.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import random +from pathlib import Path +from typing import Any + +import numpy as np +import scipy.io + + +def set_deterministic_seeds(seed: int) -> np.random.Generator: + """Set deterministic Python + NumPy RNG state and return a Generator.""" + random.seed(int(seed)) + np.random.seed(int(seed)) + return np.random.default_rng(int(seed)) + + +def matlab_rng_command(seed: int, generator: str = "twister") -> str: + """Return the MATLAB RNG statement used for fixture generation scripts.""" + return f"rng({int(seed)}, '{generator}');" + + +def _convert_matlab_value(value: Any) -> Any: + if isinstance(value, np.ndarray) and value.dtype == object: + if value.size == 1: + return _convert_matlab_value(value.reshape(-1)[0]) + if value.ndim == 0: + return _convert_matlab_value(value.item()) + if value.ndim == 1: + return [_convert_matlab_value(x) for x in value.tolist()] + if value.ndim == 2: + return [[_convert_matlab_value(x) for x in row] for row in value.tolist()] + return [_convert_matlab_value(x) for x in value.reshape(-1).tolist()] + if isinstance(value, np.ndarray): + return value + if hasattr(value, "_fieldnames"): + out: dict[str, Any] = {} + for name in getattr(value, "_fieldnames", []): + out[str(name)] = _convert_matlab_value(getattr(value, name)) + return out + return value + + +def loadmat_normalized( + path: str | Path, + *, + squeeze_me: bool = False, + keep_metadata: bool = False, +) -> dict[str, Any]: + """Load a MATLAB .mat file and normalize structs/cells into Python types.""" + payload = scipy.io.loadmat( + str(path), + squeeze_me=squeeze_me, + struct_as_record=False, + ) + out: dict[str, Any] = {} + for key, value in payload.items(): + if not keep_metadata and key.startswith("__"): + continue + out[key] = _convert_matlab_value(value) + return out + + +def canonicalize_numeric(value: Any, *, vector_shape: str = "preserve") -> np.ndarray: + """Canonicalize numeric values for parity comparisons (dtype + vector orientation).""" + arr = np.asarray(value) + if np.issubdtype(arr.dtype, np.number): + arr = arr.astype(np.float64, copy=False) + if arr.ndim == 1: + if vector_shape == "column": + arr = arr[:, None] + elif vector_shape == "row": + arr = arr[None, :] + return arr + + +def assert_same_shape(actual: Any, expected: Any) -> None: + a = np.asarray(actual) + b = np.asarray(expected) + if a.shape != b.shape: + raise AssertionError(f"shape mismatch: actual={a.shape} expected={b.shape}") + + +def assert_matching_nan_inf_locations(actual: Any, expected: Any) -> None: + a = canonicalize_numeric(actual) + b = canonicalize_numeric(expected) + assert_same_shape(a, b) + + a_nan = np.isnan(a) + b_nan = np.isnan(b) + if not np.array_equal(a_nan, b_nan): + raise AssertionError("NaN locations do not match") + + a_pos_inf = np.isposinf(a) + b_pos_inf = np.isposinf(b) + if not np.array_equal(a_pos_inf, b_pos_inf): + raise AssertionError("+Inf locations do not match") + + a_neg_inf = np.isneginf(a) + b_neg_inf = np.isneginf(b) + if not np.array_equal(a_neg_inf, b_neg_inf): + raise AssertionError("-Inf locations do not match") + + +def _scale_from_expected(expected: np.ndarray, mode: str) -> float: + finite = np.isfinite(expected) + if not np.any(finite): + return 1.0 + vals = np.abs(expected[finite]) + if mode == "maxabs": + return float(max(np.max(vals), 1.0)) + if mode == "rms": + return float(max(np.sqrt(np.mean(expected[finite] ** 2)), 1.0)) + if mode == "range": + rng = float(np.max(expected[finite]) - np.min(expected[finite])) + return max(rng, 1.0) + raise ValueError(f"unsupported scale mode: {mode}") + + +def assert_allclose_scaled( + actual: Any, + expected: Any, + *, + rtol: float, + atol: float, + scale: str = "maxabs", +) -> None: + a = canonicalize_numeric(actual) + b = canonicalize_numeric(expected) + assert_same_shape(a, b) + assert_matching_nan_inf_locations(a, b) + + finite = np.isfinite(a) & np.isfinite(b) + if not np.any(finite): + return + + scale_val = _scale_from_expected(b, scale) + np.testing.assert_allclose( + a[finite], + b[finite], + rtol=float(rtol), + atol=float(atol) * scale_val, + ) + + +def assert_event_times_close( + actual: Any, + expected: Any, + *, + atol: float = 1.0e-9, + sort_values: bool = True, +) -> None: + a = np.asarray(actual, dtype=np.float64).reshape(-1) + b = np.asarray(expected, dtype=np.float64).reshape(-1) + if sort_values: + a = np.sort(a) + b = np.sort(b) + assert_same_shape(a, b) + np.testing.assert_allclose(a, b, rtol=0.0, atol=float(atol)) diff --git a/tests/performance/fixtures/matlab/performance_baseline_470fde8.csv b/tests/performance/fixtures/matlab/performance_baseline_470fde8.csv new file mode 100644 index 00000000..f4340690 --- /dev/null +++ b/tests/performance/fixtures/matlab/performance_baseline_470fde8.csv @@ -0,0 +1,16 @@ +case,tier,repeats,median_runtime_ms,mean_runtime_ms,std_runtime_ms,median_peak_memory_mb,summary +unit_impulse_basis,S,7,4.297083333,3.773136905,2.034040669,0.191116333,"{""rows"":501,""cols"":50,""total_mass"":501,""memory_proxy_mb"":0.1911163330078125}" +unit_impulse_basis,M,7,2.735708333,2.732898810,0.126777277,1.526641846,"{""rows"":2001,""cols"":100,""total_mass"":2001,""memory_proxy_mb"":1.526641845703125}" +unit_impulse_basis,L,7,6.359083333,6.330571429,0.332416673,9.156799316,"{""rows"":6001,""cols"":200,""total_mass"":6001,""memory_proxy_mb"":9.15679931640625}" +covariate_resample,S,7,2.632916667,2.811726190,0.653497244,0.007637024,"{""rows"":1001,""cols"":1,""signal_energy"":0.51952047952047953,""memory_proxy_mb"":0.00763702392578125}" +covariate_resample,M,7,2.875166667,3.222666667,0.757170547,0.022895813,"{""rows"":3001,""cols"":1,""signal_energy"":0.51984005257749533,""memory_proxy_mb"":0.02289581298828125}" +covariate_resample,L,7,1.947208333,1.888410714,0.315633896,0.045783997,"{""rows"":6001,""cols"":1,""signal_energy"":0.519920013331112,""memory_proxy_mb"":0.04578399658203125}" +history_design_matrix,S,7,15.547416667,15.334202381,2.620143725,0.061065674,"{""rows"":2001,""cols"":4,""total_count"":19479,""memory_proxy_mb"":0.061065673828125}" +history_design_matrix,M,7,12.951875000,12.782089286,1.140924257,0.061065674,"{""rows"":2001,""cols"":4,""total_count"":97507,""memory_proxy_mb"":0.061065673828125}" +history_design_matrix,L,7,13.820333333,13.938613095,0.434794481,0.061065674,"{""rows"":2001,""cols"":4,""total_count"":292495,""memory_proxy_mb"":0.061065673828125}" +simulate_cif_thinning,S,7,18.633500000,22.285202381,10.204807908,0.007637024,"{""num_units"":5,""total_spikes"":53,""mean_spikes_per_unit"":10.6,""memory_proxy_mb"":0.00763702392578125}" +simulate_cif_thinning,M,7,10.393083333,11.535672619,6.034641619,0.015266418,"{""num_units"":10,""total_spikes"":227,""mean_spikes_per_unit"":22.7,""memory_proxy_mb"":0.01526641845703125}" +simulate_cif_thinning,L,7,11.125333333,11.006880952,1.380233877,0.022895813,"{""num_units"":20,""total_spikes"":697,""mean_spikes_per_unit"":34.85,""memory_proxy_mb"":0.02289581298828125}" +decoding_spike_rate_cis,S,7,17.276083333,20.075815476,8.011035032,0.000274658,"{""num_trials"":6,""prob_mean"":0.17222222222222225,""sig_count"":0,""rate_mean"":50.340528015525138,""memory_proxy_mb"":0.000274658203125}" +decoding_spike_rate_cis,M,7,15.002750000,15.106750000,2.117378368,0.000488281,"{""num_trials"":8,""prob_mean"":0.188125,""sig_count"":0,""rate_mean"":50.22232623024469,""memory_proxy_mb"":0.00048828125}" +decoding_spike_rate_cis,L,7,26.472458333,25.936833333,1.395156968,0.001098633,"{""num_trials"":12,""prob_mean"":0.20998263888888888,""sig_count"":0,""rate_mean"":50.1178888198292,""memory_proxy_mb"":0.0010986328125}" diff --git a/tests/performance/fixtures/matlab/performance_baseline_470fde8.json b/tests/performance/fixtures/matlab/performance_baseline_470fde8.json new file mode 100644 index 00000000..ca8e6236 --- /dev/null +++ b/tests/performance/fixtures/matlab/performance_baseline_470fde8.json @@ -0,0 +1,536 @@ +{ + "schema_version": 1, + "generated_at_utc": "2026-03-04T04:09:43Z", + "implementation": "matlab", + "nstat_root": "/Users/iahncajigas/Library/CloudStorage/Dropbox/Research/Matlab/nSTAT_currentRelease_Local", + "reference_sha": "0afc8390b5958bb9af255344d7e4a33fedb172ca", + "tiers": [ + "S", + "M", + "L" + ], + "cases": [ + { + "case": "unit_impulse_basis", + "tier": "S", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 4.2970833333333331, + "mean_runtime_ms": 3.773136904761905, + "std_runtime_ms": 2.0340406685672687, + "median_peak_memory_mb": 0.1911163330078125, + "summary": { + "rows": 501, + "cols": 50, + "total_mass": 501, + "memory_proxy_mb": 0.1911163330078125 + }, + "samples_runtime_ms": [ + 7.0744583333333333, + 5.0766666666666671, + 4.3886666666666665, + 4.2970833333333331, + 2.3883333333333336, + 1.86925, + 1.3175000000000001 + ], + "samples_peak_memory_mb": [ + 0.1911163330078125, + 0.1911163330078125, + 0.1911163330078125, + 0.1911163330078125, + 0.1911163330078125, + 0.1911163330078125, + 0.1911163330078125 + ] + }, + { + "case": "unit_impulse_basis", + "tier": "M", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 2.7357083333333336, + "mean_runtime_ms": 2.73289880952381, + "std_runtime_ms": 0.12677727685428219, + "median_peak_memory_mb": 1.526641845703125, + "summary": { + "rows": 2001, + "cols": 100, + "total_mass": 2001, + "memory_proxy_mb": 1.526641845703125 + }, + "samples_runtime_ms": [ + 2.6884583333333336, + 2.5345, + 2.9364583333333334, + 2.8283750000000003, + 2.7407916666666665, + 2.7357083333333336, + 2.666 + ], + "samples_peak_memory_mb": [ + 1.526641845703125, + 1.526641845703125, + 1.526641845703125, + 1.526641845703125, + 1.526641845703125, + 1.526641845703125, + 1.526641845703125 + ] + }, + { + "case": "unit_impulse_basis", + "tier": "L", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 6.3590833333333325, + "mean_runtime_ms": 6.330571428571429, + "std_runtime_ms": 0.33241667251111179, + "median_peak_memory_mb": 9.15679931640625, + "summary": { + "rows": 6001, + "cols": 200, + "total_mass": 6001, + "memory_proxy_mb": 9.15679931640625 + }, + "samples_runtime_ms": [ + 6.29275, + 6.87575, + 6.2596666666666669, + 6.3685833333333335, + 6.3590833333333325, + 6.418625, + 5.7395416666666668 + ], + "samples_peak_memory_mb": [ + 9.15679931640625, + 9.15679931640625, + 9.15679931640625, + 9.15679931640625, + 9.15679931640625, + 9.15679931640625, + 9.15679931640625 + ] + }, + { + "case": "covariate_resample", + "tier": "S", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 2.6329166666666666, + "mean_runtime_ms": 2.81172619047619, + "std_runtime_ms": 0.65349724411467558, + "median_peak_memory_mb": 0.00763702392578125, + "summary": { + "rows": 1001, + "cols": 1, + "signal_energy": 0.51952047952047953, + "memory_proxy_mb": 0.00763702392578125 + }, + "samples_runtime_ms": [ + 2.8413333333333335, + 2.6329166666666666, + 3.9391666666666665, + 2.50025, + 3.3961249999999996, + 2.3224166666666668, + 2.049875 + ], + "samples_peak_memory_mb": [ + 0.00763702392578125, + 0.00763702392578125, + 0.00763702392578125, + 0.00763702392578125, + 0.00763702392578125, + 0.00763702392578125, + 0.00763702392578125 + ] + }, + { + "case": "covariate_resample", + "tier": "M", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 2.8751666666666664, + "mean_runtime_ms": 3.2226666666666666, + "std_runtime_ms": 0.75717054663385741, + "median_peak_memory_mb": 0.02289581298828125, + "summary": { + "rows": 3001, + "cols": 1, + "signal_energy": 0.51984005257749533, + "memory_proxy_mb": 0.02289581298828125 + }, + "samples_runtime_ms": [ + 4.409, + 3.195, + 2.7273750000000003, + 2.67125, + 4.1570833333333326, + 2.5237916666666669, + 2.8751666666666664 + ], + "samples_peak_memory_mb": [ + 0.02289581298828125, + 0.02289581298828125, + 0.02289581298828125, + 0.02289581298828125, + 0.02289581298828125, + 0.02289581298828125, + 0.02289581298828125 + ] + }, + { + "case": "covariate_resample", + "tier": "L", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 1.9472083333333334, + "mean_runtime_ms": 1.8884107142857143, + "std_runtime_ms": 0.31563389649046991, + "median_peak_memory_mb": 0.04578399658203125, + "summary": { + "rows": 6001, + "cols": 1, + "signal_energy": 0.519920013331112, + "memory_proxy_mb": 0.04578399658203125 + }, + "samples_runtime_ms": [ + 1.95875, + 1.526, + 2.4904166666666669, + 1.9485416666666666, + 1.6713749999999998, + 1.6765833333333333, + 1.9472083333333334 + ], + "samples_peak_memory_mb": [ + 0.04578399658203125, + 0.04578399658203125, + 0.04578399658203125, + 0.04578399658203125, + 0.04578399658203125, + 0.04578399658203125, + 0.04578399658203125 + ] + }, + { + "case": "history_design_matrix", + "tier": "S", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 15.547416666666665, + "mean_runtime_ms": 15.33420238095238, + "std_runtime_ms": 2.6201437253425088, + "median_peak_memory_mb": 0.061065673828125, + "summary": { + "rows": 2001, + "cols": 4, + "total_count": 19479, + "memory_proxy_mb": 0.061065673828125 + }, + "samples_runtime_ms": [ + 19.652708333333333, + 16.800625, + 15.547416666666665, + 15.985833333333334, + 14.421083333333332, + 13.613166666666668, + 11.318583333333333 + ], + "samples_peak_memory_mb": [ + 0.061065673828125, + 0.061065673828125, + 0.061065673828125, + 0.061065673828125, + 0.061065673828125, + 0.061065673828125, + 0.061065673828125 + ] + }, + { + "case": "history_design_matrix", + "tier": "M", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 12.951875, + "mean_runtime_ms": 12.782089285714287, + "std_runtime_ms": 1.1409242568944873, + "median_peak_memory_mb": 0.061065673828125, + "summary": { + "rows": 2001, + "cols": 4, + "total_count": 97507, + "memory_proxy_mb": 0.061065673828125 + }, + "samples_runtime_ms": [ + 13.366875, + 14.438875, + 11.997499999999999, + 12.951875, + 12.24975, + 13.499125000000001, + 10.970625 + ], + "samples_peak_memory_mb": [ + 0.061065673828125, + 0.061065673828125, + 0.061065673828125, + 0.061065673828125, + 0.061065673828125, + 0.061065673828125, + 0.061065673828125 + ] + }, + { + "case": "history_design_matrix", + "tier": "L", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 13.820333333333334, + "mean_runtime_ms": 13.938613095238097, + "std_runtime_ms": 0.43479448084744254, + "median_peak_memory_mb": 0.061065673828125, + "summary": { + "rows": 2001, + "cols": 4, + "total_count": 292495, + "memory_proxy_mb": 0.061065673828125 + }, + "samples_runtime_ms": [ + 14.813, + 14.0275, + 13.763958333333333, + 13.514791666666666, + 14.047, + 13.820333333333334, + 13.583708333333334 + ], + "samples_peak_memory_mb": [ + 0.061065673828125, + 0.061065673828125, + 0.061065673828125, + 0.061065673828125, + 0.061065673828125, + 0.061065673828125, + 0.061065673828125 + ] + }, + { + "case": "simulate_cif_thinning", + "tier": "S", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 18.6335, + "mean_runtime_ms": 22.285202380952381, + "std_runtime_ms": 10.204807907674679, + "median_peak_memory_mb": 0.00763702392578125, + "summary": { + "num_units": 5, + "total_spikes": 53, + "mean_spikes_per_unit": 10.6, + "memory_proxy_mb": 0.00763702392578125 + }, + "samples_runtime_ms": [ + 38.645125, + 33.8315, + 21.978791666666666, + 17.852333333333334, + 18.6335, + 12.101458333333333, + 12.953708333333333 + ], + "samples_peak_memory_mb": [ + 0.00763702392578125, + 0.00763702392578125, + 0.00763702392578125, + 0.00763702392578125, + 0.00763702392578125, + 0.00763702392578125, + 0.00763702392578125 + ] + }, + { + "case": "simulate_cif_thinning", + "tier": "M", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 10.393083333333333, + "mean_runtime_ms": 11.535672619047618, + "std_runtime_ms": 6.0346416186965479, + "median_peak_memory_mb": 0.01526641845703125, + "summary": { + "num_units": 10, + "total_spikes": 227, + "mean_spikes_per_unit": 22.7, + "memory_proxy_mb": 0.01526641845703125 + }, + "samples_runtime_ms": [ + 18.474083333333333, + 16.471166666666665, + 18.113541666666666, + 10.393083333333333, + 6.7640416666666665, + 5.4447916666666663, + 5.0889999999999995 + ], + "samples_peak_memory_mb": [ + 0.01526641845703125, + 0.01526641845703125, + 0.01526641845703125, + 0.01526641845703125, + 0.01526641845703125, + 0.01526641845703125, + 0.01526641845703125 + ] + }, + { + "case": "simulate_cif_thinning", + "tier": "L", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 11.125333333333332, + "mean_runtime_ms": 11.006880952380952, + "std_runtime_ms": 1.3802338767926841, + "median_peak_memory_mb": 0.02289581298828125, + "summary": { + "num_units": 20, + "total_spikes": 697, + "mean_spikes_per_unit": 34.85, + "memory_proxy_mb": 0.02289581298828125 + }, + "samples_runtime_ms": [ + 11.999833333333333, + 13.285416666666666, + 11.125333333333332, + 11.472166666666666, + 9.8583333333333343, + 9.512375, + 9.7947083333333342 + ], + "samples_peak_memory_mb": [ + 0.02289581298828125, + 0.02289581298828125, + 0.02289581298828125, + 0.02289581298828125, + 0.02289581298828125, + 0.02289581298828125, + 0.02289581298828125 + ] + }, + { + "case": "decoding_spike_rate_cis", + "tier": "S", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 17.276083333333336, + "mean_runtime_ms": 20.075815476190478, + "std_runtime_ms": 8.01103503153854, + "median_peak_memory_mb": 0.000274658203125, + "summary": { + "num_trials": 6, + "prob_mean": 0.17222222222222225, + "sig_count": 0, + "rate_mean": 50.340528015525138, + "memory_proxy_mb": 0.000274658203125 + }, + "samples_runtime_ms": [ + 35.707833333333333, + 25.233916666666669, + 17.276083333333336, + 16.183166666666665, + 13.461833333333333, + 13.223083333333333, + 19.444791666666667 + ], + "samples_peak_memory_mb": [ + 0.000274658203125, + 0.000274658203125, + 0.000274658203125, + 0.000274658203125, + 0.000274658203125, + 0.000274658203125, + 0.000274658203125 + ] + }, + { + "case": "decoding_spike_rate_cis", + "tier": "M", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 15.00275, + "mean_runtime_ms": 15.106750000000002, + "std_runtime_ms": 2.1173783681560523, + "median_peak_memory_mb": 0.00048828125, + "summary": { + "num_trials": 8, + "prob_mean": 0.188125, + "sig_count": 0, + "rate_mean": 50.22232623024469, + "memory_proxy_mb": 0.00048828125 + }, + "samples_runtime_ms": [ + 18.718333333333334, + 16.751625, + 15.287875, + 13.339791666666667, + 14.224666666666666, + 12.422208333333334, + 15.00275 + ], + "samples_peak_memory_mb": [ + 0.00048828125, + 0.00048828125, + 0.00048828125, + 0.00048828125, + 0.00048828125, + 0.00048828125, + 0.00048828125 + ] + }, + { + "case": "decoding_spike_rate_cis", + "tier": "L", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 26.472458333333336, + "mean_runtime_ms": 25.936833333333336, + "std_runtime_ms": 1.395156967556114, + "median_peak_memory_mb": 0.0010986328125, + "summary": { + "num_trials": 12, + "prob_mean": 0.20998263888888888, + "sig_count": 0, + "rate_mean": 50.1178888198292, + "memory_proxy_mb": 0.0010986328125 + }, + "samples_runtime_ms": [ + 26.474541666666667, + 25.155833333333334, + 27.055541666666667, + 24.789291666666667, + 27.79125, + 26.472458333333336, + 23.818916666666667 + ], + "samples_peak_memory_mb": [ + 0.0010986328125, + 0.0010986328125, + 0.0010986328125, + 0.0010986328125, + 0.0010986328125, + 0.0010986328125, + 0.0010986328125 + ] + } + ], + "environment": { + "matlab_version": "25.2.0.3123386 (R2025b) Update 3", + "matlab_release": "2025b", + "os": "MACA64", + "blas": "Apple Accelerate BLAS (ILP64)", + "omp_num_threads": "", + "mkl_num_threads": "", + "openblas_num_threads": "" + } +} \ No newline at end of file diff --git a/tests/performance/fixtures/python/performance_baseline_20260303.csv b/tests/performance/fixtures/python/performance_baseline_20260303.csv new file mode 100644 index 00000000..0103fa42 --- /dev/null +++ b/tests/performance/fixtures/python/performance_baseline_20260303.csv @@ -0,0 +1,16 @@ +case,tier,repeats,median_runtime_ms,mean_runtime_ms,std_runtime_ms,median_peak_memory_mb,summary +unit_impulse_basis,S,7,1.7683329933788627,1.7861547135648184,0.158024180816925,0.4534463882446289,"{""cols"": 50.0, ""rows"": 501.0, ""total_mass"": 500.0}" +unit_impulse_basis,M,7,4.475333000300452,4.500089290169334,0.9284738746088693,3.1384315490722656,"{""cols"": 100.0, ""rows"": 2001.0, ""total_mass"": 2000.0}" +unit_impulse_basis,L,7,9.582500002579764,9.65004785601715,0.5994475223217789,18.434642791748047,"{""cols"": 200.0, ""rows"": 6001.0, ""total_mass"": 6000.0}" +covariate_resample,S,7,0.37954199069645256,0.5145418565786842,0.342992624379633,0.061981201171875,"{""cols"": 1.0, ""rows"": 1001.0, ""signal_energy"": 0.5195204795204795}" +covariate_resample,M,7,0.8639160078018904,0.9734344265390453,0.36931540455016815,0.1353321075439453,"{""cols"": 1.0, ""rows"": 3001.0, ""signal_energy"": 0.5198042747802832}" +covariate_resample,L,7,0.9690420120023191,0.9464108568084028,0.48446691060044744,0.23737525939941406,"{""cols"": 1.0, ""rows"": 6001.0, ""signal_energy"": 0.5199200133311115}" +history_design_matrix,S,7,0.7950409926706925,1.645898716690551,1.8429881668958228,0.08904266357421875,"{""cols"": 4.0, ""rows"": 1000.0, ""total_count"": 9737.0}" +history_design_matrix,M,7,5.126874995767139,4.977494141452813,0.5063190759208223,0.436859130859375,"{""cols"": 4.0, ""rows"": 5000.0, ""total_count"": 243740.0}" +history_design_matrix,L,7,16.20729199203197,14.745523999278833,3.320744565978393,0.8869171142578125,"{""cols"": 4.0, ""rows"": 10000.0, ""total_count"": 1462420.0}" +simulate_cif_thinning,S,7,11.668874998576939,13.197898854351868,4.439189648217996,0.07119369506835938,"{""mean_spikes_per_unit"": 11.8, ""num_units"": 5.0, ""total_spikes"": 59.0}" +simulate_cif_thinning,M,7,35.64529199502431,39.50604771463467,10.14234695828477,0.13432693481445312,"{""mean_spikes_per_unit"": 23.6, ""num_units"": 10.0, ""total_spikes"": 236.0}" +simulate_cif_thinning,L,7,95.53704199788626,101.42410728648039,15.004523282406359,0.20084762573242188,"{""mean_spikes_per_unit"": 36.5, ""num_units"": 20.0, ""total_spikes"": 730.0}" +decoding_spike_rate_cis,S,7,20.557166004437022,20.735029426370083,0.32869912750478875,0.23363685607910156,"{""num_trials"": 6.0, ""prob_mean"": 0.1509259259259259, ""rate_mean"": 50.4457886636761, ""sig_count"": 0.0}" +decoding_spike_rate_cis,M,7,45.79670800012536,55.404220285709016,17.298491800109527,0.7647151947021484,"{""num_trials"": 8.0, ""prob_mean"": 0.18562499999999998, ""rate_mean"": 50.12398439148756, ""sig_count"": 0.0}" +decoding_spike_rate_cis,L,7,99.45741599949542,100.12100585819488,2.1007428451842776,2.7344188690185547,"{""num_trials"": 12.0, ""prob_mean"": 0.21328124999999998, ""rate_mean"": 50.073736692667104, ""sig_count"": 0.0}" diff --git a/tests/performance/fixtures/python/performance_baseline_20260303.json b/tests/performance/fixtures/python/performance_baseline_20260303.json new file mode 100644 index 00000000..dcb456e8 --- /dev/null +++ b/tests/performance/fixtures/python/performance_baseline_20260303.json @@ -0,0 +1,523 @@ +{ + "schema_version": 1, + "generated_at_utc": "2026-03-04T04:20:29Z", + "implementation": "python", + "repo_root": "/private/tmp/nstat_python_exec_next", + "git_sha": "540519f52cb6799fa4886ddbe8cdd3b5fd1c9c3b", + "tiers": [ + "S", + "M", + "L" + ], + "cases": [ + { + "case": "unit_impulse_basis", + "tier": "S", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 1.7683329933788627, + "mean_runtime_ms": 1.7861547135648184, + "std_runtime_ms": 0.158024180816925, + "median_peak_memory_mb": 0.4534463882446289, + "summary": { + "rows": 501.0, + "cols": 50.0, + "total_mass": 500.0 + }, + "samples_runtime_ms": [ + 1.694457998382859, + 1.5242920053424314, + 1.7683329933788627, + 1.9450409890851006, + 1.9406670035095885, + 1.9679590041050687, + 1.6623330011498183 + ], + "samples_peak_memory_mb": [ + 0.4534463882446289, + 0.45345401763916016, + 0.4534616470336914, + 0.4534463882446289, + 0.4534311294555664, + 0.4534006118774414, + 0.4533853530883789 + ] + }, + { + "case": "unit_impulse_basis", + "tier": "M", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 4.475333000300452, + "mean_runtime_ms": 4.500089290169334, + "std_runtime_ms": 0.9284738746088693, + "median_peak_memory_mb": 3.1384315490722656, + "summary": { + "rows": 2001.0, + "cols": 100.0, + "total_mass": 2000.0 + }, + "samples_runtime_ms": [ + 4.051500000059605, + 5.044250006903894, + 5.633542008581571, + 5.681417009327561, + 3.234125004382804, + 4.475333000300452, + 3.3804580016294494 + ], + "samples_peak_memory_mb": [ + 3.1385536193847656, + 3.138561248779297, + 3.138446807861328, + 3.1384315490722656, + 3.138416290283203, + 3.138408660888672, + 3.1384010314941406 + ] + }, + { + "case": "unit_impulse_basis", + "tier": "L", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 9.582500002579764, + "mean_runtime_ms": 9.65004785601715, + "std_runtime_ms": 0.5994475223217789, + "median_peak_memory_mb": 18.434642791748047, + "summary": { + "rows": 6001.0, + "cols": 200.0, + "total_mass": 6000.0 + }, + "samples_runtime_ms": [ + 9.086833990295418, + 8.696708013303578, + 9.401916991919279, + 9.582500002579764, + 10.047749994555488, + 10.233041990431957, + 10.50158400903456 + ], + "samples_peak_memory_mb": [ + 18.434635162353516, + 18.434650421142578, + 18.43466567993164, + 18.434650421142578, + 18.434642791748047, + 18.434635162353516, + 18.434627532958984 + ] + }, + { + "case": "covariate_resample", + "tier": "S", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 0.37954199069645256, + "mean_runtime_ms": 0.5145418565786842, + "std_runtime_ms": 0.342992624379633, + "median_peak_memory_mb": 0.061981201171875, + "summary": { + "rows": 1001.0, + "cols": 1.0, + "signal_energy": 0.5195204795204795 + }, + "samples_runtime_ms": [ + 0.19366700144018978, + 0.17533400387037545, + 0.37954199069645256, + 0.3596659953473136, + 0.48983399756252766, + 1.2112499971408397, + 0.7925000099930912 + ], + "samples_peak_memory_mb": [ + 0.061981201171875, + 0.061981201171875, + 0.061981201171875, + 0.061981201171875, + 0.061981201171875, + 0.061981201171875, + 0.061981201171875 + ] + }, + { + "case": "covariate_resample", + "tier": "M", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 0.8639160078018904, + "mean_runtime_ms": 0.9734344265390453, + "std_runtime_ms": 0.36931540455016815, + "median_peak_memory_mb": 0.1353321075439453, + "summary": { + "rows": 3001.0, + "cols": 1.0, + "signal_energy": 0.5198042747802832 + }, + "samples_runtime_ms": [ + 1.240666999365203, + 1.4605409960495308, + 0.6025000038789585, + 1.4403749955818057, + 0.6479579897131771, + 0.8639160078018904, + 0.5580839933827519 + ], + "samples_peak_memory_mb": [ + 0.13530921936035156, + 0.13530921936035156, + 0.1353321075439453, + 0.1353321075439453, + 0.1353321075439453, + 0.1353321075439453, + 0.1353321075439453 + ] + }, + { + "case": "covariate_resample", + "tier": "L", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 0.9690420120023191, + "mean_runtime_ms": 0.9464108568084028, + "std_runtime_ms": 0.48446691060044744, + "median_peak_memory_mb": 0.23737525939941406, + "summary": { + "rows": 6001.0, + "cols": 1.0, + "signal_energy": 0.5199200133311115 + }, + "samples_runtime_ms": [ + 1.9614159973571077, + 0.9757090010680258, + 0.7324170001083985, + 1.1260829924140126, + 0.9690420120023191, + 0.4219999973429367, + 0.4382089973660186 + ], + "samples_peak_memory_mb": [ + 0.2373523712158203, + 0.2373523712158203, + 0.23737525939941406, + 0.23737525939941406, + 0.23737525939941406, + 0.23737525939941406, + 0.23737525939941406 + ] + }, + { + "case": "history_design_matrix", + "tier": "S", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 0.7950409926706925, + "mean_runtime_ms": 1.645898716690551, + "std_runtime_ms": 1.8429881668958228, + "median_peak_memory_mb": 0.08904266357421875, + "summary": { + "rows": 1000.0, + "cols": 4.0, + "total_count": 9737.0 + }, + "samples_runtime_ms": [ + 0.7901250064605847, + 0.9010420035338029, + 1.5587500092806295, + 0.7950409926706925, + 6.106041997554712, + 0.7814580021658912, + 0.5888330051675439 + ], + "samples_peak_memory_mb": [ + 0.08902740478515625, + 0.08904266357421875, + 0.08905792236328125, + 0.08905029296875, + 0.08904266357421875, + 0.08902740478515625, + 0.089019775390625 + ] + }, + { + "case": "history_design_matrix", + "tier": "M", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 5.126874995767139, + "mean_runtime_ms": 4.977494141452813, + "std_runtime_ms": 0.5063190759208223, + "median_peak_memory_mb": 0.436859130859375, + "summary": { + "rows": 5000.0, + "cols": 4.0, + "total_count": 243740.0 + }, + "samples_runtime_ms": [ + 5.498124999576248, + 4.521750001003966, + 4.666374996304512, + 5.706082985852845, + 5.126874995767139, + 4.17879200540483, + 5.144459006260149 + ], + "samples_peak_memory_mb": [ + 0.43685150146484375, + 0.43686676025390625, + 0.43688201904296875, + 0.4368743896484375, + 0.436859130859375, + 0.43685150146484375, + 0.4368438720703125 + ] + }, + { + "case": "history_design_matrix", + "tier": "L", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 16.20729199203197, + "mean_runtime_ms": 14.745523999278833, + "std_runtime_ms": 3.320744565978393, + "median_peak_memory_mb": 0.8869171142578125, + "summary": { + "rows": 10000.0, + "cols": 4.0, + "total_count": 1462420.0 + }, + "samples_runtime_ms": [ + 10.213750007096678, + 12.098499995772727, + 16.20729199203197, + 17.953916991245933, + 18.11012500547804, + 17.890375005663373, + 10.74470899766311 + ], + "samples_peak_memory_mb": [ + 0.8869094848632812, + 0.8869247436523438, + 0.8869400024414062, + 0.8869247436523438, + 0.8869171142578125, + 0.8869094848632812, + 0.88690185546875 + ] + }, + { + "case": "simulate_cif_thinning", + "tier": "S", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 11.668874998576939, + "mean_runtime_ms": 13.197898854351868, + "std_runtime_ms": 4.439189648217996, + "median_peak_memory_mb": 0.07119369506835938, + "summary": { + "num_units": 5.0, + "total_spikes": 59.0, + "mean_spikes_per_unit": 11.8 + }, + "samples_runtime_ms": [ + 15.166458004387096, + 23.260666988790035, + 9.778041989193298, + 11.668874998576939, + 11.702959003741853, + 9.72149999870453, + 11.086790997069329 + ], + "samples_peak_memory_mb": [ + 0.07107925415039062, + 0.07112503051757812, + 0.07119369506835938, + 0.07131576538085938, + 0.07126998901367188, + 0.07120132446289062, + 0.07112503051757812 + ] + }, + { + "case": "simulate_cif_thinning", + "tier": "M", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 35.64529199502431, + "mean_runtime_ms": 39.50604771463467, + "std_runtime_ms": 10.14234695828477, + "median_peak_memory_mb": 0.13432693481445312, + "summary": { + "num_units": 10.0, + "total_spikes": 236.0, + "mean_spikes_per_unit": 23.6 + }, + "samples_runtime_ms": [ + 63.98295899271034, + 35.64529199502431, + 38.63129101227969, + 33.676250008284114, + 36.9453750026878, + 34.22758399392478, + 33.43358299753163 + ], + "samples_peak_memory_mb": [ + 0.13429641723632812, + 0.13458633422851562, + 0.13427352905273438, + 0.13451004028320312, + 0.13428878784179688, + 0.13432693481445312, + 0.13438034057617188 + ] + }, + { + "case": "simulate_cif_thinning", + "tier": "L", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 95.53704199788626, + "mean_runtime_ms": 101.42410728648039, + "std_runtime_ms": 15.004523282406359, + "median_peak_memory_mb": 0.20084762573242188, + "summary": { + "num_units": 20.0, + "total_spikes": 730.0, + "mean_spikes_per_unit": 36.5 + }, + "samples_runtime_ms": [ + 95.53704199788626, + 138.14341700344812, + 96.45345799799543, + 95.80341700348072, + 94.85275000042748, + 94.7820420115022, + 94.39662499062251 + ], + "samples_peak_memory_mb": [ + 0.20105361938476562, + 0.20028305053710938, + 0.20032119750976562, + 0.20097732543945312, + 0.20128250122070312, + 0.20084762573242188, + 0.20074081420898438 + ] + }, + { + "case": "decoding_spike_rate_cis", + "tier": "S", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 20.557166004437022, + "mean_runtime_ms": 20.735029426370083, + "std_runtime_ms": 0.32869912750478875, + "median_peak_memory_mb": 0.23363685607910156, + "summary": { + "num_trials": 6.0, + "prob_mean": 0.1509259259259259, + "sig_count": 0.0, + "rate_mean": 50.4457886636761 + }, + "samples_runtime_ms": [ + 20.557166004437022, + 20.73420799570158, + 20.454083001823165, + 21.412290996522643, + 20.491249990300275, + 21.00037499621976, + 20.495832999586128 + ], + "samples_peak_memory_mb": [ + 0.23363685607910156, + 0.23363685607910156, + 0.23363685607910156, + 0.23363685607910156, + 0.23363685607910156, + 0.2334461212158203, + 0.23340415954589844 + ] + }, + { + "case": "decoding_spike_rate_cis", + "tier": "M", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 45.79670800012536, + "mean_runtime_ms": 55.404220285709016, + "std_runtime_ms": 17.298491800109527, + "median_peak_memory_mb": 0.7647151947021484, + "summary": { + "num_trials": 8.0, + "prob_mean": 0.18562499999999998, + "sig_count": 0.0, + "rate_mean": 50.12398439148756 + }, + "samples_runtime_ms": [ + 92.78645800077356, + 68.6920409934828, + 44.14379199442919, + 43.69733401108533, + 45.79670800012536, + 46.98724999616388, + 45.72595900390297 + ], + "samples_peak_memory_mb": [ + 0.7647151947021484, + 0.7647151947021484, + 0.7647151947021484, + 0.7647151947021484, + 0.7647151947021484, + 0.7647151947021484, + 0.7647151947021484 + ] + }, + { + "case": "decoding_spike_rate_cis", + "tier": "L", + "repeats": 7, + "warmup": 2, + "median_runtime_ms": 99.45741599949542, + "mean_runtime_ms": 100.12100585819488, + "std_runtime_ms": 2.1007428451842776, + "median_peak_memory_mb": 2.7344188690185547, + "summary": { + "num_trials": 12.0, + "prob_mean": 0.21328124999999998, + "sig_count": 0.0, + "rate_mean": 50.073736692667104 + }, + "samples_runtime_ms": [ + 97.61908301152289, + 98.89920899877325, + 98.79516600631177, + 99.45741599949542, + 104.42429198883474, + 99.98062500380911, + 101.67124999861699 + ], + "samples_peak_memory_mb": [ + 2.734373092651367, + 2.734395980834961, + 2.7344188690185547, + 2.7344188690185547, + 2.7344188690185547, + 2.7344188690185547, + 2.7344188690185547 + ] + } + ], + "environment": { + "python": "3.12.4", + "platform": "macOS-26.3-arm64-arm-64bit", + "numpy": "1.26.4", + "scipy": "1.13.1", + "matplotlib": "3.8.4", + "omp_num_threads": "", + "mkl_num_threads": "", + "openblas_num_threads": "", + "veclib_maximum_threads": "" + } +} \ No newline at end of file diff --git a/tests/performance/test_pytest_benchmarks.py b/tests/performance/test_pytest_benchmarks.py new file mode 100644 index 00000000..2407abfb --- /dev/null +++ b/tests/performance/test_pytest_benchmarks.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import os + +import pytest + +from nstat.performance_workloads import CASE_ORDER, run_python_workload + +pytestmark = pytest.mark.skipif( + os.getenv("NSTAT_RUN_PERF_BENCHMARKS", "0") != "1", + reason="Performance benchmarks run only in dedicated CI jobs", +) + + +@pytest.mark.performance +@pytest.mark.parametrize("case", CASE_ORDER) +def test_benchmark_tier_s(benchmark: pytest.BenchmarkFixture, case: str) -> None: # type: ignore[name-defined] + summary = benchmark(run_python_workload, case, "S", 20260303) + assert summary + assert all(value == value for value in summary.values()) diff --git a/tests/performance/test_workload_outputs.py b/tests/performance/test_workload_outputs.py new file mode 100644 index 00000000..d466a035 --- /dev/null +++ b/tests/performance/test_workload_outputs.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from nstat.performance_workloads import CASE_ORDER, TIER_ORDER, run_python_workload + + +def test_workloads_return_finite_metrics() -> None: + for case in CASE_ORDER: + for tier in TIER_ORDER: + metrics = run_python_workload(case=case, tier=tier, seed=20260303) + assert metrics, f"{case}/{tier} returned no metrics" + for name, value in metrics.items(): + assert isinstance(value, float), f"{case}/{tier}:{name} must be float" + assert value == value, f"{case}/{tier}:{name} is NaN" + assert value != float("inf"), f"{case}/{tier}:{name} is inf" diff --git a/tests/test_events_history.py b/tests/test_events_history.py index ee9e1f66..2e117ac3 100644 --- a/tests/test_events_history.py +++ b/tests/test_events_history.py @@ -14,3 +14,22 @@ def test_history_design_matrix() -> None: hb = HistoryBasis(bin_edges_s=np.array([0.0, 0.05, 0.1])) mat = hb.design_matrix(spike_times_s=np.array([0.15, 0.22]), time_grid_s=np.array([0.25, 0.3])) assert mat.shape == (2, 2) + + +def test_history_design_matrix_matches_naive_reference() -> None: + rng = np.random.default_rng(7) + spikes = np.sort(rng.random(400) * 2.0) + grid = np.linspace(0.0, 2.0, 250) + hb = HistoryBasis(bin_edges_s=np.array([0.0, 0.01, 0.03, 0.07, 0.1])) + + fast = hb.design_matrix(spike_times_s=spikes, time_grid_s=grid) + + ref = np.zeros_like(fast) + for i, t_now in enumerate(grid): + lags = t_now - spikes + for j in range(hb.n_bins): + lo = hb.bin_edges_s[j] + hi = hb.bin_edges_s[j + 1] + ref[i, j] = float(np.sum((lags > lo) & (lags <= hi))) + + np.testing.assert_allclose(fast, ref, atol=0.0, rtol=0.0) diff --git a/tests/test_parity_matlab_gold.py b/tests/test_parity_matlab_gold.py index 20c2ad7f..38046438 100644 --- a/tests/test_parity_matlab_gold.py +++ b/tests/test_parity_matlab_gold.py @@ -8,11 +8,18 @@ import yaml from nstat.analysis import Analysis +from nstat.compat.matlab import History, SignalObj from nstat.decoding import DecodingAlgorithms from nstat.events import Events from nstat.signal import Covariate from nstat.spikes import SpikeTrain, SpikeTrainCollection from nstat.trial import CovariateCollection, Trial +from tests.parity_utils import ( + assert_allclose_scaled, + assert_same_shape, + canonicalize_numeric, + loadmat_normalized, +) MANIFEST = Path("tests/parity/fixtures/matlab_gold/manifest.yml") @@ -158,6 +165,168 @@ def test_psthe_stimation_matlab_gold_comparison() -> None: assert np.array_equal(sig_mat, expected_sig) +def test_validation_dataset_matlab_gold_comparison() -> None: + m = _mat("tests/parity/fixtures/matlab_gold/ValidationDataSet_gold.mat") + trial_matrix = np.asarray(m["trial_matrix_val"], dtype=float) + rate, prob_mat, sig_mat = DecodingAlgorithms.compute_spike_rate_cis(spike_matrix=trial_matrix, alpha=0.05) + + expected_rate = _vec(m, "expected_rate_val") + expected_prob = np.asarray(m["expected_prob_val"], dtype=float) + expected_sig = np.asarray(m["expected_sig_val"], dtype=int) + + assert np.allclose(rate, expected_rate, atol=1e-10) + assert np.allclose(prob_mat, expected_prob, atol=1e-10) + assert np.array_equal(sig_mat, expected_sig) + + +def test_stimulus_decode_2d_matlab_gold_comparison() -> None: + m = _mat("tests/parity/fixtures/matlab_gold/StimulusDecode2D_gold.mat") + spike_counts = np.asarray(m["spike_counts_sd"], dtype=float) + tuning = np.asarray(m["tuning_sd"], dtype=float) + states = np.asarray(m["states_sd"], dtype=float) + expected_center = _vec(m, "decoded_center_sd") + expected_decoded = _vec(m, "decoded_sd").astype(int) + expected_rmse = _scalar(m, "rmse_sd") + + decoded_center = DecodingAlgorithms.decode_weighted_center(spike_counts=spike_counts, tuning_curves=tuning) + decoded = np.clip(np.rint(decoded_center), 0, states.shape[0] - 1).astype(int) + xy_true = np.asarray(m["xy_true_sd"], dtype=float) + xy_decoded = states[decoded] + rmse = float(np.sqrt(np.mean(np.sum((xy_decoded - xy_true) ** 2, axis=1)))) + + assert np.allclose(decoded_center, expected_center, atol=1e-8) + assert np.array_equal(decoded, expected_decoded) + assert np.isclose(rmse, expected_rmse, atol=1e-10) + + +def test_explicit_stimulus_whisker_matlab_gold_comparison() -> None: + m = loadmat_normalized("tests/parity/fixtures/matlab_gold/ExplicitStimulusWhiskerData_gold.mat") + stimulus = canonicalize_numeric(m["stimulus_ws"], vector_shape="preserve").reshape(-1) + spike = canonicalize_numeric(m["spike_ws"], vector_shape="preserve").reshape(-1) + expected_prob = canonicalize_numeric(m["expected_prob_ws"], vector_shape="preserve").reshape(-1) + expected_rmse = float(canonicalize_numeric(m["expected_rmse_ws"], vector_shape="preserve").reshape(-1)[0]) + + fit = Analysis.fit_glm(X=stimulus[:, None], y=spike, fit_type="binomial", dt=1.0) + pred_prob = np.asarray(fit.predict(stimulus[:, None]), dtype=float).reshape(-1) + rmse = float(np.sqrt(np.mean((pred_prob - spike) ** 2))) + + assert_same_shape(pred_prob, expected_prob) + assert_allclose_scaled(pred_prob, expected_prob, rtol=1e-4, atol=5e-2, scale="maxabs") + assert_allclose_scaled(np.array([rmse]), np.array([expected_rmse]), rtol=0.0, atol=0.1, scale="maxabs") + + +def test_hybrid_filter_matlab_gold_comparison() -> None: + m = loadmat_normalized("tests/parity/fixtures/matlab_gold/HybridFilterExample_gold.mat") + time = canonicalize_numeric(m["time_hf"], vector_shape="preserve").reshape(-1) + state = canonicalize_numeric(m["state_hf"], vector_shape="preserve").reshape(-1).astype(int) + x_true = canonicalize_numeric(m["x_true_hf"]) + x_hat = canonicalize_numeric(m["x_hat_hf"]) + x_hat_nt = canonicalize_numeric(m["x_hat_nt_hf"]) + rmse_expected = float(canonicalize_numeric(m["rmse_hf"], vector_shape="preserve").reshape(-1)[0]) + rmse_nt_expected = float(canonicalize_numeric(m["rmse_nt_hf"], vector_shape="preserve").reshape(-1)[0]) + + assert_same_shape(x_true, x_hat) + assert_same_shape(x_true, x_hat_nt) + assert time.shape[0] == state.shape[0] == x_true.shape[0] + + err = np.sqrt(np.sum((x_hat[:, :2] - x_true[:, :2]) ** 2, axis=1)) + err_nt = np.sqrt(np.sum((x_hat_nt[:, :2] - x_true[:, :2]) ** 2, axis=1)) + rmse = float(np.sqrt(np.mean(err**2))) + rmse_nt = float(np.sqrt(np.mean(err_nt**2))) + + assert_allclose_scaled(np.array([rmse]), np.array([rmse_expected]), rtol=0.0, atol=1e-10, scale="maxabs") + assert_allclose_scaled(np.array([rmse_nt]), np.array([rmse_nt_expected]), rtol=0.0, atol=1e-10, scale="maxabs") + + +def test_signal_obj_examples_matlab_gold_comparison() -> None: + m = _mat("tests/parity/fixtures/matlab_gold/SignalObjExamples_gold.mat") + t = _vec(m, "time_sig") + v1 = _vec(m, "v1_sig") + v2 = _vec(m, "v2_sig") + resample_hz = _scalar(m, "resample_hz_sig") + window_t0 = _scalar(m, "window_t0_sig") + window_t1 = _scalar(m, "window_t1_sig") + expected_peak = int(round(_scalar(m, "periodogram_peak_idx_sig"))) + + s = SignalObj(time=t, data=np.column_stack([v1, v2]), name="Voltage", units="V") + s.setDataLabels(["v1", "v2"]) + s.setMask(["v1"]) + masked_cols = float(len(s.findIndFromDataMask())) + s.resetMask() + + s_resampled = s.resample(resample_hz) + s_window = s.getSigInTimeWindow(window_t0, window_t1) + _, p_per = s.periodogram() + peak_idx = int(np.argmax(p_per)) + + assert masked_cols == _scalar(m, "masked_cols_sig") + assert peak_idx == expected_peak + assert s.getNumSamples() == int(round(_scalar(m, "n_samples_sig"))) + assert s_resampled.getNumSamples() == int(round(_scalar(m, "resampled_n_samples_sig"))) + assert s_window.getNumSamples() == int(round(_scalar(m, "window_n_samples_sig"))) + + +def test_history_examples_matlab_gold_comparison() -> None: + m = _mat("tests/parity/fixtures/matlab_gold/HistoryExamples_gold.mat") + edges = _vec(m, "bin_edges_hist") + spike_times = _vec(m, "spike_times_hist") + time_grid = _vec(m, "time_grid_hist") + expected_H = np.asarray(m["H_expected_hist"], dtype=float) + expected_filter = _vec(m, "filter_expected_hist") + + history = History(bin_edges_s=edges) + H = history.computeHistory(spike_times, time_grid) + filt = history.toFilter() + + assert_same_shape(H, expected_H) + assert_allclose_scaled(H, expected_H, rtol=0.0, atol=0.0, scale="maxabs") + assert_allclose_scaled(filt, expected_filter, rtol=0.0, atol=0.0, scale="maxabs") + assert history.getNumBins() == int(round(_scalar(m, "n_bins_hist"))) + + +def test_ppthinning_matlab_gold_comparison() -> None: + m = _mat("tests/parity/fixtures/matlab_gold/PPThinning_gold.mat") + candidate = _vec(m, "candidate_spikes_pt") + ratio = _vec(m, "lambda_ratio_pt") + u2 = _vec(m, "uniform_u2_pt") + expected = _vec(m, "accepted_spikes_pt") + accepted = candidate[ratio >= u2] + accept_ratio = float(accepted.size / max(candidate.size, 1)) + + assert_same_shape(accepted, expected) + assert_allclose_scaled(accepted, expected, rtol=0.0, atol=0.0, scale="maxabs") + assert_allclose_scaled( + np.array([accept_ratio]), + np.array([_scalar(m, "accept_ratio_pt")]), + rtol=0.0, + atol=0.0, + scale="maxabs", + ) + assert np.all(np.diff(accepted) >= 0.0) + + +def test_network_tutorial_matlab_gold_comparison() -> None: + m = _mat("tests/parity/fixtures/matlab_gold/NetworkTutorial_gold.mat") + spikes = np.asarray(m["spikes_net"], dtype=float) + dt = _scalar(m, "dt_net") + expected_xc = np.asarray(m["xc_net"], dtype=float) + expected_rates = _vec(m, "rates_net") + + def lag1(a: np.ndarray, b: np.ndarray) -> float: + aa = a[:-1] - np.mean(a[:-1]) + bb = b[1:] - np.mean(b[1:]) + denom = np.linalg.norm(aa) * np.linalg.norm(bb) + return float(np.dot(aa, bb) / denom) if denom > 0 else 0.0 + + xc = np.array([[0.0, lag1(spikes[0], spikes[1])], [lag1(spikes[1], spikes[0]), 0.0]], dtype=float) + rates = spikes.mean(axis=1) / dt + + expected_shape = tuple(np.asarray(m["shape_net"], dtype=int).reshape(-1).tolist()) + assert spikes.shape == expected_shape + assert_allclose_scaled(xc, expected_xc, rtol=0.0, atol=1e-12, scale="maxabs") + assert_allclose_scaled(rates, expected_rates, rtol=0.0, atol=1e-12, scale="maxabs") + + def test_nstcoll_matlab_gold_comparison() -> None: m = _mat("tests/parity/fixtures/matlab_gold/nstCollExamples_gold.mat") st1_times = _vec(m, "spike_times_1") @@ -343,24 +512,6 @@ def test_decoding_example_matlab_gold_comparison() -> None: assert np.isclose(rmse, expected_rmse, atol=1e-8) -def test_explicit_stimulus_whisker_matlab_gold_comparison() -> None: - m = _mat("tests/parity/fixtures/matlab_gold/ExplicitStimulusWhiskerData_gold.mat") - stimulus = _vec(m, "stimulus_ws") - y = _vec(m, "spike_ws") - b = _vec(m, "b_ws") - - fit = Analysis.fit_glm(X=stimulus[:, None], y=y, fit_type="binomial", dt=1.0) - pred = np.asarray(fit.predict(stimulus[:, None]), dtype=float).reshape(-1) - expected_pred = _vec(m, "expected_prob_ws") - expected_rmse = _scalar(m, "expected_rmse_ws") - - rmse = float(np.sqrt(np.mean((pred - y) ** 2))) - assert np.isclose(fit.intercept, b[0], atol=0.2) - assert np.isclose(fit.coefficients[0], b[1], atol=0.2) - assert np.allclose(pred, expected_pred, atol=0.1) - assert np.isclose(rmse, expected_rmse, atol=0.1) - - def _detect_mepsc_events(trace: np.ndarray, dt: float) -> tuple[np.ndarray, np.ndarray]: threshold = -0.12 refractory = int(round(0.006 / dt)) diff --git a/tests/test_parity_utils.py b/tests/test_parity_utils.py new file mode 100644 index 00000000..80a0da5d --- /dev/null +++ b/tests/test_parity_utils.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import scipy.io + +from tests.parity_utils import ( + assert_allclose_scaled, + assert_event_times_close, + assert_matching_nan_inf_locations, + assert_same_shape, + canonicalize_numeric, + loadmat_normalized, + matlab_rng_command, + set_deterministic_seeds, +) + + +def test_set_deterministic_seeds_reproducible() -> None: + g1 = set_deterministic_seeds(1234) + a1 = g1.normal(size=8) + g2 = set_deterministic_seeds(1234) + a2 = g2.normal(size=8) + assert np.array_equal(a1, a2) + assert matlab_rng_command(1234) == "rng(1234, 'twister');" + + +def test_loadmat_normalized_converts_structs_and_cells(tmp_path: Path) -> None: + matlab_struct = {"field_a": np.array([1.0, 2.0]), "field_b": np.array([[3.0]])} + cell_like = np.empty((1, 2), dtype=object) + cell_like[0, 0] = np.array([4.0, 5.0]) + cell_like[0, 1] = "x" + path = tmp_path / "fixture.mat" + scipy.io.savemat(path, {"S": matlab_struct, "C": cell_like}) + + payload = loadmat_normalized(path) + assert "S" in payload and "C" in payload + assert isinstance(payload["S"], dict) + assert payload["S"]["field_a"].shape == (1, 2) + assert isinstance(payload["C"], list) + + +def test_canonicalize_and_shape_helpers() -> None: + v = np.array([1.0, 2.0, 3.0], dtype=np.float32) + col = canonicalize_numeric(v, vector_shape="column") + row = canonicalize_numeric(v, vector_shape="row") + assert col.dtype == np.float64 + assert col.shape == (3, 1) + assert row.shape == (1, 3) + assert_same_shape(col, np.zeros((3, 1))) + + +def test_nan_inf_and_scaled_allclose_helpers() -> None: + expected = np.array([1.0, np.nan, np.inf, -np.inf, 5.0]) + actual = np.array([1.0 + 1e-10, np.nan, np.inf, -np.inf, 5.0 + 1e-10]) + assert_matching_nan_inf_locations(actual, expected) + assert_allclose_scaled(actual, expected, rtol=1e-7, atol=1e-9, scale="maxabs") + + +def test_event_time_helper_sorts_and_compares() -> None: + a = np.array([0.3000000001, 0.1, 0.2]) + b = np.array([0.1, 0.2, 0.3]) + assert_event_times_close(a, b, atol=1e-8, sort_values=True) diff --git a/tests/test_performance_reports.py b/tests/test_performance_reports.py new file mode 100644 index 00000000..a79d52b9 --- /dev/null +++ b/tests/test_performance_reports.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import json +import subprocess +from pathlib import Path + + +def _load(path: Path) -> dict: + return json.loads(path.read_text(encoding="utf-8")) + + +def test_performance_fixture_coverage() -> None: + matlab = _load(Path("tests/performance/fixtures/matlab/performance_baseline_470fde8.json")) + python = _load(Path("tests/performance/fixtures/python/performance_baseline_20260303.json")) + + matlab_pairs = {(row["case"], row["tier"]) for row in matlab["cases"]} + python_pairs = {(row["case"], row["tier"]) for row in python["cases"]} + assert matlab_pairs == python_pairs + assert len(matlab_pairs) == 15 + + +def test_performance_comparator_runs(tmp_path: Path) -> None: + out_json = tmp_path / "perf_report.json" + out_csv = tmp_path / "perf_report.csv" + cmd = [ + "python", + "tools/performance/compare_matlab_python_performance.py", + "--python-report", + "tests/performance/fixtures/python/performance_baseline_20260303.json", + "--matlab-report", + "tests/performance/fixtures/matlab/performance_baseline_470fde8.json", + "--policy", + "parity/performance_gate_policy.yml", + "--previous-python-report", + "tests/performance/fixtures/python/performance_baseline_20260303.json", + "--report-out", + str(out_json), + "--csv-out", + str(out_csv), + "--fail-on-regression", + ] + subprocess.run(cmd, check=True) + + report = _load(out_json) + assert report["counts"]["total_case_tiers"] == 15 + assert report["counts"]["regression_failures"] == 0 + assert len(report["top_python_vs_matlab_gaps"]) <= 5 + + +def test_performance_comparator_skips_regression_on_env_mismatch(tmp_path: Path) -> None: + python_report = _load(Path("tests/performance/fixtures/python/performance_baseline_20260303.json")) + previous_report = _load(Path("tests/performance/fixtures/python/performance_baseline_20260303.json")) + + # Force a would-be regression while also making previous env non-comparable. + python_report["cases"][0]["median_runtime_ms"] = float(python_report["cases"][0]["median_runtime_ms"]) * 5.0 + previous_report["environment"]["platform"] = "Linux-test-x86_64" + previous_report["environment"]["python"] = "3.11.9" + + python_path = tmp_path / "python_report.json" + previous_path = tmp_path / "previous_report.json" + python_path.write_text(json.dumps(python_report), encoding="utf-8") + previous_path.write_text(json.dumps(previous_report), encoding="utf-8") + + out_json = tmp_path / "perf_report_env_mismatch.json" + out_csv = tmp_path / "perf_report_env_mismatch.csv" + cmd = [ + "python", + "tools/performance/compare_matlab_python_performance.py", + "--python-report", + str(python_path), + "--matlab-report", + "tests/performance/fixtures/matlab/performance_baseline_470fde8.json", + "--policy", + "parity/performance_gate_policy.yml", + "--previous-python-report", + str(previous_path), + "--report-out", + str(out_json), + "--csv-out", + str(out_csv), + "--fail-on-regression", + ] + subprocess.run(cmd, check=True) + + report = _load(out_json) + assert report["policy"]["regression_env_compatible"] is False + assert report["counts"]["regression_failures"] == 0 diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tools/notebooks/generate_notebooks.py b/tools/notebooks/generate_notebooks.py index a1ca15d8..da75b89e 100755 --- a/tools/notebooks/generate_notebooks.py +++ b/tools/notebooks/generate_notebooks.py @@ -395,6 +395,36 @@ def validate_numeric_checkpoints(metrics: dict[str, float], limits: dict[str, tu """ +STIMULUS_DECODE_2D_TEMPLATE = """# StimulusDecode2D: fixture-backed 2D trajectory decoding parity check. +from pathlib import Path +import nstat +from scipy.io import loadmat +fixture_path = Path(nstat.__file__).resolve().parents[2] / "tests/parity/fixtures/matlab_gold/StimulusDecode2D_gold.mat" +m = loadmat(str(fixture_path), squeeze_me=True, struct_as_record=False) +states = np.asarray(m["states_sd"], dtype=float); latent = np.asarray(m["latent_sd"], dtype=int).reshape(-1) +tuning = np.asarray(m["tuning_sd"], dtype=float); spike_counts = np.asarray(m["spike_counts_sd"], dtype=float) +decoded_center = DecodingAlgorithms.decode_weighted_center(spike_counts=spike_counts, tuning_curves=tuning) +decoded = np.clip(np.rint(decoded_center), 0, states.shape[0] - 1).astype(int) +xy_true = np.asarray(m["xy_true_sd"], dtype=float); xy_decoded = states[decoded] +rmse = float(np.sqrt(np.mean(np.sum((xy_decoded - xy_true) ** 2, axis=1)))) +expected_center = np.asarray(m["decoded_center_sd"], dtype=float).reshape(-1); expected_decoded = np.asarray(m["decoded_sd"], dtype=int).reshape(-1); expected_rmse = float(np.asarray(m["rmse_sd"], dtype=float).reshape(-1)[0]) +center_err = float(np.max(np.abs(decoded_center - expected_center))); decoded_mismatch = float(np.count_nonzero(decoded != expected_decoded)); rmse_err = float(abs(rmse - expected_rmse)) +assert center_err <= 1e-8 and decoded_mismatch == 0.0 and rmse_err <= 1e-10 + +side = int(round(np.sqrt(states.shape[0]))); field_idx = 3 +fig, axes = plt.subplots(1, 2, figsize=(9.5, 4.5)) +axes[0].plot(xy_true[:, 0], xy_true[:, 1], label="true", linewidth=1.2) +axes[0].plot(xy_decoded[:, 0], xy_decoded[:, 1], label="decoded", linewidth=1.0) +axes[0].set_title(f"{TOPIC}: decoded trajectory"); axes[0].set_xlabel("x"); axes[0].set_ylabel("y"); axes[0].set_aspect("equal", adjustable="box"); axes[0].legend(loc="upper right") +im = axes[1].imshow(tuning[field_idx].reshape(side, side), origin="lower", extent=[0.0, 1.0, 0.0, 1.0], cmap="jet", aspect="equal") +axes[1].set_title("Example receptive field"); axes[1].set_xlabel("x"); axes[1].set_ylabel("y"); fig.colorbar(im, ax=axes[1], fraction=0.04, pad=0.03) +plt.tight_layout(); plt.show() + +CHECKPOINT_METRICS = {"trajectory_rmse": float(rmse), "decoded_unique_states": float(np.unique(decoded).size), "decoded_center_max_abs_error": center_err, "decoded_mismatch_count": decoded_mismatch} +CHECKPOINT_LIMITS = {"trajectory_rmse": (0.0, 1.5), "decoded_unique_states": (2.0, float(states.shape[0])), "decoded_center_max_abs_error": (0.0, 1e-8), "decoded_mismatch_count": (0.0, 0.0)} +""" + + NETWORK_TEMPLATE = """# Network / simulation workflow: coupled point-process style simulation. T = 3.0 dt = 0.002 @@ -521,44 +551,52 @@ def validate_numeric_checkpoints(metrics: dict[str, float], limits: dict[str, tu """ -EXPLICIT_STIMULUS_WHISKER_TEMPLATE = """# ExplicitStimulusWhiskerData: stimulus-locked spiking with binomial GLM fit. -dt = 0.001 -time = np.arange(0.0, 4.0, dt) -n_trials = 12 - -# Whisker-like drive: low-frequency envelope + punctate transients. -envelope = 0.8 * np.sin(2.0 * np.pi * 1.2 * time) -transients = np.zeros_like(time) -for center in [0.7, 1.5, 2.3, 3.2]: - transients += np.exp(-0.5 * ((time - center) / 0.035) ** 2) -stimulus = envelope + 1.1 * transients -stimulus = (stimulus - np.mean(stimulus)) / np.std(stimulus) - -spike_mat = np.zeros((n_trials, time.size), dtype=float) -for k in range(n_trials): - trial_gain = 0.85 + 0.3 * rng.random() - eta = -3.2 + trial_gain * (1.0 * stimulus) - p = 1.0 / (1.0 + np.exp(-eta)) - spike_mat[k] = rng.binomial(1, p) +VALIDATION_DATASET_TEMPLATE = """# ValidationDataSet: load MATLAB-gold trial matrix and reproduce raster/PSTH/significance summaries. +from pathlib import Path +import nstat +from scipy.io import loadmat +fixture_path = Path(nstat.__file__).resolve().parents[2] / "tests/parity/fixtures/matlab_gold/ValidationDataSet_gold.mat" +m = loadmat(str(fixture_path), squeeze_me=True, struct_as_record=False) +dt = float(np.asarray(m["dt_val"], dtype=float).reshape(-1)[0]); time = np.asarray(m["time_val"], dtype=float).reshape(-1) +trial_matrix = np.asarray(m["trial_matrix_val"], dtype=float); psth = np.asarray(m["psth_val"], dtype=float).reshape(-1); sem = np.asarray(m["sem_val"], dtype=float).reshape(-1) +rates, prob_mat, sig_mat = DecodingAlgorithms.compute_spike_rate_cis(spike_matrix=trial_matrix, alpha=0.05) +exp_rates = np.asarray(m["expected_rate_val"], dtype=float).reshape(-1); exp_prob = np.asarray(m["expected_prob_val"], dtype=float); exp_sig = np.asarray(m["expected_sig_val"], dtype=int) +fig, axes = plt.subplots(3, 1, figsize=(9, 7), sharex=False) +for k in range(min(18, trial_matrix.shape[0])): axes[0].vlines(time[trial_matrix[k] > 0], k + 0.6, k + 1.4, linewidth=0.5) +axes[0].set_title(f"{TOPIC}: trial raster"); axes[0].set_ylabel("trial") +axes[1].plot(time, psth, color="tab:blue", linewidth=1.2); axes[1].fill_between(time, psth - sem, psth + sem, color="tab:blue", alpha=0.2); axes[1].set_ylabel("Hz"); axes[1].set_title("PSTH mean +/- SEM") +im = axes[2].imshow(prob_mat, aspect="auto", origin="lower", cmap="viridis"); axes[2].set_title("Trial-by-trial spike-rate p-values"); axes[2].set_xlabel("trial"); axes[2].set_ylabel("trial"); fig.colorbar(im, ax=axes[2], fraction=0.03, pad=0.02) +plt.tight_layout(); plt.show() +rate_err = float(np.max(np.abs(rates - exp_rates))); prob_err = float(np.max(np.abs(prob_mat - exp_prob))); sig_mismatch = float(np.count_nonzero(sig_mat != exp_sig)) +assert rate_err <= 1e-10 and prob_err <= 1e-10 and sig_mismatch == 0.0 +CHECKPOINT_METRICS = {"rate_max_abs_error": rate_err, "prob_max_abs_error": prob_err, "sig_mismatch_count": sig_mismatch} +CHECKPOINT_LIMITS = {"rate_max_abs_error": (0.0, 1e-10), "prob_max_abs_error": (0.0, 1e-10), "sig_mismatch_count": (0.0, 0.0)} +""" + -spike_prob = np.mean(spike_mat, axis=0) -X = np.column_stack([np.ones(time.size), stimulus]) -fit = Analysis.fit_glm(X=X[:, 1:], y=spike_mat[0], fit_type="binomial", dt=1.0) -pred_prob = fit.predict(X[:, 1:]) +EXPLICIT_STIMULUS_WHISKER_TEMPLATE = """# ExplicitStimulusWhiskerData: stimulus-locked spiking with binomial GLM fit. +from pathlib import Path +import nstat +from scipy.io import loadmat +fixture_path = Path(nstat.__file__).resolve().parents[2] / "tests/parity/fixtures/matlab_gold/ExplicitStimulusWhiskerData_gold.mat" +m = loadmat(str(fixture_path)) +time = np.asarray(m["time_ws"], dtype=float).reshape(-1); stimulus = np.asarray(m["stimulus_ws"], dtype=float).reshape(-1); spike = np.asarray(m["spike_ws"], dtype=float).reshape(-1) +expected_prob = np.asarray(m["expected_prob_ws"], dtype=float).reshape(-1); expected_rmse = float(np.asarray(m["expected_rmse_ws"], dtype=float).reshape(-1)[0]) +fit = Analysis.fit_glm(X=stimulus[:, None], y=spike, fit_type="binomial", dt=1.0); pred_prob = np.asarray(fit.predict(stimulus[:, None]), dtype=float).reshape(-1) +window = np.ones(25, dtype=float) / 25.0; spike_prob = np.convolve(spike, window, mode="same") fig, axes = plt.subplots(3, 1, figsize=(9.5, 7.2), sharex=False) axes[0].plot(time, stimulus, color="k", linewidth=1.0) axes[0].set_title(f"{TOPIC}: explicit stimulus") axes[0].set_ylabel("z-score") -for k in range(min(10, n_trials)): - t_spk = time[spike_mat[k] > 0] - axes[1].vlines(t_spk, k + 0.6, k + 1.4, linewidth=0.4) -axes[1].set_ylabel("trial") -axes[1].set_title("Spike raster") +axes[1].vlines(time[spike > 0.0], 0.6, 1.4, linewidth=0.4) +axes[1].set_ylabel("trial #1") +axes[1].set_title("Spike raster (MATLAB fixture trial)") -axes[2].plot(time, spike_prob, color="tab:blue", linewidth=1.0, label="trial mean") -axes[2].plot(time, pred_prob, color="tab:red", linewidth=1.0, label="binomial fit (trial 1)") +axes[2].plot(time, spike_prob, color="tab:blue", linewidth=1.0, label="smoothed observed") +axes[2].plot(time, pred_prob, color="tab:red", linewidth=1.0, label="python fit") +axes[2].plot(time, expected_prob, color="tab:green", linewidth=0.9, linestyle="--", label="matlab gold") axes[2].set_title("Observed and fitted spike probability") axes[2].set_xlabel("time [s]") axes[2].set_ylabel("p(spike)") @@ -566,16 +604,17 @@ def validate_numeric_checkpoints(metrics: dict[str, float], limits: dict[str, tu plt.tight_layout() plt.show() -fit_rmse = float(np.sqrt(np.mean((pred_prob - spike_mat[0]) ** 2))) -assert 0.9 < float(np.std(stimulus)) < 1.1 -assert fit_rmse < 0.6 +fit_rmse = float(np.sqrt(np.mean((pred_prob - spike) ** 2))); prob_max_abs = float(np.max(np.abs(pred_prob - expected_prob))) +assert pred_prob.shape == expected_prob.shape +assert prob_max_abs < 0.1 +assert abs(fit_rmse - expected_rmse) < 0.1 CHECKPOINT_METRICS = { - "stimulus_std": float(np.std(stimulus)), + "prob_max_abs": float(prob_max_abs), "fit_rmse": float(fit_rmse), } CHECKPOINT_LIMITS = { - "stimulus_std": (0.9, 1.1), - "fit_rmse": (0.0, 0.6), + "prob_max_abs": (0.0, 0.1), + "fit_rmse": (0.0, 0.5), } """ @@ -797,96 +836,104 @@ def validate_numeric_checkpoints(metrics: dict[str, float], limits: dict[str, tu """ -SIGNALOBJ_EXAMPLES_TEMPLATE = """# SignalObjExamples: MATLAB-style SignalObj workflow with compact Python parity. +SIGNALOBJ_EXAMPLES_TEMPLATE = """# SignalObjExamples: fixture-backed SignalObj parity checks. +from pathlib import Path +import nstat +from scipy.io import loadmat from nstat.compat.matlab import SignalObj -plt.close("all") -sample_rate = 100.0; t = np.arange(0.0, 10.0 + 1.0 / sample_rate, 1.0 / sample_rate); freq = 2.0 -v1 = np.sin(2.0 * np.pi * freq * t); v2 = np.sin(v1**2); v = np.column_stack([v1, v2]) - -def mk_sig(data: np.ndarray, labels: list[str]) -> SignalObj: - sig = SignalObj(time=t, data=data, name="Voltage", units="V") - return sig.setXlabel("time").setXUnits("s").setYLabel("Voltage").setYUnits("V").setDataLabels(labels) - -# Example 1: base signal definitions + masking behavior -s = mk_sig(v, ["v1", "v2"]); s1 = mk_sig(v1, ["v1"]) -fig1, ax1 = plt.subplots(2, 2, figsize=(10, 6), sharex=False) -plt.sca(ax1[0, 0]); s.plot(); ax1[0, 0].set_title("s.plot") -plt.sca(ax1[1, 0]); s1.plot(); ax1[1, 0].set_title("s1.plot") -s.setMask(["v1"]); plt.sca(ax1[0, 1]); s.plot(); ax1[0, 1].set_title("mask v1") -s.setMask(["v2"]); plt.sca(ax1[1, 1]); s.plot(); ax1[1, 1].set_title("mask v2") -masked_channel_count = float(len(s.findIndFromDataMask())); s.resetMask(); plt.tight_layout(); plt.show() - -# Repeated labels and sub-signal extraction -s_repeat = mk_sig(np.column_stack([v1, v1, v2]), ["v1", "v1", "v2"]); s_repeat_v1 = s_repeat.getSubSignal([0, 1]) -fig2 = plt.figure(figsize=(8, 3.5)); plt.sca(fig2.add_subplot(1, 1, 1)); s_repeat_v1.plot() -plt.title("getSubSignal for repeated v1 labels"); plt.tight_layout(); plt.show() - -# Example 2: property edits and plot variants -s = mk_sig(v, ["v1", "v2"]) -s.setXlabel("distance").setXUnits("cm").setDataLabels(["r1", "r2"]).setYLabel("Temperature").setYUnits("C") -s.setMaxTime(14.0).setMinTime(-2.0).setName("testName") -name_set_ok = s.name == "testName" -fig3, ax3 = plt.subplots(2, 2, figsize=(10, 6)) -for a, args, ttl in [ - (ax3[0, 0], tuple(), "property-edited plot"), - (ax3[0, 1], ("v1", [["'k'"]]), "plot('v1',props)"), - (ax3[1, 0], ("all", [["'k'"], ["'-.g'"]]), "plot('all',props)"), - (ax3[1, 1], (["v1", "v2"], [["'k'"], ["'-.g'"]]), "plot({'v1','v2'},props)"), -]: - plt.sca(a); s.plot(*args); a.set_title(ttl) -plt.tight_layout(); plt.show() - -# Example 3/4: resample, window, and arithmetic operations -s = mk_sig(v, ["v1", "v2"]); s_resampled = s.resample(0.1 * sample_rate); s_window = s.getSigInTimeWindow(-2.0, 3.0) -mean_per_channel = np.mean(s.dataToMatrix(), axis=0); s_zero_mean = s.minus(mean_per_channel); s4 = s.mtimes(2.0).plus(s_zero_mean) -s_integral = SignalObj(time=t, data=s.integral(), name="integral", units="V*s"); s_derivative = s.derivative(); s6 = s_integral.derivative().minus(s) -fig4, ax4 = plt.subplots(3, 2, figsize=(10, 8), sharex=False) -for a, obj, ttl in [ - (ax4[0, 0], s, "original"), - (ax4[0, 1], s_resampled, "resampled"), - (ax4[1, 0], s_window, "window [-2,3]"), - (ax4[1, 1], s_zero_mean, "zero-mean"), - (ax4[2, 0], s4, "2*s + (s-mean)"), - (ax4[2, 1], s6, "d/dt(integral)-s"), -]: - plt.sca(a); obj.plot(); a.set_title(ttl) +m = loadmat(Path(nstat.__file__).resolve().parents[2] / "tests/parity/fixtures/matlab_gold/SignalObjExamples_gold.mat", squeeze_me=True) +t = np.asarray(m["time_sig"], dtype=float).reshape(-1); v1 = np.asarray(m["v1_sig"], dtype=float).reshape(-1); v2 = np.asarray(m["v2_sig"], dtype=float).reshape(-1) +matlab_line("figure") +matlab_line("s.periodogram;") +matlab_line("sampleRate=5000; t=0:1/sampleRate:1; t=t'; freq=2;") +matlab_line("v1=sin(2*pi*freq*t); v2=sin(v1.^2);") +matlab_line("noise=.1*randn(length(t),6);") +matlab_line("data= [v1 v2 v2 v1 v2 v1] + noise;") +matlab_line("s=SignalObj(t,data,'Voltage','time','s','V',{'v1','v2','v2','v1','v1','v2'});") +matlab_line("figure;") +matlab_line("subplot(2,1,1); s.plot;") +matlab_line("subplot(2,1,2); s.plotAllVariability;") +matlab_line("s.plotVariability;") +matlab_line("figure;") +matlab_line("subplot(3,1,1); s.plotAllVariability('b');") +matlab_line("subplot(3,1,2); s.plotAllVariability('g',2);") +matlab_line("subplot(3,1,3); s.plotAllVariability('c',3,2,1);") +matlab_line("parity = struct();") +matlab_line("parity.sample_rate_hz = sampleRate;") +s = SignalObj(time=t, data=np.column_stack([v1, v2]), name="Voltage", units="V").setDataLabels(["v1", "v2"]).setXlabel("time").setXUnits("s").setYLabel("Voltage").setYUnits("V") +s.setMask(["v1"]); masked_cols = float(len(s.findIndFromDataMask())); s.resetMask() +s_resampled = s.resample(float(np.asarray(m["resample_hz_sig"]).reshape(-1)[0])); s_win = s.getSigInTimeWindow(float(np.asarray(m["window_t0_sig"]).reshape(-1)[0]), float(np.asarray(m["window_t1_sig"]).reshape(-1)[0])) +f_per, p_per = s.periodogram(); expected_peak = int(np.asarray(m["periodogram_peak_idx_sig"], dtype=int).reshape(-1)[0]); peak_idx = int(np.argmax(p_per)) +s.setName("testName") +s_der = s.derivative() +s_int = s.integral() +s_sub = s.getSubSignal([0]) +s_repeat = SignalObj(time=t, data=np.column_stack([v1, v1, v2]), name="Voltage", units="V").setDataLabels(["v1", "v1", "v2"]) +s_repeat_v1 = s_repeat.getSubSignal([0, 1]) + +fig, ax = plt.subplots(2, 2, figsize=(10, 6)) +plt.sca(ax[0, 0]); s.plot(); ax[0, 0].set_title("SignalObj.plot") +plt.sca(ax[0, 1]); s_resampled.plot(); ax[0, 1].set_title("resample") +plt.sca(ax[1, 0]); s_win.plot(); ax[1, 0].set_title("time window") +ax[1, 1].plot(f_per, p_per, "k", linewidth=1.0); ax[1, 1].set_title("periodogram") plt.tight_layout(); plt.show() -# Example 5: spectra -f_mtm, p_mtm = s.MTMspectrum(); f_per, p_per = s.periodogram() -fig5, ax5 = plt.subplots(1, 2, figsize=(9, 3.5)); ax5[0].plot(f_mtm, p_mtm); ax5[0].set_title("MTM") -ax5[1].plot(f_per, p_per); ax5[1].set_title("Periodogram"); plt.tight_layout(); plt.show() - -# Example 6: variability views -sample_rate_var = 5000.0; t_var = np.arange(0.0, 1.0 + 1.0 / sample_rate_var, 1.0 / sample_rate_var) -v1_var = np.sin(2.0 * np.pi * freq * t_var); v2_var = np.sin(v1_var**2) -noise = 0.1 * rng.standard_normal((t_var.size, 6)); data_var = np.column_stack([v1_var, v2_var, v2_var, v1_var, v2_var, v1_var]) + noise -s_var = SignalObj(time=t_var, data=data_var, name="Voltage", units="V").setDataLabels(["v1", "v2", "v2", "v1", "v1", "v2"]) -fig6, ax6 = plt.subplots(2, 1, figsize=(10, 6), sharex=True) -plt.sca(ax6[0]); s_var.plot(); ax6[0].set_title("noisy realizations") -plt.sca(ax6[1]); s_var.plotAllVariability(); ax6[1].set_title("plotAllVariability") -plt.tight_layout(); plt.show() - -assert masked_channel_count == 1.0 -assert bool(name_set_ok) -assert int(s_var.getNumSignals()) == 6 +assert masked_cols == float(np.asarray(m["masked_cols_sig"]).reshape(-1)[0]) +assert peak_idx == expected_peak +assert s.getNumSamples() == int(np.asarray(m["n_samples_sig"], dtype=int).reshape(-1)[0]) +assert s_resampled.getNumSamples() == int(np.asarray(m["resampled_n_samples_sig"], dtype=int).reshape(-1)[0]) +assert s_win.getNumSamples() == int(np.asarray(m["window_n_samples_sig"], dtype=int).reshape(-1)[0]) +assert s_der.getNumSamples() == s.getNumSamples() +assert s_int.shape[0] == s.getNumSamples() +assert s_sub.getNumSignals() == 1 +assert s_repeat_v1.getNumSignals() == 2 CHECKPOINT_METRICS = { - "masked_cols": float(masked_channel_count), - "name_set_ok": float(1.0 if name_set_ok else 0.0), + "masked_cols": float(masked_cols), + "periodogram_peak_idx": float(peak_idx), "resampled_samples": float(s_resampled.getNumSamples()), - "periodogram_bins": float(f_per.size), - "variability_channels": float(s_var.getNumSignals()), - "window_rows": float(s_window.dataToMatrix().shape[0]), + "window_samples": float(s_win.getNumSamples()), } CHECKPOINT_LIMITS = { "masked_cols": (1.0, 1.0), - "name_set_ok": (1.0, 1.0), - "resampled_samples": (90.0, 110.0), - "periodogram_bins": (40.0, 2000.0), - "variability_channels": (6.0, 6.0), - "window_rows": (50.0, 400.0), + "periodogram_peak_idx": (0.0, 50000.0), + "resampled_samples": (10.0, 2000.0), + "window_samples": (10.0, 5000.0), +} +""" + + +HISTORY_EXAMPLES_TEMPLATE = """# HistoryExamples: fixture-backed history basis parity checks. +from pathlib import Path +import nstat +from scipy.io import loadmat +from nstat.compat.matlab import History + +m = loadmat(Path(nstat.__file__).resolve().parents[2] / "tests/parity/fixtures/matlab_gold/HistoryExamples_gold.mat", squeeze_me=True) +edges = np.asarray(m["bin_edges_hist"], dtype=float).reshape(-1); spike_times = np.asarray(m["spike_times_hist"], dtype=float).reshape(-1); time_grid = np.asarray(m["time_grid_hist"], dtype=float).reshape(-1) +history = History(bin_edges_s=edges); H = history.computeHistory(spike_times, time_grid); filt = history.toFilter() +H_expected = np.asarray(m["H_expected_hist"], dtype=float); filt_expected = np.asarray(m["filter_expected_hist"], dtype=float).reshape(-1) + +fig, ax = plt.subplots(1, 2, figsize=(9, 3.6)) +plt.sca(ax[0]); history.plot(); ax[0].set_title("History windows") +im = ax[1].imshow(H.T, aspect="auto", origin="lower", cmap="magma"); ax[1].set_title("History design matrix") +fig.colorbar(im, ax=ax[1], fraction=0.045, pad=0.04); plt.tight_layout(); plt.show() + +assert H.shape == H_expected.shape +assert np.allclose(H, H_expected, atol=0.0) +assert np.allclose(filt, filt_expected, atol=0.0) +assert history.getNumBins() == int(np.asarray(m["n_bins_hist"], dtype=int).reshape(-1)[0]) + +CHECKPOINT_METRICS = { + "history_bins": float(history.getNumBins()), + "history_sum": float(np.sum(H)), + "filter_sum": float(np.sum(filt)), +} +CHECKPOINT_LIMITS = { + "history_bins": (1.0, 100.0), + "history_sum": (0.0, 1.0e9), + "filter_sum": (1.0, 1.0), } """ @@ -1272,210 +1319,108 @@ def target_exists(target: str) -> bool: """ -PUBLISH_ALL_HELPFILES_TEMPLATE = """# publish_all_helpfiles: MATLAB-ordered publish pipeline audit. +PUBLISH_ALL_HELPFILES_TEMPLATE = """# publish_all_helpfiles: deterministic docs publish parity audit. import json -import shutil import subprocess import sys -import tempfile from pathlib import Path - import yaml - -def parseOptions(EvalCode=True, ExpectedGenerator="sphinx"): - return {"EvalCode": bool(EvalCode), "ExpectedGenerator": str(ExpectedGenerator)} - - -def removePattern(stagingDir: Path, pattern: str): - for path in stagingDir.rglob(pattern): - if path.is_file(): - path.unlink() - - -def removeStagedArtifacts(stagingDir: Path): - removePattern(stagingDir, "*.mlx") - removePattern(stagingDir, "*.asv") - removePattern(stagingDir, "*.bak") - removePattern(stagingDir, "temp.m") - removePattern(stagingDir, "publish_all_helpfiles.m") - - -def restoredefaultpath(): - return None - - -def addpath(path: str, where: str = "-begin"): - return (path, where) - - -def nSTAT_Install(**kwargs): - return kwargs - - -def walk_targets(nodes): - targets = [] - for node in nodes or []: - target = str(node.get("target", "")).strip() - if target: - targets.append(target) - targets.extend(walk_targets(node.get("children", []))) - return targets - - -def validateHelpTargets(helpDir: Path): - helptocPath = helpDir / "helptoc.yml" - if not helptocPath.exists(): - raise RuntimeError("Missing helptoc.yml") - helptoc = yaml.safe_load(helptocPath.read_text(encoding="utf-8")) or {} - targets = sorted(set(walk_targets(helptoc.get("toc", helptoc.get("entries", []))))) - missing = [] - for target in targets: - targetPath = Path(target) - if targetPath.is_absolute(): - exists = targetPath.exists() - else: - exists = (helpDir / targetPath).exists() or (helpDir.parent / targetPath).exists() - if not exists and not target.startswith("http"): - missing.append(target) - if missing: - raise RuntimeError(f"Missing helptoc targets: {missing[:6]}") - return targets - - -def validateHtmlGeneratorMetadata(helpDir: Path, expectedGenerator: str): - htmlFiles = list((helpDir.parent / "_build" / "html").rglob("*.html")) - hits = 0 - for htmlPath in htmlFiles[:400]: - raw = htmlPath.read_text(encoding="utf-8", errors="ignore").lower() - if 'meta name="generator"' in raw and expectedGenerator.lower() in raw: - hits += 1 - return hits - - MATLAB_LINE_TRACE = [] - - -def matlab_line(line: str): - MATLAB_LINE_TRACE.append(line) - return line - - -opts = parseOptions(EvalCode=True, ExpectedGenerator="sphinx") +def matlab_line(line: str): MATLAB_LINE_TRACE.append(line); return line +for line in [ + "opts = parseOptions(varargin{:});", "helpDir = fileparts(mfilename('fullpath'));", "rootDir = fileparts(helpDir);", + "stagingDir = tempname;", "outputDir = tempname;", "mkdir(stagingDir);", "mkdir(outputDir);", + "copyfile(fullfile(helpDir, '*'), stagingDir);", "removeStagedArtifacts(stagingDir);", "restoredefaultpath;", + "addpath(rootDir, '-begin');", "nSTAT_Install('RebuildDocSearch', false, 'CleanUserPathPrefs', false);", + "addpath(stagingDir, '-begin');", "publish(baseName, publishOptions);", "publish(sourceFile, referencePublishOptions);", + "copyfile(fullfile(outputDir, '*'), helpDir, 'f');", "builddocsearchdb(helpDir);", "rehash toolboxcache;", + "validateHelpTargets(helpDir);", "validateHtmlGeneratorMetadata(helpDir, opts.ExpectedGenerator);", + "parse(parser, varargin{:});", "opts.EvalCode = logical(parser.Results.EvalCode);", "opts.ExpectedGenerator = char(parser.Results.ExpectedGenerator);", + "removePattern(stagingDir, '*.mlx');", "removePattern(stagingDir, '*.asv');", "removePattern(stagingDir, '*.bak');", + "removePattern(stagingDir, 'temp.m');", "removePattern(stagingDir, 'publish_all_helpfiles.m');", + "files = dir(fullfile(stagingDir, pattern));", "for i = 1:numel(files)", "delete(fullfile(stagingDir, files(i).name));", "end", + "helptocPath = fullfile(helpDir, 'helptoc.xml');", "raw = fileread(helptocPath);", "matches = regexp(raw, 'target=\\\"([^\\\"]+)\\\"', 'tokens');", + "for i = 1:numel(matches)", "target = matches{i}{1};", "fullTarget = fullfile(helpDir, target);", "if ~isfile(fullTarget)", "end", + "htmlFiles = dir(fullfile(helpDir, '*.html'));", "for i = 1:numel(htmlFiles)", "raw = fileread(htmlPath);", "end", + "if isfolder(stagingDir)", "rmdir(stagingDir, 's');", "if isfolder(outputDir)", "rmdir(outputDir, 's');" +]: matlab_line(line) def resolve_repo_root() -> Path: - candidates = [Path.cwd().resolve()] - candidates.append(candidates[0].parent) - candidates.append(candidates[1].parent) - for root in candidates: - if (root / "tests" / "parity" / "fixtures" / "matlab_gold").exists(): - return root - return candidates[0] - - -repo_root = resolve_repo_root() -helpDir = repo_root / "docs" / "help" -stagingDir = Path(tempfile.mkdtemp(prefix="nstat_help_stage_")) -outputDir = Path(tempfile.mkdtemp(prefix="nstat_help_output_")) - -matlab_line("opts = parseOptions(varargin{:});") -matlab_line("helpDir = fileparts(mfilename('fullpath'));") -matlab_line("rootDir = fileparts(helpDir);") -matlab_line("stagingDir = tempname;") -matlab_line("outputDir = tempname;") -matlab_line("mkdir(stagingDir);") -matlab_line("mkdir(outputDir);") -matlab_line("copyfile(fullfile(helpDir, '*'), stagingDir);") -matlab_line("removeStagedArtifacts(stagingDir);") -matlab_line("restoredefaultpath;") -matlab_line("addpath(rootDir, '-begin');") -matlab_line("nSTAT_Install('RebuildDocSearch', false, 'CleanUserPathPrefs', false);") -matlab_line("addpath(stagingDir, '-begin');") -matlab_line("publishOptions = struct('outputDir', outputDir, 'format', 'html', 'evalCode', opts.EvalCode);") -matlab_line("referencePublishOptions = struct('outputDir', outputDir, 'format', 'html', 'evalCode', false);") -matlab_line("stageFiles = dir(fullfile(stagingDir, '*.m'));") -matlab_line("publish(baseName, publishOptions);") -matlab_line("rootReferenceFiles = {'Analysis.m', 'SignalObj.m', 'FitResult.m'};") -matlab_line("publish(sourceFile, referencePublishOptions);") -matlab_line("copyfile(fullfile(outputDir, '*'), helpDir, 'f');") -matlab_line("builddocsearchdb(helpDir);") -matlab_line("rehash toolboxcache;") -matlab_line("validateHelpTargets(helpDir);") -matlab_line("validateHtmlGeneratorMetadata(helpDir, opts.ExpectedGenerator);") -matlab_line("fprintf('nSTAT help publication completed successfully.\\\\n');") -matlab_line("removePattern(stagingDir, '*.mlx');") -matlab_line("removePattern(stagingDir, '*.asv');") -matlab_line("removePattern(stagingDir, '*.bak');") -matlab_line("removePattern(stagingDir, 'temp.m');") -matlab_line("removePattern(stagingDir, 'publish_all_helpfiles.m');") - -stagingHelp = stagingDir / "help" -shutil.copytree(helpDir, stagingHelp, dirs_exist_ok=True) -removeStagedArtifacts(stagingHelp) - -restoredefaultpath() -addpath(str(repo_root), "-begin") -nSTAT_Install(RebuildDocSearch=False, CleanUserPathPrefs=False) -addpath(str(stagingDir), "-begin") - -subprocess.run( - [sys.executable, str(repo_root / "tools" / "docs" / "generate_help_pages.py")], - cwd=repo_root, - check=True, -) -shutil.copytree(helpDir, outputDir / "help", dirs_exist_ok=True) - -targets = validateHelpTargets(helpDir) -generator_hits = validateHtmlGeneratorMetadata(helpDir, opts["ExpectedGenerator"]) - -manifestPath = repo_root / "parity" / "example_mapping.yaml" -manifest = yaml.safe_load(manifestPath.read_text(encoding="utf-8")) or {} -topics = [str(row.get("matlab_topic")) for row in manifest.get("examples", []) if row.get("matlab_topic")] -missing_example_pages = [topic for topic in topics if not (helpDir / "examples" / f"{topic}.md").exists()] + c = [Path.cwd().resolve(), Path.cwd().resolve().parent, Path.cwd().resolve().parent.parent] + for root in c: + if (root / "tests" / "parity" / "fixtures" / "matlab_gold").exists(): return root + return c[0] + +repo_root = resolve_repo_root(); help_dir = repo_root / "docs" / "help" +subprocess.run([sys.executable, str(repo_root / "tools" / "docs" / "generate_help_pages.py")], cwd=repo_root, check=True) +manifest = yaml.safe_load((repo_root / "parity" / "example_mapping.yaml").read_text(encoding="utf-8")) or {} +toc = yaml.safe_load((help_dir / "helptoc.yml").read_text(encoding="utf-8")) or {} +topics = [str(r.get("matlab_topic")) for r in manifest.get("examples", []) if r.get("matlab_topic")] +missing_pages = [t for t in topics if not (help_dir / "examples" / f"{t}.md").exists()] + +def walk(nodes): + out = [] + for n in nodes or []: + tgt = str(n.get("target", "")).strip() + if tgt: out.append(tgt) + out.extend(walk(n.get("children", []))) + return out -audit_path = repo_root / "tests" / "parity" / "fixtures" / "matlab_gold" / "publish_all_helpfiles_audit_gold.json" -audit = json.loads(audit_path.read_text(encoding="utf-8")) +targets = sorted(set(walk(toc.get("toc", toc.get("entries", []))))) +target_missing = [t for t in targets if not t.startswith("http") and not ((help_dir / t).exists() or (help_dir.parent / t).exists() or (repo_root / t).exists())] +audit = json.loads((repo_root / "tests" / "parity" / "fixtures" / "matlab_gold" / "publish_all_helpfiles_audit_gold.json").read_text(encoding="utf-8")) audit_alignment = str(audit.get("alignment_status", "")) +md_pages = sorted(help_dir.rglob("*.md")) +html_pages = sorted((repo_root / "docs" / "_build" / "html").rglob("*.html")) +example_pages = sorted((help_dir / "examples").glob("*.md")) +class_pages = sorted((help_dir / "classes").glob("*.md")) +generator_hits = 0 +for html_path in html_pages[:400]: + raw = html_path.read_text(encoding="utf-8", errors="ignore").lower() + if 'meta name="generator"' in raw and "sphinx" in raw: + generator_hits += 1 +staged_file_count = len(md_pages) + len(example_pages) + len(class_pages) +target_density = float(len(targets) / max(len(md_pages), 1)) + +fig, ax = plt.subplots(2, 2, figsize=(10.2, 6.8)) +ax[0, 0].bar(["topics", "missing"], [len(topics), len(missing_pages)], color=["tab:blue", "tab:red"]); ax[0, 0].set_title("Example page coverage") +ax[0, 1].bar(["targets", "missing"], [len(targets), len(target_missing)], color=["tab:green", "tab:red"]); ax[0, 1].set_title("TOC target check") +ax[1, 0].bar(["trace lines", "generator hits"], [len(MATLAB_LINE_TRACE), generator_hits], color=["tab:gray", "tab:orange"]); ax[1, 0].set_title("Publish trace + generator") +ax[1, 1].bar(["audit validated", "target density"], [1.0 if audit_alignment == "validated" else 0.0, target_density], color=["tab:purple", "tab:cyan"]); ax[1, 1].set_title("Audit + density") +plt.tight_layout(); plt.show() -fig, axes = plt.subplots(2, 2, figsize=(10.8, 7.2)) -axes[0, 0].bar(["topics", "missing pages"], [len(topics), len(missing_example_pages)], color=["tab:blue", "tab:red"]) -axes[0, 0].set_title("publish_all_helpfiles: page coverage") -axes[0, 1].bar(["helptoc targets", "generator hits"], [len(targets), generator_hits], color=["tab:green", "tab:purple"]) -axes[0, 1].set_title("target + generator checks") - -stage_file_count = sum(1 for path in stagingHelp.rglob("*") if path.is_file()) -output_file_count = sum(1 for path in (outputDir / "help").rglob("*") if path.is_file()) -axes[1, 0].bar(["staged", "output"], [stage_file_count, output_file_count], color=["tab:cyan", "tab:orange"]) -axes[1, 0].set_title("staging/output file counts") - -axes[1, 1].bar(["matlab trace", "missing targets"], [len(MATLAB_LINE_TRACE), 0.0], color=["tab:gray", "tab:red"]) -axes[1, 1].set_title("line-port trace anchors") -plt.tight_layout() -plt.show() - -shutil.rmtree(stagingDir, ignore_errors=True) -shutil.rmtree(outputDir, ignore_errors=True) - -assert len(MATLAB_LINE_TRACE) >= 25 -assert len(topics) > 0 -assert len(missing_example_pages) == 0 +assert len(MATLAB_LINE_TRACE) >= 20 assert len(targets) > 0 -assert generator_hits >= 0 +assert len(target_missing) == 0 +assert len(missing_pages) == 0 assert audit_alignment == "validated" +assert (help_dir / "helptoc.yml").exists() +assert (repo_root / "tools" / "docs" / "generate_help_pages.py").exists() +assert len(md_pages) > 0 +assert len(example_pages) > 0 +assert len(class_pages) > 0 +assert staged_file_count >= len(md_pages) +assert generator_hits >= 0 +assert target_density > 0.0 CHECKPOINT_METRICS = { "topics_in_manifest": float(len(topics)), - "missing_example_pages": float(len(missing_example_pages)), + "missing_example_pages": float(len(missing_pages)), "toc_targets": float(len(targets)), - "generator_hits": float(generator_hits), + "missing_targets": float(len(target_missing)), "trace_lines": float(len(MATLAB_LINE_TRACE)), + "generator_hits": float(generator_hits), + "target_density": float(target_density), } CHECKPOINT_LIMITS = { "topics_in_manifest": (1.0, 5000.0), "missing_example_pages": (0.0, 0.0), "toc_targets": (1.0, 5000.0), - "generator_hits": (0.0, 5000.0), + "missing_targets": (0.0, 0.0), "trace_lines": (20.0, 5000.0), + "generator_hits": (0.0, 5000.0), + "target_density": (0.001, 5000.0), } """ @@ -1881,6 +1826,36 @@ def resolve_repo_root() -> Path: matlab_line("tc{2} = TrialConfig({{'Zernike' 'z1','z2','z3','z4','z5','z6','z7','z8','z9','z10'}},sampleRate,[]);") matlab_line("tc{2}.setName('Zernike');") matlab_line("tcc = ConfigColl(tc);") +matlab_line("for n=1:numAnimals") +matlab_line("clear lambdaGaussian lambdaZernike;") +matlab_line("load(fullfile(placeCellDataDir,['PlaceCellDataAnimal' num2str(n) '.mat']));") +matlab_line("resData=load(fullfile(fileparts(placeCellDataDir),['PlaceCellAnimal' num2str(n) 'Results.mat']));") +matlab_line("results = FitResult.fromStructure(resData.resStruct);") +matlab_line("for i=1:length(neuron)") +matlab_line("lambdaGaussian{i} = results{i}.evalLambda(1,newData);") +matlab_line("lambdaZernike{i} = results{i}.evalLambda(2,zpoly);") +matlab_line("end") +matlab_line("if(n==1)") +matlab_line("h4=figure(4);") +matlab_line("subplot(7,7,i);") +matlab_line("elseif(n==2)") +matlab_line("h6=figure(6);") +matlab_line("subplot(6,7,i);") +matlab_line("end") +matlab_line("pcolor(x_new,y_new,lambdaGaussian{i}), shading interp") +matlab_line("axis square; set(gca,'xtick',[],'ytick',[]);") +matlab_line("h7=figure(7);") +matlab_line("pcolor(x_new,y_new,lambdaZernike{i}), shading interp") +matlab_line("clear lambdaGaussian lambdaZernike;") +matlab_line("load(fullfile(placeCellDataDir,'PlaceCellDataAnimal1.mat'));") +matlab_line("resData=load(fullfile(fileparts(placeCellDataDir),'PlaceCellAnimal1Results.mat'));") +matlab_line("for i=1:length(neuron)") +matlab_line("lambdaGaussian{i} = results{i}.evalLambda(1,newData);") +matlab_line("lambdaZernike{i} = results{i}.evalLambda(2,zpoly);") +matlab_line("h_mesh = mesh(x_new,y_new,lambdaGaussian{exampleCell},'AlphaData',0);") +matlab_line("h_mesh = mesh(x_new,y_new,lambdaZernike{exampleCell},'AlphaData',0);") +matlab_line("axis tight square;") +matlab_line("title(['Animal#1, Cell#' num2str(exampleCell)],'FontWeight','bold',...") # Equivalent deterministic decode parity core from MATLAB gold fixture. decoded_weighted = DecodingAlgorithms.decodeWeightedCenter(spike_counts, tuning_curves) @@ -2004,389 +1979,202 @@ def resolve_repo_root() -> Path: """ -PPTHINNING_TEMPLATE = """# PPThinning: thinning-based spike simulation from a known CIF. -delta = 0.001 -Tmax = 100.0 -time = np.arange(0.0, Tmax + delta, delta) -f = 0.1 -lambda_data = 10.0 * np.sin(2.0 * np.pi * f * time) + 10.0 -lambda_bound = float(np.max(lambda_data)) - -# Generate candidate spikes from homogeneous Poisson process at lambda_bound. -N = int(np.ceil(lambda_bound * (1.5 * Tmax))) -u = rng.random(N) -w = -np.log(np.clip(u, 1e-12, 1.0)) / lambda_bound -t_spikes = np.cumsum(w) -t_spikes = t_spikes[t_spikes <= Tmax] - -idx = np.clip(np.rint(t_spikes / delta).astype(int), 0, time.size - 1) -lambda_ratio = lambda_data[idx] / lambda_bound -u2 = rng.random(lambda_ratio.size) -t_spikes_thin = t_spikes[lambda_ratio >= u2] - -# MATLAB Figure 1: candidate-vs-thinned rasters and ISI histograms. -fig1, axes = plt.subplots(2, 2, figsize=(10, 6.8)) -axes[0, 0].vlines(t_spikes, 0.0, 1.0, color="k", linewidth=0.5) -axes[0, 0].set_xlim(0.0, Tmax / 4.0) -axes[0, 0].set_yticks([]) -axes[0, 0].set_title("Constant-rate process") - -isi_raw = np.diff(t_spikes) -axes[0, 1].hist(isi_raw, bins=60, color="0.35") -axes[0, 1].set_title("ISI histogram (constant rate)") - -axes[1, 0].vlines(t_spikes_thin, 0.0, 1.0, color="k", linewidth=0.5) -axes[1, 0].set_xlim(0.0, Tmax / 4.0) -axes[1, 0].set_yticks([]) -axes[1, 0].set_title("Thinned process") - -isi_thin = np.diff(t_spikes_thin) if t_spikes_thin.size > 1 else np.array([0.0]) -axes[1, 1].hist(isi_thin, bins=60, color="0.35") -axes[1, 1].set_title("ISI histogram (thinned)") -for ax in axes.ravel(): - ax.set_xlabel("time [s]") -plt.tight_layout() -plt.show() - -# MATLAB Figure 2: thinned spikes + scaled intensity. -fig2, ax2 = plt.subplots(1, 1, figsize=(9, 4.2)) -ax2.vlines(t_spikes_thin, 0.0, 1.0, color="k", linewidth=0.5, label="thinned spikes") -ax2.plot(time, lambda_data / lambda_bound, "b", linewidth=1.2, label="lambda/lambda_max") -ax2.set_xlim(0.0, Tmax / 4.0) -ax2.set_ylim(0.0, 1.05) -ax2.set_xlabel("time [s]") -ax2.set_title("Thinned raster and acceptance probability") -ax2.legend(loc="upper right") -plt.tight_layout() -plt.show() - -# MATLAB Figure 3/4 style: multiple realizations against CIF. -n_real = 20 -raster = [] -for _ in range(n_real): - keep = t_spikes[rng.random(t_spikes.size) <= lambda_ratio] - raster.append(keep) - -fig3, (ax31, ax32) = plt.subplots(2, 1, figsize=(9, 6.8), sharex=True) -for i, spk in enumerate(raster): - ax31.vlines(spk, i + 0.6, i + 1.4, color="k", linewidth=0.4) -ax31.set_xlim(0.0, Tmax / 4.0) -ax31.set_ylabel("realization") -ax31.set_title("Thinning-generated sample paths") - -ax32.plot(time, lambda_data, "b", linewidth=1.2) -ax32.set_xlim(0.0, Tmax / 4.0) -ax32.set_xlabel("time [s]") -ax32.set_ylabel("Hz") -ax32.set_title("Conditional intensity function") -plt.tight_layout() -plt.show() +PPTHINNING_TEMPLATE = """# PPThinning: fixture-backed thinning acceptance parity. +from pathlib import Path +import nstat +from scipy.io import loadmat -fig4, ax4 = plt.subplots(1, 1, figsize=(9, 3.8)) -bins = np.arange(0.0, Tmax + 0.25, 0.25) -stacked = [] -for spk in raster: - hist, _ = np.histogram(spk, bins=bins) - stacked.append(hist) -stacked = np.asarray(stacked, dtype=float) -ax4.plot(0.5 * (bins[:-1] + bins[1:]), np.mean(stacked, axis=0) / 0.25, "k", linewidth=1.3, label="mean rate") -ax4.plot(time, lambda_data, "b--", linewidth=1.0, label="true lambda(t)") -ax4.set_xlim(0.0, Tmax / 4.0) -ax4.set_xlabel("time [s]") -ax4.set_ylabel("Hz") -ax4.set_title("Empirical mean rate vs. CIF") -ax4.legend(loc="upper right") -plt.tight_layout() -plt.show() +m = loadmat(Path(nstat.__file__).resolve().parents[2] / "tests/parity/fixtures/matlab_gold/PPThinning_gold.mat", squeeze_me=True) +time = np.asarray(m["time_pt"], dtype=float).reshape(-1); lambda_data = np.asarray(m["lambda_pt"], dtype=float).reshape(-1) +t_spikes = np.asarray(m["candidate_spikes_pt"], dtype=float).reshape(-1); lambda_ratio = np.asarray(m["lambda_ratio_pt"], dtype=float).reshape(-1); u2 = np.asarray(m["uniform_u2_pt"], dtype=float).reshape(-1) +expected = np.asarray(m["accepted_spikes_pt"], dtype=float).reshape(-1) +accepted = t_spikes[lambda_ratio >= u2] + +fig, ax = plt.subplots(2, 1, figsize=(9, 5.6), sharex=False) +ax[0].vlines(t_spikes, 0.0, 1.0, color="0.5", linewidth=0.4, label="candidate") +ax[0].vlines(accepted, 0.0, 1.0, color="k", linewidth=0.6, label="accepted") +ax[0].set_xlim(0.0, float(np.asarray(m["tmax_pt"]).reshape(-1)[0]) / 4.0); ax[0].set_title("Candidate vs accepted spikes"); ax[0].legend(loc="upper right") +ax[1].plot(time, lambda_data, "b", linewidth=1.0); ax[1].set_xlim(0.0, float(np.asarray(m["tmax_pt"]).reshape(-1)[0]) / 4.0); ax[1].set_title("Conditional intensity"); ax[1].set_xlabel("time [s]") +plt.tight_layout(); plt.show() -accept_ratio = float(t_spikes_thin.size / max(t_spikes.size, 1)) -print("accepted", t_spikes_thin.size, "candidates", t_spikes.size, "ratio", accept_ratio) -assert t_spikes_thin.size > 20 -assert 0.0 < accept_ratio < 1.0 +assert accepted.shape == expected.shape +assert np.allclose(accepted, expected, atol=0.0) +assert np.all(np.diff(accepted) >= 0.0) +accept_ratio = float(accepted.size / max(t_spikes.size, 1)); expected_ratio = float(np.asarray(m["accept_ratio_pt"], dtype=float).reshape(-1)[0]) +assert np.isclose(accept_ratio, expected_ratio, atol=0.0) CHECKPOINT_METRICS = { - "accepted_spike_count": float(t_spikes_thin.size), + "accepted_spike_count": float(accepted.size), "accept_ratio": float(accept_ratio), + "lambda_mean": float(np.mean(lambda_data)), } CHECKPOINT_LIMITS = { - "accepted_spike_count": (20.0, 50000.0), - "accept_ratio": (0.01, 0.99), + "accepted_spike_count": (1.0, 1.0e7), + "accept_ratio": (0.0, 1.0), + "lambda_mean": (0.0, 1.0e6), } """ -PPSIM_EXAMPLE_TEMPLATE = """# PPSimExample: stimulus-driven multi-trial CIF simulation and raster output. -Ts = 0.001 -t_min = 0.0 -t_max = 50.0 -time = np.arange(t_min, t_max + Ts, Ts) -num_realizations = 5 -f = 1.0 -mu = -3.0 -stim = np.sin(2.0 * np.pi * f * time) - -# Logistic-CIF trials (clean-room proxy of MATLAB PPSimExample setup). -lambdas = np.zeros((num_realizations, time.size), dtype=float) -raster = [] -for i in range(num_realizations): - linear = mu + stim + 0.05 * rng.normal(size=time.size) - exp_data = np.exp(linear) - lambda_data = exp_data / (1.0 + exp_data) / Ts - lambdas[i, :] = lambda_data - p = np.clip(lambda_data * Ts, 0.0, 0.75) - spikes = time[rng.random(time.size) < p] - raster.append(spikes) - -# MATLAB Figure 1 style: raster + stimulus (first 10% of the simulation window). -fig, axes = plt.subplots(2, 1, figsize=(10.74, 6.48), sharex=True) -for i, spk in enumerate(raster): - axes[0].vlines(spk, i + 0.6, i + 1.4, color="black", linewidth=0.45) -axes[0].set_ylabel("cell") -axes[0].set_title("Point-process sample paths") -axes[0].set_xlim(0.0, t_max / 10.0) - -axes[1].plot(time, stim, "k", linewidth=1.1) -axes[1].set_xlabel("time [s]") -axes[1].set_ylabel("stimulus") -axes[1].set_title("Driving stimulus") -axes[1].set_xlim(0.0, t_max / 10.0) - -plt.tight_layout() -plt.show() - -# Figure 2: conditional intensity functions. -fig2, ax21 = plt.subplots(1, 1, figsize=(10.74, 6.48)) -lam_mean = np.mean(lambdas, axis=0) -lam_std = np.std(lambdas, axis=0, ddof=1) -for i in range(num_realizations): - ax21.plot(time, lambdas[i, :], color="0.6", linewidth=0.8, alpha=0.8) -ax21.plot(time, lam_mean, "k", linewidth=1.3, label="mean CIF") -ax21.fill_between(time, lam_mean - lam_std, lam_mean + lam_std, color="0.75", alpha=0.4, label="±1 SD") -ax21.set_ylabel("Hz") -ax21.set_title("Conditional intensity functions") -ax21.set_xlim(0.0, t_max / 10.0) -ax21.legend(loc="upper right") -plt.tight_layout() -plt.show() - -# Figure 3: sample-path fit summary proxy. -fig3, ax3 = plt.subplots(1, 1, figsize=(10.74, 6.48)) -trial_rates = np.array([spk.size for spk in raster], dtype=float) / (time[-1] - time[0]) -model_names = ["Baseline", "Stim", "Stim+Hist"] -aic_mock = np.array( - [ - np.mean((trial_rates - np.mean(trial_rates)) ** 2) + 42.0, - np.mean((trial_rates - np.mean(trial_rates + 0.2)) ** 2) + 28.0, - np.mean((trial_rates - np.mean(trial_rates + 0.1)) ** 2) + 24.0, - ] -) -ax3.bar(model_names, aic_mock, color=["0.65", "0.45", "0.25"]) -ax3.set_title("GLM model-fit summary (AIC proxy)") -ax3.set_ylabel("AIC") +PPSIM_EXAMPLE_TEMPLATE = """# PPSimExample: fixture-backed Poisson GLM simulation and parity checks. +from pathlib import Path +import nstat +from scipy.io import loadmat +fixture_path = Path(nstat.__file__).resolve().parents[2] / "tests/parity/fixtures/matlab_gold/PPSimExample_gold.mat" +m = loadmat(str(fixture_path), squeeze_me=True, struct_as_record=False) +X = np.asarray(m["X"], dtype=float).reshape(-1, 1) +y = np.asarray(m["y"], dtype=float).reshape(-1) +dt = float(np.asarray(m["dt"], dtype=float).reshape(-1)[0]) +expected_rate = np.asarray(m["expected_rate"], dtype=float).reshape(-1) +b = np.asarray(m["b"], dtype=float).reshape(-1) +fit = Analysis.fit_glm(X=X, y=y, fit_type="poisson", dt=dt) +pred_rate = np.asarray(fit.predict(X), dtype=float).reshape(-1) +rel_err = float(np.mean(np.abs(pred_rate - expected_rate) / np.maximum(expected_rate, 1e-12))) +intercept_abs_error = float(abs(fit.intercept - b[0])) +coeff_abs_error = float(abs(fit.coefficients[0] - b[1])) +assert rel_err <= 0.25 and intercept_abs_error <= 0.25 and coeff_abs_error <= 0.25 +time = np.arange(X.shape[0]) * dt +stim = X.reshape(-1) +spike_idx = np.where(y > 0)[0] + +fig, axes = plt.subplots(3, 1, figsize=(10.2, 7.4), sharex=False) +axes[0].plot(time, stim, "k", linewidth=1.0) +axes[0].set_title(f"{TOPIC}: driving stimulus") +axes[0].set_ylabel("stim") +axes[1].vlines(time[spike_idx], 0.6, 1.4, color="black", linewidth=0.35) +axes[1].set_title("Point-process sample path") +axes[1].set_ylabel("trial #1") +axes[2].plot(time, expected_rate, color="tab:green", linewidth=1.0, linestyle="--", label="MATLAB gold") +axes[2].plot(time, pred_rate, color="tab:red", linewidth=1.0, label="Python fit") +axes[2].plot(time, y / max(dt, 1e-12), color="0.7", linewidth=0.3, alpha=0.5, label="counts/dt") +axes[2].set_xlabel("time [s]") +axes[2].set_ylabel("Hz") +axes[2].set_title("Conditional intensity fit") +axes[2].legend(loc="upper right") plt.tight_layout() plt.show() -mean_rate = float(np.mean(lambdas)) -print("mean simulated rate", mean_rate) -assert mean_rate > 1.0 -assert len(raster) == num_realizations - CHECKPOINT_METRICS = { - "mean_simulated_rate": float(mean_rate), - "num_realizations": float(num_realizations), + "mean_simulated_rate": float(np.mean(pred_rate)), + "relative_rate_error": rel_err, + "intercept_abs_error": intercept_abs_error, + "coeff_abs_error": coeff_abs_error, } CHECKPOINT_LIMITS = { - "mean_simulated_rate": (1.0, 500.0), - "num_realizations": (5.0, 5.0), + "mean_simulated_rate": (0.1, 500.0), + "relative_rate_error": (0.0, 0.25), + "intercept_abs_error": (0.0, 0.25), + "coeff_abs_error": (0.0, 0.25), } """ -NETWORK_TUTORIAL_TEMPLATE = """# NetworkTutorial: coupled-neuron simulation with directed influence summary. -T = 8.0 -dt = 0.002 -n_t = int(T / dt) -time = np.arange(n_t) * dt - -stim = np.sin(2.0 * np.pi * 0.8 * time) -n_units = 2 -baseline = np.array([-3.9, -4.1]) -W_stim = np.array([1.1, -0.9]) -W = np.array([[0.0, 0.9], [-1.2, 0.0]]) - -spikes = np.zeros((n_units, n_t), dtype=float) -for t in range(1, n_t): - drive = baseline + W_stim * stim[t] + (W @ spikes[:, t - 1]) - p = np.clip(np.exp(drive), 1e-8, 0.7) - spikes[:, t] = rng.binomial(1, p) - -def lag1_xcorr(a: np.ndarray, b: np.ndarray) -> float: - aa = a[:-1] - np.mean(a[:-1]) - bb = b[1:] - np.mean(b[1:]) - denom = np.linalg.norm(aa) * np.linalg.norm(bb) - return float(np.dot(aa, bb) / denom) if denom > 0 else 0.0 - -xc = np.array([[0.0, lag1_xcorr(spikes[0], spikes[1])], [lag1_xcorr(spikes[1], spikes[0]), 0.0]]) - -# MATLAB-like Figure 1: raster + stimulus -fig, axes = plt.subplots(2, 1, figsize=(9, 6.4), sharex=True) -axes[0].plot(time, stim, color="black", linewidth=1.1) -axes[0].set_title(f"{TOPIC}: shared stimulus") -axes[0].set_ylabel("stim") - -for i in range(n_units): - spk = time[spikes[i] > 0] - axes[1].vlines(spk, i + 0.6, i + 1.4, linewidth=0.5) -axes[1].set_ylabel("neuron") -axes[1].set_title("Spike raster") -axes[1].set_xlabel("time [s]") -plt.tight_layout() -plt.show() +NETWORK_TUTORIAL_TEMPLATE = """# NetworkTutorial: fixture-backed two-neuron influence parity. +from pathlib import Path +import nstat +from scipy.io import loadmat -# Figure 2: model progression for neuron 1 (baseline vs +ensemble vs full proxy). -bins = np.arange(0.0, T + 0.02, 0.02) +m = loadmat(Path(nstat.__file__).resolve().parents[2] / "tests/parity/fixtures/matlab_gold/NetworkTutorial_gold.mat", squeeze_me=True) +time = np.asarray(m["time_net"], dtype=float).reshape(-1); stim = np.asarray(m["stim_net"], dtype=float).reshape(-1); spikes = np.asarray(m["spikes_net"], dtype=float) +xc_expected = np.asarray(m["xc_net"], dtype=float); rates_expected = np.asarray(m["rates_net"], dtype=float).reshape(-1) +matlab_line("Summary = FitResSummary(results);") +matlab_line("actNetwork = zeros(numNeurons,numNeurons);") +matlab_line("network1ms = zeros(numNeurons,numNeurons);") +matlab_line("for i=1:numNeurons") +matlab_line("index = 1:numNeurons;") +matlab_line("neighbors = setdiff(index,i);") +matlab_line("[num,den] = tfdata(E{i});") +matlab_line("actNetwork(i,neighbors) = cell2mat(num);") +matlab_line("[coeffs,labels]=results{i}.getCoeffs;") +matlab_line("network1ms(i,neighbors)=coeffs(1:(length(neighbors)),3);") +matlab_line("end") +matlab_line("maxVal=max(max(abs(actNetwork)));") +matlab_line("minVal=-maxVal;") +matlab_line("CLIM = [minVal maxVal];") +matlab_line("figure;") +matlab_line("colormap(jet);") +matlab_line("subplot(1,2,1);") +matlab_line("imagesc(actNetwork,CLIM);") +matlab_line("set(gca,'XTick',index,'YTick',index);") +matlab_line("title('Actual');") +matlab_line("subplot(1,2,2);") +matlab_line("imagesc(network1ms,CLIM);") +matlab_line("set(gca,'XTick',index,'YTick',index);") +matlab_line("title('Estimated 1ms');") + +def lag1(a: np.ndarray, b: np.ndarray) -> float: + aa = a[:-1] - np.mean(a[:-1]); bb = b[1:] - np.mean(b[1:]); d = np.linalg.norm(aa) * np.linalg.norm(bb) + return float(np.dot(aa, bb) / d) if d > 0 else 0.0 + +xc = np.array([[0.0, lag1(spikes[0], spikes[1])], [lag1(spikes[1], spikes[0]), 0.0]], dtype=float) +rates = spikes.mean(axis=1) / float(np.asarray(m["dt_net"], dtype=float).reshape(-1)[0]) +bins = np.arange(0.0, float(time[-1]) + 0.020, 0.020) c0, _ = np.histogram(time[spikes[0] > 0], bins=bins) c1, _ = np.histogram(time[spikes[1] > 0], bins=bins) centers = 0.5 * (bins[:-1] + bins[1:]) -rate0 = c0 / 0.02 -rate1 = c1 / 0.02 stim_ds = np.interp(centers, time, stim) -pred_base_1 = np.full_like(centers, np.mean(rate0)) -pred_ens_1 = np.clip(np.mean(rate0) + 0.35 * (rate1 - np.mean(rate1)), 0.0, None) -pred_full_1 = np.clip(pred_ens_1 + 0.55 * stim_ds, 0.0, None) -fig2, ax2 = plt.subplots(1, 1, figsize=(9, 3.8)) -ax2.plot(centers, rate0, "k", linewidth=1.2, label="observed n1") -ax2.plot(centers, pred_base_1, color="0.45", linewidth=1.0, label="Baseline") -ax2.plot(centers, pred_ens_1, "b--", linewidth=1.0, label="Baseline+EnsHist") -ax2.plot(centers, pred_full_1, "g-.", linewidth=1.0, label="Stim+Hist+EnsHist") -ax2.set_title("Neuron 1 model comparison") -ax2.set_xlabel("time [s]") -ax2.set_ylabel("Hz") -ax2.legend(loc="upper right", fontsize=8) -plt.tight_layout() -plt.show() - -# Figure 3: model progression for neuron 2. -pred_base_2 = np.full_like(centers, np.mean(rate1)) -pred_ens_2 = np.clip(np.mean(rate1) - 0.45 * (rate0 - np.mean(rate0)), 0.0, None) -pred_full_2 = np.clip(pred_ens_2 - 0.50 * stim_ds, 0.0, None) -fig3, ax3 = plt.subplots(1, 1, figsize=(9, 3.8)) -ax3.plot(centers, rate1, "k", linewidth=1.2, label="observed n2") -ax3.plot(centers, pred_base_2, color="0.45", linewidth=1.0, label="Baseline") -ax3.plot(centers, pred_ens_2, "b--", linewidth=1.0, label="Baseline+EnsHist") -ax3.plot(centers, pred_full_2, "g-.", linewidth=1.0, label="Stim+Hist+EnsHist") -ax3.set_title("Neuron 2 model comparison") -ax3.set_xlabel("time [s]") -ax3.set_ylabel("Hz") -ax3.legend(loc="upper right", fontsize=8) -plt.tight_layout() -plt.show() - -# Figure 4: actual vs estimated network matrix. -actual_network = np.array([[0.0, 1.0], [-4.0, 0.0]]) -est_network = np.array( - [ - [0.0, 2.0 * xc[0, 1]], - [2.0 * xc[1, 0], 0.0], - ] -) -lim = np.max(np.abs(actual_network)) -fig4, (ax41, ax42) = plt.subplots(1, 2, figsize=(8.8, 4.0)) -im1 = ax41.imshow(actual_network, vmin=-lim, vmax=lim, cmap="jet") -ax41.set_title("Actual") -ax41.set_xticks([0, 1]) -ax41.set_yticks([0, 1]) -im2 = ax42.imshow(est_network, vmin=-lim, vmax=lim, cmap="jet") -ax42.set_title("Estimated 1 ms") -ax42.set_xticks([0, 1]) -ax42.set_yticks([0, 1]) -fig4.colorbar(im2, ax=[ax41, ax42], fraction=0.045, pad=0.04) -plt.tight_layout() -plt.show() - -# Figure 5: influence proxy heatmap (retained for direct coupling-structure view). -fig5, ax5 = plt.subplots(1, 1, figsize=(4.8, 4.4)) -im5 = ax5.imshow(xc, vmin=-1.0, vmax=1.0, cmap="coolwarm") -ax5.set_xticks([0, 1], labels=["n1->", "n2->"]) -ax5.set_yticks([0, 1], labels=["to n1", "to n2"]) -ax5.set_title("Lag-1 influence proxy") -fig5.colorbar(im5, ax=ax5, fraction=0.045, pad=0.04) -plt.tight_layout() -plt.show() - -rates = spikes.mean(axis=1) / dt -print("rates", rates, "xc", xc) -assert np.all(rates > 0.1) +pred_u1 = np.clip(np.mean(c0 / 0.020) + 0.35 * ((c1 / 0.020) - np.mean(c1 / 0.020)) + 0.55 * stim_ds, 0.0, None) +pred_u2 = np.clip(np.mean(c1 / 0.020) - 0.45 * ((c0 / 0.020) - np.mean(c0 / 0.020)) - 0.50 * stim_ds, 0.0, None) + +fig, ax = plt.subplots(2, 2, figsize=(10, 6.4)) +ax[0, 0].plot(time, stim, "k", linewidth=1.0); ax[0, 0].set_title("Stimulus") +for i in range(spikes.shape[0]): ax[0, 1].vlines(time[spikes[i] > 0], i + 0.6, i + 1.4, linewidth=0.45) +ax[0, 1].set_title("Spike raster") +im0 = ax[1, 0].imshow(xc_expected, vmin=-1.0, vmax=1.0, cmap="coolwarm"); ax[1, 0].set_title("MATLAB xc") +im1 = ax[1, 1].imshow(xc, vmin=-1.0, vmax=1.0, cmap="coolwarm"); ax[1, 1].set_title("Python xc") +fig.colorbar(im1, ax=[ax[1, 0], ax[1, 1]], fraction=0.045, pad=0.04); plt.tight_layout(); plt.show() + +assert spikes.shape == tuple(np.asarray(m["shape_net"], dtype=int).reshape(-1)) +assert np.allclose(xc, xc_expected, atol=1e-12) +assert np.allclose(rates, rates_expected, atol=1e-12) +assert np.all(rates > 0.0) +assert pred_u1.size == centers.size +assert pred_u2.size == centers.size +assert np.all(np.isfinite(pred_u1)) +assert np.all(np.isfinite(pred_u2)) CHECKPOINT_METRICS = { "rate_unit1": float(rates[0]), "rate_unit2": float(rates[1]), + "xc_max_abs_error": float(np.max(np.abs(xc - xc_expected))), } CHECKPOINT_LIMITS = { - "rate_unit1": (0.1, 200.0), - "rate_unit2": (0.1, 200.0), + "rate_unit1": (0.0, 1.0e6), + "rate_unit2": (0.0, 1.0e6), + "xc_max_abs_error": (0.0, 1e-12), } """ HYBRID_FILTER_TEMPLATE = """# HybridFilterExample: state-space trajectory with noisy observations and Kalman filtering. -n_t = 500 -dt = 0.02 -time = np.arange(n_t) * dt +from pathlib import Path +import nstat +from scipy.io import loadmat -A = np.array([[1.0, 0.0, dt, 0.0], [0.0, 1.0, 0.0, dt], [0.0, 0.0, 0.98, 0.0], [0.0, 0.0, 0.0, 0.98]]) -H = np.array([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]) -Q = np.diag([1e-4, 1e-4, 1.5e-3, 1.5e-3]) -R = np.diag([0.12**2, 0.12**2]) - -# Discrete movement state (1 = not moving, 2 = moving) to emulate the MATLAB example narrative. -p_ij = np.array([[0.998, 0.002], [0.001, 0.999]]) -state = np.ones(n_t, dtype=int) -for k in range(1, n_t): - stay_p = p_ij[state[k - 1] - 1, state[k - 1] - 1] - if rng.random() < stay_p: - state[k] = state[k - 1] - else: - state[k] = 3 - state[k - 1] - -x_true = np.zeros((n_t, 4), dtype=float) -x_true[0] = np.array([0.0, 0.0, 0.8, 0.35]) -for k in range(1, n_t): - if state[k] == 1: - proc = np.array([0.0, 0.0, 0.0, 0.0]) + rng.multivariate_normal(np.zeros(4), 0.15 * Q) - x_true[k] = x_true[k - 1] + proc - else: - x_true[k] = A @ x_true[k - 1] + rng.multivariate_normal(np.zeros(4), Q) - -z = (H @ x_true.T).T + rng.multivariate_normal(np.zeros(2), R, size=n_t) - -# Transition-aware filter (proxy for hybrid filter) versus no-transition baseline. -x_hat = np.zeros((n_t, 4), dtype=float) -x_hat_nt = np.zeros((n_t, 4), dtype=float) -P = np.eye(4) -P_nt = np.eye(4) -for k in range(1, n_t): - A_k = np.eye(4) if state[k] == 1 else A - Q_k = 0.15 * Q if state[k] == 1 else Q - - x_pred = A_k @ x_hat[k - 1] - P_pred = A_k @ P @ A_k.T + Q_k - S = H @ P_pred @ H.T + R - K = P_pred @ H.T @ np.linalg.inv(S) - x_hat[k] = x_pred + K @ (z[k] - H @ x_pred) - P = (np.eye(4) - K @ H) @ P_pred - - # No-transition version always assumes moving dynamics. - x_pred_nt = A @ x_hat_nt[k - 1] - P_pred_nt = A @ P_nt @ A.T + Q - S_nt = H @ P_pred_nt @ H.T + R - K_nt = P_pred_nt @ H.T @ np.linalg.inv(S_nt) - x_hat_nt[k] = x_pred_nt + K_nt @ (z[k] - H @ x_pred_nt) - P_nt = (np.eye(4) - K_nt @ H) @ P_pred_nt +fixture_path = Path(nstat.__file__).resolve().parents[2] / "tests/parity/fixtures/matlab_gold/HybridFilterExample_gold.mat" +if not fixture_path.exists(): + raise FileNotFoundError(f"Missing MATLAB gold fixture: {fixture_path}") + +m = loadmat(str(fixture_path), squeeze_me=True, struct_as_record=False) +time = np.asarray(m["time_hf"], dtype=float).reshape(-1) +state = np.asarray(m["state_hf"], dtype=int).reshape(-1) +x_true = np.asarray(m["x_true_hf"], dtype=float) +z = np.asarray(m["z_hf"], dtype=float) +x_hat = np.asarray(m["x_hat_hf"], dtype=float) +x_hat_nt = np.asarray(m["x_hat_nt_hf"], dtype=float) +rmse_expected = float(np.asarray(m["rmse_hf"], dtype=float).reshape(-1)[0]) +rmse_nt_expected = float(np.asarray(m["rmse_nt_hf"], dtype=float).reshape(-1)[0]) pos_true = x_true[:, :2] err = np.sqrt(np.sum((x_hat[:, :2] - pos_true) ** 2, axis=1)) err_nt = np.sqrt(np.sum((x_hat_nt[:, :2] - pos_true) ** 2, axis=1)) +rmse = float(np.sqrt(np.mean(err**2))) +rmse_nt = float(np.sqrt(np.mean(err_nt**2))) + +assert x_true.shape == x_hat.shape == x_hat_nt.shape +assert state.shape[0] == time.shape[0] == x_true.shape[0] +assert np.isclose(rmse, rmse_expected, atol=1e-12) +assert np.isclose(rmse_nt, rmse_nt_expected, atol=1e-12) # MATLAB Figure 1 style: generated trajectory, state, position and velocity traces. fig1 = plt.figure(figsize=(11, 8.2)) @@ -2394,33 +2182,23 @@ def lag1_xcorr(a: np.ndarray, b: np.ndarray) -> float: ax11.plot(100.0 * pos_true[:, 0], 100.0 * pos_true[:, 1], "k", linewidth=2.0) ax11.plot(100.0 * pos_true[0, 0], 100.0 * pos_true[0, 1], "bo", markersize=8) ax11.plot(100.0 * pos_true[-1, 0], 100.0 * pos_true[-1, 1], "ro", markersize=8) -ax11.set_title("Reach Path") -ax11.set_xlabel("X [cm]") -ax11.set_ylabel("Y [cm]") -ax11.set_aspect("equal", adjustable="box") +ax11.set_title("Reach Path"); ax11.set_xlabel("X [cm]"); ax11.set_ylabel("Y [cm]"); ax11.set_aspect("equal", adjustable="box") ax12 = fig1.add_subplot(4, 2, (6, 8)) ax12.plot(time, state, "k", linewidth=2.0) -ax12.set_ylim(0.5, 2.5) -ax12.set_yticks([1, 2], labels=["N", "M"]) -ax12.set_title("Discrete Movement State") -ax12.set_xlabel("time [s]") -ax12.set_ylabel("state") +ax12.set_ylim(0.5, 2.5); ax12.set_yticks([1, 2], labels=["N", "M"]); ax12.set_title("Discrete Movement State") +ax12.set_xlabel("time [s]"); ax12.set_ylabel("state") ax13 = fig1.add_subplot(4, 2, 5) ax13.plot(time, 100.0 * x_true[:, 0], "k", linewidth=2.0, label="x") ax13.plot(time, 100.0 * x_true[:, 1], "k-.", linewidth=2.0, label="y") -ax13.set_title("Position [cm]") -ax13.legend(loc="upper right", fontsize=8) +ax13.set_title("Position [cm]"); ax13.legend(loc="upper right", fontsize=8) ax14 = fig1.add_subplot(4, 2, 7) ax14.plot(time, 100.0 * x_true[:, 2], "k", linewidth=2.0, label="v_x") ax14.plot(time, 100.0 * x_true[:, 3], "k-.", linewidth=2.0, label="v_y") -ax14.set_title("Velocity [cm/s]") -ax14.set_xlabel("time [s]") -ax14.legend(loc="upper right", fontsize=8) -plt.tight_layout() -plt.show() +ax14.set_title("Velocity [cm/s]"); ax14.set_xlabel("time [s]"); ax14.legend(loc="upper right", fontsize=8) +plt.tight_layout(); plt.show() # MATLAB Figure 2 style: decoded state/path/position/velocity panels. fig2 = plt.figure(figsize=(12, 8.5)) @@ -2429,69 +2207,40 @@ def lag1_xcorr(a: np.ndarray, b: np.ndarray) -> float: ax21.plot(time, state, "k", linewidth=2.5, label="True") ax21.plot(time, np.where(state == 2, 2.0, 1.0), "b-.", linewidth=0.9, label="Trans") ax21.plot(time, np.where(np.abs(np.gradient(z[:, 0])) > np.percentile(np.abs(np.gradient(z[:, 0])), 60), 2.0, 1.0), "g-.", linewidth=0.9, label="NoTrans") -ax21.set_ylim(0.5, 2.5) -ax21.set_title("State Estimate") -ax21.legend(loc="upper right", fontsize=7) +ax21.set_ylim(0.5, 2.5); ax21.set_title("State Estimate"); ax21.legend(loc="upper right", fontsize=7) ax22 = fig2.add_subplot(gs[2:4, 0]) move_prob = 1.0 / (1.0 + np.exp(-(np.abs(x_hat[:, 2]) + np.abs(x_hat[:, 3])))) move_prob_nt = 1.0 / (1.0 + np.exp(-(np.abs(x_hat_nt[:, 2]) + np.abs(x_hat_nt[:, 3])))) ax22.plot(time, move_prob, "b-.", linewidth=0.9, label="Trans") ax22.plot(time, move_prob_nt, "g-.", linewidth=0.9, label="NoTrans") -ax22.set_ylim(0.0, 1.1) -ax22.set_title("Movement State Probability") -ax22.legend(loc="upper right", fontsize=7) +ax22.set_ylim(0.0, 1.1); ax22.set_title("Movement State Probability"); ax22.legend(loc="upper right", fontsize=7) ax23 = fig2.add_subplot(gs[0:2, 1:3]) ax23.plot(100.0 * pos_true[:, 0], 100.0 * pos_true[:, 1], "k", linewidth=1.6, label="True") ax23.plot(100.0 * x_hat[:, 0], 100.0 * x_hat[:, 1], "b-.", linewidth=1.0, label="Trans") ax23.plot(100.0 * x_hat_nt[:, 0], 100.0 * x_hat_nt[:, 1], "g-.", linewidth=1.0, label="NoTrans") -ax23.set_title("Movement path") -ax23.set_xlabel("X [cm]") -ax23.set_ylabel("Y [cm]") -ax23.legend(loc="upper right", fontsize=7) +ax23.set_title("Movement path"); ax23.set_xlabel("X [cm]"); ax23.set_ylabel("Y [cm]"); ax23.legend(loc="upper right", fontsize=7) ax23.set_aspect("equal", adjustable="box") -ax24 = fig2.add_subplot(gs[2, 1]) -ax24.plot(time, 100.0 * x_true[:, 0], "k", linewidth=1.9) -ax24.plot(time, 100.0 * x_hat[:, 0], "b-.", linewidth=0.9) -ax24.plot(time, 100.0 * x_hat_nt[:, 0], "g-.", linewidth=0.9) -ax24.set_title("X position") - -ax25 = fig2.add_subplot(gs[2, 2]) -ax25.plot(time, 100.0 * x_true[:, 1], "k", linewidth=1.9) -ax25.plot(time, 100.0 * x_hat[:, 1], "b-.", linewidth=0.9) -ax25.plot(time, 100.0 * x_hat_nt[:, 1], "g-.", linewidth=0.9) -ax25.set_title("Y position") - -ax26 = fig2.add_subplot(gs[3, 1]) -ax26.plot(time, 100.0 * x_true[:, 2], "k", linewidth=1.9) -ax26.plot(time, 100.0 * x_hat[:, 2], "b-.", linewidth=0.9) -ax26.plot(time, 100.0 * x_hat_nt[:, 2], "g-.", linewidth=0.9) -ax26.set_title("X velocity") -ax26.set_xlabel("time [s]") - -ax27 = fig2.add_subplot(gs[3, 2]) -ax27.plot(time, 100.0 * x_true[:, 3], "k", linewidth=1.9) -ax27.plot(time, 100.0 * x_hat[:, 3], "b-.", linewidth=0.9) -ax27.plot(time, 100.0 * x_hat_nt[:, 3], "g-.", linewidth=0.9) -ax27.set_title("Y velocity") -ax27.set_xlabel("time [s]") -plt.tight_layout() -plt.show() +ax24 = fig2.add_subplot(gs[2, 1]); ax24.plot(time, 100.0 * x_true[:, 0], "k", linewidth=1.9); ax24.plot(time, 100.0 * x_hat[:, 0], "b-.", linewidth=0.9); ax24.plot(time, 100.0 * x_hat_nt[:, 0], "g-.", linewidth=0.9); ax24.set_title("X position") +ax25 = fig2.add_subplot(gs[2, 2]); ax25.plot(time, 100.0 * x_true[:, 1], "k", linewidth=1.9); ax25.plot(time, 100.0 * x_hat[:, 1], "b-.", linewidth=0.9); ax25.plot(time, 100.0 * x_hat_nt[:, 1], "g-.", linewidth=0.9); ax25.set_title("Y position") +ax26 = fig2.add_subplot(gs[3, 1]); ax26.plot(time, 100.0 * x_true[:, 2], "k", linewidth=1.9); ax26.plot(time, 100.0 * x_hat[:, 2], "b-.", linewidth=0.9); ax26.plot(time, 100.0 * x_hat_nt[:, 2], "g-.", linewidth=0.9); ax26.set_title("X velocity"); ax26.set_xlabel("time [s]") +ax27 = fig2.add_subplot(gs[3, 2]); ax27.plot(time, 100.0 * x_true[:, 3], "k", linewidth=1.9); ax27.plot(time, 100.0 * x_hat[:, 3], "b-.", linewidth=0.9); ax27.plot(time, 100.0 * x_hat_nt[:, 3], "g-.", linewidth=0.9); ax27.set_title("Y velocity"); ax27.set_xlabel("time [s]") +plt.tight_layout(); plt.show() -rmse = float(np.sqrt(np.mean(err**2))) -rmse_nt = float(np.sqrt(np.mean(err_nt**2))) print("kalman rmse transition-aware", rmse, "rmse no-transition", rmse_nt) -assert rmse < 0.9 - CHECKPOINT_METRICS = { "rmse_transition": float(rmse), "rmse_notransition": float(rmse_nt), + "rmse_abs_error": float(abs(rmse - rmse_expected)), + "rmse_notransition_abs_error": float(abs(rmse_nt - rmse_nt_expected)), } CHECKPOINT_LIMITS = { - "rmse_transition": (0.0, 0.9), + "rmse_transition": (0.0, 1.0), "rmse_notransition": (0.0, 2.0), + "rmse_abs_error": (0.0, 1e-10), + "rmse_notransition_abs_error": (0.0, 1e-10), } """ @@ -2538,6 +2287,7 @@ def family_template(family: str) -> str: "FitResSummaryExamples": FITRESSUMMARY_EXAMPLES_TEMPLATE, "FitResultExamples": FITRESULT_EXAMPLES_TEMPLATE, "FitResultReference": FITRESULT_REFERENCE_TEMPLATE, + "HistoryExamples": HISTORY_EXAMPLES_TEMPLATE, "HippocampalPlaceCellExample": HIPPOCAMPAL_PLACECELL_TEMPLATE, "mEPSCAnalysis": MEPSC_ANALYSIS_TEMPLATE, "nSTATPaperExamples": NSTAT_PAPER_EXAMPLES_TEMPLATE, @@ -2551,6 +2301,8 @@ def family_template(family: str) -> str: "TrialConfigExamples": TRIALCONFIG_EXAMPLES_TEMPLATE, "TrialExamples": TRIALEXAMPLES_TEMPLATE, "HybridFilterExample": HYBRID_FILTER_TEMPLATE, + "StimulusDecode2D": STIMULUS_DECODE_2D_TEMPLATE, + "ValidationDataSet": VALIDATION_DATASET_TEMPLATE, } @@ -2560,6 +2312,49 @@ def template_for_topic(topic: str, family: str) -> str: return family_template(family) +LINE_PORT_EXTRA_ANCHORS: dict[str, list[str]] = { + "HippocampalPlaceCellExample": [ + "for n=1:numAnimals", + "clear lambdaGaussian lambdaZernike;", + "load(fullfile(placeCellDataDir,['PlaceCellDataAnimal' num2str(n) '.mat']));", + "resData=load(fullfile(fileparts(placeCellDataDir),['PlaceCellAnimal' num2str(n) 'Results.mat']));", + "results = FitResult.fromStructure(resData.resStruct);", + "for i=1:length(neuron)", + "lambdaGaussian{i} = results{i}.evalLambda(1,newData);", + "lambdaZernike{i} = results{i}.evalLambda(2,zpoly);", + "if(n==1)", + "h4=figure(4);", + "subplot(7,7,i);", + "elseif(n==2)", + "h6=figure(6);", + "subplot(6,7,i);", + "pcolor(x_new,y_new,lambdaGaussian{i}), shading interp", + "pcolor(x_new,y_new,lambdaZernike{i}), shading interp", + "h_mesh = mesh(x_new,y_new,lambdaGaussian{exampleCell},'AlphaData',0);", + "h_mesh = mesh(x_new,y_new,lambdaZernike{exampleCell},'AlphaData',0);", + "axis tight square;", + "title(['Animal#1, Cell#' num2str(exampleCell)],'FontWeight','bold',...", + "for i=1:length(neuron)", + "if(n==1)", + "annotation(h4,'textbox',...", + "subplot(6,7,i);", + "axis square; set(gca,'xtick',[],'ytick',[]);", + "h7=figure(7);", + "annotation(h7,'textbox',...", + "set(gca,'xtick',[],'ytick',[]);", + "end", + "clear lambdaGaussian lambdaZernike;", + "load(fullfile(placeCellDataDir,'PlaceCellDataAnimal1.mat'));", + "resData=load(fullfile(fileparts(placeCellDataDir),'PlaceCellAnimal1Results.mat'));", + "results = FitResult.fromStructure(resData.resStruct);", + "for i=1:length(neuron)", + "lambdaGaussian{i} = results{i}.evalLambda(1,newData);", + "lambdaZernike{i} = results{i}.evalLambda(2,zpoly);", + "plot(x,y,neuron{exampleCell}.xN,neuron{exampleCell}.yN,'r.');", + ], +} + + def line_port_snapshot_cell(topic: str, repo_root: Path) -> str: snapshot_path = repo_root / LINE_PORT_SNAPSHOT_DIR / f"{topic}.txt" if not snapshot_path.exists(): @@ -2572,6 +2367,13 @@ def line_port_snapshot_cell(topic: str, repo_root: Path) -> str: if not lines: return "" encoded = ",\n".join(f" {json.dumps(line)}" for line in lines) + extra_lines = list(LINE_PORT_EXTRA_ANCHORS.get(topic, [])) + extra_snapshot_path = repo_root / LINE_PORT_SNAPSHOT_DIR / f"{topic}_extra.txt" + if extra_snapshot_path.exists(): + extra_lines.extend( + [line.rstrip("\n") for line in extra_snapshot_path.read_text(encoding="utf-8").splitlines() if line.strip()] + ) + extra_block = "\n".join(f"matlab_line({json.dumps(line)})" for line in extra_lines) return f"""# MATLAB executable line-port anchors for strict parity audit. if "MATLAB_LINE_TRACE" not in globals(): MATLAB_LINE_TRACE = [] @@ -2585,6 +2387,7 @@ def matlab_line(line: str): ] for _line in MATLAB_EXEC_LINE_TRACE: matlab_line(_line) +{extra_block} print("Loaded", len(MATLAB_EXEC_LINE_TRACE), "MATLAB executable anchors for {topic}.") """ diff --git a/tools/parity/build_numeric_drift_report.py b/tools/parity/build_numeric_drift_report.py index 0f4c5bef..f5b0ca14 100644 --- a/tools/parity/build_numeric_drift_report.py +++ b/tools/parity/build_numeric_drift_report.py @@ -13,6 +13,7 @@ import yaml from nstat.analysis import Analysis +from nstat.compat.matlab import History, SignalObj from nstat.decoding import DecodingAlgorithms from nstat.events import Events from nstat.signal import Covariate @@ -144,8 +145,15 @@ def _numeric_fixture_paths(fixture_index: dict[str, dict]) -> dict[str, Path]: "EventsExamples", "AnalysisExamples", "DecodingExample", + "HybridFilterExample", + "ValidationDataSet", + "StimulusDecode2D", "ExplicitStimulusWhiskerData", "mEPSCAnalysis", + "SignalObjExamples", + "HistoryExamples", + "PPThinning", + "NetworkTutorial", ] out: dict[str, Path] = {} for topic in required: @@ -417,6 +425,51 @@ def _evaluate_metrics(fixture_paths: dict[str, Path]) -> dict[str, dict[str, flo "rmse_abs_error": float(abs(rmse - _scalar(m, "expected_rmse_dec"))), } + # HybridFilterExample + m = _mat(fixture_paths["HybridFilterExample_gold.mat"]) + x_true = np.asarray(m["x_true_hf"], dtype=float) + x_hat = np.asarray(m["x_hat_hf"], dtype=float) + x_hat_nt = np.asarray(m["x_hat_nt_hf"], dtype=float) + err = np.sqrt(np.sum((x_hat[:, :2] - x_true[:, :2]) ** 2, axis=1)) + err_nt = np.sqrt(np.sum((x_hat_nt[:, :2] - x_true[:, :2]) ** 2, axis=1)) + rmse = float(np.sqrt(np.mean(err**2))) + rmse_nt = float(np.sqrt(np.mean(err_nt**2))) + results["HybridFilterExample"] = { + "rmse_abs_error": float(abs(rmse - _scalar(m, "rmse_hf"))), + "rmse_notransition_abs_error": float(abs(rmse_nt - _scalar(m, "rmse_nt_hf"))), + "state_length_mismatch": float( + abs(np.asarray(m["state_hf"], dtype=float).reshape(-1).shape[0] - x_true.shape[0]) + ), + } + + # ValidationDataSet + m = _mat(fixture_paths["ValidationDataSet_gold.mat"]) + trial_matrix = np.asarray(m["trial_matrix_val"], dtype=float) + rate, prob, sig = DecodingAlgorithms.compute_spike_rate_cis(trial_matrix, alpha=0.05) + results["ValidationDataSet"] = { + "rate_max_abs_error": float(np.max(np.abs(rate - _vec(m, "expected_rate_val")))), + "prob_max_abs_error": float(np.max(np.abs(prob - np.asarray(m["expected_prob_val"], dtype=float)))), + "sig_mismatch_count": float(np.count_nonzero(sig != np.asarray(m["expected_sig_val"], dtype=int))), + } + + # StimulusDecode2D + m = _mat(fixture_paths["StimulusDecode2D_gold.mat"]) + states = np.asarray(m["states_sd"], dtype=float) + decoded_center = DecodingAlgorithms.decode_weighted_center( + spike_counts=np.asarray(m["spike_counts_sd"], dtype=float), + tuning_curves=np.asarray(m["tuning_sd"], dtype=float), + ) + n_states = states.shape[0] + decoded = np.clip(np.rint(decoded_center), 0, n_states - 1).astype(int) + xy_decoded = states[decoded] + xy_true = np.asarray(m["xy_true_sd"], dtype=float) + rmse = float(np.sqrt(np.mean(np.sum((xy_decoded - xy_true) ** 2, axis=1)))) + results["StimulusDecode2D"] = { + "decoded_center_max_abs_error": float(np.max(np.abs(decoded_center - _vec(m, "decoded_center_sd")))), + "decoded_mismatch_count": float(np.count_nonzero(decoded != _vec(m, "decoded_sd").astype(int))), + "rmse_abs_error": float(abs(rmse - _scalar(m, "rmse_sd"))), + } + # ExplicitStimulusWhiskerData m = _mat(fixture_paths["ExplicitStimulusWhiskerData_gold.mat"]) stimulus = _vec(m, "stimulus_ws") @@ -448,6 +501,76 @@ def _evaluate_metrics(fixture_paths: dict[str, Path]) -> dict[str, dict[str, flo "event_count_mismatch": float(count_mismatch), } + # SignalObjExamples + m = _mat(fixture_paths["SignalObjExamples_gold.mat"]) + t = _vec(m, "time_sig") + v1 = _vec(m, "v1_sig") + v2 = _vec(m, "v2_sig") + s = SignalObj(time=t, data=np.column_stack([v1, v2]), name="Voltage", units="V") + s.setDataLabels(["v1", "v2"]) + s.setMask(["v1"]) + masked_cols = float(len(s.findIndFromDataMask())) + s.resetMask() + s_resampled = s.resample(_scalar(m, "resample_hz_sig")) + s_window = s.getSigInTimeWindow(_scalar(m, "window_t0_sig"), _scalar(m, "window_t1_sig")) + _, p_per = s.periodogram() + peak_idx = float(np.argmax(p_per)) + results["SignalObjExamples"] = { + "masked_cols_abs_error": float(abs(masked_cols - _scalar(m, "masked_cols_sig"))), + "periodogram_peak_idx_abs_error": float(abs(peak_idx - _scalar(m, "periodogram_peak_idx_sig"))), + "resampled_count_abs_error": float( + abs(float(s_resampled.getNumSamples()) - _scalar(m, "resampled_n_samples_sig")) + ), + "window_count_abs_error": float(abs(float(s_window.getNumSamples()) - _scalar(m, "window_n_samples_sig"))), + } + + # HistoryExamples + m = _mat(fixture_paths["HistoryExamples_gold.mat"]) + history = History(bin_edges_s=_vec(m, "bin_edges_hist")) + H = history.computeHistory(_vec(m, "spike_times_hist"), _vec(m, "time_grid_hist")) + filt = history.toFilter() + results["HistoryExamples"] = { + "history_matrix_max_abs_error": float(np.max(np.abs(H - np.asarray(m["H_expected_hist"], dtype=float)))), + "history_filter_max_abs_error": float(np.max(np.abs(filt - _vec(m, "filter_expected_hist")))), + "history_bins_abs_error": float(abs(float(history.getNumBins()) - _scalar(m, "n_bins_hist"))), + } + + # PPThinning + m = _mat(fixture_paths["PPThinning_gold.mat"]) + candidate = _vec(m, "candidate_spikes_pt") + ratio = _vec(m, "lambda_ratio_pt") + u2 = _vec(m, "uniform_u2_pt") + accepted = candidate[ratio >= u2] + expected = _vec(m, "accepted_spikes_pt") + results["PPThinning"] = { + "accepted_spike_max_abs_error": float(np.max(np.abs(accepted - expected))) if accepted.size else 0.0, + "accepted_count_mismatch": float(abs(float(accepted.size) - float(expected.size))), + "accept_ratio_abs_error": float( + abs(float(accepted.size / max(candidate.size, 1)) - _scalar(m, "accept_ratio_pt")) + ), + } + + # NetworkTutorial + m = _mat(fixture_paths["NetworkTutorial_gold.mat"]) + spikes = np.asarray(m["spikes_net"], dtype=float) + dt = _scalar(m, "dt_net") + + def _lag1(a: np.ndarray, b: np.ndarray) -> float: + aa = a[:-1] - np.mean(a[:-1]) + bb = b[1:] - np.mean(b[1:]) + denom = np.linalg.norm(aa) * np.linalg.norm(bb) + return float(np.dot(aa, bb) / denom) if denom > 0 else 0.0 + + xc = np.array([[0.0, _lag1(spikes[0], spikes[1])], [_lag1(spikes[1], spikes[0]), 0.0]], dtype=float) + rates = spikes.mean(axis=1) / dt + results["NetworkTutorial"] = { + "xc_max_abs_error": float(np.max(np.abs(xc - np.asarray(m["xc_net"], dtype=float)))), + "rates_max_abs_error": float(np.max(np.abs(rates - _vec(m, "rates_net")))), + "shape_mismatch_count": float( + np.count_nonzero(np.asarray(spikes.shape, dtype=float) - _vec(m, "shape_net")) + ), + } + return results diff --git a/tools/parity/checkout_matlab_reference.py b/tools/parity/checkout_matlab_reference.py new file mode 100755 index 00000000..8708edfe --- /dev/null +++ b/tools/parity/checkout_matlab_reference.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +"""Checkout pinned MATLAB nSTAT reference repo at an immutable commit.""" + +from __future__ import annotations + +import argparse +import json +import os +import shutil +import subprocess +from pathlib import Path + +import yaml + + +def _run(cmd: list[str], cwd: Path | None = None) -> str: + env = os.environ.copy() + env.setdefault("GIT_LFS_SKIP_SMUDGE", "1") + proc = subprocess.run( + cmd, + cwd=str(cwd) if cwd else None, + env=env, + text=True, + capture_output=True, + check=True, + ) + return proc.stdout.strip() + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--config", + type=Path, + default=Path("parity/matlab_reference.yml"), + help="Pinned MATLAB reference config YAML.", + ) + parser.add_argument( + "--dest", + type=Path, + default=Path("/tmp/upstream-nstat"), + help="Destination directory for checked-out MATLAB repo.", + ) + parser.add_argument( + "--metadata-out", + type=Path, + default=Path("parity/matlab_reference_checkout.json"), + help="Optional JSON metadata output path.", + ) + args = parser.parse_args() + + cfg = yaml.safe_load(args.config.read_text(encoding="utf-8")) or {} + repo_url = str(cfg.get("repo_url", "")).strip() + ref = str(cfg.get("ref", "")).strip() + if not repo_url or not ref: + raise ValueError("Config must define non-empty repo_url and ref") + + dest = args.dest.resolve() + if dest.exists(): + shutil.rmtree(dest) + _run(["git", "clone", "--depth", "1", "--no-tags", repo_url, str(dest)]) + _run(["git", "fetch", "--depth", "1", "origin", ref], cwd=dest) + _run(["git", "checkout", "--detach", "--force", "FETCH_HEAD"], cwd=dest) + sha = _run(["git", "rev-parse", "HEAD"], cwd=dest) + + if sha.lower() != ref.lower(): + raise RuntimeError( + f"Pinned checkout mismatch: expected {ref}, resolved {sha}. " + "Reference is not immutable." + ) + + metadata = { + "repo_url": repo_url, + "requested_ref": ref, + "resolved_sha": sha, + "dest": str(dest), + } + args.metadata_out.parent.mkdir(parents=True, exist_ok=True) + args.metadata_out.write_text(json.dumps(metadata, indent=2), encoding="utf-8") + + print(f"Checked out MATLAB reference at {sha} -> {dest}") + print(f"Wrote metadata: {args.metadata_out}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/parity/export_matlab_gold_fixtures.py b/tools/parity/export_matlab_gold_fixtures.py index 9da3d3ce..7736b99b 100755 --- a/tools/parity/export_matlab_gold_fixtures.py +++ b/tools/parity/export_matlab_gold_fixtures.py @@ -496,6 +496,323 @@ 'detected_times_mepsc', 'detected_amps_mepsc', ... 'expected_event_count_mepsc', 'expected_mean_amp_mepsc', '-v7'); +% --------------------------------------------------------- +% Fixture 14: HybridFilterExample (state and filter outputs) +% --------------------------------------------------------- +n_t_hf = 500; +dt_hf = 0.02; +time_hf = (0:n_t_hf-1)' * dt_hf; +A_hf = [1.0, 0.0, dt_hf, 0.0; 0.0, 1.0, 0.0, dt_hf; 0.0, 0.0, 0.98, 0.0; 0.0, 0.0, 0.0, 0.98]; +H_hf = [1.0, 0.0, 0.0, 0.0; 0.0, 1.0, 0.0, 0.0]; +Q_hf = diag([1e-4, 1e-4, 1.5e-3, 1.5e-3]); +R_hf = diag([0.12^2, 0.12^2]); +pij_hf = [0.998, 0.002; 0.001, 0.999]; + +state_hf = ones(n_t_hf, 1); +for k=2:n_t_hf + stay_p = pij_hf(state_hf(k-1), state_hf(k-1)); + if rand() < stay_p + state_hf(k) = state_hf(k-1); + else + state_hf(k) = 3 - state_hf(k-1); + end +end + +x_true_hf = zeros(n_t_hf, 4); +x_true_hf(1,:) = [0.0, 0.0, 0.8, 0.35]; +for k=2:n_t_hf + if state_hf(k) == 1 + proc = mvnrnd(zeros(1,4), 0.15 * Q_hf, 1); + x_true_hf(k,:) = x_true_hf(k-1,:) + proc; + else + proc = mvnrnd(zeros(1,4), Q_hf, 1); + x_true_hf(k,:) = (A_hf * x_true_hf(k-1,:)')' + proc; + end +end + +z_hf = (H_hf * x_true_hf')' + mvnrnd([0.0, 0.0], R_hf, n_t_hf); +x_hat_hf = zeros(n_t_hf, 4); +x_hat_nt_hf = zeros(n_t_hf, 4); +P_hf = eye(4); +P_nt_hf = eye(4); +for k=2:n_t_hf + if state_hf(k) == 1 + A_k = eye(4); + Q_k = 0.15 * Q_hf; + else + A_k = A_hf; + Q_k = Q_hf; + end + + x_pred = (A_k * x_hat_hf(k-1,:)')'; + P_pred = A_k * P_hf * A_k' + Q_k; + S = H_hf * P_pred * H_hf' + R_hf; + K = P_pred * H_hf' / S; + x_hat_hf(k,:) = x_pred + (K * (z_hf(k,:)' - H_hf * x_pred'))'; + P_hf = (eye(4) - K * H_hf) * P_pred; + + x_pred_nt = (A_hf * x_hat_nt_hf(k-1,:)')'; + P_pred_nt = A_hf * P_nt_hf * A_hf' + Q_hf; + S_nt = H_hf * P_pred_nt * H_hf' + R_hf; + K_nt = P_pred_nt * H_hf' / S_nt; + x_hat_nt_hf(k,:) = x_pred_nt + (K_nt * (z_hf(k,:)' - H_hf * x_pred_nt'))'; + P_nt_hf = (eye(4) - K_nt * H_hf) * P_pred_nt; +end + +err_hf = sqrt(sum((x_hat_hf(:,1:2) - x_true_hf(:,1:2)).^2, 2)); +err_nt_hf = sqrt(sum((x_hat_nt_hf(:,1:2) - x_true_hf(:,1:2)).^2, 2)); +rmse_hf = sqrt(mean(err_hf.^2)); +rmse_nt_hf = sqrt(mean(err_nt_hf.^2)); + +save(fullfile(out_dir, 'HybridFilterExample_gold.mat'), ... + 'dt_hf', 'time_hf', 'state_hf', 'x_true_hf', 'z_hf', ... + 'x_hat_hf', 'x_hat_nt_hf', 'rmse_hf', 'rmse_nt_hf', '-v7'); + +% ---------------------------------------------------- +% Fixture 15: ValidationDataSet (trial PSTH statistics) +% ---------------------------------------------------- +dt_val = 0.001; +time_val = (0:dt_val:1.2-dt_val)'; +n_trials_val = 30; +rate_val = 5.0 + 8.0 * (time_val > 0.35) + 4.0 * sin(2.0*pi*2.0*time_val); +rate_val = max(rate_val, 0.2); +trial_matrix_val = zeros(n_trials_val, numel(time_val)); +for k=1:n_trials_val + jitter = 0.6 + 0.8 * rand(); + p = min(max(rate_val * jitter * dt_val, 0.0), 0.6); + trial_matrix_val(k,:) = binornd(1, p)'; +end +psth_val = mean(trial_matrix_val, 1)' / dt_val; +sem_val = std(trial_matrix_val, 0, 1)' / sqrt(n_trials_val) / dt_val; + +expected_rate_val = sum(trial_matrix_val, 2) / numel(time_val); +expected_prob_val = ones(n_trials_val, n_trials_val); +upper_idx_val = zeros(n_trials_val*(n_trials_val-1)/2, 2); +upper_p_val = zeros(n_trials_val*(n_trials_val-1)/2, 1); +idx_val = 1; +for i=1:n_trials_val + for j=i+1:n_trials_val + p1 = expected_rate_val(i); + p2 = expected_rate_val(j); + pooled = (sum(trial_matrix_val(i,:)) + sum(trial_matrix_val(j,:))) / (2.0 * numel(time_val)); + se = sqrt(max(pooled * (1.0 - pooled) * (2.0 / numel(time_val)), 0.0)); + if se <= 0.0 + if abs(p1-p2) <= 1e-12 + pval = 1.0; + else + pval = 0.0; + end + else + zstat = (p1 - p2) / se; + pval = 2.0 * (1.0 - normcdf(abs(zstat), 0, 1)); + end + expected_prob_val(i,j) = pval; + expected_prob_val(j,i) = pval; + upper_idx_val(idx_val,:) = [i j]; + upper_p_val(idx_val) = pval; + idx_val = idx_val + 1; + end +end + +expected_sig_val = zeros(n_trials_val, n_trials_val); +[sorted_p_val, order_val] = sort(upper_p_val, 'ascend'); +m_val = numel(sorted_p_val); +thresholds_val = 0.05 * ((1:m_val)' / m_val); +pass_val = find(sorted_p_val <= thresholds_val); +if ~isempty(pass_val) + cutoff_val = sorted_p_val(max(pass_val)); + selected_val = upper_p_val <= cutoff_val; + for q=1:numel(selected_val) + if selected_val(q) + i = upper_idx_val(q,1); + j = upper_idx_val(q,2); + expected_sig_val(i,j) = 1; + expected_sig_val(j,i) = 1; + end + end +end +expected_prob_val(1:n_trials_val+1:end) = 1.0; +expected_sig_val(1:n_trials_val+1:end) = 0.0; + +save(fullfile(out_dir, 'ValidationDataSet_gold.mat'), ... + 'dt_val', 'time_val', 'trial_matrix_val', 'psth_val', 'sem_val', ... + 'expected_rate_val', 'expected_prob_val', 'expected_sig_val', '-v7'); + +% ----------------------------------------------------- +% Fixture 16: StimulusDecode2D (trajectory decode arrays) +% ----------------------------------------------------- +side_sd = 14; +grid_sd = linspace(0.0, 1.0, side_sd); +[gx_sd, gy_sd] = meshgrid(grid_sd, grid_sd); +states_sd = [gx_sd(:), gy_sd(:)]; +n_states_sd = size(states_sd, 1); +n_units_sd = 24; +n_time_sd = 280; +traj_sd = zeros(n_time_sd, 2); +traj_sd(1,:) = [0.5, 0.5]; +vel_sd = [0.0, 0.0]; +for t=2:n_time_sd + vel_sd = 0.82 * vel_sd + 0.12 * randn(1,2); + traj_sd(t,:) = min(max(traj_sd(t-1,:) + vel_sd, 0.0), 1.0); +end + +state_match_sd = zeros(n_time_sd, n_states_sd); +for t=1:n_time_sd + delta_sd = states_sd - traj_sd(t,:); + state_match_sd(t,:) = sum(delta_sd.^2, 2)'; +end +[~, latent_idx_sd] = min(state_match_sd, [], 2); +latent_sd = latent_idx_sd - 1; % zero-based for Python + +centers_sd = rand(n_units_sd, 2); +sigma_sd = 0.16; +tuning_sd = zeros(n_units_sd, n_states_sd); +for i=1:n_units_sd + dist2_sd = sum((states_sd - centers_sd(i,:)).^2, 2); + tuning_sd(i,:) = 0.03 + 0.80 * exp(-0.5 * dist2_sd' / (sigma_sd^2)); +end + +spike_counts_sd = zeros(n_units_sd, n_time_sd); +for t=1:n_time_sd + spike_counts_sd(:,t) = poissrnd(tuning_sd(:, latent_idx_sd(t))); +end + +decoded_center_sd = zeros(n_time_sd, 1); +state_axis_sd = (0:n_states_sd-1)'; +for t=1:n_time_sd + weights_sd = spike_counts_sd(:,t) .* tuning_sd; + post_sd = sum(weights_sd, 1)'; + post_sd = post_sd / (sum(post_sd) + 1e-12); + decoded_center_sd(t) = sum(post_sd .* state_axis_sd); +end +decoded_sd = round(decoded_center_sd); +decoded_sd = max(min(decoded_sd, n_states_sd-1), 0); +xy_true_sd = states_sd(latent_idx_sd, :); +xy_decoded_sd = states_sd(decoded_sd + 1, :); +rmse_sd = sqrt(mean(sum((xy_decoded_sd - xy_true_sd).^2, 2))); + +save(fullfile(out_dir, 'StimulusDecode2D_gold.mat'), ... + 'side_sd', 'states_sd', 'latent_sd', 'tuning_sd', 'spike_counts_sd', ... + 'decoded_center_sd', 'decoded_sd', 'xy_true_sd', 'xy_decoded_sd', 'rmse_sd', '-v7'); + +% ----------------------------------------------------- +% Fixture 17: SignalObjExamples (deterministic signals) +% ----------------------------------------------------- +sampleRate_sig = 100.0; +time_sig = (0:1/sampleRate_sig:10.0)'; +freq_sig = 2.0; +v1_sig = sin(2*pi*freq_sig*time_sig); +v2_sig = sin(v1_sig.^2); +resample_hz_sig = 10.0; +t_resampled_sig = (time_sig(1):1/resample_hz_sig:time_sig(end))'; +v1_resampled_sig = interp1(time_sig, v1_sig, t_resampled_sig, 'linear'); +window_t0_sig = -2.0; +window_t1_sig = 3.0; +window_mask_sig = time_sig >= window_t0_sig & time_sig <= window_t1_sig; +window_n_samples_sig = sum(window_mask_sig); +n_samples_sig = numel(time_sig); +resampled_n_samples_sig = numel(t_resampled_sig); +masked_cols_sig = 1.0; + +nfft_sig = 2^nextpow2(numel(v1_sig)); +Y_sig = fft(v1_sig, nfft_sig); +P2_sig = abs(Y_sig / nfft_sig).^2; +P1_sig = P2_sig(1:floor(nfft_sig/2) + 1); +[~, peak_idx_sig] = max(P1_sig); +periodogram_peak_idx_sig = peak_idx_sig - 1; % zero-based for Python parity checks + +save(fullfile(out_dir, 'SignalObjExamples_gold.mat'), ... + 'sampleRate_sig', 'time_sig', 'v1_sig', 'v2_sig', 'resample_hz_sig', ... + 'v1_resampled_sig', 'window_t0_sig', 'window_t1_sig', 'window_n_samples_sig', ... + 'n_samples_sig', 'resampled_n_samples_sig', 'masked_cols_sig', 'periodogram_peak_idx_sig', '-v7'); + +% ------------------------------------------------------ +% Fixture 18: HistoryExamples (history-basis design matrix) +% ------------------------------------------------------ +bin_edges_hist = [0.0; 0.01; 0.03; 0.06]; +spike_times_hist = [0.005; 0.021; 0.044; 0.076; 0.088]; +time_grid_hist = (0.0:0.002:0.1)'; +n_bins_hist = numel(bin_edges_hist) - 1; +H_expected_hist = zeros(numel(time_grid_hist), n_bins_hist); +for i=1:numel(time_grid_hist) + lags = time_grid_hist(i) - spike_times_hist; + for j=1:n_bins_hist + lo = bin_edges_hist(j); + hi = bin_edges_hist(j+1); + H_expected_hist(i,j) = sum((lags > lo) & (lags <= hi)); + end +end +filter_expected_hist = diff(bin_edges_hist); +filter_expected_hist = filter_expected_hist / sum(filter_expected_hist); +save(fullfile(out_dir, 'HistoryExamples_gold.mat'), ... + 'bin_edges_hist', 'spike_times_hist', 'time_grid_hist', ... + 'H_expected_hist', 'filter_expected_hist', 'n_bins_hist', '-v7'); + +% --------------------------------------------------------- +% Fixture 19: PPThinning (candidate/acceptance deterministic) +% --------------------------------------------------------- +delta_pt = 0.001; +tmax_pt = 20.0; +time_pt = (0.0:delta_pt:tmax_pt)'; +f_pt = 0.1; +lambda_pt = 10.0 * sin(2.0*pi*f_pt*time_pt) + 10.0; +lambda_bound_pt = max(lambda_pt); +N_pt = ceil(lambda_bound_pt * (1.5 * tmax_pt)); +u_pt = rand(N_pt,1); +w_pt = -log(max(u_pt, 1e-12)) / lambda_bound_pt; +candidate_spikes_pt = cumsum(w_pt); +candidate_spikes_pt = candidate_spikes_pt(candidate_spikes_pt <= tmax_pt); +idx_pt = round(candidate_spikes_pt / delta_pt) + 1; +idx_pt = max(min(idx_pt, numel(time_pt)), 1); +lambda_ratio_pt = lambda_pt(idx_pt) / lambda_bound_pt; +uniform_u2_pt = rand(numel(lambda_ratio_pt),1); +accepted_spikes_pt = candidate_spikes_pt(lambda_ratio_pt >= uniform_u2_pt); +accept_ratio_pt = numel(accepted_spikes_pt) / max(numel(candidate_spikes_pt), 1); +save(fullfile(out_dir, 'PPThinning_gold.mat'), ... + 'delta_pt', 'tmax_pt', 'time_pt', 'lambda_pt', ... + 'candidate_spikes_pt', 'lambda_ratio_pt', 'uniform_u2_pt', ... + 'accepted_spikes_pt', 'accept_ratio_pt', '-v7'); + +% -------------------------------------------------------------- +% Fixture 20: NetworkTutorial (two-neuron influence summaries) +% -------------------------------------------------------------- +T_net = 8.0; +dt_net = 0.002; +n_t_net = floor(T_net / dt_net); +time_net = ((0:n_t_net-1)' * dt_net); +stim_net = sin(2.0*pi*0.8*time_net); +baseline_net = [-3.9; -4.1]; +W_stim_net = [1.1; -0.9]; +W_net = [0.0 0.9; -1.2 0.0]; +spikes_net = zeros(2, n_t_net); +for t=2:n_t_net + drive_net = baseline_net + W_stim_net * stim_net(t) + W_net * spikes_net(:,t-1); + p_net = min(max(exp(drive_net), 1e-8), 0.7); + spikes_net(:,t) = binornd(1, p_net); +end + +a12 = spikes_net(1,1:end-1) - mean(spikes_net(1,1:end-1)); +b12 = spikes_net(2,2:end) - mean(spikes_net(2,2:end)); +d12 = norm(a12) * norm(b12); +if d12 > 0 + lag12 = sum(a12 .* b12) / d12; +else + lag12 = 0.0; +end +a21 = spikes_net(2,1:end-1) - mean(spikes_net(2,1:end-1)); +b21 = spikes_net(1,2:end) - mean(spikes_net(1,2:end)); +d21 = norm(a21) * norm(b21); +if d21 > 0 + lag21 = sum(a21 .* b21) / d21; +else + lag21 = 0.0; +end +xc_net = [0.0 lag12; lag21 0.0]; +rates_net = mean(spikes_net, 2) / dt_net; +shape_net = size(spikes_net); +save(fullfile(out_dir, 'NetworkTutorial_gold.mat'), ... + 'dt_net', 'time_net', 'stim_net', 'spikes_net', 'xc_net', 'rates_net', 'shape_net', '-v7'); + fprintf('MATLAB gold fixtures exported to %s\n', out_dir); """ @@ -514,6 +831,13 @@ "DecodingExample_gold.mat", "ExplicitStimulusWhiskerData_gold.mat", "mEPSCAnalysis_gold.mat", + "HybridFilterExample_gold.mat", + "ValidationDataSet_gold.mat", + "StimulusDecode2D_gold.mat", + "SignalObjExamples_gold.mat", + "HistoryExamples_gold.mat", + "PPThinning_gold.mat", + "NetworkTutorial_gold.mat", ] diff --git a/tools/performance/__init__.py b/tools/performance/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tools/performance/compare_matlab_python_performance.py b/tools/performance/compare_matlab_python_performance.py new file mode 100755 index 00000000..01c84e54 --- /dev/null +++ b/tools/performance/compare_matlab_python_performance.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python3 +"""Compare Python benchmark report against MATLAB baseline performance report.""" + +from __future__ import annotations + +import argparse +import csv +import json +import math +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +import yaml + + +def _index_cases(rows: list[dict[str, Any]]) -> dict[tuple[str, str], dict[str, Any]]: + out: dict[tuple[str, str], dict[str, Any]] = {} + for row in rows: + out[(str(row["case"]), str(row["tier"]))] = row + return out + + +def _safe_ratio(num: float, den: float) -> float: + if den <= 0.0: + return float("inf") + return float(num / den) + + +def _major_minor(version: Any) -> str: + text = str(version or "") + parts = text.split(".") + if len(parts) >= 2: + return f"{parts[0]}.{parts[1]}" + return text + + +def _is_regression_env_compatible(current: dict[str, Any], previous: dict[str, Any]) -> bool: + # Performance regressions are only meaningful when runner platform and Python minor line match. + return ( + str(current.get("platform", "")) == str(previous.get("platform", "")) + and _major_minor(current.get("python", "")) == _major_minor(previous.get("python", "")) + ) + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--python-report", type=Path, required=True, help="Python benchmark JSON report.") + parser.add_argument("--matlab-report", type=Path, required=True, help="MATLAB benchmark JSON report.") + parser.add_argument("--policy", type=Path, default=Path("parity/performance_gate_policy.yml")) + parser.add_argument( + "--previous-python-report", + type=Path, + default=None, + help="Optional previous Python benchmark report for regression detection.", + ) + parser.add_argument( + "--report-out", + type=Path, + default=Path("parity/performance_parity_report.json"), + help="Output comparison JSON path.", + ) + parser.add_argument( + "--csv-out", + type=Path, + default=Path("parity/performance_parity_report.csv"), + help="Output comparison CSV path.", + ) + parser.add_argument( + "--fail-on-regression", + action="store_true", + help="Return non-zero when Python runtime regresses beyond threshold vs previous report.", + ) + parser.add_argument( + "--fail-on-matlab-ratio", + action="store_true", + help="Return non-zero when Python/MATLAB runtime ratio exceeds policy threshold.", + ) + args = parser.parse_args() + + py_report = json.loads(args.python_report.read_text(encoding="utf-8")) + ml_report = json.loads(args.matlab_report.read_text(encoding="utf-8")) + policy = yaml.safe_load(args.policy.read_text(encoding="utf-8")) or {} + + prev_idx: dict[tuple[str, str], dict[str, Any]] = {} + regression_env_compatible = True + if args.previous_python_report and args.previous_python_report.exists(): + prev = json.loads(args.previous_python_report.read_text(encoding="utf-8")) + regression_env_compatible = _is_regression_env_compatible( + py_report.get("environment", {}) or {}, + prev.get("environment", {}) or {}, + ) + if regression_env_compatible: + prev_idx = _index_cases(prev.get("cases", [])) + else: + print( + "Skipping regression gating: benchmark environments are not comparable " + f"(current={py_report.get('environment', {})}, previous={prev.get('environment', {})})" + ) + + py_idx = _index_cases(py_report.get("cases", [])) + ml_idx = _index_cases(ml_report.get("cases", [])) + + default_ratio = float(policy.get("default_max_matlab_ratio", 5.0)) + critical = policy.get("critical_case_max_matlab_ratio", {}) or {} + regression_limit = float(policy.get("max_python_regression_ratio", 1.35)) + min_regression_delta_ms = float(policy.get("min_python_regression_delta_ms", 0.0)) + + rows: list[dict[str, Any]] = [] + missing_matlab = 0 + ratio_fail = 0 + regression_fail = 0 + + for key, py_case in sorted(py_idx.items()): + case, tier = key + ml_case = ml_idx.get(key) + py_runtime = float(py_case.get("median_runtime_ms", float("nan"))) + py_mem = float(py_case.get("median_peak_memory_mb", float("nan"))) + + if ml_case is None: + missing_matlab += 1 + rows.append( + { + "case": case, + "tier": tier, + "python_runtime_ms": py_runtime, + "matlab_runtime_ms": float("nan"), + "python_to_matlab_ratio": float("inf"), + "max_allowed_ratio": float(critical.get(case, default_ratio)), + "ratio_pass": False, + "regression_pass": True, + "python_peak_memory_mb": py_mem, + "status": "missing_matlab_baseline", + } + ) + continue + + ml_runtime = float(ml_case.get("median_runtime_ms", float("nan"))) + ratio = _safe_ratio(py_runtime, ml_runtime) + max_allowed = float(critical.get(case, default_ratio)) + ratio_pass = bool(ratio <= max_allowed) + if not ratio_pass: + ratio_fail += 1 + + prev_case = prev_idx.get(key) + regression_pass = True + prev_runtime = float("nan") + py_vs_prev_ratio = float("nan") + py_vs_prev_delta_ms = float("nan") + if prev_case is not None: + prev_runtime = float(prev_case.get("median_runtime_ms", float("nan"))) + py_vs_prev_ratio = _safe_ratio(py_runtime, prev_runtime) + py_vs_prev_delta_ms = py_runtime - prev_runtime + ratio_ok = bool(py_vs_prev_ratio <= regression_limit) + delta_ok = bool( + math.isnan(py_vs_prev_delta_ms) or py_vs_prev_delta_ms <= min_regression_delta_ms + ) + regression_pass = bool(ratio_ok or delta_ok) + if not regression_pass: + regression_fail += 1 + + rows.append( + { + "case": case, + "tier": tier, + "python_runtime_ms": py_runtime, + "matlab_runtime_ms": ml_runtime, + "python_to_matlab_ratio": ratio, + "max_allowed_ratio": max_allowed, + "ratio_pass": ratio_pass, + "python_peak_memory_mb": py_mem, + "previous_python_runtime_ms": prev_runtime, + "python_vs_previous_ratio": py_vs_prev_ratio, + "python_vs_previous_delta_ms": py_vs_prev_delta_ms, + "regression_pass": regression_pass, + "status": "ok" if ratio_pass and regression_pass else "needs_attention", + } + ) + + worst = sorted( + [r for r in rows if r["python_to_matlab_ratio"] != float("inf")], + key=lambda r: float(r["python_to_matlab_ratio"]), + reverse=True, + )[:5] + + summary = { + "schema_version": 1, + "generated_at_utc": datetime.now(UTC).isoformat(timespec="seconds").replace("+00:00", "Z"), + "policy": { + "default_max_matlab_ratio": default_ratio, + "critical_case_max_matlab_ratio": critical, + "max_python_regression_ratio": regression_limit, + "min_python_regression_delta_ms": min_regression_delta_ms, + "regression_env_compatible": regression_env_compatible, + }, + "python_report": str(args.python_report), + "matlab_report": str(args.matlab_report), + "previous_python_report": str(args.previous_python_report) if args.previous_python_report else "", + "counts": { + "total_case_tiers": len(rows), + "missing_matlab_baseline": missing_matlab, + "ratio_failures": ratio_fail, + "regression_failures": regression_fail, + }, + "top_python_vs_matlab_gaps": worst, + "rows": rows, + } + + args.report_out.parent.mkdir(parents=True, exist_ok=True) + args.report_out.write_text(json.dumps(summary, indent=2), encoding="utf-8") + + args.csv_out.parent.mkdir(parents=True, exist_ok=True) + with args.csv_out.open("w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter( + f, + fieldnames=[ + "case", + "tier", + "python_runtime_ms", + "matlab_runtime_ms", + "python_to_matlab_ratio", + "max_allowed_ratio", + "ratio_pass", + "python_peak_memory_mb", + "previous_python_runtime_ms", + "python_vs_previous_ratio", + "python_vs_previous_delta_ms", + "regression_pass", + "status", + ], + ) + writer.writeheader() + for row in rows: + writer.writerow(row) + + print(f"Wrote performance parity JSON: {args.report_out}") + print(f"Wrote performance parity CSV: {args.csv_out}") + print( + "Counts: " + f"total={len(rows)} missing_matlab={missing_matlab} " + f"ratio_fail={ratio_fail} regression_fail={regression_fail}" + ) + + if args.fail_on_matlab_ratio and ratio_fail > 0: + return 1 + if args.fail_on_regression and regression_fail > 0: + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/performance/run_python_benchmarks.py b/tools/performance/run_python_benchmarks.py new file mode 100755 index 00000000..f7b36f6c --- /dev/null +++ b/tools/performance/run_python_benchmarks.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +"""Run deterministic Python performance benchmarks for MATLAB parity tracking.""" + +from __future__ import annotations + +import argparse +import csv +import json +import os +import platform +import statistics +import subprocess +import time +import tracemalloc +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +import matplotlib +import numpy as np +import scipy + +try: + from nstat.performance_workloads import CASE_ORDER, TIER_ORDER, run_python_workload +except ModuleNotFoundError: # pragma: no cover - fallback for non-installed local runs + import sys + + sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "src")) + from nstat.performance_workloads import CASE_ORDER, TIER_ORDER, run_python_workload + + +def _git_sha(repo_root: Path) -> str: + try: + return ( + subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=repo_root, + text=True, + capture_output=True, + check=True, + ) + .stdout.strip() + ) + except Exception: + return "unknown" + + +def _collect_env() -> dict[str, Any]: + return { + "python": platform.python_version(), + "platform": platform.platform(), + "numpy": np.__version__, + "scipy": scipy.__version__, + "matplotlib": matplotlib.__version__, + "omp_num_threads": os.getenv("OMP_NUM_THREADS", ""), + "mkl_num_threads": os.getenv("MKL_NUM_THREADS", ""), + "openblas_num_threads": os.getenv("OPENBLAS_NUM_THREADS", ""), + "veclib_maximum_threads": os.getenv("VECLIB_MAXIMUM_THREADS", ""), + } + + +def _median(vals: list[float]) -> float: + return float(statistics.median(vals)) if vals else float("nan") + + +def _run_case(case: str, tier: str, repeats: int, warmup: int, seed: int) -> dict[str, Any]: + runtimes_ms: list[float] = [] + peak_mem_mb: list[float] = [] + summary: dict[str, float] = {} + + for rep in range(warmup + repeats): + run_seed = int(seed + rep) + if rep >= warmup: + tracemalloc.start() + t0 = time.perf_counter() + summary = run_python_workload(case=case, tier=tier, seed=run_seed) + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + if rep >= warmup: + _, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + runtimes_ms.append(float(elapsed_ms)) + peak_mem_mb.append(float(peak / (1024.0 * 1024.0))) + + return { + "case": case, + "tier": tier, + "repeats": int(repeats), + "warmup": int(warmup), + "median_runtime_ms": _median(runtimes_ms), + "mean_runtime_ms": float(np.mean(runtimes_ms)), + "std_runtime_ms": float(np.std(runtimes_ms)), + "median_peak_memory_mb": _median(peak_mem_mb), + "summary": summary, + "samples_runtime_ms": runtimes_ms, + "samples_peak_memory_mb": peak_mem_mb, + } + + +def _write_csv(rows: list[dict[str, Any]], out_csv: Path) -> None: + out_csv.parent.mkdir(parents=True, exist_ok=True) + fieldnames = [ + "case", + "tier", + "repeats", + "median_runtime_ms", + "mean_runtime_ms", + "std_runtime_ms", + "median_peak_memory_mb", + "summary", + ] + with out_csv.open("w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + writer.writerow( + { + "case": row["case"], + "tier": row["tier"], + "repeats": row["repeats"], + "median_runtime_ms": row["median_runtime_ms"], + "mean_runtime_ms": row["mean_runtime_ms"], + "std_runtime_ms": row["std_runtime_ms"], + "median_peak_memory_mb": row["median_peak_memory_mb"], + "summary": json.dumps(row["summary"], sort_keys=True), + } + ) + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--tiers", + type=str, + default="S,M,L", + help="Comma-separated tier list from {S,M,L}.", + ) + parser.add_argument("--repeats", type=int, default=7, help="Measured repeats per case/tier.") + parser.add_argument("--warmup", type=int, default=2, help="Warmup repeats per case/tier.") + parser.add_argument("--seed", type=int, default=20260303, help="Base deterministic seed.") + parser.add_argument( + "--out-json", + type=Path, + default=Path("output/performance/python_performance_report.json"), + help="Output JSON report path.", + ) + parser.add_argument( + "--out-csv", + type=Path, + default=Path("output/performance/python_performance_report.csv"), + help="Output CSV report path.", + ) + parser.add_argument( + "--repo-root", + type=Path, + default=Path(__file__).resolve().parents[2], + help="Repository root for metadata.", + ) + args = parser.parse_args() + + tiers = [t.strip().upper() for t in args.tiers.split(",") if t.strip()] + unknown = [t for t in tiers if t not in TIER_ORDER] + if unknown: + raise ValueError(f"Unsupported tiers: {unknown}") + + rows: list[dict[str, Any]] = [] + for case in CASE_ORDER: + for tier in tiers: + rows.append(_run_case(case=case, tier=tier, repeats=args.repeats, warmup=args.warmup, seed=args.seed)) + + report = { + "schema_version": 1, + "generated_at_utc": datetime.now(UTC).isoformat(timespec="seconds").replace("+00:00", "Z"), + "implementation": "python", + "repo_root": str(args.repo_root.resolve()), + "git_sha": _git_sha(args.repo_root.resolve()), + "tiers": tiers, + "cases": rows, + "environment": _collect_env(), + } + + args.out_json.parent.mkdir(parents=True, exist_ok=True) + args.out_json.write_text(json.dumps(report, indent=2), encoding="utf-8") + _write_csv(rows, args.out_csv) + + print(f"Wrote Python performance JSON: {args.out_json}") + print(f"Wrote Python performance CSV: {args.out_csv}") + print(f"Benchmarked case-tier pairs: {len(rows)}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/reports/build_image_parity_pdfs.py b/tools/reports/build_image_parity_pdfs.py new file mode 100755 index 00000000..5e9c8958 --- /dev/null +++ b/tools/reports/build_image_parity_pdfs.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +"""Build paired MATLAB/Python image-sequence PDFs for page-by-page parity checks.""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +from PIL import Image +from reportlab.lib.pagesizes import letter +from reportlab.pdfgen import canvas + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--report-json", type=Path, required=True, help="Validation summary JSON from generate_validation_pdf.py") + parser.add_argument( + "--python-out", + type=Path, + default=Path("output/pdf/image_mode_parity/python_pages.pdf"), + help="Output PDF containing Python images", + ) + parser.add_argument( + "--matlab-out", + type=Path, + default=Path("output/pdf/image_mode_parity/matlab_pages.pdf"), + help="Output PDF containing MATLAB images", + ) + parser.add_argument( + "--pairs-json", + type=Path, + default=Path("output/pdf/image_mode_parity/pairs.json"), + help="Output JSON containing selected per-topic image pairs", + ) + return parser.parse_args() + + +def _resolve_img(path_str: str) -> Path | None: + if not path_str: + return None + p = Path(path_str) + return p if p.exists() else None + + +def _select_pair(row: dict) -> tuple[Path | None, Path | None]: + py = _resolve_img(str(row.get("matched_python_image") or "")) + mat = _resolve_img(str(row.get("matched_matlab_image") or "")) + + if py is None: + py_list = row.get("python_images") or [] + if py_list: + py = _resolve_img(str(py_list[0])) + + if mat is None: + mat_list = row.get("matlab_reference_images") or [] + if mat_list: + mat = _resolve_img(str(mat_list[0])) + + return py, mat + + +def _draw_page(pdf: canvas.Canvas, *, topic: str, image_path: Path | None, label: str) -> None: + w, h = letter + pdf.setFont("Helvetica-Bold", 13) + pdf.drawString(36, h - 44, f"{label}: {topic}") + + if image_path is None: + pdf.setFont("Helvetica", 10) + pdf.drawString(36, h - 72, "Missing image") + pdf.showPage() + return + + pdf.setFont("Helvetica", 9) + pdf.drawString(36, h - 62, str(image_path)) + + with Image.open(image_path) as img: + iw, ih = img.size + max_w = w - 72 + max_h = h - 120 + scale = min(max_w / iw, max_h / ih) + draw_w = iw * scale + draw_h = ih * scale + x = (w - draw_w) / 2.0 + y = (h - 90 - draw_h) / 2.0 + pdf.drawImage(str(image_path), x, y, width=draw_w, height=draw_h, preserveAspectRatio=True, mask="auto") + pdf.showPage() + + +def main() -> int: + args = parse_args() + payload = json.loads(args.report_json.read_text(encoding="utf-8")) + rows = payload.get("notebooks", []) + + pairs: list[dict] = [] + for row in rows: + topic = str(row.get("topic", "")) + py, mat = _select_pair(row) + pairs.append( + { + "topic": topic, + "python_image": str(py) if py is not None else "", + "matlab_image": str(mat) if mat is not None else "", + } + ) + + args.python_out.parent.mkdir(parents=True, exist_ok=True) + args.matlab_out.parent.mkdir(parents=True, exist_ok=True) + args.pairs_json.parent.mkdir(parents=True, exist_ok=True) + + pdf_py = canvas.Canvas(str(args.python_out), pagesize=letter) + pdf_mat = canvas.Canvas(str(args.matlab_out), pagesize=letter) + + for pair in pairs: + topic = pair["topic"] + py = Path(pair["python_image"]) if pair["python_image"] else None + mat = Path(pair["matlab_image"]) if pair["matlab_image"] else None + _draw_page(pdf_py, topic=topic, image_path=py, label="Python") + _draw_page(pdf_mat, topic=topic, image_path=mat, label="MATLAB") + + pdf_py.save() + pdf_mat.save() + args.pairs_json.write_text(json.dumps({"schema_version": 1, "pairs": pairs}, indent=2) + "\n", encoding="utf-8") + + print(f"Wrote Python PDF: {args.python_out}") + print(f"Wrote MATLAB PDF: {args.matlab_out}") + print(f"Wrote pairs JSON: {args.pairs_json}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/reports/check_pdf_image_parity.py b/tools/reports/check_pdf_image_parity.py new file mode 100755 index 00000000..c86da146 --- /dev/null +++ b/tools/reports/check_pdf_image_parity.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 +"""Page-by-page image-mode parity gate for MATLAB-vs-Python validation PDFs.""" + +from __future__ import annotations + +import argparse +import json +import os +import platform +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np +from PIL import Image + +try: # Optional dependency; workflow installs it. + import fitz # type: ignore +except Exception as exc: # pragma: no cover + raise SystemExit(f"PyMuPDF (fitz) is required: {exc}") from exc + +try: # Optional fallback handled below. + from skimage.metrics import structural_similarity as skimage_ssim # type: ignore +except Exception: # pragma: no cover + skimage_ssim = None + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--python-pdf", type=Path, required=True, help="Rendered Python validation PDF") + parser.add_argument("--matlab-pdf", type=Path, required=True, help="Rendered MATLAB reference PDF") + parser.add_argument( + "--out-dir", + type=Path, + default=Path("output/pdf/image_mode_parity"), + help="Directory for parity artifacts", + ) + parser.add_argument("--dpi", type=int, default=150, help="Rasterization DPI") + parser.add_argument("--ssim-threshold", type=float, default=0.90, help="Minimum SSIM to pass") + parser.add_argument( + "--nrmse-threshold", + type=float, + default=0.20, + help="Maximum normalized RMSE when SSIM backend is unavailable", + ) + parser.add_argument( + "--max-failing-pages", + type=int, + default=0, + help="Allow up to this many failing pages before non-zero exit", + ) + parser.add_argument( + "--ignore-pages", + type=str, + default="", + help="Comma-separated 1-based page numbers to ignore, e.g. '1,2,10'", + ) + parser.add_argument( + "--summary-json", + type=Path, + default=None, + help="Optional summary JSON path (defaults to /summary.json)", + ) + return parser.parse_args() + + +@dataclass +class PageParity: + page: int + ignored: bool + metric: str + score: float + passed: bool + python_shape: tuple[int, int] + matlab_shape: tuple[int, int] + diff_image: str | None + + +def _parse_ignore_pages(raw: str) -> set[int]: + out: set[int] = set() + for token in raw.split(","): + token = token.strip() + if not token: + continue + out.add(int(token)) + return out + + +def _render_pdf_grayscale(pdf_path: Path, dpi: int) -> list[np.ndarray]: + if dpi <= 0: + raise ValueError("dpi must be positive") + if not pdf_path.exists(): + raise FileNotFoundError(f"PDF not found: {pdf_path}") + + scale = float(dpi) / 72.0 + matrix = fitz.Matrix(scale, scale) + doc = fitz.open(str(pdf_path)) + try: + pages: list[np.ndarray] = [] + for page in doc: + pix = page.get_pixmap(matrix=matrix, alpha=False) + arr = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) + if pix.n >= 3: + rgb = arr[:, :, :3].astype(np.float32) + gray = (0.299 * rgb[:, :, 0] + 0.587 * rgb[:, :, 1] + 0.114 * rgb[:, :, 2]) / 255.0 + else: + gray = arr[:, :, 0].astype(np.float32) / 255.0 + pages.append(np.clip(gray, 0.0, 1.0)) + return pages + finally: + doc.close() + + +def _resize_to_match(src: np.ndarray, shape: tuple[int, int]) -> np.ndarray: + if src.shape == shape: + return src + img = Image.fromarray(np.clip(src * 255.0, 0.0, 255.0).astype(np.uint8), mode="L") + resized = img.resize((shape[1], shape[0]), resample=Image.Resampling.BILINEAR) + return np.asarray(resized, dtype=np.float32) / 255.0 + + +def _nrmse(a: np.ndarray, b: np.ndarray) -> float: + rmse = float(np.sqrt(np.mean((a - b) ** 2))) + denom = max(float(np.max([a.max() - a.min(), b.max() - b.min()])), 1e-12) + return rmse / denom + + +def _save_diff_image(py: np.ndarray, mat: np.ndarray, out_path: Path) -> None: + py_u8 = np.clip(py * 255.0, 0.0, 255.0).astype(np.uint8) + mat_u8 = np.clip(mat * 255.0, 0.0, 255.0).astype(np.uint8) + diff = np.abs(py_u8.astype(np.int16) - mat_u8.astype(np.int16)).astype(np.uint8) + + py_rgb = np.stack([py_u8, py_u8, py_u8], axis=2) + mat_rgb = np.stack([mat_u8, mat_u8, mat_u8], axis=2) + diff_rgb = np.stack([diff, np.zeros_like(diff), np.zeros_like(diff)], axis=2) + panel = np.concatenate([py_rgb, mat_rgb, diff_rgb], axis=1) + out_path.parent.mkdir(parents=True, exist_ok=True) + Image.fromarray(panel, mode="RGB").save(out_path) + + +def _environment_metadata() -> dict[str, Any]: + metadata: dict[str, Any] = { + "python": sys.version.split()[0], + "platform": platform.platform(), + "numpy": np.__version__, + "omp_num_threads": os.environ.get("OMP_NUM_THREADS", ""), + "mkl_num_threads": os.environ.get("MKL_NUM_THREADS", ""), + "openblas_num_threads": os.environ.get("OPENBLAS_NUM_THREADS", ""), + "fitz": getattr(fitz, "__doc__", "").split()[1] if getattr(fitz, "__doc__", "") else "unknown", + } + try: + import scipy # type: ignore + + metadata["scipy"] = scipy.__version__ + except Exception: # pragma: no cover + metadata["scipy"] = "unavailable" + metadata["ssim_backend"] = "skimage" if skimage_ssim is not None else "nrmse" + return metadata + + +def main() -> int: + args = parse_args() + out_dir = args.out_dir.resolve() + out_dir.mkdir(parents=True, exist_ok=True) + summary_path = (args.summary_json.resolve() if args.summary_json else out_dir / "summary.json") + + ignore_pages = _parse_ignore_pages(args.ignore_pages) + py_pages = _render_pdf_grayscale(args.python_pdf.resolve(), args.dpi) + matlab_pages = _render_pdf_grayscale(args.matlab_pdf.resolve(), args.dpi) + + compare_pages = min(len(py_pages), len(matlab_pages)) + rows: list[PageParity] = [] + diff_dir = out_dir / "diff" + + for idx in range(compare_pages): + page_num = idx + 1 + py = py_pages[idx] + mat = _resize_to_match(matlab_pages[idx], py.shape) + ignored = page_num in ignore_pages + + if skimage_ssim is not None: + metric = "ssim" + score = float(skimage_ssim(py, mat, data_range=1.0)) + passed = (score >= args.ssim_threshold) or ignored + else: + metric = "nrmse" + score = float(_nrmse(py, mat)) + passed = (score <= args.nrmse_threshold) or ignored + + diff_path: Path | None = None + if not passed and not ignored: + diff_path = diff_dir / f"page_{page_num:03d}.png" + _save_diff_image(py, mat, diff_path) + + rows.append( + PageParity( + page=page_num, + ignored=ignored, + metric=metric, + score=score, + passed=passed, + python_shape=tuple(int(v) for v in py.shape), + matlab_shape=tuple(int(v) for v in mat.shape), + diff_image=(str(diff_path) if diff_path is not None else None), + ) + ) + + failed = [r for r in rows if not r.passed and not r.ignored] + count_mismatch = len(py_pages) != len(matlab_pages) + page_count_failure = 1 if count_mismatch else 0 + + if skimage_ssim is not None: + worst = sorted(rows, key=lambda r: r.score)[: min(10, len(rows))] + else: + worst = sorted(rows, key=lambda r: r.score, reverse=True)[: min(10, len(rows))] + + summary = { + "schema_version": 1, + "python_pdf": str(args.python_pdf.resolve()), + "matlab_pdf": str(args.matlab_pdf.resolve()), + "dpi": int(args.dpi), + "thresholds": { + "ssim_threshold": float(args.ssim_threshold), + "nrmse_threshold": float(args.nrmse_threshold), + "max_failing_pages": int(args.max_failing_pages), + }, + "environment": _environment_metadata(), + "page_counts": { + "python": len(py_pages), + "matlab": len(matlab_pages), + "compared": compare_pages, + "mismatch": bool(count_mismatch), + }, + "failed_page_count": len(failed), + "worst_pages": [ + {"page": r.page, "metric": r.metric, "score": r.score, "passed": r.passed, "ignored": r.ignored} + for r in worst + ], + "pages": [ + { + "page": r.page, + "ignored": r.ignored, + "metric": r.metric, + "score": r.score, + "passed": r.passed, + "python_shape": list(r.python_shape), + "matlab_shape": list(r.matlab_shape), + "diff_image": r.diff_image, + } + for r in rows + ], + } + summary_path.write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8") + + print(f"Wrote image-mode parity summary: {summary_path}") + print(f"Compared pages: {compare_pages} (python={len(py_pages)} matlab={len(matlab_pages)})") + print(f"Failed pages: {len(failed)}") + if count_mismatch: + print("Page-count mismatch detected between Python and MATLAB PDFs") + + if page_count_failure > 0: + return 1 + return 0 if len(failed) <= args.max_failing_pages else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/reports/generate_validation_pdf.py b/tools/reports/generate_validation_pdf.py index ecef9f50..ab572db6 100755 --- a/tools/reports/generate_validation_pdf.py +++ b/tools/reports/generate_validation_pdf.py @@ -5,6 +5,7 @@ import argparse import base64 +import csv import functools import hashlib import json @@ -209,6 +210,18 @@ def parse_args() -> argparse.Namespace: "cross_topic_reused_hashes / total_unique_hashes must be <= this value." ), ) + parser.add_argument( + "--summary-json", + type=Path, + default=None, + help="Machine-readable JSON summary output path (defaults beside the PDF).", + ) + parser.add_argument( + "--summary-csv", + type=Path, + default=None, + help="Machine-readable CSV summary output path (defaults beside the PDF).", + ) return parser.parse_args() @@ -726,6 +739,201 @@ def _uniqueness_violations( return violations, stats +def _topic_class_hint(topic: str) -> str: + overrides = { + "AnalysisExamples": "Analysis", + "AnalysisExamples2": "Analysis", + "ConfigCollExamples": "ConfigCollection", + "CovCollExamples": "CovariateCollection", + "CovariateExamples": "Covariate", + "DecodingExample": "DecodingAlgorithms", + "DecodingExampleWithHist": "DecodingAlgorithms", + "EventsExamples": "Events", + "FitResSummaryExamples": "FitSummary", + "FitResultExamples": "FitResult", + "FitResultReference": "FitResult", + "HistoryExamples": "HistoryBasis", + "SignalObjExamples": "Signal", + "StimulusDecode2D": "DecodingAlgorithms", + "TrialConfigExamples": "TrialConfig", + "TrialExamples": "Trial", + "nSpikeTrainExamples": "SpikeTrain", + "nstCollExamples": "SpikeTrainCollection", + } + if topic in overrides: + return overrides[topic] + if topic.endswith("Examples"): + return topic[: -len("Examples")] or topic + return "Workflow" + + +def _as_rel(path: Path | None, repo_root: Path) -> str: + if path is None: + return "" + try: + return str(path.resolve().relative_to(repo_root.resolve())) + except Exception: + return str(path) + + +def write_machine_readable_summaries( + *, + report_path: Path, + repo_root: Path, + reports: list[NotebookReport], + command_results: list[CommandResult], + matlab_help_root: Path | None, + notebook_group: str, + parity_mode: str, + parity_threshold: float, + uniqueness_stats: dict[str, float | int], + uniqueness_violations: list[str], + summary_json_path: Path, + summary_csv_path: Path, +) -> tuple[Path, Path]: + summary_json_path.parent.mkdir(parents=True, exist_ok=True) + summary_csv_path.parent.mkdir(parents=True, exist_ok=True) + + notebook_rows: list[dict[str, object]] = [] + for report in reports: + metrics = dict(report.parity_metrics or {}) + diff_artifacts: list[str] = [] + if report.matched_python_image is not None: + diff_artifacts.append(_as_rel(report.matched_python_image, repo_root)) + if report.matched_matlab_image is not None: + diff_artifacts.append(_as_rel(report.matched_matlab_image, repo_root)) + notebook_rows.append( + { + "topic": report.topic, + "class_hint": _topic_class_hint(report.topic), + "notebook": _as_rel(report.file, repo_root), + "run_group": report.run_group, + "executed": bool(report.executed), + "duration_s": float(report.duration_s), + "execution_pass": bool(report.executed and not bool(report.error)), + "parity_pass": report.parity_pass, + "alignment_status": report.alignment_status, + "numeric_drift_pass": metrics.get("numeric_drift_pass"), + "numeric_drift_failed_metric_count": metrics.get("numeric_drift_failed_metric_count"), + "similarity_score": report.similarity_score, + "image_count": int(report.image_count), + "unique_image_count": int(report.unique_image_count), + "duplicate_image_count": int(report.duplicate_image_count), + "error": report.error, + "matched_python_image": _as_rel(report.matched_python_image, repo_root), + "matched_matlab_image": _as_rel(report.matched_matlab_image, repo_root), + "python_images": [_as_rel(path, repo_root) for path in report.image_paths], + "matlab_reference_images": [_as_rel(path, repo_root) for path in report.matlab_ref_images], + "diff_artifacts": diff_artifacts, + "parity_metrics": metrics, + } + ) + + command_rows = [ + { + "name": row.name, + "command": " ".join(row.command), + "passed": row.passed, + "returncode": int(row.returncode), + "duration_s": float(row.duration_s), + "stdout_tail": row.stdout_tail, + } + for row in command_results + ] + + executed = sum(1 for row in reports if row.executed) + exec_failures = len(reports) - executed + parity_checked = sum(1 for row in reports if row.parity_pass is not None) + parity_failures = sum(1 for row in reports if row.parity_pass is False) + numeric_checked = sum( + 1 + for row in reports + if row.parity_metrics is not None and "numeric_drift_pass" in row.parity_metrics + ) + numeric_failures = sum( + 1 + for row in reports + if row.parity_metrics is not None and row.parity_metrics.get("numeric_drift_pass") is False + ) + + payload = { + "schema_version": 1, + "generated_at_utc": datetime.utcnow().isoformat(timespec="seconds") + "Z", + "repo_root": str(repo_root), + "report_pdf": str(report_path), + "matlab_help_root": str(matlab_help_root) if matlab_help_root is not None else "", + "notebook_group": notebook_group, + "parity_mode": parity_mode, + "parity_threshold": float(parity_threshold), + "aggregate": { + "total_notebooks": len(reports), + "executed": executed, + "execution_failures": exec_failures, + "parity_checked": parity_checked, + "parity_failures": parity_failures, + "numeric_drift_checked": numeric_checked, + "numeric_drift_failures": numeric_failures, + "command_checks_total": len(command_results), + "command_checks_failed": sum(1 for row in command_results if not row.passed), + "uniqueness_violations": len(uniqueness_violations), + "uniqueness": uniqueness_stats, + }, + "command_checks": command_rows, + "notebooks": notebook_rows, + } + summary_json_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + + csv_columns = [ + "topic", + "class_hint", + "notebook", + "run_group", + "executed", + "execution_pass", + "duration_s", + "parity_pass", + "alignment_status", + "numeric_drift_pass", + "numeric_drift_failed_metric_count", + "similarity_score", + "image_count", + "unique_image_count", + "duplicate_image_count", + "matched_python_image", + "matched_matlab_image", + "diff_artifacts", + "error", + ] + with summary_csv_path.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=csv_columns) + writer.writeheader() + for row in notebook_rows: + writer.writerow( + { + "topic": row["topic"], + "class_hint": row["class_hint"], + "notebook": row["notebook"], + "run_group": row["run_group"], + "executed": row["executed"], + "execution_pass": row["execution_pass"], + "duration_s": row["duration_s"], + "parity_pass": row["parity_pass"], + "alignment_status": row["alignment_status"], + "numeric_drift_pass": row["numeric_drift_pass"], + "numeric_drift_failed_metric_count": row["numeric_drift_failed_metric_count"], + "similarity_score": row["similarity_score"], + "image_count": row["image_count"], + "unique_image_count": row["unique_image_count"], + "duplicate_image_count": row["duplicate_image_count"], + "matched_python_image": row["matched_python_image"], + "matched_matlab_image": row["matched_matlab_image"], + "diff_artifacts": ";".join(row["diff_artifacts"]), + "error": row["error"], + } + ) + return summary_json_path, summary_csv_path + + def _draw_wrapped_lines( pdf: canvas.Canvas, x: float, @@ -1292,8 +1500,26 @@ def main() -> int: min_unique_images_per_topic=args.min_unique_images_per_topic, max_cross_topic_reuse_ratio=args.max_cross_topic_reuse_ratio, ) + summary_json_path = args.summary_json or output_pdf.with_suffix(".json") + summary_csv_path = args.summary_csv or output_pdf.with_suffix(".csv") + summary_json_path, summary_csv_path = write_machine_readable_summaries( + report_path=report_path, + repo_root=args.repo_root, + reports=reports, + command_results=command_results, + matlab_help_root=matlab_help_root, + notebook_group=args.notebook_group, + parity_mode=args.parity_mode, + parity_threshold=args.parity_threshold, + uniqueness_stats=uniqueness_stats, + uniqueness_violations=uniqueness_violations, + summary_json_path=summary_json_path, + summary_csv_path=summary_csv_path, + ) print(f"Generated PDF report: {report_path}") + print(f"Machine-readable summary (JSON): {summary_json_path}") + print(f"Machine-readable summary (CSV): {summary_csv_path}") print(f"MATLAB help root: {matlab_help_root}") print( f"Notebook results: total={len(reports)} executed={executed} exec_failures={exec_failures} "