Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,38 @@ jobs:
- name: Execute Python notebook parity-core group
run: python tools/notebooks/run_notebooks.py --group parity_core --timeout 900

notebook-changed-pr:
if: github.event_name == 'pull_request'
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -e .[dev]
python -m pip install ipykernel
python -m ipykernel install --user --name python3 --display-name "Python 3"
- name: Resolve changed notebook topics
id: changed_topics
run: |
TOPICS=$(python tools/notebooks/changed_topics.py \
--base-sha "${{ github.event.pull_request.base.sha }}" \
--head-sha "${{ github.sha }}")
echo "topics=${TOPICS}" >> "$GITHUB_OUTPUT"
echo "Resolved notebook topics: ${TOPICS:-<none>}"
- name: Execute changed notebook topics
if: steps.changed_topics.outputs.topics != ''
run: python tools/notebooks/run_notebooks.py --group full --topics "${{ steps.changed_topics.outputs.topics }}" --timeout 1200
- name: No notebook changes detected
if: steps.changed_topics.outputs.topics == ''
run: echo "No changed notebook topics to execute."

cleanroom-compliance:
runs-on: ubuntu-latest

Expand Down
266 changes: 176 additions & 90 deletions notebooks/AnalysisExamples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
"cells": [
{
"cell_type": "markdown",
"id": "0f7ae1f5",
"id": "667810f9",
"metadata": {},
"source": [
"<!-- parity-note -->\n",
"## MATLAB Parity Note\n",
"- Source MATLAB helpfile: `AnalysisExamples.mlx`\n",
"- Fidelity status: `partial`\n",
"- Remaining justified differences: Advanced MATLAB algorithm-selection branches and report plots remain lighter in Python, and the notebook still contains tracker-only visualization sections rather than a fully executable MATLAB-equivalent workflow."
"- Fidelity status: `high_fidelity`\n",
"- Remaining justified differences: The notebook now follows the MATLAB standard-GLM workflow with the canonical `glm_data.mat` dataset and real KS/model-visualization figures; coefficient values and styling still vary modestly because the Python GLM backend and plotting defaults differ from MATLAB.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b518f479",
"id": "1f210963",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -36,107 +36,193 @@
"import numpy as np\n",
"from scipy.io import loadmat\n",
"\n",
"from nstat import Analysis, Covariate, nspikeTrain\n",
"from nstat.data_manager import ensure_example_data\n",
"from nstat.glm import fit_poisson_glm\n",
"from nstat.notebook_figures import FigureTracker\n",
"\n",
"np.random.seed(0)\n",
"DATA_DIR = ensure_example_data(download=True)\n",
"GLM_DATA = loadmat(DATA_DIR / \"glm_data.mat\", squeeze_me=True, struct_as_record=False)\n",
"OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n",
"__tracker = FigureTracker(topic='AnalysisExamples', output_root=OUTPUT_ROOT, expected_count=4)\n",
"__tracker = FigureTracker(topic=\"AnalysisExamples\", output_root=OUTPUT_ROOT, expected_count=4)\n",
"\n",
"def _load_example_globals(name: str) -> dict[str, object]:\n",
" candidates = [\n",
" Path(name),\n",
" DATA_DIR / name,\n",
" DATA_DIR / \"mEPSCs\" / name,\n",
" DATA_DIR / \"Place Cells\" / name,\n",
" DATA_DIR / \"Explicit Stimulus\" / name,\n",
" ]\n",
" for path in candidates:\n",
" if path.exists():\n",
" data = loadmat(path)\n",
" return {k: v for k, v in data.items() if not k.startswith(\"__\")}\n",
" return {}\n",
"\n",
"# SECTION 0: Section 0\n",
"# Analysis Examples\n",
"# This is an example on the standard approach to fitting GLM models to spike train data. This data set was obtained at the Society For Neuroscience '08 Workshop on Workshop on Neural Signal Processing Compare to analysis with Neural Spike Analysis Toolbox"
"def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n",
" fig = __tracker.new_figure(matlab_line)\n",
" fig.clear()\n",
" fig.set_size_inches(*figsize)\n",
" return fig\n",
"\n",
"\n",
"def _poisson_standard_errors(design_matrix, result):\n",
" x = np.asarray(design_matrix, dtype=float)\n",
" if x.ndim == 1:\n",
" x = x[:, None]\n",
" x_aug = np.column_stack([np.ones(x.shape[0]), x])\n",
" beta = np.concatenate([[result.intercept], np.asarray(result.coefficients, dtype=float)])\n",
" lam = np.exp(np.clip(x_aug @ beta, -20.0, 20.0))\n",
" cov = np.linalg.pinv(x_aug.T @ (lam[:, None] * x_aug))\n",
" return np.sqrt(np.clip(np.diag(cov), 0.0, None))\n",
"\n",
"\n",
"T = np.asarray(GLM_DATA[\"T\"], dtype=float).reshape(-1)\n",
"xN = np.asarray(GLM_DATA[\"xN\"], dtype=float).reshape(-1)\n",
"yN = np.asarray(GLM_DATA[\"yN\"], dtype=float).reshape(-1)\n",
"spikes_binned = np.asarray(GLM_DATA[\"spikes_binned\"], dtype=float).reshape(-1)\n",
"spiketimes = np.asarray(GLM_DATA[\"spiketimes\"], dtype=float).reshape(-1)\n",
"x_at_spiketimes = np.asarray(GLM_DATA[\"x_at_spiketimes\"], dtype=float).reshape(-1)\n",
"y_at_spiketimes = np.asarray(GLM_DATA[\"y_at_spiketimes\"], dtype=float).reshape(-1)\n",
"sample_rate = 1.0 / float(np.median(np.diff(T)))\n",
"nst = nspikeTrain(spiketimes, name=\"1\", minTime=float(T[0]), maxTime=float(T[-1]), makePlots=-1)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "575fbbc6",
"id": "401a35e2",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 1: Example 1: Tradition Preliminary Analysis\n",
"# Script glm_part1.m\n",
"# MATLAB code to visualize data, fit a GLM model of the relation between\n",
"# spiking and the rat's position, and visualize this model for the\n",
"# Neuroinformatics GLM problem set.\n",
"# The code is initialized with an overly simple GLM model construction.\n",
"# Please improve it!\n",
"#\n",
"# load the rat trajectory and spiking data;\n",
"# SECTION 1: Analysis Examples\n",
"plt.close(\"all\")\n",
"globals().update(_load_example_globals('glm_data.mat'))\n",
"# visualize the raw data\n",
"__tracker.new_figure('figure')\n",
"__tracker.new_figure('figure;')\n",
"__tracker.annotate(\"plot(xN,yN,x_at_spiketimes,y_at_spiketimes,'r.')\")\n",
"ax = plt.gca()\n",
"ax.cla()\n",
"plt.gcf().set_size_inches(8.0, 8.0, forward=True)\n",
"ax.plot(np.ravel(xN), np.ravel(yN), color=(0.0, 0.4470, 0.7410), linewidth=0.6)\n",
"ax.plot(np.ravel(x_at_spiketimes), np.ravel(y_at_spiketimes), 'r.', markersize=2.5)\n",
"ax = plt.gca()\n",
"ax.relim()\n",
"ax.autoscale_view(tight=True)\n",
"ax.set_aspect('equal', adjustable='box')\n",
"ax.tick_params(top=True, right=True, direction='in')\n",
"plt.xlabel('x position (m)')\n",
"plt.ylabel('y position (m)')\n",
"# fit a GLM model to the x and y positions.\n",
"__tracker.new_figure('figure')\n",
"__tracker.new_figure('figure;')\n",
"# visualize your model construct a grid of positions to plot the model against...\n",
"__tracker.new_figure('figure')\n",
"__tracker.new_figure('figure;')\n",
"#\n",
"# compute lambda for each point on this grid using the GLM model\n",
"#\n",
"# plot lambda as a function position over this grid\n",
"__tracker.annotate('plot3(cos(-pi:1e-2:pi),sin(-pi:1e-2:pi),zeros(size(-pi:1e-2:pi)))')\n",
"__tracker.annotate(\"plot(xN,yN,x_at_spiketimes,y_at_spiketimes,'r.')\")\n",
"ax = plt.gca()\n",
"ax.relim()\n",
"ax.autoscale_view(tight=True)\n",
"ax.set_aspect('equal', adjustable='box')\n",
"ax.tick_params(top=True, right=True, direction='in')\n",
"plt.xlabel('x position (m)')\n",
"plt.ylabel('y position (m)')\n",
"# Compare a linear model versus a Gaussian GLM model.\n",
"#\n",
"# Make the KS Plot\n",
"# ******* K-S Plot *******************\n",
"# graph the K-S plot and confidence intervals for the K-S statistic\n",
"#\n",
"# first generate the conditional intensity at each timestep\n",
"# ** Adjust the below line according to your choice of model.\n",
"# remember to include a column of ones to multiply the default constant GLM parameter beta_0**\n",
"#\n",
"# Use your parameter estimates (b) from glmfit along\n",
"# with the covariates you used (xN, yN, ...)\n",
"#\n",
"timestep = 1\n",
"lambdaInt = 0\n",
"j = 0\n",
"__tracker.new_figure('figure')\n",
"__tracker.new_figure('figure;')\n",
"__tracker.annotate(\"plot( ([1:N]-.5)/N, KSSorted, 0:.01:1,0:.01:1, 'g',0:.01:1, [0:.01:1]+1.36/sqrt(N), 'r', 0:.01:1,[0:.01:1]-1.36/sqrt(N), 'r' )\")\n",
"__tracker.annotate(\"title('KS Plot with 95% Confidence Intervals')\")\n",
"__tracker.finalize()"
"print({\"n_samples\": int(T.shape[0]), \"n_spikes\": int(spiketimes.shape[0]), \"sample_rate_hz\": round(sample_rate, 3)})\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "501a6470",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 2: Example 1: Tradition Preliminary Analysis\n",
"x_linear = np.column_stack([xN, yN])\n",
"x_quadratic_centered = np.column_stack(\n",
" [\n",
" xN,\n",
" yN,\n",
" xN**2 - np.mean(xN**2),\n",
" yN**2 - np.mean(yN**2),\n",
" xN * yN - np.mean(xN * yN),\n",
" ]\n",
")\n",
"x_quadratic = np.column_stack([xN, yN, xN**2, yN**2, xN * yN])\n",
"linear_fit = fit_poisson_glm(x_linear, spikes_binned)\n",
"quadratic_fit = fit_poisson_glm(x_quadratic, spikes_binned)\n",
"centered_fit = fit_poisson_glm(x_quadratic_centered, spikes_binned)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee81820c",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 3: visualize the raw data\n",
"fig = _prepare_figure(\"figure; plot(xN,yN,x_at_spiketimes,y_at_spiketimes,'r.')\", figsize=(6.5, 6.0))\n",
"ax = fig.subplots(1, 1)\n",
"ax.plot(xN, yN, color=\"0.65\", linewidth=1.0)\n",
"ax.plot(x_at_spiketimes, y_at_spiketimes, \"r.\", markersize=3.0)\n",
"ax.set_aspect(\"equal\", adjustable=\"box\")\n",
"ax.set_xlabel(\"x position (m)\")\n",
"ax.set_ylabel(\"y position (m)\")\n",
"ax.set_title(\"Rat trajectory with spike locations\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1056f293",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 4: fit a GLM model to the x and y positions\n",
"fig = _prepare_figure(\"figure; errorbar(1:length(b), b, stats.se,'.')\", figsize=(7.0, 4.5))\n",
"ax = fig.subplots(1, 1)\n",
"centered_beta = np.concatenate([[centered_fit.intercept], np.asarray(centered_fit.coefficients, dtype=float)])\n",
"centered_se = _poisson_standard_errors(x_quadratic_centered, centered_fit)\n",
"xpos = np.arange(centered_beta.size)\n",
"ax.errorbar(xpos, centered_beta, yerr=centered_se, fmt=\".\", color=\"tab:blue\", capsize=3)\n",
"ax.set_xticks(xpos, [\"baseline\", \"x\", \"y\", \"x^2\", \"y^2\", \"x*y\"])\n",
"ax.set_ylabel(\"coefficient value\")\n",
"ax.set_title(\"Quadratic GLM coefficients\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "98e73438",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 5: visualize your model\n",
"fig = _prepare_figure(\"figure; mesh(x_new,y_new,lambda,'AlphaData',0)\", figsize=(8.0, 6.5))\n",
"ax = fig.add_subplot(111, projection=\"3d\")\n",
"grid = np.arange(-1.0, 1.01, 0.1)\n",
"x_new, y_new = np.meshgrid(grid, grid)\n",
"X_grid = np.column_stack([x_new.ravel(), y_new.ravel(), x_new.ravel() ** 2, y_new.ravel() ** 2, x_new.ravel() * y_new.ravel()])\n",
"lam_grid = quadratic_fit.predict_rate(X_grid).reshape(x_new.shape)\n",
"lam_grid = np.where((x_new**2 + y_new**2) <= 1.0, lam_grid, np.nan)\n",
"ax.plot_wireframe(x_new, y_new, lam_grid, rstride=1, cstride=1, color=\"tab:blue\", linewidth=0.7)\n",
"theta = np.linspace(-np.pi, np.pi, 400)\n",
"ax.plot(np.cos(theta), np.sin(theta), np.zeros_like(theta), color=\"k\", linewidth=1.0)\n",
"ax.plot(x_at_spiketimes, y_at_spiketimes, np.zeros_like(x_at_spiketimes), \"r.\", markersize=2.0)\n",
"ax.set_xlabel(\"x position (m)\")\n",
"ax.set_ylabel(\"y position (m)\")\n",
"ax.set_zlabel(\"lambda\")\n",
"ax.set_title(\"Quadratic GLM spatial intensity\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "46235a71",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 6: Compare a linear model versus a Gaussian GLM model\n",
"lambda_linear_hz = linear_fit.predict_rate(x_linear) * sample_rate\n",
"lambda_quadratic_hz = quadratic_fit.predict_rate(x_quadratic) * sample_rate\n",
"lambda_linear = Covariate(T, lambda_linear_hz, \"lambda_linear\", \"time\", \"s\", \"Hz\", [\"Linear\"])\n",
"lambda_quadratic = Covariate(T, lambda_quadratic_hz, \"lambda_quadratic\", \"time\", \"s\", \"Hz\", [\"Quadratic\"])\n",
"print(\n",
" {\n",
" \"linear_mean_rate_hz\": round(float(np.mean(lambda_linear_hz)), 4),\n",
" \"quadratic_mean_rate_hz\": round(float(np.mean(lambda_quadratic_hz)), 4),\n",
" }\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a5a09608",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 7: Make the KS Plot\n",
"_, _, x_linear_ks, ks_linear, _ = Analysis.computeKSStats(nst, lambda_linear)\n",
"_, _, x_quadratic_ks, ks_quadratic, _ = Analysis.computeKSStats(nst, lambda_quadratic)\n",
"fig = _prepare_figure(\"figure; plot(([1:N]-.5)/N, KSSorted, ...)\", figsize=(6.5, 5.0))\n",
"ax = fig.subplots(1, 1)\n",
"x_axis = np.asarray(x_linear_ks, dtype=float).reshape(-1)\n",
"ks_linear_arr = np.asarray(ks_linear, dtype=float).reshape(-1)\n",
"ks_quadratic_arr = np.asarray(ks_quadratic, dtype=float).reshape(-1)\n",
"if x_axis.size:\n",
" ci = 1.36 / np.sqrt(x_axis.size)\n",
" ax.plot(x_axis, ks_linear_arr, color=\"tab:blue\", linewidth=1.5, label=\"Linear\")\n",
" ax.plot(x_axis, ks_quadratic_arr, color=\"tab:orange\", linewidth=1.5, label=\"Quadratic\")\n",
" ax.plot([0.0, 1.0], [0.0, 1.0], \"g\", linewidth=1.0)\n",
" ax.plot(x_axis, np.clip(x_axis + ci, 0.0, 1.0), \"r\", linewidth=1.0)\n",
" ax.plot(x_axis, np.clip(x_axis - ci, 0.0, 1.0), \"r\", linewidth=1.0)\n",
"ax.set_xlim(0.0, 1.0)\n",
"ax.set_ylim(0.0, 1.0)\n",
"ax.set_xlabel(\"Uniform CDF\")\n",
"ax.set_ylabel(\"Empirical CDF of Rescaled ISIs\")\n",
"ax.set_title(\"KS Plot with 95% Confidence Intervals\")\n",
"ax.legend(loc=\"lower right\", frameon=False)\n",
"__tracker.finalize()\n"
]
}
],
Expand All @@ -153,4 +239,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
Loading