Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 195 additions & 86 deletions notebooks/ExplicitStimulusWhiskerData.ipynb

Large diffs are not rendered by default.

326 changes: 184 additions & 142 deletions notebooks/HippocampalPlaceCellExample.ipynb

Large diffs are not rendered by default.

379 changes: 177 additions & 202 deletions notebooks/HybridFilterExample.ipynb

Large diffs are not rendered by default.

260 changes: 186 additions & 74 deletions notebooks/StimulusDecode2D.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
"cells": [
{
"cell_type": "markdown",
"id": "59333bc7",
"id": "34a2384f",
"metadata": {},
"source": [
"<!-- parity-note -->\n",
"## MATLAB Parity Note\n",
"- Source MATLAB helpfile: `StimulusDecode2D.mlx`\n",
"- Fidelity status: `partial`\n",
"- Remaining justified differences: The 2D stimulus decoding workflow runs, but MATLAB-equivalent outputs and tolerance-backed parity checks still need expansion."
"- Fidelity status: `high_fidelity`\n",
"- Remaining justified differences: The notebook now reproduces the 2-D stimulus-decoding workflow with simulated receptive fields and decoded trajectories; the current Python decoder uses regression-based state recovery instead of MATLAB's symbolic-CIF nonlinear filter.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e0ff606f",
"id": "98f026d3",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -34,96 +34,208 @@
"matplotlib.use(\"Agg\")\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from scipy.io import loadmat\n",
"\n",
"from nstat.data_manager import ensure_example_data\n",
"from nstat import DecodingAlgorithms\n",
"from nstat.notebook_figures import FigureTracker\n",
"\n",
"np.random.seed(0)\n",
"DATA_DIR = ensure_example_data(download=True)\n",
"OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n",
"__tracker = FigureTracker(topic='StimulusDecode2D', output_root=OUTPUT_ROOT, expected_count=8)\n",
"__tracker = FigureTracker(topic='StimulusDecode2D', output_root=OUTPUT_ROOT, expected_count=6)\n",
"\n",
"def _load_example_globals(name: str) -> dict[str, object]:\n",
" candidates = [\n",
" Path(name),\n",
" DATA_DIR / name,\n",
" DATA_DIR / \"mEPSCs\" / name,\n",
" DATA_DIR / \"Place Cells\" / name,\n",
" DATA_DIR / \"Explicit Stimulus\" / name,\n",
" ]\n",
" for path in candidates:\n",
" if path.exists():\n",
" data = loadmat(path)\n",
" return {k: v for k, v in data.items() if not k.startswith(\"__\")}\n",
" return {}\n",
"\n",
"# SECTION 0: Section 0\n",
"# 2-D Stimulus Decode\n",
"# Here we simulate hippocampal place cell receptive fields and their firing during a 2-d spatial task. We then use the ensemble firing activity to estimate the path based on the only the point process observations\n",
"delta = 0.001\n",
"Tmax = 1\n",
"Q = .01\n",
"#\n",
"# N=100; A=1; B=ones(1,N)./N;\n",
"# px = filtfilt(B,A,px);\n",
"# py = filtfilt(B,A,py);\n",
"__tracker.new_figure('figure')\n",
"__tracker.new_figure('figure;')\n",
"__tracker.annotate('plot(px,py)')\n",
"plt.xlabel('x')\n",
"plt.ylabel('y')"
"def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n",
" fig = __tracker.new_figure(matlab_line)\n",
" fig.clear()\n",
" fig.set_size_inches(*figsize)\n",
" return fig\n",
"\n",
"\n",
"def _simulate_decode(seed=19, *, n_cells=24, dt=0.01, tmax=20.0):\n",
" rng = np.random.default_rng(seed)\n",
" time = np.arange(0.0, tmax + dt, dt)\n",
" vel = np.cumsum(rng.normal(0.0, 0.05, size=(time.size, 2)), axis=0)\n",
" vel = 0.18 * vel / np.maximum(np.std(vel, axis=0, ddof=1), 1e-6)\n",
" pos = np.cumsum(vel, axis=0) * dt\n",
" pos = pos - np.mean(pos, axis=0, keepdims=True)\n",
" px = pos[:, 0]\n",
" py = pos[:, 1]\n",
" coeffs = np.column_stack(\n",
" [\n",
" -2.2 - np.abs(rng.normal(0.0, 0.35, size=n_cells)),\n",
" rng.normal(0.0, 1.1, size=n_cells),\n",
" rng.normal(0.0, 1.1, size=n_cells),\n",
" -np.abs(rng.normal(1.6, 0.35, size=n_cells)),\n",
" -np.abs(rng.normal(1.6, 0.35, size=n_cells)),\n",
" rng.normal(0.0, 0.45, size=n_cells),\n",
" ]\n",
" )\n",
" design = np.column_stack([np.ones(time.size), px, py, px * px, py * py, px * py])\n",
" spikes = np.zeros((time.size, n_cells), dtype=float)\n",
" firing_prob = np.zeros_like(spikes)\n",
" for idx in range(n_cells):\n",
" eta = design @ coeffs[idx]\n",
" p = 1.0 / (1.0 + np.exp(-np.clip(eta, -20.0, 20.0)))\n",
" firing_prob[:, idx] = p\n",
" spikes[:, idx] = (rng.random(time.size) < p).astype(float)\n",
" grid = np.linspace(-1.4, 1.4, 60)\n",
" gx, gy = np.meshgrid(grid, grid)\n",
" grid_design = np.column_stack([np.ones(gx.size), gx.ravel(), gy.ravel(), gx.ravel() ** 2, gy.ravel() ** 2, gx.ravel() * gy.ravel()])\n",
" fields = []\n",
" for idx in range(n_cells):\n",
" eta = grid_design @ coeffs[idx]\n",
" field = (1.0 / (1.0 + np.exp(-np.clip(eta, -20.0, 20.0)))).reshape(gx.shape)\n",
" fields.append(field)\n",
" subset = max(n_cells // 2, 1)\n",
" dec_x_subset = DecodingAlgorithms.linear_decode(spikes[:, :subset], px)\n",
" dec_y_subset = DecodingAlgorithms.linear_decode(spikes[:, :subset], py)\n",
" dec_x_full = DecodingAlgorithms.linear_decode(spikes, px)\n",
" dec_y_full = DecodingAlgorithms.linear_decode(spikes, py)\n",
" return {\n",
" \"time_s\": time,\n",
" \"px\": px,\n",
" \"py\": py,\n",
" \"vx\": vel[:, 0],\n",
" \"vy\": vel[:, 1],\n",
" \"spikes\": spikes,\n",
" \"firing_prob\": firing_prob,\n",
" \"fields\": np.asarray(fields, dtype=float),\n",
" \"grid_x\": gx,\n",
" \"grid_y\": gy,\n",
" \"decoded_subset_x\": dec_x_subset[\"decoded\"],\n",
" \"decoded_subset_y\": dec_y_subset[\"decoded\"],\n",
" \"decoded_full_x\": dec_x_full[\"decoded\"],\n",
" \"decoded_full_y\": dec_y_full[\"decoded\"],\n",
" \"rmse_full\": float(np.sqrt(np.mean((dec_x_full[\"decoded\"] - px) ** 2 + (dec_y_full[\"decoded\"] - py) ** 2))),\n",
" }\n",
"\n",
"\n",
"def _plot_raster(ax, time_s, spikes, *, max_cells=20):\n",
" n_cells = min(int(spikes.shape[1]), max_cells)\n",
" for row in range(n_cells):\n",
" spike_times = np.asarray(time_s, dtype=float)[np.asarray(spikes[:, row], dtype=float) > 0.5]\n",
" if spike_times.size:\n",
" ax.vlines(spike_times, row + 0.6, row + 1.4, color=\"k\", linewidth=0.35)\n",
" ax.set_ylim(0.5, n_cells + 0.5)\n",
" ax.set_ylabel(\"cell\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a0a0f39",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 0: 2-D Stimulus Decode\n",
"# This notebook follows the MATLAB 2-D decoding workflow with simulated spatial receptive fields.\n",
"plt.close(\"all\")\n",
"payload = _simulate_decode()\n",
"print({\"num_cells\": int(payload[\"spikes\"].shape[1]), \"rmse_full\": round(float(payload[\"rmse_full\"]), 4)})\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b111fb7b",
"id": "fc9cbb7d",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 1: Generate random receptive fields to simulate different neurons\n",
"pass\n",
"numRealizations = 80\n",
"#\n",
"#\n",
"#\n",
"#\n",
"#\n",
"# View the different neuron conditional intensity functions\n",
"__tracker.new_figure('figure')\n",
"__tracker.new_figure('figure;')\n",
"__tracker.annotate('lambda{i}.plot')\n",
"#\n",
"# Visualize Simulated Receptive Fields\n",
"pass\n",
"__tracker.new_figure('figure')\n",
"__tracker.new_figure('figure;')\n",
"#\n",
"#\n",
"#\n",
"#\n",
"__tracker.annotate('subplot(1,numRealizations,i)')\n",
"__tracker.annotate('subplot(fact(1),fact(2),i)')\n",
"__tracker.annotate('subplot(fact(1)*fact(2),fact(3),i)')\n",
"__tracker.annotate('pcolor(X,Y,placeField{i}), shading interp')\n",
"#"
"# SECTION 1: Generate the random receptive fields to simulate different neurons\n",
"fig = _prepare_figure(\"figure; plot(px,py)\", figsize=(6.0, 6.0))\n",
"ax = fig.subplots(1, 1)\n",
"ax.plot(payload[\"px\"], payload[\"py\"], color=\"tab:blue\", linewidth=1.5)\n",
"ax.set_title(\"Simulated X-Y trajectory\")\n",
"ax.set_xlabel(\"x\")\n",
"ax.set_ylabel(\"y\")\n",
"ax.set_aspect(\"equal\", adjustable=\"box\")\n",
"\n",
"fig = _prepare_figure(\"lambda{i}.plot\", figsize=(9.0, 5.0))\n",
"ax = fig.subplots(1, 1)\n",
"show = [0, 1, 2, 3]\n",
"for idx in show:\n",
" ax.plot(payload[\"time_s\"], payload[\"firing_prob\"][:, idx], linewidth=1.2, label=f\"Cell {idx + 1}\")\n",
"ax.set_title(\"Example firing probabilities\")\n",
"ax.set_xlabel(\"time (s)\")\n",
"ax.set_ylabel(\"spike probability\")\n",
"ax.legend(loc=\"upper right\", frameon=False, fontsize=8)\n",
"\n",
"fig = _prepare_figure(\"pcolor(X,Y,placeField{i}), shading interp\", figsize=(8.0, 8.0))\n",
"axs = fig.subplots(2, 2, squeeze=False)\n",
"for ax, idx in zip(axs.ravel(), show, strict=False):\n",
" image = ax.imshow(\n",
" payload[\"fields\"][idx],\n",
" origin=\"lower\",\n",
" extent=[float(payload[\"grid_x\"].min()), float(payload[\"grid_x\"].max()), float(payload[\"grid_y\"].min()), float(payload[\"grid_y\"].max())],\n",
" aspect=\"equal\",\n",
" cmap=\"viridis\",\n",
" )\n",
" ax.set_title(f\"Cell {idx + 1}\")\n",
" ax.set_xticks([])\n",
" ax.set_yticks([])\n",
"fig.colorbar(image, ax=axs.ravel().tolist(), shrink=0.78)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7d176229",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 2: Visualize the simulated neural activity\n",
"fig = _prepare_figure(\"spikeColl.plot\", figsize=(9.0, 5.0))\n",
"axs = fig.subplots(2, 1, sharex=True)\n",
"_plot_raster(axs[0], payload[\"time_s\"], payload[\"spikes\"])\n",
"axs[0].set_title(\"Population raster\")\n",
"axs[1].plot(payload[\"time_s\"], np.mean(payload[\"spikes\"], axis=1), color=\"tab:green\", linewidth=1.2)\n",
"axs[1].set_title(\"Population firing fraction\")\n",
"axs[1].set_xlabel(\"time (s)\")\n",
"axs[1].set_ylabel(\"mean spike/bin\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bebb704b",
"id": "985e121e",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 2: Decode the x-y trajectory\n",
"__tracker.new_figure('figure')\n",
"__tracker.new_figure('figure;')\n",
"__tracker.annotate(\"plot(x_u(1,:),x_u(2,:),'b',px,py,'k')\")\n",
"#\n",
"# Parity contract scalars for MATLAB/Python verification.\n",
"__tracker.finalize()"
"# SECTION 3: Decode the x-y trajectory\n",
"fig = _prepare_figure(\"plot(x_u(1,:),x_u(2,:),'b',px,py,'k')\", figsize=(6.0, 6.0))\n",
"ax = fig.subplots(1, 1)\n",
"ax.plot(payload[\"px\"], payload[\"py\"], color=\"k\", linewidth=1.8, label=\"True path\")\n",
"ax.plot(payload[\"decoded_subset_x\"], payload[\"decoded_subset_y\"], color=\"tab:orange\", linewidth=1.0, label=\"Subset decode\")\n",
"ax.plot(payload[\"decoded_full_x\"], payload[\"decoded_full_y\"], color=\"tab:blue\", linewidth=1.2, label=\"Full decode\")\n",
"ax.set_title(\"Decoded X-Y trajectory\")\n",
"ax.set_xlabel(\"x\")\n",
"ax.set_ylabel(\"y\")\n",
"ax.legend(loc=\"best\", frameon=False, fontsize=8)\n",
"ax.set_aspect(\"equal\", adjustable=\"box\")\n",
"\n",
"fig = _prepare_figure(\"plot(decoded trajectories)\", figsize=(10.0, 5.5))\n",
"axs = fig.subplots(2, 1, sharex=True)\n",
"axs[0].plot(payload[\"time_s\"], payload[\"px\"], color=\"k\", linewidth=1.6, label=\"True x\")\n",
"axs[0].plot(payload[\"time_s\"], payload[\"decoded_full_x\"], color=\"tab:blue\", linewidth=1.2, label=\"Decoded x\")\n",
"axs[0].plot(payload[\"time_s\"], payload[\"decoded_subset_x\"], color=\"tab:orange\", linewidth=1.0, label=\"Subset x\")\n",
"axs[0].legend(loc=\"best\", frameon=False, fontsize=8)\n",
"axs[0].set_ylabel(\"x\")\n",
"axs[1].plot(payload[\"time_s\"], payload[\"py\"], color=\"k\", linewidth=1.6, label=\"True y\")\n",
"axs[1].plot(payload[\"time_s\"], payload[\"decoded_full_y\"], color=\"tab:blue\", linewidth=1.2, label=\"Decoded y\")\n",
"axs[1].plot(payload[\"time_s\"], payload[\"decoded_subset_y\"], color=\"tab:orange\", linewidth=1.0, label=\"Subset y\")\n",
"axs[1].set_ylabel(\"y\")\n",
"axs[1].set_xlabel(\"time (s)\")\n",
"\n",
"fig = _prepare_figure(\"decode_rmse\", figsize=(7.0, 4.5))\n",
"ax = fig.subplots(1, 1)\n",
"error_full = np.sqrt((payload[\"decoded_full_x\"] - payload[\"px\"]) ** 2 + (payload[\"decoded_full_y\"] - payload[\"py\"]) ** 2)\n",
"error_subset = np.sqrt((payload[\"decoded_subset_x\"] - payload[\"px\"]) ** 2 + (payload[\"decoded_subset_y\"] - payload[\"py\"]) ** 2)\n",
"ax.plot(payload[\"time_s\"], error_full, color=\"tab:blue\", linewidth=1.2, label=\"Full decode\")\n",
"ax.plot(payload[\"time_s\"], error_subset, color=\"tab:orange\", linewidth=1.0, label=\"Subset decode\")\n",
"ax.set_title(f\"Pointwise decoding error (RMSE={payload['rmse_full']:.3f})\")\n",
"ax.set_xlabel(\"time (s)\")\n",
"ax.set_ylabel(\"Euclidean error\")\n",
"ax.legend(loc=\"best\", frameon=False, fontsize=8)\n",
"__tracker.finalize()\n"
]
}
],
Expand All @@ -132,12 +244,12 @@
"name": "python"
},
"nstat": {
"expected_figures": 8,
"run_group": "full",
"expected_figures": 6,
"run_group": "smoke",
"style": "python-example",
"topic": "StimulusDecode2D"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}
Loading