Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ from fxa.tools.bearer import get_bearer_token
fxa_token: str = get_bearer_token(
your_mozilla_account_email,
your_mozilla_account_password,
scopes=["profile"],
scopes=["profile:uid"],
client_id="5882386c6d801776" # a common client_id for the dev environment,
account_server_url="https://api.accounts.firefox.com",
oauth_server_url="https://oauth.accounts.firefox.com",
Expand All @@ -148,7 +148,7 @@ fxa_token: str = get_bearer_token(

## How to update static docs/index.html from redoc

Make sure you have Node installed
Ensure Node is installed

1. `make install`
2. `mlpa`
Expand Down
3 changes: 3 additions & 0 deletions src/mlpa/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def valid_service_types(self) -> list[str]:
# FxA
CLIENT_ID: str = "default-client-id"
CLIENT_SECRET: str = "default-client-secret"
ADDITIONAL_FXA_SCOPE_1: str | None = None
ADDITIONAL_FXA_SCOPE_2: str | None = None
ADDITIONAL_FXA_SCOPE_3: str | None = None

# PostgreSQL
LITELLM_DB_NAME: str = "litellm"
Expand Down
40 changes: 34 additions & 6 deletions src/mlpa/core/routers/fxa/fxa.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,57 @@
import asyncio
import time
from typing import Annotated

from fastapi import APIRouter, Header, HTTPException
from fastapi.concurrency import run_in_threadpool

from mlpa.core.config import env
from mlpa.core.logger import logger
from mlpa.core.prometheus_metrics import PrometheusResult, metrics
from mlpa.core.utils import get_fxa_client

router = APIRouter()

client = get_fxa_client()
FXA_DEFAULT_SCOPE = "profile:uid"
FXA_SCOPES = tuple(
scope
for scope in (
FXA_DEFAULT_SCOPE,
env.ADDITIONAL_FXA_SCOPE_1,
env.ADDITIONAL_FXA_SCOPE_2,
env.ADDITIONAL_FXA_SCOPE_3,
)
if scope
)


async def fxa_auth(authorization: Annotated[str | None, Header()]):
start_time = time.perf_counter()
token = authorization.removeprefix("Bearer ").split()[0]
result = PrometheusResult.ERROR

errors = []
try:
profile = await run_in_threadpool(client.verify_token, token, scope="profile")
result = PrometheusResult.SUCCESS
return profile
except Exception as e:
logger.error(f"FxA auth error: {e}")
tasks = [
asyncio.create_task(
run_in_threadpool(client.verify_token, token, scope=scope)
)
for scope in FXA_SCOPES
]
try:
for task in asyncio.as_completed(tasks):
try:
profile = await task
result = PrometheusResult.SUCCESS
return profile
except Exception as e:
errors.append(e)
finally:
for task in tasks:
if not task.done():
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
logger.error(f"FxA auth error: {errors}")
raise HTTPException(status_code=401, detail="Invalid FxA auth")
finally:
metrics.validate_fxa_latency.labels(result=result).observe(
Expand Down
2 changes: 1 addition & 1 deletion src/tests/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
MOCK_FXA_USER_DATA = {
"user": TEST_USER_ID,
"client_id": "test-client-id",
"scope": ["profile"],
"scope": ["profile:uid"],
"generation": 1,
"profile_changed_at": 1234567890,
}
6 changes: 3 additions & 3 deletions src/tests/integration/test_mock_router_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_mock_chat_completions_no_auth_missing_user_in_token(

mock_fxa_client._verify_jwt_token.return_value = {
"client_id": "test-client-id",
"scope": ["profile"],
"scope": ["profile:uid"],
}

response = mocked_client_integration.post(
Expand All @@ -259,7 +259,7 @@ def test_mock_chat_completions_no_auth_blocked_user(
mock_fxa_client._verify_jwt_token.return_value = {
"user": "blocked-user-id",
"client_id": "test-client-id",
"scope": ["profile"],
"scope": ["profile:uid"],
}

with patch(
Expand Down Expand Up @@ -295,7 +295,7 @@ def test_mock_chat_completions_latency_simulation(self, mocked_client_integratio
mock_fxa_client._verify_jwt_token.return_value = {
"user": TEST_USER_ID,
"client_id": "test-client-id",
"scope": ["profile"],
"scope": ["profile:uid"],
}

with patch.dict("os.environ", {"MOCK_LATENCY_MS": "100"}):
Expand Down
2 changes: 1 addition & 1 deletion src/tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(self, client_id: str, client_secret: str, fxa_url: str):
self.client_secret = client_secret
self.fxa_url = fxa_url

def verify_token(self, token: str, scope: str = "profile"):
def verify_token(self, token: str, scope: str = "profile:uid"):
if token == TEST_FXA_TOKEN:
return {"user": TEST_USER_ID}
raise Exception("Invalid token")
Expand Down
51 changes: 51 additions & 0 deletions src/tests/unit/test_fxa_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import asyncio

import pytest
from fastapi import HTTPException

from mlpa.core.prometheus_metrics import PrometheusResult
from mlpa.core.routers.fxa import fxa as fxa_module


async def test_fxa_auth_returns_first_successful_scope(mocker):
scopes = ("profile:uid", "scope-a", "scope-b")
mocker.patch.object(fxa_module, "FXA_SCOPES", scopes)

async def fake_run_in_threadpool(_fn, _token, *, scope):
if scope == "scope-b":
await asyncio.sleep(0.01)
return {"user": "ok"}
await asyncio.sleep(0.02)
raise Exception(f"invalid-{scope}")

mocker.patch.object(fxa_module, "run_in_threadpool", new=fake_run_in_threadpool)
mock_metrics = mocker.patch.object(fxa_module, "metrics")

profile = await fxa_module.fxa_auth("Bearer test-token")

assert profile == {"user": "ok"}
mock_metrics.validate_fxa_latency.labels.assert_called_once_with(
result=PrometheusResult.SUCCESS
)
mock_metrics.validate_fxa_latency.labels().observe.assert_called_once()


async def test_fxa_auth_raises_when_all_scopes_fail(mocker):
scopes = ("profile:uid", "scope-a")
mocker.patch.object(fxa_module, "FXA_SCOPES", scopes)

async def fake_run_in_threadpool(_fn, _token, *, scope):
await asyncio.sleep(0.01)
raise Exception(f"invalid-{scope}")

mocker.patch.object(fxa_module, "run_in_threadpool", new=fake_run_in_threadpool)
mock_metrics = mocker.patch.object(fxa_module, "metrics")

with pytest.raises(HTTPException) as exc_info:
await fxa_module.fxa_auth("Bearer test-token")

assert exc_info.value.status_code == 401
mock_metrics.validate_fxa_latency.labels.assert_called_once_with(
result=PrometheusResult.ERROR
)
mock_metrics.validate_fxa_latency.labels().observe.assert_called_once()