diff --git a/.gitignore b/.gitignore index 2a5aa9a..a7f514f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ venv/ dist/ __pycache__/ Pipfile.lock +uv.lock .ruff_cache/ .vscode python/example/test.py diff --git a/python/README.md b/python/README.md index 772b9d0..7ed5cc6 100644 --- a/python/README.md +++ b/python/README.md @@ -1,5 +1,11 @@ ### Usage +#### Features + +- Arrow flight server (defaults to port 8815) +- HTTP server (disabled by default) +- Prometheus metrics server (disabled by default) + #### 1. Define your functions in a Python file ```python from databend_udf import * diff --git a/python/databend_udf/udf.py b/python/databend_udf/udf.py index ee13ac8..ecaac51 100644 --- a/python/databend_udf/udf.py +++ b/python/databend_udf/udf.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import time import logging import inspect from concurrent.futures import ThreadPoolExecutor @@ -21,6 +22,14 @@ from prometheus_client import start_http_server import threading +from fastapi import FastAPI, Request, Response +from fastapi.responses import StreamingResponse +from typing import Any, Dict +from uvicorn import run +import pyarrow as pa +from pyarrow import ipc +from io import BytesIO + import pyarrow as pa from pyarrow.flight import FlightServerBase, FlightInfo @@ -216,7 +225,6 @@ def gcd(x, y): batch_mode=batch_mode, ) - class UDFServer(FlightServerBase): """ A server that provides user-defined functions to clients. @@ -232,11 +240,19 @@ class UDFServer(FlightServerBase): _location: str _functions: Dict[str, UserDefinedFunction] - def __init__(self, location="0.0.0.0:8815", metric_location=None, **kwargs): + def __init__( + self, + location="0.0.0.0:8815", + metric_location=None, + http_location=None, + **kwargs, + ): super(UDFServer, self).__init__("grpc://" + location, **kwargs) self._location = location self._metric_location = metric_location + self._http_location = http_location self._functions = {} + self.app = FastAPI() # Initialize Prometheus metrics self.requests_count = Counter( @@ -296,9 +312,6 @@ def _start_metrics_server(self): def start_server(): start_http_server(port, host) - logger.info( - f"Prometheus metrics server started on {self._metric_location}" - ) metrics_thread = threading.Thread(target=start_server, daemon=True) metrics_thread.start() @@ -306,6 +319,35 @@ def start_server(): logger.error(f"Failed to start metrics server: {e}") raise + def _start_httpudf_server(self): + """Start UDF HTTP server if http_location is provided""" + try: + host, port = self._http_location.split(":") + port = int(port) + app = self.app + + # Middleware to measure elapsed time + @app.middleware("http") + async def log_elapsed_time(request: Request, call_next): + start_time = time.time() + response = await call_next(request) + elapsed_time = time.time() - start_time + logger.info(f"{request.method} {request.url.path} - From {start_time}, Elapsed time: {elapsed_time:.4f} seconds") + return response + + @app.get("/") + async def root(): + return {"protocol" : "http", "description": "databend-udf-server"} + + def start_server(): + run(app, host=host, port=port) + + http_thread = threading.Thread(target=start_server, daemon=True) + http_thread.start() + except Exception as e: + logger.error(f"Failed to start http udf server: {e}") + raise + def get_flight_info(self, context, descriptor): """Return the result schema of a function.""" func_name = descriptor.path[0].decode("utf-8") @@ -377,6 +419,39 @@ def add_function(self, udf: UserDefinedFunction): f"RETURNS {output_type} LANGUAGE python " f"HANDLER = '{name}' ADDRESS = 'http://{self._location}';" ) + + ## http router register + @self.app.get("/" + name) + async def flight_info(): + # Return the flight info schema of the function + full_schema = pa.schema(list(udf._input_schema) + list(udf._result_schema)) + # Serialize the schema using PyArrow's IPC + buf = BytesIO() + writer = pa.ipc.new_stream(buf, full_schema) + writer.close() + # Return the serialized schema as a streaming response + return Response(content=buf.getvalue(), media_type="application/octet-stream") + + @self.app.post("/" + name) + async def handle(request: Request): + # Deserialize the RecordBatch from the input data + body = await request.body() + reader = pa.ipc.open_stream(BytesIO(body)) + batch = next(reader) + + # Create a generator that applies the UDF to the data and yields the results + async def generate_batches(): + # Get the first batch from the reader + result_batch = udf.eval_batch(batch) + buf = BytesIO() + writer = pa.ipc.new_stream(buf, udf._result_schema) + for b in result_batch: + writer.write_batch(b) + writer.close() + yield buf.getvalue() + + return StreamingResponse(generate_batches(), media_type="application/octet-stream") + logger.info(f"added function: {name}, SQL:\n{sql}\n") def serve(self): @@ -387,7 +462,9 @@ def serve(self): logger.info( f"Prometheus metrics available at http://{self._metric_location}/metrics" ) - + if self._http_location: + self._start_httpudf_server() + logger.info(f"UDF HTTP SERVER available at http://{self._http_location}") super(UDFServer, self).serve() diff --git a/python/example/http_client.py b/python/example/http_client.py new file mode 100644 index 0000000..fcbc560 --- /dev/null +++ b/python/example/http_client.py @@ -0,0 +1,46 @@ +import requests +import pyarrow as pa +from pyarrow import ipc +from io import BytesIO + + +def main(): + # Create a RecordBatch + data = [ + pa.array([1, 2, 3, 4]), + pa.array([5, 6, 7, 8]), + ] + schema = pa.schema( + [ + ("a", pa.int32()), + ("b", pa.int32()), + ] + ) + batch = pa.RecordBatch.from_arrays(data, schema=schema) + + + # fetch result schema + response = requests.get("http://localhost:8818/gcd") + reader = pa.ipc.open_stream(BytesIO(response.content)) + schema = reader.schema + print("schema \n\n", schema) + + # Serialize the RecordBatch + buf = BytesIO() + writer = pa.ipc.new_stream(buf, batch.schema) + writer.write_batch(batch) + writer.close() + serialized_batch = buf.getvalue() + + # Send the serialized RecordBatch to the server + response = requests.post("http://localhost:8818/gcd", data=serialized_batch) + # Deserialize the response + reader = pa.ipc.open_stream(BytesIO(response.content)) + result_batches = [b for b in reader] + # Print the result + for batch in result_batches: + print("res \n", batch) + + +if __name__ == "__main__": + main() diff --git a/python/example/server.py b/python/example/server.py index 0820241..8f141d1 100644 --- a/python/example/server.py +++ b/python/example/server.py @@ -18,8 +18,12 @@ import time from typing import List, Dict, Any, Tuple, Optional +import sys, os + +cwd = os.getcwd() +sys.path.append(cwd) + from databend_udf import udf, UDFServer -# from test import udf, UDFServer logging.basicConfig(level=logging.INFO) @@ -48,6 +52,7 @@ def bool_select(condition, a, b): name="gcd", input_types=["INT", "INT"], result_type="INT", + io_threads=1, skip_null=True, ) def gcd(x: int, y: int) -> int: @@ -300,7 +305,6 @@ def return_all_non_nullable( json, ) - @udf(input_types=["INT"], result_type="INT") def wait(x): time.sleep(0.1) @@ -312,9 +316,15 @@ def wait_concurrent(x): time.sleep(0.1) return x +@udf(input_types=["INT"], result_type="INT", batch_mode=True) +def wait_batch(x: List[int]) -> List[int]: + time.sleep(0.1) + return x if __name__ == "__main__": - udf_server = UDFServer("0.0.0.0:8815", metric_location="0.0.0.0:8816") + udf_server = UDFServer( + "0.0.0.0:8815", metric_location="0.0.0.0:8816", http_location="0.0.0.0:8818" + ) udf_server.add_function(add_signed) udf_server.add_function(add_unsigned) udf_server.add_function(add_float) @@ -338,4 +348,5 @@ def wait_concurrent(x): udf_server.add_function(return_all_non_nullable) udf_server.add_function(wait) udf_server.add_function(wait_concurrent) + udf_server.add_function(wait_batch) udf_server.serve() diff --git a/python/pyproject.toml b/python/pyproject.toml index 7449e9a..d6160d9 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -12,8 +12,14 @@ readme = "README.md" requires-python = ">=3.7" dependencies = [ "pyarrow", - "prometheus-client>=0.17.0" + "prometheus-client>=0.17.0", + "fastapi>=0.103.2", + "uvicorn>=0.22.0", ] + +[dev-dependencies] +requests = ">=2.31.0" + [project.optional-dependencies] lint = ["ruff"]