diff --git a/notebooks/AnalysisExamples.ipynb b/notebooks/AnalysisExamples.ipynb index 31faeedb..89257604 100644 --- a/notebooks/AnalysisExamples.ipynb +++ b/notebooks/AnalysisExamples.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "667810f9", + "id": "fc7be361", "metadata": {}, "source": [ "\n", @@ -15,7 +15,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1f210963", + "id": "45d93add", "metadata": {}, "outputs": [], "source": [ @@ -34,15 +34,13 @@ "matplotlib.use(\"Agg\")\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "from scipy.io import loadmat\n", "\n", "from nstat import Analysis, Covariate, nspikeTrain\n", - "from nstat.data_manager import ensure_example_data\n", "from nstat.glm import fit_poisson_glm\n", + "from nstat.notebook_data import load_glm_data_for_notebook\n", "from nstat.notebook_figures import FigureTracker\n", "\n", - "DATA_DIR = ensure_example_data(download=True)\n", - "GLM_DATA = loadmat(DATA_DIR / \"glm_data.mat\", squeeze_me=True, struct_as_record=False)\n", + "GLM_DATA = load_glm_data_for_notebook()\n", "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic=\"AnalysisExamples\", output_root=OUTPUT_ROOT, expected_count=4)\n", "\n", @@ -79,7 +77,7 @@ { "cell_type": "code", "execution_count": null, - "id": "401a35e2", + "id": "3c621348", "metadata": {}, "outputs": [], "source": [ @@ -91,7 +89,7 @@ { "cell_type": "code", "execution_count": null, - "id": "501a6470", + "id": "c1d9b5e4", "metadata": {}, "outputs": [], "source": [ @@ -115,7 +113,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ee81820c", + "id": "b5f3a818", "metadata": {}, "outputs": [], "source": [ @@ -133,7 +131,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1056f293", + "id": "396cb183", "metadata": {}, "outputs": [], "source": [ @@ -152,7 +150,7 @@ { "cell_type": "code", "execution_count": null, - "id": "98e73438", + "id": "49d54a88", "metadata": {}, "outputs": [], "source": [ @@ -177,7 +175,7 @@ { "cell_type": "code", "execution_count": null, - "id": "46235a71", + "id": "8b700118", "metadata": {}, "outputs": [], "source": [ @@ -197,7 +195,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a5a09608", + "id": "9bd202c8", "metadata": {}, "outputs": [], "source": [ diff --git a/notebooks/AnalysisExamples2.ipynb b/notebooks/AnalysisExamples2.ipynb index fd852da4..e0e11f4e 100644 --- a/notebooks/AnalysisExamples2.ipynb +++ b/notebooks/AnalysisExamples2.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "4acc696f", + "id": "2468a9e7", "metadata": {}, "source": [ "\n", @@ -15,7 +15,7 @@ { "cell_type": "code", "execution_count": null, - "id": "06139fcd", + "id": "5e1d1998", "metadata": {}, "outputs": [], "source": [ @@ -34,15 +34,13 @@ "matplotlib.use(\"Agg\")\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "from scipy.io import loadmat\n", "\n", "from nstat import Analysis, ConfigColl, CovColl, Covariate, FitResSummary, Trial, TrialConfig, nspikeTrain, nstColl\n", - "from nstat.data_manager import ensure_example_data\n", "from nstat.glm import fit_poisson_glm\n", + "from nstat.notebook_data import load_glm_data_for_notebook\n", "from nstat.notebook_figures import FigureTracker\n", "\n", - "DATA_DIR = ensure_example_data(download=True)\n", - "GLM_DATA = loadmat(DATA_DIR / \"glm_data.mat\", squeeze_me=True, struct_as_record=False)\n", + "GLM_DATA = load_glm_data_for_notebook()\n", "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic=\"AnalysisExamples2\", output_root=OUTPUT_ROOT, expected_count=5)\n", "\n", @@ -75,7 +73,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dec86050", + "id": "45dc365a", "metadata": {}, "outputs": [], "source": [ @@ -87,7 +85,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d1533de9", + "id": "2a9182fe", "metadata": {}, "outputs": [], "source": [ @@ -98,7 +96,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e35b302e", + "id": "126391f1", "metadata": {}, "outputs": [], "source": [ @@ -114,7 +112,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8ec23b9f", + "id": "8aaea418", "metadata": {}, "outputs": [], "source": [ @@ -132,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d2d28f3d", + "id": "d17e023e", "metadata": {}, "outputs": [], "source": [ @@ -151,7 +149,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0a79790c", + "id": "4ab39635", "metadata": {}, "outputs": [], "source": [ @@ -165,7 +163,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cf575ee5", + "id": "db6c7107", "metadata": {}, "outputs": [], "source": [ @@ -188,7 +186,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fd4cd7e9", + "id": "5a1dbe4c", "metadata": {}, "outputs": [], "source": [ @@ -209,7 +207,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d9198965", + "id": "85a9d741", "metadata": {}, "outputs": [], "source": [ diff --git a/notebooks/ConfigCollExamples.ipynb b/notebooks/ConfigCollExamples.ipynb index ecc87b98..56dffa58 100644 --- a/notebooks/ConfigCollExamples.ipynb +++ b/notebooks/ConfigCollExamples.ipynb @@ -22,34 +22,17 @@ "matplotlib.use(\"Agg\")\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "from scipy.io import loadmat\n", "\n", - "from nstat.data_manager import ensure_example_data\n", "from nstat.notebook_figures import FigureTracker\n", "\n", "np.random.seed(0)\n", - "DATA_DIR = ensure_example_data(download=True)\n", "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic='ConfigCollExamples', output_root=OUTPUT_ROOT, expected_count=0)\n", "\n", - "def _load_example_globals(name: str) -> dict[str, object]:\n", - " candidates = [\n", - " Path(name),\n", - " DATA_DIR / name,\n", - " DATA_DIR / \"mEPSCs\" / name,\n", - " DATA_DIR / \"Place Cells\" / name,\n", - " DATA_DIR / \"Explicit Stimulus\" / name,\n", - " ]\n", - " for path in candidates:\n", - " if path.exists():\n", - " data = loadmat(path)\n", - " return {k: v for k, v in data.items() if not k.startswith(\"__\")}\n", - " return {}\n", - "\n", "# SECTION 0: Section 0\n", "# ConfigColl Examples\n", "# tcObj=TrialConfig(covMask,sampleRate, history,minTime,maxTime)\n", - "__tracker.finalize()" + "__tracker.finalize()\n" ] } ], diff --git a/notebooks/CovariateExamples.ipynb b/notebooks/CovariateExamples.ipynb index f67a09dd..edc5cc69 100644 --- a/notebooks/CovariateExamples.ipynb +++ b/notebooks/CovariateExamples.ipynb @@ -22,33 +22,16 @@ "matplotlib.use(\"Agg\")\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "from scipy.io import loadmat\n", "\n", - "from nstat.data_manager import ensure_example_data\n", "from nstat.notebook_figures import FigureTracker\n", "\n", "np.random.seed(0)\n", - "DATA_DIR = ensure_example_data(download=True)\n", "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic='CovariateExamples', output_root=OUTPUT_ROOT, expected_count=2)\n", "\n", - "def _load_example_globals(name: str) -> dict[str, object]:\n", - " candidates = [\n", - " Path(name),\n", - " DATA_DIR / name,\n", - " DATA_DIR / \"mEPSCs\" / name,\n", - " DATA_DIR / \"Place Cells\" / name,\n", - " DATA_DIR / \"Explicit Stimulus\" / name,\n", - " ]\n", - " for path in candidates:\n", - " if path.exists():\n", - " data = loadmat(path)\n", - " return {k: v for k, v in data.items() if not k.startswith(\"__\")}\n", - " return {}\n", - "\n", "# SECTION 0: Section 0\n", "# Test the Cov class\n", - "# Covariates are just like signals with a mean and a standard deviation They have two representations, the default (original representation) and a zero-mean representation" + "# Covariates are just like signals with a mean and a standard deviation They have two representations, the default (original representation) and a zero-mean representation\n" ] }, { diff --git a/notebooks/ExplicitStimulusWhiskerData.ipynb b/notebooks/ExplicitStimulusWhiskerData.ipynb index 7938e8f9..325ba9d7 100644 --- a/notebooks/ExplicitStimulusWhiskerData.ipynb +++ b/notebooks/ExplicitStimulusWhiskerData.ipynb @@ -2,20 +2,20 @@ "cells": [ { "cell_type": "markdown", - "id": "1978cb81", + "id": "8b8aa493", "metadata": {}, "source": [ "\n", "## MATLAB Parity Note\n", "- Source MATLAB helpfile: `ExplicitStimulusWhiskerData.mlx`\n", "- Fidelity status: `high_fidelity`\n", - "- Remaining justified differences: The notebook now reproduces the dataset-backed lag search, stimulus-effect, and history-effect workflow with real figures; exact KS traces and coefficient values still vary modestly from MATLAB because the Python GLM backend and plotting defaults are different." + "- Remaining justified differences: The notebook now reproduces the dataset-backed lag search, stimulus-effect, and history-effect workflow with real figures; exact KS traces and coefficient values still vary modestly from MATLAB because the Python GLM backend and plotting defaults are different.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "9787cf60", + "id": "dbf8e486", "metadata": {}, "outputs": [], "source": [ @@ -35,12 +35,12 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "from nstat.data_manager import ensure_example_data\n", + "from nstat.notebook_data import notebook_example_data_dir\n", "from nstat.notebook_figures import FigureTracker\n", "from nstat.paper_examples_full import run_experiment2\n", "\n", "np.random.seed(0)\n", - "DATA_DIR = ensure_example_data(download=True)\n", + "DATA_DIR = notebook_example_data_dir(allow_synthetic=True)\n", "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic='ExplicitStimulusWhiskerData', output_root=OUTPUT_ROOT, expected_count=9)\n", "\n", @@ -83,7 +83,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a60a7a6d", + "id": "e296bd4d", "metadata": {}, "outputs": [], "source": [ @@ -106,7 +106,7 @@ { "cell_type": "code", "execution_count": null, - "id": "862db342", + "id": "6a18af75", "metadata": {}, "outputs": [], "source": [ @@ -134,7 +134,7 @@ { "cell_type": "code", "execution_count": null, - "id": "98d686d5", + "id": "fd207a34", "metadata": {}, "outputs": [], "source": [ @@ -149,7 +149,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2b5d9da0", + "id": "2afb535d", "metadata": {}, "outputs": [], "source": [ @@ -170,7 +170,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4a033d5c", + "id": "7a48f375", "metadata": {}, "outputs": [], "source": [ @@ -198,7 +198,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0a73d5e1", + "id": "043ef33a", "metadata": {}, "outputs": [], "source": [ @@ -230,7 +230,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5aa18805", + "id": "d66ac872", "metadata": {}, "outputs": [], "source": [ @@ -290,4 +290,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/notebooks/HippocampalPlaceCellExample.ipynb b/notebooks/HippocampalPlaceCellExample.ipynb index 6a81718a..04e3ad7d 100644 --- a/notebooks/HippocampalPlaceCellExample.ipynb +++ b/notebooks/HippocampalPlaceCellExample.ipynb @@ -2,20 +2,20 @@ "cells": [ { "cell_type": "markdown", - "id": "e0aeece6", + "id": "4110300d", "metadata": {}, "source": [ "\n", "## MATLAB Parity Note\n", "- Source MATLAB helpfile: `HippocampalPlaceCellExample.mlx`\n", "- Fidelity status: `high_fidelity`\n", - "- Remaining justified differences: The notebook now reproduces the dataset-backed place-cell model-comparison and field-visualization workflow with real figures; the Python port still uses an approximate Zernike-like basis rather than the original MATLAB toolbox implementation." + "- Remaining justified differences: The notebook now reproduces the dataset-backed place-cell model-comparison and field-visualization workflow with real figures; the Python port still uses an approximate Zernike-like basis rather than the original MATLAB toolbox implementation.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "cadf7961", + "id": "8c6412bd", "metadata": {}, "outputs": [], "source": [ @@ -35,12 +35,12 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "from nstat.data_manager import ensure_example_data\n", + "from nstat.notebook_data import notebook_example_data_dir\n", "from nstat.notebook_figures import FigureTracker\n", "from nstat.paper_examples_full import run_experiment4\n", "\n", "np.random.seed(0)\n", - "DATA_DIR = ensure_example_data(download=True)\n", + "DATA_DIR = notebook_example_data_dir(allow_synthetic=True)\n", "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic='HippocampalPlaceCellExample', output_root=OUTPUT_ROOT, expected_count=11)\n", "\n", @@ -85,7 +85,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8c9b854e", + "id": "07fa2765", "metadata": {}, "outputs": [], "source": [ @@ -105,7 +105,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bca6b3c3", + "id": "1a5876df", "metadata": {}, "outputs": [], "source": [ @@ -125,7 +125,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7ad76d69", + "id": "db6754b1", "metadata": {}, "outputs": [], "source": [ @@ -152,7 +152,7 @@ { "cell_type": "code", "execution_count": null, - "id": "516eb14e", + "id": "5c98b75c", "metadata": {}, "outputs": [], "source": [ @@ -179,7 +179,7 @@ { "cell_type": "code", "execution_count": null, - "id": "711a2d08", + "id": "e4dee3d3", "metadata": {}, "outputs": [], "source": [ @@ -270,4 +270,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/notebooks/SignalObjExamples.ipynb b/notebooks/SignalObjExamples.ipynb index 684f6b29..76e0a497 100644 --- a/notebooks/SignalObjExamples.ipynb +++ b/notebooks/SignalObjExamples.ipynb @@ -22,33 +22,16 @@ "matplotlib.use(\"Agg\")\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "from scipy.io import loadmat\n", "\n", - "from nstat.data_manager import ensure_example_data\n", "from nstat.notebook_figures import FigureTracker\n", "\n", "np.random.seed(0)\n", - "DATA_DIR = ensure_example_data(download=True)\n", "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic='SignalObjExamples', output_root=OUTPUT_ROOT, expected_count=16)\n", "\n", - "def _load_example_globals(name: str) -> dict[str, object]:\n", - " candidates = [\n", - " Path(name),\n", - " DATA_DIR / name,\n", - " DATA_DIR / \"mEPSCs\" / name,\n", - " DATA_DIR / \"Place Cells\" / name,\n", - " DATA_DIR / \"Explicit Stimulus\" / name,\n", - " ]\n", - " for path in candidates:\n", - " if path.exists():\n", - " data = loadmat(path)\n", - " return {k: v for k, v in data.items() if not k.startswith(\"__\")}\n", - " return {}\n", - "\n", "# SECTION 0: Section 0\n", "# Using the SignalObj Class\n", - "# In this file we will give several examples of how the SignalObj can be used. A description of all of the properties of SignalObj can be found at: SignalObj Class Definition" + "# In this file we will give several examples of how the SignalObj can be used. A description of all of the properties of SignalObj can be found at: SignalObj Class Definition\n" ] }, { diff --git a/notebooks/TrialConfigExamples.ipynb b/notebooks/TrialConfigExamples.ipynb index c88fd06e..525a79cc 100644 --- a/notebooks/TrialConfigExamples.ipynb +++ b/notebooks/TrialConfigExamples.ipynb @@ -22,34 +22,17 @@ "matplotlib.use(\"Agg\")\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "from scipy.io import loadmat\n", "\n", - "from nstat.data_manager import ensure_example_data\n", "from nstat.notebook_figures import FigureTracker\n", "\n", "np.random.seed(0)\n", - "DATA_DIR = ensure_example_data(download=True)\n", "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic='TrialConfigExamples', output_root=OUTPUT_ROOT, expected_count=0)\n", "\n", - "def _load_example_globals(name: str) -> dict[str, object]:\n", - " candidates = [\n", - " Path(name),\n", - " DATA_DIR / name,\n", - " DATA_DIR / \"mEPSCs\" / name,\n", - " DATA_DIR / \"Place Cells\" / name,\n", - " DATA_DIR / \"Explicit Stimulus\" / name,\n", - " ]\n", - " for path in candidates:\n", - " if path.exists():\n", - " data = loadmat(path)\n", - " return {k: v for k, v in data.items() if not k.startswith(\"__\")}\n", - " return {}\n", - "\n", "# SECTION 0: Section 0\n", "# TrialConfig Examples\n", "# tcObj=TrialConfig(covMask,sampleRate, history,minTime,maxTime)\n", - "__tracker.finalize()" + "__tracker.finalize()\n" ] } ], diff --git a/notebooks/nSTATPaperExamples.ipynb b/notebooks/nSTATPaperExamples.ipynb index ce5da90b..39f762b9 100644 --- a/notebooks/nSTATPaperExamples.ipynb +++ b/notebooks/nSTATPaperExamples.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "8fdec003", + "id": "f7aeaead", "metadata": {}, "source": [ "\n", @@ -15,7 +15,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b6ea5834", + "id": "5224c304", "metadata": {}, "outputs": [], "source": [ @@ -35,7 +35,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "from nstat.data_manager import ensure_example_data\n", + "from nstat.notebook_data import notebook_example_data_dir\n", "from nstat.notebook_figures import FigureTracker\n", "from nstat.paper_examples_full import (\n", " run_experiment1,\n", @@ -48,7 +48,7 @@ " run_experiment6,\n", ")\n", "\n", - "DATA_DIR = ensure_example_data(download=True)\n", + "DATA_DIR = notebook_example_data_dir(allow_synthetic=True)\n", "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic=\"nSTATPaperExamples\", output_root=OUTPUT_ROOT, expected_count=26)\n", "\n", @@ -75,7 +75,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84e87421", + "id": "7479e214", "metadata": {}, "outputs": [], "source": [ @@ -86,7 +86,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5af707f9", + "id": "8aad9697", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ { "cell_type": "code", "execution_count": null, - "id": "85ba4a89", + "id": "7190f56b", "metadata": {}, "outputs": [], "source": [ @@ -113,7 +113,7 @@ { "cell_type": "code", "execution_count": null, - "id": "95ab58dd", + "id": "1c1ee772", "metadata": {}, "outputs": [], "source": [ @@ -138,7 +138,7 @@ { "cell_type": "code", "execution_count": null, - "id": "977351c8", + "id": "be2211f0", "metadata": {}, "outputs": [], "source": [ @@ -158,7 +158,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79b8695f", + "id": "33028dc0", "metadata": {}, "outputs": [], "source": [ @@ -176,7 +176,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2a8b3bde", + "id": "8afc80fe", "metadata": {}, "outputs": [], "source": [ @@ -194,7 +194,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c0761b59", + "id": "52f1ba93", "metadata": {}, "outputs": [], "source": [ @@ -205,7 +205,7 @@ { "cell_type": "code", "execution_count": null, - "id": "75738bbd", + "id": "6b5263c4", "metadata": {}, "outputs": [], "source": [ @@ -224,7 +224,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fa81dd6f", + "id": "45a8e34e", "metadata": {}, "outputs": [], "source": [ @@ -240,7 +240,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0dd24c21", + "id": "6e33b0e3", "metadata": {}, "outputs": [], "source": [ @@ -259,7 +259,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b8733445", + "id": "6376eee2", "metadata": {}, "outputs": [], "source": [ @@ -281,7 +281,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6c8f49c6", + "id": "d3e45147", "metadata": {}, "outputs": [], "source": [ @@ -301,7 +301,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1e66d43b", + "id": "d859424b", "metadata": {}, "outputs": [], "source": [ @@ -321,7 +321,7 @@ { "cell_type": "code", "execution_count": null, - "id": "36d0fb8e", + "id": "cbf7d0c0", "metadata": {}, "outputs": [], "source": [ @@ -332,7 +332,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5bacb83a", + "id": "1765fe2c", "metadata": {}, "outputs": [], "source": [ @@ -348,7 +348,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6b296651", + "id": "2b42cc95", "metadata": {}, "outputs": [], "source": [ @@ -366,7 +366,7 @@ { "cell_type": "code", "execution_count": null, - "id": "75110902", + "id": "5806ab16", "metadata": {}, "outputs": [], "source": [ @@ -377,7 +377,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5af99ac3", + "id": "5354140f", "metadata": {}, "outputs": [], "source": [ @@ -394,7 +394,7 @@ { "cell_type": "code", "execution_count": null, - "id": "de6d1c17", + "id": "cec81566", "metadata": {}, "outputs": [], "source": [ @@ -410,7 +410,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f3c3d433", + "id": "f9f22950", "metadata": {}, "outputs": [], "source": [ @@ -426,7 +426,7 @@ { "cell_type": "code", "execution_count": null, - "id": "60437712", + "id": "92b61189", "metadata": {}, "outputs": [], "source": [ @@ -437,7 +437,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4f7bdfe0", + "id": "0b46271e", "metadata": {}, "outputs": [], "source": [ @@ -453,7 +453,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9cf5befa", + "id": "ab1e1e3b", "metadata": {}, "outputs": [], "source": [ @@ -469,7 +469,7 @@ { "cell_type": "code", "execution_count": null, - "id": "572987ca", + "id": "71966f40", "metadata": {}, "outputs": [], "source": [ @@ -483,7 +483,7 @@ { "cell_type": "code", "execution_count": null, - "id": "408e0541", + "id": "449c2417", "metadata": {}, "outputs": [], "source": [ @@ -497,7 +497,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a7f697d6", + "id": "198fdb9f", "metadata": {}, "outputs": [], "source": [ @@ -508,7 +508,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b679dda9", + "id": "8a6f7766", "metadata": {}, "outputs": [], "source": [ @@ -526,7 +526,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7a209772", + "id": "5d51d0c1", "metadata": {}, "outputs": [], "source": [ @@ -537,7 +537,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e003f38e", + "id": "cffe4e46", "metadata": {}, "outputs": [], "source": [ @@ -556,7 +556,7 @@ { "cell_type": "code", "execution_count": null, - "id": "049bfc62", + "id": "17a6d341", "metadata": {}, "outputs": [], "source": [ @@ -575,7 +575,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0deb3318", + "id": "a38816f8", "metadata": {}, "outputs": [], "source": [ @@ -586,7 +586,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0962e40e", + "id": "24daef12", "metadata": {}, "outputs": [], "source": [ @@ -604,7 +604,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b16c44b6", + "id": "42056ccb", "metadata": {}, "outputs": [], "source": [ @@ -623,7 +623,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c5624d21", + "id": "a0dcaad0", "metadata": {}, "outputs": [], "source": [ @@ -641,7 +641,7 @@ { "cell_type": "code", "execution_count": null, - "id": "161659b6", + "id": "61804aaf", "metadata": {}, "outputs": [], "source": [ @@ -664,7 +664,7 @@ { "cell_type": "code", "execution_count": null, - "id": "57ff85a2", + "id": "3f0a28f1", "metadata": {}, "outputs": [], "source": [ diff --git a/notebooks/nSpikeTrainExamples.ipynb b/notebooks/nSpikeTrainExamples.ipynb index 5238c4c6..71075a0e 100644 --- a/notebooks/nSpikeTrainExamples.ipynb +++ b/notebooks/nSpikeTrainExamples.ipynb @@ -22,32 +22,15 @@ "matplotlib.use(\"Agg\")\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "from scipy.io import loadmat\n", "\n", - "from nstat.data_manager import ensure_example_data\n", "from nstat.notebook_figures import FigureTracker\n", "\n", "np.random.seed(0)\n", - "DATA_DIR = ensure_example_data(download=True)\n", "OUTPUT_ROOT = REPO_ROOT / \"output\" / \"notebook_images\"\n", "__tracker = FigureTracker(topic='nSpikeTrainExamples', output_root=OUTPUT_ROOT, expected_count=4)\n", "\n", - "def _load_example_globals(name: str) -> dict[str, object]:\n", - " candidates = [\n", - " Path(name),\n", - " DATA_DIR / name,\n", - " DATA_DIR / \"mEPSCs\" / name,\n", - " DATA_DIR / \"Place Cells\" / name,\n", - " DATA_DIR / \"Explicit Stimulus\" / name,\n", - " ]\n", - " for path in candidates:\n", - " if path.exists():\n", - " data = loadmat(path)\n", - " return {k: v for k, v in data.items() if not k.startswith(\"__\")}\n", - " return {}\n", - "\n", "# SECTION 0: Section 0\n", - "# Test the nspikeTrain Class" + "# Test the nspikeTrain Class\n" ] }, { diff --git a/nstat/core.py b/nstat/core.py index 509cddab..2190da94 100644 --- a/nstat/core.py +++ b/nstat/core.py @@ -54,6 +54,22 @@ def _matlab_mode_1d(values: Sequence[float] | np.ndarray) -> float: return float(unique[int(best[0])]) +def _nearest_sample_matrix(target_time: np.ndarray, source_time: np.ndarray, source_data: np.ndarray) -> np.ndarray: + target = np.asarray(target_time, dtype=float).reshape(-1) + source_t = np.asarray(source_time, dtype=float).reshape(-1) + source = np.asarray(source_data, dtype=float) + if source.ndim == 1: + source = source[:, None] + if source_t.size == 0: + return np.zeros((target.size, source.shape[1]), dtype=float) + right = np.searchsorted(source_t, target, side="left") + right = np.clip(right, 0, source_t.size - 1) + left = np.clip(right - 1, 0, source_t.size - 1) + choose_right = np.abs(source_t[right] - target) <= np.abs(source_t[left] - target) + indices = np.where(choose_right, right, left) + return source[indices] + + class SignalObj: """Closer MATLAB-style signal abstraction used throughout the Python port.""" @@ -596,6 +612,69 @@ def findIndFromDataMask(self) -> list[int]: def isMaskSet(self) -> bool: return bool(np.any(self.dataMask == 0)) + def abs(self) -> "SignalObj": + labels = [f"|{label}|" if label else "" for label in self.dataLabels] + return self._spawn(self.time, np.abs(self.data), data_labels=labels).with_metadata( + name=f"|{self.name}|", + yunits=self.yunits, + ) + + def __abs__(self) -> "SignalObj": + return self.abs() + + def log(self) -> "SignalObj": + labels = [f"ln({label})" if label else "" for label in self.dataLabels] + yunits = f"ln({self.yunits})" if self.yunits else "" + return self._spawn(self.time, np.log(self.data), data_labels=labels).with_metadata( + name=f"ln({self.name})", + yunits=yunits, + ) + + def with_metadata(self, *, name: str | None = None, xlabelval: str | None = None, xunits: str | None = None, yunits: str | None = None) -> "SignalObj": + out = self.copySignal() + if name is not None: + out.name = str(name) + if xlabelval is not None: + out.xlabelval = str(xlabelval) + if xunits is not None: + out.xunits = str(xunits) + if yunits is not None: + out.yunits = str(yunits) + return out + + def median(self, axis: int | None = None) -> "SignalObj": + axis_arg = 0 if axis is None else axis + median_data = np.median(self.data, axis=axis_arg) + array = np.asarray(median_data, dtype=float) + if array.ndim == 1 and array.size == self.dimension: + labels = [f"median({label})" if label else "" for label in self.dataLabels] + return self._spawn( + np.asarray([self.time[0], self.time[-1]], dtype=float), + np.vstack([array, array]), + data_labels=labels, + ).with_metadata(name=f"median({self.name})") + reshaped = array.reshape(-1, 1) + return self._spawn(self.time, reshaped, data_labels=[f"median({self.name})"]).with_metadata(name=f"median({self.name})") + + def mode(self, axis: int | None = None) -> "SignalObj": + axis_arg = 0 if axis is None else axis + if axis_arg == 0: + mode_data = np.asarray([_matlab_mode_1d(self.data[:, i]) for i in range(self.dimension)], dtype=float) + elif axis_arg == 1: + mode_data = np.asarray([_matlab_mode_1d(row) for row in self.data], dtype=float) + else: + raise ValueError("axis must be 0, 1, or None") + array = np.asarray(mode_data, dtype=float) + if array.ndim == 1 and array.size == self.dimension: + labels = [f"mode({label})" if label else "" for label in self.dataLabels] + return self._spawn( + np.asarray([self.time[0], self.time[-1]], dtype=float), + np.vstack([array, array]), + data_labels=labels, + ).with_metadata(name=f"mode({self.name})") + reshaped = array.reshape(-1, 1) + return self._spawn(self.time, reshaped, data_labels=[f"mode({self.name})"]).with_metadata(name=f"mode({self.name})") + def mean(self, axis: int | None = None) -> "SignalObj": axis_arg = 0 if axis is None else axis mean_data = np.mean(self.data, axis=axis_arg) @@ -624,6 +703,20 @@ def std(self, axis: int | None = None) -> "SignalObj": reshaped = array.reshape(-1, 1) return self._spawn(self.time, reshaped, data_labels=[f"\\sigma({self.name})"]) + def max(self, axis: int | None = None): + axis_arg = 0 if axis is None else axis + values = np.max(self.data, axis=axis_arg) + indices = np.argmax(self.data, axis=axis_arg) + time = self.time[np.asarray(indices, dtype=int)] + return values, indices, time + + def min(self, axis: int | None = None): + axis_arg = 0 if axis is None else axis + values = np.min(self.data, axis=axis_arg) + indices = np.argmin(self.data, axis=axis_arg) + time = self.time[np.asarray(indices, dtype=int)] + return values, indices, time + def resample(self, sample_rate: float) -> "SignalObj": copied = self.copySignal() copied.resampleMe(sample_rate) @@ -654,6 +747,29 @@ def derivativeAt(self, x0: Sequence[float] | float): values = deriv.getValueAt(x0) return values + def integral(self, t0: float | None = None, tf: float | None = None) -> "SignalObj": + start = self.minTime if t0 is None else float(t0) + stop = self.maxTime if tf is None else float(tf) + integrated = self.getSigInTimeWindow(start, stop) + dt = 1.0 / max(float(integrated.sampleRate), 1e-12) + integrated = integrated.filter([dt], [1.0, -1.0]) + if integrated.yunits and integrated.xunits: + integrated.setYUnits(f"{integrated.yunits}*{integrated.xunits}") + elif integrated.xunits: + integrated.setYUnits(integrated.xunits) + dtstr = " d\\tau" + integrated.setName(f"\\int_{integrated.minTime:g}^{integrated.xlabelval[:1]}\\!\\!{{{integrated.name}{dtstr}}}") + labels_empty = all(not str(label) for label in integrated.dataLabels) + if not labels_empty: + updated_labels: list[str] = [] + for label in self.dataLabels: + if label: + updated_labels.append(f"\\int_{integrated.minTime:g}^{integrated.xlabelval[:1]}\\!\\!{{{label}{dtstr}}}") + else: + updated_labels.append("") + integrated.setDataLabels(updated_labels) + return integrated + def filter(self, B, A=1) -> "SignalObj": try: from scipy.signal import lfilter @@ -676,6 +792,120 @@ def filtfilt(self, B, A=1) -> "SignalObj": filtered = np.column_stack([filtfilt(b, a, self.data[:, index]) for index in range(self.dimension)]) return self._spawn(self.time, filtered, data_labels=list(self.dataLabels)) + def makeCompatible(self, other: "SignalObj", holdVals: int = 0) -> tuple["SignalObj", "SignalObj"]: + if ( + self.minTime == other.minTime + and self.maxTime == other.maxTime + and round(float(self.sampleRate), 9) == round(float(other.sampleRate), 9) + and self.time.shape == other.time.shape + and np.max(np.abs(self.time - other.time)) <= 1e-9 + ): + return self, other + + s1c = self.copySignal() + s2c = other.copySignal() + min_time = min(s1c.minTime, s2c.minTime) + max_time = max(s1c.maxTime, s2c.maxTime) + sample_rate = max(float(s1c.sampleRate), float(s2c.sampleRate)) + s1c.setSampleRate(sample_rate) + s2c.setSampleRate(sample_rate) + s1c.setMinTime(min_time, holdVals) + s2c.setMinTime(min_time, holdVals) + s1c.setMaxTime(max_time, holdVals) + s2c.setMaxTime(max_time, holdVals) + s2c.data = _nearest_sample_matrix(s1c.time, s2c.time, s2c.data) + s2c.time = s1c.time.copy() + s2c.minTime = float(np.min(s2c.time)) + s2c.maxTime = float(np.max(s2c.time)) + return s1c, s2c + + def autocorrelation(self) -> "SignalObj": + centered = self.data - np.mean(self.data, axis=0, keepdims=True) + columns: list[np.ndarray] = [] + lags: np.ndarray | None = None + for index in range(self.dimension): + series = centered[:, index] + denom = float(np.dot(series, series)) + corr = np.correlate(series, series, mode="full") + if denom > 0: + corr = corr / denom + else: + corr = np.zeros_like(corr, dtype=float) + if lags is None: + lags = np.arange(-series.size + 1, series.size, dtype=float) / max(float(self.sampleRate), 1e-12) + columns.append(np.asarray(corr, dtype=float)) + data = np.column_stack(columns) if columns else np.zeros((0, 0), dtype=float) + return self.__class__( + lags if lags is not None else np.array([], dtype=float), + data, + f"ACF({self.name})", + "Lag", + self.xunits, + f"{self.yunits}^2" if self.yunits else "", + list(self.dataLabels), + list(self.plotProps), + ) + + def crosscorrelation(self, other: "SignalObj") -> "SignalObj": + if self.dimension != 1 or other.dimension != 1: + raise ValueError("crosscorrelation only supports one-dimensional signals") + s1c, s2c = self.makeCompatible(other) + x = s1c.data[:, 0] - float(np.mean(s1c.data[:, 0])) + y = s2c.data[:, 0] - float(np.mean(s2c.data[:, 0])) + denom = float(np.sqrt(np.dot(x, x) * np.dot(y, y))) + corr = np.correlate(x, y, mode="full") + if denom > 0: + corr = corr / denom + else: + corr = np.zeros_like(corr, dtype=float) + lags = np.arange(-x.size + 1, x.size, dtype=float) / max(float(s1c.sampleRate), 1e-12) + return self.__class__( + lags, + corr, + f"XCORF({self.name})", + "Lag", + self.xunits, + f"{self.yunits}^2" if self.yunits else "", + list(self.dataLabels[:1]), + list(self.plotProps[:1]), + ) + + def xcorr(self, other: "SignalObj" | None = None, maxlag: int | None = None) -> "SignalObj": + s2 = self if other is None else other + s1c, s2c = self.makeCompatible(s2) + data_columns: list[np.ndarray] = [] + data_labels: list[str] = [] + lag_index: np.ndarray | None = None + for left_index in range(s1c.dimension): + for right_index in range(s2c.dimension): + corr = np.correlate(s1c.data[:, left_index], s2c.data[:, right_index], mode="full") + lags = np.arange(-s1c.data.shape[0] + 1, s1c.data.shape[0], dtype=int) + if maxlag is not None: + keep = np.abs(lags) <= int(maxlag) + corr = corr[keep] + lags = lags[keep] + if other is None: + keep = lags >= 0 + corr = corr[keep] + lags = lags[keep] + if lag_index is None: + lag_index = lags.astype(float) / max(float(s1c.sampleRate), 1e-12) + data_columns.append(np.asarray(corr, dtype=float)) + left_label = s1c.dataLabels[left_index] if left_index < len(s1c.dataLabels) else str(left_index + 1) + right_label = s2c.dataLabels[right_index] if right_index < len(s2c.dataLabels) else str(right_index + 1) + data_labels.append(f"corr({left_label},{right_label})") + data = np.column_stack(data_columns) if data_columns else np.zeros((0, 0), dtype=float) + name = f"corr({self.name},{s2.name})" + return self.__class__( + lag_index if lag_index is not None else np.array([], dtype=float), + data, + name, + "\\Delta \\tau", + self.xunits, + f"{self.yunits}^2" if self.yunits else "", + data_labels, + ) + def setConfInterval(self, bounds: tuple[np.ndarray, np.ndarray]) -> None: low, high = bounds low_arr = np.asarray(low, dtype=float) @@ -1014,6 +1244,7 @@ def setName(self, name: str) -> None: def computeStatistics(self, makePlots: int = 0) -> None: self.avgFiringRate = self.firing_rate_hz isi = self.getISIs() + spike_times = self.spikeTimes mode_isi = _matlab_mode_1d(isi) self.burstIndex = float(1.0 / mode_isi / self.avgFiringRate) if np.isfinite(mode_isi) and self.avgFiringRate > 0 else np.nan self.B = np.nan @@ -1026,7 +1257,53 @@ def computeStatistics(self, makePlots: int = 0) -> None: self.numSpikesPerBurst = np.array([], dtype=float) self.avgSpikesPerBurst = np.nan self.stdSpikesPerBurst = np.nan + self.Lstatistic = np.nan + + if isi.size: + sigma = float(np.std(isi)) + mu = float(np.mean(isi)) + if np.isfinite(mu) and mu > 0: + r = sigma / mu + self.B = float((r - 1.0) / (r + 1.0)) + n = float(spike_times.size) + self.An = float((np.sqrt(n + 2.0) * r - np.sqrt(n)) / (((np.sqrt(n + 2.0) - 2.0) * r) + np.sqrt(n))) + + ln = isi[isi < mu] + ml = float(np.mean(ln)) if ln.size else np.nan + if np.isfinite(ml): + burst_isi = (isi < ml).astype(float) + shifted = np.concatenate([burst_isi[1:], [0.0]]) if burst_isi.size else np.array([], dtype=float) + y = (burst_isi + shifted) > 1.0 + diff_sig = np.concatenate([[0.0], np.diff(y.astype(float))]) if y.size else np.array([], dtype=float) + burst_start = np.flatnonzero(diff_sig == 1.0) + burst_end = np.flatnonzero(diff_sig == -1.0) + 1 + if burst_start.size == 0: + burst_end = np.array([], dtype=int) + if burst_end.size > burst_start.size and burst_end.size: + first = np.flatnonzero(y[: burst_end[0]] == 1) + if first.size: + burst_start = np.concatenate([[int(first[0])], burst_start]) + if burst_start.size > burst_end.size and burst_start.size: + last = np.flatnonzero(y[burst_start[-1] :] == 1) + if last.size: + burst_end = np.concatenate([burst_end, [int(last[-1])]]) + if burst_start.size and burst_end.size: + burst_data = np.zeros(spike_times.size, dtype=float) + for start, end in zip(burst_start, burst_end, strict=False): + burst_data[int(start) : int(end) + 1] = 1.0 + self.burstDuration = spike_times[burst_end] - spike_times[burst_start] + self.burstSig = SignalObj(spike_times, burst_data, "Burst Signal") + self.burstTimes = spike_times[burst_start] + self.numBursts = int(burst_start.size) + duration = self.maxTime - self.minTime + self.burstRate = float(self.numBursts / duration) if duration > 0 else np.nan + self.numSpikesPerBurst = (burst_end - burst_start + 1).astype(float) + self.avgSpikesPerBurst = float(np.mean(self.numSpikesPerBurst + 1.0)) + self.stdSpikesPerBurst = float(np.std(self.numSpikesPerBurst + 1.0)) + self.Lstatistic = self.getLStatistic() + if makePlots == 1: + self.plot() def getLStatistic(self) -> float: isi = self.getISIs() @@ -1181,8 +1458,8 @@ def computeRate(self) -> SignalObj: def restoreToOriginal(self) -> None: self.spikeTimes = self.originalSpikeTimes.copy() self.sampleRate = float(self.originalSampleRate) - self.minTime = float(np.min(self.spikeTimes)) if self.spikeTimes.size else 0.0 - self.maxTime = float(np.max(self.spikeTimes)) if self.spikeTimes.size else 0.0 + self.minTime = float(self.originalMinTime) + self.maxTime = float(self.originalMaxTime) self.clearSigRep() def partitionNST( @@ -1203,6 +1480,8 @@ def partitionNST( normalize = bool(normalizeTime) if normalizeTime is not None else False partitions: list[nspikeTrain] = [] for index, (window_start, window_stop) in enumerate(zip(windows[:-1], windows[1:]), start=1): + window_start = round(float(window_start) * self.sampleRate) / self.sampleRate + window_stop = round(float(window_stop) * self.sampleRate) / self.sampleRate duration = float(window_stop - window_start) if lbound is not None and ubound is not None and not (float(lbound) <= abs(duration) <= float(ubound)): continue @@ -1213,7 +1492,7 @@ def partitionNST( subset = subset - float(window_start) if normalize and duration != 0: subset = subset / duration - partitions.append(nspikeTrain(subset, self.name, 1.0 / self.sampleRate if self.sampleRate > 0 else 0.001, makePlots=-1)) + partitions.append(nspikeTrain(subset, self.name, makePlots=-1)) coll = nstColl(partitions) if normalize: @@ -1224,6 +1503,97 @@ def partitionNST( def getFieldVal(self, fieldName: str): return getattr(self, fieldName, []) + def plotISISpectrumFunction(self): + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(1, 1, figsize=(6.0, 3.5)) + isi = self.getISIs() + if isi.size: + (line,) = ax.plot(self.spikeTimes[1:], isi, ".") + else: + (line,) = ax.plot([], [], ".") + ax.set_xlabel("time [s]") + ax.set_ylabel("ISI [s]") + return line + + def plotJointISIHistogram(self): + import matplotlib.pyplot as plt + + ax = plt.subplots(1, 1, figsize=(4.5, 4.0))[1] + isi = self.getISIs() + if isi.size >= 2: + ax.loglog(isi[:-1], isi[1:], ".") + mean_isi = float(np.mean(isi)) + ln = isi[isi < mean_isi] + ml = float(np.mean(ln)) if ln.size else np.nan + if np.isfinite(ml) and ml > 0: + v = ax.axis() + ax.loglog([ml, ml], [v[2], v[3]], "k--") + ax.loglog([v[0], v[1]], [ml, ml], "k--") + ax.set_xlabel("ISI(t) [s]") + ax.set_ylabel("ISI(t+1) [s]") + return ax + + def plotISIHistogram(self, minTime: float | None = None, maxTime: float | None = None, numBins: int | None = None, handle=None): + import matplotlib.pyplot as plt + + del numBins + ax = plt.gca() if handle is None else handle + if maxTime is None: + maxTime = self.maxTime + if minTime is None: + minTime = self.minTime + isi = self.getISIs(minTime, maxTime) + counts = np.array([], dtype=float) + bins = np.array([], dtype=float) + if isi.size: + bin_width = 0.001 + bins = np.arange(0.0, float(np.max(isi)) + bin_width, bin_width, dtype=float) + if bins.size < 2: + bins = np.array([0.0, bin_width], dtype=float) + counts, edges = np.histogram(isi, bins=bins) + centers = edges[:-1] + ax.bar( + centers, + counts, + width=bin_width, + align="edge", + edgecolor=(0.0, 0.0, 0.0), + linewidth=2.0, + color=(0.831372559070587, 0.815686285495758, 0.7843137383461), + ) + ax.set_xlabel("ISI [sec]") + ax.set_ylabel("Spike Counts") + ax.autoscale(enable=True, axis="x", tight=True) + return counts + + def plotProbPlot(self, minTime: float | None = None, maxTime: float | None = None, handle=None): + import matplotlib.pyplot as plt + from scipy import stats + + ax = plt.gca() if handle is None else handle + if maxTime is None: + maxTime = self.maxTime + if minTime is None: + minTime = self.minTime + isi = self.getISIs(minTime, maxTime) + ax.clear() + if isi.size: + stats.probplot(isi, dist=stats.expon, plot=ax) + ax.set_title(ax.get_title() or "Probability Plot") + return ax + + def plotExponentialFit(self, minTime: float | None = None, maxTime: float | None = None, numBins: int | None = None, handle=None): + import matplotlib.pyplot as plt + + fig = handle if handle is not None else plt.figure(figsize=(10.0, 4.0)) + fig.clear() + axes = fig.subplots(1, 2) + self.plotISIHistogram(minTime, maxTime, numBins, axes[0]) + self.plotProbPlot(minTime, maxTime, axes[1]) + fig.tight_layout() + return fig + def plot(self, dHeight: float = 1.0, yOffset: float = 0.5, currentHandle=None): import matplotlib.pyplot as plt diff --git a/nstat/notebook_data.py b/nstat/notebook_data.py new file mode 100644 index 00000000..d82dd134 --- /dev/null +++ b/nstat/notebook_data.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import os +from pathlib import Path + +import numpy as np +from scipy.io import loadmat + +from . import data_manager + + +def _download_policy() -> bool: + policy = os.environ.get("NSTAT_NOTEBOOK_DOWNLOAD_EXAMPLE_DATA", "").strip().lower() + if policy in {"1", "true", "yes", "on", "always"}: + return True + if policy in {"0", "false", "no", "off", "never"}: + return False + return os.environ.get("CI", "").strip().lower() not in {"1", "true", "yes", "on"} + + +def notebook_example_data_dir(*, allow_synthetic: bool = False) -> Path: + """Return a notebook-safe example-data root. + + Local runs still auto-download data when needed. CI defaults to cached-only + behavior and can fall back to the synthetic dataset paths used by the paper + example helpers. + """ + + try: + return data_manager.ensure_example_data(download=_download_policy()) + except FileNotFoundError: + if not allow_synthetic: + raise + os.environ.setdefault("NSTAT_ALLOW_SYNTHETIC_DATA", "1") + return data_manager.get_data_dir() + + +def _is_lfs_pointer(path: Path) -> bool: + try: + head = path.read_bytes()[:200] + except OSError: + return False + return head.startswith(b"version https://git-lfs.github.com/spec/v1") + + +def _synthetic_glm_data() -> dict[str, np.ndarray]: + rng = np.random.default_rng(1202) + sample_rate = 1000.0 + dt = 1.0 / sample_rate + duration_s = 6.0 + t = np.arange(0.0, duration_s + dt, dt, dtype=float) + theta = 2.0 * np.pi * 0.08 * t + x_n = 0.75 * np.cos(theta) + 0.15 * np.sin(0.35 * theta) + y_n = 0.65 * np.sin(theta + 0.2) + 0.10 * np.cos(0.5 * theta) + vx_n = np.gradient(x_n, dt) + vy_n = np.gradient(y_n, dt) + + eta = 1.4 + 1.0 * x_n - 0.8 * y_n - 0.55 * x_n * x_n - 0.45 * y_n * y_n + 0.35 * x_n * y_n + lam_per_bin = np.clip(np.exp(np.clip(eta, -8.0, 4.5)) * dt, 1e-6, 0.2) + spikes_binned = (rng.random(t.shape[0]) < lam_per_bin).astype(float) + spike_idx = np.flatnonzero(spikes_binned > 0.5) + spiketimes = t[spike_idx] + x_at_spiketimes = x_n[spike_idx] + y_at_spiketimes = y_n[spike_idx] + + return { + "T": t, + "xN": x_n, + "yN": y_n, + "vxN": vx_n, + "vyN": vy_n, + "spikes_binned": spikes_binned, + "spiketimes": spiketimes, + "x_at_spiketimes": x_at_spiketimes, + "y_at_spiketimes": y_at_spiketimes, + } + + +def load_glm_data_for_notebook() -> dict[str, np.ndarray]: + """Return the canonical GLM dataset or a deterministic synthetic fallback.""" + + data_dir = notebook_example_data_dir(allow_synthetic=True) + path = data_dir / "glm_data.mat" + if path.exists() and not _is_lfs_pointer(path): + payload = loadmat(path, squeeze_me=True, struct_as_record=False) + return {key: value for key, value in payload.items() if not key.startswith("__")} + return _synthetic_glm_data() diff --git a/nstat/paper_examples_full.py b/nstat/paper_examples_full.py index 815f57e4..0cb94845 100644 --- a/nstat/paper_examples_full.py +++ b/nstat/paper_examples_full.py @@ -138,6 +138,26 @@ def _coefficient_intervals(x: np.ndarray, result, offset: np.ndarray) -> tuple[n def _load_mepsc_times_seconds(path: Path) -> np.ndarray: + if not path.exists(): + if _allow_synthetic_data(): + name = path.name + if name == "epsc2.txt": + rng = np.random.default_rng(1001) + time = np.arange(0.0, 220.0, 0.05, dtype=float) + rate_hz = np.full(time.shape, 0.55, dtype=float) + elif name == "washout1.txt": + rng = np.random.default_rng(1002) + time = np.arange(0.0, 500.0, 0.05, dtype=float) + rate_hz = np.where(time < 235.0, 0.70, 1.25) + elif name == "washout2.txt": + rng = np.random.default_rng(1003) + time = np.arange(0.0, 320.0, 0.05, dtype=float) + rate_hz = 1.75 + 0.20 * np.sin(0.01 * time) + else: + raise FileNotFoundError(f"Missing mEPSC file: {path}") + keep = rng.random(time.shape[0]) < np.clip(rate_hz * 0.05, 1e-6, 0.25) + return time[keep] + raise FileNotFoundError(f"Missing mEPSC file: {path}") arr = np.loadtxt(path, skiprows=1) return np.asarray(arr[:, 1], dtype=float).reshape(-1) / 1000.0 diff --git a/parity/class_fidelity.yml b/parity/class_fidelity.yml index a07c2c14..10db0f1e 100644 --- a/parity/class_fidelity.yml +++ b/parity/class_fidelity.yml @@ -24,7 +24,9 @@ items: confidence-interval storage. method_parity: MATLAB-facing methods now cover labels, masking, sub-signals, nearest-time lookup, time-window extraction, merge, arithmetic operators, derivative/derivativeAt, - filtering, plotting, restore/reset, mean/std, resampling, and structure export. + integral, filtering, compatibility alignment, autocorrelation/crosscorrelation/xcorr, + abs/log, mean/median/mode/std, min/max summaries, plotting, restore/reset, resampling, + and structure export. defaults_parity: Defaults for labels, units, and sample-rate fallback now match MATLAB closely, including the 1 kHz fallback when sample spacing is ill-conditioned. indexing_parity: Signals use time-by-dimension storage and one-based selector behavior @@ -34,14 +36,15 @@ items: output_type_parity: MATLAB-facing methods return SignalObj/Covariate instances where expected. known_remaining_differences: - - Some specialized MATLAB utilities, plotting options, and correlation helpers remain + - Some specialized MATLAB spectral utilities and report-style plotting options remain unported. - Structure serialization is close but not exhaustive for every MATLAB-only field. required_remediation: - - Add MATLAB-derived fixtures for filter outputs, plotting selectors, and any remaining - specialized utility methods. - plotting_report_parity: Core plotting is implemented; some MATLAB-only plot selectors, - spectral utilities, and report-style helpers remain lighter. + - Add MATLAB-derived fixtures for filter outputs, xcorr/autocorrelation traces, plotting + selectors, and the remaining spectral utility methods. + plotting_report_parity: Core plotting and correlation helpers are implemented; some + MATLAB-only plot selectors, spectral utilities, and report-style helpers remain + lighter. - matlab_name: Covariate kind: class matlab_path: Covariate.m @@ -84,7 +87,9 @@ items: and label metadata. method_parity: MATLAB-facing methods now cover setSigRep, setMinTime, setMaxTime, resample, getSigRep, getSpikeTimes, getISIs, getMinISI, getMaxBinSizeBinary, partitionNST, - getFieldVal, computeRate, restoreToOriginal, nstCopy, plot, and structure round-trip. + getFieldVal, computeRate, restoreToOriginal, nstCopy, burst/statistics computation, + ISI histogram/probability plotting, joint ISI plotting, raster plotting, and structure + round-trip. defaults_parity: Defaults, cache behavior, and restore/resample semantics now track MATLAB much more closely than the earlier simplified implementation. indexing_parity: Spike vectors remain one-dimensional and time-window filtering @@ -94,13 +99,13 @@ items: output_type_parity: Signal representation returns SignalObj and rate conversion returns SignalObj as expected. known_remaining_differences: - - Several ISI-plot helper methods remain unported or lighter than MATLAB. - - Burst metrics remain approximated rather than fully MATLAB-equivalent. + - Some MATLAB visual styling and distribution-fit detail in the ISI plotting helpers + remains lighter than MATLAB. required_remediation: - - Port the remaining ISI plotting helpers and burst-detection detail from MATLAB. - - Add MATLAB-derived fixtures for partitionNST and burst/statistics outputs. - plotting_report_parity: Raster/basic plotting works; ISI, burst, and reporting helpers - remain thinner than MATLAB. + - Add MATLAB-derived fixtures for partitionNST, burst/statistics outputs, and ISI + plotting traces. + plotting_report_parity: Raster, ISI, and burst-oriented plotting helpers now execute + on the canonical class, though visual detail remains lighter than MATLAB. - matlab_name: nstColl kind: class matlab_path: nstColl.m diff --git a/tests/test_notebook_data.py b/tests/test_notebook_data.py new file mode 100644 index 00000000..9aac26d6 --- /dev/null +++ b/tests/test_notebook_data.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np + +import nstat.data_manager as data_manager +import nstat.notebook_data as notebook_data + + +def test_notebook_example_data_dir_disables_download_in_ci(monkeypatch) -> None: + calls: list[bool] = [] + + def fake_ensure_example_data(*, download: bool = True) -> Path: + calls.append(download) + raise FileNotFoundError("missing") + + monkeypatch.setattr(data_manager, "ensure_example_data", fake_ensure_example_data) + monkeypatch.setattr(data_manager, "get_data_dir", lambda: Path("/tmp/nstat-synthetic")) + monkeypatch.setenv("CI", "true") + monkeypatch.delenv("NSTAT_NOTEBOOK_DOWNLOAD_EXAMPLE_DATA", raising=False) + monkeypatch.delenv("NSTAT_ALLOW_SYNTHETIC_DATA", raising=False) + + path = notebook_data.notebook_example_data_dir(allow_synthetic=True) + + assert path == Path("/tmp/nstat-synthetic") + assert calls == [False] + assert notebook_data.os.environ["NSTAT_ALLOW_SYNTHETIC_DATA"] == "1" + + +def test_load_glm_data_for_notebook_uses_synthetic_fallback(monkeypatch) -> None: + monkeypatch.setattr(notebook_data, "notebook_example_data_dir", lambda *, allow_synthetic=False: Path("/tmp/missing-data")) + + payload = notebook_data.load_glm_data_for_notebook() + + expected = { + "T", + "xN", + "yN", + "vxN", + "vyN", + "spikes_binned", + "spiketimes", + "x_at_spiketimes", + "y_at_spiketimes", + } + assert expected <= set(payload) + assert payload["T"].ndim == 1 + assert payload["spikes_binned"].shape == payload["T"].shape + assert payload["spiketimes"].ndim == 1 + assert np.all(np.diff(payload["T"]) > 0.0) diff --git a/tests/test_notebook_surface.py b/tests/test_notebook_surface.py index f7975ab8..10accd4f 100644 --- a/tests/test_notebook_surface.py +++ b/tests/test_notebook_surface.py @@ -7,6 +7,7 @@ REPO_ROOT = Path(__file__).resolve().parents[1] +TOPIC_GROUPS = yaml.safe_load((REPO_ROOT / "tools" / "notebooks" / "topic_groups.yml").read_text(encoding="utf-8")) or {} def test_notebooks_are_python_facing() -> None: @@ -45,3 +46,17 @@ def test_hybrid_filter_notebook_does_not_require_example_data_download() -> None assert "ensure_example_data(download=True)" not in text assert "from nstat.data_manager import ensure_example_data" not in text + + +def test_parity_core_notebooks_do_not_require_live_example_data_download() -> None: + topics = TOPIC_GROUPS.get("groups", {}).get("parity_core", []) + for topic in topics: + notebook = nbformat.read(REPO_ROOT / "notebooks" / f"{topic}.ipynb", as_version=4) + text = "\n".join(cell.source for cell in notebook.cells) + assert "ensure_example_data(download=True)" not in text, f"{topic} still hard-requires remote example-data download" + + +def test_notebook_builder_sources_do_not_hard_require_live_example_data_download() -> None: + for path in sorted((REPO_ROOT / "tools" / "notebooks").glob("*.py")): + text = path.read_text(encoding="utf-8") + assert "ensure_example_data(download=True)" not in text, f"{path.name} still hardcodes live example-data download" diff --git a/tests/test_nspiketrain_fidelity.py b/tests/test_nspiketrain_fidelity.py index 34ee3981..c8430566 100644 --- a/tests/test_nspiketrain_fidelity.py +++ b/tests/test_nspiketrain_fidelity.py @@ -1,5 +1,6 @@ from __future__ import annotations +import matplotlib.pyplot as plt import numpy as np from nstat.nspikeTrain import nspikeTrain @@ -61,4 +62,44 @@ def test_nspiketrain_setsigrep_restore_and_field_access_match_matlab_surface() - train.restoreToOriginal() assert train.sampleRate == 5.0 - np.testing.assert_allclose([train.minTime, train.maxTime], [0.2, 0.6]) + np.testing.assert_allclose([train.minTime, train.maxTime], [0.0, 1.0]) + + +def test_nspiketrain_compute_statistics_matches_matlab_style_burst_metrics() -> None: + train = nspikeTrain([0.0, 0.001, 0.002, 0.007, 0.507, 0.508, 0.509, 0.514], "bursting", 0.001, 0.0, 0.6, makePlots=0) + + assert np.isfinite(train.B) + assert np.isfinite(train.An) + assert np.isfinite(train.burstIndex) + assert train.numBursts >= 1 + assert train.burstSig is not None + assert train.burstTimes.size == train.numBursts + assert train.numSpikesPerBurst.size == train.numBursts + + +def test_nspiketrain_partition_rounds_windows_and_uses_matlab_constructor_defaults() -> None: + train = nspikeTrain([0.0004, 0.0014, 0.0096], "neuron", 0.001, 0.0, 0.01, makePlots=-1) + + parts = train.partitionNST([0.00049, 0.00151, 0.0101], normalizeTime=0) + + assert parts.numSpikeTrains == 2 + np.testing.assert_allclose(parts.getNST(1).spikeTimes, [0.0004, 0.0014]) + np.testing.assert_allclose(parts.getNST(2).spikeTimes, [0.0076]) + assert parts.getNST(1).sampleRate == 1000.0 + + +def test_nspiketrain_isi_plot_helpers_execute_and_return_matplotlib_objects() -> None: + train = nspikeTrain([0.1, 0.12, 0.15, 0.5, 0.8], "neuron", 0.001, 0.0, 1.0, makePlots=0) + + line = train.plotISISpectrumFunction() + joint_ax = train.plotJointISIHistogram() + counts = train.plotISIHistogram() + prob_ax = train.plotProbPlot() + fig = train.plotExponentialFit() + + assert hasattr(line, "get_xdata") + assert hasattr(joint_ax, "loglog") + assert counts.sum() == train.getISIs().size + assert hasattr(prob_ax, "plot") + assert len(fig.axes) == 2 + plt.close("all") diff --git a/tests/test_signalobj_fidelity.py b/tests/test_signalobj_fidelity.py index 86684543..c59e30f5 100644 --- a/tests/test_signalobj_fidelity.py +++ b/tests/test_signalobj_fidelity.py @@ -1,5 +1,6 @@ from __future__ import annotations +import matplotlib.pyplot as plt import numpy as np from nstat.ConfidenceInterval import ConfidenceInterval @@ -89,3 +90,65 @@ def test_covariate_plus_minus_propagate_confidence_intervals() -> None: assert subtracted.isConfIntervalSet() np.testing.assert_allclose(added.ci[0].bounds, [[1.2, 1.8], [3.2, 3.8]]) np.testing.assert_allclose(subtracted.ci[0].bounds, [[0.2, 0.8], [0.2, 0.8]]) + + +def test_signalobj_integral_matches_matlab_style_accumulator_and_labels() -> None: + sig = SignalObj([0.0, 1.0, 2.0], [1.0, 2.0, 3.0], "stim", "time", "s", "a.u.", ["x"]) + + integrated = sig.integral() + + np.testing.assert_allclose(integrated.data[:, 0], [1.0, 3.0, 6.0]) + assert integrated.yunits == "a.u.*s" + assert integrated.name.startswith("\\int_") + assert integrated.dataLabels[0].startswith("\\int_") + + +def test_signalobj_makecompatible_and_correlation_helpers_follow_matlab_surface() -> None: + s1 = SignalObj([0.0, 1.0, 2.0], [1.0, 0.0, -1.0], "s1", dataLabels=["x"]) + s2 = SignalObj([0.5, 1.5], [2.0, 4.0], "s2", dataLabels=["y"]) + + s1c, s2c = s1.makeCompatible(s2, holdVals=1) + + np.testing.assert_allclose(s1c.time, s2c.time) + assert s1c.sampleRate == s2c.sampleRate + assert s1c.minTime == 0.0 + assert s1c.maxTime == 2.0 + np.testing.assert_allclose(s2c.data[:, 0], [2.0, 4.0, 4.0]) + + acf = s1.autocorrelation() + assert acf.name == "ACF(s1)" + np.testing.assert_allclose(acf.time[acf.time.size // 2], 0.0) + + xcf = s1.crosscorrelation(SignalObj([0.0, 1.0, 2.0], [0.0, 1.0, 0.0], "s3", dataLabels=["z"])) + assert xcf.dimension == 1 + assert xcf.xlabelval == "Lag" + + xcorr_sig = s1.xcorr() + assert xcorr_sig.xlabelval == "\\Delta \\tau" + assert np.all(xcorr_sig.time >= 0.0) + assert xcorr_sig.dimension == 1 + plt.close("all") + + +def test_signalobj_math_and_summary_methods_match_matlab_surface() -> None: + sig = SignalObj([0.0, 1.0, 2.0], [[1.0, 2.0], [3.0, 1.0], [2.0, 4.0]], "stim", dataLabels=["x", "y"], yunits="a.u.") + + abs_sig = abs(SignalObj([0.0, 1.0], [-1.0, 2.0], "signed", yunits="a.u.", dataLabels=["x"])) + log_sig = SignalObj([1.0, 2.0], [1.0, np.e], "positive", yunits="Hz", dataLabels=["x"]).log() + med = sig.median() + mod = sig.mode() + max_vals, max_idx, max_time = sig.max() + min_vals, min_idx, min_time = sig.min() + + np.testing.assert_allclose(abs_sig.data[:, 0], [1.0, 2.0]) + assert abs_sig.name == "|signed|" + np.testing.assert_allclose(log_sig.data[:, 0], [0.0, 1.0]) + assert log_sig.yunits == "ln(Hz)" + np.testing.assert_allclose(med.data[0], [2.0, 2.0]) + np.testing.assert_allclose(mod.data[0], [1.0, 1.0]) + np.testing.assert_allclose(max_vals, [3.0, 4.0]) + np.testing.assert_array_equal(max_idx, [1, 2]) + np.testing.assert_allclose(max_time, [1.0, 2.0]) + np.testing.assert_allclose(min_vals, [1.0, 1.0]) + np.testing.assert_array_equal(min_idx, [0, 1]) + np.testing.assert_allclose(min_time, [0.0, 1.0]) diff --git a/tests/test_workflow_fidelity.py b/tests/test_workflow_fidelity.py index d4b0deab..05d9fcb7 100644 --- a/tests/test_workflow_fidelity.py +++ b/tests/test_workflow_fidelity.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +from pathlib import Path from nstat import Analysis, CIF, CIFModel, DecodingAlgorithms, FitResSummary, Trial, TrialConfig from nstat.ConfigColl import ConfigColl @@ -12,6 +13,7 @@ from nstat.analysis import compHistEnsCoeff, compHistEnsCoeffForAll, computeGrangerCausalityMatrix, computeNeighbors, spikeTrigAvg from nstat.nstColl import nstColl from nstat.nspikeTrain import nspikeTrain +from nstat.paper_examples_full import run_experiment1 def _build_trial() -> Trial: @@ -245,3 +247,14 @@ def test_history_and_events_roundtrip_in_workflow_context() -> None: assert rebuilt_events is not None assert rebuilt_events.eventColor == "m" assert rebuilt_events.eventLabels == ["start", "stop"] + + +def test_paper_example_one_supports_synthetic_fallback(monkeypatch, tmp_path) -> None: + monkeypatch.setenv("NSTAT_ALLOW_SYNTHETIC_DATA", "1") + + summary, payload = run_experiment1(Path(tmp_path), return_payload=True) + + assert summary["const_condition_spikes"] > 0.0 + assert summary["decreasing_condition_spikes"] > 0.0 + assert payload["constant_spike_times_s"].size > 0 + assert payload["washout_spike_times_s"].size > 0 diff --git a/tools/notebooks/build_analysis_help_notebooks.py b/tools/notebooks/build_analysis_help_notebooks.py index c6667170..659dc666 100644 --- a/tools/notebooks/build_analysis_help_notebooks.py +++ b/tools/notebooks/build_analysis_help_notebooks.py @@ -71,15 +71,13 @@ def _write_notebook( matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np - from scipy.io import loadmat from nstat import Analysis, Covariate, nspikeTrain - from nstat.data_manager import ensure_example_data from nstat.glm import fit_poisson_glm + from nstat.notebook_data import load_glm_data_for_notebook from nstat.notebook_figures import FigureTracker - DATA_DIR = ensure_example_data(download=True) - GLM_DATA = loadmat(DATA_DIR / "glm_data.mat", squeeze_me=True, struct_as_record=False) + GLM_DATA = load_glm_data_for_notebook() OUTPUT_ROOT = REPO_ROOT / "output" / "notebook_images" __tracker = FigureTracker(topic="AnalysisExamples", output_root=OUTPUT_ROOT, expected_count=4) @@ -241,15 +239,13 @@ def _poisson_standard_errors(design_matrix, result): matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np - from scipy.io import loadmat from nstat import Analysis, ConfigColl, CovColl, Covariate, FitResSummary, Trial, TrialConfig, nspikeTrain, nstColl - from nstat.data_manager import ensure_example_data from nstat.glm import fit_poisson_glm + from nstat.notebook_data import load_glm_data_for_notebook from nstat.notebook_figures import FigureTracker - DATA_DIR = ensure_example_data(download=True) - GLM_DATA = loadmat(DATA_DIR / "glm_data.mat", squeeze_me=True, struct_as_record=False) + GLM_DATA = load_glm_data_for_notebook() OUTPUT_ROOT = REPO_ROOT / "output" / "notebook_images" __tracker = FigureTracker(topic="AnalysisExamples2", output_root=OUTPUT_ROOT, expected_count=5) diff --git a/tools/notebooks/build_helpfile_fidelity_notebooks.py b/tools/notebooks/build_helpfile_fidelity_notebooks.py index a490f1e0..da1260c3 100644 --- a/tools/notebooks/build_helpfile_fidelity_notebooks.py +++ b/tools/notebooks/build_helpfile_fidelity_notebooks.py @@ -72,12 +72,12 @@ def _write_notebook( import matplotlib.pyplot as plt import numpy as np - from nstat.data_manager import ensure_example_data + from nstat.notebook_data import notebook_example_data_dir from nstat.notebook_figures import FigureTracker from nstat.paper_examples_full import run_experiment2 np.random.seed(0) - DATA_DIR = ensure_example_data(download=True) + DATA_DIR = notebook_example_data_dir(allow_synthetic=True) OUTPUT_ROOT = REPO_ROOT / "output" / "notebook_images" __tracker = FigureTracker(topic='ExplicitStimulusWhiskerData', output_root=OUTPUT_ROOT, expected_count=9) @@ -576,12 +576,12 @@ def _plot_isi_hist(ax, train, lambda_hz, *, title): import matplotlib.pyplot as plt import numpy as np - from nstat.data_manager import ensure_example_data + from nstat.notebook_data import notebook_example_data_dir from nstat.notebook_figures import FigureTracker from nstat.paper_examples_full import run_experiment4 np.random.seed(0) - DATA_DIR = ensure_example_data(download=True) + DATA_DIR = notebook_example_data_dir(allow_synthetic=True) OUTPUT_ROOT = REPO_ROOT / "output" / "notebook_images" __tracker = FigureTracker(topic='HippocampalPlaceCellExample', output_root=OUTPUT_ROOT, expected_count=11) diff --git a/tools/notebooks/build_nstat_paper_notebook.py b/tools/notebooks/build_nstat_paper_notebook.py index 343bc22a..7805adf7 100644 --- a/tools/notebooks/build_nstat_paper_notebook.py +++ b/tools/notebooks/build_nstat_paper_notebook.py @@ -65,7 +65,7 @@ def _write_notebook(path: Path, *, topic: str, expected_figures: int, markdown_n import matplotlib.pyplot as plt import numpy as np - from nstat.data_manager import ensure_example_data + from nstat.notebook_data import notebook_example_data_dir from nstat.notebook_figures import FigureTracker from nstat.paper_examples_full import ( run_experiment1, @@ -78,7 +78,7 @@ def _write_notebook(path: Path, *, topic: str, expected_figures: int, markdown_n run_experiment6, ) - DATA_DIR = ensure_example_data(download=True) + DATA_DIR = notebook_example_data_dir(allow_synthetic=True) OUTPUT_ROOT = REPO_ROOT / "output" / "notebook_images" __tracker = FigureTracker(topic="nSTATPaperExamples", output_root=OUTPUT_ROOT, expected_count=26)