diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3ef9625..972ffd3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -139,7 +139,7 @@ from fxa.tools.bearer import get_bearer_token fxa_token: str = get_bearer_token( your_mozilla_account_email, your_mozilla_account_password, - scopes=["profile"], + scopes=["profile:uid"], client_id="5882386c6d801776" # a common client_id for the dev environment, account_server_url="https://api.accounts.firefox.com", oauth_server_url="https://oauth.accounts.firefox.com", @@ -148,7 +148,7 @@ fxa_token: str = get_bearer_token( ## How to update static docs/index.html from redoc -Make sure you have Node installed +Ensure Node is installed 1. `make install` 2. `mlpa` diff --git a/src/mlpa/core/config.py b/src/mlpa/core/config.py index a72f0c2..f154c0f 100644 --- a/src/mlpa/core/config.py +++ b/src/mlpa/core/config.py @@ -96,6 +96,9 @@ def valid_service_types(self) -> list[str]: # FxA CLIENT_ID: str = "default-client-id" CLIENT_SECRET: str = "default-client-secret" + ADDITIONAL_FXA_SCOPE_1: str | None = None + ADDITIONAL_FXA_SCOPE_2: str | None = None + ADDITIONAL_FXA_SCOPE_3: str | None = None # PostgreSQL LITELLM_DB_NAME: str = "litellm" diff --git a/src/mlpa/core/routers/fxa/fxa.py b/src/mlpa/core/routers/fxa/fxa.py index 11af14e..2608c71 100644 --- a/src/mlpa/core/routers/fxa/fxa.py +++ b/src/mlpa/core/routers/fxa/fxa.py @@ -1,9 +1,11 @@ +import asyncio import time from typing import Annotated from fastapi import APIRouter, Header, HTTPException from fastapi.concurrency import run_in_threadpool +from mlpa.core.config import env from mlpa.core.logger import logger from mlpa.core.prometheus_metrics import PrometheusResult, metrics from mlpa.core.utils import get_fxa_client @@ -11,19 +13,45 @@ router = APIRouter() client = get_fxa_client() +FXA_DEFAULT_SCOPE = "profile:uid" +FXA_SCOPES = tuple( + scope + for scope in ( + FXA_DEFAULT_SCOPE, + env.ADDITIONAL_FXA_SCOPE_1, + env.ADDITIONAL_FXA_SCOPE_2, + env.ADDITIONAL_FXA_SCOPE_3, + ) + if scope +) async def fxa_auth(authorization: Annotated[str | None, Header()]): start_time = time.perf_counter() token = authorization.removeprefix("Bearer ").split()[0] result = PrometheusResult.ERROR - + errors = [] try: - profile = await run_in_threadpool(client.verify_token, token, scope="profile") - result = PrometheusResult.SUCCESS - return profile - except Exception as e: - logger.error(f"FxA auth error: {e}") + tasks = [ + asyncio.create_task( + run_in_threadpool(client.verify_token, token, scope=scope) + ) + for scope in FXA_SCOPES + ] + try: + for task in asyncio.as_completed(tasks): + try: + profile = await task + result = PrometheusResult.SUCCESS + return profile + except Exception as e: + errors.append(e) + finally: + for task in tasks: + if not task.done(): + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + logger.error(f"FxA auth error: {errors}") raise HTTPException(status_code=401, detail="Invalid FxA auth") finally: metrics.validate_fxa_latency.labels(result=result).observe( diff --git a/src/tests/consts.py b/src/tests/consts.py index 661d81e..3d165e4 100644 --- a/src/tests/consts.py +++ b/src/tests/consts.py @@ -65,7 +65,7 @@ MOCK_FXA_USER_DATA = { "user": TEST_USER_ID, "client_id": "test-client-id", - "scope": ["profile"], + "scope": ["profile:uid"], "generation": 1, "profile_changed_at": 1234567890, } diff --git a/src/tests/integration/test_mock_router_integration.py b/src/tests/integration/test_mock_router_integration.py index 2b99b59..9732ef3 100644 --- a/src/tests/integration/test_mock_router_integration.py +++ b/src/tests/integration/test_mock_router_integration.py @@ -232,7 +232,7 @@ def test_mock_chat_completions_no_auth_missing_user_in_token( mock_fxa_client._verify_jwt_token.return_value = { "client_id": "test-client-id", - "scope": ["profile"], + "scope": ["profile:uid"], } response = mocked_client_integration.post( @@ -259,7 +259,7 @@ def test_mock_chat_completions_no_auth_blocked_user( mock_fxa_client._verify_jwt_token.return_value = { "user": "blocked-user-id", "client_id": "test-client-id", - "scope": ["profile"], + "scope": ["profile:uid"], } with patch( @@ -295,7 +295,7 @@ def test_mock_chat_completions_latency_simulation(self, mocked_client_integratio mock_fxa_client._verify_jwt_token.return_value = { "user": TEST_USER_ID, "client_id": "test-client-id", - "scope": ["profile"], + "scope": ["profile:uid"], } with patch.dict("os.environ", {"MOCK_LATENCY_MS": "100"}): diff --git a/src/tests/mocks.py b/src/tests/mocks.py index 9ab731c..1230b2f 100644 --- a/src/tests/mocks.py +++ b/src/tests/mocks.py @@ -195,7 +195,7 @@ def __init__(self, client_id: str, client_secret: str, fxa_url: str): self.client_secret = client_secret self.fxa_url = fxa_url - def verify_token(self, token: str, scope: str = "profile"): + def verify_token(self, token: str, scope: str = "profile:uid"): if token == TEST_FXA_TOKEN: return {"user": TEST_USER_ID} raise Exception("Invalid token") diff --git a/src/tests/unit/test_fxa_auth.py b/src/tests/unit/test_fxa_auth.py new file mode 100644 index 0000000..790fd65 --- /dev/null +++ b/src/tests/unit/test_fxa_auth.py @@ -0,0 +1,51 @@ +import asyncio + +import pytest +from fastapi import HTTPException + +from mlpa.core.prometheus_metrics import PrometheusResult +from mlpa.core.routers.fxa import fxa as fxa_module + + +async def test_fxa_auth_returns_first_successful_scope(mocker): + scopes = ("profile:uid", "scope-a", "scope-b") + mocker.patch.object(fxa_module, "FXA_SCOPES", scopes) + + async def fake_run_in_threadpool(_fn, _token, *, scope): + if scope == "scope-b": + await asyncio.sleep(0.01) + return {"user": "ok"} + await asyncio.sleep(0.02) + raise Exception(f"invalid-{scope}") + + mocker.patch.object(fxa_module, "run_in_threadpool", new=fake_run_in_threadpool) + mock_metrics = mocker.patch.object(fxa_module, "metrics") + + profile = await fxa_module.fxa_auth("Bearer test-token") + + assert profile == {"user": "ok"} + mock_metrics.validate_fxa_latency.labels.assert_called_once_with( + result=PrometheusResult.SUCCESS + ) + mock_metrics.validate_fxa_latency.labels().observe.assert_called_once() + + +async def test_fxa_auth_raises_when_all_scopes_fail(mocker): + scopes = ("profile:uid", "scope-a") + mocker.patch.object(fxa_module, "FXA_SCOPES", scopes) + + async def fake_run_in_threadpool(_fn, _token, *, scope): + await asyncio.sleep(0.01) + raise Exception(f"invalid-{scope}") + + mocker.patch.object(fxa_module, "run_in_threadpool", new=fake_run_in_threadpool) + mock_metrics = mocker.patch.object(fxa_module, "metrics") + + with pytest.raises(HTTPException) as exc_info: + await fxa_module.fxa_auth("Bearer test-token") + + assert exc_info.value.status_code == 401 + mock_metrics.validate_fxa_latency.labels.assert_called_once_with( + result=PrometheusResult.ERROR + ) + mock_metrics.validate_fxa_latency.labels().observe.assert_called_once()