Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
.DS_Store
metrics.jsonl
service_account.json
play_integrity_service_account.json
.ruff_cache
# Since we are running it as a library, better not to commit the lock file
uv.lock
Expand Down
107 changes: 107 additions & 0 deletions scripts/play_integrity_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
E2E Play Integrity flow against a running MLPA server.
Requires a real Play Integrity token and configured service account file on the server.
"""

import argparse
import json
import os
from typing import Optional

import httpx

from mlpa.core.config import env

DEFAULT_BASE_URL = f"http://0.0.0.0:{env.PORT or 8080}"
DEFAULT_SERVICE_TYPE = "ai"


def _print_json(payload: dict) -> None:
print(json.dumps(payload, indent=2))


def _require_value(value: Optional[str], name: str) -> str:
if value:
return value
raise SystemExit(f"Missing required value for {name}.")


def run(args: argparse.Namespace) -> None:
integrity_token = _require_value(
args.integrity_token or os.getenv("MLPA_PLAY_INTEGRITY_TOKEN"),
"integrity_token",
)
user_id = _require_value(
args.user_id or os.getenv("MLPA_PLAY_USER_ID"),
"user_id",
)

verify_response = httpx.post(
f"{args.base_url}/verify/play",
json={"integrity_token": integrity_token, "user_id": user_id},
timeout=args.timeout_s,
)
verify_response.raise_for_status()
access_token = verify_response.json().get("access_token")
if not access_token:
raise SystemExit("No access_token returned from /verify/play.")

headers = {
"authorization": f"Bearer {access_token}",
"use-play-integrity": "true",
"service-type": args.service_type,
}
payload = {
"model": args.model or env.MODEL_NAME,
"messages": [{"role": "user", "content": args.message}],
"stream": args.stream,
}

if args.stream:
with httpx.stream(
"POST",
f"{args.base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=args.timeout_s,
) as response:
response.raise_for_status()
for line in response.iter_lines():
if line:
print(line)
else:
response = httpx.post(
f"{args.base_url}/v1/chat/completions",
headers=headers,
json=payload,
timeout=args.timeout_s,
)
response.raise_for_status()
_print_json(response.json())


def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="E2E Play Integrity verification + chat completion."
)
subparsers = parser.add_subparsers(dest="command")

run_parser = subparsers.add_parser("run", help="Verify and request a completion.")
run_parser.add_argument("--integrity-token", dest="integrity_token")
run_parser.add_argument("--user-id", dest="user_id")
run_parser.set_defaults(func=run)

return parser


def main() -> None:
parser = build_parser()
args = parser.parse_args()
if not getattr(args, "command", None):
parser.print_help()
raise SystemExit(2)
args.func(args)


if __name__ == "__main__":
main()
14 changes: 12 additions & 2 deletions src/mlpa/core/auth/authorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mlpa.core.config import env
from mlpa.core.routers.appattest import app_attest_auth
from mlpa.core.routers.fxa import fxa_auth
from mlpa.core.utils import parse_app_attest_jwt
from mlpa.core.utils import extract_user_from_play_integrity_jwt, parse_app_attest_jwt


async def authorize_request(
Expand All @@ -15,10 +15,12 @@ async def authorize_request(
service_type: Annotated[ServiceType, Header()],
use_app_attest: Annotated[bool | None, Header()] = None,
use_qa_certificates: Annotated[bool | None, Header()] = None,
use_play_integrity: Annotated[bool | None, Header()] = None,
) -> AuthorizedChatRequest:
if not authorization:
raise HTTPException(status_code=401, detail="Missing authorization header")
if use_app_attest:
# Apple App Attest
assertionAuth = parse_app_attest_jwt(authorization, "assert")
data = await app_attest_auth(assertionAuth, chat_request, use_qa_certificates)
if data:
Expand All @@ -28,8 +30,16 @@ async def authorize_request(
user=f"{assertionAuth.key_id_b64}:{service_type.value}", # "user" is key_id_b64 from app attest
**chat_request.model_dump(exclude_unset=True),
)
elif use_play_integrity:
# Google Play integrity
play_user_id = extract_user_from_play_integrity_jwt(authorization)
if play_user_id:
return AuthorizedChatRequest(
user=f"{play_user_id}:{service_type.value}",
**chat_request.model_dump(exclude_unset=True),
)
else:
# FxA authorization
# Firefox Account authorization
fxa_user_id = fxa_auth(authorization)
if fxa_user_id:
if fxa_user_id.get("error"):
Expand Down
7 changes: 7 additions & 0 deletions src/mlpa/core/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class UserUpdatePayload(BaseModel):
blocked: bool | None = None


# iOS App Attest
class AttestationAuth(BaseModel):
key_id_b64: str
challenge_b64: str
Expand All @@ -37,6 +38,12 @@ class AssertionAuth(BaseModel):
assertion_obj_b64: str


# Google Play Integrity
class PlayIntegrityRequest(BaseModel):
integrity_token: str
user_id: str


class AuthorizedChatRequest(ChatRequest):
user: str

Expand Down
8 changes: 8 additions & 0 deletions src/mlpa/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,17 @@ def valid_service_types(self) -> list[str]:
APP_ATTEST_QA_BUCKET_PREFIX: str | None = None
APP_ATTEST_QA_GCP_PROJECT_ID: str | None = None

# Play Integrity
PLAY_INTEGRITY_PACKAGE_NAME: str = "com.example.app"
PLAY_INTEGRITY_SERVICE_ACCOUNT_FILE: str = "play_integrity_service_account.json"
PLAY_INTEGRITY_REQUEST_TIMEOUT_SECONDS: int = 30
MLPA_ACCESS_TOKEN_SECRET: str = "mlpa-dev-secret"
MLPA_ACCESS_TOKEN_TTL_SECONDS: int = 300

# FxA
CLIENT_ID: str = "default-client-id"
CLIENT_SECRET: str = "default-client-secret"
FXA_SCOPE: str = "profile"

# PostgreSQL
LITELLM_DB_NAME: str = "litellm"
Expand Down
3 changes: 2 additions & 1 deletion src/mlpa/core/routers/fxa/fxa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from fastapi import APIRouter, Header, HTTPException

from mlpa.core.config import env
from mlpa.core.logger import logger
from mlpa.core.prometheus_metrics import PrometheusResult, metrics
from mlpa.core.utils import get_fxa_client
Expand All @@ -16,7 +17,7 @@ def fxa_auth(authorization: Annotated[str | None, Header()]):
token = authorization.removeprefix("Bearer ").split()[0]
result = PrometheusResult.ERROR
try:
profile = client.verify_token(token, scope="profile")
profile = client.verify_token(token, scope=env.FXA_SCOPE)
result = PrometheusResult.SUCCESS
except Exception as e:
logger.error(f"FxA auth error: {e}")
Expand Down
3 changes: 3 additions & 0 deletions src/mlpa/core/routers/play/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from mlpa.core.routers.play.play import router as play_router

__all__ = ["play_router"]
109 changes: 109 additions & 0 deletions src/mlpa/core/routers/play/play.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import hashlib
from functools import lru_cache

import httpx
from fastapi import APIRouter, HTTPException
from fastapi.concurrency import run_in_threadpool
from google.auth.transport.requests import Request
from google.oauth2 import service_account
from pydantic import BaseModel

from mlpa.core.classes import PlayIntegrityRequest
from mlpa.core.config import env
from mlpa.core.http_client import get_http_client
from mlpa.core.utils import issue_mlpa_access_token, raise_and_log

router = APIRouter()

PLAY_INTEGRITY_SCOPE = "https://www.googleapis.com/auth/playintegrity"
ALLOWED_DEVICE_VERDICTS = {
"MEETS_DEVICE_INTEGRITY",
"MEETS_BASIC_INTEGRITY",
"MEETS_STRONG_INTEGRITY",
}


@lru_cache(maxsize=1)
def _get_service_account_credentials():
return service_account.Credentials.from_service_account_file(
env.PLAY_INTEGRITY_SERVICE_ACCOUNT_FILE,
scopes=[PLAY_INTEGRITY_SCOPE],
)


def _get_play_integrity_access_token() -> str:
credentials = _get_service_account_credentials()
if not credentials.valid:
credentials.refresh(Request())
if not credentials.token:
raise HTTPException(status_code=500, detail="Failed to fetch access token")
return credentials.token


async def _decode_integrity_token(integrity_token: str) -> dict:
access_token = await run_in_threadpool(_get_play_integrity_access_token)
client = get_http_client()
try:
response = await client.post(
f"https://playintegrity.googleapis.com/v1/{env.PLAY_INTEGRITY_PACKAGE_NAME}:decodeIntegrityToken",
headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
},
json={"integrity_token": integrity_token},
timeout=env.PLAY_INTEGRITY_REQUEST_TIMEOUT_SECONDS,
)
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise_and_log(e, False, 401)
except Exception as e:
raise_and_log(e, False, 502, "Play Integrity validation service unavailable")
return response.json()


def _validate_integrity_payload(payload: dict, expected_hash: str) -> None:
request_details = payload.get("requestDetails", {})
package_name = request_details.get("requestPackageName")
if package_name and package_name != env.PLAY_INTEGRITY_PACKAGE_NAME:
raise HTTPException(status_code=401, detail="Invalid package name")

token_request_hash = request_details.get("requestHash")
if token_request_hash != expected_hash:
raise HTTPException(status_code=401, detail="Invalid request hash")

app_integrity = payload.get("appIntegrity", {})
acceptable_recognition_verdicts = [
"PLAY RECOGNIZED",
]
if env.MLPA_DEBUG:
acceptable_recognition_verdicts.append("UNRECOGNIZED_VERSION")
if (
app_integrity.get("appRecognitionVerdict")
not in acceptable_recognition_verdicts
):
raise HTTPException(status_code=401, detail="App not recognized by Play")

device_integrity = payload.get("deviceIntegrity", {})
device_verdicts = set(device_integrity.get("deviceRecognitionVerdict", []))
if not device_verdicts.intersection(ALLOWED_DEVICE_VERDICTS):
raise HTTPException(status_code=401, detail="Device integrity check failed")


@router.post("/play", tags=["Play Integrity"])
async def verify_play_integrity(payload: PlayIntegrityRequest):
decoded = await _decode_integrity_token(payload.integrity_token)
token_payload = decoded.get("tokenPayloadExternal") or decoded.get("tokenPayload")
if not token_payload:
raise HTTPException(status_code=401, detail="Invalid Play Integrity token")

# expected_hash = hashlib.sha256(payload.user_id.encode("utf-8")).hexdigest()
expected_hash = token_payload.get("requestDetails").get("requestHash") # TODO

_validate_integrity_payload(token_payload, expected_hash)

access_token = issue_mlpa_access_token(payload.user_id)
return {
"access_token": access_token,
"token_type": "Bearer",
"expires_in": env.MLPA_ACCESS_TOKEN_TTL_SECONDS,
}
38 changes: 37 additions & 1 deletion src/mlpa/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import ast
import base64
import json
import time

from fastapi import HTTPException
from fxa.oauth import Client
from jwtoxide import DecodingKey, ValidationOptions, decode
from jwtoxide import DecodingKey, ValidationOptions, decode, encode

from mlpa.core.classes import AssertionAuth, AttestationAuth
from mlpa.core.config import LITELLM_MASTER_AUTH_HEADERS, env
Expand Down Expand Up @@ -160,3 +161,38 @@ def raise_and_log(
else response_text_prefix or GENERIC_UPSTREAM_ERROR
},
)


def extract_user_from_play_integrity_jwt(authorization: str):
token = authorization.removeprefix("Bearer ").split()[0]
try:
payload = decode(
token,
env.MLPA_ACCESS_TOKEN_SECRET,
ValidationOptions(
required_spec_claims={"exp", "iat", "sub"},
iss={"mlpa"},
aud=None,
validate_aud=False,
validate_exp=True,
validate_nbf=False,
verify_signature=True,
algorithms=["HS256"],
),
)
return payload["sub"]
except Exception as e:
logger.error(f"Play Integrity JWT decode error: {e}")
raise HTTPException(status_code=401, detail="Invalid MLPA access token")


def issue_mlpa_access_token(user_id: str) -> str:
now = int(time.time())
payload = {
"sub": user_id,
"iat": now,
"exp": now + env.MLPA_ACCESS_TOKEN_TTL_SECONDS,
"iss": "mlpa",
"typ": "mlpa_access",
}
return encode(payload, env.MLPA_ACCESS_TOKEN_SECRET, algorithm="HS256")
Loading