diff --git a/docs/source/notebooks/AdvancedGuide.ipynb b/docs/source/notebooks/AdvancedGuide.ipynb index 4c9ba9e..7948b69 100644 --- a/docs/source/notebooks/AdvancedGuide.ipynb +++ b/docs/source/notebooks/AdvancedGuide.ipynb @@ -606,6 +606,43 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Param shape with `None`\n", + "\n", + "Sometimes we know that a Param ought to have a particular number of dimensions, but we don't know ahead of time what size they will be. For example a Param may represent a 2D image, but of any size. We can enforce this by setting the Param shape to `(None, None)`. Now when we set a value for this param, it will need to conform to that expectation. A scalar value will raise an error, and a 3D value will be assumed to be batched. Here is a basic example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "p = ck.Param(\"image\", shape=(None, None))\n", + "print(\"Not very interesting, just returns exactly what we set: \", p.shape)\n", + "try:\n", + " p.value = 5\n", + "except Exception as e:\n", + " print(\"Got an error: \", e)\n", + "\n", + "p.value = np.ones((10, 15))\n", + "print(\"Now p shape reflects the value that was set: \", p.shape)\n", + "\n", + "p.value = np.ones((5, 20, 25))\n", + "print(\"Now p shape reflects the appropriate 2 dimensions:\", p.shape)\n", + "print(\"And the extra dimension is assumed to be a batch dimension: \", p.batch_shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that the `None` must be in a tuple for it to work this way. So `param.shape = None` means no defined shape, it will just take the shape of `param.value.shape`; a 1D vector of unspecified length must be written as `param.shape = (None,)`. You may also mix and match, so `param.shape = (None, 2)` would be an unspecified number of size `2` vectors, for example." + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/caskade/param.py b/src/caskade/param.py index 2e3d997..7fe79a4 100644 --- a/src/caskade/param.py +++ b/src/caskade/param.py @@ -12,15 +12,22 @@ def valid_shape(batch_shape, shape, value_shape): - if shape is None: # no shape to compare + # No shape to compare + if shape is None: return True + + # Determine what to compare if batch_shape is None: - if value_shape == shape: # shapes match - return True - if value_shape[len(value_shape) - len(shape) :] == shape: # endswith - return True + value_shape = value_shape[len(value_shape) - len(shape) :] + else: + shape = batch_shape + shape + + # Definitely dont match, wrong lengths + if len(value_shape) != len(shape): return False - return value_shape == (batch_shape + shape) + + # Check for match or None + return all(s is None or v == s for v, s in zip(value_shape, shape)) NULL = object() @@ -265,13 +272,20 @@ def to_pointer(self, value, link=()): self.node_type = "pointer" @property - def shape(self) -> Optional[tuple[int, ...]]: - if self._shape is not None: - return self._shape + def shape(self) -> tuple[int, ...]: value = self.value - if value is not None: - return tuple(value.shape) - return () + # 1. Handle cases where no shape template is defined + if self._shape is None: + return tuple(value.shape) if value is not None else () + + # 2. If value is missing, return the template as-is + if value is None: + return self._shape + + # 3. Fill wildcards (None) in _shape using the trailing dimensions of value + # Negative indexing handles the alignment automatically + n = len(self._shape) + return tuple(v if s is None else s for s, v in zip(self._shape, value.shape[-n:])) @shape.setter def shape(self, shape: Optional[Iterable]): @@ -286,7 +300,9 @@ def shape(self, shape: Optional[Iterable]): try: shape = tuple(shape) except TypeError: - raise ParamConfigurationError(f"Param shape must be iterable of ints ({self.name})") + raise ParamConfigurationError( + f"Param shape must be iterable of ints/None, not: {type(shape)}. ({self.name})" + ) if value is None or valid_shape(self._batch_shape, shape, tuple(value.shape)): self._shape = shape return diff --git a/tests/test_param.py b/tests/test_param.py index 354edf8..0adbb5b 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -7,9 +7,7 @@ ActiveStateError, ParamConfigurationError, ParamTypeError, - GraphError, InvalidValueWarning, - LinkToAttributeError, ActiveContext, backend, ) @@ -62,6 +60,35 @@ def test_param_creation(many_param, capsys): assert captured.out == "" +def test_none_shape_param(capsys): + p = Param("p", np.ones((3, 4, 5)), shape=(None, 5)) + assert p.shape == (4, 5) + assert p.batch_shape == (3,) + p.value = np.ones((2, 3, 5)) + assert p.shape == (3, 5) + assert p.batch_shape == (2,) + + p.value = None + + assert p.shape == (None, 5) + + with pytest.raises(ParamConfigurationError): + p.value = np.ones(5) + with pytest.raises(ParamConfigurationError): + p.value = np.ones((2, 2)) + + p.value = np.ones((4, 5)) + + with pytest.raises(ValueError): + p.shape = (4, 2) + with pytest.raises(ValueError): + p.shape = (3, 4, 5) + + # Ensure no spurious output + captured = capsys.readouterr() + assert captured.out == "" + + @pytest.mark.parametrize("value", [None, 1, [1, 2]]) @pytest.mark.parametrize("dynamic", [True, False]) def test_active_state(value, dynamic):