From d0b02a6cea7c4da73ebd5df22d7dfc0daa449b2e Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 22 Oct 2025 13:26:25 +0800 Subject: [PATCH 01/17] Implement metadata-aware PySimpleScalarUDF Enhance scalar UDF definitions to retain Arrow Field information, including extension metadata, in DataFusion. Normalize Python UDF signatures to accept pyarrow.Field objects, ensuring metadata survives the Rust bindings roundtrip. Add a regression test for UUID-backed UDFs to verify that the second UDF correctly receives a pyarrow.ExtensionArray, preventing past metadata loss. --- python/datafusion/user_defined.py | 71 +++++++++++++++------- python/tests/test_udf.py | 55 ++++++++++++++++++ src/udf.rs | 97 ++++++++++++++++++++++++++++--- 3 files changed, 196 insertions(+), 27 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 67568e313..24dd7761f 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -22,17 +22,13 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar, overload +from typing import Any, Callable, Optional, Protocol, Sequence, overload import pyarrow as pa import datafusion._internal as df_internal from datafusion.expr import Expr -if TYPE_CHECKING: - _R = TypeVar("_R", bound=pa.DataType) - - class Volatility(Enum): """Defines how stable or volatile a function is. @@ -77,6 +73,40 @@ def __str__(self) -> str: return self.name.lower() +def _normalize_field(value: pa.DataType | pa.Field, *, default_name: str) -> pa.Field: + if isinstance(value, pa.Field): + return value + if isinstance(value, pa.DataType): + return pa.field(default_name, value) + msg = "Expected a pyarrow.DataType or pyarrow.Field" + raise TypeError(msg) + + +def _normalize_input_fields( + values: pa.DataType | pa.Field | Sequence[pa.DataType | pa.Field], +) -> list[pa.Field]: + if isinstance(values, (pa.DataType, pa.Field)): + sequence: Sequence[pa.DataType | pa.Field] = [values] + elif isinstance(values, Sequence) and not isinstance(values, (str, bytes)): + sequence = values + else: + msg = "input_types must be a DataType, Field, or a sequence of them" + raise TypeError(msg) + + return [ + _normalize_field(value, default_name=f"arg_{idx}") for idx, value in enumerate(sequence) + ] + + +def _normalize_return_field( + value: pa.DataType | pa.Field, + *, + name: str, +) -> pa.Field: + default_name = f"{name}_result" if name else "result" + return _normalize_field(value, default_name=default_name) + + class ScalarUDFExportable(Protocol): """Type hint for object that has __datafusion_scalar_udf__ PyCapsule.""" @@ -93,9 +123,9 @@ class ScalarUDF: def __init__( self, name: str, - func: Callable[..., _R], - input_types: pa.DataType | list[pa.DataType], - return_type: _R, + func: Callable[..., Any], + input_types: pa.DataType | pa.Field | Sequence[pa.DataType | pa.Field], + return_type: pa.DataType | pa.Field, volatility: Volatility | str, ) -> None: """Instantiate a scalar user-defined function (UDF). @@ -105,10 +135,10 @@ def __init__( if hasattr(func, "__datafusion_scalar_udf__"): self._udf = df_internal.ScalarUDF.from_pycapsule(func) return - if isinstance(input_types, pa.DataType): - input_types = [input_types] + normalized_inputs = _normalize_input_fields(input_types) + normalized_return = _normalize_return_field(return_type, name=name) self._udf = df_internal.ScalarUDF( - name, func, input_types, return_type, str(volatility) + name, func, normalized_inputs, normalized_return, str(volatility) ) def __repr__(self) -> str: @@ -127,8 +157,8 @@ def __call__(self, *args: Expr) -> Expr: @overload @staticmethod def udf( - input_types: list[pa.DataType], - return_type: _R, + input_types: list[pa.DataType | pa.Field], + return_type: pa.DataType | pa.Field, volatility: Volatility | str, name: Optional[str] = None, ) -> Callable[..., ScalarUDF]: ... @@ -136,9 +166,9 @@ def udf( @overload @staticmethod def udf( - func: Callable[..., _R], - input_types: list[pa.DataType], - return_type: _R, + func: Callable[..., Any], + input_types: list[pa.DataType | pa.Field], + return_type: pa.DataType | pa.Field, volatility: Volatility | str, name: Optional[str] = None, ) -> ScalarUDF: ... @@ -164,10 +194,11 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417 backed ScalarUDF within a PyCapsule, you can pass this parameter and ignore the rest. They will be determined directly from the underlying function. See the online documentation for more information. - input_types (list[pa.DataType]): The data types of the arguments - to ``func``. This list must be of the same length as the number of - arguments. - return_type (_R): The data type of the return value from the function. + input_types (list[pa.DataType | pa.Field]): The argument types for ``func``. + This list must be of the same length as the number of arguments. Pass + :class:`pyarrow.Field` instances to preserve extension metadata. + return_type (pa.DataType | pa.Field): The return type of the function. Use a + :class:`pyarrow.Field` to preserve metadata on extension arrays. volatility (Volatility | str): See `Volatility` for allowed values. name (Optional[str]): A descriptive name for the function. diff --git a/python/tests/test_udf.py b/python/tests/test_udf.py index a6c047552..6a631a49e 100644 --- a/python/tests/test_udf.py +++ b/python/tests/test_udf.py @@ -124,3 +124,58 @@ def udf_with_param(values: pa.Array) -> pa.Array: result = df2.collect()[0].column(0) assert result == pa.array([False, True, True]) + + +def test_uuid_extension_chain(ctx) -> None: + uuid_type = pa.uuid() + uuid_field = pa.field("uuid_col", uuid_type) + + first = udf( + lambda values: values, + [uuid_field], + uuid_field, + volatility="immutable", + name="uuid_identity", + ) + + def ensure_extension(values: pa.Array) -> pa.Array: + assert isinstance(values, pa.ExtensionArray) + return values + + second = udf( + ensure_extension, + [uuid_field], + uuid_field, + volatility="immutable", + name="uuid_assert", + ) + + batch = pa.RecordBatch.from_arrays( + [ + pa.array( + [ + "00000000-0000-0000-0000-000000000000", + "00000000-0000-0000-0000-000000000001", + ], + type=uuid_type, + ) + ], + names=["uuid_col"], + ) + + df = ctx.create_dataframe([[batch]]) + result = ( + df.select(second(first(column("uuid_col")))) + .collect()[0] + .column(0) + ) + + expected = pa.array( + [ + "00000000-0000-0000-0000-000000000000", + "00000000-0000-0000-0000-000000000001", + ], + type=uuid_type, + ) + + assert result.equals(expected) diff --git a/src/udf.rs b/src/udf.rs index a9249d6c8..c8e275e6c 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::fmt; use std::sync::Arc; use datafusion_ffi::udf::{FFI_ScalarUDF, ForeignScalarUDF}; @@ -22,13 +23,16 @@ use pyo3::types::PyCapsule; use pyo3::{prelude::*, types::PyTuple}; use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef}; -use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::arrow::pyarrow::FromPyArrow; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::error::DataFusionError; use datafusion::logical_expr::function::ScalarFunctionImplementation; -use datafusion::logical_expr::ScalarUDF; -use datafusion::logical_expr::{create_udf, ColumnarValue}; +use datafusion::logical_expr::ptr_eq::PtrEq; +use datafusion::logical_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, + Volatility, +}; use crate::errors::to_datafusion_err; use crate::errors::{py_datafusion_err, PyDataFusionResult}; @@ -80,6 +84,83 @@ fn to_scalar_function_impl(func: PyObject) -> ScalarFunctionImplementation { }) } +#[derive(PartialEq, Eq, Hash)] +struct PySimpleScalarUDF { + name: String, + signature: Signature, + return_field: Arc, + fun: PtrEq, +} + +impl fmt::Debug for PySimpleScalarUDF { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PySimpleScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("return_field", &self.return_field) + .finish() + } +} + +impl PySimpleScalarUDF { + fn new( + name: impl Into, + input_fields: Vec, + return_field: Field, + volatility: Volatility, + fun: ScalarFunctionImplementation, + ) -> Self { + let signature_types = input_fields + .into_iter() + .map(|field| field.data_type().clone()) + .collect(); + let signature = Signature::exact(signature_types, volatility); + Self { + name: name.into(), + signature, + return_field: Arc::new(return_field), + fun: fun.into(), + } + } +} + +impl ScalarUDFImpl for PySimpleScalarUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result { + Ok(self.return_field.data_type().clone()) + } + + fn return_field_from_args( + &self, + _args: ReturnFieldArgs, + ) -> datafusion::error::Result> { + Ok(Arc::new( + self.return_field + .as_ref() + .clone() + .with_name(self.name.clone()), + )) + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion::error::Result { + (self.fun)(&args.args) + } +} + /// Represents a PyScalarUDF #[pyclass(frozen, name = "ScalarUDF", module = "datafusion", subclass)] #[derive(Debug, Clone)] @@ -94,17 +175,19 @@ impl PyScalarUDF { fn new( name: &str, func: PyObject, - input_types: PyArrowType>, - return_type: PyArrowType, + input_types: PyArrowType>, + return_type: PyArrowType, volatility: &str, ) -> PyResult { - let function = create_udf( + let volatility = parse_volatility(volatility)?; + let scalar_impl = PySimpleScalarUDF::new( name, input_types.0, return_type.0, - parse_volatility(volatility)?, + volatility, to_scalar_function_impl(func), ); + let function = ScalarUDF::new_from_impl(scalar_impl); Ok(Self { function }) } From 2325993a7c987b4ab97fd8ab6e5405ba61e420e5 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 22 Oct 2025 14:28:31 +0800 Subject: [PATCH 02/17] Clone pyarrow.Field objects for FFI handoff Wrap scalar UDF inputs/outputs to maintain extension types during execution. Enhance UUID extension regression test to ensure metadata retention and normalize results for accurate comparison. --- python/datafusion/user_defined.py | 46 +++++++++++++++++++++++++++++-- python/tests/test_udf.py | 32 +++++++++++---------- 2 files changed, 60 insertions(+), 18 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 24dd7761f..1e627a12a 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -73,11 +73,17 @@ def __str__(self) -> str: return self.name.lower() +def _clone_field(field: pa.Field) -> pa.Field: + """Return a deep copy of ``field`` including its DataType.""" + + return pa.schema([field]).field(0) + + def _normalize_field(value: pa.DataType | pa.Field, *, default_name: str) -> pa.Field: if isinstance(value, pa.Field): - return value + return _clone_field(value) if isinstance(value, pa.DataType): - return pa.field(default_name, value) + return _clone_field(pa.field(default_name, value)) msg = "Expected a pyarrow.DataType or pyarrow.Field" raise TypeError(msg) @@ -107,6 +113,39 @@ def _normalize_return_field( return _normalize_field(value, default_name=default_name) +def _wrap_extension_value(value: Any, data_type: pa.DataType) -> Any: + storage_type = getattr(data_type, "storage_type", None) + wrap_array = getattr(data_type, "wrap_array", None) + if storage_type is None or wrap_array is None: + return value + if isinstance(value, pa.Array) and value.type.equals(storage_type): + return wrap_array(value) + if isinstance(value, pa.ChunkedArray) and value.type.equals(storage_type): + wrapped_chunks = [wrap_array(chunk) for chunk in value.chunks] + return pa.chunked_array(wrapped_chunks) + return value + + +def _wrap_udf_function( + func: Callable[..., Any], + input_fields: Sequence[pa.Field], + return_field: pa.Field, +) -> Callable[..., Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: + if args: + converted_args = list(args) + for idx, field in enumerate(input_fields): + if idx >= len(converted_args): + break + converted_args[idx] = _wrap_extension_value(converted_args[idx], field.type) + else: + converted_args = [] + result = func(*converted_args, **kwargs) + return _wrap_extension_value(result, return_field.type) + + return wrapper + + class ScalarUDFExportable(Protocol): """Type hint for object that has __datafusion_scalar_udf__ PyCapsule.""" @@ -137,8 +176,9 @@ def __init__( return normalized_inputs = _normalize_input_fields(input_types) normalized_return = _normalize_return_field(return_type, name=name) + wrapped_func = _wrap_udf_function(func, normalized_inputs, normalized_return) self._udf = df_internal.ScalarUDF( - name, func, normalized_inputs, normalized_return, str(volatility) + name, wrapped_func, normalized_inputs, normalized_return, str(volatility) ) def __repr__(self) -> str: diff --git a/python/tests/test_udf.py b/python/tests/test_udf.py index 6a631a49e..ad830af0d 100644 --- a/python/tests/test_udf.py +++ b/python/tests/test_udf.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import uuid + import pyarrow as pa import pytest from datafusion import column, udf @@ -150,16 +152,19 @@ def ensure_extension(values: pa.Array) -> pa.Array: name="uuid_assert", ) - batch = pa.RecordBatch.from_arrays( + # The UUID extension metadata should survive UDF registration. + assert getattr(uuid_type, "extension_name", None) == "arrow.uuid" + assert getattr(uuid_field.type, "extension_name", None) == "arrow.uuid" + + storage = pa.array( [ - pa.array( - [ - "00000000-0000-0000-0000-000000000000", - "00000000-0000-0000-0000-000000000001", - ], - type=uuid_type, - ) + uuid.UUID("00000000-0000-0000-0000-000000000000").bytes, + uuid.UUID("00000000-0000-0000-0000-000000000001").bytes, ], + type=uuid_type.storage_type, + ) + batch = pa.RecordBatch.from_arrays( + [uuid_type.wrap_array(storage)], names=["uuid_col"], ) @@ -170,12 +175,9 @@ def ensure_extension(values: pa.Array) -> pa.Array: .column(0) ) - expected = pa.array( - [ - "00000000-0000-0000-0000-000000000000", - "00000000-0000-0000-0000-000000000001", - ], - type=uuid_type, - ) + expected = uuid_type.wrap_array(storage) + + if isinstance(result, pa.Array) and result.type.equals(uuid_type.storage_type): + result = uuid_type.wrap_array(result) assert result.equals(expected) From e0bce84be019b201de6c39dbbb72cc675b960819 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 22 Oct 2025 15:36:22 +0800 Subject: [PATCH 03/17] Updated the _function and _decorator helpers in ScalarUDF.udf to use concrete Callable[..., Any] and pa.DataType | pa.Field annotations, removing the lingering references to the deleted _R type variable. --- python/datafusion/user_defined.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 1e627a12a..dd057b828 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -261,9 +261,9 @@ def double_udf(x): """ def _function( - func: Callable[..., _R], + func: Callable[..., Any], input_types: list[pa.DataType], - return_type: _R, + return_type: pa.DataType | pa.Field, volatility: Volatility | str, name: Optional[str] = None, ) -> ScalarUDF: @@ -285,7 +285,7 @@ def _function( def _decorator( input_types: list[pa.DataType], - return_type: _R, + return_type: pa.DataType | pa.Field, volatility: Volatility | str, name: Optional[str] = None, ) -> Callable: From 5fb08d6b0bda3c2b837cb5932450a3e3757d063a Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 22 Oct 2025 16:00:41 +0800 Subject: [PATCH 04/17] Updated the internal ScalarUDF.udf helper signatures so IDEs and type checkers surface support for pyarrow.Field inputs when defining UDFs --- python/datafusion/user_defined.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index dd057b828..ca7e2436a 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -262,7 +262,7 @@ def double_udf(x): def _function( func: Callable[..., Any], - input_types: list[pa.DataType], + input_types: list[pa.DataType | pa.Field], return_type: pa.DataType | pa.Field, volatility: Volatility | str, name: Optional[str] = None, @@ -284,7 +284,7 @@ def _function( ) def _decorator( - input_types: list[pa.DataType], + input_types: list[pa.DataType | pa.Field], return_type: pa.DataType | pa.Field, volatility: Volatility | str, name: Optional[str] = None, From 6a89977f9fa4066513718568c2f5d4b8997c0e17 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 22 Oct 2025 16:31:55 +0800 Subject: [PATCH 05/17] Refactor imports in udf.rs for improved organization and clarity --- src/udf.rs | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/udf.rs b/src/udf.rs index c8e275e6c..f3d764a6c 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -22,23 +22,20 @@ use datafusion_ffi::udf::{FFI_ScalarUDF, ForeignScalarUDF}; use pyo3::types::PyCapsule; use pyo3::{prelude::*, types::PyTuple}; +use crate::errors::to_datafusion_err; +use crate::errors::{py_datafusion_err, PyDataFusionResult}; +use crate::expr::PyExpr; +use crate::utils::{parse_volatility, validate_pycapsule}; use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef}; use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::arrow::pyarrow::FromPyArrow; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::error::DataFusionError; -use datafusion::logical_expr::function::ScalarFunctionImplementation; -use datafusion::logical_expr::ptr_eq::PtrEq; use datafusion::logical_expr::{ - ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, - Volatility, + function::ScalarFunctionImplementation, ptr_eq::PtrEq, ColumnarValue, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; -use crate::errors::to_datafusion_err; -use crate::errors::{py_datafusion_err, PyDataFusionResult}; -use crate::expr::PyExpr; -use crate::utils::{parse_volatility, validate_pycapsule}; - /// Create a Rust callable function from a python function that expects pyarrow arrays fn pyarrow_function_to_rust( func: PyObject, From 5aacb41da9435ecd5776fe4d0974849585581b9f Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 22 Oct 2025 16:49:15 +0800 Subject: [PATCH 06/17] Add shared PyArrowArray alias and refine ScalarUDFs Introduce a shared alias for PyArrowArray and update the extension wrapping helpers to ensure scalar UDF return types are preserved when handling PyArrow arrays. Enhance ScalarUDF signatures, overloads, and documentation to align with the PyArrow array contract for Python scalar UDFs. --- python/datafusion/user_defined.py | 37 +++++++++++++++++++------------ 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index ca7e2436a..ff605920e 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -22,13 +22,17 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import Any, Callable, Optional, Protocol, Sequence, overload +from typing import Any, Callable, Optional, Protocol, Sequence, TypeVar, cast, overload import pyarrow as pa import datafusion._internal as df_internal from datafusion.expr import Expr +PyArrowArray = pa.Array | pa.ChunkedArray +# Type alias for array batches exchanged with Python scalar UDFs. +PyArrowArrayT = TypeVar("PyArrowArrayT", pa.Array, pa.ChunkedArray) + class Volatility(Enum): """Defines how stable or volatile a function is. @@ -113,7 +117,7 @@ def _normalize_return_field( return _normalize_field(value, default_name=default_name) -def _wrap_extension_value(value: Any, data_type: pa.DataType) -> Any: +def _wrap_extension_value(value: PyArrowArrayT, data_type: pa.DataType) -> PyArrowArrayT: storage_type = getattr(data_type, "storage_type", None) wrap_array = getattr(data_type, "wrap_array", None) if storage_type is None or wrap_array is None: @@ -127,17 +131,20 @@ def _wrap_extension_value(value: Any, data_type: pa.DataType) -> Any: def _wrap_udf_function( - func: Callable[..., Any], + func: Callable[..., PyArrowArrayT], input_fields: Sequence[pa.Field], return_field: pa.Field, -) -> Callable[..., Any]: - def wrapper(*args: Any, **kwargs: Any) -> Any: +) -> Callable[..., PyArrowArrayT]: + def wrapper(*args: Any, **kwargs: Any) -> PyArrowArrayT: if args: - converted_args = list(args) + converted_args: list[Any] = list(args) for idx, field in enumerate(input_fields): if idx >= len(converted_args): break - converted_args[idx] = _wrap_extension_value(converted_args[idx], field.type) + converted_args[idx] = _wrap_extension_value( + cast(PyArrowArray, converted_args[idx]), + field.type, + ) else: converted_args = [] result = func(*converted_args, **kwargs) @@ -162,7 +169,7 @@ class ScalarUDF: def __init__( self, name: str, - func: Callable[..., Any], + func: Callable[..., PyArrowArray] | ScalarUDFExportable, input_types: pa.DataType | pa.Field | Sequence[pa.DataType | pa.Field], return_type: pa.DataType | pa.Field, volatility: Volatility | str, @@ -201,12 +208,12 @@ def udf( return_type: pa.DataType | pa.Field, volatility: Volatility | str, name: Optional[str] = None, - ) -> Callable[..., ScalarUDF]: ... + ) -> Callable[[Callable[..., PyArrowArray]], Callable[..., Expr]]: ... @overload @staticmethod def udf( - func: Callable[..., Any], + func: Callable[..., PyArrowArray], input_types: list[pa.DataType | pa.Field], return_type: pa.DataType | pa.Field, volatility: Volatility | str, @@ -234,6 +241,8 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417 backed ScalarUDF within a PyCapsule, you can pass this parameter and ignore the rest. They will be determined directly from the underlying function. See the online documentation for more information. + The callable should accept and return :class:`pyarrow.Array` or + :class:`pyarrow.ChunkedArray` values. input_types (list[pa.DataType | pa.Field]): The argument types for ``func``. This list must be of the same length as the number of arguments. Pass :class:`pyarrow.Field` instances to preserve extension metadata. @@ -261,7 +270,7 @@ def double_udf(x): """ def _function( - func: Callable[..., Any], + func: Callable[..., PyArrowArray], input_types: list[pa.DataType | pa.Field], return_type: pa.DataType | pa.Field, volatility: Volatility | str, @@ -288,14 +297,14 @@ def _decorator( return_type: pa.DataType | pa.Field, volatility: Volatility | str, name: Optional[str] = None, - ) -> Callable: - def decorator(func: Callable): + ) -> Callable[[Callable[..., PyArrowArray]], Callable[..., Expr]]: + def decorator(func: Callable[..., PyArrowArray]) -> Callable[..., Expr]: udf_caller = ScalarUDF.udf( func, input_types, return_type, volatility, name ) @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any): + def wrapper(*args: Any, **kwargs: Any) -> Expr: return udf_caller(*args, **kwargs) return wrapper From bae5d54b9e107912e7f0fad8fcbbbac83b0dbb1e Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 22 Oct 2025 17:04:07 +0800 Subject: [PATCH 07/17] Add feature flag for pyarrow UUID helper detection Implement a feature flag to check for UUID helper in pyarrow. Add conditional skip to the UUID extension UDF chaining test when the helper is unavailable, retaining original assertions for supported environments. --- python/tests/test_udf.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tests/test_udf.py b/python/tests/test_udf.py index ad830af0d..cbb0022b9 100644 --- a/python/tests/test_udf.py +++ b/python/tests/test_udf.py @@ -22,6 +22,9 @@ from datafusion import column, udf +UUID_EXTENSION_AVAILABLE = hasattr(pa, "uuid") + + @pytest.fixture def df(ctx): # create a RecordBatch and a new DataFrame from it @@ -128,6 +131,10 @@ def udf_with_param(values: pa.Array) -> pa.Array: assert result == pa.array([False, True, True]) +@pytest.mark.skipif( + not UUID_EXTENSION_AVAILABLE, + reason="PyArrow uuid extension helper unavailable", +) def test_uuid_extension_chain(ctx) -> None: uuid_type = pa.uuid() uuid_field = pa.field("uuid_col", uuid_type) From 153b5f1620385e4d16cadb520b70b002c2d40d86 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 22 Oct 2025 17:27:28 +0800 Subject: [PATCH 08/17] Add assertions for UUID extension type in tests Ensure collected UUID results are extension arrays or chunked arrays with the UUID extension type before comparison to expected values, preserving end-to-end metadata validation. --- python/tests/test_udf.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tests/test_udf.py b/python/tests/test_udf.py index cbb0022b9..d6b5a131c 100644 --- a/python/tests/test_udf.py +++ b/python/tests/test_udf.py @@ -184,7 +184,10 @@ def ensure_extension(values: pa.Array) -> pa.Array: expected = uuid_type.wrap_array(storage) - if isinstance(result, pa.Array) and result.type.equals(uuid_type.storage_type): - result = uuid_type.wrap_array(result) + if isinstance(result, pa.ChunkedArray): + assert result.type.equals(uuid_type) + else: + assert isinstance(result, pa.ExtensionArray) + assert result.type.equals(uuid_type) assert result.equals(expected) From 1baa2b719d328c9f91649e2f8d93041b3ed3647f Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 22 Oct 2025 18:05:09 +0800 Subject: [PATCH 09/17] Teach _wrap_extension_value to handle empty arrays Return a wrapped empty extension array for chunked storage arrays with no chunks, preserving extension metadata. Expand UUID UDF regression to support chunked inputs, test empty chunked returns, and ensure UUID extension type remains intact through UDF chaining. --- python/datafusion/user_defined.py | 5 ++++- python/tests/test_udf.py | 37 ++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index ff605920e..898a20561 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -126,7 +126,10 @@ def _wrap_extension_value(value: PyArrowArrayT, data_type: pa.DataType) -> PyArr return wrap_array(value) if isinstance(value, pa.ChunkedArray) and value.type.equals(storage_type): wrapped_chunks = [wrap_array(chunk) for chunk in value.chunks] - return pa.chunked_array(wrapped_chunks) + if not wrapped_chunks: + empty_storage = pa.array([], type=storage_type) + return wrap_array(empty_storage) + return pa.chunked_array(wrapped_chunks, type=data_type) return value diff --git a/python/tests/test_udf.py b/python/tests/test_udf.py index d6b5a131c..f10604c56 100644 --- a/python/tests/test_udf.py +++ b/python/tests/test_udf.py @@ -147,8 +147,12 @@ def test_uuid_extension_chain(ctx) -> None: name="uuid_identity", ) - def ensure_extension(values: pa.Array) -> pa.Array: + def ensure_extension(values: pa.Array | pa.ChunkedArray) -> pa.Array: + if isinstance(values, pa.ChunkedArray): + assert values.type.equals(uuid_type) + return values.combine_chunks() assert isinstance(values, pa.ExtensionArray) + assert values.type.equals(uuid_type) return values second = udf( @@ -191,3 +195,34 @@ def ensure_extension(values: pa.Array) -> pa.Array: assert result.type.equals(uuid_type) assert result.equals(expected) + + empty_storage = pa.array([], type=uuid_type.storage_type) + empty_batch = pa.RecordBatch.from_arrays( + [uuid_type.wrap_array(empty_storage)], + names=["uuid_col"], + ) + + empty_first = udf( + lambda values: pa.chunked_array([], type=uuid_type.storage_type), + [uuid_field], + uuid_field, + volatility="immutable", + name="uuid_empty_chunk", + ) + + empty_df = ctx.create_dataframe([[empty_batch]]) + empty_result = ( + empty_df.select(second(empty_first(column("uuid_col")))) + .collect()[0] + .column(0) + ) + + expected_empty = uuid_type.wrap_array(empty_storage) + + if isinstance(empty_result, pa.ChunkedArray): + assert empty_result.type.equals(uuid_type) + assert empty_result.combine_chunks().equals(expected_empty) + else: + assert isinstance(empty_result, pa.ExtensionArray) + assert empty_result.type.equals(uuid_type) + assert empty_result.equals(expected_empty) From 308e77438a03a86857ae5fe4f76d472ee3821465 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 22 Oct 2025 18:30:43 +0800 Subject: [PATCH 10/17] Refactor type alias for PyArrowArray to use Union for better clarity --- python/datafusion/user_defined.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 898a20561..250d307ce 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -22,14 +22,14 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import Any, Callable, Optional, Protocol, Sequence, TypeVar, cast, overload +from typing import Any, Callable, Optional, Protocol, Sequence, TypeVar, Union, cast, overload import pyarrow as pa import datafusion._internal as df_internal from datafusion.expr import Expr -PyArrowArray = pa.Array | pa.ChunkedArray +PyArrowArray = Union[pa.Array, pa.ChunkedArray] # Type alias for array batches exchanged with Python scalar UDFs. PyArrowArrayT = TypeVar("PyArrowArrayT", pa.Array, pa.ChunkedArray) From 16224e25543cd85bc2831c28c0b8412592947ad2 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 22 Oct 2025 18:47:18 +0800 Subject: [PATCH 11/17] Fix Ruff errors --- python/datafusion/user_defined.py | 31 +++++++++++++++++++++++-------- python/tests/test_udf.py | 13 ++++--------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 250d307ce..2cbabd15b 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -22,7 +22,17 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import Any, Callable, Optional, Protocol, Sequence, TypeVar, Union, cast, overload +from typing import ( + Any, + Callable, + Optional, + Protocol, + Sequence, + TypeVar, + Union, + cast, + overload, +) import pyarrow as pa @@ -33,6 +43,7 @@ # Type alias for array batches exchanged with Python scalar UDFs. PyArrowArrayT = TypeVar("PyArrowArrayT", pa.Array, pa.ChunkedArray) + class Volatility(Enum): """Defines how stable or volatile a function is. @@ -79,7 +90,6 @@ def __str__(self) -> str: def _clone_field(field: pa.Field) -> pa.Field: """Return a deep copy of ``field`` including its DataType.""" - return pa.schema([field]).field(0) @@ -104,7 +114,8 @@ def _normalize_input_fields( raise TypeError(msg) return [ - _normalize_field(value, default_name=f"arg_{idx}") for idx, value in enumerate(sequence) + _normalize_field(value, default_name=f"arg_{idx}") + for idx, value in enumerate(sequence) ] @@ -117,7 +128,9 @@ def _normalize_return_field( return _normalize_field(value, default_name=default_name) -def _wrap_extension_value(value: PyArrowArrayT, data_type: pa.DataType) -> PyArrowArrayT: +def _wrap_extension_value( + value: PyArrowArrayT, data_type: pa.DataType +) -> PyArrowArrayT: storage_type = getattr(data_type, "storage_type", None) wrap_array = getattr(data_type, "wrap_array", None) if storage_type is None or wrap_array is None: @@ -440,10 +453,12 @@ def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901 This class allows you to define an aggregate function that can be used in data aggregation or window function calls. - Usage: - - As a function: ``udaf(accum, input_types, return_type, state_type, volatility, name)``. - - As a decorator: ``@udaf(input_types, return_type, state_type, volatility, name)``. - When using ``udaf`` as a decorator, do not pass ``accum`` explicitly. + Usage: + - As a function: ``udaf(accum, input_types, return_type, state_type,`` + ``volatility, name)``. + - As a decorator: ``@udaf(input_types, return_type, state_type,`` + ``volatility, name)``. + When using ``udaf`` as a decorator, do not pass ``accum`` explicitly. Function example: diff --git a/python/tests/test_udf.py b/python/tests/test_udf.py index f10604c56..ffca7b9b4 100644 --- a/python/tests/test_udf.py +++ b/python/tests/test_udf.py @@ -15,13 +15,14 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import uuid import pyarrow as pa import pytest from datafusion import column, udf - UUID_EXTENSION_AVAILABLE = hasattr(pa, "uuid") @@ -180,11 +181,7 @@ def ensure_extension(values: pa.Array | pa.ChunkedArray) -> pa.Array: ) df = ctx.create_dataframe([[batch]]) - result = ( - df.select(second(first(column("uuid_col")))) - .collect()[0] - .column(0) - ) + result = df.select(second(first(column("uuid_col")))).collect()[0].column(0) expected = uuid_type.wrap_array(storage) @@ -212,9 +209,7 @@ def ensure_extension(values: pa.Array | pa.ChunkedArray) -> pa.Array: empty_df = ctx.create_dataframe([[empty_batch]]) empty_result = ( - empty_df.select(second(empty_first(column("uuid_col")))) - .collect()[0] - .column(0) + empty_df.select(second(empty_first(column("uuid_col")))).collect()[0].column(0) ) expected_empty = uuid_type.wrap_array(empty_storage) From 7b9ced0b6960d2daa47c8ee94530dfe7076d9c41 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 22 Oct 2025 19:22:59 +0800 Subject: [PATCH 12/17] Enhance documentation for PyArrowArray type alias to clarify usage and improve type-checking --- python/datafusion/user_defined.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 2cbabd15b..d13b27e26 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -41,6 +41,19 @@ PyArrowArray = Union[pa.Array, pa.ChunkedArray] # Type alias for array batches exchanged with Python scalar UDFs. +# +# We need two related but different annotations here: +# - `PyArrowArray` is the concrete union type (pa.Array | pa.ChunkedArray) +# that is convenient for user-facing callables and casts. Use this when +# annotating or checking values that may be either an Array or +# a ChunkedArray. +# - `PyArrowArrayT` is a constrained `TypeVar` over the two concrete +# array flavors. Keeping a generic TypeVar allows helpers like +# `_wrap_extension_value` and `_wrap_udf_function` to remain generic +# and preserve the specific array "flavor" (Array vs ChunkedArray) +# flowing through them, rather than collapsing everything to the +# wide union. This improves type-checking and keeps return types +# precise in the wrapper logic. PyArrowArrayT = TypeVar("PyArrowArrayT", pa.Array, pa.ChunkedArray) From 0f28465ae5108625769cfba22d25bb3ccbd362d0 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 23 Oct 2025 16:37:43 +0800 Subject: [PATCH 13/17] Add dev dependency pyarrow >= 19 and remove UUID extension availability check from tests --- pyproject.toml | 52 +++++++++++++++++++++++----------------- python/tests/test_udf.py | 6 ----- 2 files changed, 30 insertions(+), 28 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 69d31ec9f..88eaca840 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,20 +69,20 @@ asyncio_default_fixture_loop_scope = "function" # Enable docstring linting using the google style guide [tool.ruff.lint] -select = ["ALL" ] +select = ["ALL"] ignore = [ - "A001", # Allow using words like min as variable names - "A002", # Allow using words like filter as variable names - "ANN401", # Allow Any for wrapper classes - "COM812", # Recommended to ignore these rules when using with ruff-format - "FIX002", # Allow TODO lines - consider removing at some point - "FBT001", # Allow boolean positional args - "FBT002", # Allow boolean positional args - "ISC001", # Recommended to ignore these rules when using with ruff-format - "SLF001", # Allow accessing private members + "A001", # Allow using words like min as variable names + "A002", # Allow using words like filter as variable names + "ANN401", # Allow Any for wrapper classes + "COM812", # Recommended to ignore these rules when using with ruff-format + "FIX002", # Allow TODO lines - consider removing at some point + "FBT001", # Allow boolean positional args + "FBT002", # Allow boolean positional args + "ISC001", # Recommended to ignore these rules when using with ruff-format + "SLF001", # Allow accessing private members "TD002", - "TD003", # Allow TODO lines - "UP007", # Disallowing Union is pedantic + "TD003", # Allow TODO lines + "UP007", # Disallowing Union is pedantic # TODO: Enable all of the following, but this PR is getting too large already "PLR0913", "TRY003", @@ -129,25 +129,33 @@ extend-allowed-calls = ["lit", "datafusion.lit"] ] "examples/*" = ["D", "W505", "E501", "T201", "S101"] "dev/*" = ["D", "E", "T", "S", "PLR", "C", "SIM", "UP", "EXE", "N817"] -"benchmarks/*" = ["D", "F", "T", "BLE", "FURB", "PLR", "E", "TD", "TRY", "S", "SIM", "EXE", "UP"] +"benchmarks/*" = [ + "D", + "F", + "T", + "BLE", + "FURB", + "PLR", + "E", + "TD", + "TRY", + "S", + "SIM", + "EXE", + "UP", +] "docs/*" = ["D"] [tool.codespell] -skip = [ - "./target", - "uv.lock", - "./python/tests/test_functions.py" -] +skip = ["./target", "uv.lock", "./python/tests/test_functions.py"] count = true -ignore-words-list = [ - "ans", - "IST" -] +ignore-words-list = ["ans", "IST"] [dependency-groups] dev = [ "maturin>=1.8.1", "numpy>1.25.0", + "pyarrow>=19.0.0", "pre-commit>=4.0.0", "pytest>=7.4.4", "pytest-asyncio>=0.23.3", diff --git a/python/tests/test_udf.py b/python/tests/test_udf.py index ffca7b9b4..313295bc8 100644 --- a/python/tests/test_udf.py +++ b/python/tests/test_udf.py @@ -23,8 +23,6 @@ import pytest from datafusion import column, udf -UUID_EXTENSION_AVAILABLE = hasattr(pa, "uuid") - @pytest.fixture def df(ctx): @@ -132,10 +130,6 @@ def udf_with_param(values: pa.Array) -> pa.Array: assert result == pa.array([False, True, True]) -@pytest.mark.skipif( - not UUID_EXTENSION_AVAILABLE, - reason="PyArrow uuid extension helper unavailable", -) def test_uuid_extension_chain(ctx) -> None: uuid_type = pa.uuid() uuid_field = pa.field("uuid_col", uuid_type) From f068caa72033a86d5503f1747dc3951ba26d5458 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 23 Oct 2025 16:56:27 +0800 Subject: [PATCH 14/17] Update return_type to raise error for metadata source Ensure return_field_from_args is the only metadata source by having PySimpleScalarUDF::return_type raise an internal error. This aligns with DataFusion guidance. Enhance Python UDF helper documentation to clarify how callers can declare extension metadata on both arguments and results. --- python/datafusion/user_defined.py | 8 +++++--- src/udf.rs | 5 ++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index d13b27e26..6f84209cc 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -274,9 +274,11 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417 :class:`pyarrow.ChunkedArray` values. input_types (list[pa.DataType | pa.Field]): The argument types for ``func``. This list must be of the same length as the number of arguments. Pass - :class:`pyarrow.Field` instances to preserve extension metadata. - return_type (pa.DataType | pa.Field): The return type of the function. Use a - :class:`pyarrow.Field` to preserve metadata on extension arrays. + :class:`pyarrow.Field` instances when you need to declare extension + metadata for an argument. + return_type (pa.DataType | pa.Field): The return type of the function. Supply + a :class:`pyarrow.Field` when the result should expose extension metadata + to downstream consumers. volatility (Volatility | str): See `Volatility` for allowed values. name (Optional[str]): A descriptive name for the function. diff --git a/src/udf.rs b/src/udf.rs index f3d764a6c..ae4e9b913 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -135,7 +135,10 @@ impl ScalarUDFImpl for PySimpleScalarUDF { } fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result { - Ok(self.return_field.data_type().clone()) + Err(DataFusionError::Internal( + "return_type should be unreachable when return_field_from_args is implemented" + .to_string(), + )) } fn return_field_from_args( From 1ea75fde177382492041d2624fb850b2d4a17450 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 23 Oct 2025 18:20:51 +0800 Subject: [PATCH 15/17] Fix ruff errors --- python/datafusion/user_defined.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 6f84209cc..26f49549b 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -276,9 +276,9 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417 This list must be of the same length as the number of arguments. Pass :class:`pyarrow.Field` instances when you need to declare extension metadata for an argument. - return_type (pa.DataType | pa.Field): The return type of the function. Supply - a :class:`pyarrow.Field` when the result should expose extension metadata - to downstream consumers. + return_type (pa.DataType | pa.Field): The return type of the function. + Supply a :class:`pyarrow.Field` when the result should expose + extension metadata to downstream consumers. volatility (Volatility | str): See `Volatility` for allowed values. name (Optional[str]): A descriptive name for the function. @@ -290,8 +290,13 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417 def double_func(x): return x * 2 - double_udf = udf(double_func, [pa.int32()], pa.int32(), - "volatile", "double_it") + double_udf = udf( + double_func, + [pa.int32()], + pa.int32(), + "volatile", + "double_it", + ) Example: Using ``udf`` as a decorator:: From 27bc0129fd3a53607e59886fddf7ac309183cc32 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sat, 1 Nov 2025 17:18:52 +0800 Subject: [PATCH 16/17] reverted accidental formatting --- python/datafusion/user_defined.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 26f49549b..d28bb69d5 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -473,12 +473,10 @@ def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901 This class allows you to define an aggregate function that can be used in data aggregation or window function calls. - Usage: - - As a function: ``udaf(accum, input_types, return_type, state_type,`` - ``volatility, name)``. - - As a decorator: ``@udaf(input_types, return_type, state_type,`` - ``volatility, name)``. - When using ``udaf`` as a decorator, do not pass ``accum`` explicitly. + Usage: + - As a function: ``udaf(accum, input_types, return_type, state_type, volatility, name)``. + - As a decorator: ``@udaf(input_types, return_type, state_type, volatility, name)``. + When using ``udaf`` as a decorator, do not pass ``accum`` explicitly. Function example: From 6b400524464e9bf95d09c7ce96b641da01e64201 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sat, 1 Nov 2025 17:29:43 +0800 Subject: [PATCH 17/17] fix ruff errors --- python/datafusion/user_defined.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index d28bb69d5..be181505a 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -118,9 +118,9 @@ def _normalize_field(value: pa.DataType | pa.Field, *, default_name: str) -> pa. def _normalize_input_fields( values: pa.DataType | pa.Field | Sequence[pa.DataType | pa.Field], ) -> list[pa.Field]: - if isinstance(values, (pa.DataType, pa.Field)): + if isinstance(values, pa.DataType | pa.Field): sequence: Sequence[pa.DataType | pa.Field] = [values] - elif isinstance(values, Sequence) and not isinstance(values, (str, bytes)): + elif isinstance(values, Sequence) and not isinstance(values, str | bytes): sequence = values else: msg = "input_types must be a DataType, Field, or a sequence of them"