Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified docs/figures/example01/fig01_constant_mg_summary.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example01/fig02_washout_raster_overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example01/fig03_piecewise_baseline_comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example02/fig01_data_overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example02/fig02_lag_and_model_comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example03/fig01_simulated_and_real_rasters.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example03/fig02_psth_comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example03/fig03_ssglm_simulation_summary.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example03/fig04_ssglm_fit_diagnostics.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example03/fig05_stimulus_effect_surfaces.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example03/fig06_learning_trial_comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example04/fig01_example_cells_path_overlay.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example04/fig02_model_summary_statistics.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example04/fig03_gaussian_place_fields_animal1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example04/fig04_zernike_place_fields_animal1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example04/fig05_gaussian_place_fields_animal2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example04/fig06_zernike_place_fields_animal2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example04/fig07_example_cell_mesh_comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example05/fig01_univariate_setup.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example05/fig02_univariate_decoding.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example05/fig03_reach_and_population_setup.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example05/fig04_ppaf_goal_vs_free.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example05/fig05_hybrid_setup.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/figures/example05/fig06_hybrid_decoding_summary.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 19 additions & 8 deletions examples/paper/example01_mepsc_poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ def _load_mepsc_times_seconds(path: Path) -> np.ndarray:
return times_ms / 1000.0


def _matlab_colon(start: float, step: float, stop: float) -> np.ndarray:
"""Replicate MATLAB ``start:step:stop`` exactly.

``np.arange`` accumulates floating-point error over many steps and can
produce off-by-one length mismatches. This function computes the element
count first (like MATLAB's colon operator), then multiplies by integer
indices — giving bit-exact parity.
"""
n = int(np.floor((stop - start) / step)) + 1
return start + np.arange(n) * step


# =========================================================================
# Helper: export figure
# =========================================================================
Expand Down Expand Up @@ -102,7 +114,7 @@ def run_example01(*, export_figures: bool = False, export_dir: Path | None = Non

# Create spike train and time vector
nstConst = nspikeTrain(epsc2)
timeConst = np.arange(0, nstConst.maxTime + 1.0 / sampleRate, 1.0 / sampleRate)
timeConst = _matlab_colon(0, 1.0 / sampleRate, nstConst.maxTime)

# Create baseline covariate
baseline = Covariate(
Expand Down Expand Up @@ -150,7 +162,7 @@ def run_example01(*, export_figures: bool = False, export_dir: Path | None = Non
spikeTimes1 = 260.0 + washout1
spikeTimes2 = np.sort(washout2) + 745.0
nstWashout = nspikeTrain(np.concatenate([spikeTimes1, spikeTimes2]))
timeWashout = np.arange(260.0, nstWashout.maxTime + 1.0 / sampleRate, 1.0 / sampleRate)
timeWashout = _matlab_colon(260.0, 1.0 / sampleRate, nstWashout.maxTime)

# --- Figure 2: Constant vs Decreasing Mg2+ rasters ---
fig2, axes2 = plt.subplots(2, 1, figsize=(14, 9))
Expand Down Expand Up @@ -179,12 +191,11 @@ def run_example01(*, export_figures: bool = False, export_dir: Path | None = Non
print("\n=== Part 3: Piecewise Baseline Model Comparison ===")

# Build piecewise indicator covariates
# Matlab: find(time<495,1,'last') — last index strictly before 495
# np.searchsorted gives first index >= 495, so subtract 1 isn't needed
# because Python slice [:idx] is exclusive. But Matlab's 1:timeInd1 is
# inclusive, so we need searchsorted(..., side='right') to include 495.
timeInd1 = np.searchsorted(timeWashout, 495.0, side="right")
timeInd2 = np.searchsorted(timeWashout, 765.0, side="right")
# Matlab: timeInd1 = find(time < 495, 1, 'last') → last 1-based index < 495
# Equivalent Python: first 0-based index >= 495 (searchsorted side='left'),
# so rate1[:idx] covers [260, 494.999] and rate2[idx:] starts at 495.
timeInd1 = np.searchsorted(timeWashout, 495.0, side="left")
timeInd2 = np.searchsorted(timeWashout, 765.0, side="left")
N = len(timeWashout)

constantRate = np.ones((N, 1))
Expand Down
4 changes: 3 additions & 1 deletion examples/paper/example02_whisker_stimulus_thalamus.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,10 @@ def run_example02(*, export_figures: bool = False, export_dir: Path | None = Non
windowIndex = ksIdx

# Extract selected history windows
# windowIndex is 0-based; MATLAB uses windowTimes(1:windowIndex) with 1-based
# indexing, which includes windowIndex elements. Python equivalent is [:windowIndex+1].
if windowIndex > 1:
selectedHistory = list(windowTimes[:windowIndex])
selectedHistory = list(windowTimes[:windowIndex + 1])
else:
selectedHistory = []

Expand Down
4 changes: 3 additions & 1 deletion examples/paper/example03_psth_and_ssglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,9 @@ def run_part_b(data_dir, export_dir=None):
spikeColl.resample(1 / delta)
spikeColl.setMaxTime(tmax)

dN = spikeColl.dataToMatrix()
# MATLAB: dN = spikeColl.dataToMatrix' → (K, T)
# Python dataToMatrix() returns (T, K), so transpose to match.
dN = spikeColl.dataToMatrix().T # (K, T)
if dN.ndim == 1:
dN = dN.reshape(1, -1)
dN = np.asarray(dN, dtype=float)
Expand Down
16 changes: 13 additions & 3 deletions nstat/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,12 @@ def GLMFit(
lambda_time = np.asarray(tObj.getCov(1).time, dtype=float).reshape(-1)
sample_rate = float(tObj.sampleRate)
dt = 1.0 / max(sample_rate, 1e-12)
bin_edges = np.concatenate([lambda_time, [lambda_time[-1] + dt]])
y = np.asarray(tObj.nspikeColl.getNST(index).to_binned_counts(bin_edges), dtype=float).reshape(-1)
# Use getSpikeVector (via getSigRep) to match MATLAB's GLMFit,
# which calls tObj.getSpikeVector(neuronIndex). The alternative
# to_binned_counts uses np.histogram bin edges that can assign
# spikes to adjacent bins when spike times fall on floating-point
# boundary values, causing small but systematic deviance offsets.
y = np.asarray(tObj.getSpikeVector(index), dtype=float).reshape(-1)

n_obs = min(x.shape[0], y.shape[0], lambda_time.shape[0])
x = x[:n_obs, :]
Expand Down Expand Up @@ -472,7 +476,13 @@ def run_analysis_for_neuron(
)
# MATLAB returns fits with KS diagnostics already populated, and
# downstream summary classes read those cached fields directly.
fit_result.computeKSStats()
# Compute KS stats for ALL fits (not just fit 1) so that history
# sweeps and multi-model comparisons have correct KS statistics.
for _fit_i in range(1, fit_result.numResults + 1):
try:
fit_result.computeKSStats(fit_num=_fit_i)
except Exception:
pass # some configs may fail KS (e.g. degenerate lambda)

# Compute the conditional intensity on validation data when a
# validation partition is present (mirrors Matlab behaviour).
Expand Down
2 changes: 1 addition & 1 deletion nstat/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,7 +1513,7 @@ def plotVariability(self, selectorArray=None, ax=None):
# Cross-covariance (match Matlab SignalObj.xcov)
# ------------------------------------------------------------------
def xcov(self, other: "SignalObj | None" = None, maxlag: int | None = None,
scaleOpt: str = "biased") -> "SignalObj":
scaleOpt: str = "none") -> "SignalObj":
"""Cross-covariance (mean-removed xcorr). Matches Matlab ``xcov``.

When called with no *other* argument (auto-covariance), only
Expand Down
33 changes: 29 additions & 4 deletions nstat/decoding_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,13 +1507,32 @@ def _ppdecode_filter_linear(
W_p = np.zeros((ns, ns, N + 1), dtype=float)
W_u = np.zeros((ns, ns, N), dtype=float)

# Fuse initial state with backward target information
# (Srinivasan et al. 2006 — prior update step)
if np.linalg.det(Pi0_mat) != 0:
invPi0 = np.linalg.pinv(Pi0_mat)
invPitT_0_fuse = np.linalg.pinv(PitT[:, :, 0])
Pi0New = np.linalg.pinv(invPi0 + invPitT_0_fuse)
Pi0New = np.where(np.isnan(Pi0New), 0.0, Pi0New)
x0New = Pi0New @ (invPi0 @ x0_vec + invPitT_0_fuse @ PhitT[:, :, 0] @ yT_vec)
x0_vec = x0New
Pi0_mat = Pi0New

# Initial predict with target correction
# NOTE: MATLAB uses n=N (leftover from ft loop) for the initial
# PPDecode_predict call, so Amat(:,:,min(N,N)) = B(:,:,N) and
# Qmat(:,:,min(N)) = QT(:,:,N). We replicate this for parity.
invPitT_0 = np.linalg.pinv(PitT[:, :, 0])
invA1 = np.linalg.pinv(A1)
invPhi0T = np.linalg.pinv(invA1 @ PhitT[:, :, 0])
ut[:, 0] = (Q1 @ invPitT_0) @ PhitT[:, :, 0] @ (yT_vec - invPhi0T @ x0_vec)
x_p[:, 0] = Amat[:, :, 0] @ x0_vec + ut[:, 0]
W_p[:, :, 0] = Amat[:, :, 0] @ Pi0_mat @ Amat[:, :, 0].T + Qmat_arr[:, :, 0]
x_p[:, 0], W_p[:, :, 0] = DecodingAlgorithms.PPDecode_predict(
x0_vec, Pi0_mat,
Amat[:, :, N - 1],
Qmat_arr[:, :, N - 1],
)
x_p[:, 0] += ut[:, 0]
W_p[:, :, 0] += (Q1 @ invPitT_0) @ A1 @ Pi0_mat @ A1.T @ (Q1 @ invPitT_0).T

for time_index in range(1, N + 1):
x_u[:, time_index - 1], W_u[:, :, time_index - 1], _ = DecodingAlgorithms.PPDecode_updateLinear(
Expand All @@ -1536,8 +1555,14 @@ def _ppdecode_filter_linear(
ut[:, time_index] = (Qn @ invPitT_n1) @ PhitT[:, :, time_index] @ (
yT_vec - invPhitm1T @ x_u[:, time_index - 1]
)
A_t = Amat[:, :, min(time_index - 1, N - 1)]
Q_t = Qmat_arr[:, :, min(time_index - 1, N - 1)]
# MATLAB PPDecode_predict in non-augmented target branch
# uses Amat(:,:,min(size(A,3),n)) and Qmat(:,:,min(size(Qmat,3))).
# size(A,3) = number of A pages (1 if time-invariant), so
# min(1,n) = 1 → always B[:,:,0].
# min(size(Qmat,3)) = min(N) = N → always QT[:,:,N-1].
A_dim3 = A.shape[2] if A.ndim == 3 else 1
A_t = Amat[:, :, min(A_dim3 - 1, time_index - 1)]
Q_t = Qmat_arr[:, :, N - 1]
x_p[:, time_index], W_p[:, :, time_index] = DecodingAlgorithms.PPDecode_predict(
x_u[:, time_index - 1],
W_u[:, :, time_index - 1],
Expand Down
83 changes: 69 additions & 14 deletions nstat/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,9 @@ def _compute_diagnostics(self, fit_num: int = 1, *, dt_correction: int = 1) -> d
ideal = np.asarray(xAxis[:, 0], dtype=float).reshape(-1) if np.asarray(xAxis).size else np.asarray([], dtype=float)
empirical = np.asarray(KSSorted[:, 0], dtype=float).reshape(-1) if np.asarray(KSSorted).size else np.asarray([], dtype=float)
ci = np.full(ideal.size, 1.36 / np.sqrt(float(ideal.size)), dtype=float) if ideal.size else np.asarray([], dtype=float)
# MATLAB's setKSStats (FitResult.m:1434) recomputes the KS stat
# via kstest2(xAxis, KSSorted) — a two-sample KS test. The
# curve-level max deviation is kept separately for plotting.
ks_curve_stat = float(np.max(np.abs(empirical - ideal))) if ideal.size else 1.0
if ideal.size:
different, ks_pvalue, ks_stat = _matlab_kstest2(ideal, empirical)
Expand Down Expand Up @@ -962,7 +965,20 @@ def _compute_diagnostics(self, fit_num: int = 1, *, dt_correction: int = 1) -> d
"coeff_labels": np.asarray(coeff_labels, dtype=object),
}
self._diagnostic_cache[fit_num] = diagnostics
self.setKSStats(z, uniforms, ideal, empirical, np.asarray([ks_stat], dtype=float))
# Write KS stat to the correct index (fit_num is 1-based).
# We avoid calling setKSStats here because it overwrites the
# multi-column Z/U/KSXAxis/KSSorted arrays and always writes
# the ks_stat scalar to index 0. Instead, write directly to
# the correct row so that multi-fit sweeps accumulate properly.
idx = fit_num - 1
if idx < self.KSStats.shape[0]:
self.KSStats[idx, 0] = ks_stat
# For the last fit, store Z/U/etc. so legacy callers that
# expect those arrays still see something useful.
self.Z = np.asarray(z, dtype=float)[:, None] if z.size else np.array([], dtype=float)
self.U = np.asarray(uniforms, dtype=float)[:, None] if uniforms.size else np.array([], dtype=float)
self.KSXAxis = np.asarray(ideal, dtype=float)[:, None] if ideal.size else np.array([], dtype=float)
self.KSSorted = np.asarray(empirical, dtype=float)[:, None] if empirical.size else np.array([], dtype=float)
self.KSPvalues[fit_num - 1] = ks_pvalue
self.withinConfInt[fit_num - 1] = within
self.X = gaussianized
Expand Down Expand Up @@ -1130,7 +1146,7 @@ def plotResults(self, fit_num: int = 1, handle=None):
ax_co = fig.add_subplot(gs[1, 0:2])
ax_re = fig.add_subplot(gs[1, 2:4])

self.KSPlot(fit_num=fit_num, handle=ax_ks)
self.KSPlot(fit_num=None, handle=ax_ks)
# Add neuron number label (matching Matlab)
ax_ks.text(
0.45, 0.95, f"Neuron: {self.neuronNumber}",
Expand All @@ -1144,23 +1160,62 @@ def plotResults(self, fit_num: int = 1, handle=None):
fig.tight_layout()
return fig

def KSPlot(self, fit_num: int = 1, handle=None):
"""KS goodness-of-fit plot with 95 % confidence bands (Matlab ``KSPlot``)."""
diag = self._compute_diagnostics(fit_num)
# MATLAB color cycle used by Analysis.colors: b, g, r, c, m, y, k
_MATLAB_KS_COLORS = ["tab:blue", "tab:green", "tab:red", "tab:cyan", "tab:purple", "tab:olive", "k"]

def KSPlot(self, fit_num: int | list[int] | None = None, handle=None):
"""KS goodness-of-fit plot with 95 % confidence bands (Matlab ``KSPlot``).

Parameters
----------
fit_num : int, list of int, or None
Which model(s) to plot. ``None`` (default) plots all models
(``1:numResults``), matching the MATLAB default behaviour.
A single int plots one model; a list plots the specified subset.
handle : matplotlib Axes, optional
Axes to draw on. A new figure is created when *None*.
"""
if fit_num is None:
fit_nums = list(range(1, self.numResults + 1))
elif isinstance(fit_num, int):
fit_nums = [fit_num]
else:
fit_nums = list(fit_num)

ax = handle if handle is not None else plt.subplots(1, 1, figsize=(5.0, 4.0))[1]
ideal = np.asarray(diag["ks_ideal"], dtype=float)
empirical = np.asarray(diag["ks_empirical"], dtype=float)
ci = np.asarray(diag["ks_ci"], dtype=float)
if ideal.size:
ax.plot(ideal, empirical, color="tab:blue", linewidth=1.5)
ax.plot([0.0, 1.0], [0.0, 1.0], color="0.3", linewidth=1.0, linestyle="--")
ax.plot(ideal, np.clip(ideal + ci, 0.0, 1.0), color="tab:red", linewidth=1.0)
ax.plot(ideal, np.clip(ideal - ci, 0.0, 1.0), color="tab:red", linewidth=1.0)

# Draw reference diagonal and confidence bands from the first model
first_diag = self._compute_diagnostics(fit_nums[0])
ideal_ref = np.asarray(first_diag["ks_ideal"], dtype=float)
ci_ref = np.asarray(first_diag["ks_ci"], dtype=float)
if ideal_ref.size:
ax.plot([0.0, 1.0], [0.0, 1.0], color="0.3", linewidth=1.0, linestyle="-.")
ax.plot(ideal_ref, np.clip(ideal_ref + ci_ref, 0.0, 1.0), color="tab:red", linewidth=1.0)
ax.plot(ideal_ref, np.clip(ideal_ref - ci_ref, 0.0, 1.0), color="tab:red", linewidth=1.0)

# Plot each model's empirical CDF (matching MATLAB colour cycle)
labels_for_legend: list[str] = []
handles_for_legend: list[object] = []
data_labels = list(self.lambda_signal.dataLabels) if getattr(self.lambda_signal, "dataLabels", None) else []
for i, fn in enumerate(fit_nums):
diag = self._compute_diagnostics(fn)
ideal = np.asarray(diag["ks_ideal"], dtype=float)
empirical = np.asarray(diag["ks_empirical"], dtype=float)
color = self._MATLAB_KS_COLORS[i % len(self._MATLAB_KS_COLORS)]
label = data_labels[fn - 1] if fn - 1 < len(data_labels) else f"Model {fn}"
if ideal.size:
h, = ax.plot(ideal, empirical, color=color, linewidth=2.0)
handles_for_legend.append(h)
labels_for_legend.append(label)

if handles_for_legend:
ax.legend(handles_for_legend, labels_for_legend, loc="lower right", fontsize=10)

ax.set_xlim(0.0, 1.0)
ax.set_ylim(0.0, 1.0)
ax.set_xlabel("Ideal Uniform CDF")
ax.set_ylabel("Empirical CDF")
ax.set_title("KS Plot")
ax.set_title("KS Plot of Rescaled ISIs\nwith 95% Confidence Intervals", fontweight="bold", fontsize=11)
return ax

def plotResidual(self, fit_num: int = 1, handle=None):
Expand Down
58 changes: 52 additions & 6 deletions nstat/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,7 @@ def psthGLM(self, binwidth: float, windowTimes=None, fitType: str = "poisson",
"""
from .analysis import Analysis
from .confidence_interval import ConfidenceInterval
from .glm import fit_poisson_glm

basis = self.generateUnitImpulseBasis(
float(binwidth), float(self.minTime), float(self.maxTime), float(self.sampleRate)
Expand All @@ -1393,13 +1394,58 @@ def psthGLM(self, binwidth: float, windowTimes=None, fitType: str = "poisson",
cfg.setName("GLM-PSTH+Hist" if np.asarray(hist).size else "GLM-PSTH")
cfgColl = ConfigCollection([cfg])
algorithm = "GLM" if str(fitType or "poisson").lower() == "poisson" else "BNLRCG"
psth_result = Analysis.RunAnalysisForAllNeurons(trial, cfgColl, 0, algorithm, [], 1)
fit = psth_result[0] if isinstance(psth_result, list) else psth_result

# Extract coefficients and standard errors
coeffs_all, _labels, se_all = fit.getCoeffsWithLabels(1)
raw_coeffs = np.asarray(coeffs_all, dtype=float).reshape(-1)
se_vec = np.asarray(se_all, dtype=float).reshape(-1)
# ---- Matlab batchMode=1: concatenate Y and X across ALL trials ----
# Matlab nstColl.psthGLM (line 1003-1004) calls
# RunAnalysisForAllNeurons(trial, cfgColl, 0, Algorithm, [], 1)
# with batchMode=1, which pools all trials of the same neuron into
# a single GLM fit. Python's RunAnalysisForAllNeurons previously
# ignored batchMode, fitting each trial separately — producing
# single-trial coefficients instead of across-trial pooled ones.
cfgColl.setConfig(trial, 1)
stacked_x: list[np.ndarray] = []
stacked_y: list[np.ndarray] = []
for idx in range(1, trial.nspikeColl.num_spike_trains + 1):
x_i = np.asarray(trial.getDesignMatrix(idx), dtype=float)
y_i = np.asarray(trial.getSpikeVector(idx), dtype=float).reshape(-1)
n_obs = min(x_i.shape[0], y_i.shape[0])
stacked_x.append(x_i[:n_obs])
stacked_y.append(y_i[:n_obs])
X = np.vstack(stacked_x)
y = np.concatenate(stacked_y)

if algorithm == "GLM":
glm_res = fit_poisson_glm(X, y, include_intercept=False)
raw_coeffs = np.asarray(glm_res.coefficients, dtype=float).reshape(-1)
lambda_hat = glm_res.predict_rate(X)
W = np.maximum(lambda_hat, 1e-12)
else:
from .glm import fit_binomial_glm
glm_res = fit_binomial_glm(X, y, include_intercept=False)
raw_coeffs = np.asarray(glm_res.coefficients, dtype=float).reshape(-1)
lambda_hat = np.clip(glm_res.predict_probability(X), 1e-12, 1.0 - 1e-9)
W = lambda_hat * (1.0 - lambda_hat)
W = np.maximum(W, 1e-12)

# Standard errors from Fisher information (Hessian inverse)
try:
XtWX = X.T @ (X * W[:, None]) + 1e-6 * np.eye(X.shape[1])
covb = np.linalg.inv(XtWX)
se_vec = np.sqrt(np.maximum(np.diag(covb), 0.0))
except np.linalg.LinAlgError:
se_vec = np.full(raw_coeffs.size, np.nan, dtype=float)

# Build a proper FitResult for the third return value by fitting just
# the first spike train (fast), then override its coefficients with
# the batch-fit values.
fit = Analysis.RunAnalysisForNeuron(trial, 1, cfgColl, 0, algorithm)
if isinstance(fit, list):
fit = fit[0]
# Override with batch-fit coefficients and standard errors
fit.b[0] = raw_coeffs.copy()
if fit.stats and isinstance(fit.stats[0], dict):
fit.stats[0]["se"] = se_vec.copy()

numBasis = basis.dimension

if raw_coeffs.size < numBasis:
Expand Down
Loading
Loading