Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ Source pages:
- [docs/ClassDefinitions.md](docs/ClassDefinitions.md)
- [docs/DocumentationSetup.md](docs/DocumentationSetup.md)

## Plot Style

```python
from nstat.plot_style import set_plot_style

# Modern readability-focused plots (default)
set_plot_style('modern')

# Legacy visual style for strict reproduction
set_plot_style('legacy')
```

## Paper-Aligned Toolbox Map

To keep terminology and workflows consistent with the 2012 toolbox paper,
Expand Down Expand Up @@ -201,6 +213,36 @@ pytest -q
sphinx-build -b html docs docs/_build
```

## Code audit (2026-03-11)

The Python port was verified against the MATLAB reference through a comprehensive
5-phase audit covering all 16 classes and 484 methods. **466 methods found in
Python, 6 nominal (MATLAB-infrastructure) gaps.** Full class-level and behavioral
parity verified.

**Python bugs fixed during the port:**

- `SignalObj.std()` used `ddof=0`; MATLAB uses `ddof=1` (N-1 normalization)
- `CovariateCollection.isCovPresent()` off-by-one in boundary check
- `SpikeTrainCollection.psthGLM()` was a stub; now wired to the full GLM path
- `SpikeTrainCollection.getNSTnames()` / `getUniqueNSTnames()` ignored the
`selectorArray` filter parameter
- `nspikeTrain.getNST()` missing resample check on retrieval

**MATLAB bugs discovered (13 total, filed as GitHub issues):**

- `FitResult.m` — KS test used `sampleRate` as bin width instead of
`1/sampleRate`, invalidating goodness-of-fit for any sampleRate != 1
- `CIF.m` — `symvar()` reordered variables alphabetically, causing silent
argument mismatch for non-alphabetical variable names
- `SignalObj.m` — `findPeaks('minima')` returned maxima; `findGlobalPeak('minima')`
crashed; handle aliasing mutated input signals in arithmetic
- `DecodingAlgorithms.m` — `isa(condNum,'nan')` always false; `ExplambdaDeltaCubed`
used `.^2` instead of `.^3`
- `Analysis.m` — Granger causality mask zeroed all columns instead of column `i`

See [parity/parity_report.md](parity/parity_report.md) for the full audit.

## License

nSTAT is protected by the GPL v2 Open Source License.
Expand Down
19 changes: 3 additions & 16 deletions notebooks/AnalysisExamples2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"\n",
"GLM_DATA = load_glm_data_for_notebook()\n",
"OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n",
"__tracker = FigureTracker(topic=\"AnalysisExamples2\", output_root=OUTPUT_ROOT, expected_count=5)\n",
"__tracker = FigureTracker(topic=\"AnalysisExamples2\", output_root=OUTPUT_ROOT, expected_count=4)\n",
"\n",
"def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n",
" fig = __tracker.new_figure(matlab_line)\n",
Expand Down Expand Up @@ -187,20 +187,7 @@
"id": "64cdac30",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 8: Toolbox vs. Standard GLM comparison\n",
"standard_fit = fit_poisson_glm(np.column_stack([np.ones_like(xN), xN, yN, xN**2, yN**2, xN * yN]), spikes_binned, include_intercept=False)\n",
"coeff_diff = np.asarray(standard_fit.coefficients - fitResults.getCoeffs(2), dtype=float)\n",
"fig = _prepare_figure(\"b-fitResults.b{2}\", figsize=(7.0, 4.5))\n",
"ax = fig.subplots(1, 1)\n",
"labels = [\"mu\", \"x\", \"y\", \"x^2\", \"y^2\", \"x*y\"]\n",
"ax.bar(np.arange(coeff_diff.size), coeff_diff, color=\"tab:blue\")\n",
"ax.axhline(0.0, color=\"0.3\", linestyle=\"--\", linewidth=1.0)\n",
"ax.set_xticks(np.arange(coeff_diff.size), labels, rotation=20)\n",
"ax.set_ylabel(\"standard minus toolbox\")\n",
"ax.set_title(\"Coefficient agreement between workflows\")\n",
"print({\"quadratic_coeff_diff_max_abs\": round(float(np.max(np.abs(coeff_diff))), 6)})"
]
"source": "# SECTION 8: Toolbox vs. Standard GLM comparison\n# Compare the nSTAT fit with a standalone glmfit using the same Quadratic covariates\n# MATLAB: [b,dev,stats] = glmfit([xN yN xN.^2 yN.^2 xN.*yN], spikes_binned, 'poisson');\n# b - fitResults.b{2} % should be close to zero\nX_quad = np.column_stack([xN, yN, xN**2, yN**2, xN * yN])\nglm_result = fit_poisson_glm(X_quad, spikes_binned)\nb = np.concatenate([[glm_result.intercept], glm_result.coefficients])\nb_diff = b - fitResults.getCoeffs(2)\nprint(\"b - fitResults.b{2} =\", b_diff)"
},
{
"cell_type": "code",
Expand Down Expand Up @@ -238,4 +225,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
12 changes: 6 additions & 6 deletions notebooks/ExplicitStimulusWhiskerData.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"np.random.seed(0)\n",
"DATA_DIR = notebook_example_data_dir(allow_synthetic=True)\n",
"OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n",
"__tracker = FigureTracker(topic='ExplicitStimulusWhiskerData', output_root=OUTPUT_ROOT, expected_count=9)\n",
"__tracker = FigureTracker(topic='ExplicitStimulusWhiskerData', output_root=OUTPUT_ROOT, expected_count=10)\n",
"\n",
"def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n",
" fig = __tracker.new_figure(matlab_line)\n",
Expand All @@ -61,7 +61,7 @@
" ideal_arr = np.asarray(ideal, dtype=float)\n",
" empirical_arr = np.asarray(empirical, dtype=float)\n",
" ci_arr = np.asarray(ci, dtype=float)\n",
" ax.plot(ideal_arr, ideal_arr, color=\"0.2\", linewidth=1.0, linestyle=\"--\", label=\"45° line\")\n",
" ax.plot(ideal_arr, ideal_arr, color=\"0.2\", linewidth=1.0, linestyle=\"--\", label=\"45\u00b0 line\")\n",
" ax.plot(ideal_arr, empirical_arr, color=color, linewidth=1.5, label=label)\n",
" ax.fill_between(\n",
" ideal_arr,\n",
Expand Down Expand Up @@ -209,10 +209,10 @@
"axs[0].set_title(\"History-window scan\")\n",
"axs[1].plot(history_windows, payload[\"delta_aic\"], marker=\"o\", color=\"tab:green\", linewidth=1.2)\n",
"axs[1].scatter([history_windows[best_history_idx]], [payload[\"delta_aic\"][best_history_idx]], color=\"tab:red\", zorder=3)\n",
"axs[1].set_ylabel(\"ΔAIC\")\n",
"axs[1].set_ylabel(\"\u0394AIC\")\n",
"axs[2].plot(history_windows, payload[\"delta_bic\"], marker=\"o\", color=\"tab:brown\", linewidth=1.2)\n",
"axs[2].scatter([history_windows[best_history_idx]], [payload[\"delta_bic\"][best_history_idx]], color=\"tab:red\", zorder=3)\n",
"axs[2].set_ylabel(\"ΔBIC\")\n",
"axs[2].set_ylabel(\"\u0394BIC\")\n",
"axs[2].set_xlabel(\"history window count\")\n",
"\n",
"fig = _prepare_figure(\"plot(x,dBIC,'.')\", figsize=(8.0, 4.5))\n",
Expand All @@ -221,7 +221,7 @@
"ax.axvline(history_windows[best_history_idx], color=\"tab:red\", linestyle=\"--\", linewidth=1.0)\n",
"ax.set_title(\"BIC improvement across history-window choices\")\n",
"ax.set_xlabel(\"history window count\")\n",
"ax.set_ylabel(\"ΔBIC relative to first history model\")"
"ax.set_ylabel(\"\u0394BIC relative to first history model\")"
]
},
{
Expand Down Expand Up @@ -287,4 +287,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
132 changes: 27 additions & 105 deletions notebooks/HippocampalPlaceCellExample.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"np.random.seed(0)\n",
"DATA_DIR = notebook_example_data_dir(allow_synthetic=True)\n",
"OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n",
"__tracker = FigureTracker(topic='HippocampalPlaceCellExample', output_root=OUTPUT_ROOT, expected_count=11)\n",
"__tracker = FigureTracker(topic='HippocampalPlaceCellExample', output_root=OUTPUT_ROOT, expected_count=10)\n",
"\n",
"def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n",
" fig = __tracker.new_figure(matlab_line)\n",
Expand Down Expand Up @@ -127,23 +127,20 @@
"outputs": [],
"source": [
"# SECTION 2: Analyze All Cells\n",
"fig = _prepare_figure(\"Summary.plotSummary\", figsize=(7.5, 4.5))\n",
"ax = fig.subplots(1, 1)\n",
"fig = _prepare_figure(\"Summary.plotSummary\", figsize=(12.0, 4.5))\n",
"axs = fig.subplots(1, 2)\n",
"animal1 = payload[\"animal1\"]\n",
"labels = [f\"Cell {int(idx) + 1}\" for idx in np.asarray(animal1[\"selected_indices\"], dtype=int)]\n",
"ax.bar(np.arange(len(labels)), animal1[\"delta_aic\"], color=\"tab:purple\")\n",
"ax.axhline(0.0, color=\"0.2\", linewidth=1.0)\n",
"ax.set_xticks(np.arange(len(labels)), labels, rotation=20)\n",
"ax.set_ylabel(\"Gaussian - Zernike AIC\")\n",
"ax.set_title(\"Animal 1 model comparison\")\n",
"\n",
"fig = _prepare_figure(\"Summary.plotSummary\", figsize=(7.5, 4.5))\n",
"ax = fig.subplots(1, 1)\n",
"ax.bar(np.arange(len(labels)), animal1[\"delta_bic\"], color=\"tab:green\")\n",
"ax.axhline(0.0, color=\"0.2\", linewidth=1.0)\n",
"ax.set_xticks(np.arange(len(labels)), labels, rotation=20)\n",
"ax.set_ylabel(\"Gaussian - Zernike BIC\")\n",
"ax.set_title(\"Animal 1 model comparison\")"
"axs[0].bar(np.arange(len(labels)), animal1[\"delta_aic\"], color=\"tab:purple\")\n",
"axs[0].axhline(0.0, color=\"0.2\", linewidth=1.0)\n",
"axs[0].set_xticks(np.arange(len(labels)), labels, rotation=20)\n",
"axs[0].set_ylabel(\"Gaussian - Zernike AIC\")\n",
"axs[0].set_title(\"Animal 1 AIC\")\n",
"axs[1].bar(np.arange(len(labels)), animal1[\"delta_bic\"], color=\"tab:green\")\n",
"axs[1].axhline(0.0, color=\"0.2\", linewidth=1.0)\n",
"axs[1].set_xticks(np.arange(len(labels)), labels, rotation=20)\n",
"axs[1].set_ylabel(\"Gaussian - Zernike BIC\")\n",
"axs[1].set_title(\"Animal 1 BIC\")\n"
]
},
{
Expand All @@ -154,23 +151,20 @@
"outputs": [],
"source": [
"# SECTION 3: View Summary Statistics\n",
"fig = _prepare_figure(\"Summary.plotSummary\", figsize=(7.5, 4.5))\n",
"ax = fig.subplots(1, 1)\n",
"fig = _prepare_figure(\"Summary.plotSummary\", figsize=(12.0, 4.5))\n",
"axs = fig.subplots(1, 2)\n",
"animal2 = payload[\"animal2\"]\n",
"labels = [f\"Cell {int(idx) + 1}\" for idx in np.asarray(animal2[\"selected_indices\"], dtype=int)]\n",
"ax.bar(np.arange(len(labels)), animal2[\"delta_aic\"], color=\"tab:purple\")\n",
"ax.axhline(0.0, color=\"0.2\", linewidth=1.0)\n",
"ax.set_xticks(np.arange(len(labels)), labels, rotation=20)\n",
"ax.set_ylabel(\"Gaussian - Zernike AIC\")\n",
"ax.set_title(\"Animal 2 model comparison\")\n",
"\n",
"fig = _prepare_figure(\"Summary.plotSummary\", figsize=(7.5, 4.5))\n",
"ax = fig.subplots(1, 1)\n",
"ax.bar(np.arange(len(labels)), animal2[\"delta_bic\"], color=\"tab:green\")\n",
"ax.axhline(0.0, color=\"0.2\", linewidth=1.0)\n",
"ax.set_xticks(np.arange(len(labels)), labels, rotation=20)\n",
"ax.set_ylabel(\"Gaussian - Zernike BIC\")\n",
"ax.set_title(\"Animal 2 model comparison\")"
"axs[0].bar(np.arange(len(labels)), animal2[\"delta_aic\"], color=\"tab:purple\")\n",
"axs[0].axhline(0.0, color=\"0.2\", linewidth=1.0)\n",
"axs[0].set_xticks(np.arange(len(labels)), labels, rotation=20)\n",
"axs[0].set_ylabel(\"Gaussian - Zernike AIC\")\n",
"axs[0].set_title(\"Animal 2 AIC\")\n",
"axs[1].bar(np.arange(len(labels)), animal2[\"delta_bic\"], color=\"tab:green\")\n",
"axs[1].axhline(0.0, color=\"0.2\", linewidth=1.0)\n",
"axs[1].set_xticks(np.arange(len(labels)), labels, rotation=20)\n",
"axs[1].set_ylabel(\"Gaussian - Zernike BIC\")\n",
"axs[1].set_title(\"Animal 2 BIC\")\n"
]
},
{
Expand All @@ -179,79 +173,7 @@
"id": "26aafec5",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 4: Visualize the results\n",
"fig = _prepare_figure(\"h4=figure(4)\", figsize=(8.5, 8.0))\n",
"_plot_field_grid(fig, \"animal1\", \"gaussian_fields\", \"Gaussian place fields - Animal 1\")\n",
"\n",
"fig = _prepare_figure(\"h5=figure(5)\", figsize=(8.5, 8.0))\n",
"_plot_field_grid(fig, \"animal1\", \"zernike_fields\", \"Zernike place fields - Animal 1\")\n",
"\n",
"fig = _prepare_figure(\"h6=figure(6)\", figsize=(8.5, 8.0))\n",
"_plot_field_grid(fig, \"animal2\", \"gaussian_fields\", \"Gaussian place fields - Animal 2\")\n",
"\n",
"fig = _prepare_figure(\"h7=figure(7)\", figsize=(8.5, 8.0))\n",
"_plot_field_grid(fig, \"animal2\", \"zernike_fields\", \"Zernike place fields - Animal 2\")\n",
"\n",
"fig = _prepare_figure(\"figure(8)\", figsize=(7.0, 5.5))\n",
"ax = fig.subplots(1, 1)\n",
"ax.imshow(\n",
" mesh[\"gaussian_field\"],\n",
" origin=\"lower\",\n",
" extent=[float(np.min(mesh[\"grid_x\"])), float(np.max(mesh[\"grid_x\"])), float(np.min(mesh[\"grid_y\"])), float(np.max(mesh[\"grid_y\"]))],\n",
" aspect=\"equal\",\n",
" cmap=\"viridis\",\n",
")\n",
"ax.plot(mesh[\"x_pos\"], mesh[\"y_pos\"], color=\"white\", linewidth=0.5, alpha=0.35)\n",
"ax.scatter(spike_x, spike_y, s=8, color=\"tab:red\", alpha=0.7)\n",
"ax.set_title(f\"Gaussian receptive field - Cell {int(mesh['cell_index']) + 1}\")\n",
"ax.set_xlabel(\"x\")\n",
"ax.set_ylabel(\"y\")\n",
"\n",
"fig = _prepare_figure(\"figure(9)\", figsize=(7.0, 5.5))\n",
"ax = fig.subplots(1, 1)\n",
"ax.imshow(\n",
" mesh[\"zernike_field\"],\n",
" origin=\"lower\",\n",
" extent=[float(np.min(mesh[\"grid_x\"])), float(np.max(mesh[\"grid_x\"])), float(np.min(mesh[\"grid_y\"])), float(np.max(mesh[\"grid_y\"]))],\n",
" aspect=\"equal\",\n",
" cmap=\"viridis\",\n",
")\n",
"ax.plot(mesh[\"x_pos\"], mesh[\"y_pos\"], color=\"white\", linewidth=0.5, alpha=0.35)\n",
"ax.scatter(spike_x, spike_y, s=8, color=\"tab:red\", alpha=0.7)\n",
"ax.set_title(f\"Zernike receptive field - Cell {int(mesh['cell_index']) + 1}\")\n",
"ax.set_xlabel(\"x\")\n",
"ax.set_ylabel(\"y\")\n",
"\n",
"fig = _prepare_figure(\"figure(10)\", figsize=(9.0, 4.5))\n",
"axs = fig.subplots(1, 2)\n",
"axs[0].hist(np.concatenate([payload[\"animal1\"][\"delta_aic\"], payload[\"animal2\"][\"delta_aic\"]]), bins=8, color=\"tab:purple\", alpha=0.8)\n",
"axs[0].axvline(0.0, color=\"0.2\", linewidth=1.0)\n",
"axs[0].set_title(\"Distribution of ΔAIC\")\n",
"axs[1].hist(np.concatenate([payload[\"animal1\"][\"delta_bic\"], payload[\"animal2\"][\"delta_bic\"]]), bins=8, color=\"tab:green\", alpha=0.8)\n",
"axs[1].axvline(0.0, color=\"0.2\", linewidth=1.0)\n",
"axs[1].set_title(\"Distribution of ΔBIC\")\n",
"\n",
"fig = _prepare_figure(\"figure(11)\", figsize=(6.5, 4.5))\n",
"ax = fig.subplots(1, 1)\n",
"ax.axis(\"off\")\n",
"ax.text(\n",
" 0.0,\n",
" 0.95,\n",
" \"\\n\".join(\n",
" [\n",
" f\"Cells analyzed: {int(summary['num_cells_fit'])}\",\n",
" f\"Mean Gaussian-Zernike ΔAIC: {summary['mean_delta_aic_gaussian_minus_zernike']:.2f}\",\n",
" f\"Mean Gaussian-Zernike ΔBIC: {summary['mean_delta_bic_gaussian_minus_zernike']:.2f}\",\n",
" \"Negative values favor the Zernike model.\",\n",
" ]\n",
" ),\n",
" va=\"top\",\n",
" family=\"monospace\",\n",
" fontsize=10,\n",
")\n",
"__tracker.finalize()"
]
"source": "# SECTION 4: Visualize the results\nfig = _prepare_figure(\"h4=figure(4)\", figsize=(8.5, 8.0))\n_plot_field_grid(fig, \"animal1\", \"gaussian_fields\", \"Gaussian place fields - Animal 1\")\n\nfig = _prepare_figure(\"h5=figure(5)\", figsize=(8.5, 8.0))\n_plot_field_grid(fig, \"animal1\", \"zernike_fields\", \"Zernike place fields - Animal 1\")\n\nfig = _prepare_figure(\"h6=figure(6)\", figsize=(8.5, 8.0))\n_plot_field_grid(fig, \"animal2\", \"gaussian_fields\", \"Gaussian place fields - Animal 2\")\n\nfig = _prepare_figure(\"h7=figure(7)\", figsize=(8.5, 8.0))\n_plot_field_grid(fig, \"animal2\", \"zernike_fields\", \"Zernike place fields - Animal 2\")\n\nfig = _prepare_figure(\"figure(8)\", figsize=(7.0, 5.5))\nax = fig.subplots(1, 1)\nax.imshow(\n mesh[\"gaussian_field\"],\n origin=\"lower\",\n extent=[float(np.min(mesh[\"grid_x\"])), float(np.max(mesh[\"grid_x\"])), float(np.min(mesh[\"grid_y\"])), float(np.max(mesh[\"grid_y\"]))],\n aspect=\"equal\",\n cmap=\"viridis\",\n)\nax.plot(mesh[\"x_pos\"], mesh[\"y_pos\"], color=\"white\", linewidth=0.5, alpha=0.35)\nax.scatter(spike_x, spike_y, s=8, color=\"tab:red\", alpha=0.7)\nax.set_title(f\"Gaussian receptive field - Cell {int(mesh['cell_index']) + 1}\")\nax.set_xlabel(\"x\")\nax.set_ylabel(\"y\")\n\nfig = _prepare_figure(\"figure(9)\", figsize=(7.0, 5.5))\nax = fig.subplots(1, 1)\nax.imshow(\n mesh[\"zernike_field\"],\n origin=\"lower\",\n extent=[float(np.min(mesh[\"grid_x\"])), float(np.max(mesh[\"grid_x\"])), float(np.min(mesh[\"grid_y\"])), float(np.max(mesh[\"grid_y\"]))],\n aspect=\"equal\",\n cmap=\"viridis\",\n)\nax.plot(mesh[\"x_pos\"], mesh[\"y_pos\"], color=\"white\", linewidth=0.5, alpha=0.35)\nax.scatter(spike_x, spike_y, s=8, color=\"tab:red\", alpha=0.7)\nax.set_title(f\"Zernike receptive field - Cell {int(mesh['cell_index']) + 1}\")\nax.set_xlabel(\"x\")\nax.set_ylabel(\"y\")\n\n# figure(9) overlay — matches MATLAB hold-on composite (published as _10.png)\nfig = _prepare_figure(\"figure(9) overlay\", figsize=(10.0, 5.0))\naxs = fig.subplots(1, 2)\next = [float(np.min(mesh[\"grid_x\"])), float(np.max(mesh[\"grid_x\"])), float(np.min(mesh[\"grid_y\"])), float(np.max(mesh[\"grid_y\"]))]\nfor ax, field, label in zip(axs, [mesh[\"gaussian_field\"], mesh[\"zernike_field\"]], [\"Gaussian\", \"Zernike\"]):\n ax.imshow(field, origin=\"lower\", extent=ext, aspect=\"equal\", cmap=\"viridis\")\n ax.plot(mesh[\"x_pos\"], mesh[\"y_pos\"], color=\"white\", linewidth=0.5, alpha=0.35)\n ax.scatter(spike_x, spike_y, s=8, color=\"tab:red\", alpha=0.7)\n ax.set_title(f\"{label} - Cell {int(mesh['cell_index']) + 1}\")\n ax.set_xlabel(\"x\")\n ax.set_ylabel(\"y\")\n\n__tracker.finalize()"
}
],
"metadata": {
Expand All @@ -267,4 +189,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
Loading
Loading