diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b4dbcb1..54f7158 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,10 +1,6 @@ name: Tests -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] +on: push jobs: api-unit-tests: diff --git a/LLSP-API/api.py b/LLSP-API/api.py index a102afa..6b3b476 100644 --- a/LLSP-API/api.py +++ b/LLSP-API/api.py @@ -7,10 +7,12 @@ import logging import os -from typing import Any +import sys +from typing import Annotated, Any from celery import Celery # type: ignore -from fastapi import FastAPI +from fastapi import Depends, FastAPI, HTTPException +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from model import ExecIn, Task, TaskState from utils import map_state @@ -19,6 +21,8 @@ BACKEND = os.getenv("CELERY_RESULT_BACKEND", "rpc://") TASK_NAME = os.getenv("EXEC_TASK_NAME", "celery_app.exec_script") +logger = logging.getLogger(__name__) + class EndpointFilter(logging.Filter): """Filter out log messages containing /healthz or /ready.""" @@ -36,15 +40,29 @@ def filter(self, record: logging.LogRecord) -> bool: app = FastAPI(title="Exec API") +security = HTTPBearer() + + +def verify_api_key(credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)]) -> None: + """ + Validate the API key from the request credentials. + + :param credentials: The HTTP authorization credentials from the request. + :return: True if the API key matches the environment variable, False otherwise. + """ + if credentials.credentials != os.environ.get("LLSP_API_KEY"): + raise HTTPException(status_code=401, detail="Invalid API key") + @app.post("/execute") -def execute(payload: ExecIn) -> Task: +def execute(payload: ExecIn, credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)]) -> Task: """ Submit a script for execution. :param payload: The script submission payload. :return: A `Task` object containing the task ID and initial status. """ + verify_api_key(credentials) async_result = celery.send_task( TASK_NAME, args=[payload.script], @@ -53,13 +71,14 @@ def execute(payload: ExecIn) -> Task: @app.get("/status/{task_id}") -def status(task_id: str) -> Task: +def status(task_id: str, credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)]) -> Task: """ Check the status of a submitted task. :param task_id: The ID of the task to check. :return: A `Task` object with the current state and result (if ready). """ + verify_api_key(credentials) res = celery.AsyncResult(task_id) body: Any = None try: @@ -96,4 +115,8 @@ def ready() -> dict[str, str]: :return: A dict indicating the service is ready. """ + api_key = os.environ.get("LLSP_API_KEY", None) + if api_key is None: + logger.critical("The LLSP_API_KEY environment variable is not set.") + sys.exit(1) return {"status": "ready"} diff --git a/LLSP-Worker/Dockerfile b/LLSP-Worker/Dockerfile index 79a6dd4..51c03f1 100644 --- a/LLSP-Worker/Dockerfile +++ b/LLSP-Worker/Dockerfile @@ -1,3 +1,4 @@ + FROM ghcr.io/fiaisis/mantid:6.15.0 ENV PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1 WORKDIR /app diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 6907bdf..058d080 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -19,6 +19,7 @@ services: CELERY_BROKER_URL: amqp://llsp:llsp@rabbitmq:5672/llspvhost CELERY_RESULT_BACKEND: rpc:// EXEC_TASK_NAME: celery_app.exec_script + LLSP_API_KEY: "secret-token" depends_on: rabbitmq: condition: service_healthy @@ -50,6 +51,7 @@ services: dockerfile: tests/Dockerfile environment: API_BASE_URL: http://api:8000 + LLSP_API_KEY: "secret-token" depends_on: api: condition: service_healthy diff --git a/docker-compose.yml b/docker-compose.yml index 2a984d3..d7b5808 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -19,6 +19,7 @@ services: CELERY_BROKER_URL: amqp://llsp:llsp@rabbitmq:5672/llspvhost CELERY_RESULT_BACKEND: rpc:// EXEC_TASK_NAME: celery_app.exec_script + LLSP_API_KEY: "secret-token" depends_on: rabbitmq: condition: service_healthy diff --git a/tests/e2e/test_submission.py b/tests/e2e/test_submission.py index cf6f407..5065bf6 100644 --- a/tests/e2e/test_submission.py +++ b/tests/e2e/test_submission.py @@ -6,14 +6,15 @@ import os import time +from http import HTTPStatus import requests -from tenacity import retry, stop_after_attempt, wait_fixed API_URL = os.getenv("API_BASE_URL", "http://localhost:8000") +API_KEY = os.getenv("LLSP_API_KEY", "secret-token") +HEADERS = {"Authorization": f"Bearer {API_KEY}"} -@retry(stop=stop_after_attempt(5), wait=wait_fixed(2)) def wait_for_api(): """ Wait for the API to become responsive. @@ -45,7 +46,12 @@ def test_workflow(): print('Hello from stdout') print('Hello from stderr', file=sys.stderr) """ - response = requests.post(f"{API_URL}/execute", json={"script": script}, timeout=10) + response = requests.post( + f"{API_URL}/execute", + json={"script": script}, + headers=HEADERS, + timeout=10, + ) response.raise_for_status() data = response.json() assert "task_id" in data @@ -53,7 +59,11 @@ def test_workflow(): # 2. Poll for Status for _ in range(30): - response = requests.get(f"{API_URL}/status/{task_id}", timeout=10) + response = requests.get( + f"{API_URL}/status/{task_id}", + headers=HEADERS, + timeout=10, + ) response.raise_for_status() state = response.json() if state["state"] in ["success", "error"]: @@ -70,5 +80,23 @@ def test_workflow(): assert "Hello from stderr" in result["stderr"] -if __name__ == "__main__": - test_workflow() +def test_unauthorized_access(): + """Verify that requests without a valid API key are rejected.""" + wait_for_api() + + # Case 1: No Authorization header + response = requests.post( + f"{API_URL}/execute", + json={"script": "print('fail')"}, + timeout=10, + ) + assert response.status_code == HTTPStatus.FORBIDDEN, f"Expected 403, got {response.status_code}" + + # Case 2: Invalid API key + response = requests.post( + f"{API_URL}/execute", + json={"script": "print('fail')"}, + headers={"Authorization": "Bearer invalid-token"}, + timeout=10, + ) + assert response.status_code == HTTPStatus.UNAUTHORIZED, f"Expected 401, got {response.status_code}"