Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions docs/source/notebooks/AdvancedGuide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,43 @@
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Param shape with `None`\n",
"\n",
"Sometimes we know that a Param ought to have a particular number of dimensions, but we don't know ahead of time what size they will be. For example a Param may represent a 2D image, but of any size. We can enforce this by setting the Param shape to `(None, None)`. Now when we set a value for this param, it will need to conform to that expectation. A scalar value will raise an error, and a 3D value will be assumed to be batched. Here is a basic example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"p = ck.Param(\"image\", shape=(None, None))\n",
"print(\"Not very interesting, just returns exactly what we set: \", p.shape)\n",
"try:\n",
" p.value = 5\n",
"except Exception as e:\n",
" print(\"Got an error: \", e)\n",
"\n",
"p.value = np.ones((10, 15))\n",
"print(\"Now p shape reflects the value that was set: \", p.shape)\n",
"\n",
"p.value = np.ones((5, 20, 25))\n",
"print(\"Now p shape reflects the appropriate 2 dimensions:\", p.shape)\n",
"print(\"And the extra dimension is assumed to be a batch dimension: \", p.batch_shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that the `None` must be in a tuple for it to work this way. So `param.shape = None` means no defined shape, it will just take the shape of `param.value.shape`; a 1D vector of unspecified length must be written as `param.shape = (None,)`. You may also mix and match, so `param.shape = (None, 2)` would be an unspecified number of size `2` vectors, for example."
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
42 changes: 29 additions & 13 deletions src/caskade/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,22 @@


def valid_shape(batch_shape, shape, value_shape):
if shape is None: # no shape to compare
# No shape to compare
if shape is None:
return True

# Determine what to compare
if batch_shape is None:
if value_shape == shape: # shapes match
return True
if value_shape[len(value_shape) - len(shape) :] == shape: # endswith
return True
value_shape = value_shape[len(value_shape) - len(shape) :]
else:
shape = batch_shape + shape

# Definitely dont match, wrong lengths
if len(value_shape) != len(shape):
return False
return value_shape == (batch_shape + shape)

# Check for match or None
return all(s is None or v == s for v, s in zip(value_shape, shape))


NULL = object()
Expand Down Expand Up @@ -265,13 +272,20 @@ def to_pointer(self, value, link=()):
self.node_type = "pointer"

@property
def shape(self) -> Optional[tuple[int, ...]]:
if self._shape is not None:
return self._shape
def shape(self) -> tuple[int, ...]:
value = self.value
if value is not None:
return tuple(value.shape)
return ()
# 1. Handle cases where no shape template is defined
if self._shape is None:
return tuple(value.shape) if value is not None else ()

# 2. If value is missing, return the template as-is
if value is None:
return self._shape

# 3. Fill wildcards (None) in _shape using the trailing dimensions of value
# Negative indexing handles the alignment automatically
n = len(self._shape)
return tuple(v if s is None else s for s, v in zip(self._shape, value.shape[-n:]))

@shape.setter
def shape(self, shape: Optional[Iterable]):
Expand All @@ -286,7 +300,9 @@ def shape(self, shape: Optional[Iterable]):
try:
shape = tuple(shape)
except TypeError:
raise ParamConfigurationError(f"Param shape must be iterable of ints ({self.name})")
raise ParamConfigurationError(
f"Param shape must be iterable of ints/None, not: {type(shape)}. ({self.name})"
)
if value is None or valid_shape(self._batch_shape, shape, tuple(value.shape)):
self._shape = shape
return
Expand Down
31 changes: 29 additions & 2 deletions tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
ActiveStateError,
ParamConfigurationError,
ParamTypeError,
GraphError,
InvalidValueWarning,
LinkToAttributeError,
ActiveContext,
backend,
)
Expand Down Expand Up @@ -62,6 +60,35 @@ def test_param_creation(many_param, capsys):
assert captured.out == ""


def test_none_shape_param(capsys):
p = Param("p", np.ones((3, 4, 5)), shape=(None, 5))
assert p.shape == (4, 5)
assert p.batch_shape == (3,)
p.value = np.ones((2, 3, 5))
assert p.shape == (3, 5)
assert p.batch_shape == (2,)

p.value = None

assert p.shape == (None, 5)

with pytest.raises(ParamConfigurationError):
p.value = np.ones(5)
with pytest.raises(ParamConfigurationError):
p.value = np.ones((2, 2))

p.value = np.ones((4, 5))

with pytest.raises(ValueError):
p.shape = (4, 2)
with pytest.raises(ValueError):
p.shape = (3, 4, 5)

# Ensure no spurious output
captured = capsys.readouterr()
assert captured.out == ""


@pytest.mark.parametrize("value", [None, 1, [1, 2]])
@pytest.mark.parametrize("dynamic", [True, False])
def test_active_state(value, dynamic):
Expand Down
Loading