Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
779 changes: 779 additions & 0 deletions PORTING_MAP.md

Large diffs are not rendered by default.

45 changes: 28 additions & 17 deletions examples/paper/example04_place_cells_continuous_stimulus.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,15 @@ def _load_animal_results(path, x, y, time, neurons):
return fit_results


def _compute_place_field(coeffs, grid_design, grid_shape):
"""Compute predicted firing rate on a spatial grid."""
def _compute_place_field(coeffs, grid_design, grid_shape, sample_rate=1.0):
"""Compute predicted firing rate on a spatial grid.

Matches Matlab ``FitResult.evalLambda`` which computes
``exp(X * b) * sampleRate`` to convert from conditional intensity
(per bin) to firing rate (Hz).
"""
eta = grid_design @ coeffs
rate = np.exp(eta)
rate = np.exp(eta) * sample_rate
return rate.reshape(grid_shape)


Expand Down Expand Up @@ -219,22 +224,26 @@ def run_example04(*, export_figures: bool = False, export_dir: Path | None = Non
summary1 = FitSummary(fitResults1)
summary2 = FitSummary(fitResults2)

# Delta statistics: Gaussian (index 0) minus Zernike (index 1)
dAIC1 = summary1.AIC[:, 0] - summary1.AIC[:, 1]
dBIC1 = summary1.BIC[:, 0] - summary1.BIC[:, 1]
# Delta statistics
# dKS: direct subtraction Gaussian - Zernike (matches Matlab line 81-83)
dKS1 = summary1.KSStats[:, 0] - summary1.KSStats[:, 1]

dAIC2 = summary2.AIC[:, 0] - summary2.AIC[:, 1]
dBIC2 = summary2.BIC[:, 0] - summary2.BIC[:, 1]
dKS2 = summary2.KSStats[:, 0] - summary2.KSStats[:, 1]

# dAIC/dBIC: Matlab uses getDiffAIC(1) / getDiffBIC(1) which computes
# Zernike - Gaussian (other columns minus reference column).
dAIC1 = summary1.AIC[:, 1] - summary1.AIC[:, 0]
dBIC1 = summary1.BIC[:, 1] - summary1.BIC[:, 0]

dAIC2 = summary2.AIC[:, 1] - summary2.AIC[:, 0]
dBIC2 = summary2.BIC[:, 1] - summary2.BIC[:, 0]

dAIC_all = np.concatenate([dAIC1, dAIC2])
dBIC_all = np.concatenate([dBIC1, dBIC2])
dKS_all = np.concatenate([dKS1, dKS2])

print(f" Mean dAIC (Gauss-Zern): {np.nanmean(dAIC_all):.2f}")
print(f" Mean dBIC (Gauss-Zern): {np.nanmean(dBIC_all):.2f}")
print(f" Mean dKS (Gauss-Zern): {np.nanmean(dKS_all):.4f}")
print(f" Mean dAIC (Zern-Gauss): {np.nanmean(dAIC_all):.2f}")
print(f" Mean dBIC (Zern-Gauss): {np.nanmean(dBIC_all):.2f}")

# ==================================================================
# Figure 1: Example cells — spike locations over path (2x2)
Expand Down Expand Up @@ -267,13 +276,13 @@ def run_example04(*, export_figures: bool = False, export_dir: Path | None = Non

axes2[1].boxplot([dAIC1[np.isfinite(dAIC1)], dAIC2[np.isfinite(dAIC2)]],
tick_labels=["Animal 1", "Animal 2"])
axes2[1].set_ylabel(r"$\Delta$AIC (Gaussian - Zernike)")
axes2[1].set_ylabel(r"$\Delta$AIC (Zernike - Gaussian)")
axes2[1].set_title("AIC Difference")
axes2[1].axhline(0, color="gray", linestyle="--", linewidth=0.5)

axes2[2].boxplot([dBIC1[np.isfinite(dBIC1)], dBIC2[np.isfinite(dBIC2)]],
tick_labels=["Animal 1", "Animal 2"])
axes2[2].set_ylabel(r"$\Delta$BIC (Gaussian - Zernike)")
axes2[2].set_ylabel(r"$\Delta$BIC (Zernike - Gaussian)")
axes2[2].set_title("BIC Difference")
axes2[2].axhline(0, color="gray", linestyle="--", linewidth=0.5)

Expand Down Expand Up @@ -316,13 +325,14 @@ def _plot_heatmaps(fit_results, nCells, title_prefix, design_gauss,
for i in range(nCells):
row, col = divmod(i, nCols)
fr = fit_results[i]
sr = float(fr.lambda_signal.sampleRate)
coeffs_g = np.asarray(fr.b[0], dtype=float).ravel()
coeffs_z = np.asarray(fr.b[1], dtype=float).ravel() if fr.numResults > 1 else coeffs_g

# Gaussian field
ax = axesG[row, col]
try:
field_g = _compute_place_field(coeffs_g, design_gauss[:, :coeffs_g.size], grid_shape)
field_g = _compute_place_field(coeffs_g, design_gauss[:, :coeffs_g.size], grid_shape, sr)
ax.pcolormesh(xx, yy, field_g, shading="gouraud", cmap="jet")
except Exception:
pass
Expand All @@ -334,7 +344,7 @@ def _plot_heatmaps(fit_results, nCells, title_prefix, design_gauss,
# Zernike field
ax = axesZ[row, col]
try:
field_z = _compute_place_field(coeffs_z, design_zern[:, :coeffs_z.size], grid_shape)
field_z = _compute_place_field(coeffs_z, design_zern[:, :coeffs_z.size], grid_shape, sr)
ax.pcolormesh(xx, yy, field_z, shading="gouraud", cmap="jet")
except Exception:
pass
Expand Down Expand Up @@ -372,13 +382,14 @@ def _plot_heatmaps(fit_results, nCells, title_prefix, design_gauss,
# ==================================================================
exampleCell = min(24, nCells1 - 1) # 0-indexed → cell 25 in Matlab
fr_ex = fitResults1[exampleCell]
sr_ex = float(fr_ex.lambda_signal.sampleRate)
coeffs_g = np.asarray(fr_ex.b[0], dtype=float).ravel()
coeffs_z = np.asarray(fr_ex.b[1], dtype=float).ravel()

field_g = _compute_place_field(
coeffs_g, gridDesignGauss[:, :coeffs_g.size], xx.shape)
coeffs_g, gridDesignGauss[:, :coeffs_g.size], xx.shape, sr_ex)
field_z = _compute_place_field(
coeffs_z, gridDesignZern[:, :coeffs_z.size], xx.shape)
coeffs_z, gridDesignZern[:, :coeffs_z.size], xx.shape, sr_ex)

fig7 = plt.figure(figsize=(12, 8))
ax3d = fig7.add_subplot(111, projection="3d")
Expand Down
Loading
Loading