diff --git a/README.md b/README.md index f6836411..9e6310e6 100644 --- a/README.md +++ b/README.md @@ -182,6 +182,11 @@ If you have enabled two-factor authentications (2FA) or [two-step authentication (2SA)](https://support.apple.com/en-us/HT204152) for the account you will have to do some extra work: +For HSA2 accounts, `request_2fa_code()` now starts Apple's active delivery +route for code-based challenges. Depending on the account and session, that may +be a trusted-device prompt or an SMS code. Security-key challenges are handled +separately via `security_key_names` / `confirm_security_key()`. + ```python import sys @@ -216,6 +221,7 @@ if api.requires_2fa: else: print("Two-factor authentication required.") + api.request_2fa_code() code = input( "Enter the code you received of one of your approved devices: " ) @@ -1254,8 +1260,6 @@ Notes caveats: - `api.notes.raw` is available for advanced/debug workflows, but it is not the primary Notes API surface. -### Notes CLI Example - [`examples/notes_cli.py`](examples/notes_cli.py) is a local developer utility built on top of `api.notes`. It is useful for searching notes, inspecting the rendering pipeline, and exporting HTML, but its selection heuristics and debug diff --git a/pyicloud/base.py b/pyicloud/base.py index 47897c70..572fdda6 100644 --- a/pyicloud/base.py +++ b/pyicloud/base.py @@ -5,9 +5,10 @@ import json import logging import time +from dataclasses import dataclass from os import chmod, environ, makedirs, path, umask from tempfile import gettempdir -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Mapping, Optional from uuid import uuid1 import srp @@ -34,6 +35,14 @@ PyiCloudPasswordException, PyiCloudServiceNotActivatedException, PyiCloudServiceUnavailable, + PyiCloudTrustedDevicePromptException, + PyiCloudTrustedDeviceVerificationException, +) +from pyicloud.hsa2_bridge import ( + Hsa2BootContext, + TrustedDeviceBridgeBootstrapper, + TrustedDeviceBridgeState, + parse_boot_args_html, ) from pyicloud.services import ( AccountService, @@ -106,6 +115,92 @@ def resolve_cookie_directory(cookie_directory: Optional[str] = None) -> str: return path.join(topdir, getpass.getuser()) +@dataclass(frozen=True) +class TrustedPhoneNumber: + """Typed view of Apple's trusted-phone metadata.""" + + device_id: int | str + non_fteu: Optional[bool] = None + push_mode: Optional[str] = None + + @classmethod + def from_mapping( + cls, value: Optional[Mapping[str, Any]] + ) -> Optional["TrustedPhoneNumber"]: + """Return a typed phone record when Apple's payload includes one.""" + + if not isinstance(value, Mapping): + return None + device_id = value.get("id") + if not isinstance(device_id, (int, str)): + return None + + non_fteu = value.get("nonFTEU") + if not isinstance(non_fteu, bool): + non_fteu = None + + push_mode = value.get("pushMode") + if push_mode is not None: + push_mode = str(push_mode) + + return cls( + device_id=device_id, + non_fteu=non_fteu, + push_mode=push_mode, + ) + + def as_phone_number_payload(self) -> dict[str, Any]: + """Return the nested phoneNumber payload expected by Apple's SMS endpoints.""" + + payload: dict[str, Any] = {"id": self.device_id} + if self.non_fteu is not None: + payload["nonFTEU"] = self.non_fteu + return payload + + +@dataclass(frozen=True) +class PhoneNumberVerification: + """Typed view of Apple's phone verification wrapper payload.""" + + trusted_phone_number: Optional[TrustedPhoneNumber] = None + trusted_phone_numbers: tuple[TrustedPhoneNumber, ...] = () + + @classmethod + def from_mapping( + cls, value: Optional[Mapping[str, Any]] + ) -> "PhoneNumberVerification": + """Return the parsed phone verification payload when Apple exposes one.""" + + if not isinstance(value, Mapping): + return cls() + + trusted_phone_number = TrustedPhoneNumber.from_mapping( + value.get("trustedPhoneNumber") + ) + + trusted_phone_numbers_raw = value.get("trustedPhoneNumbers") + trusted_phone_numbers: list[TrustedPhoneNumber] = [] + if isinstance(trusted_phone_numbers_raw, list): + for entry in trusted_phone_numbers_raw: + phone_number = TrustedPhoneNumber.from_mapping(entry) + if phone_number is not None: + trusted_phone_numbers.append(phone_number) + + return cls( + trusted_phone_number=trusted_phone_number, + trusted_phone_numbers=tuple(trusted_phone_numbers), + ) + + def best_trusted_phone_number(self) -> Optional[TrustedPhoneNumber]: + """Return the first usable trusted phone number from Apple's payload.""" + + if self.trusted_phone_number is not None: + return self.trusted_phone_number + if self.trusted_phone_numbers: + return self.trusted_phone_numbers[0] + return None + + class PyiCloudService: """ A base authentication class for the iCloud service. Handles the @@ -157,6 +252,7 @@ def __init__( authenticate: bool = True, cloudkit_validation_extra: Optional[CloudKitExtraMode] = None, ) -> None: + """Initialize a service session for one Apple ID account.""" self._is_china_mainland: bool = ( environ.get("icloud_china", "0") == "1" if china_mainland is None @@ -175,6 +271,11 @@ def __init__( self.data: dict[str, Any] = {} self._auth_data: dict[str, Any] = {} + self._hsa2_boot_context: Optional[Hsa2BootContext] = None + self._trusted_device_bridge_state: Optional[TrustedDeviceBridgeState] = None + self._trusted_device_bridge = TrustedDeviceBridgeBootstrapper() + self._two_factor_delivery_method: str = "unknown" + self._two_factor_delivery_notice: Optional[str] = None self.params: dict[str, Any] = {} self._client_id: str = client_id or str(uuid1()).lower() @@ -326,6 +427,10 @@ def _clear_authenticated_state(self) -> None: self.data = {} self._auth_data = {} + self._hsa2_boot_context = None + self._clear_trusted_device_bridge_state() + self._two_factor_delivery_method = "unknown" + self._two_factor_delivery_notice = None self._webservices = None self._account = None self._calendar = None @@ -339,6 +444,13 @@ def _clear_authenticated_state(self) -> None: self._requires_mfa = False self.params.pop("dsid", None) + def _clear_trusted_device_bridge_state(self) -> None: + """Close any active trusted-device bridge session and clear in-memory state.""" + + if self._trusted_device_bridge_state is not None: + self._trusted_device_bridge.close(self._trusted_device_bridge_state) + self._trusted_device_bridge_state = None + def get_auth_status(self) -> dict[str, Any]: """Probe current authentication state without prompting for login.""" @@ -416,6 +528,7 @@ def logout( } def _authenticate(self) -> None: + """Authenticate with either the cached session token or fresh credentials.""" LOGGER.debug("Authenticating as %s", self.account_name) try: @@ -542,6 +655,11 @@ def _authenticate_with_token(self) -> None: if not self.is_trusted_session: raise PyiCloud2FARequiredException(self.account_name, resp) + + self._auth_data = {} + self._hsa2_boot_context = None + self._clear_trusted_device_bridge_state() + self._set_two_factor_delivery_state("unknown") except (PyiCloudAPIResponseException, HTTPError) as error: msg = "Invalid authentication token." raise PyiCloudFailedLoginException(msg, error) from error @@ -584,6 +702,7 @@ def _validate_token(self) -> Any: def _get_auth_headers( self, overrides: Optional[dict[str, Any]] = None ) -> dict[str, Any]: + """Build Apple auth headers for IDMS, bridge, and verification requests.""" headers: dict[str, Any] = _AUTH_HEADERS_JSON.copy() headers.update( { @@ -612,6 +731,7 @@ def session(self) -> PyiCloudSession: return self._session def _is_mfa_required(self) -> bool: + """Return whether the current auth state still requires MFA completion.""" return ( self.data.get("hsaChallengeRequired", False) or not self.is_trusted_session @@ -679,9 +799,111 @@ def validate_verification_code(self, device: dict[str, Any], code: str) -> bool: def _get_mfa_auth_options(self) -> Dict: """Retrieve auth request options for assertion.""" + # Apple exposes the HSA2 bridge bootstrap in the HTML auth shell. + # Requesting JSON here tends to collapse the response to the SMS-oriented shape. + headers = self._get_auth_headers({"Accept": "text/html"}) + response = self.session.get(self._auth_endpoint, headers=headers) + + auth_options: dict[str, Any] = {} + try: + response_json = response.json() + except (AttributeError, TypeError, ValueError): + response_json = None + + if isinstance(response_json, dict): + auth_options.update(response_json) + boot_context = Hsa2BootContext.from_auth_options(auth_options) + else: + boot_context = parse_boot_args_html(getattr(response, "text", "")) + + boot_auth_data = boot_context.as_auth_data() + auth_options.update(boot_auth_data) + self._hsa2_boot_context = boot_context + self._clear_trusted_device_bridge_state() + self._set_two_factor_delivery_state("unknown") + return auth_options + + def _set_two_factor_delivery_state( + self, method: str, notice: Optional[str] = None + ) -> None: + """Track the active MFA delivery route for the current auth challenge.""" + + self._two_factor_delivery_method = method + self._two_factor_delivery_notice = notice + + def _current_hsa2_boot_context(self) -> Hsa2BootContext: + """Return the best available HSA2 boot context for the active challenge.""" + + if self._hsa2_boot_context is not None: + return self._hsa2_boot_context + + boot_context = Hsa2BootContext.from_auth_options(self._auth_data) + self._hsa2_boot_context = boot_context + return boot_context + + def _supports_trusted_device_bridge(self) -> bool: + """Return whether Apple's HSA2 boot data prefers the bridge flow.""" + + boot_context = self._current_hsa2_boot_context() + return ( + boot_context.auth_initial_route == "auth/bridge/step" + and boot_context.has_trusted_devices + and bool(boot_context.bridge_initiate_data) + ) + + def _can_request_sms_2fa_code(self) -> bool: + """Return whether SMS delivery is currently available.""" + + return ( + self._two_factor_mode() == "sms" + and self._trusted_phone_number() is not None + ) + + def _request_sms_2fa_code(self, notice: Optional[str] = None) -> bool: + """Trigger SMS delivery for the current HSA2 challenge.""" + + trusted_phone_number = self._trusted_phone_number() + if not trusted_phone_number: + raise PyiCloudNoTrustedNumberAvailable() + + data: dict[str, Any] = { + "phoneNumber": trusted_phone_number.as_phone_number_payload(), + "mode": "sms", + } headers = self._get_auth_headers({"Accept": CONTENT_TYPE_JSON}) - return self.session.get(self._auth_endpoint, headers=headers).json() + self.session.put( + f"{self._auth_endpoint}/verify/phone", + json=data, + headers=headers, + ) + self._clear_trusted_device_bridge_state() + self._set_two_factor_delivery_state("sms", notice) + return True + + @property + def two_factor_delivery_method(self) -> str: + """Return the current HSA2 delivery method without exposing auth internals.""" + + if self._two_factor_delivery_method != "unknown": + return self._two_factor_delivery_method + + if self._auth_data.get("fsaChallenge") or self.security_key_names: + return "security_key" + + if self._supports_trusted_device_bridge(): + return "trusted_device" + + if self._two_factor_mode() == "sms": + return "sms" + + return "unknown" + + @property + def two_factor_delivery_notice(self) -> Optional[str]: + """Return an optional user-facing note about the active 2FA delivery path.""" + + return self._two_factor_delivery_notice @property def security_key_names(self) -> Optional[List[str]]: @@ -696,6 +918,74 @@ def _submit_webauthn_assertion_response(self, data: Dict) -> None: f"{self._auth_endpoint}/verify/security/key", json=data, headers=headers ) + def _phone_number_verification(self) -> PhoneNumberVerification: + """Return Apple's nested phone verification payload when present.""" + + phone_verification = self._auth_data.get("phoneNumberVerification") + return PhoneNumberVerification.from_mapping(phone_verification) + + def _trusted_phone_number(self) -> Optional[TrustedPhoneNumber]: + """Return the best available trusted phone number description.""" + + trusted_phone_number = TrustedPhoneNumber.from_mapping( + self._auth_data.get("trustedPhoneNumber") + ) + if trusted_phone_number is not None: + return trusted_phone_number + + return self._phone_number_verification().best_trusted_phone_number() + + def _two_factor_mode(self) -> Optional[str]: + """Return the current 2FA delivery mode reported by Apple.""" + + mode = self._auth_data.get("mode") + if isinstance(mode, str): + return mode + + trusted_phone_number = self._trusted_phone_number() + if trusted_phone_number is None: + return None + + return trusted_phone_number.push_mode + + def request_2fa_code(self) -> bool: + """Trigger the active HSA2 delivery route for the current challenge.""" + + if self._auth_data.get("fsaChallenge") or self.security_key_names: + self._set_two_factor_delivery_state("security_key") + return False + + self._clear_trusted_device_bridge_state() + + if self._supports_trusted_device_bridge(): + try: + self._trusted_device_bridge_state = self._trusted_device_bridge.start( + session=self.session, + auth_endpoint=self._auth_endpoint, + headers=self._get_auth_headers({"Accept": CONTENT_TYPE_JSON}), + boot_context=self._current_hsa2_boot_context(), + user_agent=self.session.headers.get( + "User-Agent", _HEADERS["User-Agent"] + ), + ) + self._set_two_factor_delivery_state("trusted_device") + return True + except PyiCloudTrustedDevicePromptException: + LOGGER.debug( + "Trusted-device bridge bootstrap failed; falling back to SMS when available.", + exc_info=True, + ) + if self._can_request_sms_2fa_code(): + return self._request_sms_2fa_code( + notice="Trusted-device prompt failed; falling back to SMS." + ) + raise + + if self._can_request_sms_2fa_code(): + return self._request_sms_2fa_code() + + return False + @property def fido2_devices(self) -> List[CtapHidDevice]: """List the available FIDO2 devices.""" @@ -831,45 +1121,61 @@ def _request_pcs_for_service(self, app_name: str) -> None: def validate_2fa_code(self, code: str) -> bool: """Verifies a verification code received via Apple's 2FA system (HSA2).""" + bridge_state = self._trusted_device_bridge_state try: - if self._auth_data.get("mode") == "sms": + if self.two_factor_delivery_method == "sms": self._validate_sms_code(code) + elif ( + bridge_state is not None + and not bridge_state.uses_legacy_trusted_device_verifier + ): + if not self._trusted_device_bridge.validate_code( + session=self.session, + auth_endpoint=self._auth_endpoint, + headers=self._get_auth_headers({"Accept": CONTENT_TYPE_JSON}), + bridge_state=bridge_state, + code=code, + ): + LOGGER.error("Code verification failed.") + return False else: - data: dict[str, Any] = {"securityCode": {"code": code}} - headers: dict[str, Any] = self._get_auth_headers( - {"Accept": CONTENT_TYPE_JSON} - ) - self.session.post( - f"{self._auth_endpoint}/verify/trusteddevice/securitycode", - json=data, - headers=headers, - ) + self._validate_trusted_device_code(code) + except PyiCloudTrustedDeviceVerificationException: + raise except PyiCloudAPIResponseException: # Wrong verification code LOGGER.error("Code verification failed.") return False + finally: + if bridge_state is not None: + self._clear_trusted_device_bridge_state() LOGGER.debug("Code verification successful.") self.trust_session() return not self.requires_2sa + def _validate_trusted_device_code(self, code: str) -> None: + """Verifies a verification code received via Apple's legacy device endpoint.""" + + data: dict[str, Any] = {"securityCode": {"code": code}} + headers: dict[str, Any] = self._get_auth_headers({"Accept": CONTENT_TYPE_JSON}) + self.session.post( + f"{self._auth_endpoint}/verify/trusteddevice/securitycode", + json=data, + headers=headers, + ) + def _validate_sms_code(self, code: str) -> None: """Verifies a verification code received via Apple's SMS system.""" - trusted_phone_number: dict[str, Any] | None = self._auth_data.get( - "trustedPhoneNumber" - ) + trusted_phone_number = self._trusted_phone_number() if not trusted_phone_number: raise PyiCloudNoTrustedNumberAvailable() - device_id: int | None = trusted_phone_number.get("id") - non_fteu: bool | None = trusted_phone_number.get("nonFTEU") - mode: str | None = trusted_phone_number.get("pushMode") - data: dict[str, Any] = { - "phoneNumber": {"id": device_id, "nonFTEU": non_fteu}, + "phoneNumber": trusted_phone_number.as_phone_number_payload(), "securityCode": {"code": code}, - "mode": mode, + "mode": trusted_phone_number.push_mode or "sms", } headers: dict[str, Any] = self._get_auth_headers( {"Accept": f"{CONTENT_TYPE_JSON}, {CONTENT_TYPE_TEXT}"} @@ -1112,7 +1418,9 @@ def account_name(self) -> str: return self._apple_id def __str__(self) -> str: + """Return a concise human-readable service description.""" return f"iCloud API: {self.account_name}" def __repr__(self) -> str: + """Mirror ``__str__`` for interactive inspection.""" return f"<{self}>" diff --git a/pyicloud/cli/context.py b/pyicloud/cli/context.py index ab22fccb..84d29791 100644 --- a/pyicloud/cli/context.py +++ b/pyicloud/cli/context.py @@ -17,9 +17,13 @@ from pyicloud import PyiCloudService, utils from pyicloud.base import resolve_cookie_directory from pyicloud.exceptions import ( + PyiCloudAPIResponseException, PyiCloudAuthRequiredException, PyiCloudFailedLoginException, + PyiCloudNoTrustedNumberAvailable, PyiCloudServiceUnavailable, + PyiCloudTrustedDevicePromptException, + PyiCloudTrustedDeviceVerificationException, ) from pyicloud.ssl_context import configurable_ssl_verification @@ -95,6 +99,7 @@ def __init__( log_level: LogLevel, output_format: OutputFormat, ) -> None: + """Capture the CLI options and shared runtime state for one invocation.""" self.username = (username or "").strip() self.password = password self.china_mainland = china_mainland @@ -231,6 +236,7 @@ def remember_account(self, api: PyiCloudService, *, select: bool = True) -> None self._resolved_username = api.account_name def _resolve_username(self) -> str: + """Resolve the Apple ID to use for the current CLI command.""" if self._resolved_username: return self._resolved_username @@ -276,6 +282,7 @@ def multiple_logged_in_accounts_message(usernames: list[str]) -> str: ) def _password_for_login(self, username: str) -> tuple[Optional[str], Optional[str]]: + """Return the password and its source for an interactive login flow.""" if self.password: return self.password, "explicit" @@ -289,6 +296,7 @@ def _password_for_login(self, username: str) -> tuple[Optional[str], Optional[st return utils.get_password(username, interactive=True), "prompt" def _configure_logging(self) -> None: + """Apply the requested log level once for the current CLI process.""" if self._logging_configured: return logging.basicConfig(level=self.log_level.logging_level()) @@ -302,6 +310,7 @@ def _stored_password_for_session(self, username: str) -> Optional[str]: return utils.get_password_from_keyring(username) def _prompt_index(self, prompt: str, count: int) -> int: + """Prompt for a zero-based selection index when multiple choices exist.""" if count <= 1 or not self.interactive: return 0 raw = typer.prompt(prompt, default="0") @@ -314,6 +323,7 @@ def _prompt_index(self, prompt: str, count: int) -> int: return idx def _handle_2fa(self, api: PyiCloudService) -> None: + """Complete Apple's HSA2 flow using a security key or code-based challenge.""" fido2_devices = list(getattr(api, "fido2_devices", []) or []) if fido2_devices: self.console.print("Security key verification required.") @@ -332,13 +342,56 @@ def _handle_2fa(self, api: PyiCloudService) -> None: raise CLIAbort( "Two-factor authentication is required, but interactive prompts are disabled." ) - code = typer.prompt("Enter 2FA code") - if not api.validate_2fa_code(code): - raise CLIAbort("Failed to verify the 2FA code.") + try: + if not api.request_2fa_code(): + raise CLIAbort( + "This 2FA challenge requires a security key. Connect one and retry." + ) + + notice = getattr(api, "two_factor_delivery_notice", None) + if notice: + self.console.print(notice) + + delivery_method = getattr(api, "two_factor_delivery_method", "unknown") + if delivery_method == "trusted_device": + self.console.print( + "Requested a 2FA prompt on your trusted Apple devices." + ) + elif delivery_method == "sms": + self.console.print("Requested a 2FA code by SMS.") + except PyiCloudNoTrustedNumberAvailable as exc: + raise CLIAbort( + "Two-factor authentication requires a trusted phone number, " + "but none was returned." + ) from exc + except PyiCloudTrustedDevicePromptException as exc: + raise CLIAbort( + "Failed to request the 2FA trusted-device prompt." + ) from exc + except PyiCloudAPIResponseException as exc: + raise CLIAbort("Failed to request the 2FA SMS code.") from exc + max_attempts = 3 + for attempt in range(max_attempts): + code = typer.prompt("Enter 2FA code") + try: + is_valid = api.validate_2fa_code(code) + except PyiCloudTrustedDeviceVerificationException as exc: + raise CLIAbort( + "Failed to verify the 2FA trusted-device code." + ) from exc + if is_valid: + break + remaining_attempts = max_attempts - attempt - 1 + if remaining_attempts <= 0: + raise CLIAbort("Failed to verify the 2FA code.") + self.console.print( + f"Invalid 2FA code. {remaining_attempts} attempt(s) remaining." + ) if not api.is_trusted_session: api.trust_session() def _handle_2sa(self, api: PyiCloudService) -> None: + """Complete Apple's legacy two-step authentication flow.""" devices = list(api.trusted_devices or []) if not devices: raise CLIAbort( diff --git a/pyicloud/exceptions.py b/pyicloud/exceptions.py index 57f2a7de..b0772743 100644 --- a/pyicloud/exceptions.py +++ b/pyicloud/exceptions.py @@ -31,6 +31,7 @@ def __init__( code: Optional[Union[int, str]] = None, response: Optional[Response] = None, ) -> None: + """Capture a normalized API error and the optional HTTP context.""" self.reason: str = reason self.code: Optional[Union[int, str]] = code self.response: Optional[Response] = response @@ -58,6 +59,7 @@ def __init__( *args, response: Optional[Response] = None, ) -> None: + """Initialize a login failure with optional HTTP response details.""" self.response: Optional[Response] = response message: str = msg or "Failed login to iCloud" if response is not None and response.text: @@ -73,6 +75,7 @@ class PyiCloud2FARequiredException(PyiCloudException): """iCloud 2FA required exception.""" def __init__(self, apple_id: str, response: Response) -> None: + """Initialize a 2FA-required error for an HSA2 login challenge.""" message: str = f"2FA authentication required for account: {apple_id} (HSA2)" super().__init__(message) self.response: Response = response @@ -82,6 +85,7 @@ class PyiCloud2SARequiredException(PyiCloudException): """iCloud 2SA required exception.""" def __init__(self, apple_id: str) -> None: + """Initialize a 2SA-required error for a legacy login challenge.""" message: str = f"Two-step authentication required for account: {apple_id}" super().__init__(message) @@ -90,6 +94,7 @@ class PyiCloudAuthRequiredException(PyiCloudException): """iCloud re-authentication required exception.""" def __init__(self, apple_id: str, response: Response) -> None: + """Initialize a reauthentication-required error with the triggering response.""" message: str = f"Re-authentication required for account: {apple_id}" super().__init__(message) self.response: Response = response @@ -99,6 +104,14 @@ class PyiCloudNoTrustedNumberAvailable(PyiCloudException): """iCloud no trusted number exception.""" +class PyiCloudTrustedDevicePromptException(PyiCloudAPIResponseException): + """Trusted-device prompt bootstrap exception.""" + + +class PyiCloudTrustedDeviceVerificationException(PyiCloudAPIResponseException): + """Trusted-device bridge verification exception.""" + + class PyiCloudNoStoredPasswordAvailableException(PyiCloudException): """iCloud no stored password exception.""" diff --git a/pyicloud/hsa2_bridge.py b/pyicloud/hsa2_bridge.py new file mode 100644 index 00000000..d9b8209d --- /dev/null +++ b/pyicloud/hsa2_bridge.py @@ -0,0 +1,1690 @@ +"""Internal helpers for Apple's HSA2 trusted-device bridge flow.""" + +from __future__ import annotations + +import base64 +import hashlib +import json +import logging +import os +import socket +import ssl +import struct +import time +import uuid +from binascii import Error as BinasciiError +from dataclasses import dataclass, field +from html.parser import HTMLParser +from typing import Any, Callable, Mapping, Optional, Protocol +from urllib.parse import urlparse + +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec +from pydantic import ( + BaseModel, + ConfigDict, + Field, + StrictInt, + StrictStr, + ValidationError, + field_validator, +) + +from pyicloud.exceptions import ( + PyiCloudTrustedDevicePromptException, + PyiCloudTrustedDeviceVerificationException, +) +from pyicloud.hsa2_bridge_prover import TrustedDeviceBridgeProver + +LOGGER = logging.getLogger(__name__) + +BRIDGE_STEP_PATH = "/bridge/step/0" +BRIDGE_STEP_PATH_TEMPLATE = "/bridge/step/{step}" +BRIDGE_CODE_VALIDATE_PATH = "/bridge/code/validate" +NEW_CONNECTION_EXPIRATION_SECONDS = 86400 +OPCODE_BINARY = 0x2 +OPCODE_CLOSE = 0x8 +OPCODE_PING = 0x9 +OPCODE_PONG = 0xA +SERVER_MESSAGE_CONNECTION_RESPONSE = 1 +SERVER_MESSAGE_PUSH = 2 +SERVER_MESSAGE_CHANNEL_SUBSCRIPTION_RESPONSE = 3 +SERVER_MESSAGE_PUSH_ACK = 7 +STATUS_OK = 0 +STATUS_INVALID_NONCE = 2 +BRIDGE_SIGNATURE_PREFIX = b"\x01\x03" +BRIDGE_DONE_DATA_B64 = base64.b64encode(b"done").decode("ascii") +WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" +WEBSOCKET_TIMEOUT_SECONDS = 30.0 +WEBSOCKET_ENVIRONMENT_HOSTS: dict[str, str] = { + "prod": "websocket.push.apple.com", + "sandbox": "websocket.sandbox.push.apple.com", +} +HTTP_STATUS_OK = 200 +HTTP_STATUS_NO_CONTENT = 204 +HTTP_STATUS_CONFLICT = 409 +HTTP_STATUS_PRECONDITION_FAILED = 412 + + +@dataclass(frozen=True) +class Hsa2BootContext: + """Bridge-related HSA2 boot data parsed from Apple's HTML bootstrap.""" + + auth_initial_route: str = "" + has_trusted_devices: bool = False + auth_factors: tuple[str, ...] = () + bridge_initiate_data: dict[str, Any] = field(default_factory=dict) + phone_number_verification: dict[str, Any] = field(default_factory=dict) + source_app_id: Optional[str] = None + + @classmethod + def from_auth_options(cls, auth_options: Mapping[str, Any]) -> "Hsa2BootContext": + """Build a normalized boot context from Apple's auth-options payload.""" + bridge_initiate_data = auth_options.get("bridgeInitiateData") + if not isinstance(bridge_initiate_data, dict): + bridge_initiate_data = {} + + phone_number_verification = auth_options.get("phoneNumberVerification") + if not isinstance(phone_number_verification, dict): + phone_number_verification = bridge_initiate_data.get( + "phoneNumberVerification" + ) + if not isinstance(phone_number_verification, dict): + phone_number_verification = {} + + auth_factors = auth_options.get("authFactors") + if not isinstance(auth_factors, list): + auth_factors = [] + + source_app_id = auth_options.get("sourceAppId") + if source_app_id is not None: + source_app_id = str(source_app_id) + + return cls( + auth_initial_route=str(auth_options.get("authInitialRoute") or ""), + has_trusted_devices=bool(auth_options.get("hasTrustedDevices")), + auth_factors=tuple( + factor for factor in auth_factors if isinstance(factor, str) + ), + bridge_initiate_data=dict(bridge_initiate_data), + phone_number_verification=dict(phone_number_verification), + source_app_id=source_app_id, + ) + + def as_auth_data(self) -> dict[str, Any]: + """Return parsed boot data in the shape expected by the auth flow.""" + + auth_data: dict[str, Any] = { + "authInitialRoute": self.auth_initial_route, + "hasTrustedDevices": self.has_trusted_devices, + "authFactors": list(self.auth_factors), + } + if self.bridge_initiate_data: + auth_data["bridgeInitiateData"] = dict(self.bridge_initiate_data) + if self.phone_number_verification: + auth_data["phoneNumberVerification"] = dict(self.phone_number_verification) + trusted_phone_number = self.phone_number_verification.get( + "trustedPhoneNumber" + ) + if isinstance(trusted_phone_number, dict): + auth_data["trustedPhoneNumber"] = dict(trusted_phone_number) + if self.source_app_id is not None: + auth_data["sourceAppId"] = self.source_app_id + return auth_data + + +class _BridgePushPayloadModel(BaseModel): + """Strict validator for Apple's bridge push JSON envelope.""" + + model_config = ConfigDict( + extra="allow", + populate_by_name=True, + arbitrary_types_allowed=True, + ) + + session_uuid: StrictStr = Field(alias="sessionUUID") + next_step: Optional[StrictStr | StrictInt] = Field(default=None, alias="nextStep") + rui_url_key: Optional[str] = Field(default=None, alias="ruiURLKey") + txnid: Optional[StrictStr] = None + salt: Optional[StrictStr] = None + mid: Optional[StrictStr] = None + idmsdata: Optional[StrictStr] = None + akdata: Any = None + data: Optional[StrictStr] = None + encrypted_code: Optional[StrictStr] = Field(default=None, alias="encryptedCode") + error_code: Optional[StrictInt] = Field(default=None, alias="ec") + + @field_validator("session_uuid") + @classmethod + def _validate_session_uuid(cls, value: str) -> str: + """Reject blank bridge session identifiers.""" + if not value.strip(): + raise ValueError("sessionUUID must not be blank") + return value + + @field_validator( + "txnid", + "salt", + "mid", + "idmsdata", + "data", + "encrypted_code", + ) + @classmethod + def _validate_optional_non_empty_strings( + cls, value: Optional[str] + ) -> Optional[str]: + """Reject present-but-blank optional bridge string fields.""" + if value is not None and not value.strip(): + raise ValueError("Bridge payload strings must not be blank") + return value + + @field_validator("next_step") + @classmethod + def _validate_next_step(cls, value: Optional[str | int]) -> Optional[str | int]: + """Reject blank next-step markers while allowing ints or strings.""" + if isinstance(value, str) and not value.strip(): + raise ValueError("nextStep must not be blank") + return value + + +@dataclass(frozen=True) +class BridgePushPayload: + """Decoded bridge push metadata needed to bootstrap trusted-device prompts.""" + + payload: dict[str, Any] + session_uuid: str + next_step: Optional[str] = None + rui_url_key: Optional[str] = None + txnid: Optional[str] = None + salt: Optional[str] = None + mid: Optional[str] = None + idmsdata: Optional[str] = None + akdata: Any = None + data: Optional[str] = None + encrypted_code: Optional[str] = None + error_code: Optional[int] = None + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> "BridgePushPayload": + """Validate and normalize one decoded bridge push payload.""" + try: + validated = _BridgePushPayloadModel.model_validate(payload) + except ValidationError as exc: + raise PyiCloudTrustedDevicePromptException( + "Malformed trusted-device bridge push payload." + ) from exc + + if not validated.session_uuid: + raise PyiCloudTrustedDevicePromptException( + "Trusted-device bridge push payload is missing sessionUUID." + ) + + return cls( + payload=payload, + session_uuid=validated.session_uuid, + next_step=( + str(validated.next_step) if validated.next_step is not None else None + ), + rui_url_key=validated.rui_url_key, + txnid=validated.txnid, + salt=validated.salt, + mid=validated.mid, + idmsdata=validated.idmsdata, + akdata=validated.akdata, + data=validated.data, + encrypted_code=validated.encrypted_code, + error_code=validated.error_code, + ) + + +@dataclass +class TrustedDeviceBridgeState: + """Ephemeral trusted-device bridge state.""" + + connection_path: str + push_token: str + session_uuid: str + websocket: Optional[_WebSocketLike] + topic: str + topics_by_hash: dict[str, str] + source_app_id: Optional[str] = None + next_step: Optional[str] = None + rui_url_key: Optional[str] = None + push_payload: dict[str, Any] = field(default_factory=dict) + txnid: Optional[str] = None + salt: Optional[str] = None + mid: Optional[str] = None + idmsdata: Optional[str] = None + akdata: Any = None + data: Optional[str] = None + encrypted_code: Optional[str] = None + error_code: Optional[int] = None + + def apply_push_payload(self, push_payload: BridgePushPayload) -> None: + """Persist the latest bridge push metadata in the live bridge session.""" + + self.push_payload = dict(push_payload.payload) + self.session_uuid = push_payload.session_uuid + self.next_step = push_payload.next_step + self.rui_url_key = push_payload.rui_url_key + self.txnid = push_payload.txnid + self.salt = push_payload.salt + self.mid = push_payload.mid + self.idmsdata = push_payload.idmsdata + self.akdata = push_payload.akdata + self.data = push_payload.data + self.encrypted_code = push_payload.encrypted_code + self.error_code = push_payload.error_code + + @property + def uses_legacy_trusted_device_verifier(self) -> bool: + """Return whether Apple routed this bridge challenge to the legacy verifier.""" + + return bool(self.txnid and self.txnid.endswith("_W")) + + +@dataclass(frozen=True) +class BridgeStepRequest: + """Typed request body for Apple's bridge step endpoints.""" + + session_uuid: str + data: str + push_token: str + next_step: int + idmsdata: Optional[str] = None + akdata: Any = None + + def as_json(self) -> dict[str, Any]: + """Serialize the step request into Apple's JSON envelope.""" + payload: dict[str, Any] = { + "sessionUUID": self.session_uuid, + "data": self.data, + "ptkn": self.push_token, + "nextStep": self.next_step, + } + if self.idmsdata is not None: + payload["idmsdata"] = self.idmsdata + if self.akdata is not None: + payload["akdata"] = ( + json.dumps(self.akdata, separators=(",", ":")) + if isinstance(self.akdata, dict) + else self.akdata + ) + return payload + + +@dataclass(frozen=True) +class BridgeCodeValidateRequest: + """Typed request body for Apple's final bridge code validation endpoint.""" + + session_uuid: str + code: str + + def as_json(self) -> dict[str, str]: + """Serialize the final bridge code-validation request body.""" + return { + "sessionUUID": self.session_uuid, + "code": self.code, + } + + +@dataclass(frozen=True) +class _ConnectionResponse: + """Decoded server response for the initial websocket bootstrap.""" + + push_token_b64: str = "" + status: int = 0 + server_timestamp_seconds: Optional[int] = None + + +@dataclass(frozen=True) +class _PushMessage: + """Decoded APNS-style push frame from the bridge websocket.""" + + topic: bytes + message_id: int + payload: bytes + + +@dataclass(frozen=True) +class _ChannelSubscriptionResponse: + """Decoded response to the bridge topic subscription request.""" + + message_id: int = 0 + status: int = 0 + retry_interval_seconds: int = 0 + topics: tuple[str, ...] = () + + +@dataclass(frozen=True) +class _AcknowledgementMessage: + """Decoded acknowledgment frame emitted by Apple's bridge service.""" + + topic: bytes + message_id: int + delivery_status: int = 0 + + +@dataclass(frozen=True) +class _ServerMessage: + """One websocket frame decoded into its known top-level message variants.""" + + connection_response: Optional[_ConnectionResponse] = None + push_message: Optional[_PushMessage] = None + channel_subscription_response: Optional[_ChannelSubscriptionResponse] = None + push_acknowledgment: Optional[_AcknowledgementMessage] = None + field_numbers: tuple[int, ...] = () + + +class _WebSocketLike(Protocol): + """Protocol for the minimal websocket operations used by the bridge flow.""" + + def send_binary(self, payload: bytes) -> None: + """Send one binary websocket message.""" + + def read_message(self) -> bytes: + """Read one complete websocket message payload.""" + + def close(self) -> None: + """Close the websocket transport.""" + + +class _InvalidNonceError(Exception): + """Signal Apple's INVALID_NONCE response along with the server timestamp.""" + + def __init__(self, server_timestamp_ms: int) -> None: + """Capture the server timestamp returned with INVALID_NONCE.""" + super().__init__("Invalid nonce from bridge server.") + self.server_timestamp_ms = server_timestamp_ms + + +class _BootArgsHTMLParser(HTMLParser): + """Extract the JSON body from Apple's boot_args script tag.""" + + def __init__(self) -> None: + """Initialize parser state for the first matching boot_args script tag.""" + super().__init__() + self._collecting = False + self._found = False + self._chunks: list[str] = [] + + @property + def payload(self) -> str: + """Return the collected boot_args JSON text.""" + return "".join(self._chunks).strip() + + def handle_starttag(self, tag: str, attrs: list[tuple[str, Optional[str]]]) -> None: + """Start collecting data when the boot_args script tag is found.""" + if tag != "script" or self._found: + return + attr_map = {key: value for key, value in attrs} + classes = (attr_map.get("class") or "").split() + if "boot_args" in classes: + self._collecting = True + self._found = True + + def handle_endtag(self, tag: str) -> None: + """Stop collecting when the current script tag closes.""" + if tag == "script" and self._collecting: + self._collecting = False + + def handle_data(self, data: str) -> None: + """Append script contents while the boot_args tag is active.""" + if self._collecting: + self._chunks.append(data) + + +def parse_boot_args_html(html_text: str) -> Hsa2BootContext: + """Extract HSA2 boot args from the HTML returned by GET /appleauth/auth.""" + + parser = _BootArgsHTMLParser() + parser.feed(html_text) + parser.close() + + payload_text = parser.payload + if not payload_text: + raise PyiCloudTrustedDevicePromptException("Missing HSA2 boot args payload.") + + try: + payload = json.loads(payload_text) + except json.JSONDecodeError as exc: + raise PyiCloudTrustedDevicePromptException( + "Malformed HSA2 boot args payload." + ) from exc + direct = payload.get("direct") + if not isinstance(direct, dict): + raise PyiCloudTrustedDevicePromptException("Missing HSA2 direct boot data.") + + two_sv = direct.get("twoSV") + if not isinstance(two_sv, dict): + two_sv = {} + + bridge_initiate_data = two_sv.get("bridgeInitiateData") + if not isinstance(bridge_initiate_data, dict): + bridge_initiate_data = {} + + phone_number_verification = bridge_initiate_data.get("phoneNumberVerification") + if not isinstance(phone_number_verification, dict): + phone_number_verification = {} + + auth_factors = two_sv.get("authFactors") + if not isinstance(auth_factors, list): + auth_factors = [] + + source_app_id = two_sv.get("sourceAppId") + if source_app_id is not None: + source_app_id = str(source_app_id) + + return Hsa2BootContext( + auth_initial_route=str(direct.get("authInitialRoute") or ""), + has_trusted_devices=bool(direct.get("hasTrustedDevices")), + auth_factors=tuple( + factor for factor in auth_factors if isinstance(factor, str) + ), + bridge_initiate_data=dict(bridge_initiate_data), + phone_number_verification=dict(phone_number_verification), + source_app_id=source_app_id, + ) + + +def _encode_varint(value: int) -> bytes: + """Encode an unsigned protobuf varint.""" + if value < 0: + raise ValueError("Negative varints are not supported.") + parts = bytearray() + while True: + to_write = value & 0x7F + value >>= 7 + if value: + parts.append(to_write | 0x80) + else: + parts.append(to_write) + return bytes(parts) + + +def _read_varint(data: bytes, offset: int) -> tuple[int, int]: + """Decode one protobuf varint from a byte string and return the new offset.""" + value = 0 + shift = 0 + start_offset = offset + while True: + if offset >= len(data): + raise PyiCloudTrustedDevicePromptException("Truncated protobuf varint.") + byte = data[offset] + offset += 1 + value |= (byte & 0x7F) << shift + if not (byte & 0x80): + return value, offset + shift += 7 + # Guard against malformed wire data rather than silently accepting an + # overlong varint from Apple's private bridge protocol. + if shift > 63 or offset - start_offset >= 10: + raise PyiCloudTrustedDevicePromptException("Malformed protobuf varint.") + + +def _encode_field(field_number: int, wire_type: int, value: bytes) -> bytes: + """Encode one protobuf field header and payload.""" + return _encode_varint((field_number << 3) | wire_type) + value + + +def _encode_bytes_field(field_number: int, value: bytes) -> bytes: + """Encode a length-delimited protobuf field.""" + return _encode_field(field_number, 2, _encode_varint(len(value)) + value) + + +def _encode_string_field(field_number: int, value: str) -> bytes: + """Encode a UTF-8 string protobuf field.""" + return _encode_bytes_field(field_number, value.encode("utf-8")) + + +def _encode_uint32_field(field_number: int, value: int) -> bytes: + """Encode an unsigned integer protobuf field.""" + return _encode_field(field_number, 0, _encode_varint(value)) + + +def _decode_fields(data: bytes) -> dict[int, list[Any]]: + """Decode a minimal subset of protobuf wire types into field lists.""" + offset = 0 + fields: dict[int, list[Any]] = {} + while offset < len(data): + key, offset = _read_varint(data, offset) + field_number = key >> 3 + wire_type = key & 0x07 + + if wire_type == 0: + value, offset = _read_varint(data, offset) + elif wire_type == 2: + length, offset = _read_varint(data, offset) + # Length-delimited fields must stay within the current message + # bounds; otherwise the bridge frame is truncated or malformed. + end_offset = offset + length + if end_offset > len(data): + raise PyiCloudTrustedDevicePromptException("Truncated protobuf field.") + value = data[offset:end_offset] + offset = end_offset + else: + raise PyiCloudTrustedDevicePromptException( + f"Unsupported protobuf wire type: {wire_type}" + ) + + fields.setdefault(field_number, []).append(value) + return fields + + +def _decode_connection_response(message: bytes) -> _ConnectionResponse: + """Decode the server's websocket bootstrap response.""" + fields = _decode_fields(message) + push_token_b64 = "" + if fields.get(1): + try: + push_token_b64 = fields[1][0].decode("ascii") + except UnicodeDecodeError as exc: + raise PyiCloudTrustedDevicePromptException( + "Malformed bridge connection response push token." + ) from exc + status = int(fields.get(2, [0])[0]) + server_timestamp_seconds = None + if fields.get(3): + server_timestamp_seconds = int(fields[3][0]) + return _ConnectionResponse( + push_token_b64=push_token_b64, + status=status, + server_timestamp_seconds=server_timestamp_seconds, + ) + + +def _decode_push_message(message: bytes) -> _PushMessage: + """Decode one push-delivery frame from the bridge websocket.""" + fields = _decode_fields(message) + topic = bytes(fields.get(1, [b""])[0]) + message_id = int(fields.get(2, [0])[0]) + payload = bytes(fields.get(4, [b""])[0]) + return _PushMessage(topic=topic, message_id=message_id, payload=payload) + + +def _decode_channel_subscription_response( + message: bytes, +) -> _ChannelSubscriptionResponse: + """Decode the server's response to the topic subscription message.""" + fields = _decode_fields(message) + topics: list[str] = [] + + payload_values = fields.get(1) + if payload_values: + payload_fields = _decode_fields(bytes(payload_values[0])) + for app_response_value in payload_fields.get(1, []): + app_response_fields = _decode_fields(bytes(app_response_value)) + topic_value = app_response_fields.get(1, [b""])[0] + if isinstance(topic_value, bytes): + topics.append(topic_value.decode("utf-8", "ignore")) + + return _ChannelSubscriptionResponse( + message_id=int(fields.get(2, [0])[0]), + status=int(fields.get(3, [0])[0]), + retry_interval_seconds=int(fields.get(4, [0])[0]), + topics=tuple(topic for topic in topics if topic), + ) + + +def _decode_acknowledgement_message(message: bytes) -> _AcknowledgementMessage: + """Decode a push acknowledgment frame from the bridge websocket.""" + fields = _decode_fields(message) + topic = bytes(fields.get(1, [b""])[0]) + message_id = int(fields.get(2, [0])[0]) + delivery_status = int(fields.get(3, [0])[0]) + return _AcknowledgementMessage( + topic=topic, + message_id=message_id, + delivery_status=delivery_status, + ) + + +def _decode_server_message(message: bytes) -> _ServerMessage: + """Decode all known top-level messages embedded in one websocket frame.""" + fields = _decode_fields(message) + + connection_response = None + if fields.get(SERVER_MESSAGE_CONNECTION_RESPONSE): + connection_response = _decode_connection_response( + bytes(fields[SERVER_MESSAGE_CONNECTION_RESPONSE][0]) + ) + + push_message = None + if fields.get(SERVER_MESSAGE_PUSH): + push_message = _decode_push_message(bytes(fields[SERVER_MESSAGE_PUSH][0])) + + channel_subscription_response = None + if fields.get(SERVER_MESSAGE_CHANNEL_SUBSCRIPTION_RESPONSE): + channel_subscription_response = _decode_channel_subscription_response( + bytes(fields[SERVER_MESSAGE_CHANNEL_SUBSCRIPTION_RESPONSE][0]) + ) + + push_acknowledgment = None + if fields.get(SERVER_MESSAGE_PUSH_ACK): + push_acknowledgment = _decode_acknowledgement_message( + bytes(fields[SERVER_MESSAGE_PUSH_ACK][0]) + ) + + return _ServerMessage( + connection_response=connection_response, + push_message=push_message, + channel_subscription_response=channel_subscription_response, + push_acknowledgment=push_acknowledgment, + field_numbers=tuple(sorted(fields)), + ) + + +def _encode_connection_message( + public_key: bytes, nonce: bytes, signature: bytes +) -> bytes: + """Encode the initial bridge websocket bootstrap message.""" + connection_message = b"".join( + [ + _encode_bytes_field(1, public_key), + _encode_bytes_field(2, nonce), + _encode_bytes_field(3, _encode_bridge_signature(signature)), + _encode_bytes_field( + 5, _encode_uint32_field(1, NEW_CONNECTION_EXPIRATION_SECONDS) + ), + ] + ) + return _encode_bytes_field(1, connection_message) + + +def _encode_bridge_signature(signature: bytes) -> bytes: + """Wrap the DER ECDSA signature using Apple's bridge signature envelope.""" + + if signature.startswith(BRIDGE_SIGNATURE_PREFIX): + return signature + return BRIDGE_SIGNATURE_PREFIX + signature + + +def _encode_web_filter_message(allowed_topics: list[str]) -> bytes: + """Encode the topic subscription message sent after bridge connect.""" + filter_payload = b"".join( + _encode_string_field(1, topic) for topic in allowed_topics + ) + return _encode_bytes_field(3, filter_payload) + + +def _encode_ack_message(topic: bytes, message_id: int) -> bytes: + """Encode the acknowledgment frame for one delivered push message.""" + ack_payload = b"".join( + [ + _encode_bytes_field(1, topic), + _encode_uint32_field(2, message_id), + ] + ) + return _encode_bytes_field(2, ack_payload) + + +def _topic_hash(topic: str) -> str: + """Return Apple's websocket topic hash for a named APNS topic.""" + return hashlib.sha1(topic.encode("utf-8")).hexdigest() + + +def _topic_name(topic_bytes: bytes, topics_by_hash: Mapping[str, str]) -> str: + """Resolve a hashed topic payload back to a readable topic name.""" + return topics_by_hash.get(topic_bytes.hex(), topic_bytes.decode("utf-8", "ignore")) + + +def _extract_json_payload(payload: bytes) -> dict[str, Any]: + """Extract the JSON object embedded in one bridge push payload.""" + try: + return json.loads(payload.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError): + text = payload.decode("utf-8", "ignore") + + start = text.find("{") + while start >= 0: + depth = 0 + in_string = False + escaped = False + for index, character in enumerate(text[start:], start=start): + if in_string: + if escaped: + escaped = False + elif character == "\\": + escaped = True + elif character == '"': + in_string = False + continue + if character == '"': + in_string = True + elif character == "{": + depth += 1 + elif character == "}": + depth -= 1 + if depth == 0: + try: + return json.loads(text[start : index + 1]) + except json.JSONDecodeError: + break + start = text.find("{", start + 1) + + raise PyiCloudTrustedDevicePromptException( + "Could not decode the trusted-device bridge push payload." + ) + + +def _b64_to_hex(value: str) -> str: + """Decode base64 bridge data and return it as lowercase hex.""" + try: + return base64.b64decode(value.encode("ascii"), validate=True).hex() + except (ValueError, BinasciiError) as exc: + raise ValueError("Malformed base64-encoded bridge payload.") from exc + + +def _hex_to_b64(value: str) -> str: + """Encode hex bridge data as standard base64 text.""" + return base64.b64encode(bytes.fromhex(value)).decode("ascii") + + +def _build_nonce(timestamp_ms: int) -> bytes: + """Build the nonce format expected by Apple's bridge bootstrap.""" + return b"\x00" + timestamp_ms.to_bytes(8, "big", signed=False) + os.urandom(8) + + +def _summarize_identifier( + value: Optional[str], *, prefix: int = 8, empty: str = "" +) -> str: + """Shorten sensitive identifiers before logging them at debug level.""" + if not value: + return empty + if len(value) <= prefix: + return value + return f"{value[:prefix]}..." + + +def _resolve_websocket_host(boot_context: Hsa2BootContext) -> str: + """Resolve the websocket host Apple expects for the bridge session.""" + bridge_data = boot_context.bridge_initiate_data + web_socket_url = bridge_data.get("webSocketUrl") + if isinstance(web_socket_url, str) and web_socket_url: + if "://" in web_socket_url: + parsed = urlparse(web_socket_url) + if parsed.hostname: + return parsed.hostname + return web_socket_url.split("/", 1)[0] + + environment = bridge_data.get("apnsEnvironment") + if isinstance(environment, str) and environment in WEBSOCKET_ENVIRONMENT_HOSTS: + return WEBSOCKET_ENVIRONMENT_HOSTS[environment] + + raise PyiCloudTrustedDevicePromptException( + "Missing HSA2 websocket host for the trusted-device bridge." + ) + + +def _resolve_apns_topic(boot_context: Hsa2BootContext) -> str: + """Resolve the APNS topic Apple uses for trusted-device pushes.""" + topic = boot_context.bridge_initiate_data.get("apnsTopic") + if isinstance(topic, str) and topic: + return topic + + raise PyiCloudTrustedDevicePromptException( + "Missing HSA2 APNS topic for the trusted-device bridge." + ) + + +def _derive_origin(auth_endpoint: str) -> str: + """Derive the websocket Origin header from the auth endpoint URL.""" + parsed = urlparse(auth_endpoint) + if not parsed.scheme or not parsed.hostname: + raise PyiCloudTrustedDevicePromptException( + "Invalid auth endpoint for trusted-device bridge." + ) + return f"{parsed.scheme}://{parsed.hostname}" + + +class _RawWebSocketClient: + """Minimal websocket client for Apple's webcourier bridge.""" + + def __init__( + self, + url: str, + timeout: float, + origin: str, + user_agent: str, + ) -> None: + """Open a websocket connection and prepare buffered frame reads.""" + self._url = url + self._timeout = timeout + self._origin = origin + self._user_agent = user_agent + self._buffer = bytearray() + self._socket = self._open() + + def _open(self) -> ssl.SSLSocket: + """Perform the websocket HTTP upgrade and return the TLS socket.""" + parsed = urlparse(self._url) + if parsed.scheme != "wss" or not parsed.hostname: + raise PyiCloudTrustedDevicePromptException( + f"Unsupported websocket URL: {self._url}" + ) + + port = parsed.port or 443 + resource = parsed.path or "/" + if parsed.query: + resource = f"{resource}?{parsed.query}" + + raw_socket = socket.create_connection((parsed.hostname, port), self._timeout) + context = ssl.create_default_context() + secure_socket = context.wrap_socket(raw_socket, server_hostname=parsed.hostname) + secure_socket.settimeout(self._timeout) + + websocket_key = base64.b64encode(os.urandom(16)).decode("ascii") + request_headers = [ + f"GET {resource} HTTP/1.1", + f"Host: {parsed.hostname}", + "Upgrade: websocket", + "Connection: Upgrade", + f"Origin: {self._origin}", + f"User-Agent: {self._user_agent}", + "Sec-WebSocket-Version: 13", + f"Sec-WebSocket-Key: {websocket_key}", + "\r\n", + ] + secure_socket.sendall("\r\n".join(request_headers).encode("ascii")) + + response = self._read_http_response(secure_socket) + status_line, _, headers_text = response.partition("\r\n") + if " 101 " not in status_line: + raise PyiCloudTrustedDevicePromptException( + f"Websocket upgrade failed: {status_line}" + ) + + headers: dict[str, str] = {} + for line in headers_text.split("\r\n"): + if not line or ":" not in line: + continue + key, value = line.split(":", 1) + headers[key.strip().lower()] = value.strip() + + expected_accept = base64.b64encode( + hashlib.sha1((websocket_key + WEBSOCKET_GUID).encode("ascii")).digest() + ).decode("ascii") + if headers.get("sec-websocket-accept") != expected_accept: + raise PyiCloudTrustedDevicePromptException( + "Invalid websocket accept header from bridge server." + ) + + return secure_socket + + def _read_http_response(self, sock: ssl.SSLSocket) -> str: + """Read the HTTP upgrade response headers from the websocket socket.""" + while b"\r\n\r\n" not in self._buffer: + chunk = sock.recv(4096) + if not chunk: + raise PyiCloudTrustedDevicePromptException( + "Unexpected EOF during websocket handshake." + ) + self._buffer.extend(chunk) + + marker = self._buffer.find(b"\r\n\r\n") + 4 + data = bytes(self._buffer[:marker]).decode("iso-8859-1") + del self._buffer[:marker] + return data + + def _read_exact(self, size: int) -> bytes: + """Read exactly ``size`` buffered bytes from the websocket socket.""" + while len(self._buffer) < size: + chunk = self._socket.recv(max(4096, size - len(self._buffer))) + if not chunk: + raise PyiCloudTrustedDevicePromptException( + "Unexpected EOF while reading websocket frame." + ) + self._buffer.extend(chunk) + + data = bytes(self._buffer[:size]) + del self._buffer[:size] + return data + + def _send_frame(self, opcode: int, payload: bytes) -> None: + """Send one masked websocket frame to Apple's bridge server.""" + first_byte = 0x80 | opcode + mask_key = os.urandom(4) + length = len(payload) + + header = bytearray([first_byte]) + if length < 126: + header.append(0x80 | length) + elif length < 65536: + header.append(0x80 | 126) + header.extend(struct.pack("!H", length)) + else: + header.append(0x80 | 127) + header.extend(struct.pack("!Q", length)) + + masked_payload = bytes( + byte ^ mask_key[index % 4] for index, byte in enumerate(payload) + ) + self._socket.sendall(bytes(header) + mask_key + masked_payload) + + def send_binary(self, payload: bytes) -> None: + """Send one binary websocket message payload.""" + self._send_frame(OPCODE_BINARY, payload) + + def read_message(self) -> bytes: + """Read one complete websocket message, handling control frames inline.""" + fragments: list[bytes] = [] + opcode: Optional[int] = None + + while True: + first_byte, second_byte = self._read_exact(2) + frame_opcode = first_byte & 0x0F + finished = bool(first_byte & 0x80) + masked = bool(second_byte & 0x80) + payload_length = second_byte & 0x7F + + if payload_length == 126: + payload_length = struct.unpack("!H", self._read_exact(2))[0] + elif payload_length == 127: + payload_length = struct.unpack("!Q", self._read_exact(8))[0] + + mask_key = self._read_exact(4) if masked else b"" + payload = self._read_exact(payload_length) + if masked: + payload = bytes( + byte ^ mask_key[index % 4] for index, byte in enumerate(payload) + ) + + if frame_opcode == OPCODE_CLOSE: + raise PyiCloudTrustedDevicePromptException( + "Bridge websocket closed before delivering a prompt." + ) + if frame_opcode == OPCODE_PING: + self._send_frame(OPCODE_PONG, payload) + continue + if frame_opcode == OPCODE_PONG: + continue + + if frame_opcode != 0: + opcode = frame_opcode + fragments.append(payload) + if finished: + if opcode not in (0x1, OPCODE_BINARY): + raise PyiCloudTrustedDevicePromptException( + f"Unsupported websocket opcode: {opcode}" + ) + return b"".join(fragments) + + def close(self) -> None: + """Attempt a clean websocket close and always close the socket object.""" + if getattr(self, "_socket", None) is None: + return + try: + self._send_frame(OPCODE_CLOSE, b"") + except OSError: + pass + finally: + try: + self._socket.close() + except OSError: + pass + + +class TrustedDeviceBridgeBootstrapper: + """Bootstrap the trusted-device bridge flow captured in Apple's browser client.""" + + def __init__( + self, + *, + timeout: float = WEBSOCKET_TIMEOUT_SECONDS, + websocket_factory: Optional[ + Callable[[str, float, str, str], _WebSocketLike] + ] = None, + prover_factory: Optional[Callable[[], TrustedDeviceBridgeProver]] = None, + ) -> None: + """Configure websocket and prover factories for bridge operations.""" + self.timeout = timeout + self._websocket_factory = websocket_factory or _RawWebSocketClient + self._prover_factory = prover_factory or TrustedDeviceBridgeProver + + def start( + self, + *, + session: Any, + auth_endpoint: str, + headers: Mapping[str, str], + boot_context: Hsa2BootContext, + user_agent: str, + ) -> TrustedDeviceBridgeState: + """Bootstrap Apple's trusted-device bridge until the first prompt payload arrives.""" + topic = _resolve_apns_topic(boot_context) + websocket_host = _resolve_websocket_host(boot_context) + origin = _derive_origin(auth_endpoint) + topics_by_hash = {_topic_hash(topic): topic} + source_app_id = boot_context.source_app_id + public_key, private_key = self._generate_keypair() + + LOGGER.debug( + "Bootstrapping trusted-device bridge: auth_endpoint=%s websocket_host=%s topic=%s source_app_id=%s", + auth_endpoint, + websocket_host, + topic, + source_app_id, + ) + + timestamp_ms: Optional[int] = None + last_error: Optional[Exception] = None + for _ in range(2): + nonce = _build_nonce(timestamp_ms or int(time.time() * 1000)) + signature = private_key.sign(nonce, ec.ECDSA(hashes.SHA256())) + connection_message = _encode_connection_message( + public_key, nonce, signature + ) + connection_path = connection_message.hex() + websocket_url = f"wss://{websocket_host}/v2/{connection_path}" + LOGGER.debug( + "Opening trusted-device websocket: host=%s bootstrapPayloadLen=%d", + websocket_host, + len(connection_path), + ) + websocket = self._websocket_factory( + websocket_url, + self.timeout, + origin, + user_agent, + ) + keep_websocket_open = False + + try: + push_token = self._wait_for_push_token(websocket) + push_token_hex = push_token.hex() + LOGGER.debug( + "Trusted-device bridge connected; received push token (%d bytes)", + len(push_token), + ) + websocket.send_binary(_encode_web_filter_message([topic])) + LOGGER.debug("Sent trusted-device webFilterMessage for topic=%s", topic) + + session_uuid = self._generate_session_uuid() + bridge_headers = dict(headers) + if source_app_id: + bridge_headers["X-Apple-App-Id"] = source_app_id + + LOGGER.debug( + "Posting trusted-device bridge step 0 with sessionUUID=%s ptknLen=%d", + _summarize_identifier(session_uuid), + len(push_token_hex), + ) + # Apple's browser posts step 0 immediately after obtaining the push + # token. Waiting for the first push before posting step 0 causes the + # bridge flow to stall. + self._post_bridge_step0( + session=session, + auth_endpoint=auth_endpoint, + headers=bridge_headers, + session_uuid=session_uuid, + push_token=push_token_hex, + ) + + push_payload = self._wait_for_bridge_push( + websocket, topic, topics_by_hash + ) + LOGGER.debug( + "Received trusted-device bridge payload: sessionUUID=%s nextStep=%s ruiURLKey=%s", + _summarize_identifier(push_payload.session_uuid), + push_payload.next_step, + push_payload.rui_url_key, + ) + if push_payload.session_uuid != session_uuid: + raise PyiCloudTrustedDevicePromptException( + "Trusted-device bridge returned a mismatched session UUID." + ) + + bridge_state = TrustedDeviceBridgeState( + connection_path=connection_path, + push_token=push_token_hex, + session_uuid=session_uuid, + websocket=websocket, + topic=topic, + topics_by_hash=dict(topics_by_hash), + source_app_id=source_app_id, + ) + bridge_state.apply_push_payload(push_payload) + keep_websocket_open = True + return bridge_state + except _InvalidNonceError as exc: + timestamp_ms = exc.server_timestamp_ms + last_error = exc + LOGGER.debug( + "Trusted-device bridge received INVALID_NONCE; retrying with server timestamp %s", + timestamp_ms, + ) + except (OSError, socket.timeout, ssl.SSLError) as exc: + last_error = exc + LOGGER.debug( + "Trusted-device websocket transport error during bootstrap.", + exc_info=True, + ) + break + except PyiCloudTrustedDevicePromptException as exc: + last_error = exc + LOGGER.debug( + "Trusted-device bridge bootstrap failed before completion.", + exc_info=True, + ) + break + finally: + if not keep_websocket_open: + websocket.close() + + raise PyiCloudTrustedDevicePromptException( + "Failed to bootstrap the trusted-device bridge prompt." + ) from last_error + + def _generate_keypair(self) -> tuple[bytes, ec.EllipticCurvePrivateKey]: + """Generate the ephemeral P-256 keypair used for websocket bootstrap.""" + private_key = ec.generate_private_key(ec.SECP256R1()) + public_key = private_key.public_key().public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.UncompressedPoint, + ) + return public_key, private_key + + def _generate_session_uuid(self) -> str: + """Generate the browser-style bridge session UUID string.""" + return f"{uuid.uuid4()}-{int(time.time())}" + + def close(self, bridge_state: Optional[TrustedDeviceBridgeState]) -> None: + """Close and detach the websocket associated with an active bridge session.""" + + if bridge_state is None: + return + websocket = bridge_state.websocket + bridge_state.websocket = None + if websocket is None: + return + try: + websocket.close() + except OSError: + LOGGER.debug( + "Trusted-device bridge websocket close failed.", + exc_info=True, + ) + + def validate_code( + self, + *, + session: Any, + auth_endpoint: str, + headers: Mapping[str, str], + bridge_state: TrustedDeviceBridgeState, + code: str, + ) -> bool: + """Run Apple's bridge-specific trusted-device verification flow.""" + + websocket = bridge_state.websocket + if websocket is None: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge session is not active." + ) + if bridge_state.uses_legacy_trusted_device_verifier: + raise PyiCloudTrustedDeviceVerificationException( + "Legacy trusted-device verification should bypass the bridge verifier." + ) + if bridge_state.next_step not in {"2", 2}: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge is not ready for step 2 verification." + ) + if not bridge_state.salt: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge payload is missing the step-2 salt." + ) + + prover = self._prover_factory() + bridge_headers = self._bridge_headers(headers, bridge_state) + + try: + LOGGER.debug( + "Starting trusted-device bridge code verification: sessionUUID=%s nextStep=%s txnid=%s", + _summarize_identifier(bridge_state.session_uuid), + bridge_state.next_step, + _summarize_identifier(bridge_state.txnid, prefix=12), + ) + + prover.init_with_salt(bridge_state.salt, code) + message1 = prover.get_message1() + LOGGER.debug( + "Posting trusted-device bridge step 2 with sessionUUID=%s", + _summarize_identifier(bridge_state.session_uuid), + ) + self._post_bridge_step( + session=session, + auth_endpoint=auth_endpoint, + headers=bridge_headers, + bridge_state=bridge_state, + next_step=2, + data=_hex_to_b64(message1), + idmsdata=bridge_state.idmsdata, + akdata=bridge_state.akdata, + ) + + step4_payload = self._wait_for_bridge_push( + websocket, + bridge_state.topic, + bridge_state.topics_by_hash, + ) + self._apply_expected_step4_push(bridge_state, step4_payload) + + if not bridge_state.data: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge step 4 payload is missing prover data." + ) + try: + step4_data = base64.b64decode( + bridge_state.data.encode("ascii"), validate=True + ).decode("utf-8") + bridge_message1_b64, bridge_message2_b64 = step4_data.split("_", 1) + bridge_message1_hex = _b64_to_hex(bridge_message1_b64) + bridge_message2_hex = _b64_to_hex(bridge_message2_b64) + except (ValueError, UnicodeDecodeError, BinasciiError) as exc: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge step 4 payload is malformed." + ) from exc + + LOGGER.debug( + "Processing trusted-device bridge step 4 payload for sessionUUID=%s", + _summarize_identifier(bridge_state.session_uuid), + ) + try: + message2 = prover.process_message1(bridge_message1_hex) + except ValueError as exc: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge step 4 payload is malformed." + ) from exc + try: + prover.process_message2(bridge_message2_hex) + except ValueError: + LOGGER.debug( + "Trusted-device bridge prover rejected the step-4 confirmation for sessionUUID=%s", + _summarize_identifier(bridge_state.session_uuid), + ) + return False + + LOGGER.debug( + "Posting trusted-device bridge step 4 with sessionUUID=%s", + _summarize_identifier(bridge_state.session_uuid), + ) + self._post_bridge_step( + session=session, + auth_endpoint=auth_endpoint, + headers=bridge_headers, + bridge_state=bridge_state, + next_step=4, + data=_hex_to_b64(message2), + idmsdata=bridge_state.idmsdata, + akdata=bridge_state.akdata, + ) + + final_payload = self._wait_for_bridge_push( + websocket, + bridge_state.topic, + bridge_state.topics_by_hash, + ) + self._apply_final_bridge_push(bridge_state, final_payload) + + if not bridge_state.encrypted_code: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge final payload is missing encryptedCode." + ) + + LOGGER.debug( + "Decrypting trusted-device bridge code for sessionUUID=%s", + _summarize_identifier(bridge_state.session_uuid), + ) + try: + derived_code = prover.decrypt_message(bridge_state.encrypted_code) + except ValueError as exc: + raise PyiCloudTrustedDeviceVerificationException( + "Failed to decrypt the trusted-device bridge validation code." + ) from exc + + verify_response = self._post_bridge_code_validate( + session=session, + auth_endpoint=auth_endpoint, + headers=bridge_headers, + bridge_state=bridge_state, + code=derived_code, + ) + verification_succeeded = ( + verify_response.status_code != HTTP_STATUS_PRECONDITION_FAILED + ) + + completion_step = 6 if bridge_state.next_step in {"6", 6} else 4 + LOGGER.debug( + "Posting trusted-device bridge completion step %s with sessionUUID=%s verifyStatus=%s", + completion_step, + _summarize_identifier(bridge_state.session_uuid), + verify_response.status_code, + ) + self._post_bridge_step( + session=session, + auth_endpoint=auth_endpoint, + headers=bridge_headers, + bridge_state=bridge_state, + next_step=completion_step, + data=BRIDGE_DONE_DATA_B64, + idmsdata=bridge_state.idmsdata, + akdata=bridge_state.akdata, + ) + return verification_succeeded + except PyiCloudTrustedDevicePromptException as exc: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge verification failed while waiting for the next bridge push." + ) from exc + except (OSError, socket.timeout, ssl.SSLError) as exc: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge verification failed due to a websocket transport error." + ) from exc + finally: + self.close(bridge_state) + + def _wait_for_push_token(self, websocket: _WebSocketLike) -> bytes: + """Wait for the bridge connection response that carries the push token.""" + deadline = time.monotonic() + self.timeout + while time.monotonic() < deadline: + message = websocket.read_message() + server_message = _decode_server_message(message) + connection_response = server_message.connection_response + if connection_response is None: + LOGGER.debug( + "Ignoring non-connection websocket frame while waiting for push token; fields=%s", + server_message.field_numbers, + ) + continue + + if ( + connection_response.status == STATUS_OK + and connection_response.push_token_b64 + ): + try: + return base64.b64decode( + connection_response.push_token_b64.encode("ascii"), + validate=True, + ) + except (ValueError, BinasciiError) as exc: + raise PyiCloudTrustedDevicePromptException( + "Malformed bridge push token." + ) from exc + + if ( + connection_response.status == STATUS_INVALID_NONCE + and connection_response.server_timestamp_seconds is not None + ): + raise _InvalidNonceError( + connection_response.server_timestamp_seconds * 1000 + ) + + LOGGER.debug( + "Trusted-device bridge connection response returned status=%s", + connection_response.status, + ) + raise PyiCloudTrustedDevicePromptException( + f"Bridge server returned status {connection_response.status}." + ) + + raise PyiCloudTrustedDevicePromptException( + "Timed out waiting for the bridge push token." + ) + + def _wait_for_bridge_push( + self, + websocket: _WebSocketLike, + topic: str, + topics_by_hash: Mapping[str, str], + ) -> BridgePushPayload: + """Wait for, acknowledge, and decode the next relevant bridge push.""" + deadline = time.monotonic() + self.timeout + while time.monotonic() < deadline: + message = websocket.read_message() + server_message = _decode_server_message(message) + if server_message.channel_subscription_response is not None: + channel_response = server_message.channel_subscription_response + LOGGER.debug( + "Received channel subscription response during bridge bootstrap: messageId=%s status=%s retryIntervalSeconds=%s topics=%s", + channel_response.message_id, + channel_response.status, + channel_response.retry_interval_seconds, + channel_response.topics, + ) + if channel_response.status != STATUS_OK: + raise PyiCloudTrustedDevicePromptException( + "Trusted-device bridge topic subscription failed " + f"(status {channel_response.status})." + ) + + if server_message.push_acknowledgment is not None: + push_ack = server_message.push_acknowledgment + LOGGER.debug( + "Received bridge push acknowledgment during bootstrap: messageId=%s deliveryStatus=%s topic=%s", + push_ack.message_id, + push_ack.delivery_status, + _topic_name(push_ack.topic, topics_by_hash), + ) + + push_message = server_message.push_message + if push_message is None: + LOGGER.debug( + "Ignoring non-push websocket frame during trusted-device bootstrap; fields=%s", + server_message.field_numbers, + ) + continue + + websocket.send_binary( + _encode_ack_message(push_message.topic, push_message.message_id) + ) + LOGGER.debug( + "Acknowledged trusted-device push message id=%s topic=%s", + push_message.message_id, + _topic_name(push_message.topic, topics_by_hash), + ) + + if _topic_name(push_message.topic, topics_by_hash) != topic: + continue + + payload = _extract_json_payload(push_message.payload) + return BridgePushPayload.from_payload(payload) + + raise PyiCloudTrustedDevicePromptException( + "Timed out waiting for the trusted-device bridge payload." + ) + + def _apply_bridge_push( + self, + bridge_state: TrustedDeviceBridgeState, + push_payload: BridgePushPayload, + ) -> None: + """Validate a generic bridge push and merge it into the active state.""" + if push_payload.session_uuid != bridge_state.session_uuid: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge returned a mismatched session UUID." + ) + LOGGER.debug( + "Decoded trusted-device bridge payload: sessionUUID=%s nextStep=%s txnid=%s ec=%s has_data=%s has_encryptedCode=%s", + _summarize_identifier(push_payload.session_uuid), + push_payload.next_step, + _summarize_identifier(push_payload.txnid, prefix=12), + push_payload.error_code, + bool(push_payload.data), + bool(push_payload.encrypted_code), + ) + if push_payload.error_code not in (None, 0): + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge returned an error push " + f"(nextStep={push_payload.next_step!r}, ec={push_payload.error_code})." + ) + bridge_state.apply_push_payload(push_payload) + + def _apply_expected_step4_push( + self, + bridge_state: TrustedDeviceBridgeState, + push_payload: BridgePushPayload, + ) -> None: + """Require the post-step-2 bridge push to contain step-4 prover data.""" + self._apply_bridge_push(bridge_state, push_payload) + if bridge_state.next_step != "4" or not bridge_state.data: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge returned an unexpected post-step-2 payload." + ) + LOGGER.debug( + "Received trusted-device bridge payload: sessionUUID=%s nextStep=%s txnid=%s", + _summarize_identifier(bridge_state.session_uuid), + bridge_state.next_step, + _summarize_identifier(bridge_state.txnid, prefix=12), + ) + + def _apply_final_bridge_push( + self, + bridge_state: TrustedDeviceBridgeState, + push_payload: BridgePushPayload, + ) -> None: + """Require the final bridge push to contain the encrypted validation code.""" + self._apply_bridge_push(bridge_state, push_payload) + # Apple's bridge can finish with either: + # - nextStep=6 plus encryptedCode + # - nextStep=4 plus encryptedCode + # The browser routes both shapes into final code validation. + if ( + bridge_state.next_step not in {"4", "6", 4, 6} + or not bridge_state.encrypted_code + ): + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge returned an unexpected final payload." + ) + LOGGER.debug( + "Received trusted-device bridge final payload: sessionUUID=%s nextStep=%s txnid=%s", + _summarize_identifier(bridge_state.session_uuid), + bridge_state.next_step, + _summarize_identifier(bridge_state.txnid, prefix=12), + ) + + def _bridge_headers( + self, + headers: Mapping[str, str], + bridge_state: TrustedDeviceBridgeState, + ) -> dict[str, str]: + """Build the auth headers used for bridge-specific HTTP requests.""" + bridge_headers = dict(headers) + if bridge_state.source_app_id: + bridge_headers["X-Apple-App-Id"] = bridge_state.source_app_id + return bridge_headers + + def _bridge_step_json( + self, + *, + bridge_state: TrustedDeviceBridgeState, + next_step: int, + data: str, + idmsdata: Optional[str], + akdata: Any, + ) -> dict[str, Any]: + """Build the JSON payload for one bridge step POST.""" + return BridgeStepRequest( + session_uuid=bridge_state.session_uuid, + data=data, + push_token=bridge_state.push_token, + next_step=next_step, + idmsdata=idmsdata, + akdata=akdata, + ).as_json() + + def _post_bridge_step( + self, + *, + session: Any, + auth_endpoint: str, + headers: Mapping[str, str], + bridge_state: TrustedDeviceBridgeState, + next_step: int, + data: str, + idmsdata: Optional[str], + akdata: Any, + ) -> Any: + """POST one bridge step and enforce the small set of valid statuses.""" + response = session.request_raw( + "POST", + f"{auth_endpoint}{BRIDGE_STEP_PATH_TEMPLATE.format(step=next_step)}", + json=self._bridge_step_json( + bridge_state=bridge_state, + next_step=next_step, + data=data, + idmsdata=idmsdata, + akdata=akdata, + ), + headers=headers, + ) + if response.status_code not in { + HTTP_STATUS_OK, + HTTP_STATUS_NO_CONTENT, + HTTP_STATUS_CONFLICT, + }: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge step " + f"{next_step} failed with status {response.status_code}." + ) + return response + + def _post_bridge_step0( + self, + *, + session: Any, + auth_endpoint: str, + headers: Mapping[str, str], + session_uuid: str, + push_token: str, + ) -> Any: + """POST bridge step 0 immediately after obtaining the push token.""" + response = session.request_raw( + "POST", + f"{auth_endpoint}{BRIDGE_STEP_PATH}", + json={ + "sessionUUID": session_uuid, + "ptkn": push_token, + }, + headers=headers, + ) + if response.status_code not in { + HTTP_STATUS_OK, + HTTP_STATUS_NO_CONTENT, + HTTP_STATUS_CONFLICT, + }: + raise PyiCloudTrustedDevicePromptException( + "Trusted-device bridge step 0 failed with status " + f"{response.status_code}." + ) + return response + + def _post_bridge_code_validate( + self, + *, + session: Any, + auth_endpoint: str, + headers: Mapping[str, str], + bridge_state: TrustedDeviceBridgeState, + code: str, + ) -> Any: + """POST the decrypted bridge code to Apple's final validation endpoint.""" + response = session.request_raw( + "POST", + f"{auth_endpoint}{BRIDGE_CODE_VALIDATE_PATH}", + json=BridgeCodeValidateRequest( + session_uuid=bridge_state.session_uuid, + code=code, + ).as_json(), + headers=headers, + ) + if response.status_code not in { + HTTP_STATUS_OK, + HTTP_STATUS_NO_CONTENT, + HTTP_STATUS_CONFLICT, + HTTP_STATUS_PRECONDITION_FAILED, + }: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge code validation failed with status " + f"{response.status_code}." + ) + return response diff --git a/pyicloud/hsa2_bridge_prover.py b/pyicloud/hsa2_bridge_prover.py new file mode 100644 index 00000000..7e6ef8bd --- /dev/null +++ b/pyicloud/hsa2_bridge_prover.py @@ -0,0 +1,582 @@ +"""Pure-Python bridge prover for Apple's trusted-device HSA2 flow.""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import secrets +from dataclasses import dataclass +from typing import Optional + +from cryptography.exceptions import InvalidTag +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +_SCRYPT_PARAMS = { + "n": 16384, + "r": 8, + "p": 1, + "dklen": 64, +} +_CLIENT_IDENTITY = b"com.apple.security.webprover" +_SERVER_IDENTITY = b"com.apple.security.webverifier" +_SPAKE2_CONTEXT = b"SPAKE2Web" +_KEY_LENGTH = 32 +_VERIFIER_KEY_INFO = b"webVerifier" +_PROVER_KEY_INFO = b"webProver" + +_P256_P = int("FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF", 16) +_P256_A = (_P256_P - 3) % _P256_P +_P256_B = int("5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B", 16) +_P256_ORDER = int( + "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551", 16 +) +_P256_GX = int("6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296", 16) +_P256_GY = int("4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5", 16) +_SPAKE2_M = "02886e2f97ace46e55ba9dd7242579f2993b64e16ef3dcab95afd497333d8fa12f" +_SPAKE2_N = "03d8bbd6c639c62937b04d997f38c3770719c629d7014d49a24b4f98baa1292b49" +_AES_GCM_LAYOUTS = {0: (12, 16)} + + +@dataclass(frozen=True) +class _Point: + """Affine P-256 point used by the bridge SPAKE2 math helpers.""" + + x: Optional[int] + y: Optional[int] + + @property + def is_infinity(self) -> bool: + """Return whether this point is the point at infinity.""" + return self.x is None or self.y is None + + +_INFINITY = _Point(None, None) +_GENERATOR = _Point(_P256_GX, _P256_GY) + + +def _int_to_bytes(value: int, length: Optional[int] = None) -> bytes: + """Encode an integer using big-endian bytes.""" + if length is None: + length = max(1, (value.bit_length() + 7) // 8) + return value.to_bytes(length, "big") + + +def _b64_to_bytes(value: str) -> bytes: + """Decode a base64 string into raw bytes.""" + return base64.b64decode(value.encode("ascii")) + + +def _bytes_to_b64(value: bytes) -> str: + """Encode raw bytes as an ASCII base64 string.""" + return base64.b64encode(value).decode("ascii") + + +def _encode_point(point: _Point) -> str: + """Encode a P-256 point using SEC1 uncompressed point format.""" + if point.is_infinity: + raise ValueError("Cannot encode the point at infinity.") + return "04" + _int_to_bytes(point.x, 32).hex() + _int_to_bytes(point.y, 32).hex() + + +def _decode_point(value: str) -> _Point: + """Decode a compressed or uncompressed SEC1 point into affine coordinates.""" + raw = bytes.fromhex(value) + if len(raw) == 65 and raw[0] == 0x04: + point = _Point( + int.from_bytes(raw[1:33], "big"), + int.from_bytes(raw[33:65], "big"), + ) + elif len(raw) == 33 and raw[0] in (0x02, 0x03): + x_coord = int.from_bytes(raw[1:], "big") + rhs = (pow(x_coord, 3, _P256_P) + _P256_A * x_coord + _P256_B) % _P256_P + y_coord = pow(rhs, (_P256_P + 1) // 4, _P256_P) + if y_coord & 1 != raw[0] & 1: + y_coord = (-y_coord) % _P256_P + point = _Point(x_coord, y_coord) + else: + raise ValueError("Unsupported P-256 point encoding.") + + if not _is_on_curve(point): + raise ValueError("Invalid P-256 point.") + return point + + +def _is_on_curve(point: _Point) -> bool: + """Return whether a point lies on the configured P-256 curve.""" + if point.is_infinity: + return False + assert point.x is not None and point.y is not None + return ( + pow(point.y, 2, _P256_P) + - (pow(point.x, 3, _P256_P) + _P256_A * point.x + _P256_B) + ) % _P256_P == 0 + + +def _negate(point: _Point) -> _Point: + """Return the additive inverse of a P-256 point.""" + if point.is_infinity: + return point + assert point.x is not None and point.y is not None + return _Point(point.x, (-point.y) % _P256_P) + + +def _add_points(left: _Point, right: _Point) -> _Point: + """Add two affine P-256 points.""" + if left.is_infinity: + return right + if right.is_infinity: + return left + + assert left.x is not None and left.y is not None + assert right.x is not None and right.y is not None + + if left.x == right.x and (left.y + right.y) % _P256_P == 0: + return _INFINITY + + if left.x == right.x and left.y == right.y: + if left.y == 0: + return _INFINITY + slope = ( + (3 * left.x * left.x + _P256_A) * pow(2 * left.y, -1, _P256_P) + ) % _P256_P + else: + slope = ((right.y - left.y) * pow(right.x - left.x, -1, _P256_P)) % _P256_P + + x_coord = (slope * slope - left.x - right.x) % _P256_P + y_coord = (slope * (left.x - x_coord) - left.y) % _P256_P + return _Point(x_coord, y_coord) + + +def _multiply_point(point: _Point, scalar: int) -> _Point: + """Multiply a P-256 point by a scalar using double-and-add.""" + scalar %= _P256_ORDER + result = _INFINITY + addend = point + while scalar: + if scalar & 1: + result = _add_points(result, addend) + addend = _add_points(addend, addend) + scalar >>= 1 + return result + + +def _concat_length_prefixed(*parts: bytes) -> bytes: + """Concatenate transcript parts using the bridge's length-prefixed format.""" + output = bytearray() + for part in parts: + output.extend(len(part).to_bytes(8, "little")) + output.extend(part) + return bytes(output) + + +def _hkdf_like(ikm: bytes, salt: bytes, info: bytes, length: int) -> bytes: + """Derive key material using the bridge worker's HKDF-like expansion.""" + hash_len = hashlib.sha256().digest_size + if not salt: + salt = b"\x00" * hash_len + prk = hmac.new(salt, ikm, hashlib.sha256).digest() + blocks = bytearray() + previous = b"" + counter = 1 + while len(blocks) < length: + previous = hmac.new( + prk, + previous + info + bytes([counter]), + hashlib.sha256, + ).digest() + blocks.extend(previous) + counter += 1 + return bytes(blocks[:length]) + + +def _confirmation_key_length(info: bytes, requested_length: int) -> int: + """Return the bridge-specific output length for a given HKDF info label.""" + if b"ConfirmationKeys" in info: + return 64 + return requested_length + + +def _derive_key(ikm: bytes, info: bytes, length: int = 64) -> bytes: + """Derive one bridge sub-key from raw shared-secret material.""" + return _hkdf_like( + ikm=ikm, + salt=b"", + info=info, + length=_confirmation_key_length(info, length), + ) + + +def _derive_prover_and_verifier_keys(raw_key_hex: str) -> tuple[str, str]: + """Split the raw bridge key into prover and verifier AES/HMAC keys.""" + raw_key = bytes.fromhex(raw_key_hex) + verifier_key = _derive_key(raw_key, _VERIFIER_KEY_INFO, _KEY_LENGTH) + prover_key = _derive_key(raw_key, _PROVER_KEY_INFO, _KEY_LENGTH) + return verifier_key.hex(), prover_key.hex() + + +@dataclass(frozen=True) +class _ClientSharedSecret: + """Client-side shared-secret transcript and derived confirmation keys.""" + + transcript: bytes + share_p: str + share_v: str + + def __post_init__(self) -> None: + """Derive confirmation keys and the final shared key from the transcript.""" + digest = hashlib.sha256(self.transcript).digest() + object.__setattr__(self, "_hash_transcript", digest) + confirmations = _derive_key(digest, b"ConfirmationKeys", 64) + object.__setattr__(self, "_confirm_client", confirmations[:32]) + object.__setattr__(self, "_confirm_server", confirmations[32:]) + shared_key = _derive_key(digest, b"SharedKey", _KEY_LENGTH) + object.__setattr__(self, "_shared_key", shared_key) + + def get_confirmation(self) -> str: + """Return the prover's HMAC confirmation message.""" + return hmac.new( + self._confirm_client, + bytes.fromhex(self.share_v), + hashlib.sha256, + ).hexdigest() + + def verify(self, message_hex: str) -> bytes: + """Verify the server confirmation and return the shared key bytes.""" + expected = hmac.new( + self._confirm_server, + bytes.fromhex(self.share_p), + hashlib.sha256, + ).hexdigest() + if expected != message_hex: + raise ValueError("invalid confirmation from server") + return self._shared_key + + +@dataclass(frozen=True) +class _ServerSharedSecret: + """Server-side shared-secret transcript and derived confirmation keys.""" + + transcript: bytes + share_p: str + share_v: str + + def __post_init__(self) -> None: + """Derive confirmation keys and the final shared key from the transcript.""" + digest = hashlib.sha256(self.transcript).digest() + confirmations = _derive_key(digest, b"ConfirmationKeys", 64) + object.__setattr__(self, "_confirm_client", confirmations[:32]) + object.__setattr__(self, "_confirm_server", confirmations[32:]) + object.__setattr__( + self, + "_shared_key", + _derive_key(digest, b"SharedKey", _KEY_LENGTH), + ) + + def get_confirmation(self) -> str: + """Return the verifier's HMAC confirmation message.""" + return hmac.new( + self._confirm_server, + bytes.fromhex(self.share_p), + hashlib.sha256, + ).hexdigest() + + def verify(self, message_hex: str) -> bytes: + """Verify the prover confirmation and return the shared key bytes.""" + expected = hmac.new( + self._confirm_client, + bytes.fromhex(self.share_v), + hashlib.sha256, + ).hexdigest() + if expected != message_hex: + raise ValueError("invalid confirmation from client") + return self._shared_key + + +class _ClientHandshake: + """Client-side SPAKE2 handshake state for Apple's bridge prover.""" + + def __init__( + self, + *, + x_scalar: int, + w0: int, + w1: int, + ) -> None: + """Initialize the prover handshake with the derived SPAKE2 scalars.""" + self._x = x_scalar + self._w0 = w0 + self._w1 = w1 + self._message1_point: Optional[_Point] = None + self.share_p: Optional[str] = None + + def get_message(self) -> str: + """Return the prover's first SPAKE2 message.""" + point = _add_points( + _multiply_point(_GENERATOR, self._x), + _multiply_point(_decode_point(_SPAKE2_M), self._w0), + ) + self._message1_point = point + self.share_p = _encode_point(point) + return self.share_p + + def finish(self, server_message_hex: str) -> _ClientSharedSecret: + """Finish the handshake using the verifier's first message.""" + if self._message1_point is None or self.share_p is None: + raise ValueError("get_message must be called before finish") + + server_point = _decode_point(server_message_hex) + if server_point.is_infinity: + raise ValueError("invalid curve point") + + adjusted = _add_points( + server_point, + _negate(_multiply_point(_decode_point(_SPAKE2_N), self._w0)), + ) + y_point = _multiply_point(adjusted, self._x) + v_point = _multiply_point(adjusted, self._w1) + transcript = _concat_length_prefixed( + _SPAKE2_CONTEXT, + _CLIENT_IDENTITY, + _SERVER_IDENTITY, + bytes.fromhex(_encode_point(_decode_point(_SPAKE2_M))), + bytes.fromhex(_encode_point(_decode_point(_SPAKE2_N))), + bytes.fromhex(_encode_point(self._message1_point)), + bytes.fromhex(_encode_point(server_point)), + bytes.fromhex(_encode_point(y_point)), + bytes.fromhex(_encode_point(v_point)), + _int_to_bytes(self._w0), + ) + return _ClientSharedSecret( + transcript=transcript, + share_p=self.share_p, + share_v=server_message_hex, + ) + + +class _ServerHandshake: + """Server-side SPAKE2 handshake state used by the local test helper.""" + + def __init__( + self, + *, + y_scalar: int, + w0: int, + verifier_point: _Point, + ) -> None: + """Initialize the verifier handshake with its scalar and verifier point.""" + self._y = y_scalar + self._w0 = w0 + self._verifier_point = verifier_point + self._message1_point: Optional[_Point] = None + self.share_v: Optional[str] = None + + def get_message(self) -> str: + """Return the verifier's first SPAKE2 message.""" + point = _add_points( + _multiply_point(_GENERATOR, self._y), + _multiply_point(_decode_point(_SPAKE2_N), self._w0), + ) + self._message1_point = point + self.share_v = _encode_point(point) + return self.share_v + + def finish(self, client_message_hex: str) -> _ServerSharedSecret: + """Finish the verifier handshake using the prover's first message.""" + if self._message1_point is None or self.share_v is None: + raise ValueError("get_message must be called before finish") + + client_point = _decode_point(client_message_hex) + if client_point.is_infinity: + raise ValueError("invalid curve point") + + adjusted = _add_points( + client_point, + _negate(_multiply_point(_decode_point(_SPAKE2_M), self._w0)), + ) + y_point = _multiply_point(adjusted, self._y) + verifier_share = _multiply_point(self._verifier_point, self._y) + transcript = _concat_length_prefixed( + _SPAKE2_CONTEXT, + _CLIENT_IDENTITY, + _SERVER_IDENTITY, + bytes.fromhex(_encode_point(_decode_point(_SPAKE2_M))), + bytes.fromhex(_encode_point(_decode_point(_SPAKE2_N))), + bytes.fromhex(_encode_point(client_point)), + bytes.fromhex(_encode_point(self._message1_point)), + bytes.fromhex(_encode_point(y_point)), + bytes.fromhex(_encode_point(verifier_share)), + _int_to_bytes(self._w0), + ) + return _ServerSharedSecret( + transcript=transcript, + share_p=client_message_hex, + share_v=self.share_v, + ) + + +def _compute_w0_w1(password: str, salt_b64: str) -> tuple[int, int]: + """Derive the SPAKE2 scalars from the user code and bridge salt.""" + derived = hashlib.scrypt( + password.encode("utf-8"), + salt=_b64_to_bytes(salt_b64), + **_SCRYPT_PARAMS, + ) + midpoint = len(derived) // 2 + return ( + int.from_bytes(derived[:midpoint], "big"), + int.from_bytes(derived[midpoint:], "big"), + ) + + +def _random_nonzero_scalar() -> int: + """Return a random scalar in the non-zero P-256 subgroup range.""" + scalar = 0 + while scalar == 0: + scalar = secrets.randbelow(_P256_ORDER) + return scalar + + +class TrustedDeviceBridgeProver: + """Client-side prover mirroring Apple's prover worker.""" + + def __init__(self) -> None: + """Initialize empty prover state for one bridge verification attempt.""" + self._client: Optional[_ClientHandshake] = None + self._shared_secret: Optional[_ClientSharedSecret] = None + self._raw_key: Optional[str] = None + self._verified = False + self._verifier_key: Optional[str] = None + self._prover_key: Optional[str] = None + + def init_with_salt(self, salt_b64: str, code: str) -> None: + """Initialize the prover with Apple's salt and the user-entered code.""" + w0, w1 = _compute_w0_w1(code, salt_b64) + self._client = _ClientHandshake( + x_scalar=_random_nonzero_scalar(), + w0=w0, + w1=w1, + ) + self._shared_secret = None + self._raw_key = None + self._verified = False + self._verifier_key = None + self._prover_key = None + + def get_message1(self) -> str: + """Return the prover's first bridge message.""" + if self._client is None: + raise ValueError("init_with_salt must be called before get_message1") + return self._client.get_message() + + def process_message1(self, message_hex: str) -> str: + """Process Apple's first bridge message and return the prover confirmation.""" + if self._client is None: + raise ValueError("init_with_salt must be called before process_message1") + self._shared_secret = self._client.finish(message_hex) + return self.get_message2() + + def get_message2(self) -> str: + """Return the prover confirmation generated from the shared transcript.""" + if self._shared_secret is None: + raise ValueError("process_message1 must be called before get_message2") + return self._shared_secret.get_confirmation() + + def process_message2(self, message_hex: str) -> dict[str, object]: + """Verify Apple's confirmation and persist the derived bridge keys.""" + if self._shared_secret is None: + raise ValueError("process_message1 must be called before process_message2") + raw_key = self._shared_secret.verify(message_hex).hex() + self._raw_key = raw_key + self._verifier_key, self._prover_key = _derive_prover_and_verifier_keys(raw_key) + self._verified = True + return {"isVerified": True, "key": raw_key} + + def is_verified(self) -> bool: + """Return whether the bridge confirmation exchange has completed.""" + return self._verified + + def get_key(self) -> str: + """Return the raw shared bridge key as hexadecimal.""" + if self._raw_key is None: + raise ValueError("No bridge key is available yet.") + return self._raw_key + + def decrypt_message(self, ciphertext_b64: str) -> str: + """Decrypt Apple's final encrypted validation code.""" + if self._verifier_key is None: + raise ValueError("Bridge verifier key is not available.") + try: + payload = _b64_to_bytes(ciphertext_b64) + version = payload[0] + iv_length, tag_length = _AES_GCM_LAYOUTS[version] + iv = payload[1 : 1 + iv_length] + tag = payload[1 + iv_length : 1 + iv_length + tag_length] + ciphertext = payload[1 + iv_length + tag_length :] + plaintext = AESGCM(bytes.fromhex(self._verifier_key)).decrypt( + iv, + ciphertext + tag, + bytes([version]), + ) + return plaintext.decode("utf-8") + except (IndexError, KeyError, InvalidTag, UnicodeDecodeError) as exc: + raise ValueError("Malformed bridge payload") from exc + + +class _TrustedDeviceBridgeServerProver: + """Internal test helper mirroring Apple's server-side bridge flow.""" + + def __init__(self, *, password: str, salt_b64: str) -> None: + """Initialize the local verifier helper with the same password and salt.""" + w0, w1 = _compute_w0_w1(password, salt_b64) + verifier_point = _multiply_point(_GENERATOR, w1) + self._server = _ServerHandshake( + y_scalar=_random_nonzero_scalar(), + w0=w0, + verifier_point=verifier_point, + ) + self._shared_secret: Optional[_ServerSharedSecret] = None + self._raw_key: Optional[str] = None + self._verifier_key: Optional[str] = None + self._prover_key: Optional[str] = None + + def get_message1(self) -> str: + """Return the verifier's first bridge message.""" + return self._server.get_message() + + def process_message1(self, client_message_hex: str) -> str: + """Process the prover message and return the verifier confirmation.""" + self._shared_secret = self._server.finish(client_message_hex) + return self.get_message2() + + def get_message2(self) -> str: + """Return the verifier confirmation generated from the shared transcript.""" + if self._shared_secret is None: + raise ValueError("process_message1 must be called before get_message2") + return self._shared_secret.get_confirmation() + + def verify_message2(self, message_hex: str) -> str: + """Verify the prover confirmation and persist the derived bridge keys.""" + if self._shared_secret is None: + raise ValueError("process_message1 must be called before verify_message2") + raw_key = self._shared_secret.verify(message_hex).hex() + self._raw_key = raw_key + self._verifier_key, self._prover_key = _derive_prover_and_verifier_keys(raw_key) + return raw_key + + def encrypt_message(self, plaintext: str) -> str: + """Encrypt a plaintext validation code using Apple's AES-GCM payload layout.""" + if self._verifier_key is None: + raise ValueError("Bridge verifier key is not available.") + version = 0 + iv_length, tag_length = _AES_GCM_LAYOUTS[version] + iv = secrets.token_bytes(iv_length) + encrypted = AESGCM(bytes.fromhex(self._verifier_key)).encrypt( + iv, + plaintext.encode("utf-8"), + bytes([version]), + ) + ciphertext = encrypted[:-tag_length] + tag = encrypted[-tag_length:] + payload = bytes([version]) + iv + tag + ciphertext + return _bytes_to_b64(payload) diff --git a/pyicloud/session.py b/pyicloud/session.py index b696bfd9..c134d25b 100644 --- a/pyicloud/session.py +++ b/pyicloud/session.py @@ -33,6 +33,30 @@ from pyicloud.base import PyiCloudService +NON_PERSISTED_SESSION_KEYS = frozenset( + { + "akdata", + "connection_path", + "data", + "encryptedCode", + "encrypted_code", + "idmsdata", + "mid", + "nextStep", + "next_step", + "ptkn", + "push_token", + "salt", + "sessionUUID", + "session_uuid", + "source_app_id", + "topic", + "topics_by_hash", + "txnid", + } +) + + class PyiCloudSession(requests.Session): """iCloud session.""" @@ -44,6 +68,7 @@ def __init__( verify: bool = False, headers: Optional[dict[str, str]] = None, ) -> None: + """Initialize the persisted requests session used by the service.""" super().__init__() self._service: PyiCloudService = service @@ -102,7 +127,14 @@ def _save_session_data(self) -> None: os.makedirs(self._cookie_directory, exist_ok=True) with open(self.session_path, "w", encoding="utf-8") as outfile: # Copy to avoid dict mutation during concurrent access - dump(dict(self._data), outfile) + dump( + { + key: value + for key, value in dict(self._data).items() + if key not in NON_PERSISTED_SESSION_KEYS + }, + outfile, + ) self.logger.debug("Saved session data to file: %s", self.session_path) try: @@ -143,6 +175,7 @@ def _update_session_data(self, response: Response) -> None: self._data.update({session_arg: response.headers.get(header)}) def _is_json_response(self, response: Response) -> bool: + """Return whether a response advertises one of the accepted JSON mimetypes.""" content_type: str = response.headers.get(CONTENT_TYPE, "") json_mimetypes: list[str] = [ CONTENT_TYPE_JSON, @@ -169,6 +202,7 @@ def request( cert=None, json=None, ) -> Response: + """Dispatch a request through the normalized session request pipeline.""" return self._request( method, url, @@ -188,6 +222,71 @@ def request( json=json, ) + def request_raw( + self, + method, + url, + params=None, + data=None, + headers=None, + cookies=None, + files=None, + auth=None, + timeout=None, + allow_redirects=True, + proxies=None, + hooks=None, + stream=None, + verify=None, + cert=None, + json=None, + ) -> Response: + """Dispatch a request without response-status normalization.""" + + return self._request_raw( + method, + url, + params=params, + data=data, + headers=headers, + cookies=cookies, + files=files, + auth=auth, + timeout=timeout, + allow_redirects=allow_redirects, + proxies=proxies, + hooks=hooks, + stream=stream, + verify=verify, + cert=cert, + json=json, + ) + + def _request_raw( + self, + method, + url, + **kwargs, + ) -> Response: + """Perform a request and persist cookies/session data without raising.""" + + self.logger.debug( + "%s %s", + method, + url, + ) + try: + response: Response = super().request( + method=method, + url=url, + **kwargs, + ) + except requests.exceptions.RequestException as err: + self._raise_request_exception(err) + self._update_session_data(response) + self._save_session_data() + return response + def _request( self, method, @@ -236,13 +335,19 @@ def _request( self._decode_json_response(response) return response - except requests.HTTPError as err: + except requests.exceptions.RequestException as err: + self._raise_request_exception(err) + + @staticmethod + def _raise_request_exception(err: requests.exceptions.RequestException) -> NoReturn: + """Normalize low-level requests failures into the session's public error type.""" + + if isinstance(err, requests.HTTPError) and err.response is not None: raise PyiCloudAPIResponseException( reason=err.response.text, code=err.response.status_code, ) from err - except requests.exceptions.RequestException as err: - raise PyiCloudAPIResponseException("Request failed to iCloud") from err + raise PyiCloudAPIResponseException("Request failed to iCloud") from err def _handle_request_error( self, @@ -297,6 +402,7 @@ def _decode_json_response(self, response: Response) -> None: def _raise_error( self, response: Response, code: Optional[Union[int, str]], reason: str ) -> NoReturn: + """Raise the session's public exception for a parsed iCloud error payload.""" if ( self.service.requires_2sa and reason == "Missing X-APPLE-WEBAUTH-TOKEN cookie" diff --git a/requirements.txt b/requirements.txt index 06f2a422..6e0bc574 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ certifi>=2024.12.14 click>=8.1.8 +cryptography>=44.0.0 fido2>=2.0.0 keyring>=25.6.0 keyrings.alt>=5.0.2 diff --git a/tests/test_base.py b/tests/test_base.py index 01f8bdfe..0bebaf05 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -5,10 +5,13 @@ import json import secrets +import tempfile +from pathlib import Path from typing import Any, List from unittest.mock import MagicMock, mock_open, patch import pytest +import requests from fido2.hid import CtapHidDevice from requests import HTTPError, Response @@ -21,6 +24,8 @@ PyiCloudFailedLoginException, PyiCloudServiceNotActivatedException, PyiCloudServiceUnavailable, + PyiCloudTrustedDevicePromptException, + PyiCloudTrustedDeviceVerificationException, ) from pyicloud.services.calendar import CalendarService from pyicloud.services.contacts import ContactsService @@ -257,6 +262,427 @@ def test_validate_2fa_code(pyicloud_service: PyiCloudService) -> None: assert pyicloud_service.validate_2fa_code("123456") +def test_validate_2fa_code_uses_bridge_verifier_for_step2_state( + pyicloud_service: PyiCloudService, +) -> None: + """Bridge-backed trusted-device prompts should use the bridge verifier instead of the legacy endpoint.""" + + pyicloud_service.data = {"dsInfo": {"hsaVersion": 2}, "hsaChallengeRequired": False} + pyicloud_service._two_factor_delivery_method = "trusted_device" + bridge_state = MagicMock(uses_legacy_trusted_device_verifier=False) + pyicloud_service._trusted_device_bridge_state = bridge_state + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._trusted_device_bridge.validate_code.return_value = True + pyicloud_service.trust_session = MagicMock( + side_effect=lambda: pyicloud_service.data.update({"hsaTrustedBrowser": True}) + or True + ) + pyicloud_service._session = MagicMock() + pyicloud_service.session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + + assert pyicloud_service.validate_2fa_code("123456") is True + + pyicloud_service._trusted_device_bridge.validate_code.assert_called_once() + pyicloud_service.session.post.assert_not_called() + pyicloud_service._trusted_device_bridge.close.assert_called_once_with(bridge_state) + pyicloud_service.trust_session.assert_called_once_with() + + +def test_validate_2fa_code_keeps_legacy_endpoint_for_bridge_w_subtype( + pyicloud_service: PyiCloudService, +) -> None: + """Apple's `_W` bridge subtype should keep using the legacy trusted-device verifier.""" + + pyicloud_service.data = {"dsInfo": {"hsaVersion": 2}, "hsaChallengeRequired": False} + pyicloud_service._two_factor_delivery_method = "trusted_device" + bridge_state = MagicMock(uses_legacy_trusted_device_verifier=True) + pyicloud_service._trusted_device_bridge_state = bridge_state + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service.trust_session = MagicMock( + side_effect=lambda: pyicloud_service.data.update({"hsaTrustedBrowser": True}) + or True + ) + pyicloud_service._session = MagicMock() + pyicloud_service.session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + pyicloud_service.session.post.return_value = MagicMock(status_code=200) + + assert pyicloud_service.validate_2fa_code("123456") is True + + pyicloud_service._trusted_device_bridge.validate_code.assert_not_called() + args = pyicloud_service.session.post.call_args.args + assert args[0] == ( + f"{pyicloud_service._auth_endpoint}/verify/trusteddevice/securitycode" + ) + pyicloud_service._trusted_device_bridge.close.assert_called_once_with(bridge_state) + + +def test_validate_2fa_code_bridge_verification_exception_propagates( + pyicloud_service: PyiCloudService, +) -> None: + """Bridge verification failures should not be downgraded to generic invalid-code results.""" + + pyicloud_service._two_factor_delivery_method = "trusted_device" + bridge_state = MagicMock(uses_legacy_trusted_device_verifier=False) + pyicloud_service._trusted_device_bridge_state = bridge_state + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._trusted_device_bridge.validate_code.side_effect = ( + PyiCloudTrustedDeviceVerificationException("bridge verification failed") + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="bridge verification failed", + ): + pyicloud_service.validate_2fa_code("123456") + + pyicloud_service._trusted_device_bridge.close.assert_called_once_with(bridge_state) + + +def test_request_2fa_code_requests_sms_delivery( + pyicloud_service: PyiCloudService, +) -> None: + """Nested phone verification data should trigger SMS delivery.""" + + pyicloud_service._auth_data = { + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": "sms", + } + } + } + + with patch("pyicloud.base.PyiCloudSession") as mock_session: + pyicloud_service._session = mock_session + mock_session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + + assert pyicloud_service.request_2fa_code() is True + + args = mock_session.put.call_args.args + kwargs = mock_session.put.call_args.kwargs + assert args[0] == f"{pyicloud_service._auth_endpoint}/verify/phone" + assert kwargs["json"] == { + "phoneNumber": {"id": 3, "nonFTEU": False}, + "mode": "sms", + } + assert kwargs["headers"]["Accept"] == "application/json" + + +def test_get_mfa_auth_options_parses_hsa2_boot_html( + pyicloud_service: PyiCloudService, +) -> None: + """GET /appleauth/auth HTML should populate the HSA2 boot context.""" + + response = MagicMock() + response.json.side_effect = ValueError("not json") + response.text = """ + + + + """ + pyicloud_service._session = MagicMock() + pyicloud_service.session.get.return_value = response + + auth_options = pyicloud_service._get_mfa_auth_options() + + _, kwargs = pyicloud_service.session.get.call_args + assert kwargs["headers"]["Accept"] == "text/html" + assert auth_options["authInitialRoute"] == "auth/bridge/step" + assert auth_options["hasTrustedDevices"] is True + assert auth_options["authFactors"] == ["web_piggybacking", "sms"] + assert auth_options["bridgeInitiateData"]["webSocketUrl"] == ( + "websocket.push.apple.com" + ) + assert auth_options["phoneNumberVerification"]["trustedPhoneNumber"]["id"] == 3 + assert auth_options["sourceAppId"] == "1159" + assert pyicloud_service._hsa2_boot_context is not None + assert pyicloud_service._hsa2_boot_context.auth_initial_route == ( + "auth/bridge/step" + ) + assert pyicloud_service._hsa2_boot_context.has_trusted_devices is True + + +def test_request_2fa_code_prefers_trusted_device_bridge( + pyicloud_service: PyiCloudService, +) -> None: + """Request-7 style HSA2 challenges should start the bridge before SMS.""" + + pyicloud_service.data = { + "dsInfo": {"hsaVersion": 2}, + "hsaChallengeRequired": True, + "hsaTrustedBrowser": False, + } + pyicloud_service._auth_data = { + "authInitialRoute": "auth/bridge/step", + "hasTrustedDevices": True, + "authFactors": ["web_piggybacking", "sms"], + "bridgeInitiateData": { + "apnsTopic": "com.apple.idmsauthwidget", + "apnsEnvironment": "prod", + "webSocketUrl": "websocket.push.apple.com", + }, + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": "sms", + } + }, + } + + bridge_state = MagicMock() + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._trusted_device_bridge.start.return_value = bridge_state + pyicloud_service._session = MagicMock() + pyicloud_service.session.headers = {"User-Agent": "test-agent"} + pyicloud_service.session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + + assert pyicloud_service.request_2fa_code() is True + + pyicloud_service._trusted_device_bridge.start.assert_called_once() + pyicloud_service.session.put.assert_not_called() + assert pyicloud_service.two_factor_delivery_method == "trusted_device" + assert pyicloud_service._trusted_device_bridge_state is bridge_state + + +def test_request_2fa_code_replaces_existing_bridge_state_before_restart( + pyicloud_service: PyiCloudService, +) -> None: + """Starting a new bridge prompt should close any previous in-memory bridge session.""" + + pyicloud_service._auth_data = { + "authInitialRoute": "auth/bridge/step", + "hasTrustedDevices": True, + "bridgeInitiateData": { + "apnsTopic": "com.apple.idmsauthwidget", + "apnsEnvironment": "prod", + "webSocketUrl": "websocket.push.apple.com", + }, + } + + previous_bridge_state = MagicMock() + next_bridge_state = MagicMock() + pyicloud_service._trusted_device_bridge_state = previous_bridge_state + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._trusted_device_bridge.start.return_value = next_bridge_state + pyicloud_service._session = MagicMock() + pyicloud_service.session.headers = {"User-Agent": "test-agent"} + pyicloud_service.session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + + assert pyicloud_service.request_2fa_code() is True + + pyicloud_service._trusted_device_bridge.close.assert_called_once_with( + previous_bridge_state + ) + assert pyicloud_service._trusted_device_bridge_state is next_bridge_state + + +def test_request_2fa_code_falls_back_to_sms_when_bridge_fails( + pyicloud_service: PyiCloudService, +) -> None: + """Bridge bootstrap failures should fall back to SMS when Apple exposes it.""" + + pyicloud_service._auth_data = { + "authInitialRoute": "auth/bridge/step", + "hasTrustedDevices": True, + "bridgeInitiateData": { + "apnsTopic": "com.apple.idmsauthwidget", + "apnsEnvironment": "prod", + "webSocketUrl": "websocket.push.apple.com", + }, + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": "sms", + } + }, + } + + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._trusted_device_bridge.start.side_effect = ( + PyiCloudTrustedDevicePromptException("bridge failed") + ) + pyicloud_service._session = MagicMock() + pyicloud_service.session.headers = {"User-Agent": "test-agent"} + pyicloud_service.session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + + assert pyicloud_service.request_2fa_code() is True + + args = pyicloud_service.session.put.call_args.args + kwargs = pyicloud_service.session.put.call_args.kwargs + assert args[0] == f"{pyicloud_service._auth_endpoint}/verify/phone" + assert kwargs["json"] == { + "phoneNumber": {"id": 3, "nonFTEU": False}, + "mode": "sms", + } + assert pyicloud_service.two_factor_delivery_method == "sms" + assert pyicloud_service.two_factor_delivery_notice == ( + "Trusted-device prompt failed; falling back to SMS." + ) + + +def test_request_2fa_code_keeps_security_key_path_separate( + pyicloud_service: PyiCloudService, +) -> None: + """Security-key challenges should not start the bridge or SMS flows.""" + + pyicloud_service._auth_data = { + "fsaChallenge": {"challenge": "abc"}, + "authInitialRoute": "auth/bridge/step", + "hasTrustedDevices": True, + "bridgeInitiateData": { + "apnsTopic": "com.apple.idmsauthwidget", + "apnsEnvironment": "prod", + "webSocketUrl": "websocket.push.apple.com", + }, + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": "sms", + } + }, + } + + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._session = MagicMock() + pyicloud_service.session.headers = {"User-Agent": "test-agent"} + + assert pyicloud_service.request_2fa_code() is False + + pyicloud_service._trusted_device_bridge.start.assert_not_called() + pyicloud_service.session.put.assert_not_called() + assert pyicloud_service.two_factor_delivery_method == "security_key" + + +def test_validate_2fa_code_uses_nested_sms_phone_number( + pyicloud_service: PyiCloudService, +) -> None: + """Nested phone verification data should validate via the SMS endpoint.""" + + pyicloud_service.data = {"dsInfo": {"hsaVersion": 1}, "hsaChallengeRequired": False} + pyicloud_service._auth_data = { + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": "sms", + } + } + } + pyicloud_service.trust_session = MagicMock( + side_effect=lambda: pyicloud_service.data.update({"hsaTrustedBrowser": True}) + or True + ) + + with patch("pyicloud.base.PyiCloudSession") as mock_session: + pyicloud_service._session = mock_session + mock_session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + "session_token": "test_session_token", + } + + mock_post_response = MagicMock() + mock_post_response.status_code = 200 + mock_post_response.json.return_value = {"success": True} + mock_session.post.return_value = mock_post_response + + assert pyicloud_service.validate_2fa_code("123456") + + args = mock_session.post.call_args.args + kwargs = mock_session.post.call_args.kwargs + assert args[0] == f"{pyicloud_service._auth_endpoint}/verify/phone/securitycode" + assert kwargs["json"] == { + "phoneNumber": {"id": 3, "nonFTEU": False}, + "securityCode": {"code": "123456"}, + "mode": "sms", + } + + +def test_validate_2fa_code_defaults_sms_mode_when_push_mode_missing( + pyicloud_service: PyiCloudService, +) -> None: + """Missing SMS pushMode should still validate using the delivery mode used to trigger SMS.""" + + pyicloud_service.data = {"dsInfo": {"hsaVersion": 1}, "hsaChallengeRequired": False} + pyicloud_service._auth_data = { + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": None, + } + } + } + pyicloud_service._two_factor_delivery_method = "sms" + pyicloud_service.trust_session = MagicMock( + side_effect=lambda: pyicloud_service.data.update({"hsaTrustedBrowser": True}) + or True + ) + + with patch("pyicloud.base.PyiCloudSession") as mock_session: + pyicloud_service._session = mock_session + mock_session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + "session_token": "test_session_token", + } + + mock_post_response = MagicMock() + mock_post_response.status_code = 200 + mock_post_response.json.return_value = {"success": True} + mock_session.post.return_value = mock_post_response + + assert pyicloud_service.validate_2fa_code("123456") + + kwargs = mock_session.post.call_args.kwargs + assert kwargs["json"]["mode"] == "sms" + + def test_validate_2fa_code_failure(pyicloud_service: PyiCloudService) -> None: """Test the validate_2fa_code method with an invalid code.""" exception = PyiCloudAPIResponseException("Invalid code") @@ -431,6 +857,24 @@ def test_logout_clears_authenticated_state( assert pyicloud_service._devices is None +def test_logout_closes_active_trusted_device_bridge_state( + pyicloud_service: PyiCloudService, +) -> None: + """Logout should close any active trusted-device bridge session before clearing state.""" + + bridge_state = MagicMock() + pyicloud_service._trusted_device_bridge_state = bridge_state + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service.session.cookies = MagicMock() + pyicloud_service.session.cookies.get.return_value = None + pyicloud_service.session.clear_persistence = MagicMock() + + pyicloud_service.logout() + + pyicloud_service._trusted_device_bridge.close.assert_called_once_with(bridge_state) + assert pyicloud_service._trusted_device_bridge_state is None + + def test_cookiejar_path_property(pyicloud_session: PyiCloudSession) -> None: """Test the cookiejar_path property.""" path: str = pyicloud_session.cookiejar_path @@ -557,6 +1001,59 @@ def test_request_success(pyicloud_service_working: PyiCloudService) -> None: ) +def test_session_persistence_excludes_trusted_device_bridge_state( + pyicloud_service_working: PyiCloudService, +) -> None: + """Bridge-only state should remain in memory and never be written to persisted session files.""" + + test_base = Path(tempfile.gettempdir()) / "python-test-results" + test_base.mkdir(parents=True, exist_ok=True) + temp_root = Path(tempfile.mkdtemp(prefix="bridge-auth-", dir=test_base)) + session = PyiCloudSession( + service=pyicloud_service_working, + client_id="", + cookie_directory=str(temp_root), + ) + pyicloud_service_working._session = session + bridge_state = MagicMock( + push_token="bridge-ptkn", + session_uuid="bridge-session-uuid", + idmsdata="bridge-idmsdata", + encrypted_code="bridge-encrypted-code", + ) + pyicloud_service_working._trusted_device_bridge_state = bridge_state + session._data = { + "session_token": "valid-token", + "session_id": "persisted-session-id", + "push_token": bridge_state.push_token, + "session_uuid": bridge_state.session_uuid, + "idmsdata": bridge_state.idmsdata, + "encrypted_code": bridge_state.encrypted_code, + } + + session._save_session_data() + + persisted_session = Path(session.session_path).read_text(encoding="utf-8") + for secret_value in ( + "bridge-ptkn", + "bridge-session-uuid", + "bridge-idmsdata", + "bridge-encrypted-code", + ): + assert secret_value not in persisted_session + + cookiejar_path = Path(session.cookiejar_path) + if cookiejar_path.exists(): + persisted_cookiejar = cookiejar_path.read_text(encoding="utf-8") + for secret_value in ( + "bridge-ptkn", + "bridge-session-uuid", + "bridge-idmsdata", + "bridge-encrypted-code", + ): + assert secret_value not in persisted_cookiejar + + def test_request_failure(pyicloud_service_working: PyiCloudService) -> None: """Test the request method with a failure response.""" @@ -601,6 +1098,26 @@ def test_request_failure(pyicloud_service_working: PyiCloudService) -> None: assert open_mock.call_count == 2 +def test_request_raw_normalizes_transport_failure( + pyicloud_service_working: PyiCloudService, +) -> None: + """Raw requests should keep the session's normalized transport failure contract.""" + + with patch("requests.Session.request") as mock_request: + mock_request.side_effect = requests.exceptions.Timeout("timed out") + test_base = Path(tempfile.gettempdir()) / "python-test-results" + test_base.mkdir(parents=True, exist_ok=True) + temp_root = Path(tempfile.mkdtemp(prefix="request-raw-", dir=test_base)) + pyicloud_session = PyiCloudSession( + pyicloud_service_working, "", cookie_directory=str(temp_root) + ) + + with pytest.raises( + PyiCloudAPIResponseException, match="Request failed to iCloud" + ): + pyicloud_session.request_raw("GET", "https://example.com") + + def test_request_with_custom_headers(pyicloud_service_working: PyiCloudService) -> None: """Test the request method with custom headers.""" with ( diff --git a/tests/test_cmdline.py b/tests/test_cmdline.py index 964a107b..5b5598fe 100644 --- a/tests/test_cmdline.py +++ b/tests/test_cmdline.py @@ -10,7 +10,7 @@ from pathlib import Path from types import SimpleNamespace from typing import Any, Optional -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch from uuid import uuid4 import click @@ -211,6 +211,9 @@ def __init__( self.is_china_mainland = china_mainland self.fido2_devices: list[dict[str, Any]] = [] self.trusted_devices: list[dict[str, Any]] = [] + self.two_factor_delivery_method = "unknown" + self.two_factor_delivery_notice = None + self.request_2fa_code = MagicMock(return_value=False) self.validate_2fa_code = MagicMock(return_value=True) self.confirm_security_key = MagicMock(return_value=True) self.send_verification_code = MagicMock(return_value=True) @@ -1679,6 +1682,196 @@ def test_trusted_device_2sa_flow() -> None: ) +def test_sms_2fa_flow_requests_sms_before_prompt() -> None: + """Auth login should request SMS delivery before prompting for the code.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + fake_api.two_factor_delivery_method = "sms" + fake_api.request_2fa_code.return_value = True + with patch.object(context_module.typer, "prompt", return_value="123456"): + result = _invoke(fake_api, "auth", "login", interactive=True) + assert result.exit_code == 0 + assert "Requested a 2FA code by SMS." in result.stdout + fake_api.request_2fa_code.assert_called_once_with() + fake_api.validate_2fa_code.assert_called_once_with("123456") + + +def test_trusted_device_2fa_flow_reports_device_prompt() -> None: + """Auth login should report trusted-device prompt delivery when bridge succeeds.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + + def request_prompt() -> bool: + fake_api.two_factor_delivery_method = "trusted_device" + return True + + fake_api.request_2fa_code.side_effect = request_prompt + + with patch.object(context_module.typer, "prompt", return_value="123456"): + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code == 0 + assert "Requested a 2FA prompt on your trusted Apple devices." in result.stdout + fake_api.validate_2fa_code.assert_called_once_with("123456") + + +def test_code_prompt_aborts_when_request_2fa_code_requires_security_key() -> None: + """Auth login should not enter the numeric 2FA prompt loop for key-only challenges.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + fake_api.request_2fa_code.return_value = False + + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code != 0 + assert result.exception.args[0] == ( + "This 2FA challenge requires a security key. Connect one and retry." + ) + fake_api.validate_2fa_code.assert_not_called() + + +def test_trusted_device_2fa_retries_invalid_codes_before_success() -> None: + """Auth login should allow up to three trusted-device 2FA attempts.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + + def request_prompt() -> bool: + fake_api.two_factor_delivery_method = "trusted_device" + return True + + fake_api.request_2fa_code.side_effect = request_prompt + fake_api.validate_2fa_code.side_effect = [False, False, True] + + with patch.object( + context_module.typer, + "prompt", + side_effect=["111111", "222222", "333333"], + ): + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code == 0 + assert "Invalid 2FA code. 2 attempt(s) remaining." in result.stdout + assert "Invalid 2FA code. 1 attempt(s) remaining." in result.stdout + assert fake_api.validate_2fa_code.call_args_list == [ + call("111111"), + call("222222"), + call("333333"), + ] + + +def test_sms_2fa_aborts_after_three_invalid_codes() -> None: + """Auth login should stop after three invalid 2FA attempts.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + fake_api.two_factor_delivery_method = "sms" + fake_api.request_2fa_code.return_value = True + fake_api.validate_2fa_code.side_effect = [False, False, False] + + with patch.object( + context_module.typer, + "prompt", + side_effect=["111111", "222222", "333333"], + ): + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code != 0 + assert result.exception.args[0] == "Failed to verify the 2FA code." + assert "Invalid 2FA code. 2 attempt(s) remaining." in result.stdout + assert "Invalid 2FA code. 1 attempt(s) remaining." in result.stdout + assert fake_api.validate_2fa_code.call_args_list == [ + call("111111"), + call("222222"), + call("333333"), + ] + + +def test_trusted_device_2fa_bridge_fallback_reports_notice() -> None: + """Auth login should print the bridge fallback notice before the SMS message.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + + def request_sms_fallback() -> bool: + fake_api.two_factor_delivery_method = "sms" + fake_api.two_factor_delivery_notice = ( + "Trusted-device prompt failed; falling back to SMS." + ) + return True + + fake_api.request_2fa_code.side_effect = request_sms_fallback + + with patch.object(context_module.typer, "prompt", return_value="123456"): + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code == 0 + assert "Trusted-device prompt failed; falling back to SMS." in result.stdout + assert "Requested a 2FA code by SMS." in result.stdout + fake_api.validate_2fa_code.assert_called_once_with("123456") + + +def test_sms_2fa_request_failure_aborts() -> None: + """Auth login should surface SMS delivery request failures clearly.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + fake_api.request_2fa_code.side_effect = context_module.PyiCloudAPIResponseException( + "sms request failed" + ) + + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code != 0 + assert result.exception.args[0] == "Failed to request the 2FA SMS code." + fake_api.validate_2fa_code.assert_not_called() + + +def test_trusted_device_2fa_request_failure_aborts() -> None: + """Auth login should surface bridge delivery failures clearly.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + fake_api.request_2fa_code.side_effect = ( + context_module.PyiCloudTrustedDevicePromptException("bridge failed") + ) + + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code != 0 + assert result.exception.args[0] == ( + "Failed to request the 2FA trusted-device prompt." + ) + fake_api.validate_2fa_code.assert_not_called() + + +def test_trusted_device_2fa_verification_failure_aborts() -> None: + """Auth login should surface bridge verification failures clearly.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + + def request_prompt() -> bool: + fake_api.two_factor_delivery_method = "trusted_device" + return True + + fake_api.request_2fa_code.side_effect = request_prompt + fake_api.validate_2fa_code.side_effect = ( + context_module.PyiCloudTrustedDeviceVerificationException( + "bridge verification failed" + ) + ) + + with patch.object(context_module.typer, "prompt", return_value="123456"): + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code != 0 + assert result.exception.args[0] == ("Failed to verify the 2FA trusted-device code.") + + def test_non_interactive_2sa_does_not_send_verification_code() -> None: """Non-interactive 2SA should fail before sending a verification code.""" diff --git a/tests/test_hsa2_bridge.py b/tests/test_hsa2_bridge.py new file mode 100644 index 00000000..2e64f398 --- /dev/null +++ b/tests/test_hsa2_bridge.py @@ -0,0 +1,1357 @@ +"""Tests for the HSA2 trusted-device bridge helpers.""" + +from __future__ import annotations + +import base64 +import json +import socket +from binascii import unhexlify +from typing import Callable +from unittest.mock import MagicMock, call + +import pytest + +import pyicloud.hsa2_bridge as bridge_module +from pyicloud.exceptions import ( + PyiCloudTrustedDevicePromptException, + PyiCloudTrustedDeviceVerificationException, +) +from pyicloud.hsa2_bridge import ( + BRIDGE_DONE_DATA_B64, + BridgePushPayload, + Hsa2BootContext, + TrustedDeviceBridgeBootstrapper, + _encode_ack_message, + _encode_bytes_field, + _encode_string_field, + _encode_uint32_field, + _encode_web_filter_message, + _extract_json_payload, + _hex_to_b64, + _topic_hash, + parse_boot_args_html, +) +from pyicloud.hsa2_bridge_prover import ( + TrustedDeviceBridgeProver, + _TrustedDeviceBridgeServerProver, +) + + +class _FakeWebSocket: + def __init__( + self, + messages: list[bytes | Exception], + *, + on_read: Callable[[int], None] | None = None, + ) -> None: + self._messages = list(messages) + self._on_read = on_read + self.sent_messages: list[bytes] = [] + self.closed = False + self.read_count = 0 + + def send_binary(self, payload: bytes) -> None: + self.sent_messages.append(payload) + + def read_message(self) -> bytes: + self.read_count += 1 + if self._on_read is not None: + self._on_read(self.read_count) + message = self._messages.pop(0) + if isinstance(message, Exception): + raise message + return message + + def close(self) -> None: + self.closed = True + + +class _FakePrivateKey: + def sign(self, nonce: bytes, _algorithm: object) -> bytes: + return b"signature-for-" + nonce[:4] + + +def _encode_connection_response(push_token: bytes) -> bytes: + payload = b"".join( + [ + _encode_string_field(1, base64.b64encode(push_token).decode("ascii")), + _encode_uint32_field(2, 0), + ] + ) + return _encode_bytes_field(1, payload) + + +def _encode_connection_response_with_token_b64(push_token_b64: str) -> bytes: + payload = b"".join( + [ + _encode_string_field(1, push_token_b64), + _encode_uint32_field(2, 0), + ] + ) + return _encode_bytes_field(1, payload) + + +def _encode_push_message( + topic: str, payload: dict[str, object], message_id: int +) -> bytes: + topic_bytes = bytes.fromhex(_topic_hash(topic)) + body = b"".join( + [ + _encode_bytes_field(1, topic_bytes), + _encode_uint32_field(2, message_id), + _encode_bytes_field(4, json.dumps(payload).encode("utf-8")), + ] + ) + return _encode_bytes_field(2, body) + + +def _encode_channel_subscription_response(topic: str, message_id: int = 1) -> bytes: + channel_response = b"".join( + [ + _encode_string_field(1, topic), + _encode_bytes_field(2, _encode_bytes_field(1, b"channel-id")), + ] + ) + payload = _encode_bytes_field(1, channel_response) + body = b"".join( + [ + _encode_bytes_field(1, payload), + _encode_uint32_field(2, message_id), + _encode_uint32_field(3, 0), + ] + ) + return _encode_bytes_field(3, body) + + +def _read_varint(data: bytes, offset: int) -> tuple[int, int]: + value = 0 + shift = 0 + while True: + byte = data[offset] + offset += 1 + value |= (byte & 0x7F) << shift + if not (byte & 0x80): + return value, offset + shift += 7 + + +def _decode_fields(data: bytes) -> dict[int, list[int | bytes]]: + offset = 0 + fields: dict[int, list[int | bytes]] = {} + while offset < len(data): + key, offset = _read_varint(data, offset) + field_number = key >> 3 + wire_type = key & 0x07 + + if wire_type == 0: + value, offset = _read_varint(data, offset) + elif wire_type == 2: + length, offset = _read_varint(data, offset) + value = data[offset : offset + length] + offset += length + else: + raise AssertionError(f"Unexpected wire type {wire_type}") + + fields.setdefault(field_number, []).append(value) + return fields + + +def _boot_context(topic: str = "com.apple.idmsauthwidget") -> Hsa2BootContext: + return Hsa2BootContext( + auth_initial_route="auth/bridge/step", + has_trusted_devices=True, + auth_factors=("web_piggybacking", "sms"), + bridge_initiate_data={ + "apnsTopic": topic, + "apnsEnvironment": "prod", + "webSocketUrl": "websocket.push.apple.com", + }, + source_app_id="1159", + ) + + +def _response(status_code: int) -> MagicMock: + response = MagicMock() + response.status_code = status_code + response.text = "" + return response + + +def test_parse_boot_args_html_extracts_bridge_context() -> None: + """Request-5 style boot args should yield the bridge routing metadata.""" + + html = """ + + + + """ + + boot_context = parse_boot_args_html(html) + + assert boot_context.auth_initial_route == "auth/bridge/step" + assert boot_context.has_trusted_devices is True + assert boot_context.auth_factors == ( + "web_piggybacking", + "robocall", + "sms", + "generatedcode", + ) + assert boot_context.bridge_initiate_data["webSocketUrl"] == ( + "websocket.push.apple.com" + ) + assert boot_context.phone_number_verification["trustedPhoneNumber"]["id"] == 3 + assert boot_context.source_app_id == "1159" + + +def test_parse_boot_args_html_accepts_reordered_script_attributes() -> None: + """boot_args extraction should not depend on one exact script tag string.""" + + html = """ + + + + """ + + boot_context = parse_boot_args_html(html) + + assert boot_context.auth_initial_route == "auth/bridge/step" + assert boot_context.has_trusted_devices is True + assert boot_context.bridge_initiate_data["webSocketUrl"] == ( + "websocket.push.apple.com" + ) + + +def test_read_varint_rejects_malformed_overlong_varint() -> None: + """Malformed bridge varints should fail immediately instead of reading forever.""" + + with pytest.raises( + PyiCloudTrustedDevicePromptException, + match="Malformed protobuf varint", + ): + bridge_module._read_varint(b"\x80" * 10, 0) + + +def test_decode_fields_rejects_truncated_length_delimited_field() -> None: + """Length-delimited bridge fields must fit inside the current frame.""" + + with pytest.raises( + PyiCloudTrustedDevicePromptException, + match="Truncated protobuf field", + ): + bridge_module._decode_fields(b"\x0a\x05abc") + + +@pytest.mark.parametrize( + ("payload", "message"), + [ + ({"sessionUUID": 123}, "Malformed trusted-device bridge push payload"), + ({"sessionUUID": " "}, "Malformed trusted-device bridge push payload"), + ( + {"sessionUUID": "bridge-session", "nextStep": " "}, + "Malformed trusted-device bridge push payload", + ), + ( + {"sessionUUID": "bridge-session", "encryptedCode": " "}, + "Malformed trusted-device bridge push payload", + ), + ( + {"sessionUUID": "bridge-session", "ec": "oops"}, + "Malformed trusted-device bridge push payload", + ), + ], +) +def test_bridge_push_payload_rejects_malformed_fields( + payload: dict[str, object], message: str +) -> None: + """Bridge push validation should reject coerced or blank protocol fields.""" + + with pytest.raises(PyiCloudTrustedDevicePromptException, match=message): + BridgePushPayload.from_payload(payload) + + +def test_bridge_push_payload_preserves_unknown_extra_fields() -> None: + """Unknown Apple bridge fields should survive strict validation unchanged.""" + + payload = BridgePushPayload.from_payload( + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "extraField": {"foo": "bar"}, + } + ) + + assert payload.session_uuid == "bridge-session" + assert payload.payload["extraField"] == {"foo": "bar"} + + +def test_extract_json_payload_finds_embedded_json() -> None: + """Request-8 style binary payloads should yield the embedded JSON envelope.""" + + expected_payload = { + "sessionUUID": "bridge-session", + "nextStep": "2", + "ruiURLKey": "hsa2TwoFactorAuthApprovalFlowUrl", + } + noisy_payload = ( + b"\x12\xa8\x07\x00" + + json.dumps(expected_payload).encode("utf-8") + + b"\x18\x00\x01" + ) + + assert _extract_json_payload(noisy_payload) == expected_payload + + +def test_trusted_device_bridge_prover_roundtrip() -> None: + """The Python prover port should match the worker's SPAKE2+/AES-GCM flow.""" + + salt_b64 = base64.b64encode(b"0123456789abcdef").decode("ascii") + prover = TrustedDeviceBridgeProver() + server = _TrustedDeviceBridgeServerProver(password="050044", salt_b64=salt_b64) + + prover.init_with_salt(salt_b64, "050044") + client_message1 = prover.get_message1() + server_message1 = server.get_message1() + server_message2 = server.process_message1(client_message1) + client_message2 = prover.process_message1(server_message1) + server_key = server.verify_message2(client_message2) + client_key = prover.process_message2(server_message2)["key"] + + assert prover.is_verified() is True + assert client_key == server_key + encrypted_code = server.encrypt_message("derived-device-code") + assert prover.decrypt_message(encrypted_code) == "derived-device-code" + + +def test_trusted_device_bridge_prover_retries_zero_ephemeral_scalars( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Ephemeral prover scalars must stay in the non-zero subgroup range.""" + + draws = iter([0, 7, 0, 9]) + monkeypatch.setattr( + "pyicloud.hsa2_bridge_prover.secrets.randbelow", + lambda _limit: next(draws), + ) + + salt_b64 = base64.b64encode(b"0123456789abcdef").decode("ascii") + prover = TrustedDeviceBridgeProver() + prover.init_with_salt(salt_b64, "050044") + server = _TrustedDeviceBridgeServerProver(password="050044", salt_b64=salt_b64) + + assert prover._client is not None + assert prover._client._x == 7 + assert server._server._y == 9 + + +def test_trusted_device_bridge_prover_normalizes_malformed_bridge_payloads() -> None: + """Malformed encrypted payloads should surface as ValueError.""" + + prover = TrustedDeviceBridgeProver() + prover._verifier_key = "00" * 32 + + with pytest.raises(ValueError, match="Malformed bridge payload"): + prover.decrypt_message(base64.b64encode(b"").decode("ascii")) + + with pytest.raises(ValueError, match="Malformed bridge payload"): + prover.decrypt_message(base64.b64encode(b"\x01truncated").decode("ascii")) + + +def test_trusted_device_bridge_bootstrap_keeps_websocket_open_and_persists_step2() -> ( + None +): + """The bridge bootstrap should keep the websocket alive after step 0 succeeds.""" + + topic = "com.apple.idmsauthwidget" + bridge_payload = { + "sessionUUID": "bridge-session", + "nextStep": "2", + "ruiURLKey": "hsa2TwoFactorAuthApprovalFlowUrl", + "txnid": "2300_282820214_S", + "salt": "c2FsdA==", + "mid": "bridge-mid", + "idmsdata": "idms-data", + "akdata": {"lat": 49.52, "lng": 6.1}, + } + websocket_urls: list[tuple[str, float, str, str]] = [] + session = MagicMock() + session.request_raw.return_value = _response(200) + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message(topic, bridge_payload, 2300), + ], + on_read=lambda read_count: ( + read_count == 1 or session.request_raw.call_count == 1 + ) + or (_ for _ in ()).throw( + AssertionError("Bridge step 0 should be posted before waiting for push") + ), + ) + + def websocket_factory( + url: str, timeout: float, origin: str, user_agent: str + ) -> _FakeWebSocket: + websocket_urls.append((url, timeout, origin, user_agent)) + return websocket + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=websocket_factory, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + websocket_url = websocket_urls[0][0] + assert websocket_url.startswith("wss://websocket.push.apple.com/v2/") + assert state.connection_path == websocket_url.rsplit("/", 1)[1] + connection_message = unhexlify(state.connection_path) + outer_fields = _decode_fields(connection_message) + inner_fields = _decode_fields(outer_fields[1][0]) + assert inner_fields[1][0] == b"\x04public-key" + assert bytes(inner_fields[3][0]).startswith(b"\x01\x03signature-for-") + assert state.push_token == b"push-token".hex() + assert state.session_uuid == "bridge-session" + assert state.next_step == "2" + assert state.rui_url_key == "hsa2TwoFactorAuthApprovalFlowUrl" + assert state.txnid == "2300_282820214_S" + assert state.salt == "c2FsdA==" + assert state.mid == "bridge-mid" + assert state.idmsdata == "idms-data" + assert state.akdata == {"lat": 49.52, "lng": 6.1} + assert state.websocket is websocket + assert websocket.sent_messages[0] == _encode_web_filter_message([topic]) + assert websocket.sent_messages[1] == _encode_ack_message( + bytes.fromhex(_topic_hash(topic)), + 2300, + ) + session.request_raw.assert_called_once_with( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/0", + json={ + "sessionUUID": "bridge-session", + "ptkn": b"push-token".hex(), + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ) + assert websocket.closed is False + bootstrapper.close(state) + assert websocket.closed is True + assert state.websocket is None + + +def test_trusted_device_bridge_rejects_malformed_push_token() -> None: + """Malformed push tokens should surface as bridge prompt failures.""" + + websocket = _FakeWebSocket( + [_encode_connection_response_with_token_b64("%%%not-base64%%%")] + ) + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + + with pytest.raises( + PyiCloudTrustedDevicePromptException, + match="Failed to bootstrap the trusted-device bridge prompt.", + ) as exc_info: + bootstrapper.start( + session=MagicMock(), + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(), + user_agent="test-agent", + ) + + assert isinstance(exc_info.value.__cause__, PyiCloudTrustedDevicePromptException) + assert "Malformed bridge push token" in str(exc_info.value.__cause__) + assert websocket.closed is True + + +def test_trusted_device_bridge_rejects_mismatched_session_uuid() -> None: + """The first bridge push should match the session UUID used for step 0.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "different-session", + "nextStep": "2", + }, + 2300, + ), + ] + ) + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.return_value = _response(200) + + with pytest.raises( + PyiCloudTrustedDevicePromptException, + match="Failed to bootstrap the trusted-device bridge prompt.", + ) as exc_info: + bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + assert isinstance(exc_info.value.__cause__, PyiCloudTrustedDevicePromptException) + assert "mismatched session UUID" in str(exc_info.value.__cause__) + assert websocket.closed is True + + +def test_trusted_device_bridge_start_propagates_unexpected_exception() -> None: + """Unexpected bootstrap bugs should surface directly instead of being wrapped.""" + + websocket = _FakeWebSocket([TypeError("boom")]) + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + + with pytest.raises(TypeError, match="boom"): + bootstrapper.start( + session=MagicMock(), + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(), + user_agent="test-agent", + ) + + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_runs_step2_step4_step6_sequence() -> None: + """Bridge-backed trusted-device verification should follow Apple's step 2/4/6 flow.""" + + topic = "com.apple.idmsauthwidget" + initial_push = { + "sessionUUID": "bridge-session", + "nextStep": "2", + "ruiURLKey": "hsa2TwoFactorAuthApprovalFlowUrl", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "mid": "bridge-mid", + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + } + server_message1_hex = "aa01" + server_message2_hex = "bb02" + step4_data = base64.b64encode( + ( + _hex_to_b64(server_message1_hex) + "_" + _hex_to_b64(server_message2_hex) + ).encode("utf-8") + ).decode("ascii") + step4_push = { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": step4_data, + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + } + step6_push = { + "sessionUUID": "bridge-session", + "nextStep": "6", + "txnid": "2300_282820214_S", + "encryptedCode": "ciphertext", + "idmsdata": "step6-idms", + "akdata": {"step": 6}, + "mid": "bridge-mid", + } + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message(topic, initial_push, 2300), + _encode_push_message(topic, step4_push, 2301), + _encode_push_message(topic, step6_push, 2302), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.return_value = "ef01" + prover.process_message2.return_value = {"isVerified": True, "key": "deadbeef"} + prover.get_key.return_value = "deadbeef" + prover.decrypt_message.return_value = "derived-device-code" + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + + session = MagicMock() + session.request_raw.side_effect = [ + _response(200), + _response(200), + _response(200), + _response(409), + _response(204), + ] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + assert ( + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + is True + ) + + prover.init_with_salt.assert_called_once_with(initial_push["salt"], "050044") + prover.process_message1.assert_called_once_with(server_message1_hex) + prover.process_message2.assert_called_once_with(server_message2_hex) + prover.decrypt_message.assert_called_once_with("ciphertext") + assert session.request_raw.call_args_list == [ + call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/0", + json={ + "sessionUUID": "bridge-session", + "ptkn": b"push-token".hex(), + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ), + call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/2", + json={ + "sessionUUID": "bridge-session", + "data": _hex_to_b64("abcd"), + "ptkn": b"push-token".hex(), + "nextStep": 2, + "idmsdata": "initial-idms", + "akdata": '{"lat":49.52}', + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ), + call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/4", + json={ + "sessionUUID": "bridge-session", + "data": _hex_to_b64("ef01"), + "ptkn": b"push-token".hex(), + "nextStep": 4, + "idmsdata": "step4-idms", + "akdata": '{"step":4}', + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ), + call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/code/validate", + json={ + "sessionUUID": "bridge-session", + "code": "derived-device-code", + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ), + call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/6", + json={ + "sessionUUID": "bridge-session", + "data": BRIDGE_DONE_DATA_B64, + "ptkn": b"push-token".hex(), + "nextStep": 6, + "idmsdata": "step6-idms", + "akdata": '{"step":6}', + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ), + ] + assert websocket.sent_messages[2] == _encode_ack_message( + bytes.fromhex(_topic_hash(topic)), + 2301, + ) + assert websocket.sent_messages[3] == _encode_ack_message( + bytes.fromhex(_topic_hash(topic)), + 2302, + ) + assert websocket.closed is True + assert state.websocket is None + + +def test_trusted_device_bridge_validate_code_accepts_step4_encrypted_code_final_push() -> ( + None +): + """Apple can finish the bridge flow with nextStep=4 when encryptedCode is present.""" + + topic = "com.apple.idmsauthwidget" + initial_push = { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + } + step4_data = base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode("utf-8") + ).decode("ascii") + prover_push = { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": step4_data, + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + } + final_push = { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "encryptedCode": "ciphertext", + "idmsdata": "final-idms", + "akdata": {"step": "final"}, + } + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message(topic, initial_push, 2300), + _encode_push_message(topic, prover_push, 2301), + _encode_push_message(topic, final_push, 2302), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.return_value = "ef01" + prover.process_message2.return_value = {"isVerified": True, "key": "deadbeef"} + prover.decrypt_message.return_value = "derived-device-code" + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [ + _response(200), + _response(200), + _response(200), + _response(200), + _response(204), + ] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + assert ( + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + is True + ) + assert session.request_raw.call_args_list[-1] == call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/4", + json={ + "sessionUUID": "bridge-session", + "data": BRIDGE_DONE_DATA_B64, + "ptkn": b"push-token".hex(), + "nextStep": 4, + "idmsdata": "final-idms", + "akdata": '{"step":"final"}', + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ) + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_returns_false_on_412() -> None: + """A bridge code-validate 412 should be treated as an invalid code, not a transport failure.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode( + "utf-8" + ) + ).decode("ascii"), + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + }, + 2301, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "6", + "txnid": "2300_282820214_S", + "encryptedCode": "ciphertext", + "idmsdata": "step6-idms", + "akdata": {"step": 6}, + }, + 2302, + ), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.return_value = "ef01" + prover.process_message2.return_value = {"isVerified": True, "key": "deadbeef"} + prover.decrypt_message.return_value = "derived-device-code" + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [ + _response(200), + _response(200), + _response(200), + _response(412), + _response(204), + ] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + assert ( + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + is False + ) + assert session.request_raw.call_args_list[-1].args[1].endswith("/bridge/step/6") + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_rejects_error_push() -> None: + """Bridge error pushes should surface as verification exceptions.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode( + "utf-8" + ) + ).decode("ascii"), + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + }, + 2301, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "6", + "txnid": "2300_282820214_S", + "ec": 7, + }, + 2302, + ), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.return_value = "ef01" + prover.process_message2.return_value = {"isVerified": True, "key": "deadbeef"} + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [_response(200), _response(200), _response(200)] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="error push", + ): + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_rejects_malformed_final_push() -> None: + """Final bridge pushes must include encryptedCode once the prover flow is complete.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode( + "utf-8" + ) + ).decode("ascii"), + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + }, + 2301, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + }, + 2302, + ), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.return_value = "ef01" + prover.process_message2.return_value = {"isVerified": True, "key": "deadbeef"} + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [_response(200), _response(200), _response(200)] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="unexpected final payload", + ): + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_rejects_mismatched_followup_push() -> None: + """Follow-up bridge pushes must stay on the same bridge session.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + _encode_push_message( + topic, + { + "sessionUUID": "different-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode( + "utf-8" + ) + ).decode("ascii"), + }, + 2301, + ), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [_response(200), _response(200)] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="mismatched session UUID", + ): + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_closes_on_timeout() -> None: + """Timeouts after prompt delivery should surface as bridge verification failures.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + socket.timeout("timed out"), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [_response(200), _response(200)] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="websocket transport error", + ): + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_wraps_step4_prover_message1_failure() -> ( + None +): + """Malformed step-4 prover data should surface as bridge verification failures.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode( + "utf-8" + ) + ).decode("ascii"), + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + }, + 2301, + ), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.side_effect = ValueError("bad point") + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [_response(200), _response(200)] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="step 4 payload is malformed", + ): + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + assert websocket.closed is True