diff --git a/docs/api/core.md b/docs/api/core.md index c7e40914..f16ff1b8 100644 --- a/docs/api/core.md +++ b/docs/api/core.md @@ -19,6 +19,7 @@ onnx_ir.to_proto onnx_ir.to_onnx_text onnx_ir.tensor + onnx_ir.val onnx_ir.node ``` diff --git a/src/onnx_ir/__init__.py b/src/onnx_ir/__init__.py index 5e712ffc..7df9f119 100644 --- a/src/onnx_ir/__init__.py +++ b/src/onnx_ir/__init__.py @@ -78,6 +78,7 @@ # Convenience constructors "tensor", "node", + "val", # Pass infrastructure "passes", # IO @@ -90,7 +91,7 @@ import types from onnx_ir import convenience, external_data, passes, serde, tape, traversal -from onnx_ir._convenience._constructors import node, tensor +from onnx_ir._convenience._constructors import node, tensor, val from onnx_ir._core import ( Attr, AttrFloat32, diff --git a/src/onnx_ir/_convenience/_constructors.py b/src/onnx_ir/_convenience/_constructors.py index 01665dd8..29210376 100644 --- a/src/onnx_ir/_convenience/_constructors.py +++ b/src/onnx_ir/_convenience/_constructors.py @@ -25,7 +25,7 @@ def tensor( value: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible, - dtype: _enums.DataType | None = None, + dtype: ir.DataType | None = None, name: str | None = None, doc_string: str | None = None, ) -> _protocols.TensorProtocol: @@ -215,3 +215,74 @@ def node( doc_string=doc_string, metadata_props=metadata_props, ) + + +def val( + name: str, + dtype: ir.DataType | None = None, + shape: ir.Shape | Sequence[int | str | None] | None = None, + *, + type: ir.TypeProtocol | None = None, + const_value: ir.TensorProtocol | None = None, +) -> ir.Value: + """Create a :class:`~onnx_ir.Value` with the given name and type. + + This is a convenience constructor for creating a Value that allows you to specify + dtype and shape in a more relaxed manner. Whereas to create a Value directly, you + need to create a :class:`~onnx_ir.TypeProtocol` and :class:`~onnx_ir.Shape` object + first, this function allows you to specify dtype as a :class:`~onnx_ir.DataType` + and shape as a sequence of integers or symbolic dimensions. + + Example:: + + >>> import onnx_ir as ir + >>> t = ir.val("x", ir.DataType.FLOAT, ["N", 42, 3]) + >>> t.name + 'x' + >>> t.type + Tensor(FLOAT) + >>> t.shape + Shape([SymbolicDim(N), 42, 3]) + + .. versionadded:: 0.1.9 + + Args: + name: The name of the value. + dtype: The data type of the TensorType of the value. This is used only when type is None. + shape: The shape of the value. + type: The type of the value. Only one of dtype and type can be specified. + const_value: The constant tensor that initializes the value. Supply this argument + when you want to create an initializer. The type and shape can be obtained from the tensor. + + Returns: + A value with the given name and type. + """ + if const_value is not None: + const_tensor_type = _core.TensorType(const_value.dtype) + if type is not None and type != const_tensor_type: + raise ValueError( + f"The type does not match the const_value. type={type} but const_value has type {const_tensor_type}. " + "You do not have to specify the type when const_value is provided." + ) + if dtype is not None and dtype != const_value.dtype: + raise ValueError( + f"The dtype does not match the const_value. dtype={dtype} but const_value has dtype {const_value.dtype}. " + "You do not have to specify the dtype when const_value is provided." + ) + if shape is not None and _core.Shape(shape) != const_value.shape: + raise ValueError( + f"The shape does not match the const_value. shape={shape} but const_value has shape {const_value.shape}. " + "You do not have to specify the shape when const_value is provided." + ) + return _core.Value( + name=name, + type=const_tensor_type, + shape=_core.Shape(const_value.shape), # type: ignore + const_value=const_value, + ) + + if type is None and dtype is not None: + type = _core.TensorType(dtype) + if shape is not None and not isinstance(shape, _core.Shape): + shape = _core.Shape(shape) + return _core.Value(name=name, type=type, shape=shape) diff --git a/src/onnx_ir/_convenience/_constructors_test.py b/src/onnx_ir/_convenience/_constructors_test.py index ab5e00e1..6c4d6a23 100644 --- a/src/onnx_ir/_convenience/_constructors_test.py +++ b/src/onnx_ir/_convenience/_constructors_test.py @@ -27,5 +27,96 @@ def test_tensor_handles_empty_sequence_with_dtype(self): np.testing.assert_array_equal(tensor.numpy(), np.array([], dtype=np.float32)) +class ValueConstructorTest(unittest.TestCase): + def test_value_minimal_creation(self): + """Test creating a value with just a name.""" + value = _constructors.val("minimal") + + self.assertEqual(value.name, "minimal") + self.assertIsNone(value.type) + self.assertIsNone(value.shape) + self.assertIsNone(value.const_value) + + def test_value_creation_with_sequence_shape(self): + """Test that shape is correctly converted from sequence to Shape object.""" + value = _constructors.val("test", ir.DataType.INT32, [1, 2, 3]) + + self.assertEqual(value.name, "test") + self.assertIsInstance(value.shape, ir.Shape) + self.assertEqual(value.shape, ir.Shape([1, 2, 3])) + + def test_value_creation_with_explicit_type(self): + """Test value creation with explicit type parameter.""" + tensor_type = ir.TensorType(ir.DataType.DOUBLE) + value = _constructors.val("y", type=tensor_type, shape=[10]) + + self.assertEqual(value.name, "y") + self.assertEqual(value.type, tensor_type) + self.assertEqual(value.shape, ir.Shape([10])) + + def test_value_creation_with_const_value(self): + """Test value creation with const_value (initializer).""" + const_tensor = ir.Tensor(np.array([1.0, 2.0, 3.0], dtype=np.float32), name="const") + value = _constructors.val("initializer", const_value=const_tensor) + + self.assertEqual(value.name, "initializer") + self.assertEqual(value.type, ir.TensorType(ir.DataType.FLOAT)) + self.assertEqual(value.shape, ir.Shape([3])) + self.assertEqual(value.const_value, const_tensor) + + def test_value_creation_with_dtype_only(self): + """Test value creation with only dtype specified.""" + value = _constructors.val("float_value", dtype=ir.DataType.FLOAT) + + self.assertEqual(value.name, "float_value") + self.assertEqual(value.type, ir.TensorType(ir.DataType.FLOAT)) + self.assertIsNone(value.shape) + + def test_value_const_value_type_mismatch_error(self): + """Test that providing mismatched type with const_value raises ValueError.""" + const_tensor = ir.tensor([1, 2, 3], dtype=ir.DataType.INT32) + wrong_type = ir.TensorType(ir.DataType.FLOAT) + + with self.assertRaisesRegex(ValueError, "The type does not match the const_value"): + _constructors.val("test", type=wrong_type, const_value=const_tensor) + + def test_value_const_value_dtype_mismatch_error(self): + """Test that providing mismatched dtype with const_value raises ValueError.""" + const_tensor = ir.tensor([1.0, 2.0], dtype=ir.DataType.FLOAT) + + with self.assertRaisesRegex(ValueError, "The dtype does not match the const_value"): + _constructors.val("test", dtype=ir.DataType.INT32, const_value=const_tensor) + + def test_value_const_value_shape_mismatch_error(self): + """Test that providing mismatched shape with const_value raises ValueError.""" + const_tensor = ir.tensor([[1, 2], [3, 4]], dtype=ir.DataType.INT32) # Shape: [2, 2] + + with self.assertRaisesRegex(ValueError, "The shape does not match the const_value"): + _constructors.val("test", shape=[3, 3], const_value=const_tensor) + + def test_value_initialize_with_const_value(self): + const_tensor = ir.tensor(np.array([[1.5, 2.5], [3.5, 4.5]], dtype=np.float64)) + value = _constructors.val("test", const_value=const_tensor) + + self.assertEqual(value.name, "test") + self.assertEqual(value.type, ir.TensorType(ir.DataType.DOUBLE)) + self.assertEqual(value.shape, ir.Shape([2, 2])) + self.assertEqual(value.const_value, const_tensor) + + def test_value_creation_with_string_dimensions(self): + """Test value creation with string dimensions in shape.""" + value = _constructors.val("dynamic", ir.DataType.FLOAT, ["batch", "seq_len", 768]) + + self.assertEqual(value.name, "dynamic") + self.assertEqual(value.shape, ir.Shape(["batch", "seq_len", 768])) + + def test_value_creation_with_none_dimensions(self): + """Test value creation with None dimensions in shape.""" + value = _constructors.val("unknown", ir.DataType.INT64, [None, 10, None]) + + self.assertEqual(value.name, "unknown") + self.assertEqual(value.shape, ir.Shape([None, 10, None])) + + if __name__ == "__main__": unittest.main()