From 8db79eb8e8abe4cd00b4547ac7e82810a30364a6 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Mon, 9 Feb 2026 09:11:30 -0500 Subject: [PATCH 1/5] allow None in param shape --- src/caskade/param.py | 28 ++++++++++++++++++++-------- tests/test_param.py | 31 +++++++++++++++++++++++++++++-- 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/src/caskade/param.py b/src/caskade/param.py index 2e3d997..bb1f2c0 100644 --- a/src/caskade/param.py +++ b/src/caskade/param.py @@ -15,12 +15,17 @@ def valid_shape(batch_shape, shape, value_shape): if shape is None: # no shape to compare return True if batch_shape is None: - if value_shape == shape: # shapes match - return True - if value_shape[len(value_shape) - len(shape) :] == shape: # endswith - return True + value_shape = value_shape[len(value_shape) - len(shape) :] + else: + shape = batch_shape + shape + if value_shape == shape: # shapes match + return True + if len(value_shape) != len(shape): # definitely dont match return False - return value_shape == (batch_shape + shape) + for v, s in zip(value_shape, shape): + if s is not None and v != s: # dont match, skip Nones + return False + return True NULL = object() @@ -266,9 +271,14 @@ def to_pointer(self, value, link=()): @property def shape(self) -> Optional[tuple[int, ...]]: - if self._shape is not None: - return self._shape value = self.value + if self._shape is not None: + Ns = len(self._shape) + if value is None: + return self._shape + return tuple( + value.shape[s_i - Ns] if s is None else s for s_i, s in enumerate(self._shape) + ) if value is not None: return tuple(value.shape) return () @@ -286,7 +296,9 @@ def shape(self, shape: Optional[Iterable]): try: shape = tuple(shape) except TypeError: - raise ParamConfigurationError(f"Param shape must be iterable of ints ({self.name})") + raise ParamConfigurationError( + f"Param shape must be iterable of ints/None, not: {type(shape)}. ({self.name})" + ) if value is None or valid_shape(self._batch_shape, shape, tuple(value.shape)): self._shape = shape return diff --git a/tests/test_param.py b/tests/test_param.py index 354edf8..0adbb5b 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -7,9 +7,7 @@ ActiveStateError, ParamConfigurationError, ParamTypeError, - GraphError, InvalidValueWarning, - LinkToAttributeError, ActiveContext, backend, ) @@ -62,6 +60,35 @@ def test_param_creation(many_param, capsys): assert captured.out == "" +def test_none_shape_param(capsys): + p = Param("p", np.ones((3, 4, 5)), shape=(None, 5)) + assert p.shape == (4, 5) + assert p.batch_shape == (3,) + p.value = np.ones((2, 3, 5)) + assert p.shape == (3, 5) + assert p.batch_shape == (2,) + + p.value = None + + assert p.shape == (None, 5) + + with pytest.raises(ParamConfigurationError): + p.value = np.ones(5) + with pytest.raises(ParamConfigurationError): + p.value = np.ones((2, 2)) + + p.value = np.ones((4, 5)) + + with pytest.raises(ValueError): + p.shape = (4, 2) + with pytest.raises(ValueError): + p.shape = (3, 4, 5) + + # Ensure no spurious output + captured = capsys.readouterr() + assert captured.out == "" + + @pytest.mark.parametrize("value", [None, 1, [1, 2]]) @pytest.mark.parametrize("dynamic", [True, False]) def test_active_state(value, dynamic): From 2d1c7f36eea3ec4ce8559ddd2b8284e7b15da664 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Mon, 9 Feb 2026 09:28:16 -0500 Subject: [PATCH 2/5] simplify valid shape logic --- src/caskade/param.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/caskade/param.py b/src/caskade/param.py index bb1f2c0..b5035bd 100644 --- a/src/caskade/param.py +++ b/src/caskade/param.py @@ -12,20 +12,22 @@ def valid_shape(batch_shape, shape, value_shape): - if shape is None: # no shape to compare + # No shape to compare + if shape is None: return True + + # Determine what to compare if batch_shape is None: value_shape = value_shape[len(value_shape) - len(shape) :] else: shape = batch_shape + shape - if value_shape == shape: # shapes match - return True - if len(value_shape) != len(shape): # definitely dont match + + # Definitely dont match, wrong lengths + if len(value_shape) != len(shape): return False - for v, s in zip(value_shape, shape): - if s is not None and v != s: # dont match, skip Nones - return False - return True + + # Check for match or None + return all(s is None or v == s for v, s in zip(value_shape, shape)) NULL = object() From 5ea6ce8690e6ced9aeadc39919ab4c969ad53052 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Mon, 9 Feb 2026 09:33:26 -0500 Subject: [PATCH 3/5] simplify shape logic --- src/caskade/param.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/caskade/param.py b/src/caskade/param.py index b5035bd..7fe79a4 100644 --- a/src/caskade/param.py +++ b/src/caskade/param.py @@ -272,18 +272,20 @@ def to_pointer(self, value, link=()): self.node_type = "pointer" @property - def shape(self) -> Optional[tuple[int, ...]]: + def shape(self) -> tuple[int, ...]: value = self.value - if self._shape is not None: - Ns = len(self._shape) - if value is None: - return self._shape - return tuple( - value.shape[s_i - Ns] if s is None else s for s_i, s in enumerate(self._shape) - ) - if value is not None: - return tuple(value.shape) - return () + # 1. Handle cases where no shape template is defined + if self._shape is None: + return tuple(value.shape) if value is not None else () + + # 2. If value is missing, return the template as-is + if value is None: + return self._shape + + # 3. Fill wildcards (None) in _shape using the trailing dimensions of value + # Negative indexing handles the alignment automatically + n = len(self._shape) + return tuple(v if s is None else s for s, v in zip(self._shape, value.shape[-n:])) @shape.setter def shape(self, shape: Optional[Iterable]): From 2edb23ecae18732b4a5a6619d670de4ae3034798 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Mon, 9 Feb 2026 15:29:30 -0500 Subject: [PATCH 4/5] Add none shape to advanced guide --- docs/source/notebooks/AdvancedGuide.ipynb | 30 +++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/docs/source/notebooks/AdvancedGuide.ipynb b/docs/source/notebooks/AdvancedGuide.ipynb index 4c9ba9e..8586b4e 100644 --- a/docs/source/notebooks/AdvancedGuide.ipynb +++ b/docs/source/notebooks/AdvancedGuide.ipynb @@ -606,6 +606,36 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Param shape with `None`\n", + "\n", + "Sometimes we know that a Param ought to have a particular number of dimensions, but we don't know ahead of time what size they will be. For example a Param may represent a 2D image, but of any size. We can enforce this by setting the Param shape to `(None, None)`. Now when we set a value for this param, it will need to conform to that expectation. A scalar value will raise an error, and a 3D value will be assumed to be batched. Here is a basic example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "p = ck.Param(\"image\", shape=(None, None))\n", + "print(\"Not very interesting, just returns exactly what we set: \", p.shape)\n", + "try:\n", + " p.value = 5\n", + "except Exception as e:\n", + " print(\"Got an error: \", e)\n", + "\n", + "p.value = np.ones((10, 15))\n", + "print(\"Now p shape reflects the value that was set: \", p.shape)\n", + "\n", + "p.value = np.ones((5, 20, 25))\n", + "print(\"Now p shape reflects the appropriate 2 dimensions:\", p.shape)\n", + "print(\"And the extra dimension is assumed to be a batch dimension: \", p.batch_shape)" + ] + }, { "cell_type": "markdown", "metadata": {}, From dba72858602ed7a618145e06f792b6843eab09e1 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 10 Feb 2026 08:54:38 -0500 Subject: [PATCH 5/5] further explanation of None shape --- docs/source/notebooks/AdvancedGuide.ipynb | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/source/notebooks/AdvancedGuide.ipynb b/docs/source/notebooks/AdvancedGuide.ipynb index 8586b4e..7948b69 100644 --- a/docs/source/notebooks/AdvancedGuide.ipynb +++ b/docs/source/notebooks/AdvancedGuide.ipynb @@ -636,6 +636,13 @@ "print(\"And the extra dimension is assumed to be a batch dimension: \", p.batch_shape)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that the `None` must be in a tuple for it to work this way. So `param.shape = None` means no defined shape, it will just take the shape of `param.value.shape`; a 1D vector of unspecified length must be written as `param.shape = (None,)`. You may also mix and match, so `param.shape = (None, 2)` would be an unspecified number of size `2` vectors, for example." + ] + }, { "cell_type": "markdown", "metadata": {},