Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
648 changes: 391 additions & 257 deletions notebooks/NetworkTutorial.ipynb

Large diffs are not rendered by default.

57 changes: 57 additions & 0 deletions nstat/decoding_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .cif import CIF
from .errors import UnsupportedWorkflowError
from .nspikeTrain import nspikeTrain


def _as_observation_matrix(dN) -> np.ndarray:
Expand Down Expand Up @@ -519,6 +520,60 @@ def PPDecode_predict(x_u, W_u, A, Q, Wconv=None):
x_p = A_mat @ x_vec
return x_p, W_p

@staticmethod
def PPDecode_update(x_p, W_p, dN, lambdaIn, binwidth=0.001, time_index=1, WuConv=None):
x_vec = np.asarray(x_p, dtype=float).reshape(-1)
W_mat = _as_state_matrix(W_p, x_vec.size)
obs = _as_observation_matrix(dN)
idx = max(1, min(int(time_index), obs.shape[1]))

if isinstance(lambdaIn, CIF):
lambda_items = [lambdaIn]
elif isinstance(lambdaIn, Sequence) and not isinstance(lambdaIn, (str, bytes)):
lambda_items = list(lambdaIn)
else:
raise ValueError("Lambda must be a cell of CIFs or a CIF")
if not lambda_items:
raise ValueError("Lambda must be a non-empty cell of CIFs or a CIF")

lambda_delta = np.zeros((len(lambda_items), 1), dtype=float)
sum_val_vec = np.zeros(x_vec.size, dtype=float)
sum_val_mat = np.zeros((x_vec.size, x_vec.size), dtype=float)
observed = obs[:, idx - 1]

for cell_index, cif in enumerate(lambda_items):
if not isinstance(cif, CIF):
raise ValueError("Lambda must be a cell of CIFs or a CIF")
if cif.historyMat.size == 0:
spike_times = (np.where(obs[cell_index] == 1.0)[0]) * float(binwidth)
nst = nspikeTrain(spike_times, makePlots=-1)
nst.setMinTime(0.0)
nst.setMaxTime((obs.shape[1] - 1) * float(binwidth))
nst = nst.resample(1.0 / float(binwidth))
lambda_delta[cell_index, 0] = float(cif.evalLambdaDelta(x_vec, idx, nst))
sum_val_vec += observed[cell_index] * np.asarray(cif.evalGradientLog(x_vec, idx, nst), dtype=float).reshape(-1)
sum_val_vec -= np.asarray(cif.evalGradient(x_vec, idx, nst), dtype=float).reshape(-1)
sum_val_mat -= np.asarray(cif.evalJacobianLog(x_vec, idx, nst), dtype=float)
sum_val_mat += np.asarray(cif.evalJacobian(x_vec, idx, nst), dtype=float)
else:
lambda_delta[cell_index, 0] = float(cif.evalLambdaDelta(x_vec, idx))
sum_val_vec += observed[cell_index] * np.asarray(cif.evalGradientLog(x_vec, idx), dtype=float).reshape(-1)
sum_val_vec -= np.asarray(cif.evalGradient(x_vec, idx), dtype=float).reshape(-1)
sum_val_mat -= np.asarray(cif.evalJacobianLog(x_vec, idx), dtype=float)
sum_val_mat += np.asarray(cif.evalJacobian(x_vec, idx), dtype=float)

if _is_empty_value(WuConv):
identity = np.eye(W_mat.shape[0], dtype=float)
try:
W_u = W_mat @ (identity - np.linalg.solve(identity + sum_val_mat @ W_mat, sum_val_mat @ W_mat))
except np.linalg.LinAlgError:
W_u = W_mat.copy()
W_u = _symmetrize(W_u)
else:
W_u = _symmetrize(_as_state_matrix(WuConv, x_vec.size))
x_u = x_vec + W_u @ sum_val_vec
return x_u, W_u, lambda_delta

@staticmethod
def PPDecode_updateLinear(x_p, W_p, dN, mu, beta, fitType="poisson", gamma=None, HkAll=None, time_index=1, WuConv=None):
x_vec = np.asarray(x_p, dtype=float).reshape(-1)
Expand Down Expand Up @@ -856,6 +911,7 @@ def PPHybridFilter(A, Q, p_ij, Mu0, dN, lambdaCIFColl, binwidth=0.001, x0=None,
PPDecodeFilter = DecodingAlgorithms.PPDecodeFilter
PPDecodeFilterLinear = DecodingAlgorithms.PPDecodeFilterLinear
PPDecode_predict = DecodingAlgorithms.PPDecode_predict
PPDecode_update = DecodingAlgorithms.PPDecode_update
PPDecode_updateLinear = DecodingAlgorithms.PPDecode_updateLinear
PPHybridFilter = DecodingAlgorithms.PPHybridFilter
PPHybridFilterLinear = DecodingAlgorithms.PPHybridFilterLinear
Expand All @@ -874,6 +930,7 @@ def PPHybridFilter(A, Q, p_ij, Mu0, dN, lambdaCIFColl, binwidth=0.001, x0=None,
"PPDecodeFilter",
"PPDecodeFilterLinear",
"PPDecode_predict",
"PPDecode_update",
"PPDecode_updateLinear",
"PPHybridFilter",
"PPHybridFilterLinear",
Expand Down
5 changes: 4 additions & 1 deletion nstat/paper_example_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,21 @@ def run_named_paper_example(
example_id: str, repo_root: Path, *, return_payload: bool = False
) -> dict[str, dict[str, float]] | tuple[dict[str, dict[str, float]], dict[str, dict[str, object]]]:
repo_root = repo_root.resolve()
data_dir = ensure_example_data(download=True)

if example_id == "example01":
data_dir = ensure_example_data(download=True)
if not return_payload:
return {"experiment1": run_experiment1(data_dir)}
summary, payload = run_experiment1(data_dir, return_payload=True)
return {"experiment1": summary}, {"experiment1": payload}
if example_id == "example02":
data_dir = ensure_example_data(download=True)
if not return_payload:
return {"experiment2": run_experiment2(data_dir)}
summary, payload = run_experiment2(data_dir, return_payload=True)
return {"experiment2": summary}, {"experiment2": payload}
if example_id == "example03":
data_dir = ensure_example_data(download=True)
if not return_payload:
return {
"experiment3": run_experiment3(),
Expand All @@ -52,6 +54,7 @@ def run_named_paper_example(
"experiment3b": payload3b,
}
if example_id == "example04":
data_dir = ensure_example_data(download=True)
if not return_payload:
return {"experiment4": run_experiment4(data_dir)}
summary4, payload4 = run_experiment4(data_dir, return_payload=True)
Expand Down
21 changes: 21 additions & 0 deletions parity/notebook_fidelity.yml
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,27 @@ items:
matlab_published_figures: 8
section_delta: 0
figure_delta: 0
- topic: NetworkTutorial
source_matlab: NetworkTutorial.mlx
python_notebook: notebooks/NetworkTutorial.ipynb
fidelity_status: high_fidelity
remaining_differences: The notebook now mirrors the MATLAB helpfile section order
and published figure inventory with a native Python network simulator and MATLAB-style
`Analysis` workflow; exact spike realizations still vary modestly because NumPy
and Simulink do not share identical random streams.
python_sections: 21
python_expected_figures: 14
python_uses_figure_tracker: true
python_has_finalize_call: true
python_placeholder_cells: 0
python_tracker_only_cells: 0
python_contains_placeholders: false
python_contains_tracker_only_cells: false
matlab_repo_root: /Users/iahncajigas/Library/CloudStorage/Dropbox/Codex/nSTAT
matlab_sections: 21
matlab_published_figures: 14
section_delta: 0
figure_delta: 0
- topic: ValidationDataSet
source_matlab: ValidationDataSet.mlx
python_notebook: notebooks/ValidationDataSet.ipynb
Expand Down
2 changes: 1 addition & 1 deletion parity/report.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Generated from `parity/manifest.yml`, `parity/class_fidelity.yml`, and `tools/no
| Status | Count |
|---|---:|
| `exact` | 0 |
| `high_fidelity` | 12 |
| `high_fidelity` | 13 |
| `partial` | 0 |

## Simulink Fidelity Summary
Expand Down
21 changes: 21 additions & 0 deletions tests/test_decoding_algorithms_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,27 @@ def test_ppdecodefilter_accepts_cif_collections_with_history() -> None:
assert np.all(np.isfinite(x_u))


def test_ppdecode_update_matches_matlab_facing_public_surface() -> None:
dN = np.array([[0.0, 1.0, 0.0, 1.0]], dtype=float)
lambda_cif = CIF([0.1, 0.4], ["1", "x"], ["x"], fitType="binomial")

x_u, W_u, lambda_delta = DecodingAlgorithms.PPDecode_update(
np.array([0.0], dtype=float),
np.array([[1.0]], dtype=float),
dN,
lambda_cif,
0.1,
2,
)

assert x_u.shape == (1,)
assert W_u.shape == (1, 1)
assert lambda_delta.shape == (1, 1)
assert np.all(np.isfinite(x_u))
assert np.all(np.isfinite(W_u))
assert np.all(lambda_delta > 0.0)


def test_pphybridfilterlinear_returns_model_probabilities_and_state_banks() -> None:
a = [np.array([[1.0]], dtype=float), np.array([[0.9]], dtype=float)]
q = [np.array([[0.02]], dtype=float), np.array([[0.05]], dtype=float)]
Expand Down
56 changes: 56 additions & 0 deletions tests/test_matlab_symbol_surface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

import inspect

from nstat import Analysis, CIF, DecodingAlgorithms


EXPECTED_SYMBOLS = {
Analysis: {
"RunAnalysisForNeuron",
"RunAnalysisForAllNeurons",
"GLMFit",
"KSPlot",
"plotFitResidual",
"computeFitResidual",
"computeKSStats",
"plotInvGausTrans",
"plotSeqCorr",
"plotCoeffs",
},
CIF: {
"setSpikeTrain",
"setHistory",
"simulateCIFByThinningFromLambda",
"simulateCIF",
"evalGradient",
"evalGradientLog",
"evalJacobian",
"evalJacobianLog",
"evalGradientLDGamma",
"evalJacobianLDGamma",
},
DecodingAlgorithms: {
"PPDecode_predict",
"PPDecode_update",
"PPDecode_updateLinear",
"PPDecodeFilterLinear",
"PPDecodeFilter",
"PP_fixedIntervalSmoother",
"PPHybridFilterLinear",
"PPHybridFilter",
},
}


def test_expected_matlab_symbol_surface_exists_and_is_callable() -> None:
for obj, expected in EXPECTED_SYMBOLS.items():
missing = sorted(name for name in expected if not callable(getattr(obj, name, None)))
assert not missing, f"{obj.__name__} is missing MATLAB-facing callables: {missing}"


def test_expected_symbol_surface_has_python_runtime_signatures() -> None:
for obj, expected in EXPECTED_SYMBOLS.items():
for name in expected:
signature = inspect.signature(getattr(obj, name))
assert signature is not None
32 changes: 32 additions & 0 deletions tests/test_network_tutorial_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from pathlib import Path

import nbformat

from tools.notebooks.build_network_tutorial_notebook import build_notebook


REPO_ROOT = Path(__file__).resolve().parents[1]
NOTEBOOK_PATH = REPO_ROOT / "notebooks" / "NetworkTutorial.ipynb"


def _normalize_notebook(notebook) -> None:
for cell in notebook.cells:
cell["id"] = "normalized"
cell["execution_count"] = None
cell["outputs"] = []


def _cell_payload(cell) -> tuple[str, str, dict]:
return cell.cell_type, "".join(cell.get("source", "")), dict(cell.get("metadata", {}))


def test_network_tutorial_builder_matches_committed_notebook() -> None:
committed = nbformat.read(NOTEBOOK_PATH, as_version=4)
generated = build_notebook()
_normalize_notebook(committed)
_normalize_notebook(generated)
assert committed.metadata == generated.metadata
assert len(committed.cells) == len(generated.cells)
assert [_cell_payload(cell) for cell in committed.cells] == [_cell_payload(cell) for cell in generated.cells]
2 changes: 2 additions & 0 deletions tests/test_notebook_ci_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"ExplicitStimulusWhiskerData",
"HippocampalPlaceCellExample",
"HybridFilterExample",
"NetworkTutorial",
"PPSimExample",
"SignalObjExamples",
"StimulusDecode2D",
Expand All @@ -35,6 +36,7 @@
"ExplicitStimulusWhiskerData",
"HippocampalPlaceCellExample",
"HybridFilterExample",
"NetworkTutorial",
"PPSimExample",
"StimulusDecode2D",
"TrialExamples",
Expand Down
12 changes: 12 additions & 0 deletions tests/test_notebook_fidelity_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_notebook_fidelity_audit_marks_upgraded_ports_as_high_fidelity() -> None
assert {
"AnalysisExamples",
"AnalysisExamples2",
"NetworkTutorial",
"PPSimExample",
"nSTATPaperExamples",
} <= high_fidelity_topics
Expand All @@ -59,6 +60,17 @@ def test_high_fidelity_notebooks_have_no_placeholder_or_tracker_only_cells() ->
assert not row["python_contains_tracker_only_cells"], f"{row['topic']} still contains tracker-only cells"


def test_high_fidelity_notebooks_have_near_matlab_structural_counts() -> None:
audit = yaml.safe_load(AUDIT_PATH.read_text(encoding="utf-8")) or {}
for row in audit.get("items", []):
if row["fidelity_status"] not in {"high_fidelity", "exact"}:
continue
if row.get("section_delta") is None or row.get("figure_delta") is None:
continue
assert abs(int(row["section_delta"])) <= 1, f"{row['topic']} has a large MATLAB section delta"
assert abs(int(row["figure_delta"])) <= 1, f"{row['topic']} has a large MATLAB figure delta"


def test_notebook_fidelity_audit_matches_generator_when_matlab_repo_is_available() -> None:
matlab_repo = default_matlab_repo_root(REPO_ROOT)
if not matlab_repo.exists():
Expand Down
Loading