Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 121 additions & 70 deletions notebooks/StimulusDecode2D.ipynb

Large diffs are not rendered by default.

282 changes: 281 additions & 1 deletion nstat/cif.py

Large diffs are not rendered by default.

106 changes: 81 additions & 25 deletions nstat/decoding_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,23 @@ def _normalize_mu_models(mu, n_models: int, num_cells: int) -> list[np.ndarray]:
return [_normalize_mu(mu, num_cells) for _ in range(n_models)]


def _extract_linear_terms_from_cifs(lambdaCIFColl, num_states: int, num_cells: int):
def _normalize_cif_collection(lambdaCIFColl) -> list[CIF]:
if isinstance(lambdaCIFColl, CIF):
cifs = [lambdaCIFColl]
elif isinstance(lambdaCIFColl, Sequence) and not isinstance(lambdaCIFColl, (str, bytes)):
cifs = list(lambdaCIFColl)
else:
raise UnsupportedWorkflowError("PPDecodeFilter requires a CIF or sequence of CIF objects for the Python port")
if not cifs:
raise ValueError("lambdaCIFColl must contain at least one CIF object")
for cif in cifs:
if not isinstance(cif, CIF):
raise UnsupportedWorkflowError("PPDecodeFilter only supports CIF objects in the Python port")
return cifs


def _extract_linear_terms_from_cifs(lambdaCIFColl, num_states: int, num_cells: int):
cifs = _normalize_cif_collection(lambdaCIFColl)

if len(cifs) != num_cells:
raise ValueError("Number of CIF objects must match the number of observed cells")
Expand Down Expand Up @@ -542,13 +552,12 @@ def PPDecode_update(x_p, W_p, dN, lambdaIn, binwidth=0.001, time_index=1, WuConv
observed = obs[:, idx - 1]

for cell_index, cif in enumerate(lambda_items):
if not isinstance(cif, CIF):
raise ValueError("Lambda must be a cell of CIFs or a CIF")
if cif.historyMat.size == 0:
spike_times = (np.where(obs[cell_index] == 1.0)[0]) * float(binwidth)
observed_prefix = obs[cell_index, :idx]
spike_times = np.where(observed_prefix > 0.5)[0] * float(binwidth)
nst = nspikeTrain(spike_times, makePlots=-1)
nst.setMinTime(0.0)
nst.setMaxTime((obs.shape[1] - 1) * float(binwidth))
nst.setMaxTime((idx - 1) * float(binwidth))
nst = nst.resample(1.0 / float(binwidth))
lambda_delta[cell_index, 0] = float(cif.evalLambdaDelta(x_vec, idx, nst))
sum_val_vec += observed[cell_index] * np.asarray(cif.evalGradientLog(x_vec, idx, nst), dtype=float).reshape(-1)
Expand Down Expand Up @@ -695,26 +704,73 @@ def PPDecodeFilterLinear(*args, **kwargs):
@staticmethod
def PPDecodeFilter(A, Q, Px0, dN, lambdaCIFColl, binwidth=0.001, x0=None, Pi0=None, yT=None, PiT=None, estimateTarget=0, Wconv=None):
obs = _as_observation_matrix(dN)
num_states = _infer_state_dim(A, np.array([0.0]), obs.shape[0])
mu, beta, fitType, gamma, windowTimes = _extract_linear_terms_from_cifs(lambdaCIFColl, num_states, obs.shape[0])
initial_cov = Px0 if _is_empty_value(Pi0) else Pi0
return DecodingAlgorithms._ppdecode_filter_linear(
A,
Q,
obs,
mu,
beta,
fitType,
binwidth,
gamma,
windowTimes,
x0,
initial_cov,
yT,
PiT,
estimateTarget,
Wconv,
)
lambda_items = _normalize_cif_collection(lambdaCIFColl)
num_cells, num_steps = obs.shape
if len(lambda_items) != num_cells:
raise ValueError("Number of CIF objects must match the number of observed cells")

num_states = _infer_state_dim(A, np.array([0.0]), num_cells)
uses_target_branch = not _is_empty_value(yT) or not _is_empty_value(PiT) or int(estimateTarget) != 0
if uses_target_branch:
mu, beta, fitType, gamma, windowTimes = _extract_linear_terms_from_cifs(lambda_items, num_states, num_cells)
initial_cov = Px0 if _is_empty_value(Pi0) else Pi0
return DecodingAlgorithms._ppdecode_filter_linear(
A,
Q,
obs,
mu,
beta,
fitType,
binwidth,
gamma,
windowTimes,
x0,
initial_cov,
yT,
PiT,
estimateTarget,
Wconv,
)

x0_vec = np.zeros(num_states, dtype=float) if _is_empty_value(x0) else np.asarray(x0, dtype=float).reshape(-1)
if x0_vec.size != num_states:
raise ValueError("x0 must match the decoding state dimension")
# MATLAB PPDecodeFilter's standard branch initializes from Pi0, and
# when Pi0 is omitted it falls back to zeros rather than using Px0.
Pi0_mat = np.zeros((num_states, num_states), dtype=float) if _is_empty_value(Pi0) else _as_state_matrix(Pi0, num_states)

x_p = np.zeros((num_states, num_steps + 1), dtype=float)
x_u = np.zeros((num_states, num_steps), dtype=float)
W_p = np.zeros((num_states, num_states, num_steps + 1), dtype=float)
W_u = np.zeros((num_states, num_states, num_steps), dtype=float)

A0 = _select_time_matrix(A, 0, num_states)
Q0 = _select_time_matrix(Q, 0, num_states)
x_p[:, 0], W_p[:, :, 0] = DecodingAlgorithms.PPDecode_predict(x0_vec, Pi0_mat, A0, Q0, Wconv)

for time_index in range(1, num_steps + 1):
x_u[:, time_index - 1], W_u[:, :, time_index - 1], _ = DecodingAlgorithms.PPDecode_update(
x_p[:, time_index - 1],
W_p[:, :, time_index - 1],
obs,
lambda_items,
binwidth,
time_index,
None,
)
A_t = _select_time_matrix(A, time_index - 1, num_states)
Q_t = _select_time_matrix(Q, time_index - 1, num_states)
x_p[:, time_index], W_p[:, :, time_index] = DecodingAlgorithms.PPDecode_predict(
x_u[:, time_index - 1],
W_u[:, :, time_index - 1],
A_t,
Q_t,
Wconv,
)

empty_vec = np.array([], dtype=float)
empty_cov = np.zeros((0, 0, 0), dtype=float)
return x_p, W_p, x_u, W_u, empty_vec, empty_cov, empty_vec, empty_cov

@staticmethod
def PP_fixedIntervalSmoother(A, Q, dN, lags, mu, beta, fitType="poisson", delta=0.001, gamma=None, windowTimes=None, x0=None, Pi0=None):
Expand Down
18 changes: 11 additions & 7 deletions parity/class_fidelity.yml
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,13 @@ items:
in the expected workflow positions.
symbol_presence_verified: yes
known_remaining_differences:
- Simulink-backed recursive-CIF behavior is represented by a native Python implementation,
but it is not yet fixture-matched one-for-one against MATLAB/Simulink stochastic
trajectories.
- Analytic and nonlinear polynomial CIF surfaces are now fixture-backed against
MATLAB, but recursive Simulink-backed stochastic trajectories are still validated
as high-fidelity summaries rather than exact sample-by-sample reproductions.
required_remediation:
- Extend the committed MATLAB-derived fixtures beyond analytic lambda/gradient/Jacobian
outputs and the deterministic recursive lambda prefix to cover additional thinning
and seeded simulation summaries.
outputs, nonlinear polynomial surfaces, and the deterministic recursive lambda
prefix to cover additional thinning and seeded simulation summaries.
- Add MATLAB/Simulink comparison fixtures for recursive CIF simulation trajectories
when the random-stream alignment question is resolved.
plotting_report_parity: Simulation/report plotting is limited; downstream notebooks
Expand Down Expand Up @@ -392,11 +392,15 @@ items:
tensors instead of only Python-specific dictionaries.
symbol_presence_verified: yes
known_remaining_differences:
- The nonlinear `PPDecodeFilter` path is now fixture-backed against MATLAB on
a deterministic polynomial-CIF example, but it still shows small symbolic/numeric
drift at the `1e-4` level and remains high-fidelity rather than exact.
- Target-estimation augmentation, EM routines, and some advanced symbolic-CIF workflows
remain thinner than MATLAB.
required_remediation:
- Extend the committed MATLAB-derived numerical fixtures beyond `PPDecode_predict`
to DecodingExample, DecodingExampleWithHist, StimulusDecode2D, and HybridFilterExample.
- Extend the committed MATLAB-derived numerical fixtures from `PPDecode_predict`
and the deterministic nonlinear `PPDecodeFilter` case to DecodingExample,
DecodingExampleWithHist, and HybridFilterExample summaries.
- Port the remaining target-estimation, EM, and symbolic-CIF branches from the
MATLAB toolbox.
plotting_report_parity: Notebook-level decoding figures are supported, but the full
Expand Down
11 changes: 6 additions & 5 deletions parity/notebook_fidelity.yml
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,12 @@ items:
- topic: StimulusDecode2D
source_matlab: StimulusDecode2D.mlx
python_notebook: notebooks/StimulusDecode2D.ipynb
fidelity_status: partial
remaining_differences: The notebook reproduces the MATLAB section order, figure
inventory, simulated receptive fields, and decoded-trajectory presentation, but
the current Python decoder still uses regression-based state recovery instead
of MATLAB's symbolic-CIF nonlinear filter.
fidelity_status: high_fidelity
remaining_differences: The notebook now follows the MATLAB nonlinear-CIF decoding
workflow and uses `DecodingAlgorithms.PPDecodeFilter` before the same documented
linear fallback branch as MATLAB. Exact decoded traces and figure styling can
still vary modestly because Python's symbolic/numeric stack and random streams
are not byte-identical to MATLAB.
python_sections: 4
python_expected_figures: 6
python_uses_figure_tracker: true
Expand Down
9 changes: 4 additions & 5 deletions parity/report.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ Generated from `parity/manifest.yml`, `parity/class_fidelity.yml`, `tools/notebo
| Status | Count |
|---|---:|
| `exact` | 0 |
| `high_fidelity` | 12 |
| `partial` | 1 |
| `high_fidelity` | 13 |
| `partial` | 0 |

## Simulink Fidelity Summary

Expand All @@ -60,8 +60,7 @@ Generated from `parity/manifest.yml`, `parity/class_fidelity.yml`, `tools/notebo

- Public API: no missing MATLAB public APIs remain; only the MATLAB help-browser utility is explicitly non-applicable.
- Help/notebook parity: all inventoried MATLAB help workflows are mapped to Python notebooks or equivalents.
- Notebook fidelity: workflow coverage is complete, but 1 MATLAB-helpfile notebook ports are still marked partial in `tools/notebooks/parity_notes.yml`.
- Notebook fidelity audit: structural section/figure comparisons plus placeholder/tracker-only cell detection are recorded in `parity/notebook_fidelity.yml`.
- Notebook fidelity: all tracked MATLAB-helpfile notebook ports are marked high fidelity or exact.
- Paper examples and docs gallery: all canonical paper examples and committed gallery directories are mapped.
- Class fidelity: the class audit reports no partial, wrapper-only, or missing items.
- Runtime symbol verification: every audited MATLAB-facing Python symbol marked present in `parity/class_fidelity.yml` resolves on the live public surface.
Expand All @@ -73,7 +72,7 @@ No partial or missing items remain in the mapping inventory.

## Remaining Notebook-Fidelity Deltas

- `StimulusDecode2D` -> `notebooks/StimulusDecode2D.ipynb` [partial]: The notebook reproduces the MATLAB section order, figure inventory, simulated receptive fields, and decoded-trajectory presentation, but the current Python decoder still uses regression-based state recovery instead of MATLAB's symbolic-CIF nonlinear filter.
No partial notebook-fidelity items remain in `tools/notebooks/parity_notes.yml`.

## Remaining Class-Fidelity Deltas

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies = [
"numpy>=1.24",
"scipy>=1.10",
"matplotlib>=3.7",
"sympy>=1.13",
"PyYAML>=6.0",
"nbformat>=5.10",
"nbclient>=0.10"
Expand Down
Binary file modified tests/parity/fixtures/matlab_gold/cif_exactness.mat
Binary file not shown.
Binary file not shown.
17 changes: 17 additions & 0 deletions tests/test_decoding_algorithms_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@ def test_ppdecodefilter_accepts_cif_collections_with_history() -> None:
assert np.all(np.isfinite(x_u))


def test_ppdecodefilter_handles_symbolic_style_polynomial_cifs() -> None:
dN = np.array([[0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0]], dtype=float)
lambda_cifs = [
CIF([-2.0, -0.5, 0.3, -0.2, -0.1, 0.05], ["1", "x", "y", "x^2", "y^2", "x*y"], ["x", "y"], fitType="binomial"),
CIF([-1.5, 0.4, -0.2, 0.15, -0.05, 0.02], ["1", "x", "y", "x^2", "y^2", "x*y"], ["x", "y"], fitType="binomial"),
]

x_p, W_p, x_u, W_u, *_ = DecodingAlgorithms.PPDecodeFilter(np.eye(2), 0.01 * np.eye(2), 0.05 * np.eye(2), dN, lambda_cifs, 0.1)

assert x_p.shape == (2, 5)
assert W_p.shape == (2, 2, 5)
assert x_u.shape == (2, 4)
assert W_u.shape == (2, 2, 4)
assert np.all(np.isfinite(x_u))
assert np.all(np.isfinite(W_u))


def test_ppdecode_update_matches_matlab_facing_public_surface() -> None:
dN = np.array([[0.0, 1.0, 0.0, 1.0]], dtype=float)
lambda_cif = CIF([0.1, 0.4], ["1", "x"], ["x"], fitType="binomial")
Expand Down
35 changes: 35 additions & 0 deletions tests/test_matlab_gold_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,19 @@ def test_cif_eval_surface_matches_matlab_gold_fixture() -> None:
np.testing.assert_allclose(cif.evalJacobian(stim_val), np.asarray(payload["jacobian"], dtype=float), rtol=1e-8, atol=1e-10)
np.testing.assert_allclose(cif.evalJacobianLog(stim_val), np.asarray(payload["jacobian_log"], dtype=float), rtol=1e-8, atol=1e-10)

poly_cif = CIF(
beta=_vector(payload, "poly_beta"),
Xnames=["1", "x", "y", "x^2", "y^2", "x*y"],
stimNames=["x", "y"],
fitType="binomial",
)
poly_stim = _vector(payload, "poly_stimVal")
np.testing.assert_allclose(poly_cif.evalLambdaDelta(poly_stim), _scalar(payload, "poly_lambda_delta"), rtol=1e-8, atol=1e-10)
np.testing.assert_allclose(poly_cif.evalGradient(poly_stim).reshape(-1), _vector(payload, "poly_gradient"), rtol=1e-8, atol=1e-10)
np.testing.assert_allclose(poly_cif.evalGradientLog(poly_stim).reshape(-1), _vector(payload, "poly_gradient_log"), rtol=1e-8, atol=1e-10)
np.testing.assert_allclose(poly_cif.evalJacobian(poly_stim), np.asarray(payload["poly_jacobian"], dtype=float), rtol=1e-8, atol=1e-10)
np.testing.assert_allclose(poly_cif.evalJacobianLog(poly_stim), np.asarray(payload["poly_jacobian_log"], dtype=float), rtol=1e-8, atol=1e-10)


def test_analysis_fit_surface_matches_matlab_gold_fixture() -> None:
payload = _load_fixture("analysis_exactness.mat")
Expand Down Expand Up @@ -225,6 +238,28 @@ def test_decoding_predict_matches_matlab_gold_fixture() -> None:
np.testing.assert_allclose(W_p, np.asarray(payload["W_p"], dtype=float), rtol=1e-8, atol=1e-10)


def test_nonlinear_ppdecodefilter_matches_matlab_gold_fixture() -> None:
payload = _load_fixture("nonlinear_decode_exactness.mat")
lambda_cifs = [
CIF(_vector(payload, "beta1"), ["1", "x", "y", "x^2", "y^2", "x*y"], ["x", "y"], fitType="binomial"),
CIF(_vector(payload, "beta2"), ["1", "x", "y", "x^2", "y^2", "x*y"], ["x", "y"], fitType="binomial"),
]

x_p, W_p, x_u, W_u, *_ = DecodingAlgorithms.PPDecodeFilter(
np.asarray(payload["A"], dtype=float),
np.asarray(payload["Q"], dtype=float),
np.asarray(payload["Px0"], dtype=float),
np.asarray(payload["dN"], dtype=float),
lambda_cifs,
_scalar(payload, "delta"),
)

np.testing.assert_allclose(x_p, np.asarray(payload["x_p"], dtype=float), rtol=1e-3, atol=5e-4)
np.testing.assert_allclose(W_p, np.asarray(payload["W_p"], dtype=float), rtol=1e-3, atol=5e-4)
np.testing.assert_allclose(x_u, np.asarray(payload["x_u"], dtype=float), rtol=1e-3, atol=5e-4)
np.testing.assert_allclose(W_u, np.asarray(payload["W_u"], dtype=float), rtol=1e-3, atol=5e-4)


def test_simulated_network_matches_matlab_gold_fixture() -> None:
payload = _load_fixture("simulated_network_exactness.mat")
native = simulate_two_neuron_network(seed=4)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_notebook_fidelity_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_notebook_fidelity_audit_marks_upgraded_ports_as_high_fidelity() -> None
def test_notebook_fidelity_audit_tracks_only_known_partial_notebooks() -> None:
audit = yaml.safe_load(AUDIT_PATH.read_text(encoding="utf-8")) or {}
partial_topics = {row["topic"] for row in audit.get("items", []) if row["fidelity_status"] in {"partial", "placeholder", "missing"}}
assert partial_topics == {"StimulusDecode2D"}
assert partial_topics == set()


def test_high_fidelity_notebooks_have_no_placeholder_or_tracker_only_cells() -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_notebook_parity_notes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_target_notebooks_start_with_machine_readable_parity_note() -> None:

def test_notebook_parity_notes_track_only_known_partial_statuses() -> None:
partial = [row["topic"] for row in _load_notes() if row["fidelity_status"] == "partial"]
assert partial == ["StimulusDecode2D"]
assert partial == []


def test_high_fidelity_parity_notes_do_not_admit_placeholder_or_tracker_only_status() -> None:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_parity_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

REPO_ROOT = Path(__file__).resolve().parents[1]
MANIFEST_PATH = REPO_ROOT / "parity" / "manifest.yml"
NOTEBOOK_AUDIT_PATH = REPO_ROOT / "parity" / "notebook_fidelity.yml"

EXPECTED_MATLAB_PUBLIC_API = {
"Analysis",
Expand Down Expand Up @@ -99,3 +100,22 @@ def test_parity_manifest_statuses_and_mapped_targets_are_valid() -> None:
target = row.get("python_target")
if status == "mapped":
assert target, f"Mapped item in {section_name} is missing a python_target: {row}"


def test_manifest_help_workflows_align_with_notebook_fidelity_audit() -> None:
manifest = _load_manifest()
notebook_audit = yaml.safe_load(NOTEBOOK_AUDIT_PATH.read_text(encoding="utf-8")) or {}
audit_rows = {row["topic"]: row for row in notebook_audit.get("items", [])}

manifest_help_rows = {
row["matlab"]: row
for row in manifest["help_workflows"]
if str(row.get("python_target", "")).startswith("notebooks/")
}

assert set(audit_rows) <= set(manifest_help_rows)
for topic, audit_row in audit_rows.items():
manifest_row = manifest_help_rows[topic]
audit_row = audit_rows[topic]
if manifest_row["status"] == "mapped":
assert audit_row["fidelity_status"] in {"high_fidelity", "exact"}
4 changes: 2 additions & 2 deletions tests/test_parity_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def test_parity_report_highlights_current_constraints() -> None:
assert "Notebook Fidelity Summary" in text
assert "Simulink Fidelity Summary" in text
assert "Remaining Notebook-Fidelity Deltas" in text
assert "workflow coverage is complete, but 1 MATLAB-helpfile notebook ports are still marked partial" in text
assert "`StimulusDecode2D` -> `notebooks/StimulusDecode2D.ipynb` [partial]" in text
assert "all tracked MATLAB-helpfile notebook ports are marked high fidelity or exact" in text
assert "No partial notebook-fidelity items remain in `tools/notebooks/parity_notes.yml`." in text
assert "No partial or missing items remain in the mapping inventory." in text
assert "Remaining Class-Fidelity Deltas" in text
assert "the class audit reports no partial, wrapper-only, or missing items" in text
Expand Down
Loading