diff --git a/cognite_toolkit/_cdf_tk/apps/__init__.py b/cognite_toolkit/_cdf_tk/apps/__init__.py index 04a6711886..7867158469 100644 --- a/cognite_toolkit/_cdf_tk/apps/__init__.py +++ b/cognite_toolkit/_cdf_tk/apps/__init__.py @@ -2,6 +2,7 @@ from ._core_app import CoreApp from ._data_app import DataApp from ._dev_app import DevApp +from ._dev_function_app import DevFunctionApp from ._download_app import DownloadApp from ._dump_app import DumpApp from ._import_app import ImportApp @@ -20,6 +21,7 @@ "CoreApp", "DataApp", "DevApp", + "DevFunctionApp", "DownloadApp", "DumpApp", "ImportApp", diff --git a/cognite_toolkit/_cdf_tk/apps/_dev_app.py b/cognite_toolkit/_cdf_tk/apps/_dev_app.py index b952465f9e..b0d04d2e2f 100644 --- a/cognite_toolkit/_cdf_tk/apps/_dev_app.py +++ b/cognite_toolkit/_cdf_tk/apps/_dev_app.py @@ -11,6 +11,7 @@ from cognite_toolkit._cdf_tk.feature_flags import FeatureFlag, Flags from cognite_toolkit._cdf_tk.utils.auth import EnvironmentVariables +from ._dev_function_app import DevFunctionApp from ._run import RunApp CDF_TOML = CDFToml.load(Path.cwd()) @@ -21,6 +22,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.callback(invoke_without_command=True)(self.main) self.add_typer(RunApp(*args, **kwargs), name="run") + self.add_typer(DevFunctionApp(*args, **kwargs), name="function") if FeatureFlag.is_enabled(Flags.CREATE): self.command("create")(self.create) diff --git a/cognite_toolkit/_cdf_tk/apps/_dev_function_app.py b/cognite_toolkit/_cdf_tk/apps/_dev_function_app.py new file mode 100644 index 0000000000..6efddb7fcb --- /dev/null +++ b/cognite_toolkit/_cdf_tk/apps/_dev_function_app.py @@ -0,0 +1,45 @@ +from pathlib import Path +from typing import Annotated, Any + +import typer +from rich import print + +from cognite_toolkit._cdf_tk.commands import ServeFunctionCommand + + +class DevFunctionApp(typer.Typer): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.callback(invoke_without_command=True)(self.main) + self.command("serve")(self.serve) + + @staticmethod + def main(ctx: typer.Context) -> None: + """Commands for function app development.""" + if ctx.invoked_subcommand is None: + print("Use [bold yellow]cdf dev function --help[/] for more information.") + + @staticmethod + def serve( + path: Annotated[ + Path | None, + typer.Argument(help="Path to the directory containing handler.py. If omitted, discovers and prompts."), + ] = None, + host: Annotated[str, typer.Option("--host", help="Host to bind to")] = "127.0.0.1", + port: Annotated[int, typer.Option("--port", help="Port to bind to")] = 8000, + reload: Annotated[bool, typer.Option("--reload/--no-reload", help="Enable auto-reload on code changes")] = True, + log_level: Annotated[ + str, + typer.Option("--log-level", help="Log level for the server"), + ] = "info", + ) -> None: + """Start a local development server for testing a function app handler. + + Loads your handler.py and starts a uvicorn dev server with Swagger UI. + + Example: + cdf dev function serve my_function/ + cdf dev function serve . --port 3000 --no-reload + """ + cmd = ServeFunctionCommand(client=None, skip_tracking=True) + cmd.run(lambda: cmd.serve(path, host, port, reload, log_level)) diff --git a/cognite_toolkit/_cdf_tk/commands/__init__.py b/cognite_toolkit/_cdf_tk/commands/__init__.py index 8d598690f5..4b79a8feca 100644 --- a/cognite_toolkit/_cdf_tk/commands/__init__.py +++ b/cognite_toolkit/_cdf_tk/commands/__init__.py @@ -21,6 +21,7 @@ from .repo import RepoCommand from .resources import ResourcesCommand from .run import RunFunctionCommand, RunTransformationCommand, RunWorkflowCommand +from .serve import ServeFunctionCommand __all__ = [ "AboutCommand", @@ -48,5 +49,6 @@ "RunFunctionCommand", "RunTransformationCommand", "RunWorkflowCommand", + "ServeFunctionCommand", "UploadCommand", ] diff --git a/cognite_toolkit/_cdf_tk/commands/_landing_page.py b/cognite_toolkit/_cdf_tk/commands/_landing_page.py new file mode 100644 index 0000000000..dd2089236d --- /dev/null +++ b/cognite_toolkit/_cdf_tk/commands/_landing_page.py @@ -0,0 +1,383 @@ +"""Landing page ASGI middleware for the function dev server. + +Provides: +- GET / → HTML landing page with function info, CDF warning, and live logs +- GET /api/logs → SSE endpoint streaming log lines +- GET /api/status → JSON with function info and last-reload timestamp +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +from collections import deque +from collections.abc import Awaitable, Callable +from datetime import datetime, timezone +from typing import Any + +# ASGI type aliases (same convention as cognite_function_apps.devserver.asgi) +Scope = dict[str, Any] +Receive = Callable[[], Awaitable[dict[str, Any]]] +Send = Callable[[dict[str, Any]], Awaitable[None]] +ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]] + + +class _LogCollector(logging.Handler): + """Logging handler that appends formatted records to a deque.""" + + def __init__(self, buffer: deque[dict[str, Any]], *, maxlen: int = 500) -> None: + super().__init__() + self.buffer = buffer + self._seq = 0 + + def emit(self, record: logging.LogRecord) -> None: + try: + self._seq += 1 + self.buffer.append( + { + "seq": self._seq, + "ts": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(), + "level": record.levelname, + "name": record.name, + "message": self.format(record), + } + ) + except Exception: + self.handleError(record) + + +class LandingPageMiddleware: + """ASGI middleware that adds a landing page, status API, and log streaming.""" + + def __init__( + self, + app: ASGIApp, + *, + handler_name: str, + handler_path: str, + cdf_project: str, + cdf_cluster: str, + tracing_enabled: bool = False, + tracing_endpoint: str = "", + ) -> None: + self.app = app + self.handler_name = handler_name + self.handler_path = handler_path + self.cdf_project = cdf_project + self.cdf_cluster = cdf_cluster + self.tracing_enabled = tracing_enabled + self.tracing_endpoint = tracing_endpoint + self.last_reload = datetime.now(tz=timezone.utc) + self._start_time = time.monotonic() + + # Log collection — attach to root and key loggers that may have propagate=False + self._log_buffer: deque[dict[str, Any]] = deque(maxlen=500) + self._log_handler = _LogCollector(self._log_buffer) + self._log_handler.setFormatter(logging.Formatter("%(name)s: %(message)s")) + logging.getLogger().addHandler(self._log_handler) + for name in ("uvicorn", "uvicorn.access", "uvicorn.error", "cognite_function_apps"): + logging.getLogger(name).addHandler(self._log_handler) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + path = scope.get("path", "") + method = scope.get("method", "GET") + + if method == "GET" and path == "/": + await self._serve_landing_page(send) + elif method == "GET" and path == "/api/logs": + await self._serve_sse_logs(send) + elif method == "GET" and path == "/api/status": + await self._serve_status(send) + else: + await self.app(scope, receive, send) + + async def _send_response(self, send: Send, *, status: int, content_type: str, body: bytes) -> None: + await send( + { + "type": "http.response.start", + "status": status, + "headers": [ + (b"content-type", content_type.encode()), + (b"content-length", str(len(body)).encode()), + ], + "trailers": False, + } + ) + await send({"type": "http.response.body", "body": body, "more_body": False}) + + async def _serve_status(self, send: Send) -> None: + uptime_s = round(time.monotonic() - self._start_time, 1) + payload = { + "handler_name": self.handler_name, + "handler_path": self.handler_path, + "cdf_project": self.cdf_project, + "cdf_cluster": self.cdf_cluster, + "last_reload": self.last_reload.isoformat(), + "uptime_seconds": uptime_s, + } + body = json.dumps(payload).encode() + await self._send_response(send, status=200, content_type="application/json", body=body) + + async def _serve_sse_logs(self, send: Send) -> None: + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"text/event-stream"), + (b"cache-control", b"no-cache"), + (b"connection", b"keep-alive"), + (b"x-accel-buffering", b"no"), + ], + "trailers": False, + } + ) + + # Send all buffered lines first + for entry in list(self._log_buffer): + chunk = f"data: {json.dumps(entry)}\n\n".encode() + await send({"type": "http.response.body", "body": chunk, "more_body": True}) + + last_seq = self._log_buffer[-1]["seq"] if self._log_buffer else 0 + + # Tail new entries + try: + while True: + await asyncio.sleep(0.5) + new_entries = [e for e in self._log_buffer if e["seq"] > last_seq] + for entry in new_entries: + chunk = f"data: {json.dumps(entry)}\n\n".encode() + await send({"type": "http.response.body", "body": chunk, "more_body": True}) + last_seq = entry["seq"] + except (asyncio.CancelledError, Exception): + # Client disconnected + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + async def _serve_landing_page(self, send: Send) -> None: + html = _build_landing_html( + handler_name=self.handler_name, + handler_path=self.handler_path, + cdf_project=self.cdf_project, + cdf_cluster=self.cdf_cluster, + last_reload=self.last_reload.strftime("%Y-%m-%d %H:%M:%S UTC"), + tracing_enabled=self.tracing_enabled, + tracing_endpoint=self.tracing_endpoint, + ) + await self._send_response(send, status=200, content_type="text/html; charset=utf-8", body=html.encode()) + + +def _build_landing_html( + *, + handler_name: str, + handler_path: str, + cdf_project: str, + cdf_cluster: str, + last_reload: str, + tracing_enabled: bool = False, + tracing_endpoint: str = "", +) -> str: + # Escape for safe HTML embedding + def esc(s: str) -> str: + return s.replace("&", "&").replace("<", "<").replace(">", ">").replace('"', """) + + if tracing_enabled: + if tracing_endpoint: + tracing_value = f'Configured — {esc(tracing_endpoint)}' + else: + tracing_value = 'Configured' + else: + tracing_value = 'Not configured' + + tracing_row = f"""
+ Tracing + {tracing_value} +
""" + + return f"""\ + + + + + +{esc(handler_name)} — CDF Dev Server + + + +
+ ⚠️ Connected to CDF project {esc(cdf_project)} + ({esc(cdf_cluster)}). + Handlers have full read/write access. +
+
+

{esc(handler_name)}

+

CDF Function Dev Server

+ +
+

Function

+
+ Handler + {esc(handler_path)} +
+
+ CDF Project + {esc(cdf_project)} +
+
+ Cluster + {esc(cdf_cluster)} +
+
+ Last Reload + {esc(last_reload)} +
+{tracing_row} +
+ Open API Docs +
+
+ +
+
+

Server Logs

+
+ + +
+
+
+
+
+ + + +""" diff --git a/cognite_toolkit/_cdf_tk/commands/serve.py b/cognite_toolkit/_cdf_tk/commands/serve.py new file mode 100644 index 0000000000..d8978c6110 --- /dev/null +++ b/cognite_toolkit/_cdf_tk/commands/serve.py @@ -0,0 +1,461 @@ +"""Serve command for running a local development server for Function Apps.""" + +import os +import re +import shutil +import sys +import tempfile +import threading +import webbrowser +from pathlib import Path +from typing import Any + +from rich import print + +from ._base import ToolkitCommand + + +class ServeFunctionCommand(ToolkitCommand): + def serve( + self, + path: Path | None, + host: str = "127.0.0.1", + port: int = 8000, + reload: bool = True, + log_level: str = "info", + ) -> None: + """Start a local development server for testing a function app handler.""" + try: + import uvicorn + except ImportError: + print( + "[bold red]Error:[/] Missing dependencies for serve command.\n" + "Install with: [bold]pip install cognite-toolkit\\[serve][/]\n" + "Or with uv: [bold]uv pip install cognite-toolkit\\[serve][/]" + ) + raise SystemExit(1) + + try: + from cognite_function_apps.cli import _load_handler_from_path # noqa: F401 + from cognite_function_apps.devserver import create_asgi_app # noqa: F401 + except ImportError: + print( + "[bold red]Error:[/] Missing [bold]cognite-function-apps[/] package.\n" + "Install with: [bold]pip install cognite-function-apps\\[cli][/]" + ) + raise SystemExit(1) + + # If no path given, discover and prompt + if path is None: + path = self._prompt_function_selection() + + handler_path = path.resolve() + self._validate_handler_directory(handler_path) + self._validate_handler_is_function_app(handler_path) + + # Authenticate via the toolkit's standard auth path + from cognite_toolkit._cdf_tk.utils.auth import EnvironmentVariables + + env_vars = EnvironmentVariables.create_from_environment() + cdf_project = env_vars.CDF_PROJECT + cdf_cluster = env_vars.CDF_CLUSTER + + # Check build type — block prod + validation_type = self._load_validation_type() + self._check_build_type(cdf_project, cdf_cluster, validation_type) + + url = f"http://{host}:{port}" + print(f"\n[bold green]Starting server at {url}[/]") + if reload: + print("[yellow]Auto-reload enabled — watching for changes...[/]") + print("[yellow]Press CTRL+C to quit[/]\n") + + # Open browser after a short delay so uvicorn has time to bind + threading.Timer(1.5, webbrowser.open, args=(url,)).start() + + handler_name = handler_path.name + + if reload: + self._run_server_with_reload( + uvicorn, handler_path, host, port, log_level, handler_name, cdf_project, cdf_cluster + ) + else: + self._run_server_without_reload( + uvicorn, handler_path, host, port, log_level, handler_name, cdf_project, cdf_cluster + ) + + # ── Function discovery ── + + @staticmethod + def _discover_function_dirs(organization_dir: Path | None = None) -> list[Path]: + """Discover function app directories by scanning for handler.py files + that import cognite_function_apps.""" + from cognite_toolkit._cdf_tk.utils.modules import iterate_modules + + root = organization_dir or Path.cwd() + function_dirs: list[Path] = [] + + for _module_dir, files in iterate_modules(root): + for f in files: + if f.name != "handler.py": + continue + # Check if parent is inside a functions/ folder + if f.parent.parent.name != "functions": + continue + # Quick check: is this a Function App handler? + try: + source = f.read_text() + except OSError: + continue + if "cognite_function_apps" in source or "create_function_service" in source: + function_dirs.append(f.parent) + + function_dirs.sort(key=lambda p: p.name) + return function_dirs + + @staticmethod + def _prompt_function_selection() -> Path: + """Discover function apps and prompt the user to pick one.""" + from cognite_toolkit._cdf_tk.cdf_toml import CDFToml + + toml = CDFToml.load(Path.cwd()) + org_dir = toml.cdf.default_organization_dir + + dirs = ServeFunctionCommand._discover_function_dirs(org_dir) + + if not dirs: + print( + "[bold red]Error:[/] No Function App handlers found.\n" + "Looked for handler.py files importing cognite_function_apps\n" + f"under [bold]{org_dir}[/]." + ) + raise SystemExit(1) + + if len(dirs) == 1: + print(f"[blue]Found one function app:[/] [bold]{dirs[0].name}[/] ({dirs[0]})") + return dirs[0] + + if not sys.stdin.isatty(): + print( + "[bold red]Error:[/] Multiple function apps found but stdin is not interactive.\n" + "Specify the function path explicitly: [bold]cdf dev function serve [/]" + ) + raise SystemExit(1) + + print("[bold]Available function apps:[/]\n") + for i, d in enumerate(dirs, 1): + # Show relative path from cwd for readability + try: + rel = d.relative_to(Path.cwd()) + except ValueError: + rel = d + print(f" [bold cyan]{i}[/] {d.name} [dim]({rel})[/]") + + print() + try: + answer = input(f"Select function [1-{len(dirs)}]: ").strip() + except (EOFError, KeyboardInterrupt): + print() + raise SystemExit(1) + + try: + idx = int(answer) - 1 + if not 0 <= idx < len(dirs): + raise ValueError + except ValueError: + print(f"[bold red]Error:[/] Invalid selection: {answer!r}") + raise SystemExit(1) + + return dirs[idx] + + # ── Config & validation ── + + @staticmethod + def _load_validation_type() -> str: + """Load the validation-type from the project's config YAML. + + Uses the same config file resolution as build/deploy: + reads config.{env}.yaml from the organization directory. + Falls back to CDF_BUILD_TYPE env var, then 'dev'. + """ + try: + from cognite_toolkit._cdf_tk.cdf_toml import CDFToml + from cognite_toolkit._cdf_tk.data_classes._config_yaml import BuildConfigYAML + + toml = CDFToml.load(Path.cwd()) + build_env = toml.cdf.default_env + org_dir = toml.cdf.default_organization_dir + config = BuildConfigYAML.load_from_directory(org_dir, build_env) + return config.environment.validation_type + except Exception: + # Fall back to env var + return os.environ.get("CDF_BUILD_TYPE", "").strip().lower() or "dev" + + @staticmethod + def _check_build_type(cdf_project: str, cdf_cluster: str, validation_type: str) -> None: + """Check validation type and prompt the user to acknowledge the risk.""" + validation_type = validation_type.strip().lower() + + if validation_type == "prod": + print( + "[bold red]Error:[/] The dev server cannot run against a production configuration.\n" + f" validation-type = [bold]prod[/]\n" + f" CDF_PROJECT = [bold]{cdf_project}[/]\n" + f" CDF_CLUSTER = [bold]{cdf_cluster}[/]\n\n" + "The dev server gives handlers [bold]full read/write access[/] to CDF.\n" + "Running against production risks accidental data mutation.\n\n" + "Use a [bold]dev[/] or [bold]staging[/] configuration instead." + ) + raise SystemExit(1) + + print( + f"[bold yellow]⚠ The dev server will authenticate to CDF with full read/write access.[/]\n" + f" validation-type = [bold]{validation_type}[/]\n" + f" CDF_PROJECT = [bold]{cdf_project}[/]\n" + f" CDF_CLUSTER = [bold]{cdf_cluster}[/]\n" + ) + + # Non-interactive environments (CI, piped stdin) skip the prompt + if not sys.stdin.isatty(): + return + + try: + answer = input("Continue? [y/N] ").strip().lower() + except (EOFError, KeyboardInterrupt): + print() + raise SystemExit(1) + + if answer not in ("y", "yes"): + raise SystemExit(0) + + @staticmethod + def _validate_handler_directory(handler_path: Path) -> None: + """Validate that the handler directory exists and is valid.""" + if not handler_path.exists(): + print(f"[bold red]Error:[/] Path does not exist: {handler_path}") + raise SystemExit(1) + + if not handler_path.is_dir(): + print(f"[bold red]Error:[/] Path is not a directory: {handler_path}") + raise SystemExit(1) + + dir_name = handler_path.name + if not dir_name.isidentifier(): + suggested_name = re.sub(r"\W|^(?=\d)", "_", dir_name) + print( + f"[bold red]Error:[/] Directory name '{dir_name}' is not a valid Python module name.\n" + f"[yellow]Suggested name:[/] [green]{suggested_name}[/]" + ) + raise SystemExit(1) + + if dir_name in sys.stdlib_module_names: + print( + f"[bold red]Error:[/] Directory name '{dir_name}' shadows a standard library module.\n" + "[yellow]Please rename the directory to avoid import conflicts.[/]" + ) + raise SystemExit(1) + + @staticmethod + def _validate_handler_is_function_app(handler_path: Path) -> None: + """Check that handler.py defines a FunctionService handle (not a classical function).""" + handler_file = handler_path / "handler.py" + if not handler_file.is_file(): + print(f"[bold red]Error:[/] handler.py not found in {handler_path}") + raise SystemExit(1) + + source = handler_file.read_text() + + # Quick heuristic: FunctionApp-based handlers import from cognite_function_apps + # and call create_function_service(). Classical handlers define `def handle(client, data)`. + has_function_apps_import = "cognite_function_apps" in source or "create_function_service" in source + if not has_function_apps_import: + print( + "[bold red]Error:[/] This handler appears to be a classical Cognite Function, " + "not a Function App.\n\n" + "The dev server only supports Function Apps that use [bold]cognite-function-apps[/].\n" + "Classical functions (with [bold]def handle(client, data)[/]) are not supported.\n\n" + "See the Function Apps documentation for how to migrate." + ) + raise SystemExit(1) + + @staticmethod + def _detect_tracing(handle: object) -> tuple[bool, str]: + """Detect if the loaded FunctionService uses tracing. + + Returns (tracing_enabled, backend_endpoint). + """ + try: + from cognite_function_apps.tracer import TracingApp + + # Walk the ASGI app chain looking for a TracingApp + app = getattr(handle, "asgi_app", None) + while app is not None: + if isinstance(app, TracingApp): + # Try to get the endpoint from the exporter provider closure + endpoint = "" + try: + from cognite_function_apps.tracer import OTLP_BACKENDS + + for _name, config in OTLP_BACKENDS.items(): + if config.endpoint and hasattr(app, "_exporter_provider"): + closure = getattr(app._exporter_provider, "__closure__", None) + if closure: + for cell in closure: + cell_val = cell.cell_contents + if hasattr(cell_val, "endpoint") and cell_val.endpoint == config.endpoint: + endpoint = config.endpoint + break + if endpoint: + break + except Exception: + pass + return True, endpoint + app = getattr(app, "next_app", None) + except ImportError: + pass + return False, "" + + @staticmethod + def _patch_cognite_client_factory() -> None: + """Monkey-patch cognite_function_apps to use the toolkit's auth path.""" + import importlib + + from cognite_toolkit._cdf_tk.utils.auth import EnvironmentVariables + + asgi_module = importlib.import_module("cognite_function_apps.devserver.asgi") + + def _toolkit_get_client() -> object: + env_vars = EnvironmentVariables.create_from_environment() + return env_vars.get_client(is_strict_validation=False) + + asgi_module.get_cognite_client_from_env = _toolkit_get_client # type: ignore[attr-defined] + + # ── Server startup ── + + @staticmethod + def _run_server_with_reload( + uvicorn: Any, + handler_path: Path, + host: str, + port: int, + log_level: str, + handler_name: str, + cdf_project: str, + cdf_cluster: str, + ) -> None: + """Run the development server with auto-reload enabled.""" + package_root = handler_path.parent + package_name = re.sub(r"\W|^(?=\d)", "_", handler_path.name) + + temp_dir = tempfile.mkdtemp(prefix="cdf_serve_") + temp_app_file = Path(temp_dir) / "_cdf_serve_asgi_app.py" + temp_app_file.write_text( + f'''"""Temporary ASGI app for cdf dev function serve with reload support.""" +import sys +from pathlib import Path +import importlib + +package_root = Path({str(package_root)!r}) +if str(package_root) not in sys.path: + sys.path.insert(0, str(package_root)) + +# Patch cognite_function_apps to use the toolkit's auth path +from cognite_toolkit._cdf_tk.commands.serve import ServeFunctionCommand +ServeFunctionCommand._patch_cognite_client_factory() + +from cognite_function_apps.devserver import create_asgi_app +from cognite_toolkit._cdf_tk.commands._landing_page import LandingPageMiddleware + +handler_module = importlib.import_module("{package_name}.handler") +_inner_app = create_asgi_app(handler_module.handle) + +# Detect tracing from the loaded handler +_tracing_enabled, _tracing_endpoint = ServeFunctionCommand._detect_tracing(handler_module.handle) + +app = LandingPageMiddleware( + _inner_app, + handler_name={handler_name!r}, + handler_path={str(handler_path / "handler.py")!r}, + cdf_project={cdf_project!r}, + cdf_cluster={cdf_cluster!r}, + tracing_enabled=_tracing_enabled, + tracing_endpoint=_tracing_endpoint, +) +''' + ) + + temp_dir_added = temp_dir not in sys.path + if temp_dir_added: + sys.path.insert(0, temp_dir) + + try: + uvicorn.run( + "_cdf_serve_asgi_app:app", + host=host, + port=port, + reload=True, + reload_dirs=[str(handler_path)], + log_level=log_level, + timeout_graceful_shutdown=1, + ) + finally: + if temp_dir_added and temp_dir in sys.path: + sys.path.remove(temp_dir) + shutil.rmtree(temp_dir, ignore_errors=True) + + @staticmethod + def _run_server_without_reload( + uvicorn: Any, + handler_path: Path, + host: str, + port: int, + log_level: str, + handler_name: str, + cdf_project: str, + cdf_cluster: str, + ) -> None: + """Run the development server without auto-reload.""" + from cognite_function_apps.cli import _load_handler_from_path + from cognite_function_apps.devserver import create_asgi_app + + from ._landing_page import LandingPageMiddleware + + package_root = str(handler_path.parent) + package_root_added = package_root not in sys.path + if package_root_added: + sys.path.insert(0, package_root) + + try: + # Patch before create_asgi_app calls get_cognite_client_from_env + ServeFunctionCommand._patch_cognite_client_factory() + + print(f"[blue]Loading handler from {handler_path}/handler.py...[/]") + handle = _load_handler_from_path(handler_path) + print("[green]Handler loaded successfully[/]") + + # Detect tracing + tracing_enabled, tracing_endpoint = ServeFunctionCommand._detect_tracing(handle) + + print("[blue]Creating ASGI app...[/]") + inner_app = create_asgi_app(handle) + asgi_app = LandingPageMiddleware( + inner_app, # type: ignore[arg-type] + handler_name=handler_name, + handler_path=str(handler_path / "handler.py"), + cdf_project=cdf_project, + cdf_cluster=cdf_cluster, + tracing_enabled=tracing_enabled, + tracing_endpoint=tracing_endpoint, + ) + print("[green]ASGI app created[/]") + + uvicorn.run( + asgi_app, + host=host, + port=port, + reload=False, + log_level=log_level, + ) + finally: + if package_root_added and package_root in sys.path: + sys.path.remove(package_root) diff --git a/pyproject.toml b/pyproject.toml index 593038959a..c4e9efc6f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,9 +82,11 @@ sql = [ "sqlparse >=0.5.3", ] v08 = [ - "cognite-neat >=1.0.43" ] +serve = [ + "cognite-function-apps[cli]", +] [project.scripts] cdf-tk = "cognite_toolkit._cdf:app" diff --git a/tests/test_unit/test_cdf_tk/test_commands/test_serve.py b/tests/test_unit/test_cdf_tk/test_commands/test_serve.py new file mode 100644 index 0000000000..9fee2d55d5 --- /dev/null +++ b/tests/test_unit/test_cdf_tk/test_commands/test_serve.py @@ -0,0 +1,332 @@ +"""Tests for the serve command and landing page middleware.""" + +from __future__ import annotations + +import asyncio +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from cognite_toolkit._cdf_tk.commands._landing_page import LandingPageMiddleware +from cognite_toolkit._cdf_tk.commands.serve import ServeFunctionCommand + +# ── ServeFunctionCommand._validate_handler_directory ── + + +class TestValidateHandlerDirectory: + def test_nonexistent_path(self, tmp_path: Path) -> None: + with pytest.raises(SystemExit): + ServeFunctionCommand._validate_handler_directory(tmp_path / "nope") + + def test_file_not_directory(self, tmp_path: Path) -> None: + f = tmp_path / "handler.py" + f.write_text("") + with pytest.raises(SystemExit): + ServeFunctionCommand._validate_handler_directory(f) + + def test_invalid_module_name(self, tmp_path: Path) -> None: + d = tmp_path / "not-valid-ident" + d.mkdir() + with pytest.raises(SystemExit): + ServeFunctionCommand._validate_handler_directory(d) + + def test_stdlib_shadow(self, tmp_path: Path) -> None: + d = tmp_path / "json" + d.mkdir() + with pytest.raises(SystemExit): + ServeFunctionCommand._validate_handler_directory(d) + + def test_valid_directory(self, tmp_path: Path) -> None: + d = tmp_path / "my_handler" + d.mkdir() + # Should not raise + ServeFunctionCommand._validate_handler_directory(d) + + +# ── ServeFunctionCommand._validate_handler_is_function_app ── + + +class TestValidateHandlerIsFunctionApp: + def test_rejects_classical_handler(self, tmp_path: Path) -> None: + d = tmp_path / "my_func" + d.mkdir() + (d / "handler.py").write_text("def handle(client, data):\n return {}\n") + with pytest.raises(SystemExit): + ServeFunctionCommand._validate_handler_is_function_app(d) + + def test_accepts_function_app_handler(self, tmp_path: Path) -> None: + d = tmp_path / "my_func" + d.mkdir() + (d / "handler.py").write_text( + "from cognite_function_apps import FunctionApp, create_function_service\n" + "app = FunctionApp('test', '1.0')\n" + "handle = create_function_service(app)\n" + ) + # Should not raise + ServeFunctionCommand._validate_handler_is_function_app(d) + + def test_missing_handler_file(self, tmp_path: Path) -> None: + d = tmp_path / "my_func" + d.mkdir() + with pytest.raises(SystemExit): + ServeFunctionCommand._validate_handler_is_function_app(d) + + +# ── ServeFunctionCommand._check_build_type ── + + +class TestCheckBuildType: + def test_blocks_prod(self) -> None: + with pytest.raises(SystemExit): + ServeFunctionCommand._check_build_type("my-project", "westeurope-1", "prod") + + def test_blocks_prod_case_insensitive(self) -> None: + with pytest.raises(SystemExit): + ServeFunctionCommand._check_build_type("my-project", "westeurope-1", "Prod") + + def test_prompts_for_dev(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("builtins.input", lambda _: "y") + with patch("sys.stdin") as mock_stdin: + mock_stdin.isatty.return_value = True + # Should not raise when user says "y" + ServeFunctionCommand._check_build_type("my-project", "westeurope-1", "dev") + + def test_aborts_on_no(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("builtins.input", lambda _: "n") + with patch("sys.stdin") as mock_stdin: + mock_stdin.isatty.return_value = True + with pytest.raises(SystemExit): + ServeFunctionCommand._check_build_type("my-project", "westeurope-1", "dev") + + def test_skips_prompt_in_non_tty(self) -> None: + with patch("sys.stdin") as mock_stdin: + mock_stdin.isatty.return_value = False + # Should not raise or prompt + ServeFunctionCommand._check_build_type("my-project", "westeurope-1", "dev") + + +# ── ServeFunctionCommand._patch_cognite_client_factory ── + + +class TestPatchCogniteClientFactory: + def test_patches_get_cognite_client_from_env(self) -> None: + import types + + mock_client = MagicMock() + mock_env_vars = MagicMock() + mock_env_vars.get_client.return_value = mock_client + + # Create a real module object so importlib.import_module can find it + fake_asgi = types.ModuleType("cognite_function_apps.devserver.asgi") + fake_asgi.get_cognite_client_from_env = lambda: None # type: ignore[attr-defined] + + # Also need fake parent packages for the import chain + import sys + + saved_keys = ["cognite_function_apps.devserver.asgi", "cognite_function_apps.devserver"] + saved = {k: sys.modules.get(k) for k in saved_keys} + + fake_devserver = types.ModuleType("cognite_function_apps.devserver") + fake_devserver.asgi = fake_asgi # type: ignore[attr-defined] + sys.modules["cognite_function_apps.devserver"] = fake_devserver + sys.modules["cognite_function_apps.devserver.asgi"] = fake_asgi + try: + with patch( + "cognite_toolkit._cdf_tk.utils.auth.EnvironmentVariables.create_from_environment", + return_value=mock_env_vars, + ): + ServeFunctionCommand._patch_cognite_client_factory() + + # Call the patched factory — must be inside the patch context + # because it calls EnvironmentVariables.create_from_environment() + result = fake_asgi.get_cognite_client_from_env() # type: ignore[attr-defined] + + mock_env_vars.get_client.assert_called_with(is_strict_validation=False) + assert result is mock_client + finally: + for k in saved_keys: + if saved[k] is None: + sys.modules.pop(k, None) + else: + sys.modules[k] = saved[k] + + +# ── LandingPageMiddleware ── + + +def _run_async(coro): + """Helper to run async code in tests.""" + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + +class _ResponseCollector: + """Collects ASGI send() calls for inspection.""" + + def __init__(self) -> None: + self.messages: list[dict] = [] + + async def __call__(self, message: dict) -> None: + self.messages.append(message) + + @property + def status(self) -> int: + return self.messages[0]["status"] + + @property + def headers_dict(self) -> dict[str, str]: + return {k.decode(): v.decode() for k, v in self.messages[0].get("headers", [])} + + @property + def body(self) -> bytes: + return b"".join(m.get("body", b"") for m in self.messages if m["type"] == "http.response.body") + + @property + def body_text(self) -> str: + return self.body.decode() + + +def _make_scope(path: str = "/", method: str = "GET") -> dict: + return { + "type": "http", + "method": method, + "path": path, + "query_string": b"", + "headers": [], + } + + +async def _noop_receive(): + return {"type": "http.request", "body": b"", "more_body": False} + + +def _make_middleware(inner_app=None, **kwargs): + if inner_app is None: + + async def inner_app(scope, receive, send): + await send({"type": "http.response.start", "status": 200, "headers": [], "trailers": False}) + await send({"type": "http.response.body", "body": b"inner", "more_body": False}) + + defaults = dict( + handler_name="my_func", + handler_path="/path/to/my_func/handler.py", + cdf_project="test-project", + cdf_cluster="westeurope-1", + ) + defaults.update(kwargs) + + return LandingPageMiddleware(inner_app, **defaults) + + +class TestLandingPageMiddleware: + def test_landing_page_returns_html(self) -> None: + mw = _make_middleware() + collector = _ResponseCollector() + _run_async(mw(_make_scope("/"), _noop_receive, collector)) + + assert collector.status == 200 + assert "text/html" in collector.headers_dict["content-type"] + body = collector.body_text + assert "my_func" in body + assert "test-project" in body + assert "westeurope-1" in body + assert "/docs" in body + assert "read/write" in body.lower() or "read AND WRITE" in body or "read/write access" in body.lower() + + def test_landing_page_shows_tracing_not_configured(self) -> None: + mw = _make_middleware(tracing_enabled=False) + collector = _ResponseCollector() + _run_async(mw(_make_scope("/"), _noop_receive, collector)) + + body = collector.body_text + assert "Tracing" in body + assert "Not configured" in body + + def test_landing_page_shows_tracing_configured(self) -> None: + mw = _make_middleware(tracing_enabled=True, tracing_endpoint="https://api.eu1.honeycomb.io:443") + collector = _ResponseCollector() + _run_async(mw(_make_scope("/"), _noop_receive, collector)) + + body = collector.body_text + assert "Tracing" in body + assert "Configured" in body + assert "honeycomb" in body + + def test_status_endpoint_returns_json(self) -> None: + mw = _make_middleware() + collector = _ResponseCollector() + _run_async(mw(_make_scope("/api/status"), _noop_receive, collector)) + + assert collector.status == 200 + assert "application/json" in collector.headers_dict["content-type"] + data = json.loads(collector.body) + assert data["handler_name"] == "my_func" + assert data["handler_path"] == "/path/to/my_func/handler.py" + assert data["cdf_project"] == "test-project" + assert data["cdf_cluster"] == "westeurope-1" + assert "last_reload" in data + assert "uptime_seconds" in data + + def test_passthrough_to_inner_app(self) -> None: + mw = _make_middleware() + collector = _ResponseCollector() + _run_async(mw(_make_scope("/docs"), _noop_receive, collector)) + + assert collector.status == 200 + assert collector.body == b"inner" + + def test_passthrough_non_http(self) -> None: + called = [] + + async def inner_app(scope, receive, send): + called.append(scope["type"]) + + mw = _make_middleware(inner_app) + scope = {"type": "lifespan"} + _run_async(mw(scope, _noop_receive, lambda msg: asyncio.sleep(0))) + + assert called == ["lifespan"] + + def test_post_request_passes_through(self) -> None: + mw = _make_middleware() + collector = _ResponseCollector() + _run_async(mw(_make_scope("/", method="POST"), _noop_receive, collector)) + + # POST / should pass through to inner app, not serve landing page + assert collector.body == b"inner" + + def test_sse_logs_endpoint_streams(self) -> None: + """Test that /api/logs starts an SSE stream with correct headers.""" + import logging + + mw = _make_middleware() + + # Add a log entry so the buffer has content + logger = logging.getLogger("test.serve") + logger.setLevel(logging.DEBUG) + logger.info("test log message") + + collector = _ResponseCollector() + + async def run_sse(): + # The SSE endpoint loops forever; we cancel after getting the initial burst + task = asyncio.ensure_future(mw(_make_scope("/api/logs"), _noop_receive, collector)) + # Give it time to send buffered logs + await asyncio.sleep(0.2) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + _run_async(run_sse()) + + assert collector.messages[0]["status"] == 200 + headers = {k.decode(): v.decode() for k, v in collector.messages[0]["headers"]} + assert headers["content-type"] == "text/event-stream" + assert headers["cache-control"] == "no-cache"