From f157c378c236fd57795b046e1672d46d0806149f Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 15 Oct 2025 11:58:01 +0800 Subject: [PATCH 01/12] Add fallback overload for AggregateUDF.udaf Implement fallback for PyCapsule-backed providers, ensuring type checkers are satisfied without protocol-aware stubs. Update typing imports and cast PyCapsule inputs in AggregateUDF.from_pycapsule for precise constructor typing. --- python/datafusion/user_defined.py | 48 +++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 67568e313..20e30f4f2 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -22,7 +22,16 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Optional, + Protocol, + TypeVar, + cast, + overload, +) import pyarrow as pa @@ -293,11 +302,11 @@ class AggregateUDF: def __init__( self, name: str, - accumulator: Callable[[], Accumulator], - input_types: list[pa.DataType], - return_type: pa.DataType, - state_type: list[pa.DataType], - volatility: Volatility | str, + accumulator: Callable[[], Accumulator] | AggregateUDFExportable, + input_types: list[pa.DataType] | None, + return_type: pa.DataType | None, + state_type: list[pa.DataType] | None, + volatility: Volatility | str | None, ) -> None: """Instantiate a user-defined aggregate function (UDAF). @@ -307,6 +316,18 @@ def __init__( if hasattr(accumulator, "__datafusion_aggregate_udf__"): self._udaf = df_internal.AggregateUDF.from_pycapsule(accumulator) return + if ( + input_types is None + or return_type is None + or state_type is None + or volatility is None + ): + msg = ( + "`input_types`, `return_type`, `state_type`, and `volatility` " + "must be provided when `accumulator` is callable." + ) + raise TypeError(msg) + self._udaf = df_internal.AggregateUDF( name, accumulator, @@ -350,6 +371,14 @@ def udaf( name: Optional[str] = None, ) -> AggregateUDF: ... + @overload + @staticmethod + def udaf(accum: AggregateUDFExportable) -> AggregateUDF: ... + + @overload + @staticmethod + def udaf(accum: object) -> AggregateUDF: ... + @staticmethod def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901 """Create a new User-Defined Aggregate Function (UDAF). @@ -480,16 +509,17 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr: return _decorator(*args, **kwargs) @staticmethod - def from_pycapsule(func: AggregateUDFExportable) -> AggregateUDF: + def from_pycapsule(func: AggregateUDFExportable | object) -> AggregateUDF: """Create an Aggregate UDF from AggregateUDF PyCapsule object. This function will instantiate a Aggregate UDF that uses a DataFusion AggregateUDF that is exported via the FFI bindings. """ - name = str(func.__class__) + capsule = cast(AggregateUDFExportable, func) + name = str(capsule.__class__) return AggregateUDF( name=name, - accumulator=func, + accumulator=capsule, input_types=None, return_type=None, state_type=None, From da824e6400c348afff88e929607115f824cf010a Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 15 Oct 2025 12:20:13 +0800 Subject: [PATCH 02/12] Add overloads for AggregateUDF.__init__ to support different initialization signatures --- python/datafusion/user_defined.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 20e30f4f2..9a37b3257 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -299,6 +299,30 @@ class AggregateUDF: also :py:class:`ScalarUDF` for operating on a row by row basis. """ + @overload + def __init__( + self, + name: str, + accumulator: Callable[[], Accumulator], + input_types: list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], + volatility: Volatility | str, + ) -> None: + ... + + @overload + def __init__( + self, + name: str, + accumulator: AggregateUDFExportable, + input_types: None = ..., + return_type: None = ..., + state_type: None = ..., + volatility: None = ..., + ) -> None: + ... + def __init__( self, name: str, From d275ee8bf4b6fb0f809cf990dc93448d20232e40 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 15 Oct 2025 13:25:24 +0800 Subject: [PATCH 03/12] Add PyCapsule typing protocol and helper detection Introduce a _PyCapsule typing protocol to enable type checkers to recognize PyCapsule-based registrations. Restrict the AggregateUDF udaf overload to the PyCapsule protocol and update from_pycapsule to wrap raw capsule inputs using the internal binding directly. --- python/datafusion/user_defined.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 9a37b3257..62ed25913 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -28,6 +28,7 @@ Callable, Optional, Protocol, + TypeGuard, TypeVar, cast, overload, @@ -92,6 +93,16 @@ class ScalarUDFExportable(Protocol): def __datafusion_scalar_udf__(self) -> object: ... # noqa: D105 +class _PyCapsule(Protocol): + """Lightweight typing proxy for CPython ``PyCapsule`` objects.""" + + +def _is_pycapsule(value: object) -> TypeGuard[_PyCapsule]: + """Return ``True`` when ``value`` is a CPython ``PyCapsule``.""" + + return value.__class__.__name__ == "PyCapsule" + + class ScalarUDF: """Class for performing scalar user-defined functions (UDF). @@ -401,7 +412,7 @@ def udaf(accum: AggregateUDFExportable) -> AggregateUDF: ... @overload @staticmethod - def udaf(accum: object) -> AggregateUDF: ... + def udaf(accum: _PyCapsule) -> AggregateUDF: ... @staticmethod def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901 @@ -523,7 +534,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr: return decorator - if hasattr(args[0], "__datafusion_aggregate_udf__"): + if hasattr(args[0], "__datafusion_aggregate_udf__") or _is_pycapsule(args[0]): return AggregateUDF.from_pycapsule(args[0]) if args and callable(args[0]): @@ -533,12 +544,17 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr: return _decorator(*args, **kwargs) @staticmethod - def from_pycapsule(func: AggregateUDFExportable | object) -> AggregateUDF: + def from_pycapsule(func: AggregateUDFExportable | _PyCapsule) -> AggregateUDF: """Create an Aggregate UDF from AggregateUDF PyCapsule object. This function will instantiate a Aggregate UDF that uses a DataFusion AggregateUDF that is exported via the FFI bindings. """ + if _is_pycapsule(func): + aggregate = cast(AggregateUDF, object.__new__(AggregateUDF)) + aggregate._udaf = df_internal.AggregateUDF.from_pycapsule(func) + return aggregate + capsule = cast(AggregateUDFExportable, func) name = str(capsule.__class__) return AggregateUDF( From 9d0b1911ce9b5d058d6f537dbd5bb84356288ccb Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 15 Oct 2025 14:13:43 +0800 Subject: [PATCH 04/12] Move TypeGuard import to the correct location in user_defined.py --- python/datafusion/user_defined.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 62ed25913..36d4a35e4 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -28,12 +28,13 @@ Callable, Optional, Protocol, - TypeGuard, TypeVar, cast, overload, ) +from typing_extensions import TypeGuard + import pyarrow as pa import datafusion._internal as df_internal From ac5c16facf7b9018ddf150d7bc995ee737bcf6b3 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 15 Oct 2025 15:00:10 +0800 Subject: [PATCH 05/12] Add aggregate_udf_from_capsule helper for UDFs Introduce a utility to validate PyCapsules and convert them into reusable DataFusion aggregate UDFs. Update PyAggregateUDF.from_pycapsule to handle raw PyCapsule inputs, leverage the new helper, and maintain existing provider fallback and error handling. --- src/udaf.rs | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/src/udaf.rs b/src/udaf.rs index eab4581df..0155b0309 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -154,6 +154,15 @@ pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction { }) } +fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult { + validate_pycapsule(capsule, "datafusion_aggregate_udf")?; + + let udaf = unsafe { capsule.reference::() }; + let udaf: ForeignAggregateUDF = udaf.try_into()?; + + Ok(udaf.into()) +} + /// Represents an AggregateUDF #[pyclass(frozen, name = "AggregateUDF", module = "datafusion", subclass)] #[derive(Debug, Clone)] @@ -186,22 +195,24 @@ impl PyAggregateUDF { #[staticmethod] pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult { + if func.is_instance_of::() { + let capsule = func.downcast::().map_err(py_datafusion_err)?; + let capsule: &Bound<'_, PyCapsule> = capsule.into(); + let function = aggregate_udf_from_capsule(capsule)?; + return Ok(Self { function }); + } + if func.hasattr("__datafusion_aggregate_udf__")? { let capsule = func.getattr("__datafusion_aggregate_udf__")?.call0()?; let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_aggregate_udf")?; - - let udaf = unsafe { capsule.reference::() }; - let udaf: ForeignAggregateUDF = udaf.try_into()?; - - Ok(Self { - function: udaf.into(), - }) - } else { - Err(crate::errors::PyDataFusionError::Common( - "__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(), - )) + let capsule: &Bound<'_, PyCapsule> = capsule.into(); + let function = aggregate_udf_from_capsule(capsule)?; + return Ok(Self { function }); } + + Err(crate::errors::PyDataFusionError::Common( + "__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(), + )) } /// creates a new PyExpr with the call of the udf From 9fb0349fdad6afb84e3a73f5f7452f4617d2ba80 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 15 Oct 2025 15:26:17 +0800 Subject: [PATCH 06/12] Refactor from_pycapsule method to simplify capsule handling --- src/udaf.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/udaf.rs b/src/udaf.rs index 0155b0309..da4031276 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -197,16 +197,14 @@ impl PyAggregateUDF { pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult { if func.is_instance_of::() { let capsule = func.downcast::().map_err(py_datafusion_err)?; - let capsule: &Bound<'_, PyCapsule> = capsule.into(); - let function = aggregate_udf_from_capsule(capsule)?; + let function = aggregate_udf_from_capsule(&capsule)?; return Ok(Self { function }); } if func.hasattr("__datafusion_aggregate_udf__")? { let capsule = func.getattr("__datafusion_aggregate_udf__")?.call0()?; let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - let capsule: &Bound<'_, PyCapsule> = capsule.into(); - let function = aggregate_udf_from_capsule(capsule)?; + let function = aggregate_udf_from_capsule(&capsule)?; return Ok(Self { function }); } From 52505a8ed3ec1bfb36faee5b507cceac62484198 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 15 Oct 2025 15:52:21 +0800 Subject: [PATCH 07/12] Refactor _PyCapsule definition for clarity and organization --- python/datafusion/user_defined.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 36d4a35e4..aee832832 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -41,7 +41,13 @@ from datafusion.expr import Expr if TYPE_CHECKING: + from _typeshed import CapsuleType as _PyCapsule + _R = TypeVar("_R", bound=pa.DataType) +else: + + class _PyCapsule: + """Lightweight typing proxy for CPython ``PyCapsule`` objects.""" class Volatility(Enum): @@ -94,10 +100,6 @@ class ScalarUDFExportable(Protocol): def __datafusion_scalar_udf__(self) -> object: ... # noqa: D105 -class _PyCapsule(Protocol): - """Lightweight typing proxy for CPython ``PyCapsule`` objects.""" - - def _is_pycapsule(value: object) -> TypeGuard[_PyCapsule]: """Return ``True`` when ``value`` is a CPython ``PyCapsule``.""" From 4143b8aabd432e8f7fdd4f616ae65baab3e7e74e Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 15 Oct 2025 16:07:26 +0800 Subject: [PATCH 08/12] Remove unnecessary blank line in TYPE_CHECKING block --- python/datafusion/user_defined.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index aee832832..c87973e1a 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -42,7 +42,6 @@ if TYPE_CHECKING: from _typeshed import CapsuleType as _PyCapsule - _R = TypeVar("_R", bound=pa.DataType) else: From 0d7f413b07b76dbf5be908ced92a77b14be633dd Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 15 Oct 2025 17:16:09 +0800 Subject: [PATCH 09/12] Fix ruff errors --- python/datafusion/user_defined.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index c87973e1a..299cd492f 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -33,9 +33,8 @@ overload, ) -from typing_extensions import TypeGuard - import pyarrow as pa +from typing_extensions import TypeGuard import datafusion._internal as df_internal from datafusion.expr import Expr @@ -101,7 +100,6 @@ def __datafusion_scalar_udf__(self) -> object: ... # noqa: D105 def _is_pycapsule(value: object) -> TypeGuard[_PyCapsule]: """Return ``True`` when ``value`` is a CPython ``PyCapsule``.""" - return value.__class__.__name__ == "PyCapsule" From 1f5b093ee986f113d7dcdf3507fcd3ffc8a34137 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 15 Oct 2025 17:18:14 +0800 Subject: [PATCH 10/12] fix ruff format --- python/datafusion/user_defined.py | 7 +++---- python/tests/test_pyclass_frozen.py | 3 +-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 299cd492f..767878978 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -41,6 +41,7 @@ if TYPE_CHECKING: from _typeshed import CapsuleType as _PyCapsule + _R = TypeVar("_R", bound=pa.DataType) else: @@ -319,8 +320,7 @@ def __init__( return_type: pa.DataType, state_type: list[pa.DataType], volatility: Volatility | str, - ) -> None: - ... + ) -> None: ... @overload def __init__( @@ -331,8 +331,7 @@ def __init__( return_type: None = ..., state_type: None = ..., volatility: None = ..., - ) -> None: - ... + ) -> None: ... def __init__( self, diff --git a/python/tests/test_pyclass_frozen.py b/python/tests/test_pyclass_frozen.py index 189ea8dec..428e5e98b 100644 --- a/python/tests/test_pyclass_frozen.py +++ b/python/tests/test_pyclass_frozen.py @@ -32,8 +32,7 @@ r"(?P[A-Za-z_][A-Za-z0-9_]*)\s*=\s*\"(?P[^\"]+)\"", ) STRUCT_NAME_RE = re.compile( - r"\b(?:pub\s+)?(?:struct|enum)\s+" - r"(?P[A-Za-z_][A-Za-z0-9_]*)", + r"\b(?:pub\s+)?(?:struct|enum)\s+" r"(?P[A-Za-z_][A-Za-z0-9_]*)", ) From 86e8caae441471c8fec8df0621a915ed5143ce5f Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 15 Oct 2025 17:48:14 +0800 Subject: [PATCH 11/12] Fix clippy error --- src/udaf.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/udaf.rs b/src/udaf.rs index da4031276..e48e35f8d 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -197,14 +197,14 @@ impl PyAggregateUDF { pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult { if func.is_instance_of::() { let capsule = func.downcast::().map_err(py_datafusion_err)?; - let function = aggregate_udf_from_capsule(&capsule)?; + let function = aggregate_udf_from_capsule(capsule)?; return Ok(Self { function }); } if func.hasattr("__datafusion_aggregate_udf__")? { let capsule = func.getattr("__datafusion_aggregate_udf__")?.call0()?; let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - let function = aggregate_udf_from_capsule(&capsule)?; + let function = aggregate_udf_from_capsule(capsule)?; return Ok(Self { function }); } From 8c78f383d4892e2f0c839a0323cae3d72729ff02 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 5 Nov 2025 11:52:48 +0800 Subject: [PATCH 12/12] Remove unused _PyCapsule class definition when TYPE_CHECKING is not enabled --- python/datafusion/user_defined.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 767878978..c79eb436b 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -43,10 +43,6 @@ from _typeshed import CapsuleType as _PyCapsule _R = TypeVar("_R", bound=pa.DataType) -else: - - class _PyCapsule: - """Lightweight typing proxy for CPython ``PyCapsule`` objects.""" class Volatility(Enum):