Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/core.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
onnx_ir.to_proto
onnx_ir.to_onnx_text
onnx_ir.tensor
onnx_ir.val
onnx_ir.node
```

Expand Down
3 changes: 2 additions & 1 deletion src/onnx_ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
# Convenience constructors
"tensor",
"node",
"val",
# Pass infrastructure
"passes",
# IO
Expand All @@ -90,7 +91,7 @@
import types

from onnx_ir import convenience, external_data, passes, serde, tape, traversal
from onnx_ir._convenience._constructors import node, tensor
from onnx_ir._convenience._constructors import node, tensor, val
from onnx_ir._core import (
Attr,
AttrFloat32,
Expand Down
73 changes: 72 additions & 1 deletion src/onnx_ir/_convenience/_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

def tensor(
value: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible,
dtype: _enums.DataType | None = None,
dtype: ir.DataType | None = None,
name: str | None = None,
doc_string: str | None = None,
) -> _protocols.TensorProtocol:
Expand Down Expand Up @@ -215,3 +215,74 @@ def node(
doc_string=doc_string,
metadata_props=metadata_props,
)


def val(
name: str,
dtype: ir.DataType | None = None,
shape: ir.Shape | Sequence[int | str | None] | None = None,
*,
type: ir.TypeProtocol | None = None,
const_value: ir.TensorProtocol | None = None,
) -> ir.Value:
"""Create a :class:`~onnx_ir.Value` with the given name and type.

This is a convenience constructor for creating a Value that allows you to specify
dtype and shape in a more relaxed manner. Whereas to create a Value directly, you
need to create a :class:`~onnx_ir.TypeProtocol` and :class:`~onnx_ir.Shape` object
first, this function allows you to specify dtype as a :class:`~onnx_ir.DataType`
and shape as a sequence of integers or symbolic dimensions.

Example::

>>> import onnx_ir as ir
>>> t = ir.val("x", ir.DataType.FLOAT, ["N", 42, 3])
>>> t.name
'x'
>>> t.type
Tensor(FLOAT)
>>> t.shape
Shape([SymbolicDim(N), 42, 3])

.. versionadded:: 0.1.9

Args:
name: The name of the value.
dtype: The data type of the TensorType of the value. This is used only when type is None.
shape: The shape of the value.
type: The type of the value. Only one of dtype and type can be specified.
const_value: The constant tensor that initializes the value. Supply this argument
when you want to create an initializer. The type and shape can be obtained from the tensor.

Returns:
A value with the given name and type.
"""
if const_value is not None:
const_tensor_type = _core.TensorType(const_value.dtype)
if type is not None and type != const_tensor_type:
raise ValueError(
f"The type does not match the const_value. type={type} but const_value has type {const_tensor_type}. "
"You do not have to specify the type when const_value is provided."
)
if dtype is not None and dtype != const_value.dtype:
raise ValueError(
f"The dtype does not match the const_value. dtype={dtype} but const_value has dtype {const_value.dtype}. "
"You do not have to specify the dtype when const_value is provided."
)
if shape is not None and _core.Shape(shape) != const_value.shape:
raise ValueError(
f"The shape does not match the const_value. shape={shape} but const_value has shape {const_value.shape}. "
"You do not have to specify the shape when const_value is provided."
)
return _core.Value(
name=name,
type=const_tensor_type,
shape=_core.Shape(const_value.shape), # type: ignore
const_value=const_value,
)

if type is None and dtype is not None:
type = _core.TensorType(dtype)
if shape is not None and not isinstance(shape, _core.Shape):
shape = _core.Shape(shape)
return _core.Value(name=name, type=type, shape=shape)
91 changes: 91 additions & 0 deletions src/onnx_ir/_convenience/_constructors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,96 @@ def test_tensor_handles_empty_sequence_with_dtype(self):
np.testing.assert_array_equal(tensor.numpy(), np.array([], dtype=np.float32))


class ValueConstructorTest(unittest.TestCase):
def test_value_minimal_creation(self):
"""Test creating a value with just a name."""
value = _constructors.val("minimal")

self.assertEqual(value.name, "minimal")
self.assertIsNone(value.type)
self.assertIsNone(value.shape)
self.assertIsNone(value.const_value)

def test_value_creation_with_sequence_shape(self):
"""Test that shape is correctly converted from sequence to Shape object."""
value = _constructors.val("test", ir.DataType.INT32, [1, 2, 3])

self.assertEqual(value.name, "test")
self.assertIsInstance(value.shape, ir.Shape)
self.assertEqual(value.shape, ir.Shape([1, 2, 3]))

def test_value_creation_with_explicit_type(self):
"""Test value creation with explicit type parameter."""
tensor_type = ir.TensorType(ir.DataType.DOUBLE)
value = _constructors.val("y", type=tensor_type, shape=[10])

self.assertEqual(value.name, "y")
self.assertEqual(value.type, tensor_type)
self.assertEqual(value.shape, ir.Shape([10]))

def test_value_creation_with_const_value(self):
"""Test value creation with const_value (initializer)."""
const_tensor = ir.Tensor(np.array([1.0, 2.0, 3.0], dtype=np.float32), name="const")
value = _constructors.val("initializer", const_value=const_tensor)

self.assertEqual(value.name, "initializer")
self.assertEqual(value.type, ir.TensorType(ir.DataType.FLOAT))
self.assertEqual(value.shape, ir.Shape([3]))
self.assertEqual(value.const_value, const_tensor)

def test_value_creation_with_dtype_only(self):
"""Test value creation with only dtype specified."""
value = _constructors.val("float_value", dtype=ir.DataType.FLOAT)

self.assertEqual(value.name, "float_value")
self.assertEqual(value.type, ir.TensorType(ir.DataType.FLOAT))
self.assertIsNone(value.shape)

def test_value_const_value_type_mismatch_error(self):
"""Test that providing mismatched type with const_value raises ValueError."""
const_tensor = ir.tensor([1, 2, 3], dtype=ir.DataType.INT32)
wrong_type = ir.TensorType(ir.DataType.FLOAT)

with self.assertRaisesRegex(ValueError, "The type does not match the const_value"):
_constructors.val("test", type=wrong_type, const_value=const_tensor)

def test_value_const_value_dtype_mismatch_error(self):
"""Test that providing mismatched dtype with const_value raises ValueError."""
const_tensor = ir.tensor([1.0, 2.0], dtype=ir.DataType.FLOAT)

with self.assertRaisesRegex(ValueError, "The dtype does not match the const_value"):
_constructors.val("test", dtype=ir.DataType.INT32, const_value=const_tensor)

def test_value_const_value_shape_mismatch_error(self):
"""Test that providing mismatched shape with const_value raises ValueError."""
const_tensor = ir.tensor([[1, 2], [3, 4]], dtype=ir.DataType.INT32) # Shape: [2, 2]

with self.assertRaisesRegex(ValueError, "The shape does not match the const_value"):
_constructors.val("test", shape=[3, 3], const_value=const_tensor)

def test_value_initialize_with_const_value(self):
const_tensor = ir.tensor(np.array([[1.5, 2.5], [3.5, 4.5]], dtype=np.float64))
value = _constructors.val("test", const_value=const_tensor)

self.assertEqual(value.name, "test")
self.assertEqual(value.type, ir.TensorType(ir.DataType.DOUBLE))
self.assertEqual(value.shape, ir.Shape([2, 2]))
self.assertEqual(value.const_value, const_tensor)

def test_value_creation_with_string_dimensions(self):
"""Test value creation with string dimensions in shape."""
value = _constructors.val("dynamic", ir.DataType.FLOAT, ["batch", "seq_len", 768])

self.assertEqual(value.name, "dynamic")
self.assertEqual(value.shape, ir.Shape(["batch", "seq_len", 768]))

def test_value_creation_with_none_dimensions(self):
"""Test value creation with None dimensions in shape."""
value = _constructors.val("unknown", ir.DataType.INT64, [None, 10, None])

self.assertEqual(value.name, "unknown")
self.assertEqual(value.shape, ir.Shape([None, 10, None]))


if __name__ == "__main__":
unittest.main()
Loading