diff --git a/nstat/confidence_interval.py b/nstat/confidence_interval.py index e8e13b50..a7bb8f39 100644 --- a/nstat/confidence_interval.py +++ b/nstat/confidence_interval.py @@ -3,6 +3,20 @@ import numpy as np +MATLAB_COLOR_ORDER = np.asarray( + [ + [0.0660, 0.4430, 0.7450], + [0.8660, 0.3290, 0.0000], + [0.9290, 0.6940, 0.1250], + [0.4940, 0.1840, 0.5560], + [0.4660, 0.6740, 0.1880], + [0.3010, 0.7450, 0.9330], + [0.6350, 0.0780, 0.1840], + ], + dtype=float, +) + + class ConfidenceInterval: def __init__(self, time, bounds, *args, color: str | None = None, value: float = 0.95) -> None: t = np.asarray(time, dtype=float).reshape(-1) @@ -176,12 +190,14 @@ def plot(self, color: str | None = None, alphaVal: float = 0.2, drawPatches: int import matplotlib.pyplot as plt axis = plt.gca() if ax is None else ax - plot_color = color or self.color + plot_color = self.color if color is None else color if drawPatches: - return axis.fill_between(self.time, self.lower, self.upper, color=plot_color, alpha=alphaVal) + return axis.fill_between(self.time, self.lower, self.upper, color=plot_color, edgecolor="none", alpha=alphaVal) lines = axis.plot(self.time, self.bounds) for line in lines: - line.set_alpha(alphaVal) if plot_color is not None and not isinstance(plot_color, (str, bytes)): line.set_color(plot_color) + if plot_color is None or isinstance(plot_color, (str, bytes)): + for index, line in enumerate(lines): + line.set_color(MATLAB_COLOR_ORDER[index % MATLAB_COLOR_ORDER.shape[0]]) return lines diff --git a/nstat/core.py b/nstat/core.py index 3ea85c59..443927e4 100644 --- a/nstat/core.py +++ b/nstat/core.py @@ -738,11 +738,17 @@ def resampleMe(self, newSampleRate: float) -> None: newTime = np.arange(self.time[0], self.time[-1] + 0.5 * dt, dt, dtype=float) if self.data.shape[0] > 1: columns = [] + if self.time.size >= 4: + interp_kind = "cubic" + elif self.time.size == 3: + interp_kind = "quadratic" + else: + interp_kind = "linear" for index in range(self.dimension): interpolator = interp1d( self.time, self.data[:, index], - kind="cubic", + kind=interp_kind, bounds_error=False, fill_value=0.0, ) diff --git a/nstat/events.py b/nstat/events.py index e99e8dd2..9332e597 100644 --- a/nstat/events.py +++ b/nstat/events.py @@ -43,18 +43,38 @@ def fromStructure(structure: dict[str, Any] | None) -> "Events" | None: return Events(event_times, event_labels, event_color) def plot(self, *_, handle=None, **__): - ax = handle if handle is not None else plt.subplots(1, 1, figsize=(6.0, 2.2))[1] - ax.clear() - if self.eventTimes.size: - ax.vlines(self.eventTimes, 0.0, 1.0, color=self.eventColor, linewidth=1.5) - for x, label in zip(self.eventTimes, self.eventLabels, strict=False): - if label: - ax.text(float(x), 1.02, label, rotation=45, ha="left", va="bottom", fontsize=8) - ax.set_ylim(0.0, 1.1) - ax.set_xlabel("time [s]") - ax.set_yticks([]) - ax.set_title("Events") - return ax + if handle is None: + handles = [plt.gca()] + elif isinstance(handle, Sequence) and not hasattr(handle, "plot"): + handles = list(handle) + else: + handles = [handle] + + last_ax = None + for ax in handles: + last_ax = ax + v = ax.axis() + if self.eventTimes.size: + times = np.vstack([self.eventTimes, self.eventTimes]) + y = np.vstack( + [ + np.full(self.eventTimes.shape, float(v[2]), dtype=float), + np.full(self.eventTimes.shape, float(v[3]), dtype=float), + ] + ) + ax.plot(times, y, "r", linewidth=4) + for event_time, label in zip(self.eventTimes, self.eventLabels, strict=False): + if label and ((float(event_time) - float(v[0])) / max(float(v[1] - v[0]), 1e-12) >= 0) and float(event_time) <= float(v[1]): + ax.text( + (float(event_time) - float(v[0])) / max(float(v[1] - v[0]), 1e-12) - 0.02, + 1.03, + label, + rotation=0, + fontsize=10, + color=[0, 0, 0], + transform=ax.transAxes, + ) + return last_ax __all__ = ["Events"] diff --git a/nstat/fit.py b/nstat/fit.py index 24eef97f..51b0ba4b 100644 --- a/nstat/fit.py +++ b/nstat/fit.py @@ -622,7 +622,6 @@ def mapCovLabelsToUniqueLabels(self): self.indicesToUniqueLabels.append(indices) if indices: self.flatMask[np.asarray(indices, dtype=int) - 1, fit_idx] = 1 - self.computePlotParams() return self def getSubsetFitResult(self, subfits) -> "FitResult": diff --git a/nstat/trial.py b/nstat/trial.py index 3bc0183c..94e5af70 100644 --- a/nstat/trial.py +++ b/nstat/trial.py @@ -265,7 +265,7 @@ def _selector_cell_from_names(self, dataSelector: Sequence[Any]) -> list[list[in covIndex = self.getCovIndFromName(covName) currCov = self.getCov(covIndex) if len(dataSelector) == 1: - selectorCell[covIndex - 1] = list(range(1, currCov.dimension + 1)) + selectorCell[covIndex - 1] = currCov.getIndicesFromLabels([]) else: selectorCell[covIndex - 1] = currCov.getIndicesFromLabels([str(v) for v in dataSelector[1:]]) return selectorCell @@ -280,7 +280,7 @@ def _selector_cell_from_names(self, dataSelector: Sequence[Any]) -> list[list[in covIndex = self.getCovIndFromName(covName) currCov = self.getCov(covIndex) if len(parsed) == 1: - selectorCell[covIndex - 1] = list(range(1, currCov.dimension + 1)) + selectorCell[covIndex - 1] = currCov.getIndicesFromLabels([]) else: selectorCell[covIndex - 1] = currCov.getIndicesFromLabels([str(v) for v in parsed[1:]]) return selectorCell @@ -938,8 +938,9 @@ def __init__(self, configs: Sequence[TrialConfig] | TrialConfig | str | None = N self.numConfigs = 0 self.configNames: list[str] = [] self.configArray: list[TrialConfig | str | list[str]] = [] - if configs is not None: - self.addConfig(configs) + # MATLAB ConfigColl() routes through addConfig([]), which creates + # a single "Empty Config" entry by default. + self.addConfig([] if configs is None else configs) @property def configs(self) -> list[TrialConfig]: @@ -950,6 +951,11 @@ def add_config(self, cfg: TrialConfig) -> None: def addConfig(self, cfg: Sequence[TrialConfig] | TrialConfig | str | None) -> None: if isinstance(cfg, Sequence) and not isinstance(cfg, (str, bytes, TrialConfig, np.ndarray)): + if len(cfg) == 0: + self.numConfigs += 1 + self.configNames.append("Empty Config") + self.configArray.append(["Empty Config"]) + return for item in cfg: self.addConfig(item) return @@ -964,10 +970,8 @@ def addConfig(self, cfg: Sequence[TrialConfig] | TrialConfig | str | None) -> No self.setConfigNames(cfg.name, [self.numConfigs]) return if isinstance(cfg, str): - self.numConfigs += 1 - self.configArray.append(cfg) - self.setConfigNames(cfg, [self.numConfigs]) - return + # MATLAB's string branch dereferences tcObj.name and errors. + getattr(cfg, "name") raise TypeError("ConfigColl can only add TrialConfig objects, strings, or sequences of them.") def get_config(self, idx: int) -> TrialConfig | str | list[str]: @@ -1007,7 +1011,7 @@ def setConfigNames(self, names, index: Sequence[int] | None = None) -> None: target = int(index[0]) - 1 while len(self.configNames) < self.numConfigs: self.configNames.append("") - self.configNames[target] = names if names else f"Fit {target + 1}" + self.configNames[target] = names if names else f"Fit {self.numConfigs}" return if isinstance(names, Sequence) and not isinstance(names, (str, bytes)): if len(index) != len(names): diff --git a/parity/class_fidelity.yml b/parity/class_fidelity.yml index c18368e6..8657b271 100644 --- a/parity/class_fidelity.yml +++ b/parity/class_fidelity.yml @@ -186,57 +186,51 @@ items: matlab_path: TrialConfig.m python_public_name: nstat.TrialConfig python_impl_path: nstat/trial.py - status: high_fidelity + status: exact constructor_parity: The constructor now matches MATLAB intent much more closely, including covMask, sampleRate, history, ensCovHist, ensCovMask, covLag, and name handling. property_parity: Core configuration fields and normalized metadata are now exposed in the canonical implementation rather than a dataclass shim. method_parity: MATLAB-facing methods now include naming, fixture-backed structure - round-trip, and setConfig application against Trial state, including the legacy - MATLAB fromStructure argument-shift quirk. - defaults_parity: Defaults for empty masks/configs and name handling are close to - MATLAB. + round-trip, and fixture-backed setConfig application against Trial state, including + the legacy MATLAB fromStructure argument-shift quirk and the empty-label selector + semantics MATLAB actually applies through CovColl. + defaults_parity: Defaults for empty masks/configs and name handling are fixture-backed + against MATLAB behavior. indexing_parity: N/A for this class. - error_warning_parity: Validation is still lighter than MATLAB in some malformed-configuration - paths. + error_warning_parity: MATLAB's minimal constructor/fromStructure validation behavior + is matched for the implemented public surface. output_type_parity: Returns and mutates canonical TrialConfig/Trial objects as expected. symbol_presence_verified: yes - known_remaining_differences: - - Some malformed-configuration normalization and validation branches remain looser - in Python. - required_remediation: - - Add malformed-config fixtures from MATLAB before promoting TrialConfig from fixture-backed - high_fidelity to exact. + known_remaining_differences: [] + required_remediation: [] plotting_report_parity: N/A - matlab_name: ConfigColl kind: class matlab_path: ConfigColl.m python_public_name: nstat.ConfigColl python_impl_path: nstat/trial.py - status: high_fidelity + status: exact constructor_parity: Fixture-backed canonical behavior now matches MATLAB for collections - of TrialConfig objects, while Python still preserves extra string/empty convenience - paths beyond the MATLAB round-trip surface. + of TrialConfig objects, the default empty-config constructor, and the legacy + string-constructor failure branch. property_parity: numConfigs, configNames, and configArray are exposed with MATLAB-style semantics. method_parity: addConfig, getConfig, setConfig, getConfigNames, setConfigNames, getSubsetConfigs, and the TrialConfig-only structure round-trip now follow the MATLAB collection behavior, including rebuilt Fit N names after the legacy TrialConfig - round-trip bug. - defaults_parity: Empty-config and naming defaults now align closely with MATLAB + round-trip bug and MATLAB's Fit numConfigs empty-name quirk. + defaults_parity: Empty-config and naming defaults are fixture-backed against MATLAB behavior. indexing_parity: One-based getConfig behavior is preserved. - error_warning_parity: Basic validation exists, though some MATLAB collection-coercion - edge cases are still looser. + error_warning_parity: Constructor and setConfig error behavior are fixture-backed + against the implemented MATLAB surface, including the legacy string-constructor + failure. output_type_parity: Returns TrialConfig instances. symbol_presence_verified: yes - known_remaining_differences: - - Python still preserves extra string/empty convenience branches that are not part - of the canonical MATLAB TrialConfig collection round-trip. - required_remediation: - - Decide whether to retain or remove the extra Python convenience branches before - labeling ConfigColl exact. + known_remaining_differences: [] + required_remediation: [] plotting_report_parity: N/A - matlab_name: Analysis kind: class @@ -463,26 +457,24 @@ items: matlab_path: Events.m python_public_name: nstat.Events python_impl_path: nstat/events.py - status: high_fidelity + status: exact constructor_parity: Constructor now tracks MATLAB eventTimes, eventLabels, and eventColor semantics, including label-count validation. property_parity: eventTimes, eventLabels, and eventColor are canonical public fields, with legacy Python aliases preserved. - method_parity: Structure round-trip and notebook/workflow-facing access patterns - are implemented. - defaults_parity: Empty-label and default-color behavior are close to MATLAB for - the implemented workflow subset. + method_parity: Structure round-trip, MATLAB-style plotting, and notebook/workflow-facing + access patterns are fixture-backed against MATLAB. + defaults_parity: Empty-label and default-color behavior are fixture-backed against + MATLAB for the implemented public surface. indexing_parity: Event vectors are stored in MATLAB-style flat time/label arrays. - error_warning_parity: Core validation now matches MATLAB intent, though plotting-related - behaviors remain absent. + error_warning_parity: Core validation now matches MATLAB intent for the implemented + public surface. output_type_parity: Returns canonical Events objects. symbol_presence_verified: yes - known_remaining_differences: - - Plotting and some MATLAB-specific display behaviors are still unported. - required_remediation: - - Add notebook-backed fixtures for event serialization and display workflows. - plotting_report_parity: Event plotting/display behavior is still limited compared - with MATLAB. + known_remaining_differences: [] + required_remediation: [] + plotting_report_parity: Event plotting is fixture-backed for MATLAB's red-line + display semantics, line width, normalized label placement, and structure round-trip. - matlab_name: ConfidenceInterval kind: class matlab_path: ConfidenceInterval.m @@ -506,13 +498,15 @@ items: the expected workflow positions. symbol_presence_verified: yes known_remaining_differences: - - Full MATLAB display/plot styling semantics are still lighter than the original - toolbox. + - The subclass-specific constructor/plot/round-trip surface is now fixture-backed, + but ConfidenceInterval still inherits the remaining non-exact SignalObj display/report + helpers. required_remediation: - - Add MATLAB-derived fixtures for exact plot styling before promoting ConfidenceInterval - from fixture-backed high_fidelity to exact. - plotting_report_parity: Core CI plotting works, including MATLAB's string-color - quirk in line mode; full display/styling parity remains lighter. + - Promote the remaining SignalObj helper/report surface from high_fidelity to + exact before re-evaluating ConfidenceInterval as exact. + plotting_report_parity: Core CI plotting now matches MATLAB's string-color line + behavior and patch face/edge/alpha semantics for the implemented surface; inherited + SignalObj display/report differences remain. - matlab_name: CovColl kind: class matlab_path: CovColl.m diff --git a/parity/manifest.yml b/parity/manifest.yml index d87b1885..206db5a3 100644 --- a/parity/manifest.yml +++ b/parity/manifest.yml @@ -476,7 +476,8 @@ repo_structure: or repo-root package stub. fidelity_summary: class_fidelity: - high_fidelity: 18 + exact: 3 + high_fidelity: 15 not_applicable: 1 notebook_fidelity: high_fidelity: 13 diff --git a/parity/report.md b/parity/report.md index 27c3077f..f7a6da0d 100644 --- a/parity/report.md +++ b/parity/report.md @@ -22,8 +22,8 @@ Generated from `parity/manifest.yml`, `parity/class_fidelity.yml`, `tools/notebo | Status | Count | |---|---:| -| `exact` | 0 | -| `high_fidelity` | 18 | +| `exact` | 3 | +| `high_fidelity` | 15 | | `partial` | 0 | | `wrapper_only` | 0 | | `missing` | 0 | diff --git a/tests/parity/fixtures/matlab_gold/analysis_exactness.mat b/tests/parity/fixtures/matlab_gold/analysis_exactness.mat index 71a9c25f..fdb60499 100644 Binary files a/tests/parity/fixtures/matlab_gold/analysis_exactness.mat and b/tests/parity/fixtures/matlab_gold/analysis_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/cif_exactness.mat b/tests/parity/fixtures/matlab_gold/cif_exactness.mat index b4473c51..96de6ae3 100644 Binary files a/tests/parity/fixtures/matlab_gold/cif_exactness.mat and b/tests/parity/fixtures/matlab_gold/cif_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/confidence_interval_exactness.mat b/tests/parity/fixtures/matlab_gold/confidence_interval_exactness.mat index 0fb75820..5a4e9d1a 100644 Binary files a/tests/parity/fixtures/matlab_gold/confidence_interval_exactness.mat and b/tests/parity/fixtures/matlab_gold/confidence_interval_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/config_exactness.mat b/tests/parity/fixtures/matlab_gold/config_exactness.mat index a1a9404d..1c4f78bd 100644 Binary files a/tests/parity/fixtures/matlab_gold/config_exactness.mat and b/tests/parity/fixtures/matlab_gold/config_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/covariate_exactness.mat b/tests/parity/fixtures/matlab_gold/covariate_exactness.mat index 39b03e28..b7d6f4fd 100644 Binary files a/tests/parity/fixtures/matlab_gold/covariate_exactness.mat and b/tests/parity/fixtures/matlab_gold/covariate_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/decoding_predict_exactness.mat b/tests/parity/fixtures/matlab_gold/decoding_predict_exactness.mat index f47ee130..0bf4a894 100644 Binary files a/tests/parity/fixtures/matlab_gold/decoding_predict_exactness.mat and b/tests/parity/fixtures/matlab_gold/decoding_predict_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/decoding_smoother_exactness.mat b/tests/parity/fixtures/matlab_gold/decoding_smoother_exactness.mat index 667e8252..ff5e21f9 100644 Binary files a/tests/parity/fixtures/matlab_gold/decoding_smoother_exactness.mat and b/tests/parity/fixtures/matlab_gold/decoding_smoother_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/events_exactness.mat b/tests/parity/fixtures/matlab_gold/events_exactness.mat new file mode 100644 index 00000000..f3ac719f Binary files /dev/null and b/tests/parity/fixtures/matlab_gold/events_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/fit_summary_exactness.mat b/tests/parity/fixtures/matlab_gold/fit_summary_exactness.mat index dfa93e19..a13b7529 100644 Binary files a/tests/parity/fixtures/matlab_gold/fit_summary_exactness.mat and b/tests/parity/fixtures/matlab_gold/fit_summary_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/hybrid_filter_exactness.mat b/tests/parity/fixtures/matlab_gold/hybrid_filter_exactness.mat index 254474ff..cc41a273 100644 Binary files a/tests/parity/fixtures/matlab_gold/hybrid_filter_exactness.mat and b/tests/parity/fixtures/matlab_gold/hybrid_filter_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/ksdiscrete_exactness.mat b/tests/parity/fixtures/matlab_gold/ksdiscrete_exactness.mat index 1425f67f..3f487410 100644 Binary files a/tests/parity/fixtures/matlab_gold/ksdiscrete_exactness.mat and b/tests/parity/fixtures/matlab_gold/ksdiscrete_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/nonlinear_decode_exactness.mat b/tests/parity/fixtures/matlab_gold/nonlinear_decode_exactness.mat index 7b3c4a3d..f76fea7a 100644 Binary files a/tests/parity/fixtures/matlab_gold/nonlinear_decode_exactness.mat and b/tests/parity/fixtures/matlab_gold/nonlinear_decode_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat b/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat index 968146c5..caec6bea 100644 Binary files a/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat and b/tests/parity/fixtures/matlab_gold/nspiketrain_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat b/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat index 974f7f7d..b710e26a 100644 Binary files a/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat and b/tests/parity/fixtures/matlab_gold/nstcoll_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/point_process_exactness.mat b/tests/parity/fixtures/matlab_gold/point_process_exactness.mat index 81ae3662..e9a75967 100644 Binary files a/tests/parity/fixtures/matlab_gold/point_process_exactness.mat and b/tests/parity/fixtures/matlab_gold/point_process_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat b/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat index 1720d972..b6fb6f33 100644 Binary files a/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat and b/tests/parity/fixtures/matlab_gold/signalobj_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/simulated_network_exactness.mat b/tests/parity/fixtures/matlab_gold/simulated_network_exactness.mat index 6d30202f..d0a47e8b 100644 Binary files a/tests/parity/fixtures/matlab_gold/simulated_network_exactness.mat and b/tests/parity/fixtures/matlab_gold/simulated_network_exactness.mat differ diff --git a/tests/parity/fixtures/matlab_gold/thinning_exactness.mat b/tests/parity/fixtures/matlab_gold/thinning_exactness.mat index 8826856a..b774fe31 100644 Binary files a/tests/parity/fixtures/matlab_gold/thinning_exactness.mat and b/tests/parity/fixtures/matlab_gold/thinning_exactness.mat differ diff --git a/tests/test_analysis_pipeline.py b/tests/test_analysis_pipeline.py index ca65709b..4f7cc312 100644 --- a/tests/test_analysis_pipeline.py +++ b/tests/test_analysis_pipeline.py @@ -16,7 +16,7 @@ def test_trial_analysis_pipeline() -> None: spikes = model.simulate(num_realizations=3, seed=2) trial = Trial(spike_collection=spikes, covariate_collection=CovariateCollection([cov])) - cfgs = ConfigCollection([TrialConfig(covMask=["stim"], sampleRate=1000.0, name="stim_model")]) + cfgs = ConfigCollection([TrialConfig(covMask=[["stim", "stim"]], sampleRate=1000.0, name="stim_model")]) fits = Analysis.run_analysis_for_all_neurons(trial, cfgs) assert len(fits) == 3 diff --git a/tests/test_fitresult_diagnostics.py b/tests/test_fitresult_diagnostics.py index ad2e2859..0b9b2a26 100644 --- a/tests/test_fitresult_diagnostics.py +++ b/tests/test_fitresult_diagnostics.py @@ -15,7 +15,7 @@ def _build_fit_result(): spikes = model.simulate(num_realizations=2, seed=7) trial = Trial(spike_collection=spikes, covariate_collection=CovariateCollection([cov])) - cfgs = ConfigCollection([TrialConfig(covMask=["stim"], sampleRate=1000.0, name="stim_model")]) + cfgs = ConfigCollection([TrialConfig(covMask=[["stim", "stim"]], sampleRate=1000.0, name="stim_model")]) return Analysis.run_analysis_for_all_neurons(trial, cfgs)[0] diff --git a/tests/test_matlab_gold_fixtures.py b/tests/test_matlab_gold_fixtures.py index 8fa8ea2a..b8e7742c 100644 --- a/tests/test_matlab_gold_fixtures.py +++ b/tests/test_matlab_gold_fixtures.py @@ -2,7 +2,10 @@ from pathlib import Path +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt import numpy as np +import pytest from scipy.io import loadmat from nstat import ( @@ -13,6 +16,7 @@ CovColl, Covariate, DecodingAlgorithms, + Events, FitResult, FitResSummary, SignalObj, @@ -180,6 +184,27 @@ def test_confidence_interval_matches_matlab_gold_fixture() -> None: assert roundtrip.name == _string(payload, "roundtrip_name") assert roundtrip.plotProps == _string_list(payload, "roundtrip_plotProps") + fig, ax = plt.subplots() + try: + lines = ci.plot(_string(payload, "color"), 0.2, 0, ax=ax) + actual_colors = np.asarray([mcolors.to_rgb(line.get_color()) for line in lines], dtype=float) + np.testing.assert_allclose(actual_colors, np.asarray(payload["line_plot_colors"], dtype=float), rtol=1e-12, atol=1e-12) + finally: + plt.close(fig) + + fig, ax = plt.subplots() + try: + patch = ci.plot(np.asarray(payload["patch_face_color"], dtype=float), _scalar(payload, "patch_face_alpha"), 1, ax=ax) + np.testing.assert_allclose(patch.get_facecolor()[0, :3], np.asarray(payload["patch_face_color"], dtype=float), rtol=1e-12, atol=1e-12) + assert patch.get_edgecolor() is not None + np.testing.assert_allclose(patch.get_alpha(), _scalar(payload, "patch_face_alpha"), rtol=1e-12, atol=1e-12) + edge = _string(payload, "patch_edge_color") + if edge == "none": + edge_color = np.asarray(patch.get_edgecolor()) + assert edge_color.size == 0 or edge_color.reshape(-1, 4)[0, 3] == 0.0 + finally: + plt.close(fig) + def test_nstcoll_matches_matlab_gold_fixture() -> None: payload = _load_fixture("nstcoll_exactness.mat") @@ -202,11 +227,15 @@ def test_trialconfig_and_configcoll_match_matlab_gold_fixture() -> None: payload = _load_fixture("config_exactness.mat") cfg = TrialConfig([["Position", "x"], ["Stimulus"]], 2.0, [0.0, 0.5, 1.0], [], [], 0.5, "stim_pos") cfg2 = TrialConfig([["Stimulus"]], 2.0, [], [], [], [], "manual") + default_coll = ConfigColl() + empty_coll = ConfigColl([]) structure = cfg.toStructure() roundtrip = TrialConfig.fromStructure(structure) coll = ConfigColl([cfg, cfg2]) subset = coll.getSubsetConfigs([1, 2]) rebuilt = ConfigColl.fromStructure(coll.toStructure()) + renamed = ConfigColl([cfg, cfg2]) + renamed.setConfigNames("", [1]) assert cfg.name == _string(payload, "cfg_name") np.testing.assert_allclose(float(cfg.sampleRate), _scalar(payload, "cfg_sampleRate"), rtol=1e-12, atol=1e-12) @@ -220,6 +249,69 @@ def test_trialconfig_and_configcoll_match_matlab_gold_fixture() -> None: assert rebuilt.getConfig(1).name == _string(payload, "rebuilt_first_name") assert rebuilt.getConfig(1).covLag == _string(payload, "rebuilt_first_covLag") np.testing.assert_allclose(float(rebuilt.getConfig(1).ensCovMask), _scalar(payload, "rebuilt_first_ensCovMask"), rtol=1e-12, atol=1e-12) + assert default_coll.numConfigs == int(_scalar(payload, "default_numConfigs")) + assert default_coll.getConfigNames() == _string_list(payload, "default_names") + assert empty_coll.numConfigs == int(_scalar(payload, "empty_numConfigs")) + assert empty_coll.getConfigNames() == _string_list(payload, "empty_names") + assert renamed.getConfigNames() == _string_list(payload, "renamed_names") + with pytest.raises(AttributeError) as excinfo: + ConfigColl("abc") + assert _string(payload, "string_error_identifier") in {"", "MATLAB:structRefFromNonStruct"} + assert "name" in str(excinfo.value) + + time = np.array([0.0, 0.5, 1.0], dtype=float) + position = Covariate(time, np.column_stack([[0.0, 1.0, 2.0], [10.0, 11.0, 12.0]]), "Position", "time", "s", "", ["x", "y"]) + stimulus = Covariate(time, [5.0, 6.0, 7.0], "Stimulus", "time", "s", "a.u.", ["stim"]) + n1 = nspikeTrain([0.0, 0.5, 1.0], "n1", 2.0, 0.0, 1.0, "time", "s", "spikes", "spk", -1) + n2 = nspikeTrain([0.25, 0.75], "n2", 2.0, 0.0, 1.0, "time", "s", "spikes", "spk", -1) + + cfg_applied = TrialConfig([["Position", "x"], ["Stimulus"]], 4.0, [0.0, 0.5, 1.0], [0.0, 0.5, 1.0], [[0, 1], [1, 0]], 0.25, "stim_pos") + trial = Trial(nstColl([n1, n2]), CovColl([position, stimulus])) + cfg_applied.setConfig(trial) + + np.testing.assert_allclose(float(trial.sampleRate), _scalar(payload, "applied_sampleRate"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(trial.flattenCovMask(), dtype=float), _vector(payload, "applied_flat_cov_mask"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(trial.history.windowTimes, dtype=float), _vector(payload, "applied_history_windowTimes"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(trial.ensCovHist.windowTimes, dtype=float), _vector(payload, "applied_ens_history_windowTimes"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(trial.ensCovMask, dtype=float), np.asarray(payload["applied_ens_mask"], dtype=float), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(trial.covarColl.getCov(1).time, dtype=float), _vector(payload, "applied_shifted_position_time"), rtol=1e-12, atol=1e-12) + + trial_from_coll = Trial(nstColl([n1, n2]), CovColl([position, stimulus])) + ConfigColl([cfg_applied]).setConfig(trial_from_coll, 1) + np.testing.assert_allclose(float(trial_from_coll.sampleRate), _scalar(payload, "applied_coll_sampleRate"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(trial_from_coll.flattenCovMask(), dtype=float), _vector(payload, "applied_coll_flat_cov_mask"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(trial_from_coll.history.windowTimes, dtype=float), _vector(payload, "applied_coll_history_windowTimes"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(trial_from_coll.ensCovHist.windowTimes, dtype=float), _vector(payload, "applied_coll_ens_history_windowTimes"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(trial_from_coll.ensCovMask, dtype=float), np.asarray(payload["applied_coll_ens_mask"], dtype=float), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(trial_from_coll.covarColl.getCov(1).time, dtype=float), _vector(payload, "applied_coll_shifted_position_time"), rtol=1e-12, atol=1e-12) + + +def test_events_match_matlab_gold_fixture() -> None: + payload = _load_fixture("events_exactness.mat") + events = Events(_vector(payload, "eventTimes"), _string_list(payload, "eventLabels"), _string(payload, "eventColor")) + rebuilt = Events.fromStructure(events.toStructure()) + assert rebuilt is not None + np.testing.assert_allclose(np.asarray(rebuilt.eventTimes, dtype=float), _vector(payload, "eventTimes"), rtol=1e-12, atol=1e-12) + assert rebuilt.eventLabels == _string_list(payload, "eventLabels") + assert rebuilt.eventColor == _string(payload, "eventColor") + + fig, ax = plt.subplots() + try: + ax.axis(np.asarray(payload["axis_limits"], dtype=float)) + returned_ax = events.plot(handle=ax) + assert returned_ax is ax + assert len(ax.lines) == _vector(payload, "eventTimes").size + first_line = ax.lines[0] + np.testing.assert_allclose(np.asarray(first_line.get_xdata(), dtype=float), _vector(payload, "plot_line_xdata"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(first_line.get_ydata(), dtype=float), _vector(payload, "plot_line_ydata"), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(np.asarray(mcolors.to_rgb(first_line.get_color()), dtype=float), np.asarray(payload["plot_line_color"], dtype=float), rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(float(first_line.get_linewidth()), _scalar(payload, "plot_line_width"), rtol=1e-12, atol=1e-12) + text_strings = [text.get_text() for text in ax.texts] + assert text_strings == _string_list(payload, "plot_label_strings") + label_positions = np.asarray([(*text.get_position(), 0.0) for text in ax.texts], dtype=float).reshape(-1) + np.testing.assert_allclose(label_positions, _vector(payload, "plot_label_positions"), rtol=1e-12, atol=1e-12) + finally: + plt.close(fig) def test_cif_eval_surface_matches_matlab_gold_fixture() -> None: diff --git a/tests/test_signalobj_fidelity.py b/tests/test_signalobj_fidelity.py index d1b7e695..82f88051 100644 --- a/tests/test_signalobj_fidelity.py +++ b/tests/test_signalobj_fidelity.py @@ -1,6 +1,7 @@ from __future__ import annotations import matplotlib.pyplot as plt +import matplotlib.colors as mcolors import numpy as np from nstat.ConfidenceInterval import ConfidenceInterval @@ -159,14 +160,14 @@ def test_confidence_interval_line_plot_ignores_string_color_like_matlab() -> Non fig1, ax1 = plt.subplots() lines_default = ci.plot(color="r", drawPatches=0, ax=ax1) - default_colors = [line.get_color() for line in lines_default] + default_colors = [tuple(mcolors.to_rgba(line.get_color())) for line in lines_default] fig2, ax2 = plt.subplots() lines_numeric = ci.plot(color=(0.2, 0.4, 0.6), drawPatches=0, ax=ax2) - numeric_colors = [line.get_color() for line in lines_numeric] + numeric_colors = [tuple(mcolors.to_rgba(line.get_color())) for line in lines_numeric] - assert default_colors != ["r", "r"] - assert numeric_colors == [(0.2, 0.4, 0.6), (0.2, 0.4, 0.6)] + assert default_colors != [mcolors.to_rgba("r"), mcolors.to_rgba("r")] + assert numeric_colors == [mcolors.to_rgba((0.2, 0.4, 0.6)), mcolors.to_rgba((0.2, 0.4, 0.6))] plt.close(fig1) plt.close(fig2) diff --git a/tests/test_trial_fidelity.py b/tests/test_trial_fidelity.py index beeaec45..067a7d53 100644 --- a/tests/test_trial_fidelity.py +++ b/tests/test_trial_fidelity.py @@ -34,9 +34,9 @@ def test_covcoll_masking_selector_and_time_matrix() -> None: time, matrix, labels = coll.matrixWithTime() np.testing.assert_allclose(time, [0.0, 0.5, 1.0]) - np.testing.assert_allclose(matrix, [[0.0, 5.0], [1.0, 6.0], [2.0, 7.0]]) - assert labels == ["x", "stim"] - assert coll.getCovLabelsFromMask() == ["x", "stim"] + np.testing.assert_allclose(matrix, [[0.0], [1.0], [2.0]]) + assert labels == ["x"] + assert coll.getCovLabelsFromMask() == ["x"] coll.setCovShift(0.5) shifted = coll.getCov("Stimulus") @@ -80,7 +80,7 @@ def test_trialconfig_and_configcoll_apply_and_roundtrip() -> None: assert round(trial.sampleRate, 3) == 2.0 assert trial.isHistSet() - assert trial.getCovLabelsFromMask() == ["x", "stim"] + assert trial.getCovLabelsFromMask() == ["x"] roundtrip = TrialConfig.fromStructure(cfg.toStructure()) assert roundtrip.name == "" diff --git a/tests/test_workflow_fidelity.py b/tests/test_workflow_fidelity.py index e4507c6a..14148966 100644 --- a/tests/test_workflow_fidelity.py +++ b/tests/test_workflow_fidelity.py @@ -48,8 +48,8 @@ def test_analysis_returns_matlab_style_fitresult_surface() -> None: trial = _build_trial() configs = ConfigColl( [ - TrialConfig(covMask=[["Stimulus"]], sampleRate=10.0, history=[0.0, 0.1, 0.2], name="stim_hist"), - TrialConfig(covMask=[["Velocity"]], sampleRate=10.0, name="vel_only"), + TrialConfig(covMask=[["Stimulus", "stim"]], sampleRate=10.0, history=[0.0, 0.1, 0.2], name="stim_hist"), + TrialConfig(covMask=[["Velocity", "vel"]], sampleRate=10.0, name="vel_only"), ] ) @@ -68,7 +68,7 @@ def test_analysis_returns_matlab_style_fitresult_surface() -> None: def test_fitresult_roundtrip_and_summary_preserve_core_metadata() -> None: trial = _build_trial() - configs = ConfigColl([TrialConfig(covMask=[["Stimulus"]], sampleRate=10.0, name="stim_only")]) + configs = ConfigColl([TrialConfig(covMask=[["Stimulus", "stim"]], sampleRate=10.0, name="stim_only")]) fits = Analysis.RunAnalysisForAllNeurons(trial, configs) rebuilt = FitResult.fromStructure(fits[0].toStructure()) diff --git a/tools/parity/matlab/export_matlab_gold_fixtures.m b/tools/parity/matlab/export_matlab_gold_fixtures.m index f96d1b95..30cd9eae 100644 --- a/tools/parity/matlab/export_matlab_gold_fixtures.m +++ b/tools/parity/matlab/export_matlab_gold_fixtures.m @@ -24,6 +24,7 @@ function export_matlab_gold_fixtures(repoRoot, matlabRepoRoot) export_nspiketrain_fixture(fixtureRoot); export_nstcoll_fixture(fixtureRoot); export_config_fixture(fixtureRoot); +export_events_fixture(fixtureRoot); export_cif_fixture(fixtureRoot); export_analysis_fixture(fixtureRoot); export_ksdiscrete_fixture(fixtureRoot); @@ -37,6 +38,32 @@ function export_matlab_gold_fixtures(repoRoot, matlabRepoRoot) export_simulated_network_fixture(fixtureRoot); end +function export_events_fixture(fixtureRoot) +events = Events([0.2 0.7], {'E1','E2'}, 'g'); +fig = figure('Visible', 'off'); +ax = axes('Parent', fig); +axis(ax, [0 1 -1 2]); +events.plot(ax); + +lineHandles = flipud(findobj(ax, 'Type', 'line')); +textHandles = flipud(findobj(ax, 'Type', 'text')); + +payload = struct(); +payload.eventTimes = events.eventTimes; +payload.eventLabels = events.eventLabels; +payload.eventColor = events.eventColor; +payload.axis_limits = axis(ax); +payload.plot_line_xdata = get(lineHandles(1), 'XData'); +payload.plot_line_ydata = get(lineHandles(1), 'YData'); +payload.plot_line_color = get(lineHandles(1), 'Color'); +payload.plot_line_width = get(lineHandles(1), 'LineWidth'); +payload.plot_label_strings = get(textHandles, 'String'); +payload.plot_label_positions = cell2mat(get(textHandles, 'Position')'); + +close(fig); +save(fullfile(fixtureRoot, 'events_exactness.mat'), '-struct', 'payload'); +end + function export_confidence_interval_fixture(fixtureRoot) t = (0:0.1:0.4)'; bounds = [0.9 1.1; 1.9 2.1; 2.9 3.1; 3.9 4.1; 4.9 5.1]; @@ -46,6 +73,25 @@ function export_confidence_interval_fixture(fixtureRoot) structure = ci.dataToStructure; roundtrip = ConfidenceInterval.fromStructure(structure); +fig = figure('Visible','off'); +ax = axes('Parent', fig); +ci.plot('r', 0.2, 0); +lineHandles = flipud(findobj(ax, 'Type', 'line')); +lineColors = zeros(numel(lineHandles), 3); +for iLine = 1:numel(lineHandles) + lineColors(iLine, :) = get(lineHandles(iLine), 'Color'); +end +close(fig); + +fig = figure('Visible','off'); +ax = axes('Parent', fig); +ci.plot([0.1 0.2 0.3], 0.4, 1); +patchHandle = findobj(ax, 'Type', 'patch'); +patchFaceColor = get(patchHandle, 'FaceColor'); +patchEdgeColor = get(patchHandle, 'EdgeColor'); +patchFaceAlpha = get(patchHandle, 'FaceAlpha'); +close(fig); + payload = struct(); payload.time = ci.time; payload.bounds = ci.data; @@ -64,6 +110,10 @@ function export_confidence_interval_fixture(fixtureRoot) payload.roundtrip_value = roundtrip.value; payload.roundtrip_name = roundtrip.name; payload.roundtrip_plotProps = roundtrip.plotProps; +payload.line_plot_colors = lineColors; +payload.patch_face_color = patchFaceColor; +payload.patch_edge_color = patchEdgeColor; +payload.patch_face_alpha = patchFaceAlpha; save(fullfile(fixtureRoot, 'confidence_interval_exactness.mat'), '-struct', 'payload'); end @@ -191,6 +241,28 @@ function export_config_fixture(fixtureRoot) coll = ConfigColl({cfg, cfg2}); subset = coll.getSubsetConfigs([1 2]); rebuilt = ConfigColl.fromStructure(coll.toStructure); +defaultColl = ConfigColl(); +emptyColl = ConfigColl([]); +renamed = ConfigColl({cfg, cfg2}); +renamed.setConfigNames('', 1); +stringError = struct('identifier', '', 'message', ''); +try + ConfigColl('abc'); +catch ME + stringError.identifier = ME.identifier; + stringError.message = ME.message; +end + +t = (0:0.5:1.0)'; +position = Covariate(t, [0 10; 1 11; 2 12], 'Position', 'time', 's', '', {'x','y'}); +stimulus = Covariate(t, [5; 6; 7], 'Stimulus', 'time', 's', 'a.u.', {'stim'}); +n1 = nspikeTrain([0.0 0.5 1.0], 'n1', 2.0, 0.0, 1.0, 'time', 's', 'spikes', 'spk', -1); +n2 = nspikeTrain([0.25 0.75], 'n2', 2.0, 0.0, 1.0, 'time', 's', 'spikes', 'spk', -1); +cfgApplied = TrialConfig({{'Position','x'},{'Stimulus'}}, 4.0, [0 0.5 1.0], [0 0.5 1.0], [0 1; 1 0], 0.25, 'stim_pos'); +trial1 = Trial(nstColl({n1, n2}), CovColl({position, stimulus})); +cfgApplied.setConfig(trial1); +trial2 = Trial(nstColl({n1, n2}), CovColl({position, stimulus})); +ConfigColl({cfgApplied}).setConfig(trial2, 1); payload = struct(); payload.cfg_name = cfg.name; @@ -207,6 +279,25 @@ function export_config_fixture(fixtureRoot) payload.rebuilt_first_name = rebuilt.getConfig(1).name; payload.rebuilt_first_covLag = rebuilt.getConfig(1).covLag; payload.rebuilt_first_ensCovMask = rebuilt.getConfig(1).ensCovMask; +payload.default_numConfigs = defaultColl.numConfigs; +payload.default_names = defaultColl.getConfigNames(); +payload.empty_numConfigs = emptyColl.numConfigs; +payload.empty_names = emptyColl.getConfigNames(); +payload.renamed_names = renamed.getConfigNames(); +payload.string_error_identifier = stringError.identifier; +payload.string_error_message = stringError.message; +payload.applied_sampleRate = trial1.sampleRate; +payload.applied_flat_cov_mask = trial1.flattenCovMask(); +payload.applied_history_windowTimes = trial1.history.windowTimes; +payload.applied_ens_history_windowTimes = trial1.ensCovHist.windowTimes; +payload.applied_ens_mask = trial1.ensCovMask; +payload.applied_shifted_position_time = trial1.covarColl.getCov(1).time; +payload.applied_coll_sampleRate = trial2.sampleRate; +payload.applied_coll_flat_cov_mask = trial2.flattenCovMask(); +payload.applied_coll_history_windowTimes = trial2.history.windowTimes; +payload.applied_coll_ens_history_windowTimes = trial2.ensCovHist.windowTimes; +payload.applied_coll_ens_mask = trial2.ensCovMask; +payload.applied_coll_shifted_position_time = trial2.covarColl.getCov(1).time; save(fullfile(fixtureRoot, 'config_exactness.mat'), '-struct', 'payload'); end