Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 22 additions & 19 deletions notebooks/AnalysisExamples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
"cells": [
{
"cell_type": "markdown",
"id": "98fc6fc8",
"id": "4ba2084f",
"metadata": {},
"source": [
"<!-- parity-note -->\n",
"## MATLAB Parity Note\n",
"- Source MATLAB helpfile: `AnalysisExamples.mlx`\n",
"- Fidelity status: `exact`\n",
"- Remaining justified differences: Complete MATLAB standard-GLM workflow with the canonical glm_data.mat dataset and real KS/model-visualization figures. Only inherent GLM solver numerics and matplotlib styling differ."
"- Remaining justified differences: The notebook now follows the MATLAB standard-GLM workflow with the canonical `glm_data.mat` dataset and real KS/model-visualization figures; coefficient values and styling still vary modestly because the Python GLM backend and plotting defaults differ from MATLAB.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7807842d",
"id": "1bf27b7c",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -44,12 +44,14 @@
"OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n",
"__tracker = FigureTracker(topic=\"AnalysisExamples\", output_root=OUTPUT_ROOT, expected_count=4)\n",
"\n",
"\n",
"def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n",
" fig = __tracker.new_figure(matlab_line)\n",
" fig.clear()\n",
" fig.set_size_inches(*figsize)\n",
" return fig\n",
"\n",
"\n",
"def _poisson_standard_errors(design_matrix, result):\n",
" x = np.asarray(design_matrix, dtype=float)\n",
" if x.ndim == 1:\n",
Expand All @@ -60,6 +62,7 @@
" cov = np.linalg.pinv(x_aug.T @ (lam[:, None] * x_aug))\n",
" return np.sqrt(np.clip(np.diag(cov), 0.0, None))\n",
"\n",
"\n",
"T = np.asarray(GLM_DATA[\"T\"], dtype=float).reshape(-1)\n",
"xN = np.asarray(GLM_DATA[\"xN\"], dtype=float).reshape(-1)\n",
"yN = np.asarray(GLM_DATA[\"yN\"], dtype=float).reshape(-1)\n",
Expand All @@ -68,25 +71,25 @@
"x_at_spiketimes = np.asarray(GLM_DATA[\"x_at_spiketimes\"], dtype=float).reshape(-1)\n",
"y_at_spiketimes = np.asarray(GLM_DATA[\"y_at_spiketimes\"], dtype=float).reshape(-1)\n",
"sample_rate = 1.0 / float(np.median(np.diff(T)))\n",
"nst = nspikeTrain(spiketimes, name=\"1\", minTime=float(T[0]), maxTime=float(T[-1]), makePlots=-1)"
"nst = nspikeTrain(spiketimes, name=\"1\", minTime=float(T[0]), maxTime=float(T[-1]), makePlots=-1)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "37ac20c9",
"id": "39e80ef9",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 1: Analysis Examples\n",
"plt.close(\"all\")\n",
"print({\"n_samples\": int(T.shape[0]), \"n_spikes\": int(spiketimes.shape[0]), \"sample_rate_hz\": round(sample_rate, 3)})"
"print({\"n_samples\": int(T.shape[0]), \"n_spikes\": int(spiketimes.shape[0]), \"sample_rate_hz\": round(sample_rate, 3)})\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dbdc74f9",
"id": "b2e4c2dc",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -104,13 +107,13 @@
"x_quadratic = np.column_stack([xN, yN, xN**2, yN**2, xN * yN])\n",
"linear_fit = fit_poisson_glm(x_linear, spikes_binned)\n",
"quadratic_fit = fit_poisson_glm(x_quadratic, spikes_binned)\n",
"centered_fit = fit_poisson_glm(x_quadratic_centered, spikes_binned)"
"centered_fit = fit_poisson_glm(x_quadratic_centered, spikes_binned)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c38cff1",
"id": "15ae5117",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -122,13 +125,13 @@
"ax.set_aspect(\"equal\", adjustable=\"box\")\n",
"ax.set_xlabel(\"x position (m)\")\n",
"ax.set_ylabel(\"y position (m)\")\n",
"ax.set_title(\"Rat trajectory with spike locations\")"
"ax.set_title(\"Rat trajectory with spike locations\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5af52914",
"id": "a8839159",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -141,13 +144,13 @@
"ax.errorbar(xpos, centered_beta, yerr=centered_se, fmt=\".\", color=\"tab:blue\", capsize=3)\n",
"ax.set_xticks(xpos, [\"baseline\", \"x\", \"y\", \"x^2\", \"y^2\", \"x*y\"])\n",
"ax.set_ylabel(\"coefficient value\")\n",
"ax.set_title(\"Quadratic GLM coefficients\")"
"ax.set_title(\"Quadratic GLM coefficients\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5cac7309",
"id": "6bda2bd9",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -166,13 +169,13 @@
"ax.set_xlabel(\"x position (m)\")\n",
"ax.set_ylabel(\"y position (m)\")\n",
"ax.set_zlabel(\"lambda\")\n",
"ax.set_title(\"Quadratic GLM spatial intensity\")"
"ax.set_title(\"Quadratic GLM spatial intensity\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36dd8e70",
"id": "0d526433",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -186,13 +189,13 @@
" \"linear_mean_rate_hz\": round(float(np.mean(lambda_linear_hz)), 4),\n",
" \"quadratic_mean_rate_hz\": round(float(np.mean(lambda_quadratic_hz)), 4),\n",
" }\n",
")"
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d8b81fd",
"id": "59472d54",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -217,7 +220,7 @@
"ax.set_ylabel(\"Empirical CDF of Rescaled ISIs\")\n",
"ax.set_title(\"KS Plot with 95% Confidence Intervals\")\n",
"ax.legend(loc=\"lower right\", frameon=False)\n",
"__tracker.finalize()"
"__tracker.finalize()\n"
]
}
],
Expand All @@ -234,4 +237,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
63 changes: 39 additions & 24 deletions notebooks/AnalysisExamples2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
"cells": [
{
"cell_type": "markdown",
"id": "66d56086",
"id": "04a1108d",
"metadata": {},
"source": [
"<!-- parity-note -->\n",
"## MATLAB Parity Note\n",
"- Source MATLAB helpfile: `AnalysisExamples2.mlx`\n",
"- Fidelity status: `exact`\n",
"- Remaining justified differences: Complete MATLAB toolbox workflow on the canonical glm_data.mat dataset with executable Trial, ConfigColl, and Analysis calls. Only inherent GLM solver numerics and plot styling differ."
"- Remaining justified differences: The notebook now follows the MATLAB toolbox workflow on the canonical `glm_data.mat` dataset with executable `Trial`, `ConfigColl`, and `Analysis` calls; exact coefficients and plot styling still vary modestly because the Python GLM backend differs from MATLAB.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "62e21501",
"id": "f4ecd812",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -42,14 +42,16 @@
"\n",
"GLM_DATA = load_glm_data_for_notebook()\n",
"OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n",
"__tracker = FigureTracker(topic=\"AnalysisExamples2\", output_root=OUTPUT_ROOT, expected_count=4)\n",
"__tracker = FigureTracker(topic=\"AnalysisExamples2\", output_root=OUTPUT_ROOT, expected_count=5)\n",
"\n",
"\n",
"def _prepare_figure(matlab_line: str, *, figsize=(8.0, 4.5)):\n",
" fig = __tracker.new_figure(matlab_line)\n",
" fig.clear()\n",
" fig.set_size_inches(*figsize)\n",
" return fig\n",
"\n",
"\n",
"T = np.asarray(GLM_DATA[\"T\"], dtype=float).reshape(-1)\n",
"xN = np.asarray(GLM_DATA[\"xN\"], dtype=float).reshape(-1)\n",
"yN = np.asarray(GLM_DATA[\"yN\"], dtype=float).reshape(-1)\n",
Expand All @@ -65,36 +67,36 @@
"velocity = Covariate(T, np.column_stack([vxN, vyN]), \"Velocity\", \"time\", \"s\", \"m/s\", [\"v_x\", \"v_y\"])\n",
"radial = Covariate(T, np.column_stack([xN, yN, xN**2, yN**2, xN * yN]), \"Radial\", \"time\", \"s\", \"m\", [\"x\", \"y\", \"x^2\", \"y^2\", \"x*y\"])\n",
"values_at_spiketimes = position.getValueAt(spiketimes)\n",
"values_at_spiketimes_upsampled = position.resample(1.0 / np.min(np.diff(spiketimes))).getValueAt(spiketimes)"
"values_at_spiketimes_upsampled = position.resample(1.0 / np.min(np.diff(spiketimes))).getValueAt(spiketimes)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1836e297",
"id": "4dd70d35",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 1: Analysis Examples 2\n",
"plt.close(\"all\")\n",
"print({\"n_samples\": int(T.shape[0]), \"n_spikes\": int(spiketimes.shape[0]), \"analysis_sample_rate_hz\": sample_rate})"
"print({\"n_samples\": int(T.shape[0]), \"n_spikes\": int(spiketimes.shape[0]), \"analysis_sample_rate_hz\": sample_rate})\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf657cd0",
"id": "da89b88a",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 2: load the rat trajectory and spiking data\n",
"print({\"position_shape\": list(position.data.shape), \"velocity_shape\": list(velocity.data.shape), \"radial_shape\": list(radial.data.shape)})"
"print({\"position_shape\": list(position.data.shape), \"velocity_shape\": list(velocity.data.shape), \"radial_shape\": list(radial.data.shape)})\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe47aacc",
"id": "4346a9a7",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -104,13 +106,13 @@
" \"direct_spike_position_head\": np.asarray(values_at_spiketimes[:3], dtype=float).round(4).tolist(),\n",
" \"upsampled_spike_position_head\": np.asarray(values_at_spiketimes_upsampled[:3], dtype=float).round(4).tolist(),\n",
" }\n",
")"
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f28e3be",
"id": "d5785121",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -122,13 +124,13 @@
"ax.set_aspect(\"equal\", adjustable=\"box\")\n",
"ax.set_xlabel(\"x position (m)\")\n",
"ax.set_ylabel(\"y position (m)\")\n",
"ax.set_title(\"Trajectory and interpolated spike locations\")"
"ax.set_title(\"Trajectory and interpolated spike locations\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d40c40c9",
"id": "6cf72911",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -141,27 +143,27 @@
" TrialConfig([[\"Baseline\", \"mu\"], [\"Radial\", \"x\", \"y\", \"x^2\", \"y^2\", \"x*y\"]], sampleRate=sample_rate, history=[], name=\"Quadratic\"),\n",
" TrialConfig([[\"Baseline\", \"mu\"], [\"Radial\", \"x\", \"y\", \"x^2\", \"y^2\", \"x*y\"]], sampleRate=sample_rate, history=[0.0, 1.0 / sample_rate], name=\"Quadratic+Hist\"),\n",
"]\n",
"tcc = ConfigColl(tc)"
"tcc = ConfigColl(tc)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9160bdb",
"id": "e3b67df9",
"metadata": {},
"outputs": [],
"source": [
"# SECTION 6: Create our collection of configurations and run the analysis\n",
"fitResults = Analysis.RunAnalysisForAllNeurons(trial, tcc, 0)\n",
"fig = _prepare_figure(\"fitResults.plotResults\", figsize=(11.0, 8.0))\n",
"fitResults.plotResults(handle=fig)\n",
"print({\"config_names\": fitResults.configNames, \"aic\": np.asarray(fitResults.AIC, dtype=float).round(3).tolist()})"
"print({\"config_names\": fitResults.configNames, \"aic\": np.asarray(fitResults.AIC, dtype=float).round(3).tolist()})\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a2fafcc8",
"id": "79879555",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -178,21 +180,34 @@
"ax.set_xlabel(\"x position (m)\")\n",
"ax.set_ylabel(\"y position (m)\")\n",
"ax.set_zlabel(\"lambda\")\n",
"ax.set_title(\"Toolbox-model spatial intensity comparison\")"
"ax.set_title(\"Toolbox-model spatial intensity comparison\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "64cdac30",
"id": "eef1351f",
"metadata": {},
"outputs": [],
"source": "# SECTION 8: Toolbox vs. Standard GLM comparison\n# Compare the nSTAT fit with a standalone glmfit using the same Quadratic covariates\n# MATLAB: [b,dev,stats] = glmfit([xN yN xN.^2 yN.^2 xN.*yN], spikes_binned, 'poisson');\n# b - fitResults.b{2} % should be close to zero\nX_quad = np.column_stack([xN, yN, xN**2, yN**2, xN * yN])\nglm_result = fit_poisson_glm(X_quad, spikes_binned)\nb = np.concatenate([[glm_result.intercept], glm_result.coefficients])\nb_diff = b - fitResults.getCoeffs(2)\nprint(\"b - fitResults.b{2} =\", b_diff)"
"source": [
"# SECTION 8: Toolbox vs. Standard GLM comparison\n",
"standard_fit = fit_poisson_glm(np.column_stack([np.ones_like(xN), xN, yN, xN**2, yN**2, xN * yN]), spikes_binned, include_intercept=False)\n",
"coeff_diff = np.asarray(standard_fit.coefficients - fitResults.getCoeffs(2)[0], dtype=float)\n",
"fig = _prepare_figure(\"b-fitResults.b{2}\", figsize=(7.0, 4.5))\n",
"ax = fig.subplots(1, 1)\n",
"labels = [\"mu\", \"x\", \"y\", \"x^2\", \"y^2\", \"x*y\"]\n",
"ax.bar(np.arange(coeff_diff.size), coeff_diff, color=\"tab:blue\")\n",
"ax.axhline(0.0, color=\"0.3\", linestyle=\"--\", linewidth=1.0)\n",
"ax.set_xticks(np.arange(coeff_diff.size), labels, rotation=20)\n",
"ax.set_ylabel(\"standard minus toolbox\")\n",
"ax.set_title(\"Coefficient agreement between workflows\")\n",
"print({\"quadratic_coeff_diff_max_abs\": round(float(np.max(np.abs(coeff_diff))), 6)})\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8782d383",
"id": "04cd0c92",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -208,7 +223,7 @@
"ax.set_ylabel(\"AIC\")\n",
"ax.set_title(\"History-lag model comparison\")\n",
"print({\"history_config_names\": histConfigs.getConfigNames(), \"summary_fit_names\": histSummary.fitNames})\n",
"__tracker.finalize()"
"__tracker.finalize()\n"
]
}
],
Expand All @@ -218,7 +233,7 @@
},
"nstat": {
"expected_figures": 5,
"run_group": "full",
"run_group": "smoke",
"style": "python-example",
"topic": "AnalysisExamples2"
}
Expand Down
Loading
Loading