Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,23 @@ def wait_concurrent(x):
time.sleep(2)
return x

# Define a table-valued function that emits multiple columns.
@udf(
input_types=["INT"],
result_type=[("num", "INT"), ("double_num", "INT")],
batch_mode=True,
)
def expand_numbers(nums: List[int]):
import pyarrow as pa

schema = pa.schema(
[pa.field("num", pa.int32(), nullable=False), pa.field("double_num", pa.int32(), nullable=False)]
)
return pa.RecordBatch.from_arrays(
[pa.array(nums, type=pa.int32()), pa.array([n * 2 for n in nums], type=pa.int32())],
schema=schema,
)

if __name__ == '__main__':
# create a UDF server listening at '0.0.0.0:8815'
server = UDFServer("0.0.0.0:8815")
Expand All @@ -50,6 +67,7 @@ if __name__ == '__main__':
server.add_function(gcd)
server.add_function(array_index_of)
server.add_function(wait_concurrent)
server.add_function(expand_numbers)
# start the UDF server
server.serve()
```
Expand Down
78 changes: 56 additions & 22 deletions python/databend_udf/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import pyarrow.flight as fl


_SCHEMA_METADATA_INPUT_COUNT_KEY = b"x-databend-udf-input-count"


class UDFClient:
"""Simple client for calling UDF functions on a Databend UDF server."""

Expand All @@ -29,6 +32,45 @@ def _get_cached_schema(self, function_name: str) -> pa.Schema:
self._schema_cache[function_name] = info.schema
return self._schema_cache[function_name]

@staticmethod
def _get_input_field_count(schema: pa.Schema) -> int:
"""Extract the number of input columns from schema metadata."""
metadata = schema.metadata or {}
key = _SCHEMA_METADATA_INPUT_COUNT_KEY
if metadata and key in metadata:
try:
return int(metadata[key].decode("utf-8"))
except (ValueError, AttributeError):
pass

# Fallback for older servers without metadata: assume final column is output.
return max(len(schema) - 1, 0)

@staticmethod
def _decode_result_batch(batch: pa.RecordBatch) -> List[Any]:
"""Convert a result RecordBatch into Python values."""
num_columns = batch.num_columns
num_rows = batch.num_rows

if num_columns == 0:
return [{} for _ in range(num_rows)]

if num_columns == 1:
column = batch.column(0)
return [column[i].as_py() for i in range(num_rows)]

field_names = [
batch.schema.field(i).name or f"col{i}" for i in range(num_columns)
]

rows: List[Dict[str, Any]] = []
for row_idx in range(num_rows):
row: Dict[str, Any] = {}
for col_idx, name in enumerate(field_names):
row[name] = batch.column(col_idx)[row_idx].as_py()
rows.append(row)
return rows

def _prepare_function_call(
self, function_name: str, args: tuple = (), kwargs: dict = None
) -> tuple:
Expand All @@ -41,32 +83,35 @@ def _prepare_function_call(
kwargs = kwargs or {}
schema = self._get_cached_schema(function_name)

input_field_count = self._get_input_field_count(schema)
total_fields = len(schema)
if input_field_count > total_fields:
raise ValueError(
f"Function '{function_name}' schema metadata is invalid (input count {input_field_count} > total fields {total_fields})"
)

input_fields = [schema.field(i) for i in range(input_field_count)]
input_schema = pa.schema(input_fields)

# Validate arguments
if args and kwargs:
raise ValueError("Cannot mix positional and keyword arguments")

if args:
# Positional arguments - validate count first
total_fields = len(schema)
expected_input_count = total_fields - 1 # Last field is always output

if len(args) != expected_input_count:
if len(args) != input_field_count:
raise ValueError(
f"Function '{function_name}' expects {expected_input_count} arguments, got {len(args)}"
f"Function '{function_name}' expects {input_field_count} arguments, got {len(args)}"
)

if len(args) == 0:
input_schema = pa.schema([])
# For zero-argument functions, create a batch with 1 row and no columns
dummy_array = pa.array([None])
temp_batch = pa.RecordBatch.from_arrays(
[dummy_array], schema=pa.schema([pa.field("dummy", pa.null())])
)
batch = temp_batch.select([])
else:
input_fields = [schema.field(i) for i in range(len(args))]
input_schema = pa.schema(input_fields)

arrays = []
for i, arg in enumerate(args):
if isinstance(arg, list):
Expand All @@ -85,13 +130,6 @@ def _prepare_function_call(
)
batch = temp_batch.select([])
else:
# Extract only input fields (exclude the last field which is output)
# The schema contains input fields + 1 output field
total_fields = len(schema)
num_input_fields = total_fields - 1 # Last field is always output
input_fields = [schema.field(i) for i in range(num_input_fields)]
input_schema = pa.schema(input_fields)

# Validate kwargs
expected_fields = {field.name for field in input_schema}
provided_fields = set(kwargs.keys())
Expand Down Expand Up @@ -223,12 +261,10 @@ def call_function(
writer.write_batch(batch)
writer.done_writing()

# Get results
results = []
for result_chunk in reader:
result_batch = result_chunk.data
for i in range(result_batch.num_rows):
results.append(result_batch.column(0)[i].as_py())
results.extend(self._decode_result_batch(result_batch))

return results

Expand Down Expand Up @@ -264,12 +300,10 @@ def call_function_batch(
writer.write_batch(batch)
writer.done_writing()

# Get results
results = []
for result_chunk in reader:
result_batch = result_chunk.data
for i in range(result_batch.num_rows):
results.append(result_batch.column(0)[i].as_py())
results.extend(self._decode_result_batch(result_batch))

return results

Expand Down
Loading
Loading