From 7ca0eaebe9c4c8f4c4e63b577080cc5353cc8bf2 Mon Sep 17 00:00:00 2001 From: sundyli <543950155@qq.com> Date: Tue, 25 Feb 2025 14:58:17 +0800 Subject: [PATCH 1/5] feat: enable http udf server --- .gitignore | 1 + python/databend_udf/udf.py | 63 +++++++++++++++++++++++++++++++++++--- python/example/server.py | 10 ++++-- python/pyproject.toml | 8 ++++- 4 files changed, 74 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 2a5aa9a..a7f514f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ venv/ dist/ __pycache__/ Pipfile.lock +uv.lock .ruff_cache/ .vscode python/example/test.py diff --git a/python/databend_udf/udf.py b/python/databend_udf/udf.py index ee13ac8..3fc6a3e 100644 --- a/python/databend_udf/udf.py +++ b/python/databend_udf/udf.py @@ -21,6 +21,13 @@ from prometheus_client import start_http_server import threading +from fastapi import FastAPI, Request, Response +from typing import Any, Dict +from uvicorn import run +import pyarrow as pa +from pyarrow import ipc +from io import BytesIO + import pyarrow as pa from pyarrow.flight import FlightServerBase, FlightInfo @@ -232,11 +239,19 @@ class UDFServer(FlightServerBase): _location: str _functions: Dict[str, UserDefinedFunction] - def __init__(self, location="0.0.0.0:8815", metric_location=None, **kwargs): + def __init__( + self, + location="0.0.0.0:8815", + metric_location=None, + http_location=None, + **kwargs, + ): super(UDFServer, self).__init__("grpc://" + location, **kwargs) self._location = location self._metric_location = metric_location + self._http_location = http_location self._functions = {} + self.app = FastAPI() # Initialize Prometheus metrics self.requests_count = Counter( @@ -296,9 +311,6 @@ def _start_metrics_server(self): def start_server(): start_http_server(port, host) - logger.info( - f"Prometheus metrics server started on {self._metric_location}" - ) metrics_thread = threading.Thread(target=start_server, daemon=True) metrics_thread.start() @@ -306,6 +318,26 @@ def start_server(): logger.error(f"Failed to start metrics server: {e}") raise + def _start_httpudf_server(self): + """Start UDF HTTP server if http_location is provided""" + try: + host, port = self._http_location.split(":") + port = int(port) + app = self.app + + @app.get("/_is_http") + async def is_http(): + return 1 + + def start_server(): + run(app, host=host, port=port) + + http_thread = threading.Thread(target=start_server, daemon=True) + http_thread.start() + except Exception as e: + logger.error(f"Failed to start http udf server: {e}") + raise + def get_flight_info(self, context, descriptor): """Return the result schema of a function.""" func_name = descriptor.path[0].decode("utf-8") @@ -377,6 +409,25 @@ def add_function(self, udf: UserDefinedFunction): f"RETURNS {output_type} LANGUAGE python " f"HANDLER = '{name}' ADDRESS = 'http://{self._location}';" ) + + ## http router register + @self.app.post("/" + name) + async def handle(request: Request): + # Deserialize the RecordBatch from the input data + body = await request.body() + reader = pa.ipc.open_stream(BytesIO(body)) + batches = [b for b in reader] + # Apply the UDF to the data + result_batches = [udf.eval_batch(batch) for batch in batches] + # Serialize the result to send it back + buf = BytesIO() + writer = pa.ipc.new_stream(buf, udf._result_schema) + for batch in result_batches: + for b in batch: + writer.write_batch(b) + writer.close() + return Response(content=buf.getvalue(), media_type="text/plain") + logger.info(f"added function: {name}, SQL:\n{sql}\n") def serve(self): @@ -387,7 +438,9 @@ def serve(self): logger.info( f"Prometheus metrics available at http://{self._metric_location}/metrics" ) - + if self._http_location: + self._start_httpudf_server() + logger.info(f"UDF HTTP SERVER available at http://{self._http_location}") super(UDFServer, self).serve() diff --git a/python/example/server.py b/python/example/server.py index 0820241..ad7cf2d 100644 --- a/python/example/server.py +++ b/python/example/server.py @@ -18,8 +18,12 @@ import time from typing import List, Dict, Any, Tuple, Optional +import sys, os + +cwd = os.getcwd() +sys.path.append(cwd) + from databend_udf import udf, UDFServer -# from test import udf, UDFServer logging.basicConfig(level=logging.INFO) @@ -314,7 +318,9 @@ def wait_concurrent(x): if __name__ == "__main__": - udf_server = UDFServer("0.0.0.0:8815", metric_location="0.0.0.0:8816") + udf_server = UDFServer( + "0.0.0.0:8815", metric_location="0.0.0.0:8816", http_location="0.0.0.0:8818" + ) udf_server.add_function(add_signed) udf_server.add_function(add_unsigned) udf_server.add_function(add_float) diff --git a/python/pyproject.toml b/python/pyproject.toml index 7449e9a..d3647c2 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -12,8 +12,14 @@ readme = "README.md" requires-python = ">=3.7" dependencies = [ "pyarrow", - "prometheus-client>=0.17.0" + "prometheus-client>=0.17.0", + "fastapi>=0.103.2", + "uvicorn>=0.22.0", ] + +[project.dev-dependencies] +requests = ">=2.31.0" + [project.optional-dependencies] lint = ["ruff"] From a463d9e7ce6edf067ee08c1a25aa4a51f27971cb Mon Sep 17 00:00:00 2001 From: sundyli <543950155@qq.com> Date: Tue, 25 Feb 2025 14:58:22 +0800 Subject: [PATCH 2/5] feat: enable http udf server --- python/example/http_client.py | 41 +++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 python/example/http_client.py diff --git a/python/example/http_client.py b/python/example/http_client.py new file mode 100644 index 0000000..5d4eda2 --- /dev/null +++ b/python/example/http_client.py @@ -0,0 +1,41 @@ +import requests +import pyarrow as pa +from pyarrow import ipc +from io import BytesIO + + +def main(): + # Create a RecordBatch + data = [ + pa.array([1, 2, 3, 4]), + pa.array([5, 6, 7, 8]), + ] + schema = pa.schema( + [ + ("a", pa.int32()), + ("b", pa.int32()), + ] + ) + batch = pa.RecordBatch.from_arrays(data, schema=schema) + + # Serialize the RecordBatch + buf = BytesIO() + writer = pa.ipc.new_stream(buf, batch.schema) + writer.write_batch(batch) + writer.close() + serialized_batch = buf.getvalue() + + # Send the serialized RecordBatch to the server + response = requests.post("http://localhost:8818/gcd", data=serialized_batch) + + # Deserialize the response + reader = pa.ipc.open_stream(BytesIO(response.content)) + result_batches = [b for b in reader] + + # Print the result + for batch in result_batches: + print("res \n", batch) + + +if __name__ == "__main__": + main() From a625d7845c5fdbe74afb0ce801a17e22aad7573d Mon Sep 17 00:00:00 2001 From: sundyli <543950155@qq.com> Date: Tue, 25 Feb 2025 17:53:40 +0800 Subject: [PATCH 3/5] feat: enable http udf server --- python/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index d3647c2..d6160d9 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "uvicorn>=0.22.0", ] -[project.dev-dependencies] +[dev-dependencies] requests = ">=2.31.0" [project.optional-dependencies] From fd673b4a3dac8bc0311815a89a08584b10ef8624 Mon Sep 17 00:00:00 2001 From: sundyli <543950155@qq.com> Date: Wed, 26 Feb 2025 14:08:33 +0800 Subject: [PATCH 4/5] feat(query): update --- python/databend_udf/udf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/databend_udf/udf.py b/python/databend_udf/udf.py index 3fc6a3e..e92c643 100644 --- a/python/databend_udf/udf.py +++ b/python/databend_udf/udf.py @@ -325,9 +325,9 @@ def _start_httpudf_server(self): port = int(port) app = self.app - @app.get("/_is_http") - async def is_http(): - return 1 + @app.get("/") + async def root(): + return {"protocol" : "http", "description": "databend-udf-server"} def start_server(): run(app, host=host, port=port) From 48fada2372b19555c83e4d6e7236ad6d1980df74 Mon Sep 17 00:00:00 2001 From: sundyli <543950155@qq.com> Date: Thu, 27 Feb 2025 09:16:46 +0800 Subject: [PATCH 5/5] feat(query): update --- python/README.md | 6 +++++ python/databend_udf/udf.py | 46 ++++++++++++++++++++++++++--------- python/example/http_client.py | 9 +++++-- python/example/server.py | 7 +++++- 4 files changed, 54 insertions(+), 14 deletions(-) diff --git a/python/README.md b/python/README.md index 772b9d0..7ed5cc6 100644 --- a/python/README.md +++ b/python/README.md @@ -1,5 +1,11 @@ ### Usage +#### Features + +- Arrow flight server (defaults to port 8815) +- HTTP server (disabled by default) +- Prometheus metrics server (disabled by default) + #### 1. Define your functions in a Python file ```python from databend_udf import * diff --git a/python/databend_udf/udf.py b/python/databend_udf/udf.py index e92c643..ecaac51 100644 --- a/python/databend_udf/udf.py +++ b/python/databend_udf/udf.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import time import logging import inspect from concurrent.futures import ThreadPoolExecutor @@ -22,6 +23,7 @@ import threading from fastapi import FastAPI, Request, Response +from fastapi.responses import StreamingResponse from typing import Any, Dict from uvicorn import run import pyarrow as pa @@ -223,7 +225,6 @@ def gcd(x, y): batch_mode=batch_mode, ) - class UDFServer(FlightServerBase): """ A server that provides user-defined functions to clients. @@ -325,6 +326,15 @@ def _start_httpudf_server(self): port = int(port) app = self.app + # Middleware to measure elapsed time + @app.middleware("http") + async def log_elapsed_time(request: Request, call_next): + start_time = time.time() + response = await call_next(request) + elapsed_time = time.time() - start_time + logger.info(f"{request.method} {request.url.path} - From {start_time}, Elapsed time: {elapsed_time:.4f} seconds") + return response + @app.get("/") async def root(): return {"protocol" : "http", "description": "databend-udf-server"} @@ -411,22 +421,36 @@ def add_function(self, udf: UserDefinedFunction): ) ## http router register + @self.app.get("/" + name) + async def flight_info(): + # Return the flight info schema of the function + full_schema = pa.schema(list(udf._input_schema) + list(udf._result_schema)) + # Serialize the schema using PyArrow's IPC + buf = BytesIO() + writer = pa.ipc.new_stream(buf, full_schema) + writer.close() + # Return the serialized schema as a streaming response + return Response(content=buf.getvalue(), media_type="application/octet-stream") + @self.app.post("/" + name) async def handle(request: Request): # Deserialize the RecordBatch from the input data body = await request.body() reader = pa.ipc.open_stream(BytesIO(body)) - batches = [b for b in reader] - # Apply the UDF to the data - result_batches = [udf.eval_batch(batch) for batch in batches] - # Serialize the result to send it back - buf = BytesIO() - writer = pa.ipc.new_stream(buf, udf._result_schema) - for batch in result_batches: - for b in batch: + batch = next(reader) + + # Create a generator that applies the UDF to the data and yields the results + async def generate_batches(): + # Get the first batch from the reader + result_batch = udf.eval_batch(batch) + buf = BytesIO() + writer = pa.ipc.new_stream(buf, udf._result_schema) + for b in result_batch: writer.write_batch(b) - writer.close() - return Response(content=buf.getvalue(), media_type="text/plain") + writer.close() + yield buf.getvalue() + + return StreamingResponse(generate_batches(), media_type="application/octet-stream") logger.info(f"added function: {name}, SQL:\n{sql}\n") diff --git a/python/example/http_client.py b/python/example/http_client.py index 5d4eda2..fcbc560 100644 --- a/python/example/http_client.py +++ b/python/example/http_client.py @@ -18,6 +18,13 @@ def main(): ) batch = pa.RecordBatch.from_arrays(data, schema=schema) + + # fetch result schema + response = requests.get("http://localhost:8818/gcd") + reader = pa.ipc.open_stream(BytesIO(response.content)) + schema = reader.schema + print("schema \n\n", schema) + # Serialize the RecordBatch buf = BytesIO() writer = pa.ipc.new_stream(buf, batch.schema) @@ -27,11 +34,9 @@ def main(): # Send the serialized RecordBatch to the server response = requests.post("http://localhost:8818/gcd", data=serialized_batch) - # Deserialize the response reader = pa.ipc.open_stream(BytesIO(response.content)) result_batches = [b for b in reader] - # Print the result for batch in result_batches: print("res \n", batch) diff --git a/python/example/server.py b/python/example/server.py index ad7cf2d..8f141d1 100644 --- a/python/example/server.py +++ b/python/example/server.py @@ -52,6 +52,7 @@ def bool_select(condition, a, b): name="gcd", input_types=["INT", "INT"], result_type="INT", + io_threads=1, skip_null=True, ) def gcd(x: int, y: int) -> int: @@ -304,7 +305,6 @@ def return_all_non_nullable( json, ) - @udf(input_types=["INT"], result_type="INT") def wait(x): time.sleep(0.1) @@ -316,6 +316,10 @@ def wait_concurrent(x): time.sleep(0.1) return x +@udf(input_types=["INT"], result_type="INT", batch_mode=True) +def wait_batch(x: List[int]) -> List[int]: + time.sleep(0.1) + return x if __name__ == "__main__": udf_server = UDFServer( @@ -344,4 +348,5 @@ def wait_concurrent(x): udf_server.add_function(return_all_non_nullable) udf_server.add_function(wait) udf_server.add_function(wait_concurrent) + udf_server.add_function(wait_batch) udf_server.serve()