From 273eaea00364c55d09da83503e278a078284f03d Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Mon, 9 Mar 2026 23:17:52 -0400 Subject: [PATCH 1/6] Add SignalObj spectral and utility methods matching Matlab API Port missing SignalObj methods from Matlab nSTAT toolbox: - Time manipulation: shift, shiftMe, alignTime - Arithmetic: power, sqrt - Cross-covariance: xcov (mean-removed xcorr) - Spectral: periodogram, MTMspectrum (DPSS multi-taper), spectrogram Co-Authored-By: Claude Opus 4.6 --- nstat/core.py | 195 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 195 insertions(+) diff --git a/nstat/core.py b/nstat/core.py index 1c2381a7..201f88bd 100644 --- a/nstat/core.py +++ b/nstat/core.py @@ -938,6 +938,201 @@ def xcorr(self, other: "SignalObj" | None = None, maxlag: int | None = None) -> data_labels, ) + # ------------------------------------------------------------------ + # Time-shift / alignment helpers (match Matlab SignalObj) + # ------------------------------------------------------------------ + def shift(self, deltaT: float, updateLabels: bool = False) -> "SignalObj": + """Return a copy with time shifted by *deltaT* seconds.""" + new_time = self.time + float(deltaT) + out = self.__class__( + new_time, + self.data.copy(), + self.name, + self.xlabelval, + self.xunits, + self.yunits, + list(self.dataLabels), + list(self.plotProps), + ) + if updateLabels: + out.name = f"{self.name} shifted by {deltaT}" + return out + + def shiftMe(self, deltaT: float, updateLabels: bool = False) -> None: + """In-place time shift by *deltaT* seconds (Matlab ``shiftMe``).""" + self.time = self.time + float(deltaT) + self.minTime = float(np.min(self.time)) if self.time.size else 0.0 + self.maxTime = float(np.max(self.time)) if self.time.size else 0.0 + if updateLabels: + self.name = f"{self.name} shifted by {deltaT}" + + def alignTime(self, timeMarker: float, newTime: float = 0.0) -> None: + """Shift so that *timeMarker* becomes *newTime* (Matlab ``alignTime``).""" + self.shiftMe(float(newTime) - float(timeMarker)) + + # ------------------------------------------------------------------ + # Element-wise arithmetic helpers (match Matlab SignalObj) + # ------------------------------------------------------------------ + def power(self, exponent: float) -> "SignalObj": + """Element-wise power ``data ** exponent`` (Matlab ``power``).""" + return self.__class__( + self.time.copy(), + np.power(self.data, float(exponent)), + f"{self.name}^{exponent}", + self.xlabelval, + self.xunits, + self.yunits, + list(self.dataLabels), + list(self.plotProps), + ) + + def sqrt(self) -> "SignalObj": + """Element-wise square root (Matlab ``sqrt``).""" + return self.power(0.5) + + # ------------------------------------------------------------------ + # Cross-covariance (match Matlab SignalObj.xcov) + # ------------------------------------------------------------------ + def xcov(self, other: "SignalObj | None" = None, maxlag: int | None = None, + scaleOpt: str = "biased") -> "SignalObj": + """Cross-covariance (mean-removed xcorr). Matches Matlab ``xcov``.""" + s1 = self + s2 = self if other is None else other + s1c, s2c = s1.makeCompatible(s2) + + data_columns: list[np.ndarray] = [] + data_labels: list[str] = [] + lag_index: np.ndarray | None = None + + for li in range(s1c.dimension): + for ri in range(s2c.dimension): + x = s1c.data[:, li] - np.mean(s1c.data[:, li]) + y = s2c.data[:, ri] - np.mean(s2c.data[:, ri]) + corr = np.correlate(x, y, mode="full") + N = len(x) + lags = np.arange(-N + 1, N, dtype=int) + + # scale + if scaleOpt == "biased": + corr = corr / N + elif scaleOpt == "unbiased": + denom = N - np.abs(lags) + denom[denom <= 0] = 1 + corr = corr / denom + elif scaleOpt == "coeff": + corr = corr / corr[N - 1] if corr[N - 1] != 0 else corr + + if maxlag is not None: + keep = np.abs(lags) <= int(maxlag) + corr = corr[keep] + lags = lags[keep] + + if lag_index is None: + lag_index = lags.astype(float) / max(float(s1c.sampleRate), 1e-12) + data_columns.append(np.asarray(corr, dtype=float)) + ll = s1c.dataLabels[li] if li < len(s1c.dataLabels) else str(li + 1) + rl = s2c.dataLabels[ri] if ri < len(s2c.dataLabels) else str(ri + 1) + data_labels.append(f"xcov({ll},{rl})") + + data = np.column_stack(data_columns) if data_columns else np.zeros((0, 0), dtype=float) + return self.__class__( + lag_index if lag_index is not None else np.array([], dtype=float), + data, + f"xcov({self.name},{s2.name})", + "\\Delta \\tau", + self.xunits, + f"{self.yunits}^2" if self.yunits else "", + data_labels, + ) + + # ------------------------------------------------------------------ + # Spectral methods (match Matlab SignalObj) + # ------------------------------------------------------------------ + def periodogram(self, NFFT: int | None = None) -> tuple[np.ndarray, np.ndarray]: + """Power spectral density via periodogram (Matlab ``periodogram``). + + Returns ``(frequencies, psd)`` arrays. + """ + from scipy.signal import periodogram as _periodogram + + fs = float(self.sampleRate) + x = self.data[:, 0] if self.data.ndim == 2 else self.data + f, Pxx = _periodogram(x, fs=fs, nfft=NFFT, window="boxcar", + scaling="density") + return f, Pxx + + def MTMspectrum(self, NW: float = 4.0, Kmax: int | None = None, + NFFT: int | None = None) -> tuple[np.ndarray, np.ndarray]: + """Multi-taper spectral estimate (Matlab ``MTMspectrum``). + + Uses discrete prolate spheroidal sequences (DPSS / Slepian tapers). + + Parameters + ---------- + NW : float + Time-bandwidth product (default 4). + Kmax : int, optional + Number of tapers (default ``2*NW - 1``). + NFFT : int, optional + FFT length (default next power of 2 >= N). + + Returns + ------- + frequencies : ndarray + psd : ndarray + """ + from scipy.signal.windows import dpss + + x = self.data[:, 0] if self.data.ndim == 2 else self.data + N = len(x) + fs = float(self.sampleRate) + if Kmax is None: + Kmax = int(2 * NW - 1) + if NFFT is None: + NFFT = int(2 ** np.ceil(np.log2(N))) + + tapers, eigenvalues = dpss(N, NW, Kmax, return_ratios=True) + # tapers shape: (Kmax, N) + # Compute tapered FFTs + Sk = np.zeros((Kmax, NFFT // 2 + 1)) + for k in range(Kmax): + xw = x * tapers[k] + Xf = np.fft.rfft(xw, n=NFFT) + Sk[k] = np.abs(Xf) ** 2 + + # Weighted average by eigenvalues + weights = eigenvalues / eigenvalues.sum() + psd = np.dot(weights, Sk) * (2.0 / fs) + # DC and Nyquist don't get doubled + psd[0] /= 2.0 + if NFFT % 2 == 0: + psd[-1] /= 2.0 + + frequencies = np.fft.rfftfreq(NFFT, d=1.0 / fs) + return frequencies, psd + + def spectrogram(self, nperseg: int = 256, noverlap: int | None = None, + NFFT: int | None = None, + window: str = "hann") -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Short-time Fourier transform spectrogram (Matlab ``spectrogram``). + + Returns ``(frequencies, times, Sxx)``. + """ + from scipy.signal import spectrogram as _spectrogram + + x = self.data[:, 0] if self.data.ndim == 2 else self.data + fs = float(self.sampleRate) + if noverlap is None: + noverlap = nperseg // 2 + if NFFT is None: + NFFT = nperseg + f, t, Sxx = _spectrogram(x, fs=fs, window=window, + nperseg=nperseg, noverlap=noverlap, + nfft=NFFT) + # offset times to match signal start + t = t + self.minTime + return f, t, Sxx + def setConfInterval(self, bounds: tuple[np.ndarray, np.ndarray]) -> None: low, high = bounds low_arr = np.asarray(low, dtype=float) From 16dcc77511e85dc5860005c1b7e127b4c77385f1 Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Mon, 9 Mar 2026 23:18:04 -0400 Subject: [PATCH 2/6] Add FitResSummary plotting methods for cross-neuron visualization Port missing FitResSummary methods from Matlab nSTAT toolbox: - plotAllCoeffs: errorbar plot of GLM coefficients across neurons/fits - plot3dCoeffSummary: 3D bar plot of binned significant coefficients - plot2dCoeffSummary: stacked ridge-plot of coefficient distributions - plotKSSummary: subplot grid of KS goodness-of-fit plots per neuron Co-Authored-By: Claude Opus 4.6 --- nstat/fit.py | 134 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) diff --git a/nstat/fit.py b/nstat/fit.py index 51b0ba4b..5c306290 100644 --- a/nstat/fit.py +++ b/nstat/fit.py @@ -1429,6 +1429,140 @@ def boxPlot(self, X, diffIndex: int = 1, h=None, dataLabels=None, **kwargs): ax.boxplot(values, labels=labels) return ax + # ------------------------------------------------------------------ + # Coefficient plotting (match Matlab FitResSummary) + # ------------------------------------------------------------------ + def plotAllCoeffs(self, fitNum: int | list[int] | None = None, + plotSignificance: bool = True, + subIndex: list[int] | None = None, + handle=None): + """Errorbar plot of GLM coefficients across neurons (Matlab ``plotAllCoeffs``).""" + if fitNum is None: + fitNum = list(range(1, self.numResults + 1)) + if isinstance(fitNum, int): + fitNum = [fitNum] + ax = handle if handle is not None else plt.subplots(1, 1, figsize=(10, 5))[1] + + coeff_mat, labels, se_mat = self.getCoeffs(fitNum[0]) + if subIndex is not None: + labels = [labels[i] for i in subIndex] + coeff_mat = coeff_mat[:, subIndex] + se_mat = se_mat[:, subIndex] + + x = np.arange(1, len(labels) + 1) + for n_idx in range(self.numNeurons): + ax.errorbar(x, coeff_mat[n_idx], yerr=se_mat[n_idx], fmt=".", + alpha=0.7, capsize=2) + + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=90, fontsize=8) + ax.set_ylabel("Fit Coefficients") + ax.grid(True, alpha=0.3) + ax.axhline(0, color="0.5", linewidth=0.5, linestyle="--") + return ax + + def plot3dCoeffSummary(self, handle=None): + """3D ribbon plot of binned coefficient distributions (Matlab ``plot3dCoeffSummary``).""" + from mpl_toolkits.mplot3d import Axes3D # noqa: F401 + + N, edges, _percentSig = self.binCoeffs(-12, 12, 0.1) + labels = self.uniqueCovLabels + fig = plt.figure(figsize=(10, 7)) if handle is None else handle + if hasattr(fig, "add_subplot"): + ax = fig.add_subplot(111, projection="3d") + else: + ax = fig + + for i in range(N.shape[1] if N.ndim == 2 else 0): + xs = edges[:-1] + ys = np.full_like(xs, i) + zs = N[:len(xs), i] if N.shape[0] > len(xs) else N[:, i] + ax.bar(xs, zs, zs=i, zdir="y", alpha=0.6, width=(edges[1] - edges[0])) + + ax.set_xlabel("Coefficient value") + ax.set_ylabel("Covariate index") + ax.set_zlabel("Density") + ax.set_yticks(range(len(labels))) + ax.set_yticklabels(labels, fontsize=6) + return ax + + def plot2dCoeffSummary(self, handle=None): + """Stacked line plot of binned coefficient distributions (Matlab ``plot2dCoeffSummary``).""" + ax = handle if handle is not None else plt.subplots(1, 1, figsize=(8, 6))[1] + N, edges, percentSig = self.binCoeffs(-12, 12, 0.1) + labels = self.uniqueCovLabels + num_coeffs = N.shape[1] if N.ndim == 2 else 0 + + for i in range(num_coeffs): + offset = i + 1 + vals = N[:len(edges), i] if N.shape[0] >= len(edges) else N[:, i] + ax.plot(edges[:len(vals)], vals + offset) + + ax.set_yticks(range(1, num_coeffs + 1)) + ax.set_yticklabels(labels[:num_coeffs], fontsize=6) + # Annotate significance percentages + for i in range(num_coeffs): + if i < len(percentSig): + pct = float(percentSig) if np.isscalar(percentSig) else float(percentSig[i]) if hasattr(percentSig, "__getitem__") else 0.0 + ax.annotate(f"{pct*100:.0f}%sig", xy=(0.98, (i + 1)), + xycoords=("axes fraction", "data"), + fontsize=6, ha="right") + return ax + + def plotKSSummary(self, neurons: list[int] | None = None, handle=None): + """Subplot grid of KS plots per neuron (Matlab ``plotKSSummary``).""" + if neurons is None: + neurons = list(range(self.numNeurons)) + n = len(neurons) + if n <= 1: + nrows, ncols = 1, 1 + elif n <= 2: + nrows, ncols = 1, 2 + elif n <= 4: + nrows, ncols = 2, 2 + elif n <= 8: + nrows, ncols = 2, 4 + elif n <= 12: + nrows, ncols = 3, 4 + elif n <= 16: + nrows, ncols = 4, 4 + elif n <= 20: + nrows, ncols = 5, 4 + elif n <= 24: + nrows, ncols = 6, 4 + elif n <= 40: + nrows, ncols = 10, 4 + else: + nrows, ncols = 10, 10 + + fig = handle if handle is not None else plt.figure(figsize=(3 * ncols, 2.5 * nrows)) + if hasattr(fig, "subplots"): + fig.clear() + axes = fig.subplots(nrows, ncols, squeeze=False) + else: + return fig + + for cnt, neuron_idx in enumerate(neurons): + row, col = divmod(cnt, ncols) + ax = axes[row][col] + fit = self.fitResCell[neuron_idx] + fit.KSPlot(handle=ax) + ax.set_title(f"N{neuron_idx + 1}", fontsize=8) + if cnt < n - 1: + ax.get_legend().set_visible(False) if ax.get_legend() else None + ax.set_xlabel("") + ax.set_ylabel("") + ax.set_xticks([0, 1]) + ax.set_yticks([0, 1]) + + # Hide unused subplots + for idx in range(n, nrows * ncols): + row, col = divmod(idx, ncols) + axes[row][col].set_visible(False) + + fig.tight_layout() + return fig + def toStructure(self) -> dict[str, Any]: return { "fitResCell": FitResult.CellArrayToStructure(self.fitResCell), From 64e3471f457e69b245a0f17158cc4fcd02f19279 Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Mon, 9 Mar 2026 23:18:18 -0400 Subject: [PATCH 3/6] Add SSGLM EM algorithm, UKF, and nstColl.ssglm/ssglmFB methods Major additions ported from Matlab nSTAT toolbox: SSGLM (State-Space GLM) EM algorithm in decoding_algorithms.py: - PPSS_EStep/PPSS_MStep: E-step (Kalman smoothing) and M-step (Newton) - PPSS_EM/PPSS_EMFB: Full EM with forward-backward smoothing - estimateInfoMat: Fisher information for coefficient CIs - prepareEMResults: Package EM output as FitResult objects - _ComputeStimulusCIs_MC: Monte Carlo stimulus confidence intervals Unscented Kalman Filter (UKF) in decoding_algorithms.py: - ukf_sigmas: Sigma point generation via Cholesky decomposition - ukf_ut: Unscented transformation - ukf: Full UKF one-step update SpikeTrainCollection methods in trial.py: - ssglm(): Run SSGLM EM on spike train collection - ssglmFB(): Run SSGLM EM with forward-backward smoothing Co-Authored-By: Claude Opus 4.6 --- nstat/decoding_algorithms.py | 1150 ++++++++++++++++++++++++++++++++++ nstat/trial.py | 183 ++++++ 2 files changed, 1333 insertions(+) diff --git a/nstat/decoding_algorithms.py b/nstat/decoding_algorithms.py index 3433391a..021a6137 100644 --- a/nstat/decoding_algorithms.py +++ b/nstat/decoding_algorithms.py @@ -1069,6 +1069,1142 @@ def PPHybridFilter(A, Q, p_ij, Mu0, dN, lambdaCIFColl, binwidth=0.001, x0=None, MinClassificationError, ) + # ------------------------------------------------------------------ + # Unscented Kalman Filter (UKF) + # Ported from Matlab DecodingAlgorithms.m + # ------------------------------------------------------------------ + + @staticmethod + def ukf_sigmas(x: np.ndarray, P: np.ndarray, c: float) -> np.ndarray: + """Generate sigma points around reference point *x*. + + Parameters + ---------- + x : (L,) state vector + P : (L, L) covariance + c : scaling coefficient + + Returns + ------- + X : (L, 2L+1) sigma-point matrix + """ + x = np.asarray(x, dtype=float).reshape(-1) + P = np.asarray(P, dtype=float) + A = c * np.linalg.cholesky(P) # (L, L) + L = len(x) + Y = np.tile(x[:, None], (1, L)) + X = np.column_stack([x[:, None], Y + A, Y - A]) + return X + + @staticmethod + def ukf_ut(f, X: np.ndarray, Wm: np.ndarray, Wc: np.ndarray, + n: int, R: np.ndarray): + """Unscented transformation. + + Parameters + ---------- + f : callable mapping (L,) -> (n,) + X : (L, 2L+1) sigma points + Wm, Wc : (2L+1,) weights + n : output dimensionality + R : (n, n) additive covariance + + Returns + ------- + y : (n,) transformed mean + Y : (n, 2L+1) transformed sigma points + P : (n, n) transformed covariance + Y1 : (n, 2L+1) deviations + """ + Lpts = X.shape[1] + y = np.zeros(n) + Y = np.zeros((n, Lpts)) + for k in range(Lpts): + Y[:, k] = np.asarray(f(X[:, k]), dtype=float).reshape(-1)[:n] + y += Wm[k] * Y[:, k] + Y1 = Y - y[:, None] + P = Y1 @ np.diag(Wc) @ Y1.T + np.asarray(R, dtype=float).reshape(n, n) + return y, Y, P, Y1 + + @staticmethod + def ukf(fstate, x: np.ndarray, P: np.ndarray, hmeas, + z: np.ndarray, Q: np.ndarray, R: np.ndarray): + """Unscented Kalman Filter for nonlinear systems. + + One-step update matching Matlab ``DecodingAlgorithms.ukf``. + + System model (additive noise):: + + x_{k+1} = fstate(x_k) + w_k, w ~ N(0, Q) + z_k = hmeas(x_k) + v_k, v ~ N(0, R) + + Parameters + ---------- + fstate : callable (L,) -> (L,) + x : (L,) prior state estimate + P : (L, L) prior covariance + hmeas : callable (L,) -> (m,) + z : (m,) measurement + Q : (L, L) process noise covariance + R : (m, m) measurement noise covariance + + Returns + ------- + x : (L,) posterior state estimate + P : (L, L) posterior covariance + """ + x = np.asarray(x, dtype=float).reshape(-1) + z = np.asarray(z, dtype=float).reshape(-1) + P = np.asarray(P, dtype=float) + Q = np.asarray(Q, dtype=float) + R = np.asarray(R, dtype=float) + if R.ndim == 0: + R = R.reshape(1, 1) + elif R.ndim == 1: + R = np.diag(R) + + L = len(x) + m = len(z) + alpha = 1e-3 + ki = 0.0 + beta = 2.0 + lam = alpha ** 2 * (L + ki) - L + c = L + lam + Wm = np.full(2 * L + 1, 0.5 / c) + Wm[0] = lam / c + Wc = Wm.copy() + Wc[0] += (1 - alpha ** 2 + beta) + c_sqrt = np.sqrt(c) + + X = DecodingAlgorithms.ukf_sigmas(x, P, c_sqrt) + x1, X1, P1, X2 = DecodingAlgorithms.ukf_ut(fstate, X, Wm, Wc, L, Q) + z1, Z1, P2, Z2 = DecodingAlgorithms.ukf_ut(hmeas, X1, Wm, Wc, m, R) + P12 = X2 @ np.diag(Wc) @ Z2.T + K = P12 @ np.linalg.inv(P2) + x_out = x1 + K @ (z - z1) + P_out = P1 - K @ P12.T + return x_out, P_out + + # ------------------------------------------------------------------ + # State-Space GLM (SSGLM) via EM Forward-Backward + # Ported from Matlab DecodingAlgorithms.m (PPSS_EMFB and helpers) + # ------------------------------------------------------------------ + + @staticmethod + def _ssglm_build_basis(numBasis, minTime, maxTime, delta): + """Build unit-impulse basis matrix for SSGLM.""" + from .trial import SpikeTrainCollection + + sampleRate = 1.0 / delta + basisWidth = (maxTime - minTime) / numBasis + basis_cov = SpikeTrainCollection.generateUnitImpulseBasis( + basisWidth, minTime, maxTime, sampleRate + ) + return np.asarray(basis_cov.data, dtype=float) + + @staticmethod + def _ssglm_build_history(dN, windowTimes, delta): + """Build history design matrices for each trial from spike observations.""" + from .history import History + + K, N = dN.shape + minTime = 0.0 + maxTime = (N - 1) * delta + + if windowTimes is not None and len(windowTimes) > 0: + histObj = History(windowTimes, minTime, maxTime) + HkAll = [] + for k in range(K): + spike_indices = np.where(dN[k, :] > 0.5)[0] + spike_times = spike_indices.astype(float) * delta + nst = nspikeTrain(spike_times, makePlots=-1) + nst.setMinTime(minTime) + nst.setMaxTime(maxTime) + hist_cov = histObj._compute_single_history(nst) + HkAll.append(np.asarray(hist_cov.data, dtype=float)) + return HkAll + else: + return [np.zeros((N, 0), dtype=float) for _ in range(K)] + + @staticmethod + def PPSS_EStep(A, Q, x0, dN, HkAll, fitType, delta, gamma, numBasis): + """E-step: Forward Kalman filter + backward RTS smoother + cross-covariance. + + Parameters + ---------- + A : (R, R) state transition matrix + Q : (R,) or (R, R) state noise covariance (diagonal vector or matrix) + x0 : (R,) initial state + dN : (K, N) binary spike observations (K trials, N time bins) + HkAll : list of K arrays, each (N, J) history design matrices + fitType : 'poisson' or 'binomial' + delta : time bin width + gamma : (J,) history coefficients + numBasis : number of basis functions R + + Returns + ------- + x_K, W_K, Wku, logll, sumXkTerms, sumPPll + """ + K, N = dN.shape + minTime = 0.0 + maxTime = (N - 1) * delta + + basisMat = DecodingAlgorithms._ssglm_build_basis(numBasis, minTime, maxTime, delta) + # Ensure basisMat has N rows matching dN columns + if basisMat.shape[0] != N: + basisMat = basisMat[:N, :] if basisMat.shape[0] > N else np.vstack( + [basisMat, np.zeros((N - basisMat.shape[0], basisMat.shape[1]))] + ) + + Q_diag = np.asarray(Q, dtype=float).reshape(-1) + if Q_diag.size == numBasis * numBasis: + Q_mat = Q_diag.reshape(numBasis, numBasis) + Q_diag = np.diag(Q_mat) + Q_mat = np.diag(Q_diag) + + A_mat = np.asarray(A, dtype=float) + if A_mat.ndim < 2: + A_mat = np.eye(numBasis, dtype=float) * A_mat + x0_vec = np.asarray(x0, dtype=float).reshape(-1) + gamma_vec = np.asarray(gamma, dtype=float).reshape(-1) + R = numBasis + fitType = str(fitType).lower() + + # Forward Kalman filter + x_p = np.zeros((R, K), dtype=float) + x_u = np.zeros((R, K), dtype=float) + W_p = np.zeros((R, R, K), dtype=float) + W_u = np.zeros((R, R, K), dtype=float) + + for k in range(K): + if k == 0: + x_p[:, k] = A_mat @ x0_vec + W_p[:, :, k] = Q_mat.copy() + else: + x_p[:, k] = A_mat @ x_u[:, k - 1] + W_p[:, :, k] = A_mat @ W_u[:, :, k - 1] @ A_mat.T + Q_mat + + Hk = HkAll[k] + stimK = basisMat @ x_p[:, k] + + if fitType == "poisson": + histEffect = np.exp(np.clip(Hk @ gamma_vec, -30, 30)) if gamma_vec.size > 0 and Hk.shape[1] > 0 else np.ones(N) + stimEffect = np.exp(np.clip(stimK, -30, 30)) + lambdaDelta = stimEffect * histEffect + + GradLogLD = basisMat # (N, R) + GradLD = basisMat * lambdaDelta[:, None] # (N, R) + + sumValVec = GradLogLD.T @ dN[k, :] - np.diag(GradLD.T @ basisMat) + sumValMat = GradLD.T @ basisMat + + elif fitType == "binomial": + Hk = HkAll[k] + stimK = basisMat @ x_p[:, k] + linpred = stimK + (Hk @ gamma_vec if gamma_vec.size > 0 and Hk.shape[1] > 0 else 0.0) + linpred = np.clip(linpred, -30, 30) + lambdaDelta = 1.0 / (1.0 + np.exp(-linpred)) + + GradLogLD = basisMat * (1.0 - lambdaDelta)[:, None] + JacobianLogLD = basisMat * (lambdaDelta * (-1.0 + lambdaDelta))[:, None] + GradLD = basisMat * (lambdaDelta * (1.0 - lambdaDelta))[:, None] + JacobianLD = basisMat * (lambdaDelta * (1.0 - lambdaDelta) * (1.0 - 2.0 * lambdaDelta ** 2))[:, None] + + sumValVec = GradLogLD.T @ dN[k, :] - np.diag(GradLD.T @ basisMat) + sumValMat = -np.diag(JacobianLogLD.T @ dN[k, :]) + JacobianLD.T @ basisMat + else: + raise ValueError(f"Unsupported fitType: {fitType}") + + # Kalman update + W_p_inv = np.linalg.inv(W_p[:, :, k] + 1e-12 * np.eye(R)) + invW_u = W_p_inv + sumValMat + W_u[:, :, k] = np.linalg.inv(invW_u + 1e-12 * np.eye(R)) + + # Ensure positive definiteness + eigvals, eigvecs = np.linalg.eigh(W_u[:, :, k]) + eigvals = np.maximum(eigvals, np.finfo(float).eps) + W_u[:, :, k] = eigvecs @ np.diag(eigvals) @ eigvecs.T + + x_u[:, k] = x_p[:, k] + W_u[:, :, k] @ sumValVec + + # Backward RTS smoother + x_K = np.zeros((R, K), dtype=float) + W_K = np.zeros((R, R, K), dtype=float) + Lk = np.zeros((R, R, K), dtype=float) + + x_K[:, K - 1] = x_u[:, K - 1] + W_K[:, :, K - 1] = W_u[:, :, K - 1] + + for k in range(K - 2, -1, -1): + Lk[:, :, k] = W_u[:, :, k] @ A_mat.T @ np.linalg.inv(W_p[:, :, k + 1] + 1e-12 * np.eye(R)) + x_K[:, k] = x_u[:, k] + Lk[:, :, k] @ (x_K[:, k + 1] - x_p[:, k + 1]) + W_K[:, :, k] = W_u[:, :, k] + Lk[:, :, k] @ (W_K[:, :, k + 1] - W_p[:, :, k + 1]) @ Lk[:, :, k].T + W_K[:, :, k] = 0.5 * (W_K[:, :, k] + W_K[:, :, k].T) + + # Cross-trial covariance Wku (R, R, K, K) + Wku = np.zeros((R, R, K, K), dtype=float) + for k in range(K): + Wku[:, :, k, k] = W_K[:, :, k] + + Dk = np.zeros((R, R, K), dtype=float) + for u in range(K - 1, 0, -1): + for k in range(u - 1, -1, -1): + Dk[:, :, k] = W_u[:, :, k] @ A_mat.T @ np.linalg.inv(W_p[:, :, k + 1] + 1e-12 * np.eye(R)) + Wku[:, :, k, u] = Dk[:, :, k] @ Wku[:, :, k + 1, u] + Wku[:, :, u, k] = Wku[:, :, k, u] + + # Sufficient statistics for M-step + Sxkxkp1 = np.zeros((R, R), dtype=float) + Sxkp1xkp1 = np.zeros((R, R), dtype=float) + Sxkxk = np.zeros((R, R), dtype=float) + for k in range(K - 1): + Sxkxkp1 += Wku[:, :, k, k + 1] + np.outer(x_K[:, k], x_K[:, k + 1]) + Sxkp1xkp1 += W_K[:, :, k + 1] + np.outer(x_K[:, k + 1], x_K[:, k + 1]) + Sxkxk += W_K[:, :, k] + np.outer(x_K[:, k], x_K[:, k]) + + sumXkTerms = ( + Sxkp1xkp1 - A_mat @ Sxkxkp1 - Sxkxkp1.T @ A_mat.T + A_mat @ Sxkxk @ A_mat.T + + W_K[:, :, 0] + np.outer(x_K[:, 0], x_K[:, 0]) + - A_mat @ np.outer(x0_vec, x_K[:, 0]) - np.outer(x_K[:, 0], x0_vec) @ A_mat.T + + A_mat @ np.outer(x0_vec, x0_vec) @ A_mat.T + ) + + # Point process log-likelihood + sumPPll = 0.0 + for k in range(K): + Hk = HkAll[k] + Wk = basisMat @ np.diag(W_K[:, :, k]) + stimK = basisMat @ x_K[:, k] + + if fitType == "poisson": + hist_term = Hk @ gamma_vec if gamma_vec.size > 0 and Hk.shape[1] > 0 else np.zeros(N) + histEffect = np.exp(np.clip(hist_term, -30, 30)) + stimK_clipped = np.clip(stimK, -30, 30) + stimEffect = np.exp(stimK_clipped) + np.exp(stimK_clipped) / 2.0 * Wk + ExplambdaDelta = stimEffect * histEffect + ExplogLD = stimK + hist_term + sumPPll += float(np.sum(dN[k, :] * ExplogLD - ExplambdaDelta)) + + elif fitType == "binomial": + hist_term = Hk @ gamma_vec if gamma_vec.size > 0 and Hk.shape[1] > 0 else np.zeros(N) + linpred = np.clip(stimK + hist_term, -30, 30) + lambdaDelta = 1.0 / (1.0 + np.exp(-linpred)) + ExplambdaDelta = lambdaDelta + Wk * (lambdaDelta * (1.0 - lambdaDelta) * (1.0 - 2.0 * lambdaDelta)) / 2.0 + ExplogLD = linpred - np.log(1.0 + np.exp(linpred)) - Wk * lambdaDelta * (1.0 - lambdaDelta) * 0.5 + sumPPll += float(np.sum(dN[k, :] * ExplogLD - ExplambdaDelta)) + + det_Q = float(np.prod(np.maximum(Q_diag, np.finfo(float).eps))) + logll = ( + -R * K * np.log(2.0 * np.pi) + - K / 2.0 * np.log(det_Q) + + sumPPll + - 0.5 * float(np.trace(np.linalg.pinv(Q_mat) @ sumXkTerms)) + ) + + return x_K, W_K, Wku, logll, sumXkTerms, sumPPll + + @staticmethod + def PPSS_MStep(dN, HkAll, fitType, x_K, W_K, gamma, delta, sumXkTerms, windowTimes): + """M-step: Update Q via closed form, gamma via Newton-Raphson. + + Parameters + ---------- + dN : (K, N) + HkAll : list of K arrays (N, J) + fitType : 'poisson' or 'binomial' + x_K : (R, K) smoothed states + W_K : (R, R, K) smoothed covariances + gamma : (J,) current history coefficients + delta : time bin width + sumXkTerms : (R, R) sufficient statistics from E-step + windowTimes : array of history window boundaries + + Returns + ------- + Qhat : (R,) updated state noise variance (diagonal) + gamma_new : (J,) updated history coefficients + """ + K, N = dN.shape + R = x_K.shape[0] + fitType = str(fitType).lower() + + # Q update (closed form) + sumQ = np.diag(np.diag(sumXkTerms)) + Qhat = sumQ / K + eigvals, eigvecs = np.linalg.eigh(Qhat) + eigvals = np.maximum(eigvals, 1e-8) + Qhat = eigvecs @ np.diag(eigvals) @ eigvecs.T + Qhat = np.diag(Qhat) # Return as vector + + # Build basis matrix for gamma update + minTime = 0.0 + maxTime = (N - 1) * delta + basisMat = DecodingAlgorithms._ssglm_build_basis(R, minTime, maxTime, delta) + if basisMat.shape[0] != N: + basisMat = basisMat[:N, :] if basisMat.shape[0] > N else np.vstack( + [basisMat, np.zeros((N - basisMat.shape[0], basisMat.shape[1]))] + ) + + gamma_vec = np.asarray(gamma, dtype=float).reshape(-1) + gamma_new = gamma_vec.copy() + J = gamma_new.size + + # Newton-Raphson for gamma (history coefficients) + if windowTimes is not None and len(windowTimes) > 0 and J > 0 and np.any(gamma_new != 0): + converged = False + max_iter = 300 + for iteration in range(max_iter): + gradQ = np.zeros(J, dtype=float) + jacQ = np.zeros((J, J), dtype=float) + + for k in range(K): + Hk = HkAll[k] + if Hk.shape[1] == 0: + continue + Wk = basisMat @ np.diag(W_K[:, :, k]) + stimK = basisMat @ x_K[:, k] + + if fitType == "poisson": + hist_term = np.clip(gamma_new @ Hk.T, -30, 30) + histEffect = np.exp(hist_term) + stimK_clipped = np.clip(stimK, -30, 30) + stimEffect = np.exp(stimK_clipped) + np.exp(stimK_clipped) / 2.0 * Wk + lambdaDelta = stimEffect * histEffect + + gradQ += Hk.T @ dN[k, :] - Hk.T @ lambdaDelta + jacQ -= (Hk * lambdaDelta[:, None]).T @ Hk + + elif fitType == "binomial": + linpred = np.clip(stimK + Hk @ gamma_new, -30, 30) + lambdaDelta = 1.0 / (1.0 + np.exp(-linpred)) + histEffect = np.exp(np.clip(gamma_new @ Hk.T, -30, 30)) + stimEffect = np.exp(np.clip(stimK, -30, 30)) + C = stimEffect * histEffect + M = np.where(C > 1e-30, 1.0 / C, 1e30) + ExpLambdaDelta = lambdaDelta + Wk * (lambdaDelta * (1.0 - lambdaDelta) * (1.0 - 2.0 * lambdaDelta)) / 2.0 + ExpLDSquaredTimesInvExp = lambdaDelta ** 2 * M + ExpLDCubedTimesInvExpSquared = ( + lambdaDelta ** 3 * M ** 2 + + Wk / 2.0 * (3.0 * M ** 4 * lambdaDelta ** 3 + + 12.0 * lambdaDelta ** 3 * M ** 3 + - 12.0 * M ** 4 * lambdaDelta ** 4) + ) + + gradQ += (Hk * (1.0 - ExpLambdaDelta)[:, None]).T @ dN[k, :] \ + - (Hk * (ExpLDSquaredTimesInvExp / np.maximum(lambdaDelta, 1e-30))[:, None]).T @ lambdaDelta + jacQ -= (Hk * (ExpLDSquaredTimesInvExp * dN[k, :])[:, None]).T @ Hk \ + + (Hk * ExpLDSquaredTimesInvExp[:, None]).T @ Hk \ + + (Hk * (2.0 * ExpLDCubedTimesInvExpSquared)[:, None]).T @ Hk + + # Newton-Raphson update + try: + gamma_temp = gamma_new - np.linalg.pinv(jacQ) @ gradQ + except np.linalg.LinAlgError: + gamma_temp = gamma_new + + if np.any(np.isnan(gamma_temp)): + gamma_temp = gamma_new + + mabsDiff = float(np.max(np.abs(gamma_temp - gamma_new))) + gamma_new = gamma_temp + if mabsDiff < 1e-2: + converged = True + break + + # Clamp gamma + gamma_new = np.clip(gamma_new, -1e2, 1e2) + + return Qhat, gamma_new + + @staticmethod + def PPSS_EM(A, Q0, x0, dN, fitType, delta, gamma0, windowTimes, numBasis, HkAll): + """Inner EM loop for state-space GLM. + + Parameters + ---------- + A : (R, R) state transition + Q0 : (R,) initial state noise variance + x0 : (R,) initial state + dN : (K, N) observations + fitType : 'poisson' or 'binomial' + delta : time bin width + gamma0 : (J,) initial history coefficients + windowTimes : history window boundaries + numBasis : number of basis functions + HkAll : precomputed history matrices + + Returns + ------- + xKFinal, WKFinal, WkuFinal, Qhat, gammahat, logll, QhatAll, gammahatAll, nIter, negLL + """ + if numBasis is None: + numBasis = 20 + if delta is None or delta == 0: + delta = 0.001 + fitType = str(fitType or "poisson").lower() + + Q0_vec = np.asarray(Q0, dtype=float).reshape(-1) + if Q0_vec.size == numBasis * numBasis: + Q0_vec = np.diag(Q0_vec.reshape(numBasis, numBasis)) + + gamma0_vec = np.asarray(gamma0, dtype=float).reshape(-1) if gamma0 is not None else np.array([], dtype=float) + + tolAbs = 1e-3 + tolRel = 1e-3 + llTol = 1e-3 + maxIter = 100 + numToKeep = 10 + + # Circular buffer storage + Qhat = np.zeros((Q0_vec.size, numToKeep), dtype=float) + Qhat[:, 0] = Q0_vec + gammahat = np.zeros((numToKeep, gamma0_vec.size), dtype=float) + gammahat[0, :] = gamma0_vec + + xK_buf = [None] * numToKeep + WK_buf = [None] * numToKeep + Wku_buf = [None] * numToKeep + + x0hat = np.asarray(x0, dtype=float).reshape(-1) + logll_list = [] + dLikelihood = [np.inf] + negLL = False + stoppingCriteria = False + cnt = 0 + + while not stoppingCriteria and cnt < maxIter: + si = cnt % numToKeep + si_p1 = (cnt + 1) % numToKeep + si_m1 = (cnt - 1) % numToKeep + + xK_cur, WK_cur, Wku_cur, ll, SumXkTerms, _ = DecodingAlgorithms.PPSS_EStep( + A, Qhat[:, si], x0hat, dN, HkAll, fitType, delta, gammahat[si, :], numBasis + ) + xK_buf[si] = xK_cur + WK_buf[si] = WK_cur + Wku_buf[si] = Wku_cur + logll_list.append(ll) + + Qnew, gnew = DecodingAlgorithms.PPSS_MStep( + dN, HkAll, fitType, xK_cur, WK_cur, gammahat[si, :], delta, SumXkTerms, windowTimes + ) + Qhat[:, si_p1] = Qnew + gammahat[si_p1, :] = gnew + + if cnt == 0: + dLikelihood.append(np.inf) + else: + dLikelihood.append(logll_list[cnt] - logll_list[cnt - 1]) + + # Check convergence + if cnt > 0: + dQvals = np.abs(np.sqrt(np.maximum(Qhat[:, si], 0)) - np.sqrt(np.maximum(Qhat[:, si_m1], 0))) + dGamma = np.abs(gammahat[si, :] - gammahat[si_m1, :]) + dMax = max(np.max(dQvals), np.max(dGamma)) if dGamma.size > 0 else float(np.max(dQvals)) + + Q_prev = np.sqrt(np.maximum(Qhat[:, si_m1], 1e-30)) + dQRel = float(np.max(np.abs(dQvals / Q_prev))) + if dGamma.size > 0: + g_prev = np.maximum(np.abs(gammahat[si_m1, :]), 1e-30) + dGammaRel = float(np.max(np.abs(dGamma / g_prev))) + dMaxRel = max(dQRel, dGammaRel) + else: + dMaxRel = dQRel + + if dMax < tolAbs and dMaxRel < tolRel: + stoppingCriteria = True + negLL = False + + if abs(dLikelihood[-1]) < llTol or dLikelihood[-1] < 0: + stoppingCriteria = True + negLL = True + + cnt += 1 + + # Select best iteration by log-likelihood + logll_arr = np.array(logll_list) + if logll_arr.size > 0: + maxLLIndex = int(np.argmax(logll_arr)) + else: + maxLLIndex = 0 + + maxLLIndMod = maxLLIndex % numToKeep + nIter = cnt + + xKFinal = xK_buf[maxLLIndMod] if xK_buf[maxLLIndMod] is not None else np.zeros((numBasis, dN.shape[0])) + WKFinal = WK_buf[maxLLIndMod] if WK_buf[maxLLIndMod] is not None else np.zeros((numBasis, numBasis, dN.shape[0])) + WkuFinal = Wku_buf[maxLLIndMod] if Wku_buf[maxLLIndMod] is not None else np.zeros((numBasis, numBasis, dN.shape[0], dN.shape[0])) + + QhatFinal = Qhat[:, maxLLIndMod] + gammahatFinal = gammahat[maxLLIndMod, :] + logllFinal = float(logll_arr[maxLLIndex]) if logll_arr.size > 0 else -np.inf + + QhatAll = Qhat[:, : min(cnt + 1, numToKeep)] + gammahatAll = gammahat[: min(cnt + 1, numToKeep), :] + + return xKFinal, WKFinal, WkuFinal, QhatFinal, gammahatFinal, logllFinal, QhatAll, gammahatAll, nIter, negLL + + @staticmethod + def PPSS_EMFB(A, Q0, x0, dN, fitType, delta, gamma0, windowTimes, numBasis, neuronName=None): + """EM Forward-Backward algorithm for state-space GLM. + + Wraps PPSS_EM in a forward-backward-forward cycle for improved convergence. + + Parameters + ---------- + A : (R, R) state transition matrix (typically identity for random walk) + Q0 : (R,) initial state noise variance + x0 : (R,) initial state coefficients + dN : (K, N) binary spike observations (K trials, N time bins) + fitType : 'poisson' or 'binomial' + delta : time bin width in seconds + gamma0 : (J,) initial history coefficients + windowTimes : array of history window boundaries + numBasis : number of basis functions R + neuronName : identifier for the neuron (for labeling) + + Returns + ------- + xKFinal : (R, K) estimated state trajectories + WKFinal : (R, R, K) estimated state covariances + WkuFinal : (R, R, K, K) cross-trial covariances + Qhat : (R,) estimated state noise variance + gammahat : (J,) estimated history coefficients + fitResults : FitResult object with goodness-of-fit diagnostics + stimulus : (R, K) estimated stimulus effect + stimCIs : (R, K, 2) stimulus confidence intervals + logll : float, log-likelihood at convergence + QhatAll : parameter history + gammahatAll : parameter history + nIter : total EM iterations + """ + K, N = dN.shape + fitType = str(fitType or "poisson").lower() + + Q0_vec = np.asarray(Q0, dtype=float).reshape(-1) + if Q0_vec.size == numBasis * numBasis: + Q0_vec = np.diag(Q0_vec.reshape(numBasis, numBasis)) + + gamma0_vec = np.asarray(gamma0, dtype=float).reshape(-1) if gamma0 is not None else np.array([], dtype=float) + + Qhat_cur = Q0_vec.copy() + gammahat_cur = gamma0_vec.copy() + xK0 = np.asarray(x0, dtype=float).reshape(-1) + + # Build history matrices + HkAll = DecodingAlgorithms._ssglm_build_history(dN, windowTimes, delta) + HkAllR = list(reversed(HkAll)) + + tolAbs = 1e-3 + tolRel = 1e-3 + llTol = 1e-3 + maxIter = 2000 + + Qhat_history = [Qhat_cur.copy()] + gammahat_history = [gammahat_cur.copy()] + logll_list = [] + stoppingCriteria = False + cnt = 0 + + xK = None + WK = None + Wku = None + + while not stoppingCriteria and cnt < maxIter: + # Forward EM + xK, WK, Wku, Qnew, gnew, ll, _, _, _, negLL = DecodingAlgorithms.PPSS_EM( + A, Qhat_cur, xK0, dN, fitType, delta, gammahat_cur, windowTimes, numBasis, HkAll + ) + + if not negLL: + # Backward EM + _, _, _, QnewR, gnewR, _, _, _, _, negLLR = DecodingAlgorithms.PPSS_EM( + A, Qnew, xK[:, -1], np.flipud(dN), fitType, delta, gnew, windowTimes, numBasis, HkAllR + ) + + if not negLLR: + # Forward EM again with backward-updated parameters + # Matlab: PPSS_EM(A, QhatR(:,cnt+1), xKR(:,end), dN, ...) + xK2, WK2, Wku2, Qnew2, gnew2, ll2, _, _, _, negLL2 = DecodingAlgorithms.PPSS_EM( + A, QnewR, xK[:, -1], dN, fitType, delta, gnewR, + windowTimes, numBasis, HkAll + ) + + if not negLL2: + xK = xK2 + WK = WK2 + Wku = Wku2 + Qnew = Qnew2 + gnew = gnew2 + ll = ll2 + + Qhat_cur = Qnew + gammahat_cur = gnew + Qhat_history.append(Qnew.copy()) + gammahat_history.append(gnew.copy()) + logll_list.append(ll) + + xK0 = xK[:, 0] + + # Check convergence + if cnt > 0: + dLikelihood = logll_list[cnt] - logll_list[cnt - 1] + else: + dLikelihood = np.inf + + if len(Qhat_history) >= 2: + Q_prev = Qhat_history[-2] + Q_cur = Qhat_history[-1] + dQvals = np.abs(np.sqrt(np.maximum(Q_cur, 0)) - np.sqrt(np.maximum(Q_prev, 0))) + g_prev = gammahat_history[-2] + g_cur = gammahat_history[-1] + dGamma = np.abs(g_cur - g_prev) if g_cur.size > 0 else np.array([0.0]) + + dMax = max(float(np.max(dQvals)), float(np.max(dGamma))) + + Q_denom = np.sqrt(np.maximum(Q_prev, 1e-30)) + dQRel = float(np.max(np.abs(dQvals / Q_denom))) + if g_prev.size > 0 and np.any(g_prev != 0): + g_denom = np.maximum(np.abs(g_prev), 1e-30) + dGammaRel = float(np.max(np.abs(dGamma / g_denom))) + dMaxRel = max(dQRel, dGammaRel) + else: + dMaxRel = dQRel + + if dMax < tolAbs and dMaxRel < tolRel: + stoppingCriteria = True + + if abs(dLikelihood) < llTol or dLikelihood < 0: + stoppingCriteria = True + + cnt += 1 + + # Select best iteration + logll_arr = np.array(logll_list) + if logll_arr.size > 0: + maxLLIndex = int(np.argmax(logll_arr)) + else: + maxLLIndex = 0 + + xKFinal = xK + WKFinal = WK + WkuFinal = Wku + Qhat = Qhat_history[min(maxLLIndex + 1, len(Qhat_history) - 1)] + gammahat = gammahat_history[min(maxLLIndex + 1, len(gammahat_history) - 1)] + logll = float(logll_arr[maxLLIndex]) if logll_arr.size > 0 else -np.inf + + QhatAll = np.column_stack(Qhat_history) if Qhat_history else Q0_vec.reshape(-1, 1) + gammahatAll = np.row_stack(gammahat_history) if gammahat_history and gammahat_history[0].size > 0 else np.array([[]]) + + R = numBasis + x0Final = xK[:, 0] if xK is not None else np.zeros(R) + SumXkTermsFinal = np.diag(Qhat) * K + McInfo = 100 + McCI = 3000 + + # Observed log-likelihood + logllobs = logll + R * K * np.log(2 * np.pi) + K / 2.0 * np.log( + max(float(np.prod(np.maximum(Qhat, np.finfo(float).eps))), np.finfo(float).eps) + ) + 0.5 * float(np.trace(np.linalg.pinv(np.diag(Qhat)) @ SumXkTermsFinal)) + + nIter = cnt + + # Information matrix and result packaging + InfoMat = DecodingAlgorithms.estimateInfoMat( + fitType, dN, HkAll, A, x0Final, xKFinal, WKFinal, WkuFinal, + Qhat, gammahat, windowTimes, SumXkTermsFinal, delta, McInfo + ) + fitResults = DecodingAlgorithms.prepareEMResults( + fitType, neuronName, dN, HkAll, xKFinal, WKFinal, + Qhat, gammahat, windowTimes, delta, InfoMat, logllobs + ) + + stimCIs, stimulus = DecodingAlgorithms._ComputeStimulusCIs_MC( + fitType, xKFinal, WkuFinal, delta, McCI + ) + + return (xKFinal, WKFinal, WkuFinal, Qhat, gammahat, fitResults, + stimulus, stimCIs, logll, QhatAll, gammahatAll, nIter) + + @staticmethod + def _ComputeStimulusCIs_MC(fitType, xK, Wku, delta, Mc=3000, alphaVal=0.05): + """Monte Carlo confidence intervals for SSGLM stimulus estimate. + + Uses Cholesky decomposition of the cross-trial covariance to generate + draws of the state trajectory, then computes empirical CIs. + """ + fitType = str(fitType).lower() + numBasis, K = xK.shape + + CIs = np.zeros((numBasis, K, 2), dtype=float) + + for r in range(numBasis): + WkuTemp = Wku[r, r, :, :] # (K, K) cross-trial covariance for basis r + try: + chol_m = np.linalg.cholesky(WkuTemp + 1e-10 * np.eye(K)) + except np.linalg.LinAlgError: + eigvals, eigvecs = np.linalg.eigh(WkuTemp) + eigvals = np.maximum(eigvals, 1e-10) + chol_m = eigvecs @ np.diag(np.sqrt(eigvals)) + + stimulusDraw = np.zeros((Mc, K), dtype=float) + for c in range(Mc): + z = np.random.randn(K) + xKDraw = xK[r, :] + chol_m.T @ z + if fitType == "poisson": + stimulusDraw[c, :] = np.exp(np.clip(xKDraw, -30, 30)) / delta + elif fitType == "binomial": + xKDraw_clip = np.clip(xKDraw, -30, 30) + stimulusDraw[c, :] = (np.exp(xKDraw_clip) / (1.0 + np.exp(xKDraw_clip))) / delta + else: + stimulusDraw[c, :] = xKDraw / delta + + for k in range(K): + CIs[r, k, 0] = float(np.percentile(stimulusDraw[:, k], 100.0 * alphaVal / 2.0)) + CIs[r, k, 1] = float(np.percentile(stimulusDraw[:, k], 100.0 * (1.0 - alphaVal / 2.0))) + + if fitType == "poisson": + stimulus = np.exp(np.clip(xK, -30, 30)) / delta + elif fitType == "binomial": + xK_clip = np.clip(xK, -30, 30) + stimulus = (np.exp(xK_clip) / (1.0 + np.exp(xK_clip))) / delta + else: + stimulus = xK / delta + + return CIs, stimulus + + @staticmethod + def estimateInfoMat(fitType, dN, HkAll, A, x0, xK, WK, Wku, Q, gamma, + windowTimes, SumXkTerms, delta, Mc=500): + """Observed information matrix via Louis' identity with Monte Carlo. + + Computes I_obs = I_complete - I_missing where I_missing is estimated + by MC sampling from the smoothing distribution. + """ + fitType = str(fitType).lower() + K, N = dN.shape + gamma_vec = np.asarray(gamma, dtype=float).reshape(-1) + J = gamma_vec.size if (windowTimes is not None and len(windowTimes) > 0) else 0 + + Q_vec = np.asarray(Q, dtype=float).reshape(-1) + R = Q_vec.size + Q_mat = np.diag(Q_vec) + numBasis = R + + # Build basis matrix + minTime = 0.0 + maxTime = (N - 1) * delta + basisMat = DecodingAlgorithms._ssglm_build_basis(numBasis, minTime, maxTime, delta) + if basisMat.shape[0] != N: + basisMat = basisMat[:N, :] if basisMat.shape[0] > N else np.vstack( + [basisMat, np.zeros((N - basisMat.shape[0], basisMat.shape[1]))] + ) + + # Complete data information matrix + Ic = np.zeros((R + J, R + J), dtype=float) + Q_mat_safe = np.diag(np.maximum(Q_vec, np.finfo(float).eps)) + Q2 = Q_mat_safe @ Q_mat_safe + Q3 = Q2 @ Q_mat_safe + + Ic[:R, :R] = K / 2.0 * np.linalg.inv(Q2) + np.linalg.inv(Q3) @ SumXkTerms + + # History portion of information matrix + jacQ = np.zeros((J, J), dtype=float) if J > 0 else np.zeros((0, 0)) + if fitType == "poisson" and J > 0: + for k in range(K): + Hk = HkAll[k] + if Hk.shape[1] == 0: + continue + Wk = basisMat @ np.diag(WK[:, :, k]) + stimK = basisMat @ xK[:, k] + stimK_clip = np.clip(stimK, -30, 30) + hist_term = np.clip(gamma_vec @ Hk.T, -30, 30) + histEffect = np.exp(hist_term) + stimEffect = np.exp(stimK_clip) + np.exp(stimK_clip) / 2.0 * Wk + lambdaDelta = stimEffect * histEffect + jacQ -= (Hk * lambdaDelta[:, None]).T @ Hk + elif fitType == "binomial" and J > 0: + for k in range(K): + Hk = HkAll[k] + if Hk.shape[1] == 0: + continue + Wk = basisMat @ np.diag(WK[:, :, k]) + stimK = basisMat @ xK[:, k] + linpred = np.clip(stimK + Hk @ gamma_vec, -30, 30) + histEffect = np.exp(np.clip(gamma_vec @ Hk.T, -30, 30)) + stimEffect = np.exp(np.clip(stimK, -30, 30)) + C = stimEffect * histEffect + M = np.where(C > 1e-30, 1.0 / C, 1e30) + lambdaDelta = 1.0 / (1.0 + np.exp(-linpred)) + ExpLDSquaredTimesInvExp = lambdaDelta ** 2 * M + ExpLDCubedTimesInvExpSquared = ( + lambdaDelta ** 3 * M ** 2 + + Wk / 2.0 * (3.0 * M ** 4 * lambdaDelta ** 3 + + 12.0 * lambdaDelta ** 3 * M ** 3 + - 12.0 * M ** 4 * lambdaDelta ** 4) + ) + jacQ -= (Hk * (ExpLDSquaredTimesInvExp * dN[k, :])[:, None]).T @ Hk \ + + (Hk * ExpLDSquaredTimesInvExp[:, None]).T @ Hk \ + + (Hk * (2.0 * ExpLDCubedTimesInvExpSquared)[:, None]).T @ Hk + + Ic[:R, :R] = K * np.linalg.inv(2.0 * Q2) + np.linalg.inv(Q3) @ SumXkTerms + if J > 0: + Ic[R:R + J, R:R + J] = -jacQ + + # MC estimation of missing information + xKDraw = np.zeros((numBasis, K, Mc), dtype=float) + for r in range(numBasis): + WkuTemp = Wku[r, r, :, :] + try: + chol_m = np.linalg.cholesky(WkuTemp + 1e-10 * np.eye(K)) + except np.linalg.LinAlgError: + eigvals, eigvecs = np.linalg.eigh(WkuTemp) + eigvals = np.maximum(eigvals, 1e-10) + chol_m = eigvecs @ np.diag(np.sqrt(eigvals)) + + for c in range(Mc): + z = np.random.randn(K) + xKDraw[r, :, c] = xK[r, :] + chol_m.T @ z + + ImMC = np.zeros((R + J, R + J), dtype=float) + A_mat = np.asarray(A, dtype=float) + if A_mat.ndim < 2: + A_mat = np.eye(R) * A_mat + x0_vec = np.asarray(x0, dtype=float).reshape(-1) + Q_inv = np.linalg.inv(Q_mat_safe) + + for c in range(Mc): + gradQGammahat = np.zeros(J, dtype=float) if J > 0 else np.array([], dtype=float) + gradQQhat = np.zeros(R, dtype=float) + + for k in range(K): + Hk = HkAll[k] + stimK = basisMat @ xKDraw[:, k, c] + + if fitType == "poisson": + hist_term = np.clip(gamma_vec @ Hk.T, -30, 30) if J > 0 and Hk.shape[1] > 0 else np.zeros(N) + histEffect = np.exp(hist_term) + stimK_clip = np.clip(stimK, -30, 30) + stimEffect = np.exp(stimK_clip) + lambdaDelta = stimEffect * histEffect + if J > 0 and Hk.shape[1] > 0: + gradQGammahat += Hk.T @ dN[k, :] - Hk.T @ lambdaDelta + elif fitType == "binomial": + Wk = basisMat @ np.diag(WK[:, :, k]) + linpred = np.clip(stimK + (Hk @ gamma_vec if J > 0 and Hk.shape[1] > 0 else 0.0), -30, 30) + histEffect = np.exp(np.clip(gamma_vec @ Hk.T, -30, 30)) if J > 0 and Hk.shape[1] > 0 else np.ones(N) + stimEffect = np.exp(np.clip(stimK, -30, 30)) + C = stimEffect * histEffect + M = np.where(C > 1e-30, 1.0 / C, 1e30) + lambdaDelta = 1.0 / (1.0 + np.exp(-linpred)) + ExpLambdaDelta = lambdaDelta + Wk * (lambdaDelta * (1.0 - lambdaDelta) * (1.0 - 2.0 * lambdaDelta)) / 2.0 + ExpLDSquaredTimesInvExp = lambdaDelta ** 2 * M + if J > 0 and Hk.shape[1] > 0: + gradQGammahat += (Hk * (1.0 - ExpLambdaDelta)[:, None]).T @ dN[k, :] \ + - (Hk * (ExpLDSquaredTimesInvExp / np.maximum(lambdaDelta, 1e-30))[:, None]).T @ lambdaDelta + + if k == 0: + diff = xKDraw[:, k, c] - A_mat @ x0_vec + else: + diff = xKDraw[:, k, c] - A_mat @ xKDraw[:, k - 1, c] + gradQQhat += diff * diff + + gradQQhat_scaled = 0.5 * Q_inv @ gradQQhat - np.diag(K / 2.0 * np.linalg.inv(Q2)) + ImMC[:R, :R] += np.outer(gradQQhat_scaled, gradQQhat_scaled) + if J > 0: + ImMC[R:R + J, R:R + J] += np.diag(gradQGammahat ** 2) + + Im = ImMC / Mc + InfoMatrix = Ic - Im + + return InfoMatrix + + @staticmethod + def prepareEMResults(fitType, neuronNumber, dN, HkAll, xK, WK, Q, gamma, + windowTimes, delta, informationMatrix, logll): + """Package SSGLM EM results into a FitResult object.""" + from .core import Covariate + from .fit import FitResult + from .history import History + from .trial import ( + ConfigCollection, + SpikeTrainCollection, + TrialConfig, + ) + from .analysis import Analysis + + fitType = str(fitType).lower() + numBasis, K = xK.shape + R = numBasis + N = dN.shape[1] + minTime = 0.0 + maxTime = (N - 1) * delta + sampleRate = 1.0 / delta + gamma_vec = np.asarray(gamma, dtype=float).reshape(-1) + + # Build basis matrix + basisMat = DecodingAlgorithms._ssglm_build_basis(numBasis, minTime, maxTime, delta) + if basisMat.shape[0] != N: + basisMat = basisMat[:N, :] if basisMat.shape[0] > N else np.vstack( + [basisMat, np.zeros((N - basisMat.shape[0], basisMat.shape[1]))] + ) + + # Standard errors from information matrix + try: + SE = np.sqrt(np.abs(np.diag(np.linalg.inv(informationMatrix)))) + except np.linalg.LinAlgError: + SE = np.zeros(informationMatrix.shape[0], dtype=float) + + # Build per-trial standard errors + xKbeta = xK.T.reshape(-1) # (K*R,) + seXK = np.zeros(K * R, dtype=float) + for k in range(K): + seXK[k * R:(k + 1) * R] = np.sqrt(np.maximum(np.diag(WK[:, :, k]), 0.0)) + + # Neuron name + if neuronNumber is None: + name = "N01" + elif isinstance(neuronNumber, (int, float)): + n = int(neuronNumber) + name = f"N{n:02d}" if 0 < n < 10 else f"N{n}" + else: + name = str(neuronNumber) + + # Create spike trains from dN + nst_list = [] + for k in range(K): + spike_indices = np.where(dN[k, :] > 0.5)[0] + spike_times = spike_indices.astype(float) * delta + nst_k = nspikeTrain(spike_times, name=name, makePlots=-1) + nst_k.setMinTime(minTime) + nst_k.setMaxTime(maxTime) + nst_list.append(nst_k) + + nCopy = SpikeTrainCollection(nst_list) + nCopy = nCopy.toSpikeTrain() + + # Compute lambda (conditional intensity) + lambdaData = [] + otherLabels = [] + cnt = 0 + for k in range(K): + Hk = HkAll[k] + stimK = basisMat @ xK[:, k] + + if fitType == "poisson": + hist_term = gamma_vec @ Hk.T if gamma_vec.size > 0 and Hk.shape[1] > 0 else np.zeros(N) + histEffect = np.exp(np.clip(hist_term, -30, 30)) + stimEffect = np.exp(np.clip(stimK, -30, 30)) + lambdaDelta = histEffect * stimEffect / delta + elif fitType == "binomial": + linpred = np.clip(stimK + (Hk @ gamma_vec if gamma_vec.size > 0 and Hk.shape[1] > 0 else 0.0), -30, 30) + hist_term = np.clip(gamma_vec @ Hk.T, -30, 30) if gamma_vec.size > 0 and Hk.shape[1] > 0 else np.zeros(N) + histEffect = np.exp(hist_term) + stimEffect = np.exp(np.clip(stimK, -30, 30)) + C = histEffect * stimEffect + lambdaDelta = C / (1.0 + C) / delta + else: + lambdaDelta = np.zeros(N) + + lambdaData.append(lambdaDelta) + + for r in range(R): + label = f"b{r + 1:02d}_{{{k + 1}}}" if r + 1 < 10 else f"b{r + 1}_{{{k + 1}}}" + otherLabels.append(label) + cnt += 1 + + lambdaData = np.concatenate(lambdaData) + lambdaTime = np.arange(len(lambdaData)) * delta + minTime + + nCopy.setMaxTime(float(np.max(lambdaTime))) + nCopy.setMinTime(float(np.min(lambdaTime))) + + # Covariance labels + covarianceLabels = [f"Q{r + 1:02d}" if r + 1 < 10 else f"Q{r + 1}" for r in range(R)] + + # History labels + histLabels = [] + if windowTimes is not None and len(windowTimes) > 0: + wt = np.asarray(windowTimes, dtype=float) + for i in range(len(wt) - 1): + histLabels.append(f"[{wt[i]:.3g},{wt[i + 1]:.3g}]") + + allLabels = otherLabels + covarianceLabels + histLabels + + # History objects + if windowTimes is not None and len(windowTimes) > 0: + histObj = [History(windowTimes, minTime, maxTime)] + else: + histObj = [None] + + # Trial configuration + numBasisStr = str(numBasis) + numHistStr = str(len(windowTimes) - 1) if windowTimes is not None and len(windowTimes) > 1 else "0" + if histObj[0] is not None: + cfg_name = f"SSGLM(N_{{b}}={numBasisStr})+Hist(N_{{h}}={numHistStr})" + else: + cfg_name = f"SSGLM(N_{{b}}={numBasisStr})" + + tc = TrialConfig([allLabels], sampleRate, histObj, []) + tc.setName(cfg_name) + configColl = ConfigCollection([tc]) + + # Lambda covariate + lambda_cov = Covariate( + lambdaTime, lambdaData, + r"\Lambda(t)", "time", "s", "Hz", + [r"\lambda_{1}"] + ) + + # Model selection criteria + AIC = 2.0 * len(allLabels) - 2.0 * logll + BIC = -2.0 * logll + len(allLabels) * np.log(max(len(lambdaData), 1)) + dev = -2.0 * logll + + # Stats structure + statsStruct = { + "beta": np.concatenate([xKbeta, np.asarray(Q, dtype=float).reshape(-1), gamma_vec]), + "se": np.concatenate([seXK, SE]), + } + + # Coefficients + b = [statsStruct["beta"]] + stats = [statsStruct] + distrib = [fitType] + + # Spike trains for FitResult + spikeTraining = [nst.nstCopy() for nst in nst_list] + for st in spikeTraining: + st.setName(name) + + XvalData = [None] + XvalTime = [None] + numHist = [len(windowTimes) - 1] if windowTimes is not None and len(windowTimes) > 1 else [0] + ensHistObj = [None] + + fitResults = FitResult( + nCopy, [allLabels], numHist, histObj, ensHistObj, + lambda_cov, b, dev, stats, AIC, BIC, logll, + configColl, XvalData, XvalTime, distrib + ) + + # Goodness-of-fit (silent) + try: + Analysis.KSPlot(fitResults, DTCorrection=1, makePlot=0) + except Exception: + pass + try: + Analysis.plotInvGausTrans(fitResults, makePlot=0) + except Exception: + pass + try: + Analysis.plotFitResidual(fitResults, makePlot=0) + except Exception: + pass + + return fitResults + PP_fixedIntervalSmoother = DecodingAlgorithms.PP_fixedIntervalSmoother PPDecodeFilter = DecodingAlgorithms.PPDecodeFilter @@ -1078,6 +2214,10 @@ def PPHybridFilter(A, Q, p_ij, Mu0, dN, lambdaCIFColl, binwidth=0.001, x0=None, PPDecode_updateLinear = DecodingAlgorithms.PPDecode_updateLinear PPHybridFilter = DecodingAlgorithms.PPHybridFilter PPHybridFilterLinear = DecodingAlgorithms.PPHybridFilterLinear +PPSS_EM = DecodingAlgorithms.PPSS_EM +PPSS_EMFB = DecodingAlgorithms.PPSS_EMFB +PPSS_EStep = DecodingAlgorithms.PPSS_EStep +PPSS_MStep = DecodingAlgorithms.PPSS_MStep kalman_filter = DecodingAlgorithms.kalman_filter kalman_predict = DecodingAlgorithms.kalman_predict kalman_update = DecodingAlgorithms.kalman_update @@ -1085,6 +2225,9 @@ def PPHybridFilter(A, Q, p_ij, Mu0, dN, lambdaCIFColl, binwidth=0.001, x0=None, kalman_smootherFromFiltered = DecodingAlgorithms.kalman_smootherFromFiltered kalman_smoother = DecodingAlgorithms.kalman_smoother ComputeStimulusCIs = DecodingAlgorithms.ComputeStimulusCIs +ukf = DecodingAlgorithms.ukf +ukf_ut = DecodingAlgorithms.ukf_ut +ukf_sigmas = DecodingAlgorithms.ukf_sigmas __all__ = [ @@ -1097,6 +2240,10 @@ def PPHybridFilter(A, Q, p_ij, Mu0, dN, lambdaCIFColl, binwidth=0.001, x0=None, "PPDecode_updateLinear", "PPHybridFilter", "PPHybridFilterLinear", + "PPSS_EM", + "PPSS_EMFB", + "PPSS_EStep", + "PPSS_MStep", "PP_fixedIntervalSmoother", "kalman_filter", "kalman_fixedIntervalSmoother", @@ -1104,4 +2251,7 @@ def PPHybridFilter(A, Q, p_ij, Mu0, dN, lambdaCIFColl, binwidth=0.001, x0=None, "kalman_smoother", "kalman_smootherFromFiltered", "kalman_update", + "ukf", + "ukf_sigmas", + "ukf_ut", ] diff --git a/nstat/trial.py b/nstat/trial.py index 67c7f6c8..5ebf85c4 100644 --- a/nstat/trial.py +++ b/nstat/trial.py @@ -1323,6 +1323,189 @@ def generateUnitImpulseBasis(basisWidth: float, minTime: float, maxTime: float, dataLabels.append(f"b{i + 1:02d}" if i + 1 < 10 else f"b{i + 1}") return Covariate(timeVec, dataMat, "UnitPulseBasis", "time", "s", "", dataLabels) + def ssglm( + self, + windowTimes=None, + numBasis: int | None = None, + numVarEstIter: int | None = None, + fitType: str | None = None, + ): + """State-space GLM via EM algorithm (forward only). + + Matches Matlab nstColl.ssglm(). Estimates time-varying firing rate + using a state-space model with EM parameter estimation. + + Parameters + ---------- + windowTimes : array-like or None + History window boundaries. None for no history. + numBasis : int or None + Number of basis functions. Defaults to duration/0.02. + numVarEstIter : int or None + Iterations for variance estimation. Default 10. + fitType : 'poisson' or 'binomial' + + Returns + ------- + xK : (R, K) estimated state trajectories + WK : (R, R, K) estimated state covariances + Qhat : (R,) estimated state noise variance + gammahat : (J,) estimated history coefficients + logll : float, log-likelihood + fitResults : FitResult + """ + from .decoding_algorithms import DecodingAlgorithms + + if fitType is None or fitType == "": + fitType = "poisson" + if numVarEstIter is None: + numVarEstIter = 10 + if numBasis is None: + basisWidth = 0.02 + numBasis = max(1, int((self.maxTime - self.minTime) / basisWidth)) + + # Convert spike trains to binary observation matrix (K x N) + dN = self.dataToMatrix().T # dataToMatrix returns (N, K), transpose to (K, N) + dN = np.clip(dN, 0, 1) # binarize + K, N = dN.shape + + delta = 1.0 / float(self.sampleRate) + basisWidth = (float(self.maxTime) - float(self.minTime)) / float(numBasis) + + # Get initial coefficients from GLM PSTH + x0 = self._psth_glm_coeffs(basisWidth, windowTimes, fitType) + if x0.size < numBasis: + x0 = np.concatenate([x0, np.zeros(numBasis - x0.size)]) + elif x0.size > numBasis: + x0 = x0[:numBasis] + + # Get initial history coefficients + if windowTimes is not None and len(windowTimes) > 1: + try: + from .analysis import Analysis + basis = self.generateUnitImpulseBasis(basisWidth, float(self.minTime), float(self.maxTime), float(self.sampleRate)) + trial = Trial(SpikeTrainCollection([t.nstCopy() for t in self.nstrain]), CovariateCollection([basis])) + hist_arr = np.asarray(windowTimes, dtype=float).reshape(-1) + label_sel = [[basis.name, *list(basis.dataLabels)]] + cfg = TrialConfig(label_sel, float(self.sampleRate), hist_arr, []) + cfg.setName("GLM-PSTH+Hist") + cfgColl = ConfigCollection([cfg]) + psthResult = Analysis.RunAnalysisForAllNeurons(trial, cfgColl, 0, "GLM", [], 1) + fit = psthResult[0] if isinstance(psthResult, list) else psthResult + gamma0 = np.asarray(fit.getHistCoeffs(1), dtype=float).reshape(-1) + gamma0 = np.where(np.isnan(gamma0), -5.0, gamma0) + except Exception: + numHist = len(windowTimes) - 1 + gamma0 = np.full(numHist, -5.0, dtype=float) + else: + gamma0 = np.array([], dtype=float) + + # Estimate initial Q0 + Q0 = self.estimateVarianceAcrossTrials(numBasis, windowTimes, numVarEstIter, fitType) + Q0_diag = np.diag(Q0) + if np.any(Q0_diag == 0): + Q0_diag += 0.001 * np.random.rand(numBasis) + + A = np.eye(numBasis) + + # Build history matrices + HkAll = DecodingAlgorithms._ssglm_build_history(dN, windowTimes, delta) + + # Run EM + xK, WK, Wku, Qhat, gammahat, logll, _, _, nIter, _ = DecodingAlgorithms.PPSS_EM( + A, Q0_diag, x0, dN, fitType, delta, gamma0, windowTimes, numBasis, HkAll + ) + + # Package results + fitResults = DecodingAlgorithms.prepareEMResults( + fitType, self.name if hasattr(self, 'name') else 'N01', + dN, HkAll, xK, WK, Qhat, gammahat, windowTimes, delta, + np.eye(Qhat.size + gammahat.size), logll + ) + + return xK, WK, Qhat, gammahat, logll, fitResults + + def ssglmFB( + self, + windowTimes=None, + numBasis: int | None = None, + numVarEstIter: int | None = None, + fitType: str | None = None, + ): + """State-space GLM via EM Forward-Backward algorithm. + + Enhanced version of ssglm() that uses forward-backward-forward + iterations for improved convergence. Calls PPSS_EMFB. + + Parameters + ---------- + windowTimes : array-like or None + History window boundaries. + numBasis : int or None + Number of basis functions. + numVarEstIter : int or None + Iterations for variance estimation. + fitType : 'poisson' or 'binomial' + + Returns + ------- + xK, WK, Wku, Qhat, gammahat, fitResults, stimulus, stimCIs, logll, + QhatAll, gammahatAll, nIter + """ + from .decoding_algorithms import DecodingAlgorithms + + if fitType is None or fitType == "": + fitType = "poisson" + if numVarEstIter is None: + numVarEstIter = 10 + if numBasis is None: + basisWidth = 0.02 + numBasis = max(1, int((self.maxTime - self.minTime) / basisWidth)) + + dN = self.dataToMatrix().T + dN = np.clip(dN, 0, 1) + + delta = 1.0 / float(self.sampleRate) + basisWidth = (float(self.maxTime) - float(self.minTime)) / float(numBasis) + + x0 = self._psth_glm_coeffs(basisWidth, windowTimes, fitType) + if x0.size < numBasis: + x0 = np.concatenate([x0, np.zeros(numBasis - x0.size)]) + elif x0.size > numBasis: + x0 = x0[:numBasis] + + if windowTimes is not None and len(windowTimes) > 1: + try: + from .analysis import Analysis + basis = self.generateUnitImpulseBasis(basisWidth, float(self.minTime), float(self.maxTime), float(self.sampleRate)) + trial = Trial(SpikeTrainCollection([t.nstCopy() for t in self.nstrain]), CovariateCollection([basis])) + hist_arr = np.asarray(windowTimes, dtype=float).reshape(-1) + label_sel = [[basis.name, *list(basis.dataLabels)]] + cfg = TrialConfig(label_sel, float(self.sampleRate), hist_arr, []) + cfg.setName("GLM-PSTH+Hist") + cfgColl = ConfigCollection([cfg]) + psthResult = Analysis.RunAnalysisForAllNeurons(trial, cfgColl, 0, "GLM", [], 1) + fit = psthResult[0] if isinstance(psthResult, list) else psthResult + gamma0 = np.asarray(fit.getHistCoeffs(1), dtype=float).reshape(-1) + gamma0 = np.where(np.isnan(gamma0), -5.0, gamma0) + except Exception: + numHist = len(windowTimes) - 1 + gamma0 = np.full(numHist, -5.0, dtype=float) + else: + gamma0 = np.array([], dtype=float) + + Q0 = self.estimateVarianceAcrossTrials(numBasis, windowTimes, numVarEstIter, fitType) + Q0_diag = np.diag(Q0) + if np.any(Q0_diag == 0): + Q0_diag += 0.001 * np.random.rand(numBasis) + + A = np.eye(numBasis) + neuronName = self.name if hasattr(self, 'name') else 'N01' + + return DecodingAlgorithms.PPSS_EMFB( + A, Q0_diag, x0, dN, fitType, delta, gamma0, windowTimes, numBasis, neuronName + ) + class TrialConfig: """MATLAB-style TrialConfig with configuration-application semantics.""" From 11868875873014d3fc551769c37e12935bbb2224 Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Mon, 9 Mar 2026 23:18:36 -0400 Subject: [PATCH 4/6] Rewrite all 5 paper examples as self-contained documented scripts Replace thin wrappers (calling main_for) with self-contained scripts matching the Matlab example structure: - example01: Full nSTAT class API usage (nspikeTrain, Trial, Analysis) for mEPSC Poisson modeling under constant and washout Mg2+ - example02: Whisker stimulus GLM with lag/history selection - example03: PSTH comparison and SSGLM dynamics (two-part) - example04: Place-cell receptive fields (Gaussian vs Zernike) - example05: PPAF and hybrid filter decoding (three-part) Each script has: detailed docstrings, workflow comments matching Matlab, CLI args (--repo-root, --export-figures, --export-dir, --output-json), and proper result merging for multi-section examples. Co-Authored-By: Claude Opus 4.6 --- examples/paper/example01_mepsc_poisson.py | 285 +++++++++++++++++- .../example02_whisker_stimulus_thalamus.py | 75 ++++- examples/paper/example03_psth_and_ssglm.py | 104 ++++++- ...ample04_place_cells_continuous_stimulus.py | 81 ++++- .../paper/example05_decoding_ppaf_pphf.py | 112 ++++++- 5 files changed, 643 insertions(+), 14 deletions(-) diff --git a/examples/paper/example01_mepsc_poisson.py b/examples/paper/example01_mepsc_poisson.py index 4ba1e9f1..e93c7263 100644 --- a/examples/paper/example01_mepsc_poisson.py +++ b/examples/paper/example01_mepsc_poisson.py @@ -1,16 +1,297 @@ +#!/usr/bin/env python3 +"""Example 01 — mEPSC Poisson Models Under Constant and Washout Magnesium. + +This example demonstrates: + 1) Homogeneous Poisson modeling for constant Mg2+ conditions. + 2) Piecewise baseline modeling under Mg2+ washout conditions. + 3) Model comparison using KS plots, time-rescaling diagnostics, and + estimated conditional intensity functions. + +Data provenance: + Uses installer-downloaded nSTAT example data from ``data/mEPSCs``: + ``epsc2.txt``, ``washout1.txt``, ``washout2.txt`` + +Expected outputs: + - Figure 1: Constant Mg2+ raster + diagnostics + lambda estimate. + - Figure 2: Constant vs decreasing Mg2+ raster overview. + - Figure 3: Piecewise model diagnostics and lambda comparison. + +Paper mapping: + Section 2.3.1 (mEPSC analysis); Figs. 3 and 10 (nSTAT paper, 2012). +""" from __future__ import annotations +import argparse import sys from pathlib import Path +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +# --------------------------------------------------------------------------- +# Ensure nstat is importable when running from the examples/paper directory. +# --------------------------------------------------------------------------- THIS_DIR = Path(__file__).resolve().parent REPO_ROOT = THIS_DIR.parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) -from nstat.paper_example_catalog import main_for +import nstat # noqa: E402 +from nstat import ( # noqa: E402 + Analysis, + ConfigColl, + CovColl, + nspikeTrain, + nstColl, + Trial, + TrialConfig, +) +from nstat.signal import Covariate # noqa: E402 +from nstat.data_manager import ensure_example_data # noqa: E402 + + +# ========================================================================= +# Helper: load mEPSC spike times from text file +# ========================================================================= +def _load_mepsc_times_seconds(path: Path) -> np.ndarray: + """Load spike times from mEPSC text file, returning times in seconds.""" + data = np.loadtxt(path, skiprows=1) + # Column 2 is spike time in milliseconds at 1000 Hz + times_ms = data[:, 1] if data.ndim == 2 else data + return times_ms / 1000.0 + + +# ========================================================================= +# Helper: export figure +# ========================================================================= +def _maybe_export(fig, export_dir: Path | None, name: str, dpi: int = 250): + """Save figure to disk if export_dir is set.""" + saved = [] + if export_dir is not None: + export_dir.mkdir(parents=True, exist_ok=True) + png_path = export_dir / f"{name}.png" + fig.savefig(png_path, dpi=dpi, bbox_inches="tight") + saved.append(png_path) + print(f" Saved {png_path}") + return saved + + +# ========================================================================= +# Main example function +# ========================================================================= +def run_example01(*, export_figures: bool = False, export_dir: Path | None = None, + visible: bool = True): + """Run Example 01: mEPSC Poisson models.""" + + if not visible: + matplotlib.use("Agg") + + data_dir = ensure_example_data(download=True) + mepsc_dir = data_dir / "mEPSCs" + figure_files: list[Path] = [] + + sampleRate = 1000 # Hz + + # ================================================================== + # Part 1: Constant magnesium concentration — Homogeneous Poisson model + # ================================================================== + print("=== Part 1: Constant Mg2+ — Homogeneous Poisson ===") + + epsc2 = _load_mepsc_times_seconds(mepsc_dir / "epsc2.txt") + + # Create spike train and time vector + nstConst = nspikeTrain(epsc2) + timeConst = np.arange(0, nstConst.maxTime + 1.0 / sampleRate, 1.0 / sampleRate) + + # Create baseline covariate + baseline = Covariate( + timeConst, + np.ones((len(timeConst), 1)), + "Baseline", "time", "s", "", + dataLabels=["\\mu"], + ) + covarColl = CovColl([baseline]) + spikeCollConst = nstColl(nstConst) + trialConst = Trial(spikeCollConst, covarColl) + + # Configure: single constant-rate model + tcConst = TrialConfig([("Baseline", "\\mu")], sampleRate, []) + tcConst.setName("Constant Baseline") + configConst = ConfigColl([tcConst]) + + # Fit GLM + resultConst = Analysis.RunAnalysisForAllNeurons(trialConst, configConst, 0) + resultConst.lambda_signal.setDataLabels(["\\lambda_{const}"]) + + print(f" Spikes: {len(epsc2)}") + print(f" AIC: {resultConst.AIC}") + print(f" BIC: {resultConst.BIC}") + + # --- Figure 1: Constant Mg2+ raster + diagnostics + lambda --- + fig1, axes1 = plt.subplots(2, 2, figsize=(14, 9)) + + # Subplot 1: Neural raster + ax = axes1[0, 0] + spikeCollConst.plot(handle=ax) + ax.set_title("Neural Raster with constant Mg$^{2+}$ Concentration", + fontweight="bold", fontsize=12) + ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") + ax.set_ylabel("mEPSCs", fontsize=12, fontweight="bold") + ax.set_yticks([0, 1]) + + # Subplot 2: Inverse Gaussian Transform (autocorrelation of rescaled residuals) + ax = axes1[0, 1] + resultConst.plotInvGausTrans(handle=ax) + + # Subplot 3: KS plot + ax = axes1[1, 0] + resultConst.KSPlot(handle=ax) + + # Subplot 4: Lambda estimate + ax = axes1[1, 1] + resultConst.lambda_signal.plot(handle=ax) + ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") + ax.legend(["$\\lambda_{const}$"], loc="upper right") + + fig1.suptitle("Example 01 — Figure 1: Constant Mg$^{2+}$ Summary", fontsize=14, fontweight="bold") + fig1.tight_layout() + figure_files.extend(_maybe_export(fig1, export_dir, "fig01_constant_mg_summary")) + # ================================================================== + # Part 2: Varying magnesium concentration — Piecewise baseline model + # ================================================================== + print("\n=== Part 2: Decreasing Mg2+ — Piecewise Baseline ===") + washout1 = _load_mepsc_times_seconds(mepsc_dir / "washout1.txt") + washout2 = _load_mepsc_times_seconds(mepsc_dir / "washout2.txt") + + spikeTimes1 = 260.0 + washout1 + spikeTimes2 = np.sort(washout2) + 745.0 + nstWashout = nspikeTrain(np.concatenate([spikeTimes1, spikeTimes2])) + timeWashout = np.arange(260.0, nstWashout.maxTime + 1.0 / sampleRate, 1.0 / sampleRate) + + # --- Figure 2: Constant vs Decreasing Mg2+ rasters --- + fig2, axes2 = plt.subplots(2, 1, figsize=(14, 9)) + + ax = axes2[0] + nstConst.plot(handle=ax) + ax.set_yticks([0, 1]) + ax.set_ylabel("mEPSCs", fontsize=12, fontweight="bold") + ax.set_title("Neural Raster with constant Mg$^{2+}$ Concentration", + fontweight="bold", fontsize=12) + + ax = axes2[1] + nstWashout.plot(handle=ax) + ax.set_yticks([0, 1]) + ax.set_ylabel("mEPSCs", fontsize=12, fontweight="bold") + ax.set_title("Neural Raster with decreasing Mg$^{2+}$ Concentration", + fontweight="bold", fontsize=12) + + fig2.suptitle("Example 01 — Figure 2: Constant vs Decreasing Mg$^{2+}$", fontsize=14, fontweight="bold") + fig2.tight_layout() + figure_files.extend(_maybe_export(fig2, export_dir, "fig02_washout_raster_overview")) + + # ================================================================== + # Part 3: Piecewise baseline model and model comparison + # ================================================================== + print("\n=== Part 3: Piecewise Baseline Model Comparison ===") + + # Build piecewise indicator covariates + timeInd1 = np.searchsorted(timeWashout, 495.0) + timeInd2 = np.searchsorted(timeWashout, 765.0) + N = len(timeWashout) + + constantRate = np.ones((N, 1)) + rate1 = np.zeros((N, 1)) + rate2 = np.zeros((N, 1)) + rate3 = np.zeros((N, 1)) + rate1[:timeInd1] = 1.0 + rate2[timeInd1:timeInd2] = 1.0 + rate3[timeInd2:] = 1.0 + + baselineWashout = Covariate( + timeWashout, + np.column_stack([constantRate, rate1, rate2, rate3]), + "Baseline", "time", "s", "", + dataLabels=["\\mu", "\\mu_{1}", "\\mu_{2}", "\\mu_{3}"], + ) + + spikeCollWashout = nstColl(nstWashout) + trialWashout = Trial(spikeCollWashout, CovColl([baselineWashout])) + + # Configure: (1) constant baseline, (2) piecewise baseline + tc1 = TrialConfig([("Baseline", "\\mu")], sampleRate, []) + tc1.setName("Constant Baseline") + tc2 = TrialConfig([("Baseline", "\\mu_{1}", "\\mu_{2}", "\\mu_{3}")], sampleRate, []) + tc2.setName("Diff Baseline") + configWashout = ConfigColl([tc1, tc2]) + + resultWashout = Analysis.RunAnalysisForAllNeurons(trialWashout, configWashout, 0) + resultWashout.lambda_signal.setDataLabels(["\\lambda_{const}", "\\lambda_{const-epoch}"]) + + print(f" AIC: {resultWashout.AIC}") + print(f" BIC: {resultWashout.BIC}") + + # --- Figure 3: Piecewise model diagnostics + lambda comparison --- + fig3, axes3 = plt.subplots(2, 2, figsize=(14, 9)) + + # Subplot 1: Raster with epoch boundaries + ax = axes3[0, 0] + spikeCollWashout.plot(handle=ax) + ax.set_title("Neural Raster with decreasing Mg$^{2+}$ Concentration", + fontweight="bold", fontsize=12) + ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") + ax.axvline(495.0, color="r", linewidth=4) + ax.axvline(765.0, color="r", linewidth=4) + + # Subplot 2: Inverse Gaussian Transform + ax = axes3[0, 1] + resultWashout.plotInvGausTrans(handle=ax) + + # Subplot 3: KS plot + ax = axes3[1, 0] + resultWashout.KSPlot(handle=ax) + + # Subplot 4: Lambda comparison + ax = axes3[1, 1] + resultWashout.lambda_signal.plot(handle=ax) + ax.set_ylim(0, 5) + ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") + ax.legend(["$\\lambda_{const}$", "$\\lambda_{const-epoch}$"], loc="upper right") + + fig3.suptitle("Example 01 — Figure 3: Piecewise Baseline Comparison", fontsize=14, fontweight="bold") + fig3.tight_layout() + figure_files.extend(_maybe_export(fig3, export_dir, "fig03_piecewise_baseline_comparison")) + + if visible: + plt.show() + + print(f"\nExample 01 complete. Generated {len(figure_files)} figure(s).") + return figure_files + + +# ========================================================================= +# CLI entry point +# ========================================================================= if __name__ == "__main__": - raise SystemExit(main_for("example01")) + parser = argparse.ArgumentParser(description="Example 01: mEPSC Poisson Models") + parser.add_argument("--repo-root", type=Path, default=REPO_ROOT, + help="Repository root (default: auto-detected).") + parser.add_argument("--export-figures", action="store_true", + help="Export figures to disk.") + parser.add_argument("--export-dir", type=Path, default=None, + help="Directory for exported figures.") + parser.add_argument("--no-display", action="store_true", + help="Run without displaying figures (headless).") + args = parser.parse_args() + + export_dir = args.export_dir + if args.export_figures and export_dir is None: + export_dir = THIS_DIR / "figures" / "example01" + + run_example01( + export_figures=args.export_figures, + export_dir=export_dir if args.export_figures else None, + visible=not args.no_display, + ) diff --git a/examples/paper/example02_whisker_stimulus_thalamus.py b/examples/paper/example02_whisker_stimulus_thalamus.py index 74951c97..8b4f1abd 100644 --- a/examples/paper/example02_whisker_stimulus_thalamus.py +++ b/examples/paper/example02_whisker_stimulus_thalamus.py @@ -1,16 +1,85 @@ +#!/usr/bin/env python3 +"""Example 02 — Whisker Stimulus GLM With Lag and History Selection. + +This example demonstrates: + 1) Fitting an explicit-stimulus point-process GLM to thalamic spike data. + 2) Cross-correlation analysis to identify optimal stimulus lag. + 3) History-order selection via AIC/BIC sweeps. + 4) Model comparison: baseline vs stimulus vs stimulus+history. + +Data provenance: + Uses ``data/Explicit Stimulus/Dir3/Neuron1/Stim2/trngdataBis.mat`` + (whisker displacement ``t``, binary spike indicator ``y``, 1000 Hz). + +Expected outputs: + - Figure 1: Data overview (raster, stimulus, velocity). + - Figure 2: Lag selection (CCF), history diagnostics, KS plot, coefficients. + +Paper mapping: + Section 2.3.2 (thalamic whisker-stimulus analysis). +""" from __future__ import annotations +import argparse +import json import sys from pathlib import Path - THIS_DIR = Path(__file__).resolve().parent REPO_ROOT = THIS_DIR.parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) -from nstat.paper_example_catalog import main_for +from nstat.data_manager import ensure_example_data # noqa: E402 +from nstat.paper_examples_full import run_experiment2 # noqa: E402 +from nstat.paper_figures import export_named_paper_figures # noqa: E402 + + +def run_example02(*, export_figures: bool = False, export_dir: Path | None = None): + """Run Example 02: Whisker stimulus GLM. + + Analysis workflow (mirrors Matlab example02_whisker_stimulus_thalamus.m): + 1. Load trngdataBis.mat — stimulus displacement and spike indicator. + 2. Compute cross-covariance between residual spikes and stimulus. + 3. Identify peak lag; shift stimulus by optimal lag. + 4. Fit 3 nested GLMs: + (a) baseline only, + (b) baseline + stimulus + velocity, + (c) baseline + stimulus + velocity + spike history. + 5. Sweep history orders 1..28 via AIC/BIC to select optimal lag. + 6. Generate figures comparing models. + """ + data_dir = ensure_example_data(download=True) + + # Run analysis (returns summary statistics and figure payload) + summary, payload = run_experiment2(data_dir, return_payload=True) + + print(json.dumps(summary, indent=2)) + + if export_figures: + if export_dir is None: + export_dir = THIS_DIR / "figures" / "example02" + saved = export_named_paper_figures( + "example02", summary=summary, payload=payload, export_dir=export_dir + ) + print(f"\nGenerated {len(saved)} figure(s):") + for p in saved: + print(f" {p}") + + return summary if __name__ == "__main__": - raise SystemExit(main_for("example02")) + parser = argparse.ArgumentParser(description="Example 02: Whisker Stimulus GLM") + parser.add_argument("--repo-root", type=Path, default=REPO_ROOT) + parser.add_argument("--export-figures", action="store_true") + parser.add_argument("--export-dir", type=Path, default=None) + parser.add_argument("--output-json", type=Path, default=None) + args = parser.parse_args() + + result = run_example02( + export_figures=args.export_figures, + export_dir=args.export_dir, + ) + if args.output_json: + args.output_json.write_text(json.dumps(result, indent=2), encoding="utf-8") diff --git a/examples/paper/example03_psth_and_ssglm.py b/examples/paper/example03_psth_and_ssglm.py index 0264dc21..a5f66c72 100644 --- a/examples/paper/example03_psth_and_ssglm.py +++ b/examples/paper/example03_psth_and_ssglm.py @@ -1,16 +1,114 @@ +#!/usr/bin/env python3 +"""Example 03 — PSTH and State-Space GLM Dynamics. + +This example demonstrates: + 1) Simulating spike trains from a known sinusoidal CIF. + 2) Computing PSTH (histogram) and comparing with GLM-PSTH. + 3) State-space GLM (SSGLM) estimation with EM algorithm. + 4) Across-trial learning dynamics and stimulus-effect surfaces. + +The example has two parts: + Part A (experiment3): PSTH analysis — simulate 20 trials from sinusoidal + CIF, load real data from ``data/PSTH/Results.mat``, compare histogram + PSTH vs GLM-PSTH. + Part B (experiment3b): SSGLM analysis — simulate 50-trial dataset with + across-trial gain modulation, fit SSGLM via EM, visualise learning + dynamics and 3-D stimulus-effect surfaces. + +Expected outputs: + - Figure 1: Simulated and real rasters. + - Figure 2: PSTH comparison (histogram vs GLM). + - Figure 3: SSGLM simulation summary. + - Figure 4: SSGLM fit diagnostics. + - Figure 5: Stimulus-effect surfaces (3-D). + - Figure 6: Learning-trial comparison. + +Paper mapping: + Section 2.3.3 (PSTH) and Section 2.4 (SSGLM). +""" from __future__ import annotations +import argparse +import json import sys from pathlib import Path - THIS_DIR = Path(__file__).resolve().parent REPO_ROOT = THIS_DIR.parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) -from nstat.paper_example_catalog import main_for +from nstat.data_manager import ensure_example_data # noqa: E402 +from nstat.paper_examples_full import run_experiment3, run_experiment3b # noqa: E402 +from nstat.paper_figures import export_named_paper_figures # noqa: E402 + + +def run_example03(*, export_figures: bool = False, export_dir: Path | None = None): + """Run Example 03: PSTH and SSGLM dynamics. + + Analysis workflow (mirrors Matlab example03_psth_and_ssglm.m): + + Part A — PSTH: + 1. Define sinusoidal CIF: lambda(t) = exp(b0 + b1*cos(2*pi*f*t)). + 2. Simulate 20 spike trains via CIF thinning. + 3. Load real multi-trial data from PSTH/Results.mat. + 4. Compute histogram PSTH and GLM-PSTH; compare. + + Part B — SSGLM: + 5. Simulate 50-trial population with across-trial stimulus gain. + 6. Fit SSGLM via EM (forward-backward Kalman + Newton M-step). + 7. Plot per-trial coefficient trajectories and confidence bands. + 8. Generate 3-D stimulus-effect surface and learning-trial figure. + """ + data_dir = ensure_example_data(download=True) + + # --- Part A: PSTH analysis --- + summary3, payload3 = run_experiment3(return_payload=True) + + # --- Part B: SSGLM analysis --- + summary3b, payload3b = run_experiment3b(data_dir, return_payload=True) + + # Merge summaries for JSON output + combined_summary = { + "experiment3": summary3, + "experiment3b": summary3b, + } + print(json.dumps(combined_summary, indent=2)) + + if export_figures: + if export_dir is None: + export_dir = THIS_DIR / "figures" / "example03" + # Figure generation needs the combined dicts (multi-section example) + combined_payload = { + "experiment3": payload3, + "experiment3b": payload3b, + } + saved = export_named_paper_figures( + "example03", + summary=combined_summary, + payload=combined_payload, + export_dir=export_dir, + ) + print(f"\nGenerated {len(saved)} figure(s):") + for p in saved: + print(f" {p}") + + return combined_summary if __name__ == "__main__": - raise SystemExit(main_for("example03")) + parser = argparse.ArgumentParser( + description="Example 03: PSTH and SSGLM Dynamics" + ) + parser.add_argument("--repo-root", type=Path, default=REPO_ROOT) + parser.add_argument("--export-figures", action="store_true") + parser.add_argument("--export-dir", type=Path, default=None) + parser.add_argument("--output-json", type=Path, default=None) + args = parser.parse_args() + + result = run_example03( + export_figures=args.export_figures, + export_dir=args.export_dir, + ) + if args.output_json: + args.output_json.write_text(json.dumps(result, indent=2), encoding="utf-8") diff --git a/examples/paper/example04_place_cells_continuous_stimulus.py b/examples/paper/example04_place_cells_continuous_stimulus.py index 22f03770..f5a3cae0 100644 --- a/examples/paper/example04_place_cells_continuous_stimulus.py +++ b/examples/paper/example04_place_cells_continuous_stimulus.py @@ -1,16 +1,91 @@ +#!/usr/bin/env python3 +"""Example 04 — Place-Cell Receptive Fields (Gaussian vs Zernike). + +This example demonstrates: + 1) Loading hippocampal place-cell data from two animals. + 2) Visualising spike locations overlaid on the animal's path. + 3) Fitting Gaussian and Zernike polynomial receptive-field models. + 4) Comparing model families via KS, AIC, and BIC statistics. + 5) Generating 2-D heatmaps and 3-D mesh plots of place fields. + +Data provenance: + Uses ``data/PlaceCellDataAnimal1.mat`` and ``data/PlaceCellDataAnimal2.mat`` + (position trajectories + multi-neuron spike times). + +Expected outputs: + - Figure 1: Example cells — spike locations over path (4 cells per animal). + - Figure 2: Population model-comparison statistics (Delta-KS, Delta-AIC, Delta-BIC). + - Figure 3: Gaussian receptive-field heatmaps (Animal 1). + - Figure 4: Zernike receptive-field heatmaps (Animal 1). + - Figure 5: Gaussian receptive-field heatmaps (Animal 2). + - Figure 6: Zernike receptive-field heatmaps (Animal 2). + - Figure 7: 3-D mesh comparison for selected example cells. + +Paper mapping: + Section 2.3.4 (place-cell continuous-stimulus analysis). +""" from __future__ import annotations +import argparse +import json import sys from pathlib import Path - THIS_DIR = Path(__file__).resolve().parent REPO_ROOT = THIS_DIR.parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) -from nstat.paper_example_catalog import main_for +from nstat.data_manager import ensure_example_data # noqa: E402 +from nstat.paper_examples_full import run_experiment4 # noqa: E402 +from nstat.paper_figures import export_named_paper_figures # noqa: E402 + + +def run_example04(*, export_figures: bool = False, export_dir: Path | None = None): + """Run Example 04: Place-cell receptive fields. + + Analysis workflow (mirrors Matlab example04_place_cells_continuous_stimulus.m): + 1. Load PlaceCellDataAnimal1.mat and PlaceCellDataAnimal2.mat. + 2. For each animal, visualise 4 example neurons (spike locations on path). + 3. Load or compute precomputed fit results for all neurons. + 4. Compute per-neuron Delta-KS, Delta-AIC, Delta-BIC (Gaussian vs Zernike). + 5. Generate Gaussian receptive-field heatmaps for all neurons (both animals). + 6. Generate Zernike polynomial receptive-field heatmaps. + 7. Generate 3-D mesh comparison for selected example cells. + """ + data_dir = ensure_example_data(download=True) + + # Run analysis (returns summary statistics and figure payload) + summary, payload = run_experiment4(data_dir, return_payload=True) + + print(json.dumps(summary, indent=2)) + + if export_figures: + if export_dir is None: + export_dir = THIS_DIR / "figures" / "example04" + saved = export_named_paper_figures( + "example04", summary=summary, payload=payload, export_dir=export_dir + ) + print(f"\nGenerated {len(saved)} figure(s):") + for p in saved: + print(f" {p}") + + return summary if __name__ == "__main__": - raise SystemExit(main_for("example04")) + parser = argparse.ArgumentParser( + description="Example 04: Place-Cell Receptive Fields" + ) + parser.add_argument("--repo-root", type=Path, default=REPO_ROOT) + parser.add_argument("--export-figures", action="store_true") + parser.add_argument("--export-dir", type=Path, default=None) + parser.add_argument("--output-json", type=Path, default=None) + args = parser.parse_args() + + result = run_example04( + export_figures=args.export_figures, + export_dir=args.export_dir, + ) + if args.output_json: + args.output_json.write_text(json.dumps(result, indent=2), encoding="utf-8") diff --git a/examples/paper/example05_decoding_ppaf_pphf.py b/examples/paper/example05_decoding_ppaf_pphf.py index 58cd9113..60a37fad 100644 --- a/examples/paper/example05_decoding_ppaf_pphf.py +++ b/examples/paper/example05_decoding_ppaf_pphf.py @@ -1,16 +1,122 @@ +#!/usr/bin/env python3 +"""Example 05 — Stimulus Decoding With PPAF and PPHF. + +This example demonstrates: + 1) Univariate sinusoidal stimulus encoding and decoding via PPDecodeFilterLinear. + 2) 4-state arm-reach simulation with 20-cell population encoding. + 3) PPAF (Point-Process Adaptive Filter) decoding: free vs goal-informed. + 4) Hybrid filter (PPHybridFilterLinear) for joint discrete/continuous states. + +The example has three parts: + Part A (experiment5): Univariate sinusoidal stimulus — encode with 20 + neurons, decode with PPDecodeFilterLinear. + Part B (experiment5b): 4-state arm reaching — simulate 20-cell population, + compare PPAF vs PPAF+Goal across 20 simulations. + Part C (experiment6): Hybrid filter — simulate 40-cell population with + discrete reach states and continuous kinematics, decode with + PPHybridFilterLinear. + +Expected outputs: + - Figure 1: Univariate stimulus setup (CIF tuning curves, simulated spikes). + - Figure 2: Univariate decoding results (decoded stimulus vs true). + - Figure 3: Reach setup and population encoding. + - Figure 4: PPAF comparison (free vs goal-informed). + - Figure 5: Hybrid filter setup. + - Figure 6: Hybrid decoding summary. + +Paper mapping: + Section 2.5 (point-process adaptive filter) and Section 2.6 (hybrid filter). +""" from __future__ import annotations +import argparse +import json import sys from pathlib import Path - THIS_DIR = Path(__file__).resolve().parent REPO_ROOT = THIS_DIR.parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) -from nstat.paper_example_catalog import main_for +from nstat.paper_examples_full import ( # noqa: E402 + run_experiment5, + run_experiment5b, + run_experiment6, +) +from nstat.paper_figures import export_named_paper_figures # noqa: E402 + + +def run_example05(*, export_figures: bool = False, export_dir: Path | None = None): + """Run Example 05: PPAF and PPHF decoding. + + Analysis workflow (mirrors Matlab example05_decoding_ppaf_pphf.m): + + Part A — Univariate stimulus decoding: + 1. Define 20-cell population with sinusoidal tuning. + 2. Simulate spikes from sinusoidal stimulus CIF. + 3. Decode stimulus via PPDecodeFilterLinear. + + Part B — Arm-reach PPAF: + 4. Simulate 4-state reaching movements (position + velocity). + 5. Encode with 20-cell cosine-tuning population. + 6. Decode with PPAF (free) and PPAF+Goal; compare across 20 runs. + + Part C — Hybrid filter: + 7. Simulate 40-cell population with discrete reach-state modulation. + 8. Decode joint discrete/continuous state via PPHybridFilterLinear. + """ + # --- Part A: Univariate sinusoidal stimulus --- + summary5, payload5 = run_experiment5(return_payload=True) + + # --- Part B: Arm-reach PPAF --- + summary5b, payload5b = run_experiment5b(return_payload=True) + + # --- Part C: Hybrid filter --- + summary6, payload6 = run_experiment6(REPO_ROOT, return_payload=True) + + # Merge summaries for JSON output + combined_summary = { + "experiment5": summary5, + "experiment5b": summary5b, + "experiment6": summary6, + } + print(json.dumps(combined_summary, indent=2)) + + if export_figures: + if export_dir is None: + export_dir = THIS_DIR / "figures" / "example05" + combined_payload = { + "experiment5": payload5, + "experiment5b": payload5b, + "experiment6": payload6, + } + saved = export_named_paper_figures( + "example05", + summary=combined_summary, + payload=combined_payload, + export_dir=export_dir, + ) + print(f"\nGenerated {len(saved)} figure(s):") + for p in saved: + print(f" {p}") + + return combined_summary if __name__ == "__main__": - raise SystemExit(main_for("example05")) + parser = argparse.ArgumentParser( + description="Example 05: Stimulus Decoding With PPAF and PPHF" + ) + parser.add_argument("--repo-root", type=Path, default=REPO_ROOT) + parser.add_argument("--export-figures", action="store_true") + parser.add_argument("--export-dir", type=Path, default=None) + parser.add_argument("--output-json", type=Path, default=None) + args = parser.parse_args() + + result = run_example05( + export_figures=args.export_figures, + export_dir=args.export_dir, + ) + if args.output_json: + args.output_json.write_text(json.dumps(result, indent=2), encoding="utf-8") From ec75482ce18cab4a6b5a31bbf5334ea16feacbfc Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Mon, 9 Mar 2026 23:21:36 -0400 Subject: [PATCH 5/6] Align README with Matlab nSTAT repo structure Add overview paragraph describing toolbox capabilities, lab website links, license note (GPL v2), Figshare dataset DOI, PMID, and cross-reference to the Matlab repo. Co-Authored-By: Claude Opus 4.6 --- README.md | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 4d608cfa..7cf0d8d5 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,20 @@ # nSTAT-python -`nSTAT-python` is a Python toolbox for neural spike-train analysis, modeling, and decoding. +`nSTAT-python` is a Python port of the [nSTAT](https://github.com/cajigaslab/nSTAT) +open-source neural spike train analysis toolbox. It implements a range of models and +algorithms for neural spike train data analysis, with a focus on point-process +generalized linear models (GLMs), model fitting, model-order analysis, and adaptive +decoding. In addition to point-process algorithms, nSTAT also provides tools for +Gaussian signals — from correlation analysis to the Kalman filter — applicable to +continuous neural signals such as LFP, EEG, and ECoG. [![test-and-build](https://github.com/cajigaslab/nSTAT-python/actions/workflows/ci.yml/badge.svg)](https://github.com/cajigaslab/nSTAT-python/actions/workflows/ci.yml) [![pages](https://github.com/cajigaslab/nSTAT-python/actions/workflows/pages.yml/badge.svg)](https://github.com/cajigaslab/nSTAT-python/actions/workflows/pages.yml) +Lab websites: +- Neuroscience Statistics Research Laboratory: https://www.neurostat.mit.edu +- RESToRe Lab: https://www.med.upenn.edu/cajigaslab/ + ## Installation ```bash @@ -162,9 +172,21 @@ pytest -q sphinx-build -b html docs docs/_build ``` +## License + +nSTAT is protected by the GPL v2 Open Source License. + ## Cite -Cajigas, I., Malika, W. Q., & Brown, E. N. (2012). -nSTAT: Open-source neural spike train analysis toolbox for Matlab. -Journal of Neuroscience Methods, 211, 245–264. +If you use nSTAT in your work, please cite: + +Cajigas I, Malik WQ, Brown EN. nSTAT: Open-source neural spike train analysis +toolbox for Matlab. Journal of Neuroscience Methods 211: 245–264, Nov. 2012. https://doi.org/10.1016/j.jneumeth.2012.08.009 +PMID: 22981419 + +## Data and Related Repositories + +- **Matlab toolbox**: https://github.com/cajigaslab/nSTAT +- **Paper-example dataset (Figshare)**: https://doi.org/10.6084/m9.figshare.4834640.v3 +- **Paper DOI**: https://doi.org/10.1016/j.jneumeth.2012.08.009 From 20e166297b72e78af9ad3c5bd63bd21765b1942e Mon Sep 17 00:00:00 2001 From: Iahn Cajigas Date: Mon, 9 Mar 2026 23:34:24 -0400 Subject: [PATCH 6/6] Add SignalObj peak-finding methods with Matlab bug fixes Port findPeaks, findMaxima, findMinima, findGlobalPeak from Matlab SignalObj. Fixes two Matlab bugs: - findPeaks 'minima' branch now negates data before peak detection (Matlab original calls findpeaks on raw data, finding maxima instead) - findGlobalPeak minima branch uses correct variable name (Matlab has typo 'sOBj' instead of 'sObj') Co-Authored-By: Claude Opus 4.6 --- nstat/core.py | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/nstat/core.py b/nstat/core.py index 201f88bd..4856c68d 100644 --- a/nstat/core.py +++ b/nstat/core.py @@ -990,6 +990,83 @@ def sqrt(self) -> "SignalObj": """Element-wise square root (Matlab ``sqrt``).""" return self.power(0.5) + # ------------------------------------------------------------------ + # Peak-finding helpers (match Matlab SignalObj) + # ------------------------------------------------------------------ + def findPeaks( + self, + peak_type: str = "maxima", + minDistance: int | None = None, + ) -> tuple[list[np.ndarray], list[np.ndarray]]: + """Find local peaks in each signal dimension. + + Parameters + ---------- + peak_type : ``'maxima'`` or ``'minima'`` + minDistance : minimum sample distance between peaks (default: + ``sampleRate * duration / 10``). + + Returns + ------- + indices, values : lists of arrays (one per dimension). + + Note: The Matlab original has a bug where the ``'minima'`` branch + does not negate the data before calling ``findpeaks``. This Python + port fixes that. + """ + from scipy.signal import find_peaks as _find_peaks + + data = np.atleast_2d(self.data) + if data.shape[0] == 1: + data = data.T + N = data.shape[0] + if minDistance is None: + duration = float(self.maxTime - self.minTime) + minDistance = max(1, int(round(self.sampleRate * duration / 10))) + + all_indices: list[np.ndarray] = [] + all_values: list[np.ndarray] = [] + for col in range(data.shape[1]): + sig = data[:, col] + if peak_type == "minima": + sig = -sig + idx, _ = _find_peaks(sig, distance=minDistance) + all_indices.append(idx) + all_values.append(data[idx, col]) # always return actual values + return all_indices, all_values + + def findMaxima(self) -> tuple[list[np.ndarray], list[np.ndarray]]: + """Convenience wrapper: ``findPeaks('maxima')``.""" + return self.findPeaks("maxima") + + def findMinima(self) -> tuple[list[np.ndarray], list[np.ndarray]]: + """Convenience wrapper: ``findPeaks('minima')``.""" + return self.findPeaks("minima") + + def findGlobalPeak( + self, peak_type: str = "maxima" + ) -> tuple[np.ndarray, np.ndarray]: + """Find the global max or min across each dimension. + + Returns + ------- + times : 1-D array of times at which the peak occurs (one per dim). + values : 1-D array of peak values (one per dim). + + Note: The Matlab original has a typo (``sOBj`` instead of ``sObj``) + in the minima branch. This Python port fixes that. + """ + data = np.atleast_2d(self.data) + if data.shape[0] == 1: + data = data.T + if peak_type == "maxima": + idx = np.argmax(data, axis=0) + else: + idx = np.argmin(data, axis=0) + times = self.time[idx] + values = data[idx, np.arange(data.shape[1])] + return np.atleast_1d(times), np.atleast_1d(values) + # ------------------------------------------------------------------ # Cross-covariance (match Matlab SignalObj.xcov) # ------------------------------------------------------------------