diff --git a/docs/figures/example02/fig01_data_overview.png b/docs/figures/example02/fig01_data_overview.png index f6aede37..968f5b40 100644 Binary files a/docs/figures/example02/fig01_data_overview.png and b/docs/figures/example02/fig01_data_overview.png differ diff --git a/docs/figures/example02/fig02_lag_and_model_comparison.png b/docs/figures/example02/fig02_lag_and_model_comparison.png index f43947ff..d0c44b43 100644 Binary files a/docs/figures/example02/fig02_lag_and_model_comparison.png and b/docs/figures/example02/fig02_lag_and_model_comparison.png differ diff --git a/docs/figures/example03/fig05_stimulus_effect_surfaces.png b/docs/figures/example03/fig05_stimulus_effect_surfaces.png index 24274b27..4e106d80 100644 Binary files a/docs/figures/example03/fig05_stimulus_effect_surfaces.png and b/docs/figures/example03/fig05_stimulus_effect_surfaces.png differ diff --git a/examples/paper/example01_mepsc_poisson.py b/examples/paper/example01_mepsc_poisson.py index a74456f4..885cb484 100644 --- a/examples/paper/example01_mepsc_poisson.py +++ b/examples/paper/example01_mepsc_poisson.py @@ -83,7 +83,7 @@ def _maybe_export(fig, export_dir: Path | None, name: str, dpi: int = 250): if export_dir is not None: export_dir.mkdir(parents=True, exist_ok=True) png_path = export_dir / f"{name}.png" - fig.savefig(png_path, dpi=dpi, bbox_inches="tight") + fig.savefig(png_path, dpi=dpi, facecolor="w", edgecolor="none") saved.append(png_path) print(f" Saved {png_path}") return saved @@ -149,8 +149,8 @@ def run_example01(*, export_figures: bool = False, export_dir: Path | None = Non spikeCollConst.plot(handle=ax) ax.set_title("Neural Raster with constant Mg$^{2+}$ Concentration", fontweight="bold", fontsize=12) - ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") - ax.set_ylabel("mEPSCs", fontsize=12, fontweight="bold") + ax.set_xlabel("time [s]", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_ylabel("mEPSCs", fontname="Arial", fontsize=12, fontweight="bold") ax.set_yticks([0, 1]) # (2,2,2): Inverse Gaussian transform (ACF) @@ -165,13 +165,10 @@ def run_example01(*, export_figures: bool = False, export_dir: Path | None = Non ax.plot(np.asarray(lam.time, dtype=float), np.asarray(lam.data[:, 0], dtype=float), "b", linewidth=2) - ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") - ax.set_ylabel(lam.ylabel if hasattr(lam, "ylabel") else "spikes/sec", - fontsize=12, fontweight="bold") - ax.legend(["$\\lambda_{const}$"], loc="upper right") + ax.set_xlabel("time [s]", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_ylabel(r"$\lambda(t)$ [Hz]", fontname="Arial", fontsize=12, fontweight="bold") + ax.legend(["$\\lambda_{const}$"], loc="upper right", fontsize=14) - fig1.suptitle("Example 01 — Figure 1: Constant Mg$^{2+}$ Summary", - fontsize=14, fontweight="bold") fig1.tight_layout() figure_files.extend(_maybe_export(fig1, export_dir, "fig01_constant_mg_summary")) @@ -194,18 +191,17 @@ def run_example01(*, export_figures: bool = False, export_dir: Path | None = Non ax = axes2[0] nstConst.plot(handle=ax) ax.set_yticks([0, 1]) - ax.set_ylabel("mEPSCs", fontsize=12, fontweight="bold") + ax.set_ylabel("mEPSCs", fontname="Arial", fontsize=12, fontweight="bold") ax.set_title("Neural Raster with constant Mg$^{2+}$ Concentration", fontweight="bold", fontsize=12) ax = axes2[1] nstWashout.plot(handle=ax) ax.set_yticks([0, 1]) - ax.set_ylabel("mEPSCs", fontsize=12, fontweight="bold") + ax.set_ylabel("mEPSCs", fontname="Arial", fontsize=12, fontweight="bold") ax.set_title("Neural Raster with decreasing Mg$^{2+}$ Concentration", fontweight="bold", fontsize=12) - fig2.suptitle("Example 01 — Figure 2: Constant vs Decreasing Mg$^{2+}$", fontsize=14, fontweight="bold") fig2.tight_layout() figure_files.extend(_maybe_export(fig2, export_dir, "fig02_washout_raster_overview")) @@ -281,13 +277,11 @@ def run_example01(*, export_figures: bool = False, export_dir: Path | None = Non if lam.data.shape[1] > 1: ax.plot(t, np.asarray(lam.data[:, 1], dtype=float), "g", linewidth=2) ax.set_ylim(0, 5) - ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") - ax.set_ylabel(lam.ylabel if hasattr(lam, "ylabel") else "spikes/sec", - fontsize=12, fontweight="bold") - ax.legend(["$\\lambda_{const}$", "$\\lambda_{const-epoch}$"], loc="upper right") + ax.set_xlabel("time [s]", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_ylabel(r"$\lambda(t)$ [Hz]", fontname="Arial", fontsize=12, fontweight="bold") + ax.legend(["$\\lambda_{const}$", "$\\lambda_{const-epoch}$"], + loc="upper right", fontsize=14) - fig3.suptitle("Example 01 — Figure 3: Piecewise Baseline Comparison", - fontsize=14, fontweight="bold") fig3.tight_layout() figure_files.extend(_maybe_export(fig3, export_dir, "fig03_piecewise_baseline_comparison")) diff --git a/examples/paper/example02_whisker_stimulus_thalamus.py b/examples/paper/example02_whisker_stimulus_thalamus.py index c7267f51..99447003 100644 --- a/examples/paper/example02_whisker_stimulus_thalamus.py +++ b/examples/paper/example02_whisker_stimulus_thalamus.py @@ -60,7 +60,7 @@ def _maybe_export(fig, export_dir: Path | None, name: str, dpi: int = 250): if export_dir is not None: export_dir.mkdir(parents=True, exist_ok=True) png_path = export_dir / f"{name}.png" - fig.savefig(png_path, dpi=dpi, bbox_inches="tight") + fig.savefig(png_path, dpi=dpi, facecolor="w", edgecolor="none") saved.append(png_path) print(f" Saved {png_path}") return saved @@ -150,28 +150,32 @@ def run_example02(*, export_figures: bool = False, export_dir: Path | None = Non nstView.setMaxTime(viewWindow) nstView.plot(handle=ax) ax.set_yticks([0, 1]) - ax.set_title("Neural Raster", fontweight="bold", fontsize=12) - ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") - ax.set_ylabel("Spikes", fontsize=12, fontweight="bold") + ax.set_title("Neural Raster", fontweight="bold", fontsize=16, fontname="Arial") + ax.set_xlabel("") + ax.set_xticklabels([]) + ax.set_ylabel("spikes", fontname="Arial", fontsize=12, fontweight="bold") - # Subplot 2: Stimulus displacement (first 21 s) + # Subplot 2: Stimulus displacement (first 21 s, black line matching MATLAB) ax = axes1[1] stimView = stim.getSigInTimeWindow(0, viewWindow) - stimView.plot(handle=ax) - ax.set_ylabel("Displacement [mm]", fontsize=12, fontweight="bold") - ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") - - # Subplot 3: Stimulus velocity (derivative, first 21 s) + stimView.plot(handle=ax, plotPropsIn=[["k"]]) + ax.get_legend().remove() if ax.get_legend() else None + ax.set_ylabel("Displacement [mm]", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_xlabel("") + ax.set_xticklabels([]) + ax.set_title("Stimulus - Whisker Displacement", fontweight="bold", fontsize=16, fontname="Arial") + + # Subplot 3: Stimulus velocity (derivative, first 21 s, black line matching MATLAB) ax = axes1[2] stimDeriv = stim.derivative stimDerivView = stimDeriv.getSigInTimeWindow(0, viewWindow) - stimDerivView.plot(handle=ax) + stimDerivView.plot(handle=ax, plotPropsIn=[["k"]]) + ax.get_legend().remove() if ax.get_legend() else None ax.set_ylim(-80, 80) - ax.set_ylabel("Velocity", fontsize=12, fontweight="bold") - ax.set_xlabel("time [s]", fontsize=12, fontweight="bold") + ax.set_ylabel("Displacement Velocity [mm/s]", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_xlabel("time [s]", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_title("Displacement Velocity", fontweight="bold", fontsize=16, fontname="Arial") - fig1.suptitle("Example 02 — Figure 1: Data Overview", - fontsize=14, fontweight="bold") fig1.tight_layout() figure_files.extend(_maybe_export( fig1, export_dir, "fig01_data_overview")) @@ -380,8 +384,6 @@ def run_example02(*, export_figures: bool = False, export_dir: Path | None = Non ax_coeff = fig2.add_subplot(gs[4:7, 1]) modelCompare.plotCoeffs(handle=ax_coeff) - fig2.suptitle("Example 02 — Figure 2: Lag & History Selection", - fontsize=14, fontweight="bold") figure_files.extend(_maybe_export( fig2, export_dir, "fig02_lag_and_model_comparison")) diff --git a/examples/paper/example03_psth_and_ssglm.py b/examples/paper/example03_psth_and_ssglm.py index 34c39a80..2dea60f5 100644 --- a/examples/paper/example03_psth_and_ssglm.py +++ b/examples/paper/example03_psth_and_ssglm.py @@ -199,36 +199,38 @@ def run_part_a(data_dir, export_dir=None): # Top-left: CIF ax = axes1[0, 0] ax.plot(time, lambdaData, "b", linewidth=2) - ax.set_title("Simulated CIF", fontweight="bold", fontsize=14) - ax.set_xlabel("time [s]") - ax.set_ylabel("spikes/sec") + ax.set_title("Simulated Conditional Intensity Function (CIF)", + fontweight="bold", fontsize=14, fontname="Arial") + ax.set_xlabel("time [s]", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_ylabel(r"$\lambda(t)$ [spikes/sec]", fontname="Arial", fontsize=12, fontweight="bold") + ax.legend([r"$\lambda_1$"], loc="upper right", fontsize=14) # Bottom-left: simulated raster ax = axes1[1, 0] spikeCollSim.plot(handle=ax) ax.set_yticks(range(0, numRealizations + 1, 5)) - ax.set_title(f"{numRealizations} Simulated Sample Paths", - fontweight="bold", fontsize=14) - ax.set_xlabel("time [s]") - ax.set_ylabel("Trial [k]") + ax.set_title(f"{numRealizations} Simulated Point Process Sample Paths", + fontweight="bold", fontsize=14, fontname="Arial") + ax.set_xlabel("time [s]", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_ylabel("Trial [k]", fontname="Arial", fontsize=12, fontweight="bold") # Top-right: real cell 6 raster ax = axes1[0, 1] spikeCollReal1.plot(handle=ax) ax.set_yticks(range(0, numTrials + 1, 2)) ax.set_title("Response to Moving Visual Stimulus (Neuron 6)", - fontweight="bold", fontsize=14) - ax.set_xlabel("time [s]") - ax.set_ylabel("Trial [k]") + fontweight="bold", fontsize=14, fontname="Arial") + ax.set_xlabel("time [s]", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_ylabel("Trial [k]", fontname="Arial", fontsize=12, fontweight="bold") # Bottom-right: real cell 1 raster ax = axes1[1, 1] spikeCollReal2.plot(handle=ax) ax.set_yticks(range(0, numTrials + 1, 2)) ax.set_title("Response to Moving Visual Stimulus (Neuron 1)", - fontweight="bold", fontsize=14) - ax.set_xlabel("time [s]") - ax.set_ylabel("Trial [k]") + fontweight="bold", fontsize=14, fontname="Arial") + ax.set_xlabel("time [s]", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_ylabel("Trial [k]", fontname="Arial", fontsize=12, fontweight="bold") fig1.tight_layout() @@ -496,7 +498,8 @@ def run_part_b(data_dir, export_dir=None): # ------------------------------------------------------------------ # Figure 5: True/PSTH/SSGLM stimulus effect surfaces # Match MATLAB: mesh(trial, time, data) with view([90 -90]) → top-down - # Using pcolormesh for clean 2D rendering matching MATLAB's top-down mesh + # MATLAB orientation: time [s] on x-axis, Trial [k] on y-axis + # (matches fig03 bottom-panel "True Conditional Intensity Function") # ------------------------------------------------------------------ fig5, axes5 = plt.subplots(3, 1, figsize=(14, 9)) trial_axis = np.arange(1, numRealizations + 1) @@ -509,12 +512,17 @@ def run_part_b(data_dir, export_dir=None): ] for ax, (data, title_str) in zip(axes5, surfaces): # data is (T, K) — time on rows, trials on columns - # MATLAB: imagesc shows trial on x, time on y - ax.pcolormesh(trial_axis, basis_time[:T_act], data, cmap="viridis", + # Transpose so trials are on y-axis and time on x-axis, + # matching MATLAB nSTATPaperExamples_15.png orientation. + ax.pcolormesh(basis_time[:T_act], trial_axis, data.T, cmap="viridis", shading="auto") - ax.set_ylabel("time [s]") + ax.set_xlabel("time [s]") + ax.set_ylabel("Trial [k]") ax.set_title(title_str, fontweight="bold", fontsize=14) - axes5[-1].set_xlabel("Trial [k]") + # Remove redundant per-subplot x-labels except the bottom one + for ax in axes5[:-1]: + ax.set_xlabel("") + ax.set_xticklabels([]) fig5.tight_layout() print(" Figure 5: Stimulus effect surfaces (top-down mesh)") @@ -613,7 +621,7 @@ def run_example03(*, export_figures: bool = False, export_dir: Path | None = Non export_dir.mkdir(parents=True, exist_ok=True) for name, fig in all_figs.items(): path = export_dir / f"{name}.png" - fig.savefig(str(path), dpi=150, bbox_inches="tight") + fig.savefig(str(path), dpi=250, facecolor="w", edgecolor="none") print(f" Saved {path}") plt.show() diff --git a/examples/paper/example04_place_cells_continuous_stimulus.py b/examples/paper/example04_place_cells_continuous_stimulus.py index c5062e23..085c0fd7 100644 --- a/examples/paper/example04_place_cells_continuous_stimulus.py +++ b/examples/paper/example04_place_cells_continuous_stimulus.py @@ -252,15 +252,19 @@ def run_example04(*, export_figures: bool = False, export_dir: Path | None = Non fig1, axes1 = plt.subplots(2, 2, figsize=(12, 10)) for i, cidx in enumerate(exampleCells): ax = axes1.flat[i] - ax.plot(x1, y1, "b-", linewidth=0.5, alpha=0.5) + h1, = ax.plot(x1, y1, "b", linewidth=0.5) n = neurons1[min(cidx, nCells1 - 1)] xn = np.asarray(n["xN"].item(), dtype=float).ravel() yn = np.asarray(n["yN"].item(), dtype=float).ravel() - ax.plot(xn, yn, "r.", markersize=7) - ax.set_title(f"Cell {cidx + 1}", fontweight="bold", fontsize=12) + h2, = ax.plot(xn, yn, "r.", markersize=7) + ax.set_title(f"Cell#{cidx + 1}", fontweight="bold", fontsize=12, fontname="Arial") + ax.set_xlabel("X Position") + ax.set_ylabel("Y Position") + ax.set_xticks([-1, -0.5, 0, 0.5, 1]) + ax.set_yticks([-1, -0.5, 0, 0.5, 1]) ax.set_aspect("equal") - fig1.suptitle("Animal 1 — Example Place Cells", fontweight="bold", - fontsize=14) + if i == 3: + ax.legend([h1, h2], ["Animal Path", "Location at time of spike"]) fig1.tight_layout() # ================================================================== @@ -359,10 +363,12 @@ def _plot_heatmaps(fit_results, nCells, title_prefix, design_gauss, axesG[row, col].set_visible(False) axesZ[row, col].set_visible(False) - figG.suptitle(f"{title_prefix} — Gaussian Place Fields", - fontweight="bold", fontsize=14) - figZ.suptitle(f"{title_prefix} — Zernike Place Fields", - fontweight="bold", fontsize=14) + # Match MATLAB sgtitle format: "Gaussian Place Fields - Animal#N" + animal_num = title_prefix.replace("Animal ", "") + figG.suptitle(f"Gaussian Place Fields - Animal#{animal_num}", + fontweight="bold", fontsize=12) + figZ.suptitle(f"Zernike Place Fields - Animal#{animal_num}", + fontweight="bold", fontsize=12) figG.tight_layout() figZ.tight_layout() return figG, figZ @@ -430,7 +436,7 @@ def _plot_heatmaps(fit_results, nCells, title_prefix, design_gauss, export_dir.mkdir(parents=True, exist_ok=True) for name, fig in all_figs.items(): path = export_dir / f"{name}.png" - fig.savefig(str(path), dpi=150, bbox_inches="tight") + fig.savefig(str(path), dpi=250, facecolor="w", edgecolor="none") print(f" Saved {path}") plt.show() diff --git a/examples/paper/example05_decoding_ppaf_pphf.py b/examples/paper/example05_decoding_ppaf_pphf.py index 112f0426..8430129c 100644 --- a/examples/paper/example05_decoding_ppaf_pphf.py +++ b/examples/paper/example05_decoding_ppaf_pphf.py @@ -896,7 +896,7 @@ def run_example05(*, export_figures=False, export_dir=None, show=False): ] for i, fig in enumerate(figures): path = export_dir / f"{fig_names[i]}.png" - fig.savefig(path, dpi=150, bbox_inches="tight") + fig.savefig(path, dpi=250, facecolor="w", edgecolor="none") print(f" Saved: {path}") if show: diff --git a/nstat/fit.py b/nstat/fit.py index 8e6bc3a3..6cb51e2e 100644 --- a/nstat/fit.py +++ b/nstat/fit.py @@ -1208,9 +1208,9 @@ def KSPlot(self, fit_num: int | list[int] | None = None, handle=None): ideal_ref = np.asarray(first_diag["ks_ideal"], dtype=float) ci_ref = np.asarray(first_diag["ks_ci"], dtype=float) if ideal_ref.size: - ax.plot([0.0, 1.0], [0.0, 1.0], color="0.3", linewidth=1.0, linestyle="-.") - ax.plot(ideal_ref, np.clip(ideal_ref + ci_ref, 0.0, 1.0), color="tab:red", linewidth=1.0) - ax.plot(ideal_ref, np.clip(ideal_ref - ci_ref, 0.0, 1.0), color="tab:red", linewidth=1.0) + ax.plot([0.0, 1.0], [0.0, 1.0], "k-.", linewidth=1.0) + ax.plot(ideal_ref, np.clip(ideal_ref + ci_ref, 0.0, 1.0), "r", linewidth=1.0) + ax.plot(ideal_ref, np.clip(ideal_ref - ci_ref, 0.0, 1.0), "r", linewidth=1.0) # Plot each model's empirical CDF (matching MATLAB colour cycle) labels_for_legend: list[str] = [] @@ -1229,13 +1229,19 @@ def KSPlot(self, fit_num: int | list[int] | None = None, handle=None): labels_for_legend.append(label) if handles_for_legend: - ax.legend(handles_for_legend, labels_for_legend, loc="lower right", fontsize=10) + ax.legend(handles_for_legend, labels_for_legend, loc="lower right", fontsize=14) ax.set_xlim(0.0, 1.0) ax.set_ylim(0.0, 1.0) - ax.set_xlabel("Ideal Uniform CDF") - ax.set_ylabel("Empirical CDF") - ax.set_title("KS Plot of Rescaled ISIs\nwith 95% Confidence Intervals", fontweight="bold", fontsize=11) + ax.set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) + ax.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) + ax.set_xlabel("Ideal Uniform CDF", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_ylabel("Empirical CDF", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_title("KS Plot of Rescaled ISIs\nwith 95% Confidence Intervals", + fontweight="bold", fontsize=11, fontname="Arial") + ax.tick_params(length=6, width=1) + for spine in ax.spines.values(): + spine.set_linewidth(1.0) return ax def plotResidual(self, fit_num: int | list[int] | None = None, handle=None): @@ -1271,14 +1277,18 @@ def plotResidual(self, fit_num: int | list[int] | None = None, handle=None): ) ax.axhline(0.0, color="0.4", linewidth=1.0, linestyle="--") if len(fit_nums) > 1: - ax.legend(loc="upper right", fontsize=8) - ax.set_xlabel("time [s]") - ax.set_ylabel("count residual") - ax.set_title("Point Process Residual", fontweight="bold", fontsize=11) + ax.legend(loc="upper right", fontsize=14) + ax.set_xlabel("time [s]", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_ylabel("count residual", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_title("Point Process Residual", + fontweight="bold", fontsize=11, fontname="Arial") # Match MATLAB: symmetric y-axis with 10% margin ylims = ax.get_ylim() max_y = max(abs(ylims[0]), abs(ylims[1])) * 1.1 ax.set_ylim(-max_y, max_y) + ax.tick_params(length=6, width=1) + for spine in ax.spines.values(): + spine.set_linewidth(1.0) return ax def plotInvGausTrans(self, fit_num: int | list[int] | None = None, handle=None): @@ -1321,17 +1331,20 @@ def plotInvGausTrans(self, fit_num: int | list[int] | None = None, handle=None): if ci_val is None: ci_val = float(diag["acf_ci"]) - # Plot 95% CI lines without legend entries + # Plot 95% CI lines (solid red, matching MATLAB) if ci_val is not None: - ax.axhline(ci_val, color="0.4", linewidth=0.8, linestyle="--") - ax.axhline(-ci_val, color="0.4", linewidth=0.8, linestyle="--") - ax.axhline(0.0, color="0.4", linewidth=0.8) + ax.axhline(ci_val, color="r", linewidth=1.0) + ax.axhline(-ci_val, color="r", linewidth=1.0) if legend_handles: - ax.legend(legend_handles, legend_labels, loc="upper right", fontsize=8) - ax.set_xlabel("lag") - ax.set_ylabel("autocorrelation") - ax.set_title("Autocorrelation Function\nof Rescaled ISIs\nwith 95% CIs") + ax.legend(legend_handles, legend_labels, loc="upper right", fontsize=14) + ax.set_xlabel(r"$\Delta\tau$ [sec]", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_ylabel(r"$ACF(\Phi^{-1}(u_n))$", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_title("Autocorrelation Function\nof Rescaled ISIs\nwith 95% CIs", + fontweight="bold", fontsize=11, fontname="Arial") + ax.tick_params(length=6, width=1) + for spine in ax.spines.values(): + spine.set_linewidth(1.0) return ax def plotSeqCorr(self, fit_num: int | list[int] | None = None, handle=None): @@ -1390,13 +1403,19 @@ def plotSeqCorr(self, fit_num: int | list[int] | None = None, handle=None): legend_labels.append(label) if legend_handles: - ax.legend(legend_handles, legend_labels, loc="upper right", fontsize=10) + ax.legend(legend_handles, legend_labels, loc="upper right", fontsize=14) - ax.set_title("Sequential Correlation of\nRescaled ISIs", fontweight="bold", fontsize=11) - ax.set_xlabel("$U_j$") - ax.set_ylabel("$U_{j+1}$") + ax.set_title("Sequential Correlation of\nRescaled ISIs", + fontweight="bold", fontsize=11, fontname="Arial") + ax.set_xlabel("$u_j$", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_ylabel("$u_{j+1}$", fontname="Arial", fontsize=12, fontweight="bold") ax.set_xlim(0, 1) ax.set_ylim(0, 1) + ax.set_xticks([0.0, 0.25, 0.5, 0.75, 1.0]) + ax.set_yticks([0.0, 0.25, 0.5, 0.75, 1.0]) + ax.tick_params(length=6, width=1) + for spine in ax.spines.values(): + spine.set_linewidth(1.0) return ax def plotCoeffs(self, fit_num: int | list[int] | None = None, handle=None, plotSignificance: int = 1): @@ -1405,7 +1424,7 @@ def plotCoeffs(self, fit_num: int | list[int] | None = None, handle=None, plotSi Matches Matlab FitResult.plotCoeffs: when *fit_num* is ``None`` (default) all fits are overlaid with per-fit colours, errorbar plots with ±1 SE, and asterisks (*) above significant coefficients - (p < 0.05). + (p < 0.05). Includes a legend matching Matlab's lambda dataLabels. """ if fit_num is None: fit_nums = list(range(1, self.numResults + 1)) @@ -1428,6 +1447,10 @@ def plotCoeffs(self, fit_num: int | list[int] | None = None, handle=None, plotSi label_to_x = {lbl: float(j + 1) for j, lbl in enumerate(all_labels)} xpos_all = np.arange(1, len(all_labels) + 1, dtype=float) + # Build legend labels from lambda dataLabels (matching MATLAB) + lambda_labels = list(self.lambda_signal.dataLabels) if getattr(self.lambda_signal, "dataLabels", None) else [] + errorbar_handles = [] + for i, fn in enumerate(fit_nums): diag = self._compute_diagnostics(fn) coeffs = np.asarray(diag["coefficients"], dtype=float) @@ -1437,18 +1460,31 @@ def plotCoeffs(self, fit_num: int | list[int] | None = None, handle=None, plotSi xpos = np.array([label_to_x[lbl] for lbl in fit_labels]) color = _SEQ_COLORS[i % len(_SEQ_COLORS)] valid_se = np.where(np.isfinite(se), se, 0.0) - ax.errorbar(xpos, coeffs, yerr=valid_se, fmt=".", color=color, - linewidth=1.0, markersize=8.0, capsize=3.0) + # Larger markers and thicker error bars to match MATLAB visibility + h = ax.errorbar(xpos, coeffs, yerr=valid_se, fmt=".", color=color, + linewidth=1.5, markersize=12.0, capsize=5.0, + markeredgecolor=color, markerfacecolor=color) + errorbar_handles.append(h) if plotSignificance and np.any(sig > 0): ylims = ax.get_ylim() y_star = 0.8 * ylims[1] - i * 0.1 sig_idx = xpos[sig.astype(bool)] - ax.plot(sig_idx, np.full(sig_idx.size, y_star), "*", color=color, markersize=10.0) + ax.plot(sig_idx, np.full(sig_idx.size, y_star), "*", + color=color, markersize=14.0) ax.set_xticks(xpos_all) - ax.set_xticklabels(all_labels, rotation=45, ha="right", fontsize=6) - ax.set_ylabel("GLM Fit Coefficients") - ax.set_title("GLM Coefficients", fontweight="bold", fontsize=11) + ax.set_xticklabels(all_labels, rotation=90, ha="center", fontsize=6) + ax.set_ylabel("GLM Fit Coefficients", fontname="Arial", fontsize=12, fontweight="bold") + ax.set_title("GLM Coefficients with 95% CIs (* p<0.05)", + fontweight="bold", fontsize=11, fontname="Arial") + ax.grid(axis="y", alpha=0.3) + ax.tick_params(length=6, width=1) + for spine in ax.spines.values(): + spine.set_linewidth(1.0) + # Add legend matching MATLAB: uses lambda dataLabels with NorthEast placement + if errorbar_handles and lambda_labels: + legend_labels = [lambda_labels[min(fn - 1, len(lambda_labels) - 1)] for fn in fit_nums] + ax.legend(errorbar_handles, legend_labels, loc="upper right", fontsize=10) return ax def plotCoeffsWithoutHistory(self, fit_num: int = 1, sortByEpoch: int = 0, plotSignificance: int = 1, handle=None): diff --git a/nstat/plot_style.py b/nstat/plot_style.py index 923925f3..5f9bc2f5 100644 --- a/nstat/plot_style.py +++ b/nstat/plot_style.py @@ -66,9 +66,18 @@ def apply_plot_style(target=None, *, style: str = ""): figure.set_facecolor("white") for ax in axes_list: - ax.tick_params(direction="out", top=True, right=True) + # Match MATLAB nstat.applyPlotStyle: Helvetica 10pt, ticks out, layer top + ax.tick_params(direction="out", top=True, right=True, length=6, width=1) + ax.set_axisbelow(False) # layer = 'top' equivalent for spine in ax.spines.values(): spine.set_linewidth(1.0) + # Set font on existing labels and title + for label in [ax.title, ax.xaxis.label, ax.yaxis.label]: + if label.get_fontfamily() == ["sans-serif"] or not label.get_text(): + label.set_fontfamily("Helvetica") + for label in ax.get_xticklabels() + ax.get_yticklabels(): + label.set_fontsize(10) + label.set_fontfamily("Helvetica") for line in ax.get_lines(): if isinstance(line, matplotlib.lines.Line2D) and float(line.get_linewidth()) < 1.25: line.set_linewidth(1.25) @@ -79,6 +88,13 @@ def apply_plot_style(target=None, *, style: str = ""): sizes = coll.get_sizes() if sizes.size: coll.set_sizes(sizes.clip(min=30.0)) + # Fix per-axes legends + leg = ax.get_legend() + if leg is not None: + leg.set_frame_on(False) + for text in leg.get_texts(): + text.set_fontsize(10) + # Fix figure-level legends legend = None if figure is None else figure.legends for item in legend or []: item.set_frame_on(False)