diff --git a/.gitignore b/.gitignore index bdbdb20..29391c1 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ .DS_Store metrics.jsonl service_account.json +play_integrity_service_account.json .ruff_cache # Since we are running it as a library, better not to commit the lock file uv.lock diff --git a/scripts/play_integrity_e2e.py b/scripts/play_integrity_e2e.py new file mode 100644 index 0000000..9b37c88 --- /dev/null +++ b/scripts/play_integrity_e2e.py @@ -0,0 +1,107 @@ +""" +E2E Play Integrity flow against a running MLPA server. +Requires a real Play Integrity token and configured service account file on the server. +""" + +import argparse +import json +import os +from typing import Optional + +import httpx + +from mlpa.core.config import env + +DEFAULT_BASE_URL = f"http://0.0.0.0:{env.PORT or 8080}" +DEFAULT_SERVICE_TYPE = "ai" + + +def _print_json(payload: dict) -> None: + print(json.dumps(payload, indent=2)) + + +def _require_value(value: Optional[str], name: str) -> str: + if value: + return value + raise SystemExit(f"Missing required value for {name}.") + + +def run(args: argparse.Namespace) -> None: + integrity_token = _require_value( + args.integrity_token or os.getenv("MLPA_PLAY_INTEGRITY_TOKEN"), + "integrity_token", + ) + user_id = _require_value( + args.user_id or os.getenv("MLPA_PLAY_USER_ID"), + "user_id", + ) + + verify_response = httpx.post( + f"{args.base_url}/verify/play", + json={"integrity_token": integrity_token, "user_id": user_id}, + timeout=args.timeout_s, + ) + verify_response.raise_for_status() + access_token = verify_response.json().get("access_token") + if not access_token: + raise SystemExit("No access_token returned from /verify/play.") + + headers = { + "authorization": f"Bearer {access_token}", + "use-play-integrity": "true", + "service-type": args.service_type, + } + payload = { + "model": args.model or env.MODEL_NAME, + "messages": [{"role": "user", "content": args.message}], + "stream": args.stream, + } + + if args.stream: + with httpx.stream( + "POST", + f"{args.base_url}/v1/chat/completions", + headers=headers, + json=payload, + timeout=args.timeout_s, + ) as response: + response.raise_for_status() + for line in response.iter_lines(): + if line: + print(line) + else: + response = httpx.post( + f"{args.base_url}/v1/chat/completions", + headers=headers, + json=payload, + timeout=args.timeout_s, + ) + response.raise_for_status() + _print_json(response.json()) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="E2E Play Integrity verification + chat completion." + ) + subparsers = parser.add_subparsers(dest="command") + + run_parser = subparsers.add_parser("run", help="Verify and request a completion.") + run_parser.add_argument("--integrity-token", dest="integrity_token") + run_parser.add_argument("--user-id", dest="user_id") + run_parser.set_defaults(func=run) + + return parser + + +def main() -> None: + parser = build_parser() + args = parser.parse_args() + if not getattr(args, "command", None): + parser.print_help() + raise SystemExit(2) + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/src/mlpa/core/auth/authorize.py b/src/mlpa/core/auth/authorize.py index b4b0ba6..8d94368 100644 --- a/src/mlpa/core/auth/authorize.py +++ b/src/mlpa/core/auth/authorize.py @@ -6,7 +6,7 @@ from mlpa.core.config import env from mlpa.core.routers.appattest import app_attest_auth from mlpa.core.routers.fxa import fxa_auth -from mlpa.core.utils import parse_app_attest_jwt +from mlpa.core.utils import extract_user_from_play_integrity_jwt, parse_app_attest_jwt async def authorize_request( @@ -15,10 +15,12 @@ async def authorize_request( service_type: Annotated[ServiceType, Header()], use_app_attest: Annotated[bool | None, Header()] = None, use_qa_certificates: Annotated[bool | None, Header()] = None, + use_play_integrity: Annotated[bool | None, Header()] = None, ) -> AuthorizedChatRequest: if not authorization: raise HTTPException(status_code=401, detail="Missing authorization header") if use_app_attest: + # Apple App Attest assertionAuth = parse_app_attest_jwt(authorization, "assert") data = await app_attest_auth(assertionAuth, chat_request, use_qa_certificates) if data: @@ -29,6 +31,15 @@ async def authorize_request( service_type=service_type.value, **chat_request.model_dump(exclude_unset=True), ) + elif use_play_integrity: + # Google Play integrity + play_user_id = extract_user_from_play_integrity_jwt(authorization) + if play_user_id: + return AuthorizedChatRequest( + user=f"{play_user_id}:{service_type.value}", + service_type=service_type.value, + **chat_request.model_dump(exclude_unset=True), + ) else: fxa_user_id = await fxa_auth(authorization) if fxa_user_id: diff --git a/src/mlpa/core/classes.py b/src/mlpa/core/classes.py index d751ab2..36ca47b 100644 --- a/src/mlpa/core/classes.py +++ b/src/mlpa/core/classes.py @@ -25,6 +25,7 @@ class UserUpdatePayload(BaseModel): blocked: bool | None = None +# iOS App Attest class AttestationAuth(BaseModel): key_id_b64: str challenge_b64: str @@ -37,6 +38,12 @@ class AssertionAuth(BaseModel): assertion_obj_b64: str +# Google Play Integrity +class PlayIntegrityRequest(BaseModel): + integrity_token: str + user_id: str + + class AuthorizedChatRequest(ChatRequest): user: str service_type: str diff --git a/src/mlpa/core/config.py b/src/mlpa/core/config.py index f154c0f..dbf9964 100644 --- a/src/mlpa/core/config.py +++ b/src/mlpa/core/config.py @@ -93,6 +93,13 @@ def valid_service_types(self) -> list[str]: APP_ATTEST_QA_BUCKET_PREFIX: str | None = None APP_ATTEST_QA_GCP_PROJECT_ID: str | None = None + # Play Integrity + PLAY_INTEGRITY_PACKAGE_NAME: str = "com.example.app" + PLAY_INTEGRITY_SERVICE_ACCOUNT_FILE: str = "play_integrity_service_account.json" + PLAY_INTEGRITY_REQUEST_TIMEOUT_SECONDS: int = 30 + MLPA_ACCESS_TOKEN_SECRET: str = "mlpa-dev-secret" + MLPA_ACCESS_TOKEN_TTL_SECONDS: int = 300 + # FxA CLIENT_ID: str = "default-client-id" CLIENT_SECRET: str = "default-client-secret" diff --git a/src/mlpa/core/routers/play/__init__.py b/src/mlpa/core/routers/play/__init__.py new file mode 100644 index 0000000..c12648c --- /dev/null +++ b/src/mlpa/core/routers/play/__init__.py @@ -0,0 +1,3 @@ +from mlpa.core.routers.play.play import router as play_router + +__all__ = ["play_router"] diff --git a/src/mlpa/core/routers/play/play.py b/src/mlpa/core/routers/play/play.py new file mode 100644 index 0000000..82760f2 --- /dev/null +++ b/src/mlpa/core/routers/play/play.py @@ -0,0 +1,108 @@ +import hashlib +from functools import lru_cache + +import httpx +from fastapi import APIRouter, HTTPException +from fastapi.concurrency import run_in_threadpool +from google.auth.transport.requests import Request +from google.oauth2 import service_account +from pydantic import BaseModel + +from mlpa.core.classes import PlayIntegrityRequest +from mlpa.core.config import env +from mlpa.core.http_client import get_http_client +from mlpa.core.utils import issue_mlpa_access_token, raise_and_log + +router = APIRouter() + +PLAY_INTEGRITY_SCOPE = "https://www.googleapis.com/auth/playintegrity" +ALLOWED_DEVICE_VERDICTS = { + "MEETS_DEVICE_INTEGRITY", + "MEETS_BASIC_INTEGRITY", + "MEETS_STRONG_INTEGRITY", +} + + +@lru_cache(maxsize=1) +def _get_service_account_credentials(): + return service_account.Credentials.from_service_account_file( + env.PLAY_INTEGRITY_SERVICE_ACCOUNT_FILE, + scopes=[PLAY_INTEGRITY_SCOPE], + ) + + +def _get_play_integrity_access_token() -> str: + credentials = _get_service_account_credentials() + if not credentials.valid: + credentials.refresh(Request()) + if not credentials.token: + raise HTTPException(status_code=500, detail="Failed to fetch access token") + return credentials.token + + +async def _decode_integrity_token(integrity_token: str) -> dict: + access_token = await run_in_threadpool(_get_play_integrity_access_token) + client = get_http_client() + try: + response = await client.post( + f"https://playintegrity.googleapis.com/v1/{env.PLAY_INTEGRITY_PACKAGE_NAME}:decodeIntegrityToken", + headers={ + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + }, + json={"integrity_token": integrity_token}, + timeout=env.PLAY_INTEGRITY_REQUEST_TIMEOUT_SECONDS, + ) + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise_and_log(e, False, 401) + except Exception as e: + raise_and_log(e, False, 502, "Play Integrity validation service unavailable") + return response.json() + + +def _validate_integrity_payload(payload: dict, expected_hash: str) -> None: + request_details = payload.get("requestDetails", {}) + package_name = request_details.get("requestPackageName") + if package_name and package_name != env.PLAY_INTEGRITY_PACKAGE_NAME: + raise HTTPException(status_code=401, detail="Invalid package name") + + token_request_hash = request_details.get("requestHash") + if token_request_hash != expected_hash: + raise HTTPException(status_code=401, detail="Invalid request hash") + + app_integrity = payload.get("appIntegrity", {}) + acceptable_recognition_verdicts = [ + "PLAY_RECOGNIZED", + ] + if env.MLPA_DEBUG: + acceptable_recognition_verdicts.append("UNRECOGNIZED_VERSION") + if ( + app_integrity.get("appRecognitionVerdict") + not in acceptable_recognition_verdicts + ): + raise HTTPException(status_code=401, detail="App not recognized by Play") + + device_integrity = payload.get("deviceIntegrity", {}) + device_verdicts = set(device_integrity.get("deviceRecognitionVerdict", [])) + if not device_verdicts.intersection(ALLOWED_DEVICE_VERDICTS): + raise HTTPException(status_code=401, detail="Device integrity check failed") + + +@router.post("/play", tags=["Play Integrity"]) +async def verify_play_integrity(payload: PlayIntegrityRequest): + decoded = await _decode_integrity_token(payload.integrity_token) + token_payload = decoded.get("tokenPayloadExternal") or decoded.get("tokenPayload") + if not token_payload: + raise HTTPException(status_code=401, detail="Invalid Play Integrity token") + + expected_hash = hashlib.sha256(payload.user_id.encode("utf-8")).hexdigest() + + _validate_integrity_payload(token_payload, expected_hash) + + access_token = issue_mlpa_access_token(payload.user_id) + return { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": env.MLPA_ACCESS_TOKEN_TTL_SECONDS, + } diff --git a/src/mlpa/core/utils.py b/src/mlpa/core/utils.py index 494ae91..1cbd751 100644 --- a/src/mlpa/core/utils.py +++ b/src/mlpa/core/utils.py @@ -1,10 +1,11 @@ import ast import base64 import json +import time from fastapi import HTTPException from fxa.oauth import Client -from jwtoxide import DecodingKey, ValidationOptions, decode +from jwtoxide import DecodingKey, ValidationOptions, decode, encode from mlpa.core.classes import AssertionAuth, AttestationAuth from mlpa.core.config import LITELLM_MASTER_AUTH_HEADERS, env @@ -160,3 +161,38 @@ def raise_and_log( else response_text_prefix or GENERIC_UPSTREAM_ERROR }, ) + + +def extract_user_from_play_integrity_jwt(authorization: str): + token = authorization.removeprefix("Bearer ").split()[0] + try: + payload = decode( + token, + env.MLPA_ACCESS_TOKEN_SECRET, + ValidationOptions( + required_spec_claims={"exp", "iat", "sub"}, + iss={"mlpa"}, + aud=None, + validate_aud=False, + validate_exp=True, + validate_nbf=False, + verify_signature=True, + algorithms=["HS256"], + ), + ) + return payload["sub"] + except Exception as e: + logger.error(f"Play Integrity JWT decode error: {e}") + raise HTTPException(status_code=401, detail="Invalid MLPA access token") + + +def issue_mlpa_access_token(user_id: str) -> str: + now = int(time.time()) + payload = { + "sub": user_id, + "iat": now, + "exp": now + env.MLPA_ACCESS_TOKEN_TTL_SECONDS, + "iss": "mlpa", + "typ": "mlpa_access", + } + return encode(payload, env.MLPA_ACCESS_TOKEN_SECRET, algorithm="HS256") diff --git a/src/mlpa/run.py b/src/mlpa/run.py index 02e4d75..ad1f05f 100644 --- a/src/mlpa/run.py +++ b/src/mlpa/run.py @@ -25,6 +25,7 @@ from mlpa.core.routers.fxa import fxa_router from mlpa.core.routers.health import health_router from mlpa.core.routers.mock import mock_router +from mlpa.core.routers.play import play_router from mlpa.core.routers.user import user_router from mlpa.core.utils import get_or_create_user @@ -35,6 +36,10 @@ "name": "App Attest", "description": "Endpoints for verifying App Attest payloads.", }, + { + "name": "Play Integrity", + "description": "Endpoints for verifying Play Integrity payloads.", + }, {"name": "LiteLLM", "description": "Endpoints for interacting with LiteLLM."}, {"name": "Mock", "description": "Mock endpoints for testing purposes."}, { @@ -112,6 +117,7 @@ async def get_metrics(): app.include_router(health_router, prefix="/health") app.include_router(appattest_router, prefix="/verify") +app.include_router(play_router, prefix="/verify") app.include_router(fxa_router, prefix="/fxa") app.include_router(user_router, prefix="/user") app.include_router(mock_router, prefix="/mock") diff --git a/src/tests/conftest.py b/src/tests/conftest.py new file mode 100644 index 0000000..5d7f4b7 --- /dev/null +++ b/src/tests/conftest.py @@ -0,0 +1,12 @@ +import pytest + +from mlpa.core.config import env + + +@pytest.fixture(autouse=True, scope="session") +def _force_mlpa_debug_false(): + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setenv("MLPA_DEBUG", "false") + env.MLPA_DEBUG = False + yield + monkeypatch.undo() diff --git a/src/tests/integration/test_play_integrity.py b/src/tests/integration/test_play_integrity.py new file mode 100644 index 0000000..afd6dc9 --- /dev/null +++ b/src/tests/integration/test_play_integrity.py @@ -0,0 +1,83 @@ +import hashlib + +from mlpa.core.config import env +from mlpa.core.utils import issue_mlpa_access_token +from tests.consts import SAMPLE_CHAT_REQUEST, SUCCESSFUL_CHAT_RESPONSE, TEST_USER_ID + + +def _mock_decode_payload(request_hash: str) -> dict: + return { + "tokenPayloadExternal": { + "requestDetails": { + "requestPackageName": env.PLAY_INTEGRITY_PACKAGE_NAME, + "requestHash": request_hash, + }, + "appIntegrity": {"appRecognitionVerdict": "PLAY_RECOGNIZED"}, + "deviceIntegrity": {"deviceRecognitionVerdict": ["MEETS_DEVICE_INTEGRITY"]}, + } + } + + +def test_verify_play_integrity_success(mocked_client_integration, mocker): + request_hash = hashlib.sha256(TEST_USER_ID.encode("utf-8")).hexdigest() + mocker.patch( + "mlpa.core.routers.play.play._decode_integrity_token", + return_value=_mock_decode_payload(request_hash), + ) + + response = mocked_client_integration.post( + "/verify/play", + json={"integrity_token": "test-token", "user_id": TEST_USER_ID}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["access_token"] + assert data["token_type"] == "Bearer" + assert data["expires_in"] == env.MLPA_ACCESS_TOKEN_TTL_SECONDS + + +def test_verify_play_integrity_invalid_hash(mocked_client_integration, mocker): + mocker.patch( + "mlpa.core.routers.play.play._decode_integrity_token", + return_value=_mock_decode_payload("bad-hash"), + ) + + response = mocked_client_integration.post( + "/verify/play", + json={"integrity_token": "test-token", "user_id": TEST_USER_ID}, + ) + + assert response.status_code == 401 + assert response.json()["detail"] == "Invalid request hash" + + +def test_verify_play_integrity_missing_payload(mocked_client_integration, mocker): + mocker.patch( + "mlpa.core.routers.play.play._decode_integrity_token", + return_value={}, + ) + + response = mocked_client_integration.post( + "/verify/play", + json={"integrity_token": "test-token", "user_id": TEST_USER_ID}, + ) + + assert response.status_code == 401 + assert response.json()["detail"] == "Invalid Play Integrity token" + + +def test_chat_with_play_integrity_token_success(mocked_client_integration): + access_token = issue_mlpa_access_token(TEST_USER_ID) + response = mocked_client_integration.post( + "/v1/chat/completions", + headers={ + "authorization": f"Bearer {access_token}", + "use-play-integrity": "true", + "service-type": "ai", + }, + json=SAMPLE_CHAT_REQUEST.model_dump(exclude_unset=True), + ) + + assert response.status_code == 200 + assert response.json() == SUCCESSFUL_CHAT_RESPONSE