Skip to content

Commit c4e7486

Browse files
authored
Add PyCapsule Type Support and Type Hint Enhancements for AggregateUDF in DataFusion Python Bindings (#1277)
Added TypeGuard function _is_pycapsule() for lightweight PyCapsule type validation. Introduced _PyCapsule proxy class for static typing compatibility in non-type-checking contexts. Extended overloads in AggregateUDF.__init__ and AggregateUDF.udaf() to include AggregateUDFExportable | _PyCapsule argument types. Added stricter constructor argument validation for callable accumulators. Updated AggregateUDF.from_pycapsule() to support direct PyCapsule initialization. Refactored Rust PyAggregateUDF::from_pycapsule() logic to delegate PyCapsule validation to a new helper function aggregate_udf_from_capsule() for cleaner handling.
1 parent e97ed57 commit c4e7486

File tree

3 files changed

+82
-19
lines changed

3 files changed

+82
-19
lines changed

python/datafusion/user_defined.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@
2222
import functools
2323
from abc import ABCMeta, abstractmethod
2424
from enum import Enum
25-
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload
25+
from typing import TYPE_CHECKING, Any, Protocol, TypeGuard, TypeVar, cast, overload
2626

2727
import pyarrow as pa
2828

2929
import datafusion._internal as df_internal
3030
from datafusion.expr import Expr
3131

3232
if TYPE_CHECKING:
33+
from _typeshed import CapsuleType as _PyCapsule
34+
3335
_R = TypeVar("_R", bound=pa.DataType)
3436
from collections.abc import Callable
3537

@@ -84,6 +86,11 @@ class ScalarUDFExportable(Protocol):
8486
def __datafusion_scalar_udf__(self) -> object: ... # noqa: D105
8587

8688

89+
def _is_pycapsule(value: object) -> TypeGuard[_PyCapsule]:
90+
"""Return ``True`` when ``value`` is a CPython ``PyCapsule``."""
91+
return value.__class__.__name__ == "PyCapsule"
92+
93+
8794
class ScalarUDF:
8895
"""Class for performing scalar user-defined functions (UDF).
8996
@@ -291,6 +298,7 @@ class AggregateUDF:
291298
also :py:class:`ScalarUDF` for operating on a row by row basis.
292299
"""
293300

301+
@overload
294302
def __init__(
295303
self,
296304
name: str,
@@ -299,6 +307,27 @@ def __init__(
299307
return_type: pa.DataType,
300308
state_type: list[pa.DataType],
301309
volatility: Volatility | str,
310+
) -> None: ...
311+
312+
@overload
313+
def __init__(
314+
self,
315+
name: str,
316+
accumulator: AggregateUDFExportable,
317+
input_types: None = ...,
318+
return_type: None = ...,
319+
state_type: None = ...,
320+
volatility: None = ...,
321+
) -> None: ...
322+
323+
def __init__(
324+
self,
325+
name: str,
326+
accumulator: Callable[[], Accumulator] | AggregateUDFExportable,
327+
input_types: list[pa.DataType] | None,
328+
return_type: pa.DataType | None,
329+
state_type: list[pa.DataType] | None,
330+
volatility: Volatility | str | None,
302331
) -> None:
303332
"""Instantiate a user-defined aggregate function (UDAF).
304333
@@ -308,6 +337,18 @@ def __init__(
308337
if hasattr(accumulator, "__datafusion_aggregate_udf__"):
309338
self._udaf = df_internal.AggregateUDF.from_pycapsule(accumulator)
310339
return
340+
if (
341+
input_types is None
342+
or return_type is None
343+
or state_type is None
344+
or volatility is None
345+
):
346+
msg = (
347+
"`input_types`, `return_type`, `state_type`, and `volatility` "
348+
"must be provided when `accumulator` is callable."
349+
)
350+
raise TypeError(msg)
351+
311352
self._udaf = df_internal.AggregateUDF(
312353
name,
313354
accumulator,
@@ -351,6 +392,14 @@ def udaf(
351392
name: str | None = None,
352393
) -> AggregateUDF: ...
353394

395+
@overload
396+
@staticmethod
397+
def udaf(accum: AggregateUDFExportable) -> AggregateUDF: ...
398+
399+
@overload
400+
@staticmethod
401+
def udaf(accum: _PyCapsule) -> AggregateUDF: ...
402+
354403
@staticmethod
355404
def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901
356405
"""Create a new User-Defined Aggregate Function (UDAF).
@@ -471,7 +520,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:
471520

472521
return decorator
473522

474-
if hasattr(args[0], "__datafusion_aggregate_udf__"):
523+
if hasattr(args[0], "__datafusion_aggregate_udf__") or _is_pycapsule(args[0]):
475524
return AggregateUDF.from_pycapsule(args[0])
476525

477526
if args and callable(args[0]):
@@ -481,16 +530,22 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:
481530
return _decorator(*args, **kwargs)
482531

483532
@staticmethod
484-
def from_pycapsule(func: AggregateUDFExportable) -> AggregateUDF:
533+
def from_pycapsule(func: AggregateUDFExportable | _PyCapsule) -> AggregateUDF:
485534
"""Create an Aggregate UDF from AggregateUDF PyCapsule object.
486535
487536
This function will instantiate a Aggregate UDF that uses a DataFusion
488537
AggregateUDF that is exported via the FFI bindings.
489538
"""
490-
name = str(func.__class__)
539+
if _is_pycapsule(func):
540+
aggregate = cast(AggregateUDF, object.__new__(AggregateUDF))
541+
aggregate._udaf = df_internal.AggregateUDF.from_pycapsule(func)
542+
return aggregate
543+
544+
capsule = cast(AggregateUDFExportable, func)
545+
name = str(capsule.__class__)
491546
return AggregateUDF(
492547
name=name,
493-
accumulator=func,
548+
accumulator=capsule,
494549
input_types=None,
495550
return_type=None,
496551
state_type=None,

python/tests/test_pyclass_frozen.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@
3535
r"(?P<key>[A-Za-z_][A-Za-z0-9_]*)\s*=\s*\"(?P<value>[^\"]+)\"",
3636
)
3737
STRUCT_NAME_RE = re.compile(
38-
r"\b(?:pub\s+)?(?:struct|enum)\s+"
39-
r"(?P<name>[A-Za-z_][A-Za-z0-9_]*)",
38+
r"\b(?:pub\s+)?(?:struct|enum)\s+" r"(?P<name>[A-Za-z_][A-Za-z0-9_]*)",
4039
)
4140

4241

src/udaf.rs

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,15 @@ pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction {
154154
})
155155
}
156156

157+
fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult<AggregateUDF> {
158+
validate_pycapsule(capsule, "datafusion_aggregate_udf")?;
159+
160+
let udaf = unsafe { capsule.reference::<FFI_AggregateUDF>() };
161+
let udaf: ForeignAggregateUDF = udaf.try_into()?;
162+
163+
Ok(udaf.into())
164+
}
165+
157166
/// Represents an AggregateUDF
158167
#[pyclass(frozen, name = "AggregateUDF", module = "datafusion", subclass)]
159168
#[derive(Debug, Clone)]
@@ -186,22 +195,22 @@ impl PyAggregateUDF {
186195

187196
#[staticmethod]
188197
pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
198+
if func.is_instance_of::<PyCapsule>() {
199+
let capsule = func.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
200+
let function = aggregate_udf_from_capsule(capsule)?;
201+
return Ok(Self { function });
202+
}
203+
189204
if func.hasattr("__datafusion_aggregate_udf__")? {
190205
let capsule = func.getattr("__datafusion_aggregate_udf__")?.call0()?;
191206
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
192-
validate_pycapsule(capsule, "datafusion_aggregate_udf")?;
193-
194-
let udaf = unsafe { capsule.reference::<FFI_AggregateUDF>() };
195-
let udaf: ForeignAggregateUDF = udaf.try_into()?;
196-
197-
Ok(Self {
198-
function: udaf.into(),
199-
})
200-
} else {
201-
Err(crate::errors::PyDataFusionError::Common(
202-
"__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(),
203-
))
207+
let function = aggregate_udf_from_capsule(capsule)?;
208+
return Ok(Self { function });
204209
}
210+
211+
Err(crate::errors::PyDataFusionError::Common(
212+
"__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(),
213+
))
205214
}
206215

207216
/// creates a new PyExpr with the call of the udf

0 commit comments

Comments
 (0)