From 128f881502f9467ff6c8462f3dfe11c451fa819d Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Mon, 9 Feb 2026 12:32:32 -0500 Subject: [PATCH 1/9] add finders for params array --- src/caskade/mixins.py | 76 +++++++++++++++++++++++++++++++++++++++++++ src/caskade/module.py | 15 +++++++-- tests/test_module.py | 29 +++++++++++++++++ 3 files changed, 117 insertions(+), 3 deletions(-) diff --git a/src/caskade/mixins.py b/src/caskade/mixins.py index f566779..75ed00d 100644 --- a/src/caskade/mixins.py +++ b/src/caskade/mixins.py @@ -1,5 +1,6 @@ from typing import Optional, Mapping, Sequence, Union from math import prod +import numpy as np from .param import Param from .errors import ( @@ -200,6 +201,81 @@ def _recursive_build_params_dict( del params[link] return params + def _array_inspection(self, group: Optional[int] = None): + param_list = self.dynamic_params + param_list = tuple(p for p in param_list if (group is None or p.group == group)) + self._check_values(param_list, "array") + + x = [] + with Memo(self, self.name + ":semi_findidx_active"): + for param in param_list: + if param.online: + shape = param.shape + else: + depth = max(memo.count("|") for memo in param.memos) + shape = param.batch_shape[-depth:] + param.shape + if shape == (): + x.append((param, ())) + else: + for i in range(prod(shape)): + x.append((param, tuple(itm.item() for itm in np.unravel_index(i, shape)))) + return x + + # Finders + ################################################################# + def find_param(self, idx: Union[int, tuple[int]], group: Optional[int] = None): + """ + Identify which param is associated with the provided index in the + dynamic params array. + + Parameters + ---------- + idx: Union[int, tuple[int]] + The index in the params array at which we wish to find the + associated param. + group: Optional[int] + If the dynamic params have multiple group values, then this argument + specifies which group to check. + + Returns + ------- + param_info: tuple[Param, Optional[tuple[int]]] + A tuple with the Param object and the index within the Param value + associated with idx (empty tuple if scalar). If idx is a tuple then + the result is a tuple of these results. + """ + x = self._array_inspection(group) + if isinstance(idx, int): + return x[idx] + return tuple(x[i] for i in idx) + + def find_index(self, param: Union[Param, tuple[Param], "Module"]): + if isinstance(param, (list, tuple)): + return tuple(self.find_index(p) for p in param) + elif isinstance(param, GetSetValues): + return tuple( + self.find_index(c) for c in param.children if isinstance(c, Param) and c.dynamic + ) + + if len(self.dynamic_param_groups) > 1: + for group in self.dynamic_param_groups: + x = self._array_inspection(group) + matches = tuple(m[0] for m in filter(lambda p: p[1][0] is param, enumerate(x))) + if len(matches) == 1: + return (group, matches[0]) + elif len(matches) > 1: + return (group, slice(min(matches), max(matches) + 1)) + else: + raise ValueError(f"Param {param.name} could not be found in dynamic params.") + + x = self._array_inspection(None) + matches = tuple(m[0] for m in filter(lambda p: p[1][0] is param, enumerate(x))) + if len(matches) == 1: + return matches[0] + elif len(matches) > 1: + return slice(min(matches), max(matches) + 1) + raise ValueError(f"Param {param.name} could not be found in dynamic params.") + # To/From Valid ################################################################# def _transform_params(self, node, init_params, param_list, transform_attr): diff --git a/src/caskade/module.py b/src/caskade/module.py index 39f8ca7..e0980f2 100644 --- a/src/caskade/module.py +++ b/src/caskade/module.py @@ -94,9 +94,18 @@ def update_graph(self): super().update_graph() def param_order(self): - return ", ".join( - tuple(f"{next(iter(p.parents)).name}: {p.name}" for p in self.dynamic_params) - ) + res = [] + for g in self.dynamic_param_groups: + res.append( + ", ".join( + tuple( + f"{next(iter(p.parents)).name}: {p.name}" + for p in self.dynamic_params + if p.group == g + ) + ) + ) + return "\n".join(res) @property def dynamic(self) -> bool: diff --git a/tests/test_module.py b/tests/test_module.py index 90c0867..8cd7770 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -127,6 +127,35 @@ def test_input_methods_multi_hierarchical(multi_hierarchical_sim, params_type): assert np.allclose(2 * backend.to_numpy(val), sim.run_sim(20, 22, 2 * p0[0])) +def test_finders(sim): + sim.to_dynamic(False) + assert sim.find_param(0)[0] is sim.workers[0].w2 + assert sim.find_param(0)[1] == (0, 0) + assert all(a[0] is b for a, b in zip(sim.find_param([19, -1]), [sim.helper.h1, sim.s1])) + with pytest.raises(IndexError): + sim.find_param(100) + + assert sim.find_index(sim.workers[0].w2) == slice(0, 4) + assert sim.find_index((sim.helper.h1, sim.s1)) == (19, 27) + with pytest.raises(ValueError): + sim.find_index(Param("bad_param")) + + sim.workers[1].w2.group = 1 + sim.helper.h1.group = 1 + sim.workers[4].w1.group = 1 + assert sim.find_param(0, 1)[0] is sim.workers[1].w2 + assert sim.find_param(0, 1)[1] == (0, 0) + assert all(a[0] is b for a, b in zip(sim.find_param([16, -1], 0), [sim.helper.h2, sim.s1])) + with pytest.raises(IndexError): + sim.find_param(25, 0) + + assert sim.find_index(sim.workers[0].w2) == (0, slice(0, 4)) + assert sim.find_index(sim.workers[1].w2) == (1, slice(0, 4)) + assert sim.find_index((sim.helper.h1, sim.s1)) == ((1, 4), (0, 21)) + with pytest.raises(ValueError): + sim.find_index(Param("bad_param")) + + def nested_double(params): new_params = {} for param in params: From dff06caab9a43ba734aa52a03a74347e418514ce Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Mon, 9 Feb 2026 13:14:53 -0500 Subject: [PATCH 2/9] more complete coverage --- src/caskade/mixins.py | 4 +++- tests/test_module.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/caskade/mixins.py b/src/caskade/mixins.py index 75ed00d..e3b0d23 100644 --- a/src/caskade/mixins.py +++ b/src/caskade/mixins.py @@ -254,7 +254,9 @@ def find_index(self, param: Union[Param, tuple[Param], "Module"]): return tuple(self.find_index(p) for p in param) elif isinstance(param, GetSetValues): return tuple( - self.find_index(c) for c in param.children if isinstance(c, Param) and c.dynamic + self.find_index(c) + for c in param.children.values() + if isinstance(c, Param) and c.dynamic ) if len(self.dynamic_param_groups) > 1: diff --git a/tests/test_module.py b/tests/test_module.py index 8cd7770..4234673 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -137,6 +137,7 @@ def test_finders(sim): assert sim.find_index(sim.workers[0].w2) == slice(0, 4) assert sim.find_index((sim.helper.h1, sim.s1)) == (19, 27) + assert sim.find_index(sim.helper) == (19, slice(20, 22)) with pytest.raises(ValueError): sim.find_index(Param("bad_param")) @@ -156,6 +157,22 @@ def test_finders(sim): sim.find_index(Param("bad_param")) +def test_finders_hierarchical(hierarchical_sim): + sim = hierarchical_sim + sim.to_dynamic(False) + print(sim.param_order()) + assert sim.find_param(0)[0] is sim.helper.h1 + assert sim.find_param(0)[1] == () + assert all(a[0] is b for a, b in zip(sim.find_param([19, -1]), [sim.worker.w2, sim.s1])) + with pytest.raises(IndexError): + sim.find_param(100) + + assert sim.find_index(sim.worker.w2) == slice(8, 28) + assert sim.find_index((sim.helper.h1, sim.s1)) == (0, 28) + with pytest.raises(ValueError): + sim.find_index(Param("bad_param")) + + def nested_double(params): new_params = {} for param in params: From bed5f93a4a703d39388e7c88139250e91352f8e4 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Mon, 9 Feb 2026 14:05:11 -0500 Subject: [PATCH 3/9] add finders for list, raise errors for dict --- src/caskade/mixins.py | 73 +++++++++++++++++++++++++++++++------------ tests/test_module.py | 18 +++++++++++ 2 files changed, 71 insertions(+), 20 deletions(-) diff --git a/src/caskade/mixins.py b/src/caskade/mixins.py index e3b0d23..fed2db3 100644 --- a/src/caskade/mixins.py +++ b/src/caskade/mixins.py @@ -223,7 +223,9 @@ def _array_inspection(self, group: Optional[int] = None): # Finders ################################################################# - def find_param(self, idx: Union[int, tuple[int]], group: Optional[int] = None): + def find_param( + self, idx: Union[int, tuple[int]], group: Optional[int] = None, scheme: str = "array" + ): """ Identify which param is associated with the provided index in the dynamic params array. @@ -244,39 +246,70 @@ def find_param(self, idx: Union[int, tuple[int]], group: Optional[int] = None): associated with idx (empty tuple if scalar). If idx is a tuple then the result is a tuple of these results. """ - x = self._array_inspection(group) - if isinstance(idx, int): + if not isinstance(idx, int): + return tuple(self.find_param(i, group, scheme) for i in idx) + + if scheme == "array": + x = self._array_inspection(group) return x[idx] - return tuple(x[i] for i in idx) + elif scheme == "list": + param_list = tuple(p for p in self.dynamic_params if group is None or p.group == group) + return param_list[idx] + elif scheme == "dict": + raise NotImplementedError( + "find_param is not implemented for the dict scheme. The dict has the same structure as the graph and so may be inspected in a variety of other ways." + ) + else: + raise ValueError(f"unrecognized scheme: {scheme}") - def find_index(self, param: Union[Param, tuple[Param], "Module"]): + def find_index(self, param: Union[Param, tuple[Param], "Module"], scheme="array"): if isinstance(param, (list, tuple)): - return tuple(self.find_index(p) for p in param) + return tuple(self.find_index(p, scheme) for p in param) elif isinstance(param, GetSetValues): return tuple( - self.find_index(c) + self.find_index(c, scheme) for c in param.children.values() if isinstance(c, Param) and c.dynamic ) if len(self.dynamic_param_groups) > 1: for group in self.dynamic_param_groups: - x = self._array_inspection(group) - matches = tuple(m[0] for m in filter(lambda p: p[1][0] is param, enumerate(x))) - if len(matches) == 1: - return (group, matches[0]) - elif len(matches) > 1: - return (group, slice(min(matches), max(matches) + 1)) + if scheme == "array": + x = self._array_inspection(group) + matches = tuple(m[0] for m in filter(lambda p: p[1][0] is param, enumerate(x))) + if len(matches) == 1: + return (group, matches[0]) + elif len(matches) > 1: + return (group, slice(min(matches), max(matches) + 1)) + elif scheme == "list": + param_list = tuple(p for p in self.dynamic_params if p.group == group) + if param in param_list: + return (group, param_list.index(param)) + elif scheme == "dict": + raise NotImplementedError( + "find_index is not implemented for the dict scheme. The dict has the same structure as the graph and so may be inspected in a variety of other ways." + ) + else: + raise ValueError(f"unrecognized scheme: {scheme}") else: raise ValueError(f"Param {param.name} could not be found in dynamic params.") - x = self._array_inspection(None) - matches = tuple(m[0] for m in filter(lambda p: p[1][0] is param, enumerate(x))) - if len(matches) == 1: - return matches[0] - elif len(matches) > 1: - return slice(min(matches), max(matches) + 1) - raise ValueError(f"Param {param.name} could not be found in dynamic params.") + if scheme in ["array", "tensor"]: + x = self._array_inspection(None) + matches = tuple(m[0] for m in filter(lambda p: p[1][0] is param, enumerate(x))) + if len(matches) == 1: + return matches[0] + elif len(matches) > 1: + return slice(min(matches), max(matches) + 1) + raise ValueError(f"Param {param.name} could not be found in dynamic params.") + elif scheme == "list": + return self.dynamic_params.index(param) + elif scheme == "dict": + raise NotImplementedError( + "find_index is not implemented for the dict scheme. The dict has the same structure as the graph and so may be inspected in a variety of other ways." + ) + else: + raise ValueError(f"unrecognized scheme: {scheme}") # To/From Valid ################################################################# diff --git a/tests/test_module.py b/tests/test_module.py index 4234673..35f89aa 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -135,12 +135,24 @@ def test_finders(sim): with pytest.raises(IndexError): sim.find_param(100) + assert sim.find_param(0, scheme="list") is sim.workers[0].w2 + with pytest.raises(NotImplementedError): + sim.find_param(0, scheme="dict") + with pytest.raises(ValueError): + sim.find_param(0, scheme="funky") + assert sim.find_index(sim.workers[0].w2) == slice(0, 4) assert sim.find_index((sim.helper.h1, sim.s1)) == (19, 27) assert sim.find_index(sim.helper) == (19, slice(20, 22)) with pytest.raises(ValueError): sim.find_index(Param("bad_param")) + assert sim.find_index(sim.workers[0].w2, scheme="list") == 0 + with pytest.raises(NotImplementedError): + sim.find_index(sim.s1, scheme="dict") + with pytest.raises(ValueError): + sim.find_index(sim.s1, scheme="funky") + sim.workers[1].w2.group = 1 sim.helper.h1.group = 1 sim.workers[4].w1.group = 1 @@ -156,6 +168,12 @@ def test_finders(sim): with pytest.raises(ValueError): sim.find_index(Param("bad_param")) + assert sim.find_index(sim.workers[0].w2, scheme="list") == (0, 0) + with pytest.raises(NotImplementedError): + sim.find_index(sim.s1, scheme="dict") + with pytest.raises(ValueError): + sim.find_index(sim.s1, scheme="funky") + def test_finders_hierarchical(hierarchical_sim): sim = hierarchical_sim From d674ca4312a0e0be98f4a1c1578c7c4fbe8d34c7 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Mon, 9 Feb 2026 14:38:07 -0500 Subject: [PATCH 4/9] simplify find_index --- src/caskade/mixins.py | 64 ++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 38 deletions(-) diff --git a/src/caskade/mixins.py b/src/caskade/mixins.py index fed2db3..7efc7ff 100644 --- a/src/caskade/mixins.py +++ b/src/caskade/mixins.py @@ -263,53 +263,41 @@ def find_param( raise ValueError(f"unrecognized scheme: {scheme}") def find_index(self, param: Union[Param, tuple[Param], "Module"], scheme="array"): + # 1. Handle recursive structures if isinstance(param, (list, tuple)): return tuple(self.find_index(p, scheme) for p in param) - elif isinstance(param, GetSetValues): + if isinstance(param, GetSetValues): return tuple( self.find_index(c, scheme) for c in param.children.values() if isinstance(c, Param) and c.dynamic ) - if len(self.dynamic_param_groups) > 1: - for group in self.dynamic_param_groups: - if scheme == "array": - x = self._array_inspection(group) - matches = tuple(m[0] for m in filter(lambda p: p[1][0] is param, enumerate(x))) - if len(matches) == 1: - return (group, matches[0]) - elif len(matches) > 1: - return (group, slice(min(matches), max(matches) + 1)) - elif scheme == "list": - param_list = tuple(p for p in self.dynamic_params if p.group == group) - if param in param_list: - return (group, param_list.index(param)) - elif scheme == "dict": - raise NotImplementedError( - "find_index is not implemented for the dict scheme. The dict has the same structure as the graph and so may be inspected in a variety of other ways." - ) - else: - raise ValueError(f"unrecognized scheme: {scheme}") + groups = self.dynamic_param_groups if len(self.dynamic_param_groups) > 1 else [None] + + for group in groups: + if scheme in ["array", "tensor"]: + inspection = self._array_inspection(group) + matches = [i for i, item in enumerate(inspection) if item[0] is param] + + if not matches: + continue + idx = matches[0] if len(matches) == 1 else slice(min(matches), max(matches) + 1) + + elif scheme == "list": + param_list = [p for p in self.dynamic_params if group is None or p.group == group] + if param not in param_list: + continue + idx = param_list.index(param) + elif scheme == "dict": + raise NotImplementedError("find_index is not implemented for the dict scheme.") else: - raise ValueError(f"Param {param.name} could not be found in dynamic params.") - - if scheme in ["array", "tensor"]: - x = self._array_inspection(None) - matches = tuple(m[0] for m in filter(lambda p: p[1][0] is param, enumerate(x))) - if len(matches) == 1: - return matches[0] - elif len(matches) > 1: - return slice(min(matches), max(matches) + 1) - raise ValueError(f"Param {param.name} could not be found in dynamic params.") - elif scheme == "list": - return self.dynamic_params.index(param) - elif scheme == "dict": - raise NotImplementedError( - "find_index is not implemented for the dict scheme. The dict has the same structure as the graph and so may be inspected in a variety of other ways." - ) - else: - raise ValueError(f"unrecognized scheme: {scheme}") + raise ValueError(f"unrecognized scheme: {scheme}") + + # Return with group prefix if we are in multi-group mode + return (group, idx) if len(self.dynamic_param_groups) > 1 else idx + + raise ValueError(f"Param {param.name} could not be found in dynamic params.") # To/From Valid ################################################################# From bb2310b2f94e043afd0dede30f615372a4f86d1a Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Mon, 9 Feb 2026 14:46:42 -0500 Subject: [PATCH 5/9] add coverage for list on missing param --- tests/test_module.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_module.py b/tests/test_module.py index 35f89aa..29a52bf 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -148,6 +148,8 @@ def test_finders(sim): sim.find_index(Param("bad_param")) assert sim.find_index(sim.workers[0].w2, scheme="list") == 0 + with pytest.raises(ValueError): + sim.find_index(Param("bad_param"), scheme="list") with pytest.raises(NotImplementedError): sim.find_index(sim.s1, scheme="dict") with pytest.raises(ValueError): From f6a48bb29bf2c5ad6c7184221c71310a4b7fc6c9 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Mon, 9 Feb 2026 15:16:26 -0500 Subject: [PATCH 6/9] extra test, mostly trigger CI --- tests/test_module.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_module.py b/tests/test_module.py index 29a52bf..6d8944a 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -171,6 +171,8 @@ def test_finders(sim): sim.find_index(Param("bad_param")) assert sim.find_index(sim.workers[0].w2, scheme="list") == (0, 0) + with pytest.raises(ValueError): + sim.find_index(Param("bad_param"), scheme="list") with pytest.raises(NotImplementedError): sim.find_index(sim.s1, scheme="dict") with pytest.raises(ValueError): From dd29d7b7b45293f77400c9ed7c3b81eda614ba03 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Mon, 9 Feb 2026 15:40:31 -0500 Subject: [PATCH 7/9] Add finders to beginners guide --- docs/source/notebooks/BeginnersGuide.ipynb | 28 ++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/docs/source/notebooks/BeginnersGuide.ipynb b/docs/source/notebooks/BeginnersGuide.ipynb index 6e286bf..a9a0c71 100644 --- a/docs/source/notebooks/BeginnersGuide.ipynb +++ b/docs/source/notebooks/BeginnersGuide.ipynb @@ -144,6 +144,34 @@ "# param.npvalue converts the value into numpy before returning it" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`caskade` collapses all params into a 1D array, even if the param had multiple values itself. To explore the 1D array you can use the \"finders\" as shown below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "params = secondsim.get_values()\n", + "print(params)\n", + "# Hmmm, I wonder which param goes in place 3?\n", + "# The result is a tuple (Param, index), the index tells you within the param\n", + "# where the index 3 lands. Since phi is a scalar this is just ()\n", + "print(secondsim.find_param(3))\n", + "# Notice if we get index 1 the returned index is more interesting\n", + "print(secondsim.find_param(1))\n", + "\n", + "# Hmmm, I wonder which index the q param corresponds to?\n", + "print(secondsim.find_index(secondsim.q))\n", + "# For multidimensional params, we will get a slice instead\n", + "print(secondsim.find_index(secondsim.x0))" + ] + }, { "cell_type": "markdown", "metadata": {}, From b75ec7729622d809e90a896272844f23daeef979 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Mon, 9 Feb 2026 16:05:29 -0500 Subject: [PATCH 8/9] add finder docstring --- src/caskade/mixins.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/caskade/mixins.py b/src/caskade/mixins.py index 7efc7ff..6b5723d 100644 --- a/src/caskade/mixins.py +++ b/src/caskade/mixins.py @@ -225,7 +225,7 @@ def _array_inspection(self, group: Optional[int] = None): ################################################################# def find_param( self, idx: Union[int, tuple[int]], group: Optional[int] = None, scheme: str = "array" - ): + ) -> tuple[Param, tuple[int]]: """ Identify which param is associated with the provided index in the dynamic params array. @@ -238,10 +238,13 @@ def find_param( group: Optional[int] If the dynamic params have multiple group values, then this argument specifies which group to check. + scheme: str + Whether to search the array (default) params or list version of + params. dict is currently unsupported. Returns ------- - param_info: tuple[Param, Optional[tuple[int]]] + param_info: tuple[Param, tuple[int]] A tuple with the Param object and the index within the Param value associated with idx (empty tuple if scalar). If idx is a tuple then the result is a tuple of these results. @@ -262,7 +265,28 @@ def find_param( else: raise ValueError(f"unrecognized scheme: {scheme}") - def find_index(self, param: Union[Param, tuple[Param], "Module"], scheme="array"): + def find_index( + self, param: Union[Param, tuple[Param], "Module"], scheme: str = "array" + ) -> Union[int, slice]: + """ + Identify what index is associated with a param in the dynamic params + array. + + Parameters + ---------- + param: Union[Param, tuple[Param], Module] + The param for which to find the associated index. + scheme: str + Whether to search the array (default) params or list version of + params. dict is currently unsupported. + + Returns + ------- + param_info: Union[int, slice] + A int giving the index associated with the provided Param object. If + the param is multi-dimensional then the result will be a slice over + all indices associated with that param. + """ # 1. Handle recursive structures if isinstance(param, (list, tuple)): return tuple(self.find_index(p, scheme) for p in param) From a07f8246429064ea309c2c7b78f38c23aa4c76ef Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 10 Feb 2026 08:59:13 -0500 Subject: [PATCH 9/9] note about finder of lists --- docs/source/notebooks/BeginnersGuide.ipynb | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/source/notebooks/BeginnersGuide.ipynb b/docs/source/notebooks/BeginnersGuide.ipynb index a9a0c71..bf39e9c 100644 --- a/docs/source/notebooks/BeginnersGuide.ipynb +++ b/docs/source/notebooks/BeginnersGuide.ipynb @@ -169,7 +169,10 @@ "# Hmmm, I wonder which index the q param corresponds to?\n", "print(secondsim.find_index(secondsim.q))\n", "# For multidimensional params, we will get a slice instead\n", - "print(secondsim.find_index(secondsim.x0))" + "print(secondsim.find_index(secondsim.x0))\n", + "\n", + "# You can also query lists to get a bunch at once\n", + "print(secondsim.find_param([0, 1, 2]))" ] }, {