Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
.DS_Store
metrics.jsonl
service_account.json
play_integrity_service_account.json
.ruff_cache
# Since we are running it as a library, better not to commit the lock file
uv.lock
Expand Down
107 changes: 107 additions & 0 deletions scripts/play_integrity_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
E2E Play Integrity flow against a running MLPA server.
Requires a real Play Integrity token and configured service account file on the server.
"""

import argparse
import json
import os
from typing import Optional

import httpx

from mlpa.core.config import env

DEFAULT_BASE_URL = f"http://0.0.0.0:{env.PORT or 8080}"
DEFAULT_SERVICE_TYPE = "ai"


def _print_json(payload: dict) -> None:
print(json.dumps(payload, indent=2))


def _require_value(value: Optional[str], name: str) -> str:
if value:
return value
raise SystemExit(f"Missing required value for {name}.")


def run(args: argparse.Namespace) -> None:
integrity_token = _require_value(
args.integrity_token or os.getenv("MLPA_PLAY_INTEGRITY_TOKEN"),
"integrity_token",
)
user_id = _require_value(
args.user_id or os.getenv("MLPA_PLAY_USER_ID"),
"user_id",
)

verify_response = httpx.post(
f"{args.base_url}/verify/play",
json={"integrity_token": integrity_token, "user_id": user_id},
timeout=args.timeout_s,
)
verify_response.raise_for_status()
access_token = verify_response.json().get("access_token")
if not access_token:
raise SystemExit("No access_token returned from /verify/play.")

headers = {
"authorization": f"Bearer {access_token}",
"use-play-integrity": "true",
"service-type": args.service_type,
}
payload = {
"model": args.model or env.MODEL_NAME,
"messages": [{"role": "user", "content": args.message}],
"stream": args.stream,
}

if args.stream:
with httpx.stream(
"POST",
f"{args.base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=args.timeout_s,
) as response:
response.raise_for_status()
for line in response.iter_lines():
if line:
print(line)
else:
response = httpx.post(
f"{args.base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=args.timeout_s,
)
response.raise_for_status()
_print_json(response.json())


def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="E2E Play Integrity verification + chat completion."
)
subparsers = parser.add_subparsers(dest="command")

run_parser = subparsers.add_parser("run", help="Verify and request a completion.")
run_parser.add_argument("--integrity-token", dest="integrity_token")
run_parser.add_argument("--user-id", dest="user_id")
run_parser.set_defaults(func=run)

return parser


def main() -> None:
parser = build_parser()
args = parser.parse_args()
if not getattr(args, "command", None):
parser.print_help()
raise SystemExit(2)
args.func(args)


if __name__ == "__main__":
main()
13 changes: 12 additions & 1 deletion src/mlpa/core/auth/authorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mlpa.core.config import env
from mlpa.core.routers.appattest import app_attest_auth
from mlpa.core.routers.fxa import fxa_auth
from mlpa.core.utils import parse_app_attest_jwt
from mlpa.core.utils import extract_user_from_play_integrity_jwt, parse_app_attest_jwt


async def authorize_request(
Expand All @@ -15,10 +15,12 @@ async def authorize_request(
service_type: Annotated[ServiceType, Header()],
use_app_attest: Annotated[bool | None, Header()] = None,
use_qa_certificates: Annotated[bool | None, Header()] = None,
use_play_integrity: Annotated[bool | None, Header()] = None,
) -> AuthorizedChatRequest:
if not authorization:
raise HTTPException(status_code=401, detail="Missing authorization header")
if use_app_attest:
# Apple App Attest
assertionAuth = parse_app_attest_jwt(authorization, "assert")
data = await app_attest_auth(assertionAuth, chat_request, use_qa_certificates)
if data:
Expand All @@ -29,6 +31,15 @@ async def authorize_request(
service_type=service_type.value,
**chat_request.model_dump(exclude_unset=True),
)
elif use_play_integrity:
# Google Play integrity
play_user_id = extract_user_from_play_integrity_jwt(authorization)
if play_user_id:
return AuthorizedChatRequest(
user=f"{play_user_id}:{service_type.value}",
service_type=service_type.value,
**chat_request.model_dump(exclude_unset=True),
)
else:
fxa_user_id = await fxa_auth(authorization)
if fxa_user_id:
Expand Down
7 changes: 7 additions & 0 deletions src/mlpa/core/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class UserUpdatePayload(BaseModel):
blocked: bool | None = None


# iOS App Attest
class AttestationAuth(BaseModel):
key_id_b64: str
challenge_b64: str
Expand All @@ -37,6 +38,12 @@ class AssertionAuth(BaseModel):
assertion_obj_b64: str


# Google Play Integrity
class PlayIntegrityRequest(BaseModel):
integrity_token: str
user_id: str


class AuthorizedChatRequest(ChatRequest):
user: str
service_type: str
Expand Down
7 changes: 7 additions & 0 deletions src/mlpa/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ def valid_service_types(self) -> list[str]:
APP_ATTEST_QA_BUCKET_PREFIX: str | None = None
APP_ATTEST_QA_GCP_PROJECT_ID: str | None = None

# Play Integrity
PLAY_INTEGRITY_PACKAGE_NAME: str = "com.example.app"
PLAY_INTEGRITY_SERVICE_ACCOUNT_FILE: str = "play_integrity_service_account.json"
PLAY_INTEGRITY_REQUEST_TIMEOUT_SECONDS: int = 30
MLPA_ACCESS_TOKEN_SECRET: str = "mlpa-dev-secret"
MLPA_ACCESS_TOKEN_TTL_SECONDS: int = 300

# FxA
CLIENT_ID: str = "default-client-id"
CLIENT_SECRET: str = "default-client-secret"
Expand Down
3 changes: 3 additions & 0 deletions src/mlpa/core/routers/play/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from mlpa.core.routers.play.play import router as play_router

__all__ = ["play_router"]
108 changes: 108 additions & 0 deletions src/mlpa/core/routers/play/play.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import hashlib
from functools import lru_cache

import httpx
from fastapi import APIRouter, HTTPException
from fastapi.concurrency import run_in_threadpool
from google.auth.transport.requests import Request
from google.oauth2 import service_account
from pydantic import BaseModel

from mlpa.core.classes import PlayIntegrityRequest
from mlpa.core.config import env
from mlpa.core.http_client import get_http_client
from mlpa.core.utils import issue_mlpa_access_token, raise_and_log

router = APIRouter()

PLAY_INTEGRITY_SCOPE = "https://www.googleapis.com/auth/playintegrity"
ALLOWED_DEVICE_VERDICTS = {
"MEETS_DEVICE_INTEGRITY",
"MEETS_BASIC_INTEGRITY",
"MEETS_STRONG_INTEGRITY",
}


@lru_cache(maxsize=1)
def _get_service_account_credentials():
return service_account.Credentials.from_service_account_file(
env.PLAY_INTEGRITY_SERVICE_ACCOUNT_FILE,
scopes=[PLAY_INTEGRITY_SCOPE],
)


def _get_play_integrity_access_token() -> str:
credentials = _get_service_account_credentials()
if not credentials.valid:
credentials.refresh(Request())
if not credentials.token:
raise HTTPException(status_code=500, detail="Failed to fetch access token")
return credentials.token


async def _decode_integrity_token(integrity_token: str) -> dict:
access_token = await run_in_threadpool(_get_play_integrity_access_token)
client = get_http_client()
try:
response = await client.post(
f"https://playintegrity.googleapis.com/v1/{env.PLAY_INTEGRITY_PACKAGE_NAME}:decodeIntegrityToken",
headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
},
json={"integrity_token": integrity_token},
timeout=env.PLAY_INTEGRITY_REQUEST_TIMEOUT_SECONDS,
)
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise_and_log(e, False, 401)
except Exception as e:
raise_and_log(e, False, 502, "Play Integrity validation service unavailable")
return response.json()


def _validate_integrity_payload(payload: dict, expected_hash: str) -> None:
request_details = payload.get("requestDetails", {})
package_name = request_details.get("requestPackageName")
if package_name and package_name != env.PLAY_INTEGRITY_PACKAGE_NAME:
raise HTTPException(status_code=401, detail="Invalid package name")

token_request_hash = request_details.get("requestHash")
if token_request_hash != expected_hash:
raise HTTPException(status_code=401, detail="Invalid request hash")

app_integrity = payload.get("appIntegrity", {})
acceptable_recognition_verdicts = [
"PLAY_RECOGNIZED",
]
if env.MLPA_DEBUG:
acceptable_recognition_verdicts.append("UNRECOGNIZED_VERSION")
if (
app_integrity.get("appRecognitionVerdict")
not in acceptable_recognition_verdicts
):
raise HTTPException(status_code=401, detail="App not recognized by Play")

device_integrity = payload.get("deviceIntegrity", {})
device_verdicts = set(device_integrity.get("deviceRecognitionVerdict", []))
if not device_verdicts.intersection(ALLOWED_DEVICE_VERDICTS):
raise HTTPException(status_code=401, detail="Device integrity check failed")


@router.post("/play", tags=["Play Integrity"])
async def verify_play_integrity(payload: PlayIntegrityRequest):
decoded = await _decode_integrity_token(payload.integrity_token)
token_payload = decoded.get("tokenPayloadExternal") or decoded.get("tokenPayload")
if not token_payload:
raise HTTPException(status_code=401, detail="Invalid Play Integrity token")

expected_hash = hashlib.sha256(payload.user_id.encode("utf-8")).hexdigest()

_validate_integrity_payload(token_payload, expected_hash)

access_token = issue_mlpa_access_token(payload.user_id)
return {
"access_token": access_token,
"token_type": "Bearer",
"expires_in": env.MLPA_ACCESS_TOKEN_TTL_SECONDS,
}
38 changes: 37 additions & 1 deletion src/mlpa/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import ast
import base64
import json
import time

from fastapi import HTTPException
from fxa.oauth import Client
from jwtoxide import DecodingKey, ValidationOptions, decode
from jwtoxide import DecodingKey, ValidationOptions, decode, encode

from mlpa.core.classes import AssertionAuth, AttestationAuth
from mlpa.core.config import LITELLM_MASTER_AUTH_HEADERS, env
Expand Down Expand Up @@ -160,3 +161,38 @@ def raise_and_log(
else response_text_prefix or GENERIC_UPSTREAM_ERROR
},
)


def extract_user_from_play_integrity_jwt(authorization: str):
token = authorization.removeprefix("Bearer ").split()[0]
try:
payload = decode(
token,
env.MLPA_ACCESS_TOKEN_SECRET,
ValidationOptions(
required_spec_claims={"exp", "iat", "sub"},
iss={"mlpa"},
aud=None,
validate_aud=False,
validate_exp=True,
validate_nbf=False,
verify_signature=True,
algorithms=["HS256"],
),
)
return payload["sub"]
except Exception as e:
logger.error(f"Play Integrity JWT decode error: {e}")
raise HTTPException(status_code=401, detail="Invalid MLPA access token")


def issue_mlpa_access_token(user_id: str) -> str:
now = int(time.time())
payload = {
"sub": user_id,
"iat": now,
"exp": now + env.MLPA_ACCESS_TOKEN_TTL_SECONDS,
"iss": "mlpa",
"typ": "mlpa_access",
}
return encode(payload, env.MLPA_ACCESS_TOKEN_SECRET, algorithm="HS256")
6 changes: 6 additions & 0 deletions src/mlpa/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from mlpa.core.routers.fxa import fxa_router
from mlpa.core.routers.health import health_router
from mlpa.core.routers.mock import mock_router
from mlpa.core.routers.play import play_router
from mlpa.core.routers.user import user_router
from mlpa.core.utils import get_or_create_user

Expand All @@ -35,6 +36,10 @@
"name": "App Attest",
"description": "Endpoints for verifying App Attest payloads.",
},
{
"name": "Play Integrity",
"description": "Endpoints for verifying Play Integrity payloads.",
},
{"name": "LiteLLM", "description": "Endpoints for interacting with LiteLLM."},
{"name": "Mock", "description": "Mock endpoints for testing purposes."},
{
Expand Down Expand Up @@ -112,6 +117,7 @@ async def get_metrics():

app.include_router(health_router, prefix="/health")
app.include_router(appattest_router, prefix="/verify")
app.include_router(play_router, prefix="/verify")
app.include_router(fxa_router, prefix="/fxa")
app.include_router(user_router, prefix="/user")
app.include_router(mock_router, prefix="/mock")
Expand Down
12 changes: 12 additions & 0 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest

from mlpa.core.config import env


@pytest.fixture(autouse=True, scope="session")
def _force_mlpa_debug_false():
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setenv("MLPA_DEBUG", "false")
env.MLPA_DEBUG = False
yield
monkeypatch.undo()
Loading