From d9c2122e2c690ba7f2523d0e9817f5742b6a3317 Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Sun, 8 Mar 2026 13:20:48 -0400 Subject: [PATCH] Tighten nonlinear decoding parity --- notebooks/StimulusDecode2D.ipynb | 191 +++++++----- nstat/cif.py | 282 +++++++++++++++++- nstat/decoding_algorithms.py | 106 +++++-- parity/class_fidelity.yml | 18 +- parity/notebook_fidelity.yml | 11 +- parity/report.md | 9 +- pyproject.toml | 1 + .../fixtures/matlab_gold/cif_exactness.mat | Bin 614 -> 1157 bytes .../nonlinear_decode_exactness.mat | Bin 0 -> 1097 bytes tests/test_decoding_algorithms_fidelity.py | 17 ++ tests/test_matlab_gold_fixtures.py | 35 +++ tests/test_notebook_fidelity_audit.py | 2 +- tests/test_notebook_parity_notes.py | 2 +- tests/test_parity_manifest.py | 20 ++ tests/test_parity_report.py | 4 +- .../build_helpfile_fidelity_notebooks.py | 177 +++++++---- tools/notebooks/parity_notes.yml | 4 +- .../matlab/export_matlab_gold_fixtures.m | 90 ++++++ 18 files changed, 787 insertions(+), 182 deletions(-) create mode 100644 tests/parity/fixtures/matlab_gold/nonlinear_decode_exactness.mat diff --git a/notebooks/StimulusDecode2D.ipynb b/notebooks/StimulusDecode2D.ipynb index ecff1c14..a5b6b0b7 100644 --- a/notebooks/StimulusDecode2D.ipynb +++ b/notebooks/StimulusDecode2D.ipynb @@ -2,20 +2,20 @@ "cells": [ { "cell_type": "markdown", - "id": "249987ba", + "id": "0e2c1f81", "metadata": {}, "source": [ "\n", "## MATLAB Parity Note\n", "- Source MATLAB helpfile: `StimulusDecode2D.mlx`\n", - "- Fidelity status: `partial`\n", - "- Remaining justified differences: The notebook reproduces the MATLAB section order, figure inventory, simulated receptive fields, and decoded-trajectory presentation, but the current Python decoder still uses regression-based state recovery instead of MATLAB's symbolic-CIF nonlinear filter.\n" + "- Fidelity status: `high_fidelity`\n", + "- Remaining justified differences: The notebook now follows the MATLAB nonlinear-CIF decoding workflow and uses `DecodingAlgorithms.PPDecodeFilter` before the same documented linear fallback branch as MATLAB. Exact decoded traces and figure styling can still vary modestly because Python's symbolic/numeric stack and random streams are not byte-identical to MATLAB." ] }, { "cell_type": "code", "execution_count": null, - "id": "3989e7ba", + "id": "ef0ce881", "metadata": {}, "outputs": [], "source": [ @@ -35,7 +35,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "from nstat import DecodingAlgorithms\n", + "from nstat import CIF, Covariate, DecodingAlgorithms, SignalObj, nstColl\n", "from nstat.notebook_figures import FigureTracker\n", "\n", "np.random.seed(0)\n", @@ -50,62 +50,103 @@ " return fig\n", "\n", "\n", - "def _simulate_decode(seed=19, *, n_cells=24, dt=0.01, tmax=20.0):\n", + "def _subplot_grid(count):\n", + " rows = max(int(np.floor(np.sqrt(count))), 1)\n", + " cols = int(np.ceil(count / rows))\n", + " return rows, cols\n", + "\n", + "\n", + "def _simulate_decode(seed=0, *, num_realizations=80, delta=0.001, tmax=1.0):\n", " rng = np.random.default_rng(seed)\n", - " time = np.arange(0.0, tmax + dt, dt)\n", - " vel = np.cumsum(rng.normal(0.0, 0.05, size=(time.size, 2)), axis=0)\n", - " vel = 0.18 * vel / np.maximum(np.std(vel, axis=0, ddof=1), 1e-6)\n", - " pos = np.cumsum(vel, axis=0) * dt\n", - " pos = pos - np.mean(pos, axis=0, keepdims=True)\n", - " px = pos[:, 0]\n", - " py = pos[:, 1]\n", - " coeffs = np.column_stack(\n", - " [\n", - " -2.2 - np.abs(rng.normal(0.0, 0.35, size=n_cells)),\n", - " rng.normal(0.0, 1.1, size=n_cells),\n", - " rng.normal(0.0, 1.1, size=n_cells),\n", - " -np.abs(rng.normal(1.6, 0.35, size=n_cells)),\n", - " -np.abs(rng.normal(1.6, 0.35, size=n_cells)),\n", - " rng.normal(0.0, 0.45, size=n_cells),\n", - " ]\n", - " )\n", + " time = np.arange(0.0, tmax + delta, delta)\n", + " q_drive = 0.01\n", + " innovations = q_drive * rng.standard_normal((2, time.size))\n", + " vx = np.cumsum(innovations[0])\n", + " vy = np.cumsum(innovations[1])\n", + " vel_sig = SignalObj(time, np.column_stack([vx, vy]), \"vel\", \"time\", \"s\", \"\", [\"vx\", \"vy\"])\n", + " pos_sig = vel_sig.integral()\n", + " pos_data = np.asarray(pos_sig.data, dtype=float)\n", + " px = pos_data[:, 0]\n", + " py = pos_data[:, 1]\n", + "\n", + " coeffs = -np.abs(rng.standard_normal((num_realizations, 5)))\n", + " coeffs = np.column_stack([-2.0 * np.abs(rng.standard_normal(num_realizations)), coeffs])\n", " design = np.column_stack([np.ones(time.size), px, py, px * px, py * py, px * py])\n", - " spikes = np.zeros((time.size, n_cells), dtype=float)\n", - " firing_prob = np.zeros_like(spikes)\n", - " for idx in range(n_cells):\n", + "\n", + " lambda_rates_hz = np.zeros((time.size, num_realizations), dtype=float)\n", + " lambda_cifs = []\n", + " spike_trains = []\n", + " for idx in range(num_realizations):\n", " eta = design @ coeffs[idx]\n", - " p = 1.0 / (1.0 + np.exp(-np.clip(eta, -20.0, 20.0)))\n", - " firing_prob[:, idx] = p\n", - " spikes[:, idx] = (rng.random(time.size) < p).astype(float)\n", + " exp_eta = np.exp(np.clip(eta, -20.0, 20.0))\n", + " lambda_delta = exp_eta / (1.0 + exp_eta)\n", + " lambda_rates_hz[:, idx] = lambda_delta / delta\n", + " lambda_cov = Covariate(time, lambda_rates_hz[:, idx], \"\\\\Lambda(t)\", \"time\", \"s\", \"Hz\", [f\"lambda_{idx + 1}\"])\n", + " spike_coll = CIF.simulateCIFByThinningFromLambda(lambda_cov, 1, seed=seed + idx + 1)\n", + " train = spike_coll.getNST(1)\n", + " train.setName(str(idx + 1))\n", + " spike_trains.append(train)\n", + " lambda_cifs.append(CIF(coeffs[idx], [\"1\", \"x\", \"y\", \"x^2\", \"y^2\", \"x*y\"], [\"x\", \"y\"], fitType=\"binomial\"))\n", + "\n", + " spike_coll = nstColl(spike_trains)\n", + " spike_coll.resample(1.0 / delta)\n", + " labels = list(range(1, num_realizations + 1))\n", + " dN = spike_coll.dataToMatrix(labels, delta, float(time[0]), float(time[-1])).T\n", + "\n", + " vx_var = float(np.var(px[1:] - px[:-1]))\n", + " vy_var = float(np.var(py[1:] - py[:-1]))\n", + " q_cov = np.array([[vx_var, 0.0], [0.0, vy_var]], dtype=float)\n", + " p0 = 0.1 * np.eye(2, dtype=float)\n", + " a_mat = np.eye(2, dtype=float)\n", + " decode_method = \"PPDecodeFilter\"\n", + " decode_error = \"\"\n", + " try:\n", + " x_p, pe_p, x_u, pe_u, *_ = DecodingAlgorithms.PPDecodeFilter(a_mat, q_cov, p0, dN, lambda_cifs, delta)\n", + " except Exception as exc:\n", + " decode_method = \"PPDecodeFilterLinear\"\n", + " decode_error = f\"{type(exc).__name__}: {exc}\"\n", + " mu_linear = coeffs[:, 0]\n", + " beta_linear = coeffs[:, 1:3].T\n", + " x_p, pe_p, x_u, pe_u, *_ = DecodingAlgorithms.PPDecodeFilterLinear(a_mat, q_cov, dN, mu_linear, beta_linear, \"binomial\", delta)\n", + "\n", " grid = np.linspace(-1.4, 1.4, 60)\n", " gx, gy = np.meshgrid(grid, grid)\n", " grid_design = np.column_stack([np.ones(gx.size), gx.ravel(), gy.ravel(), gx.ravel() ** 2, gy.ravel() ** 2, gx.ravel() * gy.ravel()])\n", " fields = []\n", - " for idx in range(n_cells):\n", + " for idx in range(num_realizations):\n", " eta = grid_design @ coeffs[idx]\n", " field = (1.0 / (1.0 + np.exp(-np.clip(eta, -20.0, 20.0)))).reshape(gx.shape)\n", " fields.append(field)\n", - " subset = max(n_cells // 2, 1)\n", - " dec_x_subset = DecodingAlgorithms.linear_decode(spikes[:, :subset], px)\n", - " dec_y_subset = DecodingAlgorithms.linear_decode(spikes[:, :subset], py)\n", - " dec_x_full = DecodingAlgorithms.linear_decode(spikes, px)\n", - " dec_y_full = DecodingAlgorithms.linear_decode(spikes, py)\n", + " n_common = min(px.size, x_u.shape[1])\n", + " decode_rmse = float(\n", + " np.sqrt(\n", + " np.mean(\n", + " (x_u[0, :n_common] - px[:n_common]) ** 2\n", + " + (x_u[1, :n_common] - py[:n_common]) ** 2\n", + " )\n", + " )\n", + " )\n", " return {\n", " \"time_s\": time,\n", " \"px\": px,\n", " \"py\": py,\n", - " \"vx\": vel[:, 0],\n", - " \"vy\": vel[:, 1],\n", - " \"spikes\": spikes,\n", - " \"firing_prob\": firing_prob,\n", + " \"vx\": vx,\n", + " \"vy\": vy,\n", + " \"spikes\": dN.T,\n", + " \"lambda_rates_hz\": lambda_rates_hz,\n", " \"fields\": np.asarray(fields, dtype=float),\n", " \"grid_x\": gx,\n", " \"grid_y\": gy,\n", - " \"decoded_subset_x\": dec_x_subset[\"decoded\"],\n", - " \"decoded_subset_y\": dec_y_subset[\"decoded\"],\n", - " \"decoded_full_x\": dec_x_full[\"decoded\"],\n", - " \"decoded_full_y\": dec_y_full[\"decoded\"],\n", - " \"rmse_full\": float(np.sqrt(np.mean((dec_x_full[\"decoded\"] - px) ** 2 + (dec_y_full[\"decoded\"] - py) ** 2))),\n", + " \"decoded_x\": x_u[0, :n_common],\n", + " \"decoded_y\": x_u[1, :n_common],\n", + " \"predicted_x\": x_p[0, 1 : n_common + 1],\n", + " \"predicted_y\": x_p[1, 1 : n_common + 1],\n", + " \"decode_rmse\": decode_rmse,\n", + " \"decode_method\": decode_method,\n", + " \"decode_error\": decode_error,\n", + " \"coeffs\": coeffs,\n", + " \"state_cov\": pe_u[:, :, :n_common],\n", + " \"num_cells\": num_realizations,\n", " }\n", "\n", "\n", @@ -122,7 +163,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5b0164f", + "id": "318e16a5", "metadata": {}, "outputs": [], "source": [ @@ -130,13 +171,20 @@ "# This notebook follows the MATLAB 2-D decoding workflow with simulated spatial receptive fields.\n", "plt.close(\"all\")\n", "payload = _simulate_decode()\n", - "print({\"num_cells\": int(payload[\"spikes\"].shape[1]), \"rmse_full\": round(float(payload[\"rmse_full\"]), 4)})\n" + "print(\n", + " {\n", + " \"num_cells\": int(payload[\"num_cells\"]),\n", + " \"decode_method\": payload[\"decode_method\"],\n", + " \"decode_rmse\": round(float(payload[\"decode_rmse\"]), 4),\n", + " \"fallback_error\": payload[\"decode_error\"] or \"\",\n", + " }\n", + ")\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "44c9df66", + "id": "ea4c43c8", "metadata": {}, "outputs": [], "source": [ @@ -151,17 +199,20 @@ "\n", "fig = _prepare_figure(\"lambda{i}.plot\", figsize=(9.0, 5.0))\n", "ax = fig.subplots(1, 1)\n", - "show = [0, 1, 2, 3]\n", - "for idx in show:\n", - " ax.plot(payload[\"time_s\"], payload[\"firing_prob\"][:, idx], linewidth=1.2, label=f\"Cell {idx + 1}\")\n", - "ax.set_title(\"Example firing probabilities\")\n", + "for idx in range(min(payload[\"num_cells\"], 24)):\n", + " ax.plot(payload[\"time_s\"], payload[\"lambda_rates_hz\"][:, idx], linewidth=0.7, alpha=0.5)\n", + "ax.set_title(\"Conditional intensity functions\")\n", "ax.set_xlabel(\"time (s)\")\n", - "ax.set_ylabel(\"spike probability\")\n", - "ax.legend(loc=\"upper right\", frameon=False, fontsize=8)\n", + "ax.set_ylabel(\"Hz\")\n", "\n", "fig = _prepare_figure(\"pcolor(X,Y,placeField{i}), shading interp\", figsize=(8.0, 8.0))\n", - "axs = fig.subplots(2, 2, squeeze=False)\n", - "for ax, idx in zip(axs.ravel(), show, strict=False):\n", + "n_rows, n_cols = _subplot_grid(payload[\"num_cells\"])\n", + "axs = fig.subplots(n_rows, n_cols, squeeze=False)\n", + "image = None\n", + "for idx, ax in enumerate(axs.ravel()):\n", + " if idx >= payload[\"num_cells\"]:\n", + " ax.axis(\"off\")\n", + " continue\n", " image = ax.imshow(\n", " payload[\"fields\"][idx],\n", " origin=\"lower\",\n", @@ -172,13 +223,14 @@ " ax.set_title(f\"Cell {idx + 1}\")\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", - "fig.colorbar(image, ax=axs.ravel().tolist(), shrink=0.78)\n" + "if image is not None:\n", + " fig.colorbar(image, ax=axs.ravel().tolist(), shrink=0.78)\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "94f8bcd5", + "id": "19e02f21", "metadata": {}, "outputs": [], "source": [ @@ -196,7 +248,7 @@ { "cell_type": "code", "execution_count": null, - "id": "71c1f658", + "id": "f1d2fde2", "metadata": {}, "outputs": [], "source": [ @@ -204,8 +256,7 @@ "fig = _prepare_figure(\"plot(x_u(1,:),x_u(2,:),'b',px,py,'k')\", figsize=(6.0, 6.0))\n", "ax = fig.subplots(1, 1)\n", "ax.plot(payload[\"px\"], payload[\"py\"], color=\"k\", linewidth=1.8, label=\"True path\")\n", - "ax.plot(payload[\"decoded_subset_x\"], payload[\"decoded_subset_y\"], color=\"tab:orange\", linewidth=1.0, label=\"Subset decode\")\n", - "ax.plot(payload[\"decoded_full_x\"], payload[\"decoded_full_y\"], color=\"tab:blue\", linewidth=1.2, label=\"Full decode\")\n", + "ax.plot(payload[\"decoded_x\"], payload[\"decoded_y\"], color=\"tab:blue\", linewidth=1.2, label=payload[\"decode_method\"])\n", "ax.set_title(\"Decoded X-Y trajectory\")\n", "ax.set_xlabel(\"x\")\n", "ax.set_ylabel(\"y\")\n", @@ -215,23 +266,23 @@ "fig = _prepare_figure(\"plot(decoded trajectories)\", figsize=(10.0, 5.5))\n", "axs = fig.subplots(2, 1, sharex=True)\n", "axs[0].plot(payload[\"time_s\"], payload[\"px\"], color=\"k\", linewidth=1.6, label=\"True x\")\n", - "axs[0].plot(payload[\"time_s\"], payload[\"decoded_full_x\"], color=\"tab:blue\", linewidth=1.2, label=\"Decoded x\")\n", - "axs[0].plot(payload[\"time_s\"], payload[\"decoded_subset_x\"], color=\"tab:orange\", linewidth=1.0, label=\"Subset x\")\n", + "axs[0].plot(payload[\"time_s\"][: payload[\"decoded_x\"].size], payload[\"decoded_x\"], color=\"tab:blue\", linewidth=1.2, label=\"Decoded x\")\n", + "axs[0].plot(payload[\"time_s\"][: payload[\"predicted_x\"].size], payload[\"predicted_x\"], color=\"tab:orange\", linewidth=1.0, label=\"Predicted x\")\n", "axs[0].legend(loc=\"best\", frameon=False, fontsize=8)\n", "axs[0].set_ylabel(\"x\")\n", "axs[1].plot(payload[\"time_s\"], payload[\"py\"], color=\"k\", linewidth=1.6, label=\"True y\")\n", - "axs[1].plot(payload[\"time_s\"], payload[\"decoded_full_y\"], color=\"tab:blue\", linewidth=1.2, label=\"Decoded y\")\n", - "axs[1].plot(payload[\"time_s\"], payload[\"decoded_subset_y\"], color=\"tab:orange\", linewidth=1.0, label=\"Subset y\")\n", + "axs[1].plot(payload[\"time_s\"][: payload[\"decoded_y\"].size], payload[\"decoded_y\"], color=\"tab:blue\", linewidth=1.2, label=\"Decoded y\")\n", + "axs[1].plot(payload[\"time_s\"][: payload[\"predicted_y\"].size], payload[\"predicted_y\"], color=\"tab:orange\", linewidth=1.0, label=\"Predicted y\")\n", "axs[1].set_ylabel(\"y\")\n", "axs[1].set_xlabel(\"time (s)\")\n", "\n", "fig = _prepare_figure(\"decode_rmse\", figsize=(7.0, 4.5))\n", "ax = fig.subplots(1, 1)\n", - "error_full = np.sqrt((payload[\"decoded_full_x\"] - payload[\"px\"]) ** 2 + (payload[\"decoded_full_y\"] - payload[\"py\"]) ** 2)\n", - "error_subset = np.sqrt((payload[\"decoded_subset_x\"] - payload[\"px\"]) ** 2 + (payload[\"decoded_subset_y\"] - payload[\"py\"]) ** 2)\n", - "ax.plot(payload[\"time_s\"], error_full, color=\"tab:blue\", linewidth=1.2, label=\"Full decode\")\n", - "ax.plot(payload[\"time_s\"], error_subset, color=\"tab:orange\", linewidth=1.0, label=\"Subset decode\")\n", - "ax.set_title(f\"Pointwise decoding error (RMSE={payload['rmse_full']:.3f})\")\n", + "error_decode = np.sqrt((payload[\"decoded_x\"] - payload[\"px\"][: payload[\"decoded_x\"].size]) ** 2 + (payload[\"decoded_y\"] - payload[\"py\"][: payload[\"decoded_y\"].size]) ** 2)\n", + "error_predict = np.sqrt((payload[\"predicted_x\"] - payload[\"px\"][: payload[\"predicted_x\"].size]) ** 2 + (payload[\"predicted_y\"] - payload[\"py\"][: payload[\"predicted_y\"].size]) ** 2)\n", + "ax.plot(payload[\"time_s\"][: error_decode.size], error_decode, color=\"tab:blue\", linewidth=1.2, label=\"Filtered decode\")\n", + "ax.plot(payload[\"time_s\"][: error_predict.size], error_predict, color=\"tab:orange\", linewidth=1.0, label=\"Predicted decode\")\n", + "ax.set_title(f\"Pointwise decoding error (RMSE={payload['decode_rmse']:.3f})\")\n", "ax.set_xlabel(\"time (s)\")\n", "ax.set_ylabel(\"Euclidean error\")\n", "ax.legend(loc=\"best\", frameon=False, fontsize=8)\n", @@ -252,4 +303,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/nstat/cif.py b/nstat/cif.py index f33d9f2c..afbd10a2 100644 --- a/nstat/cif.py +++ b/nstat/cif.py @@ -1,9 +1,11 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Sequence +from typing import Any, Sequence import numpy as np +import sympy as sp +from sympy.parsing.sympy_parser import convert_xor, parse_expr, standard_transformations from .history import History from .signal import Covariate @@ -12,6 +14,9 @@ from .nspikeTrain import nspikeTrain +_SYMPY_TRANSFORMS = standard_transformations + (convert_xor,) + + def _as_1d_float(values) -> np.ndarray: return np.asarray(values if values is not None else [], dtype=float).reshape(-1) @@ -83,6 +88,148 @@ def _sigmoid(values: np.ndarray) -> np.ndarray: return 1.0 / (1.0 + np.exp(-np.clip(values, -20.0, 20.0))) +def _ordered_independent_names( + xnames: Sequence[str] | None, + stim_names: Sequence[str] | None, +) -> tuple[str, ...]: + ordered: list[str] = [] + for values in (stim_names or [], xnames or []): + for raw in values: + expr = str(raw).strip() + if not expr: + continue + parsed = parse_expr(expr, transformations=_SYMPY_TRANSFORMS, evaluate=True) + for symbol in sorted(parsed.free_symbols, key=lambda item: str(item)): + name = str(symbol) + if name not in ordered: + ordered.append(name) + return tuple(ordered) + + +def _zero_row(width: int) -> np.ndarray: + return np.zeros((1, int(width)), dtype=float) + + +def _zero_square(size: int) -> np.ndarray: + return np.zeros((int(size), int(size)), dtype=float) + + +def _reshape_row(result, width: int) -> np.ndarray: + if width == 0: + return _zero_row(0) + arr = np.asarray(result, dtype=float).reshape(-1) + if arr.size != width: + raise ValueError("Compiled CIF gradient width mismatch") + return arr.reshape(1, width) + + +def _reshape_square(result, size: int) -> np.ndarray: + if size == 0: + return _zero_square(0) + arr = np.asarray(result, dtype=float) + if arr.ndim == 0: + if size != 1: + raise ValueError("Compiled CIF Hessian size mismatch") + return np.asarray([[float(arr)]], dtype=float) + arr = arr.reshape(size, size) + return arr + + +def _compile_cif_surface( + beta: np.ndarray, + xnames: Sequence[str] | None, + stim_names: Sequence[str] | None, + fit_type: str, + hist_coeffs: np.ndarray, +) -> dict[str, Any] | None: + beta_vec = np.asarray(beta) + if beta_vec.dtype == object or beta_vec.size == 0: + return None + beta_vec = np.asarray(beta_vec, dtype=float).reshape(-1) + xnames_list = [str(name) for name in (xnames or [])] + if len(xnames_list) != beta_vec.size: + return None + + independent_names = _ordered_independent_names(xnames_list, stim_names) + symbol_map = {name: sp.Symbol(name, real=True) for name in independent_names} + stim_var_names = [str(name) for name in (stim_names or [])] + for name in stim_var_names: + symbol_map.setdefault(name, sp.Symbol(name, real=True)) + + term_exprs = [ + parse_expr(str(name).strip(), local_dict=symbol_map, transformations=_SYMPY_TRANSFORMS, evaluate=True) + for name in xnames_list + ] + independent_symbols = tuple(symbol_map[name] for name in independent_names) + stim_symbols = tuple(symbol_map[name] for name in stim_var_names) + + eta = sum(sp.Float(float(coeff)) * expr for coeff, expr in zip(beta_vec, term_exprs, strict=True)) + + hist_coeff_vec = np.asarray(hist_coeffs, dtype=float).reshape(-1) + hist_symbols = tuple(sp.Symbol(f"h{idx}", real=True) for idx in range(hist_coeff_vec.size)) + gamma_symbols = tuple(sp.Symbol(f"g{idx}", real=True) for idx in range(hist_coeff_vec.size)) + hist_term = sum(sp.Float(float(coeff)) * symbol for coeff, symbol in zip(hist_coeff_vec, hist_symbols, strict=True)) + gamma_term = sum(symbol * hist_symbol for symbol, hist_symbol in zip(gamma_symbols, hist_symbols, strict=True)) + + eta_hist = eta + hist_term + eta_gamma = eta + gamma_term + + if str(fit_type).lower() == "binomial": + lambda_expr = sp.exp(eta_hist) / (1 + sp.exp(eta_hist)) + lambda_gamma_expr = sp.exp(eta_gamma) / (1 + sp.exp(eta_gamma)) + elif str(fit_type).lower() == "poisson": + lambda_expr = sp.exp(eta_hist) + lambda_gamma_expr = sp.exp(eta_gamma) + else: + return None + + gradient_expr = sp.Matrix([sp.diff(lambda_expr, symbol) for symbol in stim_symbols]) if stim_symbols else sp.Matrix([]) + gradient_log_expr = sp.Matrix([sp.diff(sp.log(lambda_expr), symbol) for symbol in stim_symbols]) if stim_symbols else sp.Matrix([]) + jacobian_expr = sp.hessian(lambda_expr, stim_symbols) if stim_symbols else sp.Matrix([]) + jacobian_log_expr = sp.hessian(sp.log(lambda_expr), stim_symbols) if stim_symbols else sp.Matrix([]) + + if gamma_symbols: + gradient_gamma_expr = sp.Matrix([sp.diff(lambda_gamma_expr, symbol) for symbol in gamma_symbols]) + gradient_log_gamma_expr = sp.Matrix([sp.diff(sp.log(lambda_gamma_expr), symbol) for symbol in gamma_symbols]) + jacobian_gamma_expr = sp.hessian(lambda_gamma_expr, gamma_symbols) + jacobian_log_gamma_expr = sp.hessian(sp.log(lambda_gamma_expr), gamma_symbols) + else: + gradient_gamma_expr = sp.Matrix([]) + gradient_log_gamma_expr = sp.Matrix([]) + jacobian_gamma_expr = sp.Matrix([]) + jacobian_log_gamma_expr = sp.Matrix([]) + + args = independent_symbols + hist_symbols + gamma_args = independent_symbols + hist_symbols + gamma_symbols + + return { + "independent_names": independent_names, + "stim_names": tuple(stim_var_names), + "lambda_expr": lambda_expr, + "lambda_gamma_expr": lambda_gamma_expr if gamma_symbols else None, + "log_lambda_gamma_expr": sp.log(lambda_gamma_expr) if gamma_symbols else None, + "gradient_expr": gradient_expr, + "gradient_log_expr": gradient_log_expr, + "jacobian_expr": jacobian_expr, + "jacobian_log_expr": jacobian_log_expr, + "gradient_gamma_expr": gradient_gamma_expr if gamma_symbols else None, + "gradient_log_gamma_expr": gradient_log_gamma_expr if gamma_symbols else None, + "jacobian_gamma_expr": jacobian_gamma_expr if gamma_symbols else None, + "jacobian_log_gamma_expr": jacobian_log_gamma_expr if gamma_symbols else None, + "lambda_fn": sp.lambdify(args, lambda_expr, modules="numpy"), + "gradient_fn": sp.lambdify(args, gradient_expr, modules="numpy") if stim_symbols else None, + "gradient_log_fn": sp.lambdify(args, gradient_log_expr, modules="numpy") if stim_symbols else None, + "jacobian_fn": sp.lambdify(args, jacobian_expr, modules="numpy") if stim_symbols else None, + "jacobian_log_fn": sp.lambdify(args, jacobian_log_expr, modules="numpy") if stim_symbols else None, + "lambda_gamma_fn": sp.lambdify(gamma_args, lambda_gamma_expr, modules="numpy") if gamma_symbols else None, + "log_lambda_gamma_fn": sp.lambdify(gamma_args, sp.log(lambda_gamma_expr), modules="numpy") if gamma_symbols else None, + "gradient_gamma_fn": sp.lambdify(gamma_args, gradient_gamma_expr, modules="numpy") if gamma_symbols else None, + "gradient_log_gamma_fn": sp.lambdify(gamma_args, gradient_log_gamma_expr, modules="numpy") if gamma_symbols else None, + "jacobian_gamma_fn": sp.lambdify(gamma_args, jacobian_gamma_expr, modules="numpy") if gamma_symbols else None, + "jacobian_log_gamma_fn": sp.lambdify(gamma_args, jacobian_log_gamma_expr, modules="numpy") if gamma_symbols else None, + } + + @dataclass class CIFModel: """Conditional intensity function abstraction used by standalone workflows.""" @@ -146,9 +293,35 @@ def __init__( self.stimVars = list(stimNames or []) self.fitType = str(fitType).lower() self.histCoeffs = _as_1d_float(histCoeffs) + self.indepVars = list(_ordered_independent_names(self.varIn, self.stimVars)) self.history = None self.historyMat = np.zeros((0, 0), dtype=float) self.spikeTrain = None + self.lambdaDelta = None + self.lambdaDeltaGamma = None + self.LogLambdaDeltaGamma = None + self.gradientLambdaDelta = None + self.gradientLogLambdaDelta = None + self.gradientLambdaDeltaGamma = None + self.gradientLogLambdaDeltaGamma = None + self.jacobianLambdaDelta = None + self.jacobianLogLambdaDelta = None + self.jacobianLambdaDeltaGamma = None + self.jacobianLogLambdaDeltaGamma = None + self._expression_surface = _compile_cif_surface(self.b, self.varIn, self.stimVars, self.fitType, self.histCoeffs) + if self._expression_surface is not None: + self.indepVars = list(self._expression_surface["independent_names"]) + self.lambdaDelta = self._expression_surface["lambda_expr"] + self.lambdaDeltaGamma = self._expression_surface["lambda_gamma_expr"] + self.LogLambdaDeltaGamma = self._expression_surface["log_lambda_gamma_expr"] + self.gradientLambdaDelta = self._expression_surface["gradient_expr"] + self.gradientLogLambdaDelta = self._expression_surface["gradient_log_expr"] + self.gradientLambdaDeltaGamma = self._expression_surface["gradient_gamma_expr"] + self.gradientLogLambdaDeltaGamma = self._expression_surface["gradient_log_gamma_expr"] + self.jacobianLambdaDelta = self._expression_surface["jacobian_expr"] + self.jacobianLogLambdaDelta = self._expression_surface["jacobian_log_expr"] + self.jacobianLambdaDeltaGamma = self._expression_surface["jacobian_gamma_expr"] + self.jacobianLogLambdaDeltaGamma = self._expression_surface["jacobian_log_gamma_expr"] if historyObj is not None: self.setHistory(historyObj) if nst is not None: @@ -189,6 +362,17 @@ def setHistory(self, histObj) -> None: if self.spikeTrain is not None: self.historyMat = np.asarray(self.history.computeHistory(self.spikeTrain).dataToMatrix(), dtype=float) + def _stimulus_values(self, stimVal) -> np.ndarray: + stim = _as_1d_float(stimVal) + if self._expression_surface is None: + return stim + expected = len(self._expression_surface["independent_names"]) + if stim.size != expected: + raise ValueError( + f"Expected {expected} independent variable values for CIF evaluation, received {stim.size}" + ) + return stim + def _split_coefficients(self, stim_dim: int) -> tuple[float, np.ndarray]: coeffs = np.asarray(self.b, dtype=float).reshape(-1) if coeffs.size == stim_dim: @@ -213,6 +397,30 @@ def _history_values(self, time_index: int | None = None, nst: nspikeTrain | None idx = min(idx, self.historyMat.shape[0] - 1) return self.historyMat[idx, :].reshape(-1) + def _surface_args(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None) -> tuple[float, ...]: + stim = self._stimulus_values(stimVal) + hist = self._history_values(time_index=time_index, nst=nst) + return tuple(stim.tolist() + hist.tolist()) + + def _surface_gamma_args( + self, + stimVal, + time_index: int | None = None, + nst: nspikeTrain | None = None, + gamma=None, + ) -> tuple[float, ...]: + args = list(self._surface_args(stimVal, time_index=time_index, nst=nst)) + gamma_arr = _as_1d_float(gamma) + if self.histCoeffs.size == 0: + if gamma_arr.size: + raise ValueError("gamma is only valid for history-dependent CIFs") + return tuple(args) + if gamma_arr.size == 1: + gamma_arr = np.repeat(gamma_arr, self.histCoeffs.size) + if gamma_arr.size != self.histCoeffs.size: + raise ValueError("gamma must be scalar or align with histCoeffs") + return tuple(args + gamma_arr.tolist()) + def _eta(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None, gamma=None) -> tuple[float, np.ndarray, np.ndarray, float]: stim = _as_1d_float(stimVal) intercept, stim_coeffs = self._split_coefficients(stim.size) @@ -234,6 +442,17 @@ def _eta(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = return eta, stim_coeffs, hist_coeffs, intercept def _lambda_delta(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None, gamma=None) -> float: + if self._expression_surface is not None and gamma is None: + return float(np.asarray(self._expression_surface["lambda_fn"](*self._surface_args(stimVal, time_index=time_index, nst=nst)), dtype=float).reshape(-1)[0]) + if self._expression_surface is not None and gamma is not None and self._expression_surface["lambda_gamma_fn"] is not None: + return float( + np.asarray( + self._expression_surface["lambda_gamma_fn"]( + *self._surface_gamma_args(stimVal, time_index=time_index, nst=nst, gamma=gamma) + ), + dtype=float, + ).reshape(-1)[0] + ) eta, _, _, _ = self._eta(stimVal, time_index=time_index, nst=nst, gamma=gamma) if self.fitType == "binomial": return float(_sigmoid(np.asarray([eta], dtype=float))[0]) @@ -242,6 +461,12 @@ def _lambda_delta(self, stimVal, time_index: int | None = None, nst: nspikeTrain raise ValueError("fitType must be either 'poisson' or 'binomial'") def _gradient(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None, gamma=None, log: bool = False) -> np.ndarray: + if self._expression_surface is not None and gamma is None: + fn = self._expression_surface["gradient_log_fn" if log else "gradient_fn"] + if fn is None: + return _zero_row(0) + width = len(self._expression_surface["stim_names"]) + return _reshape_row(fn(*self._surface_args(stimVal, time_index=time_index, nst=nst)), width) lambda_delta = self._lambda_delta(stimVal, time_index=time_index, nst=nst, gamma=gamma) _, stim_coeffs, _, _ = self._eta(stimVal, time_index=time_index, nst=nst, gamma=gamma) if self.fitType == "binomial": @@ -250,6 +475,12 @@ def _gradient(self, stimVal, time_index: int | None = None, nst: nspikeTrain | N return stim_coeffs.reshape(1, -1) if log else (lambda_delta * stim_coeffs).reshape(1, -1) def _jacobian(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None, gamma=None, log: bool = False) -> np.ndarray: + if self._expression_surface is not None and gamma is None: + fn = self._expression_surface["jacobian_log_fn" if log else "jacobian_fn"] + if fn is None: + return _zero_square(0) + size = len(self._expression_surface["stim_names"]) + return _reshape_square(fn(*self._surface_args(stimVal, time_index=time_index, nst=nst)), size) lambda_delta = self._lambda_delta(stimVal, time_index=time_index, nst=nst, gamma=gamma) _, stim_coeffs, _, _ = self._eta(stimVal, time_index=time_index, nst=nst, gamma=gamma) outer = np.outer(stim_coeffs, stim_coeffs) @@ -278,18 +509,55 @@ def evalLDGamma(self, stimVal, time_index: int | None = None, nst: nspikeTrain | return self._lambda_delta(stimVal, time_index=time_index, nst=nst, gamma=gamma) def evalLogLDGamma(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None, gamma=None): + if self._expression_surface is not None and self._expression_surface["log_lambda_gamma_fn"] is not None: + return float( + np.asarray( + self._expression_surface["log_lambda_gamma_fn"]( + *self._surface_gamma_args(stimVal, time_index=time_index, nst=nst, gamma=gamma) + ), + dtype=float, + ).reshape(-1)[0] + ) return float(np.log(np.clip(self.evalLDGamma(stimVal, time_index=time_index, nst=nst, gamma=gamma), 1e-12, None))) def evalGradientLDGamma(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None, gamma=None): + if self._expression_surface is not None and self._expression_surface["gradient_gamma_fn"] is not None: + return _reshape_row( + self._expression_surface["gradient_gamma_fn"]( + *self._surface_gamma_args(stimVal, time_index=time_index, nst=nst, gamma=gamma) + ), + self.histCoeffs.size, + ) return self._gradient(stimVal, time_index=time_index, nst=nst, gamma=gamma) def evalGradientLogLDGamma(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None, gamma=None): + if self._expression_surface is not None and self._expression_surface["gradient_log_gamma_fn"] is not None: + return _reshape_row( + self._expression_surface["gradient_log_gamma_fn"]( + *self._surface_gamma_args(stimVal, time_index=time_index, nst=nst, gamma=gamma) + ), + self.histCoeffs.size, + ) return self._gradient(stimVal, time_index=time_index, nst=nst, gamma=gamma, log=True) def evalJacobianLDGamma(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None, gamma=None): + if self._expression_surface is not None and self._expression_surface["jacobian_gamma_fn"] is not None: + return _reshape_square( + self._expression_surface["jacobian_gamma_fn"]( + *self._surface_gamma_args(stimVal, time_index=time_index, nst=nst, gamma=gamma) + ), + self.histCoeffs.size, + ) return self._jacobian(stimVal, time_index=time_index, nst=nst, gamma=gamma) def evalJacobianLogLDGamma(self, stimVal, time_index: int | None = None, nst: nspikeTrain | None = None, gamma=None): + if self._expression_surface is not None and self._expression_surface["jacobian_log_gamma_fn"] is not None: + return _reshape_square( + self._expression_surface["jacobian_log_gamma_fn"]( + *self._surface_gamma_args(stimVal, time_index=time_index, nst=nst, gamma=gamma) + ), + self.histCoeffs.size, + ) return self._jacobian(stimVal, time_index=time_index, nst=nst, gamma=gamma, log=True) def isSymBeta(self) -> bool: @@ -302,6 +570,18 @@ def evaluate(self, design_matrix: np.ndarray, *, delta: float = 1.0, history_mat x = np.asarray(design_matrix, dtype=float) if x.ndim == 1: x = x[:, None] + if self._expression_surface is not None and x.shape[1] == len(self._expression_surface["independent_names"]): + if history_matrix is None: + hist = np.zeros((x.shape[0], self.histCoeffs.size), dtype=float) + else: + hist = np.asarray(history_matrix, dtype=float) + if hist.ndim == 1: + hist = hist[:, None] + if hist.shape[1] != self.histCoeffs.size: + raise ValueError("history_matrix column count must match histCoeffs length") + args = [x[:, idx] for idx in range(x.shape[1])] + [hist[:, idx] for idx in range(hist.shape[1])] + lambda_delta = np.asarray(self._expression_surface["lambda_fn"](*args), dtype=float).reshape(-1) + return lambda_delta / max(float(delta), 1e-12) intercept, stim_coeffs = self._split_coefficients(x.shape[1]) eta = intercept + x @ stim_coeffs if history_matrix is not None and self.histCoeffs.size: diff --git a/nstat/decoding_algorithms.py b/nstat/decoding_algorithms.py index b19820f1..3e48eaf1 100644 --- a/nstat/decoding_algorithms.py +++ b/nstat/decoding_algorithms.py @@ -236,13 +236,23 @@ def _normalize_mu_models(mu, n_models: int, num_cells: int) -> list[np.ndarray]: return [_normalize_mu(mu, num_cells) for _ in range(n_models)] -def _extract_linear_terms_from_cifs(lambdaCIFColl, num_states: int, num_cells: int): +def _normalize_cif_collection(lambdaCIFColl) -> list[CIF]: if isinstance(lambdaCIFColl, CIF): cifs = [lambdaCIFColl] elif isinstance(lambdaCIFColl, Sequence) and not isinstance(lambdaCIFColl, (str, bytes)): cifs = list(lambdaCIFColl) else: raise UnsupportedWorkflowError("PPDecodeFilter requires a CIF or sequence of CIF objects for the Python port") + if not cifs: + raise ValueError("lambdaCIFColl must contain at least one CIF object") + for cif in cifs: + if not isinstance(cif, CIF): + raise UnsupportedWorkflowError("PPDecodeFilter only supports CIF objects in the Python port") + return cifs + + +def _extract_linear_terms_from_cifs(lambdaCIFColl, num_states: int, num_cells: int): + cifs = _normalize_cif_collection(lambdaCIFColl) if len(cifs) != num_cells: raise ValueError("Number of CIF objects must match the number of observed cells") @@ -542,13 +552,12 @@ def PPDecode_update(x_p, W_p, dN, lambdaIn, binwidth=0.001, time_index=1, WuConv observed = obs[:, idx - 1] for cell_index, cif in enumerate(lambda_items): - if not isinstance(cif, CIF): - raise ValueError("Lambda must be a cell of CIFs or a CIF") if cif.historyMat.size == 0: - spike_times = (np.where(obs[cell_index] == 1.0)[0]) * float(binwidth) + observed_prefix = obs[cell_index, :idx] + spike_times = np.where(observed_prefix > 0.5)[0] * float(binwidth) nst = nspikeTrain(spike_times, makePlots=-1) nst.setMinTime(0.0) - nst.setMaxTime((obs.shape[1] - 1) * float(binwidth)) + nst.setMaxTime((idx - 1) * float(binwidth)) nst = nst.resample(1.0 / float(binwidth)) lambda_delta[cell_index, 0] = float(cif.evalLambdaDelta(x_vec, idx, nst)) sum_val_vec += observed[cell_index] * np.asarray(cif.evalGradientLog(x_vec, idx, nst), dtype=float).reshape(-1) @@ -695,26 +704,73 @@ def PPDecodeFilterLinear(*args, **kwargs): @staticmethod def PPDecodeFilter(A, Q, Px0, dN, lambdaCIFColl, binwidth=0.001, x0=None, Pi0=None, yT=None, PiT=None, estimateTarget=0, Wconv=None): obs = _as_observation_matrix(dN) - num_states = _infer_state_dim(A, np.array([0.0]), obs.shape[0]) - mu, beta, fitType, gamma, windowTimes = _extract_linear_terms_from_cifs(lambdaCIFColl, num_states, obs.shape[0]) - initial_cov = Px0 if _is_empty_value(Pi0) else Pi0 - return DecodingAlgorithms._ppdecode_filter_linear( - A, - Q, - obs, - mu, - beta, - fitType, - binwidth, - gamma, - windowTimes, - x0, - initial_cov, - yT, - PiT, - estimateTarget, - Wconv, - ) + lambda_items = _normalize_cif_collection(lambdaCIFColl) + num_cells, num_steps = obs.shape + if len(lambda_items) != num_cells: + raise ValueError("Number of CIF objects must match the number of observed cells") + + num_states = _infer_state_dim(A, np.array([0.0]), num_cells) + uses_target_branch = not _is_empty_value(yT) or not _is_empty_value(PiT) or int(estimateTarget) != 0 + if uses_target_branch: + mu, beta, fitType, gamma, windowTimes = _extract_linear_terms_from_cifs(lambda_items, num_states, num_cells) + initial_cov = Px0 if _is_empty_value(Pi0) else Pi0 + return DecodingAlgorithms._ppdecode_filter_linear( + A, + Q, + obs, + mu, + beta, + fitType, + binwidth, + gamma, + windowTimes, + x0, + initial_cov, + yT, + PiT, + estimateTarget, + Wconv, + ) + + x0_vec = np.zeros(num_states, dtype=float) if _is_empty_value(x0) else np.asarray(x0, dtype=float).reshape(-1) + if x0_vec.size != num_states: + raise ValueError("x0 must match the decoding state dimension") + # MATLAB PPDecodeFilter's standard branch initializes from Pi0, and + # when Pi0 is omitted it falls back to zeros rather than using Px0. + Pi0_mat = np.zeros((num_states, num_states), dtype=float) if _is_empty_value(Pi0) else _as_state_matrix(Pi0, num_states) + + x_p = np.zeros((num_states, num_steps + 1), dtype=float) + x_u = np.zeros((num_states, num_steps), dtype=float) + W_p = np.zeros((num_states, num_states, num_steps + 1), dtype=float) + W_u = np.zeros((num_states, num_states, num_steps), dtype=float) + + A0 = _select_time_matrix(A, 0, num_states) + Q0 = _select_time_matrix(Q, 0, num_states) + x_p[:, 0], W_p[:, :, 0] = DecodingAlgorithms.PPDecode_predict(x0_vec, Pi0_mat, A0, Q0, Wconv) + + for time_index in range(1, num_steps + 1): + x_u[:, time_index - 1], W_u[:, :, time_index - 1], _ = DecodingAlgorithms.PPDecode_update( + x_p[:, time_index - 1], + W_p[:, :, time_index - 1], + obs, + lambda_items, + binwidth, + time_index, + None, + ) + A_t = _select_time_matrix(A, time_index - 1, num_states) + Q_t = _select_time_matrix(Q, time_index - 1, num_states) + x_p[:, time_index], W_p[:, :, time_index] = DecodingAlgorithms.PPDecode_predict( + x_u[:, time_index - 1], + W_u[:, :, time_index - 1], + A_t, + Q_t, + Wconv, + ) + + empty_vec = np.array([], dtype=float) + empty_cov = np.zeros((0, 0, 0), dtype=float) + return x_p, W_p, x_u, W_u, empty_vec, empty_cov, empty_vec, empty_cov @staticmethod def PP_fixedIntervalSmoother(A, Q, dN, lags, mu, beta, fitType="poisson", delta=0.001, gamma=None, windowTimes=None, x0=None, Pi0=None): diff --git a/parity/class_fidelity.yml b/parity/class_fidelity.yml index 5cb67a5d..6c55c44d 100644 --- a/parity/class_fidelity.yml +++ b/parity/class_fidelity.yml @@ -353,13 +353,13 @@ items: in the expected workflow positions. symbol_presence_verified: yes known_remaining_differences: - - Simulink-backed recursive-CIF behavior is represented by a native Python implementation, - but it is not yet fixture-matched one-for-one against MATLAB/Simulink stochastic - trajectories. + - Analytic and nonlinear polynomial CIF surfaces are now fixture-backed against + MATLAB, but recursive Simulink-backed stochastic trajectories are still validated + as high-fidelity summaries rather than exact sample-by-sample reproductions. required_remediation: - Extend the committed MATLAB-derived fixtures beyond analytic lambda/gradient/Jacobian - outputs and the deterministic recursive lambda prefix to cover additional thinning - and seeded simulation summaries. + outputs, nonlinear polynomial surfaces, and the deterministic recursive lambda + prefix to cover additional thinning and seeded simulation summaries. - Add MATLAB/Simulink comparison fixtures for recursive CIF simulation trajectories when the random-stream alignment question is resolved. plotting_report_parity: Simulation/report plotting is limited; downstream notebooks @@ -392,11 +392,15 @@ items: tensors instead of only Python-specific dictionaries. symbol_presence_verified: yes known_remaining_differences: + - The nonlinear `PPDecodeFilter` path is now fixture-backed against MATLAB on + a deterministic polynomial-CIF example, but it still shows small symbolic/numeric + drift at the `1e-4` level and remains high-fidelity rather than exact. - Target-estimation augmentation, EM routines, and some advanced symbolic-CIF workflows remain thinner than MATLAB. required_remediation: - - Extend the committed MATLAB-derived numerical fixtures beyond `PPDecode_predict` - to DecodingExample, DecodingExampleWithHist, StimulusDecode2D, and HybridFilterExample. + - Extend the committed MATLAB-derived numerical fixtures from `PPDecode_predict` + and the deterministic nonlinear `PPDecodeFilter` case to DecodingExample, + DecodingExampleWithHist, and HybridFilterExample summaries. - Port the remaining target-estimation, EM, and symbolic-CIF branches from the MATLAB toolbox. plotting_report_parity: Notebook-level decoding figures are supported, but the full diff --git a/parity/notebook_fidelity.yml b/parity/notebook_fidelity.yml index a6bc3a87..02ed69b5 100644 --- a/parity/notebook_fidelity.yml +++ b/parity/notebook_fidelity.yml @@ -257,11 +257,12 @@ items: - topic: StimulusDecode2D source_matlab: StimulusDecode2D.mlx python_notebook: notebooks/StimulusDecode2D.ipynb - fidelity_status: partial - remaining_differences: The notebook reproduces the MATLAB section order, figure - inventory, simulated receptive fields, and decoded-trajectory presentation, but - the current Python decoder still uses regression-based state recovery instead - of MATLAB's symbolic-CIF nonlinear filter. + fidelity_status: high_fidelity + remaining_differences: The notebook now follows the MATLAB nonlinear-CIF decoding + workflow and uses `DecodingAlgorithms.PPDecodeFilter` before the same documented + linear fallback branch as MATLAB. Exact decoded traces and figure styling can + still vary modestly because Python's symbolic/numeric stack and random streams + are not byte-identical to MATLAB. python_sections: 4 python_expected_figures: 6 python_uses_figure_tracker: true diff --git a/parity/report.md b/parity/report.md index 49fe2239..4ab42bd4 100644 --- a/parity/report.md +++ b/parity/report.md @@ -42,8 +42,8 @@ Generated from `parity/manifest.yml`, `parity/class_fidelity.yml`, `tools/notebo | Status | Count | |---|---:| | `exact` | 0 | -| `high_fidelity` | 12 | -| `partial` | 1 | +| `high_fidelity` | 13 | +| `partial` | 0 | ## Simulink Fidelity Summary @@ -60,8 +60,7 @@ Generated from `parity/manifest.yml`, `parity/class_fidelity.yml`, `tools/notebo - Public API: no missing MATLAB public APIs remain; only the MATLAB help-browser utility is explicitly non-applicable. - Help/notebook parity: all inventoried MATLAB help workflows are mapped to Python notebooks or equivalents. -- Notebook fidelity: workflow coverage is complete, but 1 MATLAB-helpfile notebook ports are still marked partial in `tools/notebooks/parity_notes.yml`. -- Notebook fidelity audit: structural section/figure comparisons plus placeholder/tracker-only cell detection are recorded in `parity/notebook_fidelity.yml`. +- Notebook fidelity: all tracked MATLAB-helpfile notebook ports are marked high fidelity or exact. - Paper examples and docs gallery: all canonical paper examples and committed gallery directories are mapped. - Class fidelity: the class audit reports no partial, wrapper-only, or missing items. - Runtime symbol verification: every audited MATLAB-facing Python symbol marked present in `parity/class_fidelity.yml` resolves on the live public surface. @@ -73,7 +72,7 @@ No partial or missing items remain in the mapping inventory. ## Remaining Notebook-Fidelity Deltas -- `StimulusDecode2D` -> `notebooks/StimulusDecode2D.ipynb` [partial]: The notebook reproduces the MATLAB section order, figure inventory, simulated receptive fields, and decoded-trajectory presentation, but the current Python decoder still uses regression-based state recovery instead of MATLAB's symbolic-CIF nonlinear filter. +No partial notebook-fidelity items remain in `tools/notebooks/parity_notes.yml`. ## Remaining Class-Fidelity Deltas diff --git a/pyproject.toml b/pyproject.toml index 86ea335b..cc3cf9f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "numpy>=1.24", "scipy>=1.10", "matplotlib>=3.7", + "sympy>=1.13", "PyYAML>=6.0", "nbformat>=5.10", "nbclient>=0.10" diff --git a/tests/parity/fixtures/matlab_gold/cif_exactness.mat b/tests/parity/fixtures/matlab_gold/cif_exactness.mat index 2be3498a0328ca25e4b510b1c0632b75934d229c..e3fe01b82d255a6dc97e1b1ee3c955cd1a0e504c 100644 GIT binary patch delta 520 zcmaFH(#ko(j@Qu0%D~LZ*h0a`z{qT3pz_27wv8p97=!p37#LiExMI%Z?hXIfrF5}${xx&SxBN06>da!&W}@H1#% z=dS{3u?A_;RZ4VJ&^vK{WzTufGrnG)J^r3&H9cEaJ#SfcKX5|w`j}0JegA72(7nR4dJ9wCMK!X2y)JIcySe{5M7`unv0bk8 zs|4kN=G8O!!VUC5Hqg;fS>Q7h#PeZEMRh@GNpcg;a~|PxxO4o|yxRx&RsO%W>ALrJ zy|)dQy=&~&|4)kDE6-4PM^Y8!${@I*C(sS$Mln=3DXzw^&ML%KT|k@RM8d({bTNxWE7#SFuDG&=7V1UunmmkPh0pf}|kCPJ;3>cEk6rMF) zQaHf%)Y#EbS-`64rjCOcvoeF|a)$LFwT5uD9O!DgLG=cbrv97sgyV& z!Yj+qQ)$ApmQ&|vaWYxIWlTtAR|09#glpM=tcB5#TfmqpQnOCL&SOsFM8O_L)gO$8 zAPo+14JVK_I31N(nB=)qwX@A(mT(8J#DfP6Ema%tOYrF32HB|o-TB4bfU2do47`s; ze}l|%BVdjj#GHf&$NxndJ3d%u+hVs9XdIvOPwnsF-D`^LBTSFVG3>r9DF!kv2X0yh zva1|#_VAqH5&g)@GsmHffmvXQ0D5cxCE+zdhnZD}yt)t>^5=!#xrXBq9N$mGu!Rmf kRL5HVRrQ-`VEIqAh?R$5*3LSA>Z$GY8_Wz}Z`hm#0PY{X)Bpeg literal 0 HcmV?d00001 diff --git a/tests/test_decoding_algorithms_fidelity.py b/tests/test_decoding_algorithms_fidelity.py index d075a0c7..022d5357 100644 --- a/tests/test_decoding_algorithms_fidelity.py +++ b/tests/test_decoding_algorithms_fidelity.py @@ -53,6 +53,23 @@ def test_ppdecodefilter_accepts_cif_collections_with_history() -> None: assert np.all(np.isfinite(x_u)) +def test_ppdecodefilter_handles_symbolic_style_polynomial_cifs() -> None: + dN = np.array([[0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0]], dtype=float) + lambda_cifs = [ + CIF([-2.0, -0.5, 0.3, -0.2, -0.1, 0.05], ["1", "x", "y", "x^2", "y^2", "x*y"], ["x", "y"], fitType="binomial"), + CIF([-1.5, 0.4, -0.2, 0.15, -0.05, 0.02], ["1", "x", "y", "x^2", "y^2", "x*y"], ["x", "y"], fitType="binomial"), + ] + + x_p, W_p, x_u, W_u, *_ = DecodingAlgorithms.PPDecodeFilter(np.eye(2), 0.01 * np.eye(2), 0.05 * np.eye(2), dN, lambda_cifs, 0.1) + + assert x_p.shape == (2, 5) + assert W_p.shape == (2, 2, 5) + assert x_u.shape == (2, 4) + assert W_u.shape == (2, 2, 4) + assert np.all(np.isfinite(x_u)) + assert np.all(np.isfinite(W_u)) + + def test_ppdecode_update_matches_matlab_facing_public_surface() -> None: dN = np.array([[0.0, 1.0, 0.0, 1.0]], dtype=float) lambda_cif = CIF([0.1, 0.4], ["1", "x"], ["x"], fitType="binomial") diff --git a/tests/test_matlab_gold_fixtures.py b/tests/test_matlab_gold_fixtures.py index 33ed0d07..5b7f7dff 100644 --- a/tests/test_matlab_gold_fixtures.py +++ b/tests/test_matlab_gold_fixtures.py @@ -163,6 +163,19 @@ def test_cif_eval_surface_matches_matlab_gold_fixture() -> None: np.testing.assert_allclose(cif.evalJacobian(stim_val), np.asarray(payload["jacobian"], dtype=float), rtol=1e-8, atol=1e-10) np.testing.assert_allclose(cif.evalJacobianLog(stim_val), np.asarray(payload["jacobian_log"], dtype=float), rtol=1e-8, atol=1e-10) + poly_cif = CIF( + beta=_vector(payload, "poly_beta"), + Xnames=["1", "x", "y", "x^2", "y^2", "x*y"], + stimNames=["x", "y"], + fitType="binomial", + ) + poly_stim = _vector(payload, "poly_stimVal") + np.testing.assert_allclose(poly_cif.evalLambdaDelta(poly_stim), _scalar(payload, "poly_lambda_delta"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(poly_cif.evalGradient(poly_stim).reshape(-1), _vector(payload, "poly_gradient"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(poly_cif.evalGradientLog(poly_stim).reshape(-1), _vector(payload, "poly_gradient_log"), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(poly_cif.evalJacobian(poly_stim), np.asarray(payload["poly_jacobian"], dtype=float), rtol=1e-8, atol=1e-10) + np.testing.assert_allclose(poly_cif.evalJacobianLog(poly_stim), np.asarray(payload["poly_jacobian_log"], dtype=float), rtol=1e-8, atol=1e-10) + def test_analysis_fit_surface_matches_matlab_gold_fixture() -> None: payload = _load_fixture("analysis_exactness.mat") @@ -225,6 +238,28 @@ def test_decoding_predict_matches_matlab_gold_fixture() -> None: np.testing.assert_allclose(W_p, np.asarray(payload["W_p"], dtype=float), rtol=1e-8, atol=1e-10) +def test_nonlinear_ppdecodefilter_matches_matlab_gold_fixture() -> None: + payload = _load_fixture("nonlinear_decode_exactness.mat") + lambda_cifs = [ + CIF(_vector(payload, "beta1"), ["1", "x", "y", "x^2", "y^2", "x*y"], ["x", "y"], fitType="binomial"), + CIF(_vector(payload, "beta2"), ["1", "x", "y", "x^2", "y^2", "x*y"], ["x", "y"], fitType="binomial"), + ] + + x_p, W_p, x_u, W_u, *_ = DecodingAlgorithms.PPDecodeFilter( + np.asarray(payload["A"], dtype=float), + np.asarray(payload["Q"], dtype=float), + np.asarray(payload["Px0"], dtype=float), + np.asarray(payload["dN"], dtype=float), + lambda_cifs, + _scalar(payload, "delta"), + ) + + np.testing.assert_allclose(x_p, np.asarray(payload["x_p"], dtype=float), rtol=1e-3, atol=5e-4) + np.testing.assert_allclose(W_p, np.asarray(payload["W_p"], dtype=float), rtol=1e-3, atol=5e-4) + np.testing.assert_allclose(x_u, np.asarray(payload["x_u"], dtype=float), rtol=1e-3, atol=5e-4) + np.testing.assert_allclose(W_u, np.asarray(payload["W_u"], dtype=float), rtol=1e-3, atol=5e-4) + + def test_simulated_network_matches_matlab_gold_fixture() -> None: payload = _load_fixture("simulated_network_exactness.mat") native = simulate_two_neuron_network(seed=4) diff --git a/tests/test_notebook_fidelity_audit.py b/tests/test_notebook_fidelity_audit.py index 40aac0d8..9e3cb3ef 100644 --- a/tests/test_notebook_fidelity_audit.py +++ b/tests/test_notebook_fidelity_audit.py @@ -48,7 +48,7 @@ def test_notebook_fidelity_audit_marks_upgraded_ports_as_high_fidelity() -> None def test_notebook_fidelity_audit_tracks_only_known_partial_notebooks() -> None: audit = yaml.safe_load(AUDIT_PATH.read_text(encoding="utf-8")) or {} partial_topics = {row["topic"] for row in audit.get("items", []) if row["fidelity_status"] in {"partial", "placeholder", "missing"}} - assert partial_topics == {"StimulusDecode2D"} + assert partial_topics == set() def test_high_fidelity_notebooks_have_no_placeholder_or_tracker_only_cells() -> None: diff --git a/tests/test_notebook_parity_notes.py b/tests/test_notebook_parity_notes.py index f49ca12f..49060399 100644 --- a/tests/test_notebook_parity_notes.py +++ b/tests/test_notebook_parity_notes.py @@ -40,7 +40,7 @@ def test_target_notebooks_start_with_machine_readable_parity_note() -> None: def test_notebook_parity_notes_track_only_known_partial_statuses() -> None: partial = [row["topic"] for row in _load_notes() if row["fidelity_status"] == "partial"] - assert partial == ["StimulusDecode2D"] + assert partial == [] def test_high_fidelity_parity_notes_do_not_admit_placeholder_or_tracker_only_status() -> None: diff --git a/tests/test_parity_manifest.py b/tests/test_parity_manifest.py index d6549d06..94e31b5f 100644 --- a/tests/test_parity_manifest.py +++ b/tests/test_parity_manifest.py @@ -7,6 +7,7 @@ REPO_ROOT = Path(__file__).resolve().parents[1] MANIFEST_PATH = REPO_ROOT / "parity" / "manifest.yml" +NOTEBOOK_AUDIT_PATH = REPO_ROOT / "parity" / "notebook_fidelity.yml" EXPECTED_MATLAB_PUBLIC_API = { "Analysis", @@ -99,3 +100,22 @@ def test_parity_manifest_statuses_and_mapped_targets_are_valid() -> None: target = row.get("python_target") if status == "mapped": assert target, f"Mapped item in {section_name} is missing a python_target: {row}" + + +def test_manifest_help_workflows_align_with_notebook_fidelity_audit() -> None: + manifest = _load_manifest() + notebook_audit = yaml.safe_load(NOTEBOOK_AUDIT_PATH.read_text(encoding="utf-8")) or {} + audit_rows = {row["topic"]: row for row in notebook_audit.get("items", [])} + + manifest_help_rows = { + row["matlab"]: row + for row in manifest["help_workflows"] + if str(row.get("python_target", "")).startswith("notebooks/") + } + + assert set(audit_rows) <= set(manifest_help_rows) + for topic, audit_row in audit_rows.items(): + manifest_row = manifest_help_rows[topic] + audit_row = audit_rows[topic] + if manifest_row["status"] == "mapped": + assert audit_row["fidelity_status"] in {"high_fidelity", "exact"} diff --git a/tests/test_parity_report.py b/tests/test_parity_report.py index ca0c9a14..6400baf5 100644 --- a/tests/test_parity_report.py +++ b/tests/test_parity_report.py @@ -24,8 +24,8 @@ def test_parity_report_highlights_current_constraints() -> None: assert "Notebook Fidelity Summary" in text assert "Simulink Fidelity Summary" in text assert "Remaining Notebook-Fidelity Deltas" in text - assert "workflow coverage is complete, but 1 MATLAB-helpfile notebook ports are still marked partial" in text - assert "`StimulusDecode2D` -> `notebooks/StimulusDecode2D.ipynb` [partial]" in text + assert "all tracked MATLAB-helpfile notebook ports are marked high fidelity or exact" in text + assert "No partial notebook-fidelity items remain in `tools/notebooks/parity_notes.yml`." in text assert "No partial or missing items remain in the mapping inventory." in text assert "Remaining Class-Fidelity Deltas" in text assert "the class audit reports no partial, wrapper-only, or missing items" in text diff --git a/tools/notebooks/build_helpfile_fidelity_notebooks.py b/tools/notebooks/build_helpfile_fidelity_notebooks.py index a4be05ce..e5595d10 100644 --- a/tools/notebooks/build_helpfile_fidelity_notebooks.py +++ b/tools/notebooks/build_helpfile_fidelity_notebooks.py @@ -968,8 +968,8 @@ def _plot_raster(ax, time_s, spikes, *, max_cells=18): ## MATLAB Parity Note - Source MATLAB helpfile: `StimulusDecode2D.mlx` -- Fidelity status: `partial` -- Remaining justified differences: The notebook reproduces the MATLAB section order, figure inventory, simulated receptive fields, and decoded-trajectory presentation, but the current Python decoder still uses regression-based state recovery instead of MATLAB's symbolic-CIF nonlinear filter. +- Fidelity status: `high_fidelity` +- Remaining justified differences: The notebook now follows the MATLAB nonlinear-CIF decoding workflow and uses `DecodingAlgorithms.PPDecodeFilter` before the same documented linear fallback branch as MATLAB. Exact decoded traces and figure styling can still vary modestly because Python's symbolic/numeric stack and random streams are not byte-identical to MATLAB. """ @@ -991,7 +991,7 @@ def _plot_raster(ax, time_s, spikes, *, max_cells=18): import matplotlib.pyplot as plt import numpy as np - from nstat import DecodingAlgorithms + from nstat import CIF, Covariate, DecodingAlgorithms, SignalObj, nstColl from nstat.notebook_figures import FigureTracker np.random.seed(0) @@ -1006,62 +1006,103 @@ def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)): return fig - def _simulate_decode(seed=19, *, n_cells=24, dt=0.01, tmax=20.0): + def _subplot_grid(count): + rows = max(int(np.floor(np.sqrt(count))), 1) + cols = int(np.ceil(count / rows)) + return rows, cols + + + def _simulate_decode(seed=0, *, num_realizations=80, delta=0.001, tmax=1.0): rng = np.random.default_rng(seed) - time = np.arange(0.0, tmax + dt, dt) - vel = np.cumsum(rng.normal(0.0, 0.05, size=(time.size, 2)), axis=0) - vel = 0.18 * vel / np.maximum(np.std(vel, axis=0, ddof=1), 1e-6) - pos = np.cumsum(vel, axis=0) * dt - pos = pos - np.mean(pos, axis=0, keepdims=True) - px = pos[:, 0] - py = pos[:, 1] - coeffs = np.column_stack( - [ - -2.2 - np.abs(rng.normal(0.0, 0.35, size=n_cells)), - rng.normal(0.0, 1.1, size=n_cells), - rng.normal(0.0, 1.1, size=n_cells), - -np.abs(rng.normal(1.6, 0.35, size=n_cells)), - -np.abs(rng.normal(1.6, 0.35, size=n_cells)), - rng.normal(0.0, 0.45, size=n_cells), - ] - ) + time = np.arange(0.0, tmax + delta, delta) + q_drive = 0.01 + innovations = q_drive * rng.standard_normal((2, time.size)) + vx = np.cumsum(innovations[0]) + vy = np.cumsum(innovations[1]) + vel_sig = SignalObj(time, np.column_stack([vx, vy]), "vel", "time", "s", "", ["vx", "vy"]) + pos_sig = vel_sig.integral() + pos_data = np.asarray(pos_sig.data, dtype=float) + px = pos_data[:, 0] + py = pos_data[:, 1] + + coeffs = -np.abs(rng.standard_normal((num_realizations, 5))) + coeffs = np.column_stack([-2.0 * np.abs(rng.standard_normal(num_realizations)), coeffs]) design = np.column_stack([np.ones(time.size), px, py, px * px, py * py, px * py]) - spikes = np.zeros((time.size, n_cells), dtype=float) - firing_prob = np.zeros_like(spikes) - for idx in range(n_cells): + + lambda_rates_hz = np.zeros((time.size, num_realizations), dtype=float) + lambda_cifs = [] + spike_trains = [] + for idx in range(num_realizations): eta = design @ coeffs[idx] - p = 1.0 / (1.0 + np.exp(-np.clip(eta, -20.0, 20.0))) - firing_prob[:, idx] = p - spikes[:, idx] = (rng.random(time.size) < p).astype(float) + exp_eta = np.exp(np.clip(eta, -20.0, 20.0)) + lambda_delta = exp_eta / (1.0 + exp_eta) + lambda_rates_hz[:, idx] = lambda_delta / delta + lambda_cov = Covariate(time, lambda_rates_hz[:, idx], "\\\\Lambda(t)", "time", "s", "Hz", [f"lambda_{idx + 1}"]) + spike_coll = CIF.simulateCIFByThinningFromLambda(lambda_cov, 1, seed=seed + idx + 1) + train = spike_coll.getNST(1) + train.setName(str(idx + 1)) + spike_trains.append(train) + lambda_cifs.append(CIF(coeffs[idx], ["1", "x", "y", "x^2", "y^2", "x*y"], ["x", "y"], fitType="binomial")) + + spike_coll = nstColl(spike_trains) + spike_coll.resample(1.0 / delta) + labels = list(range(1, num_realizations + 1)) + dN = spike_coll.dataToMatrix(labels, delta, float(time[0]), float(time[-1])).T + + vx_var = float(np.var(px[1:] - px[:-1])) + vy_var = float(np.var(py[1:] - py[:-1])) + q_cov = np.array([[vx_var, 0.0], [0.0, vy_var]], dtype=float) + p0 = 0.1 * np.eye(2, dtype=float) + a_mat = np.eye(2, dtype=float) + decode_method = "PPDecodeFilter" + decode_error = "" + try: + x_p, pe_p, x_u, pe_u, *_ = DecodingAlgorithms.PPDecodeFilter(a_mat, q_cov, p0, dN, lambda_cifs, delta) + except Exception as exc: + decode_method = "PPDecodeFilterLinear" + decode_error = f"{type(exc).__name__}: {exc}" + mu_linear = coeffs[:, 0] + beta_linear = coeffs[:, 1:3].T + x_p, pe_p, x_u, pe_u, *_ = DecodingAlgorithms.PPDecodeFilterLinear(a_mat, q_cov, dN, mu_linear, beta_linear, "binomial", delta) + grid = np.linspace(-1.4, 1.4, 60) gx, gy = np.meshgrid(grid, grid) grid_design = np.column_stack([np.ones(gx.size), gx.ravel(), gy.ravel(), gx.ravel() ** 2, gy.ravel() ** 2, gx.ravel() * gy.ravel()]) fields = [] - for idx in range(n_cells): + for idx in range(num_realizations): eta = grid_design @ coeffs[idx] field = (1.0 / (1.0 + np.exp(-np.clip(eta, -20.0, 20.0)))).reshape(gx.shape) fields.append(field) - subset = max(n_cells // 2, 1) - dec_x_subset = DecodingAlgorithms.linear_decode(spikes[:, :subset], px) - dec_y_subset = DecodingAlgorithms.linear_decode(spikes[:, :subset], py) - dec_x_full = DecodingAlgorithms.linear_decode(spikes, px) - dec_y_full = DecodingAlgorithms.linear_decode(spikes, py) + n_common = min(px.size, x_u.shape[1]) + decode_rmse = float( + np.sqrt( + np.mean( + (x_u[0, :n_common] - px[:n_common]) ** 2 + + (x_u[1, :n_common] - py[:n_common]) ** 2 + ) + ) + ) return { "time_s": time, "px": px, "py": py, - "vx": vel[:, 0], - "vy": vel[:, 1], - "spikes": spikes, - "firing_prob": firing_prob, + "vx": vx, + "vy": vy, + "spikes": dN.T, + "lambda_rates_hz": lambda_rates_hz, "fields": np.asarray(fields, dtype=float), "grid_x": gx, "grid_y": gy, - "decoded_subset_x": dec_x_subset["decoded"], - "decoded_subset_y": dec_y_subset["decoded"], - "decoded_full_x": dec_x_full["decoded"], - "decoded_full_y": dec_y_full["decoded"], - "rmse_full": float(np.sqrt(np.mean((dec_x_full["decoded"] - px) ** 2 + (dec_y_full["decoded"] - py) ** 2))), + "decoded_x": x_u[0, :n_common], + "decoded_y": x_u[1, :n_common], + "predicted_x": x_p[0, 1 : n_common + 1], + "predicted_y": x_p[1, 1 : n_common + 1], + "decode_rmse": decode_rmse, + "decode_method": decode_method, + "decode_error": decode_error, + "coeffs": coeffs, + "state_cov": pe_u[:, :, :n_common], + "num_cells": num_realizations, } @@ -1080,7 +1121,14 @@ def _plot_raster(ax, time_s, spikes, *, max_cells=20): # This notebook follows the MATLAB 2-D decoding workflow with simulated spatial receptive fields. plt.close("all") payload = _simulate_decode() - print({"num_cells": int(payload["spikes"].shape[1]), "rmse_full": round(float(payload["rmse_full"]), 4)}) + print( + { + "num_cells": int(payload["num_cells"]), + "decode_method": payload["decode_method"], + "decode_rmse": round(float(payload["decode_rmse"]), 4), + "fallback_error": payload["decode_error"] or "", + } + ) """, """ # SECTION 1: Generate the random receptive fields to simulate different neurons @@ -1094,17 +1142,20 @@ def _plot_raster(ax, time_s, spikes, *, max_cells=20): fig = _prepare_figure("lambda{i}.plot", figsize=(9.0, 5.0)) ax = fig.subplots(1, 1) - show = [0, 1, 2, 3] - for idx in show: - ax.plot(payload["time_s"], payload["firing_prob"][:, idx], linewidth=1.2, label=f"Cell {idx + 1}") - ax.set_title("Example firing probabilities") + for idx in range(min(payload["num_cells"], 24)): + ax.plot(payload["time_s"], payload["lambda_rates_hz"][:, idx], linewidth=0.7, alpha=0.5) + ax.set_title("Conditional intensity functions") ax.set_xlabel("time (s)") - ax.set_ylabel("spike probability") - ax.legend(loc="upper right", frameon=False, fontsize=8) + ax.set_ylabel("Hz") fig = _prepare_figure("pcolor(X,Y,placeField{i}), shading interp", figsize=(8.0, 8.0)) - axs = fig.subplots(2, 2, squeeze=False) - for ax, idx in zip(axs.ravel(), show, strict=False): + n_rows, n_cols = _subplot_grid(payload["num_cells"]) + axs = fig.subplots(n_rows, n_cols, squeeze=False) + image = None + for idx, ax in enumerate(axs.ravel()): + if idx >= payload["num_cells"]: + ax.axis("off") + continue image = ax.imshow( payload["fields"][idx], origin="lower", @@ -1115,7 +1166,8 @@ def _plot_raster(ax, time_s, spikes, *, max_cells=20): ax.set_title(f"Cell {idx + 1}") ax.set_xticks([]) ax.set_yticks([]) - fig.colorbar(image, ax=axs.ravel().tolist(), shrink=0.78) + if image is not None: + fig.colorbar(image, ax=axs.ravel().tolist(), shrink=0.78) """, """ # SECTION 2: Visualize the simulated neural activity @@ -1133,8 +1185,7 @@ def _plot_raster(ax, time_s, spikes, *, max_cells=20): fig = _prepare_figure("plot(x_u(1,:),x_u(2,:),'b',px,py,'k')", figsize=(6.0, 6.0)) ax = fig.subplots(1, 1) ax.plot(payload["px"], payload["py"], color="k", linewidth=1.8, label="True path") - ax.plot(payload["decoded_subset_x"], payload["decoded_subset_y"], color="tab:orange", linewidth=1.0, label="Subset decode") - ax.plot(payload["decoded_full_x"], payload["decoded_full_y"], color="tab:blue", linewidth=1.2, label="Full decode") + ax.plot(payload["decoded_x"], payload["decoded_y"], color="tab:blue", linewidth=1.2, label=payload["decode_method"]) ax.set_title("Decoded X-Y trajectory") ax.set_xlabel("x") ax.set_ylabel("y") @@ -1144,23 +1195,23 @@ def _plot_raster(ax, time_s, spikes, *, max_cells=20): fig = _prepare_figure("plot(decoded trajectories)", figsize=(10.0, 5.5)) axs = fig.subplots(2, 1, sharex=True) axs[0].plot(payload["time_s"], payload["px"], color="k", linewidth=1.6, label="True x") - axs[0].plot(payload["time_s"], payload["decoded_full_x"], color="tab:blue", linewidth=1.2, label="Decoded x") - axs[0].plot(payload["time_s"], payload["decoded_subset_x"], color="tab:orange", linewidth=1.0, label="Subset x") + axs[0].plot(payload["time_s"][: payload["decoded_x"].size], payload["decoded_x"], color="tab:blue", linewidth=1.2, label="Decoded x") + axs[0].plot(payload["time_s"][: payload["predicted_x"].size], payload["predicted_x"], color="tab:orange", linewidth=1.0, label="Predicted x") axs[0].legend(loc="best", frameon=False, fontsize=8) axs[0].set_ylabel("x") axs[1].plot(payload["time_s"], payload["py"], color="k", linewidth=1.6, label="True y") - axs[1].plot(payload["time_s"], payload["decoded_full_y"], color="tab:blue", linewidth=1.2, label="Decoded y") - axs[1].plot(payload["time_s"], payload["decoded_subset_y"], color="tab:orange", linewidth=1.0, label="Subset y") + axs[1].plot(payload["time_s"][: payload["decoded_y"].size], payload["decoded_y"], color="tab:blue", linewidth=1.2, label="Decoded y") + axs[1].plot(payload["time_s"][: payload["predicted_y"].size], payload["predicted_y"], color="tab:orange", linewidth=1.0, label="Predicted y") axs[1].set_ylabel("y") axs[1].set_xlabel("time (s)") fig = _prepare_figure("decode_rmse", figsize=(7.0, 4.5)) ax = fig.subplots(1, 1) - error_full = np.sqrt((payload["decoded_full_x"] - payload["px"]) ** 2 + (payload["decoded_full_y"] - payload["py"]) ** 2) - error_subset = np.sqrt((payload["decoded_subset_x"] - payload["px"]) ** 2 + (payload["decoded_subset_y"] - payload["py"]) ** 2) - ax.plot(payload["time_s"], error_full, color="tab:blue", linewidth=1.2, label="Full decode") - ax.plot(payload["time_s"], error_subset, color="tab:orange", linewidth=1.0, label="Subset decode") - ax.set_title(f"Pointwise decoding error (RMSE={payload['rmse_full']:.3f})") + error_decode = np.sqrt((payload["decoded_x"] - payload["px"][: payload["decoded_x"].size]) ** 2 + (payload["decoded_y"] - payload["py"][: payload["decoded_y"].size]) ** 2) + error_predict = np.sqrt((payload["predicted_x"] - payload["px"][: payload["predicted_x"].size]) ** 2 + (payload["predicted_y"] - payload["py"][: payload["predicted_y"].size]) ** 2) + ax.plot(payload["time_s"][: error_decode.size], error_decode, color="tab:blue", linewidth=1.2, label="Filtered decode") + ax.plot(payload["time_s"][: error_predict.size], error_predict, color="tab:orange", linewidth=1.0, label="Predicted decode") + ax.set_title(f"Pointwise decoding error (RMSE={payload['decode_rmse']:.3f})") ax.set_xlabel("time (s)") ax.set_ylabel("Euclidean error") ax.legend(loc="best", frameon=False, fontsize=8) diff --git a/tools/notebooks/parity_notes.yml b/tools/notebooks/parity_notes.yml index fa1b4c8d..525750be 100644 --- a/tools/notebooks/parity_notes.yml +++ b/tools/notebooks/parity_notes.yml @@ -63,5 +63,5 @@ notes: - topic: StimulusDecode2D file: notebooks/StimulusDecode2D.ipynb source_matlab: StimulusDecode2D.mlx - fidelity_status: partial - remaining_differences: The notebook reproduces the MATLAB section order, figure inventory, simulated receptive fields, and decoded-trajectory presentation, but the current Python decoder still uses regression-based state recovery instead of MATLAB's symbolic-CIF nonlinear filter. + fidelity_status: high_fidelity + remaining_differences: The notebook now follows the MATLAB nonlinear-CIF decoding workflow and uses `DecodingAlgorithms.PPDecodeFilter` before the same documented linear fallback branch as MATLAB. Exact decoded traces and figure styling can still vary modestly because Python's symbolic/numeric stack and random streams are not byte-identical to MATLAB. diff --git a/tools/parity/matlab/export_matlab_gold_fixtures.m b/tools/parity/matlab/export_matlab_gold_fixtures.m index bbae80e3..a62e99ac 100644 --- a/tools/parity/matlab/export_matlab_gold_fixtures.m +++ b/tools/parity/matlab/export_matlab_gold_fixtures.m @@ -26,6 +26,7 @@ function export_matlab_gold_fixtures(repoRoot, matlabRepoRoot) export_analysis_fixture(fixtureRoot); export_point_process_fixture(fixtureRoot); export_decoding_predict_fixture(fixtureRoot); +export_nonlinear_decode_fixture(fixtureRoot); export_simulated_network_fixture(fixtureRoot); end @@ -141,6 +142,8 @@ function export_nstcoll_fixture(fixtureRoot) function export_cif_fixture(fixtureRoot) cif = CIF([0.1 0.5], {'stim1', 'stim2'}, {'stim1', 'stim2'}, 'binomial'); stimVal = [0.6; -0.2]; +polyCif = build_polynomial_binomial_cif([-2.0 -0.5 0.3 -0.2 -0.1 0.05]); +polyStim = [0.2; -0.4]; payload = struct(); payload.beta = [0.1 0.5]; @@ -150,6 +153,13 @@ function export_cif_fixture(fixtureRoot) payload.gradient_log = cif.evalGradientLog(stimVal); payload.jacobian = cif.evalJacobian(stimVal); payload.jacobian_log = cif.evalJacobianLog(stimVal); +payload.poly_beta = [-2.0 -0.5 0.3 -0.2 -0.1 0.05]; +payload.poly_stimVal = polyStim; +payload.poly_lambda_delta = polyCif.evalLambdaDelta(polyStim); +payload.poly_gradient = polyCif.evalGradient(polyStim); +payload.poly_gradient_log = polyCif.evalGradientLog(polyStim); +payload.poly_jacobian = polyCif.evalJacobian(polyStim); +payload.poly_jacobian_log = polyCif.evalJacobianLog(polyStim); save(fullfile(fixtureRoot, 'cif_exactness.mat'), '-struct', 'payload'); end @@ -227,6 +237,35 @@ function export_decoding_predict_fixture(fixtureRoot) save(fullfile(fixtureRoot, 'decoding_predict_exactness.mat'), '-struct', 'payload'); end +function export_nonlinear_decode_fixture(fixtureRoot) +A = eye(2); +Q = 0.01 * eye(2); +Px0 = 0.05 * eye(2); +delta = 0.1; +dN = [0 1 0 1; + 1 0 1 0]; +lambdaCIF = { + build_polynomial_binomial_cif([-2.0 -0.5 0.3 -0.2 -0.1 0.05]), ... + build_polynomial_binomial_cif([-1.5 0.4 -0.2 0.15 -0.05 0.02]) +}; +[x_p, W_p, x_u, W_u] = DecodingAlgorithms.PPDecodeFilter(A, Q, Px0, dN, lambdaCIF, delta); + +payload = struct(); +payload.A = A; +payload.Q = Q; +payload.Px0 = Px0; +payload.delta = delta; +payload.dN = dN; +payload.beta1 = [-2.0 -0.5 0.3 -0.2 -0.1 0.05]; +payload.beta2 = [-1.5 0.4 -0.2 0.15 -0.05 0.02]; +payload.x_p = x_p; +payload.W_p = W_p; +payload.x_u = x_u; +payload.W_u = W_u; + +save(fullfile(fixtureRoot, 'nonlinear_decode_exactness.mat'), '-struct', 'payload'); +end + function export_simulated_network_fixture(fixtureRoot) rng(4); Ts = .001; @@ -278,3 +317,54 @@ function export_simulated_network_fixture(fixtureRoot) save(fullfile(fixtureRoot, 'simulated_network_exactness.mat'), '-struct', 'payload'); end + +function cifObj = build_polynomial_binomial_cif(beta) +beta = beta(:)'; +syms x y real +cifObj = CIF(beta(1:3), {'1', 'x', 'y'}, {'x', 'y'}, 'binomial'); +cifObj.b = beta; +cifObj.varIn = [sym(1); x; y; x^2; y^2; x * y]; +cifObj.stimVars = [x; y]; +cifObj.fitType = 'binomial'; +cifObj.history = []; +cifObj.histCoeffs = []; +cifObj.histVars = {}; +cifObj.histCoeffVars = {}; +cifObj.spikeTrain = []; +cifObj.historyMat = []; +cifObj.lambdaDelta = simplify(exp(beta * cifObj.varIn) ./ (1 + exp(beta * cifObj.varIn))); +cifObj.lambdaDeltaFunction = matlabFunction(cifObj.lambdaDelta, 'vars', symvar(cifObj.varIn)); +cifObj.gradientLambdaDelta = simplify(jacobian(cifObj.lambdaDelta, cifObj.stimVars)); +cifObj.gradientLogLambdaDelta = simplify(jacobian(log(cifObj.lambdaDelta), cifObj.stimVars)); +cifObj.gradientFunction = matlabFunction(cifObj.gradientLambdaDelta, 'vars', symvar(cifObj.varIn)); +cifObj.gradientLogFunction = matlabFunction(cifObj.gradientLogLambdaDelta, 'vars', symvar(cifObj.varIn)); +cifObj.jacobianLambdaDelta = simplify(jacobian(cifObj.gradientLambdaDelta, cifObj.stimVars)); +cifObj.jacobianFunction = matlabFunction(cifObj.jacobianLambdaDelta, 'vars', symvar(cifObj.varIn)); +cifObj.jacobianLogLambdaDelta = simplify(jacobian(cifObj.gradientLogLambdaDelta, cifObj.stimVars)); +cifObj.jacobianLogFunction = matlabFunction(cifObj.jacobianLogLambdaDelta, 'vars', symvar(cifObj.varIn)); +cifObj.lambdaDeltaGamma = []; +cifObj.LogLambdaDeltaGamma = []; +cifObj.gradientLambdaDeltaGamma = []; +cifObj.gradientLogLambdaDeltaGamma = []; +cifObj.jacobianLambdaDeltaGamma = []; +cifObj.jacobianLogLambdaDeltaGamma = []; +cifObj.lambdaDeltaGammaFunction = []; +cifObj.LogLambdaDeltaGammaFunction = []; +cifObj.gradientFunctionGamma = []; +cifObj.gradientLogFunctionGamma = []; +cifObj.jacobianFunctionGamma = []; +cifObj.jacobianLogFunctionGamma = []; +cifObj.indepVars = symvar(cifObj.lambdaDelta); + +vars = symvar(cifObj.varIn); +if length(vars) == 1 + cifObj.argstr = 'val'; +else + argstr = 'val(1)'; + for i = 2:length(vars) + argstr = strcat(argstr, [',val(' num2str(i) ')']); + end + cifObj.argstr = argstr; +end +cifObj.argstrLDGamma = ''; +end