diff --git a/docs/source/notebooks/BeginnersGuide.ipynb b/docs/source/notebooks/BeginnersGuide.ipynb index 6e286bf..bf39e9c 100644 --- a/docs/source/notebooks/BeginnersGuide.ipynb +++ b/docs/source/notebooks/BeginnersGuide.ipynb @@ -144,6 +144,37 @@ "# param.npvalue converts the value into numpy before returning it" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`caskade` collapses all params into a 1D array, even if the param had multiple values itself. To explore the 1D array you can use the \"finders\" as shown below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "params = secondsim.get_values()\n", + "print(params)\n", + "# Hmmm, I wonder which param goes in place 3?\n", + "# The result is a tuple (Param, index), the index tells you within the param\n", + "# where the index 3 lands. Since phi is a scalar this is just ()\n", + "print(secondsim.find_param(3))\n", + "# Notice if we get index 1 the returned index is more interesting\n", + "print(secondsim.find_param(1))\n", + "\n", + "# Hmmm, I wonder which index the q param corresponds to?\n", + "print(secondsim.find_index(secondsim.q))\n", + "# For multidimensional params, we will get a slice instead\n", + "print(secondsim.find_index(secondsim.x0))\n", + "\n", + "# You can also query lists to get a bunch at once\n", + "print(secondsim.find_param([0, 1, 2]))" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/caskade/mixins.py b/src/caskade/mixins.py index f566779..6b5723d 100644 --- a/src/caskade/mixins.py +++ b/src/caskade/mixins.py @@ -1,5 +1,6 @@ from typing import Optional, Mapping, Sequence, Union from math import prod +import numpy as np from .param import Param from .errors import ( @@ -200,6 +201,128 @@ def _recursive_build_params_dict( del params[link] return params + def _array_inspection(self, group: Optional[int] = None): + param_list = self.dynamic_params + param_list = tuple(p for p in param_list if (group is None or p.group == group)) + self._check_values(param_list, "array") + + x = [] + with Memo(self, self.name + ":semi_findidx_active"): + for param in param_list: + if param.online: + shape = param.shape + else: + depth = max(memo.count("|") for memo in param.memos) + shape = param.batch_shape[-depth:] + param.shape + if shape == (): + x.append((param, ())) + else: + for i in range(prod(shape)): + x.append((param, tuple(itm.item() for itm in np.unravel_index(i, shape)))) + return x + + # Finders + ################################################################# + def find_param( + self, idx: Union[int, tuple[int]], group: Optional[int] = None, scheme: str = "array" + ) -> tuple[Param, tuple[int]]: + """ + Identify which param is associated with the provided index in the + dynamic params array. + + Parameters + ---------- + idx: Union[int, tuple[int]] + The index in the params array at which we wish to find the + associated param. + group: Optional[int] + If the dynamic params have multiple group values, then this argument + specifies which group to check. + scheme: str + Whether to search the array (default) params or list version of + params. dict is currently unsupported. + + Returns + ------- + param_info: tuple[Param, tuple[int]] + A tuple with the Param object and the index within the Param value + associated with idx (empty tuple if scalar). If idx is a tuple then + the result is a tuple of these results. + """ + if not isinstance(idx, int): + return tuple(self.find_param(i, group, scheme) for i in idx) + + if scheme == "array": + x = self._array_inspection(group) + return x[idx] + elif scheme == "list": + param_list = tuple(p for p in self.dynamic_params if group is None or p.group == group) + return param_list[idx] + elif scheme == "dict": + raise NotImplementedError( + "find_param is not implemented for the dict scheme. The dict has the same structure as the graph and so may be inspected in a variety of other ways." + ) + else: + raise ValueError(f"unrecognized scheme: {scheme}") + + def find_index( + self, param: Union[Param, tuple[Param], "Module"], scheme: str = "array" + ) -> Union[int, slice]: + """ + Identify what index is associated with a param in the dynamic params + array. + + Parameters + ---------- + param: Union[Param, tuple[Param], Module] + The param for which to find the associated index. + scheme: str + Whether to search the array (default) params or list version of + params. dict is currently unsupported. + + Returns + ------- + param_info: Union[int, slice] + A int giving the index associated with the provided Param object. If + the param is multi-dimensional then the result will be a slice over + all indices associated with that param. + """ + # 1. Handle recursive structures + if isinstance(param, (list, tuple)): + return tuple(self.find_index(p, scheme) for p in param) + if isinstance(param, GetSetValues): + return tuple( + self.find_index(c, scheme) + for c in param.children.values() + if isinstance(c, Param) and c.dynamic + ) + + groups = self.dynamic_param_groups if len(self.dynamic_param_groups) > 1 else [None] + + for group in groups: + if scheme in ["array", "tensor"]: + inspection = self._array_inspection(group) + matches = [i for i, item in enumerate(inspection) if item[0] is param] + + if not matches: + continue + idx = matches[0] if len(matches) == 1 else slice(min(matches), max(matches) + 1) + + elif scheme == "list": + param_list = [p for p in self.dynamic_params if group is None or p.group == group] + if param not in param_list: + continue + idx = param_list.index(param) + elif scheme == "dict": + raise NotImplementedError("find_index is not implemented for the dict scheme.") + else: + raise ValueError(f"unrecognized scheme: {scheme}") + + # Return with group prefix if we are in multi-group mode + return (group, idx) if len(self.dynamic_param_groups) > 1 else idx + + raise ValueError(f"Param {param.name} could not be found in dynamic params.") + # To/From Valid ################################################################# def _transform_params(self, node, init_params, param_list, transform_attr): diff --git a/src/caskade/module.py b/src/caskade/module.py index 39f8ca7..e0980f2 100644 --- a/src/caskade/module.py +++ b/src/caskade/module.py @@ -94,9 +94,18 @@ def update_graph(self): super().update_graph() def param_order(self): - return ", ".join( - tuple(f"{next(iter(p.parents)).name}: {p.name}" for p in self.dynamic_params) - ) + res = [] + for g in self.dynamic_param_groups: + res.append( + ", ".join( + tuple( + f"{next(iter(p.parents)).name}: {p.name}" + for p in self.dynamic_params + if p.group == g + ) + ) + ) + return "\n".join(res) @property def dynamic(self) -> bool: diff --git a/tests/test_module.py b/tests/test_module.py index 90c0867..6d8944a 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -127,6 +127,74 @@ def test_input_methods_multi_hierarchical(multi_hierarchical_sim, params_type): assert np.allclose(2 * backend.to_numpy(val), sim.run_sim(20, 22, 2 * p0[0])) +def test_finders(sim): + sim.to_dynamic(False) + assert sim.find_param(0)[0] is sim.workers[0].w2 + assert sim.find_param(0)[1] == (0, 0) + assert all(a[0] is b for a, b in zip(sim.find_param([19, -1]), [sim.helper.h1, sim.s1])) + with pytest.raises(IndexError): + sim.find_param(100) + + assert sim.find_param(0, scheme="list") is sim.workers[0].w2 + with pytest.raises(NotImplementedError): + sim.find_param(0, scheme="dict") + with pytest.raises(ValueError): + sim.find_param(0, scheme="funky") + + assert sim.find_index(sim.workers[0].w2) == slice(0, 4) + assert sim.find_index((sim.helper.h1, sim.s1)) == (19, 27) + assert sim.find_index(sim.helper) == (19, slice(20, 22)) + with pytest.raises(ValueError): + sim.find_index(Param("bad_param")) + + assert sim.find_index(sim.workers[0].w2, scheme="list") == 0 + with pytest.raises(ValueError): + sim.find_index(Param("bad_param"), scheme="list") + with pytest.raises(NotImplementedError): + sim.find_index(sim.s1, scheme="dict") + with pytest.raises(ValueError): + sim.find_index(sim.s1, scheme="funky") + + sim.workers[1].w2.group = 1 + sim.helper.h1.group = 1 + sim.workers[4].w1.group = 1 + assert sim.find_param(0, 1)[0] is sim.workers[1].w2 + assert sim.find_param(0, 1)[1] == (0, 0) + assert all(a[0] is b for a, b in zip(sim.find_param([16, -1], 0), [sim.helper.h2, sim.s1])) + with pytest.raises(IndexError): + sim.find_param(25, 0) + + assert sim.find_index(sim.workers[0].w2) == (0, slice(0, 4)) + assert sim.find_index(sim.workers[1].w2) == (1, slice(0, 4)) + assert sim.find_index((sim.helper.h1, sim.s1)) == ((1, 4), (0, 21)) + with pytest.raises(ValueError): + sim.find_index(Param("bad_param")) + + assert sim.find_index(sim.workers[0].w2, scheme="list") == (0, 0) + with pytest.raises(ValueError): + sim.find_index(Param("bad_param"), scheme="list") + with pytest.raises(NotImplementedError): + sim.find_index(sim.s1, scheme="dict") + with pytest.raises(ValueError): + sim.find_index(sim.s1, scheme="funky") + + +def test_finders_hierarchical(hierarchical_sim): + sim = hierarchical_sim + sim.to_dynamic(False) + print(sim.param_order()) + assert sim.find_param(0)[0] is sim.helper.h1 + assert sim.find_param(0)[1] == () + assert all(a[0] is b for a, b in zip(sim.find_param([19, -1]), [sim.worker.w2, sim.s1])) + with pytest.raises(IndexError): + sim.find_param(100) + + assert sim.find_index(sim.worker.w2) == slice(8, 28) + assert sim.find_index((sim.helper.h1, sim.s1)) == (0, 28) + with pytest.raises(ValueError): + sim.find_index(Param("bad_param")) + + def nested_double(params): new_params = {} for param in params: