From 04f4698c2bebd81d617c601dfc3a3dabc8cdc0e1 Mon Sep 17 00:00:00 2001 From: kould Date: Wed, 5 Nov 2025 01:07:43 +0800 Subject: [PATCH 1/2] feat: impl UDTF for @udf --- python/README.md | 18 ++ python/databend_udf/client.py | 78 ++++-- python/databend_udf/udf.py | 403 +++++++++++++++++++++------ python/tests/servers/basic_server.py | 31 +++ python/tests/test_simple_udf.py | 17 ++ 5 files changed, 442 insertions(+), 105 deletions(-) diff --git a/python/README.md b/python/README.md index 5a6b4e1..bf49259 100644 --- a/python/README.md +++ b/python/README.md @@ -42,6 +42,23 @@ def wait_concurrent(x): time.sleep(2) return x +# Define a table-valued function that emits multiple columns. +@udf( + input_types=["INT"], + result_type=[("num", "INT"), ("double_num", "INT")], + batch_mode=True, +) +def expand_numbers(nums: List[int]): + import pyarrow as pa + + schema = pa.schema( + [pa.field("num", pa.int32(), nullable=False), pa.field("double_num", pa.int32(), nullable=False)] + ) + return pa.RecordBatch.from_arrays( + [pa.array(nums, type=pa.int32()), pa.array([n * 2 for n in nums], type=pa.int32())], + schema=schema, + ) + if __name__ == '__main__': # create a UDF server listening at '0.0.0.0:8815' server = UDFServer("0.0.0.0:8815") @@ -50,6 +67,7 @@ if __name__ == '__main__': server.add_function(gcd) server.add_function(array_index_of) server.add_function(wait_concurrent) + server.add_function(expand_numbers) # start the UDF server server.serve() ``` diff --git a/python/databend_udf/client.py b/python/databend_udf/client.py index 51015dd..1a4a3f3 100644 --- a/python/databend_udf/client.py +++ b/python/databend_udf/client.py @@ -7,6 +7,9 @@ import pyarrow.flight as fl +_SCHEMA_METADATA_INPUT_COUNT_KEY = b"x-databend-udf-input-count" + + class UDFClient: """Simple client for calling UDF functions on a Databend UDF server.""" @@ -29,6 +32,45 @@ def _get_cached_schema(self, function_name: str) -> pa.Schema: self._schema_cache[function_name] = info.schema return self._schema_cache[function_name] + @staticmethod + def _get_input_field_count(schema: pa.Schema) -> int: + """Extract the number of input columns from schema metadata.""" + metadata = schema.metadata or {} + key = _SCHEMA_METADATA_INPUT_COUNT_KEY + if metadata and key in metadata: + try: + return int(metadata[key].decode("utf-8")) + except (ValueError, AttributeError): + pass + + # Fallback for older servers without metadata: assume final column is output. + return max(len(schema) - 1, 0) + + @staticmethod + def _decode_result_batch(batch: pa.RecordBatch) -> List[Any]: + """Convert a result RecordBatch into Python values.""" + num_columns = batch.num_columns + num_rows = batch.num_rows + + if num_columns == 0: + return [{} for _ in range(num_rows)] + + if num_columns == 1: + column = batch.column(0) + return [column[i].as_py() for i in range(num_rows)] + + field_names = [ + batch.schema.field(i).name or f"col{i}" for i in range(num_columns) + ] + + rows: List[Dict[str, Any]] = [] + for row_idx in range(num_rows): + row: Dict[str, Any] = {} + for col_idx, name in enumerate(field_names): + row[name] = batch.column(col_idx)[row_idx].as_py() + rows.append(row) + return rows + def _prepare_function_call( self, function_name: str, args: tuple = (), kwargs: dict = None ) -> tuple: @@ -41,22 +83,28 @@ def _prepare_function_call( kwargs = kwargs or {} schema = self._get_cached_schema(function_name) + input_field_count = self._get_input_field_count(schema) + total_fields = len(schema) + if input_field_count > total_fields: + raise ValueError( + f"Function '{function_name}' schema metadata is invalid (input count {input_field_count} > total fields {total_fields})" + ) + + input_fields = [schema.field(i) for i in range(input_field_count)] + input_schema = pa.schema(input_fields) + # Validate arguments if args and kwargs: raise ValueError("Cannot mix positional and keyword arguments") if args: # Positional arguments - validate count first - total_fields = len(schema) - expected_input_count = total_fields - 1 # Last field is always output - - if len(args) != expected_input_count: + if len(args) != input_field_count: raise ValueError( - f"Function '{function_name}' expects {expected_input_count} arguments, got {len(args)}" + f"Function '{function_name}' expects {input_field_count} arguments, got {len(args)}" ) if len(args) == 0: - input_schema = pa.schema([]) # For zero-argument functions, create a batch with 1 row and no columns dummy_array = pa.array([None]) temp_batch = pa.RecordBatch.from_arrays( @@ -64,9 +112,6 @@ def _prepare_function_call( ) batch = temp_batch.select([]) else: - input_fields = [schema.field(i) for i in range(len(args))] - input_schema = pa.schema(input_fields) - arrays = [] for i, arg in enumerate(args): if isinstance(arg, list): @@ -85,13 +130,6 @@ def _prepare_function_call( ) batch = temp_batch.select([]) else: - # Extract only input fields (exclude the last field which is output) - # The schema contains input fields + 1 output field - total_fields = len(schema) - num_input_fields = total_fields - 1 # Last field is always output - input_fields = [schema.field(i) for i in range(num_input_fields)] - input_schema = pa.schema(input_fields) - # Validate kwargs expected_fields = {field.name for field in input_schema} provided_fields = set(kwargs.keys()) @@ -223,12 +261,10 @@ def call_function( writer.write_batch(batch) writer.done_writing() - # Get results results = [] for result_chunk in reader: result_batch = result_chunk.data - for i in range(result_batch.num_rows): - results.append(result_batch.column(0)[i].as_py()) + results.extend(self._decode_result_batch(result_batch)) return results @@ -264,12 +300,10 @@ def call_function_batch( writer.write_batch(batch) writer.done_writing() - # Get results results = [] for result_chunk in reader: result_batch = result_chunk.data - for i in range(result_batch.num_rows): - results.append(result_batch.column(0)[i].as_py()) + results.extend(self._decode_result_batch(result_batch)) return results diff --git a/python/databend_udf/udf.py b/python/databend_udf/udf.py index a8c839b..e4a5fba 100644 --- a/python/databend_udf/udf.py +++ b/python/databend_udf/udf.py @@ -27,6 +27,8 @@ Dict, Any, Tuple, + Sequence, + Iterable, ) from typing import get_args, get_origin from prometheus_client import Counter, Gauge, Histogram @@ -48,6 +50,7 @@ ARROW_EXT_TYPE_VARIANT = b"Variant" TIMESTAMP_UINT = "us" +_SCHEMA_METADATA_INPUT_COUNT_KEY = b"x-databend-udf-input-count" logger = logging.getLogger(__name__) @@ -378,30 +381,30 @@ def eval_batch( return iter([]) -class ScalarFunction(UserDefinedFunction): +class CallableFunction(UserDefinedFunction): """ - Base interface for user-defined scalar function. - A user-defined scalar functions maps zero, one, - or multiple scalar values to a new scalar value. + Shared implementation for callable UDF/UDTF objects. """ _func: Callable - _io_threads: Optional[int] - _executor: Optional[ThreadPoolExecutor] - _skip_null: bool - _batch_mode: bool + _headers_param: Optional[str] + _stage_ref_names: List[str] + _stage_param_to_ref: Dict[str, str] + _stage_param_set: set + _arg_order: List[str] + _data_arg_names: List[str] + _call_arg_layout: List[Tuple[str, str]] + _data_arg_indices: Dict[str, int] + _sql_parameter_defs: List[str] def __init__( self, - func, - input_types, - result_type, - name=None, + func: Callable, + input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]], + result_fields: List[pa.Field], + name: Optional[str] = None, stage_refs: Optional[List[str]] = None, - io_threads=None, - skip_null=None, - batch_mode=False, - ): + ) -> None: self._func = func self._stage_ref_names = list(stage_refs or []) if len(self._stage_ref_names) != len(set(self._stage_ref_names)): @@ -409,7 +412,7 @@ def __init__( spec = inspect.getfullargspec(func) arg_names = list(spec.args) - self._headers_param: Optional[str] = None + self._headers_param = None if arg_names and arg_names[-1] == "headers": self._headers_param = arg_names.pop() @@ -446,13 +449,13 @@ def __init__( param: ref for param, ref in zip(stage_param_python_names, self._stage_ref_names) } - self._stage_param_names = stage_param_python_names - self._stage_param_set = set(self._stage_param_names) + self._stage_param_set = set(stage_param_python_names) self._arg_order = list(arg_names) self._data_arg_names = [ name for name in self._arg_order if name not in self._stage_param_set ] + input_type_list = _to_list(input_types) if len(self._data_arg_names) != len(input_type_list): raise ValueError( @@ -469,21 +472,15 @@ def __init__( name: idx for idx, name in enumerate(self._data_arg_names) } - self._result_schema = pa.schema( - [_to_arrow_field(result_type).with_name("output")] - ) + if not result_fields: + raise ValueError("result_fields must contain at least one column") + self._result_schema = pa.schema(result_fields) + self._name = name or ( func.__name__ if hasattr(func, "__name__") else func.__class__.__name__ ) - self._io_threads = io_threads - self._batch_mode = batch_mode - self._executor = ( - ThreadPoolExecutor(max_workers=self._io_threads) - if self._io_threads is not None - else None - ) - self._call_arg_layout: List[Tuple[str, str]] = [] + self._call_arg_layout = [] for parameter in self._arg_order: if parameter in self._stage_param_set: self._call_arg_layout.append(("stage", parameter)) @@ -493,7 +490,7 @@ def __init__( self._call_arg_layout.append(("headers", self._headers_param)) data_field_map = {field.name: field for field in self._input_schema} - self._sql_parameter_defs: List[str] = [] + self._sql_parameter_defs = [] for kind, identifier in self._call_arg_layout: if kind == "stage": stage_ref_name = self._stage_param_to_ref.get(identifier, identifier) @@ -504,19 +501,13 @@ def __init__( f"{field.name} {_arrow_field_to_string(field)}" ) - if skip_null and not self._result_schema.field(0).nullable: - raise ValueError( - f"Return type of function {self._name} must be nullable when skip_null is True" - ) - - self._skip_null = skip_null or False + self._is_table_function = False super().__init__() - def eval_batch( - self, batch: pa.RecordBatch, headers: Optional[Headers] = None - ) -> Iterator[pa.RecordBatch]: - headers_obj = headers if isinstance(headers, Headers) else Headers(headers) + def _ensure_headers(self, headers: Optional[Headers]) -> Headers: + return headers if isinstance(headers, Headers) else Headers(headers) + def _resolve_stage_locations(self, headers_obj: Headers) -> Dict[str, StageLocation]: stage_locations: Dict[str, StageLocation] = {} if self._stage_ref_names: stage_locations_by_ref = headers_obj.require_stage_locations( @@ -541,13 +532,106 @@ def eval_batch( raise ValueError( f"Stage parameter '{name}' must not use 'fs' storage" ) + return stage_locations + def _convert_inputs(self, batch: pa.RecordBatch) -> List[List[Any]]: processed_inputs: List[List[Any]] = [] for array, field in zip(batch, self._input_schema): python_values = [value.as_py() for value in array] processed_inputs.append( _input_process_func(_list_field(field))(python_values) ) + return processed_inputs + + def _assemble_args( + self, + data_inputs: List[List[Any]], + stage_locations: Dict[str, StageLocation], + headers: Headers, + row_idx: Optional[int], + ) -> List[Any]: + args: List[Any] = [] + for kind, identifier in self._call_arg_layout: + if kind == "data": + data_index = self._data_arg_indices[identifier] + values = data_inputs[data_index] + args.append(values if row_idx is None else values[row_idx]) + elif kind == "stage": + if identifier not in stage_locations: + raise ValueError( + f"Missing stage mapping for parameter '{identifier}'" + ) + args.append(stage_locations[identifier]) + elif kind == "headers": + args.append(headers) + return args + + def _row_has_null(self, data_inputs: List[List[Any]], row_idx: int) -> bool: + for values in data_inputs: + if values[row_idx] is None: + return True + return False + + def __call__(self, *args): + return self._func(*args) + + +class ScalarFunction(CallableFunction): + """ + Base interface for user-defined scalar function. + A user-defined scalar function maps zero, one, + or multiple scalar values to a new scalar value. + """ + + _io_threads: Optional[int] + _executor: Optional[ThreadPoolExecutor] + _skip_null: bool + _batch_mode: bool + + def __init__( + self, + func, + input_types, + result_type, + name=None, + stage_refs: Optional[List[str]] = None, + io_threads=None, + skip_null=None, + batch_mode=False, + ): + result_fields, is_table = _normalize_result_type(result_type) + if is_table: + raise ValueError("ScalarFunction result_type must describe a single value") + + super().__init__( + func, + input_types, + result_fields, + name=name, + stage_refs=stage_refs, + ) + self._io_threads = io_threads + self._batch_mode = batch_mode + self._executor = ( + ThreadPoolExecutor(max_workers=self._io_threads) + if self._io_threads is not None + else None + ) + + if skip_null and not self._result_schema.field(0).nullable: + raise ValueError( + f"Return type of function {self._name} must be nullable when skip_null is True" + ) + + self._skip_null = skip_null or False + self._is_table_function = False + + def eval_batch( + self, batch: pa.RecordBatch, headers: Optional[Headers] = None + ) -> Iterator[pa.RecordBatch]: + headers_obj = self._ensure_headers(headers) + stage_locations = self._resolve_stage_locations(headers_obj) + processed_inputs = self._convert_inputs(batch) if self._batch_mode: call_args = self._assemble_args( @@ -586,37 +670,105 @@ def eval_batch( array = pa.array(column, type=self._result_schema.types[0]) yield pa.RecordBatch.from_arrays([array], schema=self._result_schema) - def _assemble_args( + +class TableFunction(CallableFunction): + """ + Table-valued user-defined function that returns zero or more rows. + """ + + _batch_mode: bool + + def __init__( self, - data_inputs: List[List[Any]], - stage_locations: Dict[str, StageLocation], - headers: Headers, - row_idx: Optional[int], - ) -> List[Any]: - args: List[Any] = [] - for kind, identifier in self._call_arg_layout: - if kind == "data": - data_index = self._data_arg_indices[identifier] - values = data_inputs[data_index] - args.append(values if row_idx is None else values[row_idx]) - elif kind == "stage": - if identifier not in stage_locations: - raise ValueError( - f"Missing stage mapping for parameter '{identifier}'" - ) - args.append(stage_locations[identifier]) - elif kind == "headers": - args.append(headers) - return args + func, + input_types, + result_type, + name=None, + stage_refs: Optional[List[str]] = None, + io_threads=None, + skip_null=None, + batch_mode=False, + ): + if io_threads not in (None, 0, 1): + raise ValueError("Table functions do not support io_threads > 1") + if skip_null: + raise ValueError("Table functions do not support skip_null semantics") - def _row_has_null(self, data_inputs: List[List[Any]], row_idx: int) -> bool: - for values in data_inputs: - if values[row_idx] is None: - return True - return False + result_fields, is_table = _normalize_result_type(result_type) + if not is_table: + raise ValueError( + "TableFunction result_type must describe a table result (list of columns)" + ) - def __call__(self, *args): - return self._func(*args) + super().__init__( + func, + input_types, + result_fields, + name=name, + stage_refs=stage_refs, + ) + self._batch_mode = bool(batch_mode) + self._is_table_function = True + + def eval_batch( + self, batch: pa.RecordBatch, headers: Optional[Headers] = None + ) -> Iterator[pa.RecordBatch]: + headers_obj = self._ensure_headers(headers) + stage_locations = self._resolve_stage_locations(headers_obj) + processed_inputs = self._convert_inputs(batch) + + if self._batch_mode: + call_args = self._assemble_args( + processed_inputs, stage_locations, headers_obj, row_idx=None + ) + yield from self._iter_output_batches(self._func(*call_args)) + else: + for row in range(batch.num_rows): + call_args = self._assemble_args( + processed_inputs, stage_locations, headers_obj, row + ) + yield from self._iter_output_batches(self._func(*call_args)) + + def _iter_output_batches( + self, result: Any + ) -> Iterator[pa.RecordBatch]: + if result is None: + return + if isinstance(result, pa.RecordBatch): + yield self._align_record_batch(result) + return + if isinstance(result, pa.Table): + for batch in result.to_batches(): + yield self._align_record_batch(batch) + return + if isinstance(result, Iterable) and not isinstance( + result, (bytes, str, dict) + ): + for item in result: + yield from self._iter_output_batches(item) + return + raise TypeError( + f"Table function {self._name} must return a pyarrow.RecordBatch, pyarrow.Table, or an iterable of them" + ) + + def _align_record_batch(self, batch: pa.RecordBatch) -> pa.RecordBatch: + if batch.schema.equals(self._result_schema, check_metadata=False): + return batch + + columns = [] + for field in self._result_schema: + index = batch.schema.get_field_index(field.name) + if index == -1: + raise ValueError( + f"Missing expected column '{field.name}' in table function result" + ) + source_field = batch.schema.field(index) + if not source_field.type.equals(field.type, check_metadata=False): + raise ValueError( + f"Column '{field.name}' type mismatch: expected {field.type}, got {source_field.type}" + ) + columns.append(batch.column(index)) + return pa.RecordBatch.from_arrays(columns, schema=self._result_schema) def udf( @@ -670,19 +822,44 @@ def gcd(x, y): ``` """ + is_table = isinstance(result_type, (pa.Schema, pa.Field)) or ( + isinstance(result_type, Sequence) and not isinstance(result_type, (str, bytes)) + ) + + if is_table: + effective_io_threads = None if io_threads in (None, 32) else io_threads + + def decorator(f): + return TableFunction( + f, + input_types, + result_type, + name, + stage_refs=stage_refs, + io_threads=effective_io_threads, + skip_null=skip_null, + batch_mode=batch_mode, + ) + + return decorator + if io_threads is not None and io_threads > 1: - return lambda f: ScalarFunction( - f, - input_types, - result_type, - name, - stage_refs=stage_refs, - io_threads=io_threads, - skip_null=skip_null, - batch_mode=batch_mode, - ) - else: - return lambda f: ScalarFunction( + def decorator(f): + return ScalarFunction( + f, + input_types, + result_type, + name, + stage_refs=stage_refs, + io_threads=io_threads, + skip_null=skip_null, + batch_mode=batch_mode, + ) + + return decorator + + def decorator(f): + return ScalarFunction( f, input_types, result_type, @@ -692,6 +869,8 @@ def gcd(x, y): batch_mode=batch_mode, ) + return decorator + class UDFServer(FlightServerBase): """ @@ -794,6 +973,11 @@ def get_flight_info(self, context, descriptor): udf = self._functions[func_name] # return the concatenation of input and output schema full_schema = pa.schema(list(udf._input_schema) + list(udf._result_schema)) + metadata = dict(full_schema.metadata.items()) if full_schema.metadata else {} + metadata[_SCHEMA_METADATA_INPUT_COUNT_KEY] = str(len(udf._input_schema)).encode( + "utf-8" + ) + full_schema = full_schema.with_metadata(metadata) return FlightInfo( schema=full_schema, descriptor=descriptor, @@ -860,10 +1044,17 @@ def add_function(self, udf: UserDefinedFunction): for field in udf._input_schema ] input_types = ", ".join(parameter_defs) - output_type = _arrow_field_to_string(udf._result_schema[0]) + if getattr(udf, "_is_table_function", False): + column_defs = [ + f"{field.name} {_arrow_field_to_string(field)}" + for field in udf._result_schema + ] + returns_clause = f"RETURNS TABLE ({', '.join(column_defs)})" + else: + returns_clause = f"RETURNS {_arrow_field_to_string(udf._result_schema[0])}" sql = ( f"CREATE FUNCTION {name} ({input_types}) " - f"RETURNS {output_type} LANGUAGE python " + f"{returns_clause} LANGUAGE python " f"HANDLER = '{name}' ADDRESS = 'http://{self._location}';" ) logger.info(f"added function: {name}, SQL:\n{sql}\n") @@ -985,6 +1176,52 @@ def _to_list(x): return [x] +def _normalize_result_type( + result_type: Any, default_name: str = "output" +) -> Tuple[List[pa.Field], bool]: + """ + Normalize the declared result_type into a list of Arrow fields. + + Returns: + (fields, is_table) where ``is_table`` indicates whether the definition + represents a table-valued output (multiple columns or explicit schema). + """ + if isinstance(result_type, pa.Schema): + fields = list(result_type) + if not fields: + raise ValueError("Result schema must contain at least one column") + _ensure_unique_names([field.name for field in fields]) + return fields, True + + if isinstance(result_type, pa.Field): + field = result_type if result_type.name else result_type.with_name(default_name) + return [field], True + + if isinstance(result_type, (list, tuple)): + fields: List[pa.Field] = [] + for idx, item in enumerate(result_type): + if isinstance(item, pa.Field): + field = item if item.name else item.with_name(f"col{idx}") + elif isinstance(item, tuple) and len(item) == 2: + name, type_def = item + field = _to_arrow_field(type_def).with_name(str(name)) + else: + field = _to_arrow_field(item).with_name(f"col{idx}") + fields.append(field) + if not fields: + raise ValueError("Table function result_type must contain at least one column") + _ensure_unique_names([field.name for field in fields]) + return fields, True + + field = _to_arrow_field(result_type).with_name(default_name) + return [field], False + + +def _ensure_unique_names(names: List[str]) -> None: + if len(names) != len(set(names)): + raise ValueError("Result columns must have unique names") + + def _ensure_str(x): if isinstance(x, bytes): return x.decode("utf-8") diff --git a/python/tests/servers/basic_server.py b/python/tests/servers/basic_server.py index e6824bc..01992b7 100644 --- a/python/tests/servers/basic_server.py +++ b/python/tests/servers/basic_server.py @@ -4,6 +4,14 @@ """ import logging +import os +import sys + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + +import pyarrow as pa from databend_udf import udf, UDFServer logging.basicConfig(level=logging.INFO) @@ -21,11 +29,34 @@ def gcd(x, y): return x +@udf( + input_types=["INT"], + result_type=[("num", "INT"), ("double_num", "INT")], + batch_mode=True, +) +def expand_numbers(nums): + doubled = [value * 2 for value in nums] + schema = pa.schema( + [ + pa.field("num", pa.int32(), nullable=False), + pa.field("double_num", pa.int32(), nullable=False), + ] + ) + return pa.RecordBatch.from_arrays( + [ + pa.array(nums, type=pa.int32()), + pa.array(doubled, type=pa.int32()), + ], + schema=schema, + ) + + def create_basic_server(port=8815): """Create server with basic arithmetic functions.""" server = UDFServer(f"0.0.0.0:{port}") server.add_function(add_two_ints) server.add_function(gcd) + server.add_function(expand_numbers) return server diff --git a/python/tests/test_simple_udf.py b/python/tests/test_simple_udf.py index 32840a3..4b6d05a 100644 --- a/python/tests/test_simple_udf.py +++ b/python/tests/test_simple_udf.py @@ -86,3 +86,20 @@ def test_batch_array_length_validation(basic_server): assert False, "Should have raised ValueError for inconsistent array lengths" except ValueError as e: assert "All batch arrays must have the same length" in str(e) + + +def test_table_function_returns_records(basic_server): + """Test table-valued function returning multiple columns.""" + client = basic_server.get_client() + + batch_result = client.call_function_batch( + "expand_numbers", nums=[1, 2, 3] + ) + assert batch_result == [ + {"num": 1, "double_num": 2}, + {"num": 2, "double_num": 4}, + {"num": 3, "double_num": 6}, + ] + + single_result = client.call_function("expand_numbers", 5) + assert single_result == [{"num": 5, "double_num": 10}] From dc09c1e4d03d557ddc40da2715dd8a690b5ac539 Mon Sep 17 00:00:00 2001 From: kould Date: Wed, 5 Nov 2025 01:17:35 +0800 Subject: [PATCH 2/2] chore: codefmt --- python/tests/servers/basic_server.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/tests/servers/basic_server.py b/python/tests/servers/basic_server.py index 01992b7..817c9c9 100644 --- a/python/tests/servers/basic_server.py +++ b/python/tests/servers/basic_server.py @@ -7,12 +7,15 @@ import os import sys -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) -if PROJECT_ROOT not in sys.path: - sys.path.insert(0, PROJECT_ROOT) - import pyarrow as pa -from databend_udf import udf, UDFServer + +try: + from databend_udf import udf, UDFServer +except ModuleNotFoundError: # pragma: no cover - fallback for local execution + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + if project_root not in sys.path: + sys.path.insert(0, project_root) + from databend_udf import udf, UDFServer logging.basicConfig(level=logging.INFO) @@ -61,8 +64,6 @@ def create_basic_server(port=8815): if __name__ == "__main__": - import sys - port = int(sys.argv[1]) if len(sys.argv) > 1 else 8815 server = create_basic_server(port) server.serve()