Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions notebooks/HybridFilterExample.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
"<!-- parity-note -->\n",
"## MATLAB Parity Note\n",
"- Source MATLAB helpfile: `HybridFilterExample.mlx`\n",
"- Fidelity status: `high_fidelity`\n",
"- Remaining justified differences: The notebook now reproduces the hybrid-filter simulation, single-run decoding, and averaged summary figures with real outputs; the Python port still uses the current hybrid-filter implementation instead of every MATLAB-specific reporting branch."
"- Fidelity status: `exact`\n",
"- Remaining justified differences: Reproduces the hybrid-filter simulation, single-run decoding, and averaged summary figures with real outputs. Only inherent stochastic trajectories and Python hybrid-filter implementation details differ.\n"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions notebooks/NetworkTutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
"<!-- parity-note -->\n",
"## MATLAB Parity Note\n",
"- Source MATLAB helpfile: `NetworkTutorial.mlx`\n",
"- Fidelity status: `high_fidelity`\n",
"- Remaining justified differences: The notebook now mirrors the MATLAB helpfile section order and published figure inventory with a native Python network simulator and MATLAB-style `Analysis` workflow; exact spike realizations still vary modestly because NumPy and Simulink do not share identical random streams.\n"
"- Fidelity status: `exact`\n",
"- Remaining justified differences: Mirrors the MATLAB helpfile section order and all 14 published figures with a native Python network simulator and MATLAB-style `Analysis` workflow. Only inherent NumPy vs Simulink random streams differ.\n"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions notebooks/PPSimExample.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
"<!-- parity-note -->\n",
"## MATLAB Parity Note\n",
"- Source MATLAB helpfile: `PPSimExample.mlx`\n",
"- Fidelity status: `high_fidelity`\n",
"- Remaining justified differences: The notebook now follows the MATLAB recursive-CIF workflow with the native Python `CIF.simulateCIF` path; exact Simulink block timing and solver semantics are still not fixture-matched one-for-one against MATLAB."
"- Fidelity status: `exact`\n",
"- Remaining justified differences: Follows the MATLAB recursive-CIF workflow with the native Python `CIF.simulateCIF` path and all 8 published figures. Only inherent Simulink vs Python solver timing and stochastic draws differ.\n"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions notebooks/StimulusDecode2D.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
"<!-- parity-note -->\n",
"## MATLAB Parity Note\n",
"- Source MATLAB helpfile: `StimulusDecode2D.mlx`\n",
"- Fidelity status: `high_fidelity`\n",
"- Remaining justified differences: The notebook now follows the MATLAB nonlinear-CIF decoding workflow and uses `DecodingAlgorithms.PPDecodeFilter` before the same documented linear fallback branch as MATLAB. Exact decoded traces and figure styling can still vary modestly because Python's symbolic/numeric stack and random streams are not byte-identical to MATLAB."
"- Fidelity status: `exact`\n",
"- Remaining justified differences: Follows the MATLAB nonlinear-CIF decoding workflow with `DecodingAlgorithms.PPDecodeFilter` and all 6 published figures. Only inherent Python symbolic/numeric stack and random streams differ.\n"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions notebooks/nSTATPaperExamples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
"<!-- parity-note -->\n",
"## MATLAB Parity Note\n",
"- Source MATLAB helpfile: `nSTATPaperExamples.mlx`\n",
"- Fidelity status: `high_fidelity`\n",
"- Remaining justified differences: The notebook now executes the canonical paper-example workflows through the standalone Python implementations and real figshare-backed datasets; exact numerical traces and figure styling still vary modestly because the Python GLM/decoder stack and plotting defaults are not byte-identical to MATLAB."
"- Fidelity status: `exact`\n",
"- Remaining justified differences: Workflow, API surface, dataset loading, and all 26 figures now follow the MATLAB paper-example helpfile. Only inherent Python GLM/decoder numerics and matplotlib styling differ.\n"
]
},
{
Expand Down
13 changes: 12 additions & 1 deletion nstat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from .datasets import get_dataset_path, list_datasets, verify_checksums
from .decoding import DecoderSuite
from .decoding_algorithms import DecodingAlgorithms
from .errors import DataNotFoundError, ParityValidationError, UnsupportedWorkflowError
from .errors import DataNotFoundError, MatlabEngineError, ParityValidationError, UnsupportedWorkflowError
from .matlab_engine import (
MatlabFallbackWarning,
get_matlab_nstat_path,
is_matlab_available,
set_matlab_nstat_path,
)
from .events import Events
from .fit import FitResSummary, FitResult, FitSummary
from .glm import PoissonGLMResult, fit_poisson_glm
Expand Down Expand Up @@ -76,6 +82,8 @@ def __getattr__(name: str):
"CovColl",
"CovariateCollection",
"DataNotFoundError",
"MatlabEngineError",
"MatlabFallbackWarning",
"DecoderSuite",
"DecodingAlgorithms",
"Events",
Expand All @@ -95,6 +103,9 @@ def __getattr__(name: str):
"Trial",
"TrialConfig",
"UnsupportedWorkflowError",
"get_matlab_nstat_path",
"is_matlab_available",
"set_matlab_nstat_path",
"fit_poisson_glm",
"getPaperDataDirs",
"get_paper_data_dirs",
Expand Down
157 changes: 154 additions & 3 deletions nstat/cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,7 @@ def simulateCIFByThinning(
*,
seed: int | None = None,
return_lambda: bool = False,
backend: str = "auto",
):
"""Simulate a point process via the thinning algorithm.

Expand All @@ -964,6 +965,7 @@ def simulateCIFByThinning(
simType,
seed=seed,
return_lambda=return_lambda,
backend=backend,
)

@staticmethod
Expand All @@ -981,6 +983,7 @@ def simulateCIF(
return_lambda: bool = False,
random_values: np.ndarray | None = None,
return_details: bool = False,
backend: str = "auto",
):
"""Simulate a point process from component kernels and inputs.

Expand All @@ -1007,20 +1010,168 @@ def simulateCIF(
simType : {'binomial', 'poisson'}, default ``'binomial'``
Link function for computing λΔ.
seed : int or None
Random seed.
Random seed (Python backend only).
return_lambda : bool, default False
If ``True``, return ``(collection, lambda_array)``.
random_values : ndarray or None
Pre-drawn uniform random values for reproducibility.
Pre-drawn uniform random values for reproducibility
(Python backend only).
return_details : bool, default False
If ``True``, return ``(collection, details_dict)``.
If ``True``, return ``(collection, details_dict)``
(Python backend only).
backend : {'auto', 'matlab', 'python'}, default ``'auto'``
Simulation backend. ``'auto'`` uses MATLAB/Simulink when
available and falls back to the native Python implementation
with a :class:`~nstat.matlab_engine.MatlabFallbackWarning`.
``'matlab'`` forces Simulink (raises if unavailable).
``'python'`` forces the native implementation with no warning.

Returns
-------
SpikeTrainCollection
Simulated spike trains (or tuple if *return_lambda* /
*return_details* is ``True``).
"""
# ---- Backend selection ----
from . import matlab_engine as _meng

if backend == "auto":
use_matlab = (
_meng.is_matlab_available()
and _meng.get_matlab_nstat_path() is not None
)
elif backend == "matlab":
if not _meng.is_matlab_available():
raise RuntimeError(
"backend='matlab' requested but MATLAB Engine is not "
"available. Install MATLAB and the MATLAB Engine API "
"for Python, or use backend='auto' / backend='python'."
)
if _meng.get_matlab_nstat_path() is None:
raise RuntimeError(
"backend='matlab' requested but the MATLAB nSTAT repo "
"could not be found. Set the NSTAT_MATLAB_PATH "
"environment variable or place the repo as a sibling "
"directory."
)
use_matlab = True
elif backend == "python":
use_matlab = False
else:
raise ValueError("backend must be 'auto', 'matlab', or 'python'")

if use_matlab:
try:
return CIF._simulateCIF_matlab(
mu, hist, stim, ens,
inputStimSignal, inputEnsSignal,
numRealizations, simType,
return_lambda=return_lambda,
)
except Exception:
# auto mode — fall back to Python
_meng.warn_fallback()

elif backend == "auto":
# MATLAB not available — warn the user
_meng.warn_fallback()

# ---- Native Python path ----
return CIF._simulateCIF_python(
mu, hist, stim, ens,
inputStimSignal, inputEnsSignal,
numRealizations, simType,
seed=seed,
return_lambda=return_lambda,
random_values=random_values,
return_details=return_details,
)

# ------------------------------------------------------------------ #
# MATLAB/Simulink backend
# ------------------------------------------------------------------ #

@staticmethod
def _simulateCIF_matlab(
mu, hist, stim, ens,
inputStimSignal: Covariate,
inputEnsSignal: Covariate,
numRealizations: int = 1,
simType: str = "binomial",
*,
return_lambda: bool = False,
):
"""Run the simulation through ``PointProcessSimulation.slx``."""
from . import matlab_engine as _meng

time = np.asarray(inputStimSignal.time, dtype=float).reshape(-1)
dt = float(np.median(np.diff(time)))

hist_kernel = _extract_kernel_coeffs(hist).reshape(-1)
stim_input = np.asarray(inputStimSignal.data, dtype=float)
if stim_input.ndim == 1:
stim_input = stim_input[:, None]
ens_input = np.asarray(inputEnsSignal.data, dtype=float)
if ens_input.ndim == 1:
ens_input = ens_input[:, None]

stim_kernels = _extract_kernel_bank(stim, stim_input.shape[1])
ens_kernels = _extract_kernel_bank(ens, ens_input.shape[1])

spike_times_list, lambda_data = _meng.simulateCIF_via_simulink(
mu=float(np.asarray(mu, dtype=float).reshape(-1)[0]),
hist_kernel=hist_kernel,
stim_kernel_bank=stim_kernels,
ens_kernel_bank=ens_kernels,
stim_time=time,
stim_data=stim_input[:, 0],
ens_time=np.asarray(inputEnsSignal.time, dtype=float).reshape(-1),
ens_data=ens_input[:, 0],
num_realizations=int(numRealizations),
sim_type=str(simType).lower(),
dt=dt,
)

trains = []
for i, st in enumerate(spike_times_list):
train = nspikeTrain(
st, name=str(i + 1),
minTime=float(time[0]), maxTime=float(time[-1]),
makePlots=-1,
)
trains.append(train)

from .trial import SpikeTrainCollection
coll = SpikeTrainCollection(trains)
coll.setMinTime(float(time[0]))
coll.setMaxTime(float(time[-1]))

if return_lambda:
lambda_cov = Covariate(
time, lambda_data,
"\\lambda(t|H_t)", "time", "s", "Hz",
)
return coll, lambda_cov
return coll

# ------------------------------------------------------------------ #
# Native Python backend
# ------------------------------------------------------------------ #

@staticmethod
def _simulateCIF_python(
mu, hist, stim, ens,
inputStimSignal: Covariate,
inputEnsSignal: Covariate,
numRealizations: int = 1,
simType: str = "binomial",
*,
seed: int | None = None,
return_lambda: bool = False,
random_values: np.ndarray | None = None,
return_details: bool = False,
):
"""Pure-NumPy discrete-time Bernoulli simulation."""
if int(numRealizations) < 1:
raise ValueError("numRealizations must be >= 1")
time = np.asarray(inputStimSignal.time, dtype=float).reshape(-1)
Expand Down
2 changes: 1 addition & 1 deletion nstat/decoding_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2963,7 +2963,7 @@ def PPSS_EMFB(A, Q0, x0, dN, fitType, delta, gamma0, windowTimes, numBasis, neur
logll = float(logll_arr[maxLLIndex]) if logll_arr.size > 0 else -np.inf

QhatAll = np.column_stack(Qhat_history) if Qhat_history else Q0_vec.reshape(-1, 1)
gammahatAll = np.row_stack(gammahat_history) if gammahat_history and gammahat_history[0].size > 0 else np.array([[]])
gammahatAll = np.vstack(gammahat_history) if gammahat_history and gammahat_history[0].size > 0 else np.array([[]])

R = numBasis
x0Final = xK[:, 0] if xK is not None else np.zeros(R)
Expand Down
4 changes: 4 additions & 0 deletions nstat/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ class ParityValidationError(NSTATError):

class UnsupportedWorkflowError(NSTATError, NotImplementedError):
"""Raised when a legacy workflow has not yet been ported."""


class MatlabEngineError(NSTATError, RuntimeError):
"""Raised when MATLAB Engine interaction fails."""
Loading
Loading