From 333cbc41281a2dc432165c9b3ad55f7d39deddcd Mon Sep 17 00:00:00 2001 From: mrjarnould Date: Tue, 31 Mar 2026 02:21:31 +0200 Subject: [PATCH 1/5] feat: handle Apple's HSA2 trusted-device prompts --- README.md | 55 ++ pyicloud/base.py | 341 ++++++- pyicloud/cli/context.py | 50 +- pyicloud/exceptions.py | 8 + pyicloud/hsa2_bridge.py | 1613 ++++++++++++++++++++++++++++++++ pyicloud/hsa2_bridge_prover.py | 517 ++++++++++ pyicloud/session.py | 62 ++ requirements.txt | 1 + tests/test_base.py | 450 +++++++++ tests/test_cmdline.py | 811 +++++++++++++++- tests/test_hsa2_bridge.py | 1323 ++++++++++++++++++++++++++ 11 files changed, 5206 insertions(+), 25 deletions(-) create mode 100644 pyicloud/hsa2_bridge.py create mode 100644 pyicloud/hsa2_bridge_prover.py create mode 100644 tests/test_hsa2_bridge.py diff --git a/README.md b/README.md index f6836411..2646530e 100644 --- a/README.md +++ b/README.md @@ -182,6 +182,10 @@ If you have enabled two-factor authentications (2FA) or [two-step authentication (2SA)](https://support.apple.com/en-us/HT204152) for the account you will have to do some extra work: +For HSA2 accounts, `request_2fa_code()` now starts Apple's active delivery +route for the current challenge. Depending on the account and session, that may +be a trusted-device prompt, an SMS code, or a security-key flow. + ```python import sys @@ -216,6 +220,7 @@ if api.requires_2fa: else: print("Two-factor authentication required.") + api.request_2fa_code() code = input( "Enter the code you received of one of your approved devices: " ) @@ -1254,6 +1259,56 @@ Notes caveats: - `api.notes.raw` is available for advanced/debug workflows, but it is not the primary Notes API surface. +### Notes CLI + +### Notes CLI + +The official Typer CLI exposes `icloud notes ...` for recent-note inspection, +folder browsing, title-based search, HTML rendering, and note-id-based export. + +_List recent notes, folders, or one folder’s notes:_ + +```bash +uv run icloud notes recent --username you@example.com +uv run icloud notes folders --username you@example.com +uv run icloud notes list --username you@example.com --folder-id FOLDER_ID +uv run icloud notes list --username you@example.com --all --since PREVIOUS_CURSOR +``` + +_Search notes by title:_ + +```bash +uv run icloud notes search --username you@example.com --title "Daily Plan" +uv run icloud notes search --username you@example.com --title-contains "meeting" +``` + +`icloud notes search` is the official title-filter workflow. It uses a +recents-first search strategy and falls back to a full feed scan when needed. + +_Fetch, render, and export one note by id:_ + +```bash +uv run icloud notes get NOTE_ID --username you@example.com --with-attachments +uv run icloud notes render NOTE_ID --username you@example.com --preview-appearance dark +uv run icloud notes export NOTE_ID \ + --username you@example.com \ + --output-dir ./exports/notes_html \ + --export-mode archival \ + --assets-dir ./exports/assets +``` + +`icloud notes export` stays explicit by note id. Title filters are intentionally +handled by `icloud notes search` rather than by bulk export flags. + +_Inspect incremental changes:_ + +```bash +uv run icloud notes changes --username you@example.com --since PREVIOUS_CURSOR +uv run icloud notes sync-cursor --username you@example.com +``` + +### Notes CLI Example + ### Notes CLI Example [`examples/notes_cli.py`](examples/notes_cli.py) is a local developer utility diff --git a/pyicloud/base.py b/pyicloud/base.py index 47897c70..f26bd2d4 100644 --- a/pyicloud/base.py +++ b/pyicloud/base.py @@ -5,9 +5,10 @@ import json import logging import time +from dataclasses import dataclass from os import chmod, environ, makedirs, path, umask from tempfile import gettempdir -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Mapping, Optional from uuid import uuid1 import srp @@ -34,6 +35,13 @@ PyiCloudPasswordException, PyiCloudServiceNotActivatedException, PyiCloudServiceUnavailable, + PyiCloudTrustedDevicePromptException, +) +from pyicloud.hsa2_bridge import ( + Hsa2BootContext, + TrustedDeviceBridgeBootstrapper, + TrustedDeviceBridgeState, + parse_boot_args_html, ) from pyicloud.services import ( AccountService, @@ -106,6 +114,92 @@ def resolve_cookie_directory(cookie_directory: Optional[str] = None) -> str: return path.join(topdir, getpass.getuser()) +@dataclass(frozen=True) +class TrustedPhoneNumber: + """Typed view of Apple's trusted-phone metadata.""" + + device_id: int | str + non_fteu: Optional[bool] = None + push_mode: Optional[str] = None + + @classmethod + def from_mapping( + cls, value: Optional[Mapping[str, Any]] + ) -> Optional["TrustedPhoneNumber"]: + """Return a typed phone record when Apple's payload includes one.""" + + if not isinstance(value, Mapping): + return None + device_id = value.get("id") + if not isinstance(device_id, (int, str)): + return None + + non_fteu = value.get("nonFTEU") + if not isinstance(non_fteu, bool): + non_fteu = None + + push_mode = value.get("pushMode") + if push_mode is not None: + push_mode = str(push_mode) + + return cls( + device_id=device_id, + non_fteu=non_fteu, + push_mode=push_mode, + ) + + def as_phone_number_payload(self) -> dict[str, Any]: + """Return the nested phoneNumber payload expected by Apple's SMS endpoints.""" + + payload: dict[str, Any] = {"id": self.device_id} + if self.non_fteu is not None: + payload["nonFTEU"] = self.non_fteu + return payload + + +@dataclass(frozen=True) +class PhoneNumberVerification: + """Typed view of Apple's phone verification wrapper payload.""" + + trusted_phone_number: Optional[TrustedPhoneNumber] = None + trusted_phone_numbers: tuple[TrustedPhoneNumber, ...] = () + + @classmethod + def from_mapping( + cls, value: Optional[Mapping[str, Any]] + ) -> "PhoneNumberVerification": + """Return the parsed phone verification payload when Apple exposes one.""" + + if not isinstance(value, Mapping): + return cls() + + trusted_phone_number = TrustedPhoneNumber.from_mapping( + value.get("trustedPhoneNumber") + ) + + trusted_phone_numbers_raw = value.get("trustedPhoneNumbers") + trusted_phone_numbers: list[TrustedPhoneNumber] = [] + if isinstance(trusted_phone_numbers_raw, list): + for entry in trusted_phone_numbers_raw: + phone_number = TrustedPhoneNumber.from_mapping(entry) + if phone_number is not None: + trusted_phone_numbers.append(phone_number) + + return cls( + trusted_phone_number=trusted_phone_number, + trusted_phone_numbers=tuple(trusted_phone_numbers), + ) + + def best_trusted_phone_number(self) -> Optional[TrustedPhoneNumber]: + """Return the first usable trusted phone number from Apple's payload.""" + + if self.trusted_phone_number is not None: + return self.trusted_phone_number + if self.trusted_phone_numbers: + return self.trusted_phone_numbers[0] + return None + + class PyiCloudService: """ A base authentication class for the iCloud service. Handles the @@ -175,6 +269,11 @@ def __init__( self.data: dict[str, Any] = {} self._auth_data: dict[str, Any] = {} + self._hsa2_boot_context: Optional[Hsa2BootContext] = None + self._trusted_device_bridge_state: Optional[TrustedDeviceBridgeState] = None + self._trusted_device_bridge = TrustedDeviceBridgeBootstrapper() + self._two_factor_delivery_method: str = "unknown" + self._two_factor_delivery_notice: Optional[str] = None self.params: dict[str, Any] = {} self._client_id: str = client_id or str(uuid1()).lower() @@ -326,6 +425,10 @@ def _clear_authenticated_state(self) -> None: self.data = {} self._auth_data = {} + self._hsa2_boot_context = None + self._clear_trusted_device_bridge_state() + self._two_factor_delivery_method = "unknown" + self._two_factor_delivery_notice = None self._webservices = None self._account = None self._calendar = None @@ -339,6 +442,13 @@ def _clear_authenticated_state(self) -> None: self._requires_mfa = False self.params.pop("dsid", None) + def _clear_trusted_device_bridge_state(self) -> None: + """Close any active trusted-device bridge session and clear in-memory state.""" + + if self._trusted_device_bridge_state is not None: + self._trusted_device_bridge.close(self._trusted_device_bridge_state) + self._trusted_device_bridge_state = None + def get_auth_status(self) -> dict[str, Any]: """Probe current authentication state without prompting for login.""" @@ -542,6 +652,11 @@ def _authenticate_with_token(self) -> None: if not self.is_trusted_session: raise PyiCloud2FARequiredException(self.account_name, resp) + + self._auth_data = {} + self._hsa2_boot_context = None + self._clear_trusted_device_bridge_state() + self._set_two_factor_delivery_state("unknown") except (PyiCloudAPIResponseException, HTTPError) as error: msg = "Invalid authentication token." raise PyiCloudFailedLoginException(msg, error) from error @@ -679,9 +794,111 @@ def validate_verification_code(self, device: dict[str, Any], code: str) -> bool: def _get_mfa_auth_options(self) -> Dict: """Retrieve auth request options for assertion.""" + # Apple exposes the HSA2 bridge bootstrap in the HTML auth shell. + # Requesting JSON here tends to collapse the response to the SMS-oriented shape. + headers = self._get_auth_headers({"Accept": "text/html"}) + response = self.session.get(self._auth_endpoint, headers=headers) + + auth_options: dict[str, Any] = {} + try: + response_json = response.json() + except (AttributeError, TypeError, ValueError): + response_json = None + + if isinstance(response_json, dict): + auth_options.update(response_json) + boot_context = Hsa2BootContext.from_auth_options(auth_options) + else: + boot_context = parse_boot_args_html(getattr(response, "text", "")) + + boot_auth_data = boot_context.as_auth_data() + auth_options.update(boot_auth_data) + self._hsa2_boot_context = boot_context + self._clear_trusted_device_bridge_state() + self._set_two_factor_delivery_state("unknown") + return auth_options + + def _set_two_factor_delivery_state( + self, method: str, notice: Optional[str] = None + ) -> None: + """Track the active MFA delivery route for the current auth challenge.""" + + self._two_factor_delivery_method = method + self._two_factor_delivery_notice = notice + + def _current_hsa2_boot_context(self) -> Hsa2BootContext: + """Return the best available HSA2 boot context for the active challenge.""" + + if self._hsa2_boot_context is not None: + return self._hsa2_boot_context + + boot_context = Hsa2BootContext.from_auth_options(self._auth_data) + self._hsa2_boot_context = boot_context + return boot_context + + def _supports_trusted_device_bridge(self) -> bool: + """Return whether Apple's HSA2 boot data prefers the bridge flow.""" + + boot_context = self._current_hsa2_boot_context() + return ( + boot_context.auth_initial_route == "auth/bridge/step" + and boot_context.has_trusted_devices + and bool(boot_context.bridge_initiate_data) + ) + + def _can_request_sms_2fa_code(self) -> bool: + """Return whether SMS delivery is currently available.""" + + return ( + self._two_factor_mode() == "sms" + and self._trusted_phone_number() is not None + ) + + def _request_sms_2fa_code(self, notice: Optional[str] = None) -> bool: + """Trigger SMS delivery for the current HSA2 challenge.""" + + trusted_phone_number = self._trusted_phone_number() + if not trusted_phone_number: + raise PyiCloudNoTrustedNumberAvailable() + + data: dict[str, Any] = { + "phoneNumber": trusted_phone_number.as_phone_number_payload(), + "mode": "sms", + } headers = self._get_auth_headers({"Accept": CONTENT_TYPE_JSON}) - return self.session.get(self._auth_endpoint, headers=headers).json() + self.session.put( + f"{self._auth_endpoint}/verify/phone", + json=data, + headers=headers, + ) + self._clear_trusted_device_bridge_state() + self._set_two_factor_delivery_state("sms", notice) + return True + + @property + def two_factor_delivery_method(self) -> str: + """Return the current HSA2 delivery method without exposing auth internals.""" + + if self._two_factor_delivery_method != "unknown": + return self._two_factor_delivery_method + + if self._auth_data.get("fsaChallenge") or self.security_key_names: + return "security_key" + + if self._supports_trusted_device_bridge(): + return "trusted_device" + + if self._two_factor_mode() == "sms": + return "sms" + + return "unknown" + + @property + def two_factor_delivery_notice(self) -> Optional[str]: + """Return an optional user-facing note about the active 2FA delivery path.""" + + return self._two_factor_delivery_notice @property def security_key_names(self) -> Optional[List[str]]: @@ -696,6 +913,74 @@ def _submit_webauthn_assertion_response(self, data: Dict) -> None: f"{self._auth_endpoint}/verify/security/key", json=data, headers=headers ) + def _phone_number_verification(self) -> PhoneNumberVerification: + """Return Apple's nested phone verification payload when present.""" + + phone_verification = self._auth_data.get("phoneNumberVerification") + return PhoneNumberVerification.from_mapping(phone_verification) + + def _trusted_phone_number(self) -> Optional[TrustedPhoneNumber]: + """Return the best available trusted phone number description.""" + + trusted_phone_number = TrustedPhoneNumber.from_mapping( + self._auth_data.get("trustedPhoneNumber") + ) + if trusted_phone_number is not None: + return trusted_phone_number + + return self._phone_number_verification().best_trusted_phone_number() + + def _two_factor_mode(self) -> Optional[str]: + """Return the current 2FA delivery mode reported by Apple.""" + + mode = self._auth_data.get("mode") + if isinstance(mode, str): + return mode + + trusted_phone_number = self._trusted_phone_number() + if trusted_phone_number is None: + return None + + return trusted_phone_number.push_mode + + def request_2fa_code(self) -> bool: + """Trigger the active HSA2 delivery route for the current challenge.""" + + if self._auth_data.get("fsaChallenge") or self.security_key_names: + self._set_two_factor_delivery_state("security_key") + return False + + self._clear_trusted_device_bridge_state() + + if self._supports_trusted_device_bridge(): + try: + self._trusted_device_bridge_state = self._trusted_device_bridge.start( + session=self.session, + auth_endpoint=self._auth_endpoint, + headers=self._get_auth_headers({"Accept": CONTENT_TYPE_JSON}), + boot_context=self._current_hsa2_boot_context(), + user_agent=self.session.headers.get( + "User-Agent", _HEADERS["User-Agent"] + ), + ) + self._set_two_factor_delivery_state("trusted_device") + return True + except PyiCloudTrustedDevicePromptException: + LOGGER.debug( + "Trusted-device bridge bootstrap failed; falling back to SMS when available.", + exc_info=True, + ) + if self._can_request_sms_2fa_code(): + return self._request_sms_2fa_code( + notice="Trusted-device prompt failed; falling back to SMS." + ) + raise + + if self._can_request_sms_2fa_code(): + return self._request_sms_2fa_code() + + return False + @property def fido2_devices(self) -> List[CtapHidDevice]: """List the available FIDO2 devices.""" @@ -831,45 +1116,59 @@ def _request_pcs_for_service(self, app_name: str) -> None: def validate_2fa_code(self, code: str) -> bool: """Verifies a verification code received via Apple's 2FA system (HSA2).""" + bridge_state = self._trusted_device_bridge_state try: - if self._auth_data.get("mode") == "sms": + if self.two_factor_delivery_method == "sms": self._validate_sms_code(code) + elif ( + bridge_state is not None + and not bridge_state.uses_legacy_trusted_device_verifier + ): + if not self._trusted_device_bridge.validate_code( + session=self.session, + auth_endpoint=self._auth_endpoint, + headers=self._get_auth_headers({"Accept": CONTENT_TYPE_JSON}), + bridge_state=bridge_state, + code=code, + ): + LOGGER.error("Code verification failed.") + return False else: - data: dict[str, Any] = {"securityCode": {"code": code}} - headers: dict[str, Any] = self._get_auth_headers( - {"Accept": CONTENT_TYPE_JSON} - ) - self.session.post( - f"{self._auth_endpoint}/verify/trusteddevice/securitycode", - json=data, - headers=headers, - ) + self._validate_trusted_device_code(code) except PyiCloudAPIResponseException: # Wrong verification code LOGGER.error("Code verification failed.") return False + finally: + if bridge_state is not None: + self._clear_trusted_device_bridge_state() LOGGER.debug("Code verification successful.") self.trust_session() return not self.requires_2sa + def _validate_trusted_device_code(self, code: str) -> None: + """Verifies a verification code received via Apple's legacy device endpoint.""" + + data: dict[str, Any] = {"securityCode": {"code": code}} + headers: dict[str, Any] = self._get_auth_headers({"Accept": CONTENT_TYPE_JSON}) + self.session.post( + f"{self._auth_endpoint}/verify/trusteddevice/securitycode", + json=data, + headers=headers, + ) + def _validate_sms_code(self, code: str) -> None: """Verifies a verification code received via Apple's SMS system.""" - trusted_phone_number: dict[str, Any] | None = self._auth_data.get( - "trustedPhoneNumber" - ) + trusted_phone_number = self._trusted_phone_number() if not trusted_phone_number: raise PyiCloudNoTrustedNumberAvailable() - device_id: int | None = trusted_phone_number.get("id") - non_fteu: bool | None = trusted_phone_number.get("nonFTEU") - mode: str | None = trusted_phone_number.get("pushMode") - data: dict[str, Any] = { - "phoneNumber": {"id": device_id, "nonFTEU": non_fteu}, + "phoneNumber": trusted_phone_number.as_phone_number_payload(), "securityCode": {"code": code}, - "mode": mode, + "mode": trusted_phone_number.push_mode, } headers: dict[str, Any] = self._get_auth_headers( {"Accept": f"{CONTENT_TYPE_JSON}, {CONTENT_TYPE_TEXT}"} diff --git a/pyicloud/cli/context.py b/pyicloud/cli/context.py index ab22fccb..fd780738 100644 --- a/pyicloud/cli/context.py +++ b/pyicloud/cli/context.py @@ -17,9 +17,13 @@ from pyicloud import PyiCloudService, utils from pyicloud.base import resolve_cookie_directory from pyicloud.exceptions import ( + PyiCloudAPIResponseException, PyiCloudAuthRequiredException, PyiCloudFailedLoginException, + PyiCloudNoTrustedNumberAvailable, PyiCloudServiceUnavailable, + PyiCloudTrustedDevicePromptException, + PyiCloudTrustedDeviceVerificationException, ) from pyicloud.ssl_context import configurable_ssl_verification @@ -332,9 +336,49 @@ def _handle_2fa(self, api: PyiCloudService) -> None: raise CLIAbort( "Two-factor authentication is required, but interactive prompts are disabled." ) - code = typer.prompt("Enter 2FA code") - if not api.validate_2fa_code(code): - raise CLIAbort("Failed to verify the 2FA code.") + try: + if api.request_2fa_code(): + notice = getattr(api, "two_factor_delivery_notice", None) + if notice: + self.console.print(notice) + + delivery_method = getattr( + api, "two_factor_delivery_method", "unknown" + ) + if delivery_method == "trusted_device": + self.console.print( + "Requested a 2FA prompt on your trusted Apple devices." + ) + elif delivery_method == "sms": + self.console.print("Requested a 2FA code by SMS.") + except PyiCloudNoTrustedNumberAvailable as exc: + raise CLIAbort( + "Two-factor authentication requires a trusted phone number, " + "but none was returned." + ) from exc + except PyiCloudTrustedDevicePromptException as exc: + raise CLIAbort( + "Failed to request the 2FA trusted-device prompt." + ) from exc + except PyiCloudAPIResponseException as exc: + raise CLIAbort("Failed to request the 2FA SMS code.") from exc + max_attempts = 3 + for attempt in range(max_attempts): + code = typer.prompt("Enter 2FA code") + try: + is_valid = api.validate_2fa_code(code) + except PyiCloudTrustedDeviceVerificationException as exc: + raise CLIAbort( + "Failed to verify the 2FA trusted-device code." + ) from exc + if is_valid: + break + remaining_attempts = max_attempts - attempt - 1 + if remaining_attempts <= 0: + raise CLIAbort("Failed to verify the 2FA code.") + self.console.print( + f"Invalid 2FA code. {remaining_attempts} attempt(s) remaining." + ) if not api.is_trusted_session: api.trust_session() diff --git a/pyicloud/exceptions.py b/pyicloud/exceptions.py index 57f2a7de..d5383faf 100644 --- a/pyicloud/exceptions.py +++ b/pyicloud/exceptions.py @@ -99,6 +99,14 @@ class PyiCloudNoTrustedNumberAvailable(PyiCloudException): """iCloud no trusted number exception.""" +class PyiCloudTrustedDevicePromptException(PyiCloudException): + """Trusted-device prompt bootstrap exception.""" + + +class PyiCloudTrustedDeviceVerificationException(PyiCloudException): + """Trusted-device bridge verification exception.""" + + class PyiCloudNoStoredPasswordAvailableException(PyiCloudException): """iCloud no stored password exception.""" diff --git a/pyicloud/hsa2_bridge.py b/pyicloud/hsa2_bridge.py new file mode 100644 index 00000000..9dded537 --- /dev/null +++ b/pyicloud/hsa2_bridge.py @@ -0,0 +1,1613 @@ +"""Internal helpers for Apple's HSA2 trusted-device bridge flow.""" + +from __future__ import annotations + +import base64 +import hashlib +import json +import logging +import os +import socket +import ssl +import struct +import time +import uuid +from binascii import Error as BinasciiError +from dataclasses import dataclass, field +from html.parser import HTMLParser +from typing import Any, Callable, Mapping, Optional, Protocol +from urllib.parse import urlparse + +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec +from pydantic import ( + BaseModel, + ConfigDict, + Field, + StrictInt, + StrictStr, + ValidationError, + field_validator, +) + +from pyicloud.exceptions import ( + PyiCloudTrustedDevicePromptException, + PyiCloudTrustedDeviceVerificationException, +) +from pyicloud.hsa2_bridge_prover import TrustedDeviceBridgeProver + +LOGGER = logging.getLogger(__name__) + +BRIDGE_STEP_PATH = "/bridge/step/0" +BRIDGE_STEP_PATH_TEMPLATE = "/bridge/step/{step}" +BRIDGE_CODE_VALIDATE_PATH = "/bridge/code/validate" +NEW_CONNECTION_EXPIRATION_SECONDS = 86400 +OPCODE_BINARY = 0x2 +OPCODE_CLOSE = 0x8 +OPCODE_PING = 0x9 +OPCODE_PONG = 0xA +SERVER_MESSAGE_CONNECTION_RESPONSE = 1 +SERVER_MESSAGE_PUSH = 2 +SERVER_MESSAGE_CHANNEL_SUBSCRIPTION_RESPONSE = 3 +SERVER_MESSAGE_PUSH_ACK = 7 +STATUS_OK = 0 +STATUS_INVALID_NONCE = 2 +BRIDGE_SIGNATURE_PREFIX = b"\x01\x03" +BRIDGE_DONE_DATA_B64 = base64.b64encode(b"done").decode("ascii") +WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" +WEBSOCKET_TIMEOUT_SECONDS = 30.0 +WEBSOCKET_ENVIRONMENT_HOSTS: dict[str, str] = { + "prod": "websocket.push.apple.com", + "sandbox": "websocket.sandbox.push.apple.com", +} +HTTP_STATUS_OK = 200 +HTTP_STATUS_NO_CONTENT = 204 +HTTP_STATUS_CONFLICT = 409 +HTTP_STATUS_PRECONDITION_FAILED = 412 + + +@dataclass(frozen=True) +class Hsa2BootContext: + """Bridge-related HSA2 boot data parsed from Apple's HTML bootstrap.""" + + auth_initial_route: str = "" + has_trusted_devices: bool = False + auth_factors: tuple[str, ...] = () + bridge_initiate_data: dict[str, Any] = field(default_factory=dict) + phone_number_verification: dict[str, Any] = field(default_factory=dict) + source_app_id: Optional[str] = None + + @classmethod + def from_auth_options(cls, auth_options: Mapping[str, Any]) -> "Hsa2BootContext": + bridge_initiate_data = auth_options.get("bridgeInitiateData") + if not isinstance(bridge_initiate_data, dict): + bridge_initiate_data = {} + + phone_number_verification = auth_options.get("phoneNumberVerification") + if not isinstance(phone_number_verification, dict): + phone_number_verification = bridge_initiate_data.get( + "phoneNumberVerification" + ) + if not isinstance(phone_number_verification, dict): + phone_number_verification = {} + + auth_factors = auth_options.get("authFactors") + if not isinstance(auth_factors, list): + auth_factors = [] + + source_app_id = auth_options.get("sourceAppId") + if source_app_id is not None: + source_app_id = str(source_app_id) + + return cls( + auth_initial_route=str(auth_options.get("authInitialRoute") or ""), + has_trusted_devices=bool(auth_options.get("hasTrustedDevices")), + auth_factors=tuple( + factor for factor in auth_factors if isinstance(factor, str) + ), + bridge_initiate_data=dict(bridge_initiate_data), + phone_number_verification=dict(phone_number_verification), + source_app_id=source_app_id, + ) + + def as_auth_data(self) -> dict[str, Any]: + """Return parsed boot data in the shape expected by the auth flow.""" + + auth_data: dict[str, Any] = { + "authInitialRoute": self.auth_initial_route, + "hasTrustedDevices": self.has_trusted_devices, + "authFactors": list(self.auth_factors), + } + if self.bridge_initiate_data: + auth_data["bridgeInitiateData"] = dict(self.bridge_initiate_data) + if self.phone_number_verification: + auth_data["phoneNumberVerification"] = dict(self.phone_number_verification) + trusted_phone_number = self.phone_number_verification.get( + "trustedPhoneNumber" + ) + if isinstance(trusted_phone_number, dict): + auth_data["trustedPhoneNumber"] = dict(trusted_phone_number) + if self.source_app_id is not None: + auth_data["sourceAppId"] = self.source_app_id + return auth_data + + +class _BridgePushPayloadModel(BaseModel): + """Strict validator for Apple's bridge push JSON envelope.""" + + model_config = ConfigDict( + extra="allow", + populate_by_name=True, + arbitrary_types_allowed=True, + ) + + session_uuid: StrictStr = Field(alias="sessionUUID") + next_step: Optional[StrictStr | StrictInt] = Field(default=None, alias="nextStep") + rui_url_key: Optional[str] = Field(default=None, alias="ruiURLKey") + txnid: Optional[StrictStr] = None + salt: Optional[StrictStr] = None + mid: Optional[StrictStr] = None + idmsdata: Optional[StrictStr] = None + akdata: Any = None + data: Optional[StrictStr] = None + encrypted_code: Optional[StrictStr] = Field(default=None, alias="encryptedCode") + error_code: Optional[StrictInt] = Field(default=None, alias="ec") + + @field_validator("session_uuid") + @classmethod + def _validate_session_uuid(cls, value: str) -> str: + if not value.strip(): + raise ValueError("sessionUUID must not be blank") + return value + + @field_validator( + "txnid", + "salt", + "mid", + "idmsdata", + "data", + "encrypted_code", + ) + @classmethod + def _validate_optional_non_empty_strings( + cls, value: Optional[str] + ) -> Optional[str]: + if value is not None and not value.strip(): + raise ValueError("Bridge payload strings must not be blank") + return value + + @field_validator("next_step") + @classmethod + def _validate_next_step(cls, value: Optional[str | int]) -> Optional[str | int]: + if isinstance(value, str) and not value.strip(): + raise ValueError("nextStep must not be blank") + return value + + +@dataclass(frozen=True) +class BridgePushPayload: + """Decoded bridge push metadata needed to bootstrap trusted-device prompts.""" + + payload: dict[str, Any] + session_uuid: str + next_step: Optional[str] = None + rui_url_key: Optional[str] = None + txnid: Optional[str] = None + salt: Optional[str] = None + mid: Optional[str] = None + idmsdata: Optional[str] = None + akdata: Any = None + data: Optional[str] = None + encrypted_code: Optional[str] = None + error_code: Optional[int] = None + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> "BridgePushPayload": + try: + validated = _BridgePushPayloadModel.model_validate(payload) + except ValidationError as exc: + raise PyiCloudTrustedDevicePromptException( + "Malformed trusted-device bridge push payload." + ) from exc + + if not validated.session_uuid: + raise PyiCloudTrustedDevicePromptException( + "Trusted-device bridge push payload is missing sessionUUID." + ) + + return cls( + payload=payload, + session_uuid=validated.session_uuid, + next_step=( + str(validated.next_step) if validated.next_step is not None else None + ), + rui_url_key=validated.rui_url_key, + txnid=validated.txnid, + salt=validated.salt, + mid=validated.mid, + idmsdata=validated.idmsdata, + akdata=validated.akdata, + data=validated.data, + encrypted_code=validated.encrypted_code, + error_code=validated.error_code, + ) + + +@dataclass +class TrustedDeviceBridgeState: + """Ephemeral trusted-device bridge state.""" + + connection_path: str + push_token: str + session_uuid: str + websocket: Optional[_WebSocketLike] + topic: str + topics_by_hash: dict[str, str] + source_app_id: Optional[str] = None + next_step: Optional[str] = None + rui_url_key: Optional[str] = None + push_payload: dict[str, Any] = field(default_factory=dict) + txnid: Optional[str] = None + salt: Optional[str] = None + mid: Optional[str] = None + idmsdata: Optional[str] = None + akdata: Any = None + data: Optional[str] = None + encrypted_code: Optional[str] = None + error_code: Optional[int] = None + + def apply_push_payload(self, push_payload: BridgePushPayload) -> None: + """Persist the latest bridge push metadata in the live bridge session.""" + + self.push_payload = dict(push_payload.payload) + self.session_uuid = push_payload.session_uuid + self.next_step = push_payload.next_step + self.rui_url_key = push_payload.rui_url_key + self.txnid = push_payload.txnid + self.salt = push_payload.salt + self.mid = push_payload.mid + self.idmsdata = push_payload.idmsdata + self.akdata = push_payload.akdata + self.data = push_payload.data + self.encrypted_code = push_payload.encrypted_code + self.error_code = push_payload.error_code + + @property + def uses_legacy_trusted_device_verifier(self) -> bool: + """Return whether Apple routed this bridge challenge to the legacy verifier.""" + + return bool(self.txnid and self.txnid.endswith("_W")) + + +@dataclass(frozen=True) +class BridgeStepRequest: + """Typed request body for Apple's bridge step endpoints.""" + + session_uuid: str + data: str + push_token: str + next_step: int + idmsdata: Optional[str] = None + akdata: Any = None + + def as_json(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "sessionUUID": self.session_uuid, + "data": self.data, + "ptkn": self.push_token, + "nextStep": self.next_step, + } + if self.idmsdata is not None: + payload["idmsdata"] = self.idmsdata + if self.akdata is not None: + payload["akdata"] = ( + json.dumps(self.akdata, separators=(",", ":")) + if isinstance(self.akdata, dict) + else self.akdata + ) + return payload + + +@dataclass(frozen=True) +class BridgeCodeValidateRequest: + """Typed request body for Apple's final bridge code validation endpoint.""" + + session_uuid: str + code: str + + def as_json(self) -> dict[str, str]: + return { + "sessionUUID": self.session_uuid, + "code": self.code, + } + + +@dataclass(frozen=True) +class _ConnectionResponse: + push_token_b64: str = "" + status: int = 0 + server_timestamp_seconds: Optional[int] = None + + +@dataclass(frozen=True) +class _PushMessage: + topic: bytes + message_id: int + payload: bytes + + +@dataclass(frozen=True) +class _ChannelSubscriptionResponse: + message_id: int = 0 + status: int = 0 + retry_interval_seconds: int = 0 + topics: tuple[str, ...] = () + + +@dataclass(frozen=True) +class _AcknowledgementMessage: + topic: bytes + message_id: int + delivery_status: int = 0 + + +@dataclass(frozen=True) +class _ServerMessage: + connection_response: Optional[_ConnectionResponse] = None + push_message: Optional[_PushMessage] = None + channel_subscription_response: Optional[_ChannelSubscriptionResponse] = None + push_acknowledgment: Optional[_AcknowledgementMessage] = None + field_numbers: tuple[int, ...] = () + + +class _WebSocketLike(Protocol): + def send_binary(self, payload: bytes) -> None: ... + + def read_message(self) -> bytes: ... + + def close(self) -> None: ... + + +class _InvalidNonceError(Exception): + def __init__(self, server_timestamp_ms: int) -> None: + super().__init__("Invalid nonce from bridge server.") + self.server_timestamp_ms = server_timestamp_ms + + +class _BootArgsHTMLParser(HTMLParser): + """Extract the JSON body from Apple's boot_args script tag.""" + + def __init__(self) -> None: + super().__init__() + self._collecting = False + self._found = False + self._chunks: list[str] = [] + + @property + def payload(self) -> str: + return "".join(self._chunks).strip() + + def handle_starttag(self, tag: str, attrs: list[tuple[str, Optional[str]]]) -> None: + if tag != "script" or self._found: + return + attr_map = {key: value for key, value in attrs} + classes = (attr_map.get("class") or "").split() + if "boot_args" in classes: + self._collecting = True + self._found = True + + def handle_endtag(self, tag: str) -> None: + if tag == "script" and self._collecting: + self._collecting = False + + def handle_data(self, data: str) -> None: + if self._collecting: + self._chunks.append(data) + + +def parse_boot_args_html(html_text: str) -> Hsa2BootContext: + """Extract HSA2 boot args from the HTML returned by GET /appleauth/auth.""" + + parser = _BootArgsHTMLParser() + parser.feed(html_text) + parser.close() + + payload_text = parser.payload + if not payload_text: + raise PyiCloudTrustedDevicePromptException("Missing HSA2 boot args payload.") + + try: + payload = json.loads(payload_text) + except json.JSONDecodeError as exc: + raise PyiCloudTrustedDevicePromptException( + "Malformed HSA2 boot args payload." + ) from exc + direct = payload.get("direct") + if not isinstance(direct, dict): + raise PyiCloudTrustedDevicePromptException("Missing HSA2 direct boot data.") + + two_sv = direct.get("twoSV") + if not isinstance(two_sv, dict): + two_sv = {} + + bridge_initiate_data = two_sv.get("bridgeInitiateData") + if not isinstance(bridge_initiate_data, dict): + bridge_initiate_data = {} + + phone_number_verification = bridge_initiate_data.get("phoneNumberVerification") + if not isinstance(phone_number_verification, dict): + phone_number_verification = {} + + auth_factors = two_sv.get("authFactors") + if not isinstance(auth_factors, list): + auth_factors = [] + + source_app_id = two_sv.get("sourceAppId") + if source_app_id is not None: + source_app_id = str(source_app_id) + + return Hsa2BootContext( + auth_initial_route=str(direct.get("authInitialRoute") or ""), + has_trusted_devices=bool(direct.get("hasTrustedDevices")), + auth_factors=tuple( + factor for factor in auth_factors if isinstance(factor, str) + ), + bridge_initiate_data=dict(bridge_initiate_data), + phone_number_verification=dict(phone_number_verification), + source_app_id=source_app_id, + ) + + +def _encode_varint(value: int) -> bytes: + if value < 0: + raise ValueError("Negative varints are not supported.") + parts = bytearray() + while True: + to_write = value & 0x7F + value >>= 7 + if value: + parts.append(to_write | 0x80) + else: + parts.append(to_write) + return bytes(parts) + + +def _read_varint(data: bytes, offset: int) -> tuple[int, int]: + value = 0 + shift = 0 + start_offset = offset + while True: + if offset >= len(data): + raise PyiCloudTrustedDevicePromptException("Truncated protobuf varint.") + byte = data[offset] + offset += 1 + value |= (byte & 0x7F) << shift + if not (byte & 0x80): + return value, offset + shift += 7 + # Guard against malformed wire data rather than silently accepting an + # overlong varint from Apple's private bridge protocol. + if shift > 63 or offset - start_offset >= 10: + raise PyiCloudTrustedDevicePromptException("Malformed protobuf varint.") + + +def _encode_field(field_number: int, wire_type: int, value: bytes) -> bytes: + return _encode_varint((field_number << 3) | wire_type) + value + + +def _encode_bytes_field(field_number: int, value: bytes) -> bytes: + return _encode_field(field_number, 2, _encode_varint(len(value)) + value) + + +def _encode_string_field(field_number: int, value: str) -> bytes: + return _encode_bytes_field(field_number, value.encode("utf-8")) + + +def _encode_uint32_field(field_number: int, value: int) -> bytes: + return _encode_field(field_number, 0, _encode_varint(value)) + + +def _decode_fields(data: bytes) -> dict[int, list[Any]]: + offset = 0 + fields: dict[int, list[Any]] = {} + while offset < len(data): + key, offset = _read_varint(data, offset) + field_number = key >> 3 + wire_type = key & 0x07 + + if wire_type == 0: + value, offset = _read_varint(data, offset) + elif wire_type == 2: + length, offset = _read_varint(data, offset) + # Length-delimited fields must stay within the current message + # bounds; otherwise the bridge frame is truncated or malformed. + end_offset = offset + length + if end_offset > len(data): + raise PyiCloudTrustedDevicePromptException("Truncated protobuf field.") + value = data[offset:end_offset] + offset = end_offset + else: + raise PyiCloudTrustedDevicePromptException( + f"Unsupported protobuf wire type: {wire_type}" + ) + + fields.setdefault(field_number, []).append(value) + return fields + + +def _decode_connection_response(message: bytes) -> _ConnectionResponse: + fields = _decode_fields(message) + push_token_b64 = "" + if fields.get(1): + try: + push_token_b64 = fields[1][0].decode("ascii") + except UnicodeDecodeError as exc: + raise PyiCloudTrustedDevicePromptException( + "Malformed bridge connection response push token." + ) from exc + status = int(fields.get(2, [0])[0]) + server_timestamp_seconds = None + if fields.get(3): + server_timestamp_seconds = int(fields[3][0]) + return _ConnectionResponse( + push_token_b64=push_token_b64, + status=status, + server_timestamp_seconds=server_timestamp_seconds, + ) + + +def _decode_push_message(message: bytes) -> _PushMessage: + fields = _decode_fields(message) + topic = bytes(fields.get(1, [b""])[0]) + message_id = int(fields.get(2, [0])[0]) + payload = bytes(fields.get(4, [b""])[0]) + return _PushMessage(topic=topic, message_id=message_id, payload=payload) + + +def _decode_channel_subscription_response( + message: bytes, +) -> _ChannelSubscriptionResponse: + fields = _decode_fields(message) + topics: list[str] = [] + + payload_values = fields.get(1) + if payload_values: + payload_fields = _decode_fields(bytes(payload_values[0])) + for app_response_value in payload_fields.get(1, []): + app_response_fields = _decode_fields(bytes(app_response_value)) + topic_value = app_response_fields.get(1, [b""])[0] + if isinstance(topic_value, bytes): + topics.append(topic_value.decode("utf-8", "ignore")) + + return _ChannelSubscriptionResponse( + message_id=int(fields.get(2, [0])[0]), + status=int(fields.get(3, [0])[0]), + retry_interval_seconds=int(fields.get(4, [0])[0]), + topics=tuple(topic for topic in topics if topic), + ) + + +def _decode_acknowledgement_message(message: bytes) -> _AcknowledgementMessage: + fields = _decode_fields(message) + topic = bytes(fields.get(1, [b""])[0]) + message_id = int(fields.get(2, [0])[0]) + delivery_status = int(fields.get(3, [0])[0]) + return _AcknowledgementMessage( + topic=topic, + message_id=message_id, + delivery_status=delivery_status, + ) + + +def _decode_server_message(message: bytes) -> _ServerMessage: + fields = _decode_fields(message) + + connection_response = None + if fields.get(SERVER_MESSAGE_CONNECTION_RESPONSE): + connection_response = _decode_connection_response( + bytes(fields[SERVER_MESSAGE_CONNECTION_RESPONSE][0]) + ) + + push_message = None + if fields.get(SERVER_MESSAGE_PUSH): + push_message = _decode_push_message(bytes(fields[SERVER_MESSAGE_PUSH][0])) + + channel_subscription_response = None + if fields.get(SERVER_MESSAGE_CHANNEL_SUBSCRIPTION_RESPONSE): + channel_subscription_response = _decode_channel_subscription_response( + bytes(fields[SERVER_MESSAGE_CHANNEL_SUBSCRIPTION_RESPONSE][0]) + ) + + push_acknowledgment = None + if fields.get(SERVER_MESSAGE_PUSH_ACK): + push_acknowledgment = _decode_acknowledgement_message( + bytes(fields[SERVER_MESSAGE_PUSH_ACK][0]) + ) + + return _ServerMessage( + connection_response=connection_response, + push_message=push_message, + channel_subscription_response=channel_subscription_response, + push_acknowledgment=push_acknowledgment, + field_numbers=tuple(sorted(fields)), + ) + + +def _encode_connection_message( + public_key: bytes, nonce: bytes, signature: bytes +) -> bytes: + connection_message = b"".join( + [ + _encode_bytes_field(1, public_key), + _encode_bytes_field(2, nonce), + _encode_bytes_field(3, _encode_bridge_signature(signature)), + _encode_bytes_field( + 5, _encode_uint32_field(1, NEW_CONNECTION_EXPIRATION_SECONDS) + ), + ] + ) + return _encode_bytes_field(1, connection_message) + + +def _encode_bridge_signature(signature: bytes) -> bytes: + """Wrap the DER ECDSA signature using Apple's bridge signature envelope.""" + + if signature.startswith(BRIDGE_SIGNATURE_PREFIX): + return signature + return BRIDGE_SIGNATURE_PREFIX + signature + + +def _encode_web_filter_message(allowed_topics: list[str]) -> bytes: + filter_payload = b"".join( + _encode_string_field(1, topic) for topic in allowed_topics + ) + return _encode_bytes_field(3, filter_payload) + + +def _encode_ack_message(topic: bytes, message_id: int) -> bytes: + ack_payload = b"".join( + [ + _encode_bytes_field(1, topic), + _encode_uint32_field(2, message_id), + ] + ) + return _encode_bytes_field(2, ack_payload) + + +def _topic_hash(topic: str) -> str: + return hashlib.sha1(topic.encode("utf-8")).hexdigest() + + +def _topic_name(topic_bytes: bytes, topics_by_hash: Mapping[str, str]) -> str: + return topics_by_hash.get(topic_bytes.hex(), topic_bytes.decode("utf-8", "ignore")) + + +def _extract_json_payload(payload: bytes) -> dict[str, Any]: + try: + return json.loads(payload.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError): + text = payload.decode("utf-8", "ignore") + + start = text.find("{") + while start >= 0: + depth = 0 + in_string = False + escaped = False + for index, character in enumerate(text[start:], start=start): + if in_string: + if escaped: + escaped = False + elif character == "\\": + escaped = True + elif character == '"': + in_string = False + continue + if character == '"': + in_string = True + elif character == "{": + depth += 1 + elif character == "}": + depth -= 1 + if depth == 0: + try: + return json.loads(text[start : index + 1]) + except json.JSONDecodeError: + break + start = text.find("{", start + 1) + + raise PyiCloudTrustedDevicePromptException( + "Could not decode the trusted-device bridge push payload." + ) + + +def _b64_to_hex(value: str) -> str: + try: + return base64.b64decode(value.encode("ascii"), validate=True).hex() + except (ValueError, BinasciiError) as exc: + raise ValueError("Malformed base64-encoded bridge payload.") from exc + + +def _hex_to_b64(value: str) -> str: + return base64.b64encode(bytes.fromhex(value)).decode("ascii") + + +def _build_nonce(timestamp_ms: int) -> bytes: + return b"\x00" + timestamp_ms.to_bytes(8, "big", signed=False) + os.urandom(8) + + +def _summarize_identifier( + value: Optional[str], *, prefix: int = 8, empty: str = "" +) -> str: + if not value: + return empty + if len(value) <= prefix: + return value + return f"{value[:prefix]}..." + + +def _resolve_websocket_host(boot_context: Hsa2BootContext) -> str: + bridge_data = boot_context.bridge_initiate_data + web_socket_url = bridge_data.get("webSocketUrl") + if isinstance(web_socket_url, str) and web_socket_url: + if "://" in web_socket_url: + parsed = urlparse(web_socket_url) + if parsed.hostname: + return parsed.hostname + return web_socket_url.split("/", 1)[0] + + environment = bridge_data.get("apnsEnvironment") + if isinstance(environment, str) and environment in WEBSOCKET_ENVIRONMENT_HOSTS: + return WEBSOCKET_ENVIRONMENT_HOSTS[environment] + + raise PyiCloudTrustedDevicePromptException( + "Missing HSA2 websocket host for the trusted-device bridge." + ) + + +def _resolve_apns_topic(boot_context: Hsa2BootContext) -> str: + topic = boot_context.bridge_initiate_data.get("apnsTopic") + if isinstance(topic, str) and topic: + return topic + + raise PyiCloudTrustedDevicePromptException( + "Missing HSA2 APNS topic for the trusted-device bridge." + ) + + +def _derive_origin(auth_endpoint: str) -> str: + parsed = urlparse(auth_endpoint) + if not parsed.scheme or not parsed.hostname: + raise PyiCloudTrustedDevicePromptException( + "Invalid auth endpoint for trusted-device bridge." + ) + return f"{parsed.scheme}://{parsed.hostname}" + + +class _RawWebSocketClient: + """Minimal websocket client for Apple's webcourier bridge.""" + + def __init__( + self, + url: str, + timeout: float, + origin: str, + user_agent: str, + ) -> None: + self._url = url + self._timeout = timeout + self._origin = origin + self._user_agent = user_agent + self._buffer = bytearray() + self._socket = self._open() + + def _open(self) -> ssl.SSLSocket: + parsed = urlparse(self._url) + if parsed.scheme != "wss" or not parsed.hostname: + raise PyiCloudTrustedDevicePromptException( + f"Unsupported websocket URL: {self._url}" + ) + + port = parsed.port or 443 + resource = parsed.path or "/" + if parsed.query: + resource = f"{resource}?{parsed.query}" + + raw_socket = socket.create_connection((parsed.hostname, port), self._timeout) + context = ssl.create_default_context() + secure_socket = context.wrap_socket(raw_socket, server_hostname=parsed.hostname) + secure_socket.settimeout(self._timeout) + + websocket_key = base64.b64encode(os.urandom(16)).decode("ascii") + request_headers = [ + f"GET {resource} HTTP/1.1", + f"Host: {parsed.hostname}", + "Upgrade: websocket", + "Connection: Upgrade", + f"Origin: {self._origin}", + f"User-Agent: {self._user_agent}", + "Sec-WebSocket-Version: 13", + f"Sec-WebSocket-Key: {websocket_key}", + "\r\n", + ] + secure_socket.sendall("\r\n".join(request_headers).encode("ascii")) + + response = self._read_http_response(secure_socket) + status_line, _, headers_text = response.partition("\r\n") + if " 101 " not in status_line: + raise PyiCloudTrustedDevicePromptException( + f"Websocket upgrade failed: {status_line}" + ) + + headers: dict[str, str] = {} + for line in headers_text.split("\r\n"): + if not line or ":" not in line: + continue + key, value = line.split(":", 1) + headers[key.strip().lower()] = value.strip() + + expected_accept = base64.b64encode( + hashlib.sha1((websocket_key + WEBSOCKET_GUID).encode("ascii")).digest() + ).decode("ascii") + if headers.get("sec-websocket-accept") != expected_accept: + raise PyiCloudTrustedDevicePromptException( + "Invalid websocket accept header from bridge server." + ) + + return secure_socket + + def _read_http_response(self, sock: ssl.SSLSocket) -> str: + while b"\r\n\r\n" not in self._buffer: + chunk = sock.recv(4096) + if not chunk: + raise PyiCloudTrustedDevicePromptException( + "Unexpected EOF during websocket handshake." + ) + self._buffer.extend(chunk) + + marker = self._buffer.find(b"\r\n\r\n") + 4 + data = bytes(self._buffer[:marker]).decode("iso-8859-1") + del self._buffer[:marker] + return data + + def _read_exact(self, size: int) -> bytes: + while len(self._buffer) < size: + chunk = self._socket.recv(max(4096, size - len(self._buffer))) + if not chunk: + raise PyiCloudTrustedDevicePromptException( + "Unexpected EOF while reading websocket frame." + ) + self._buffer.extend(chunk) + + data = bytes(self._buffer[:size]) + del self._buffer[:size] + return data + + def _send_frame(self, opcode: int, payload: bytes) -> None: + first_byte = 0x80 | opcode + mask_key = os.urandom(4) + length = len(payload) + + header = bytearray([first_byte]) + if length < 126: + header.append(0x80 | length) + elif length < 65536: + header.append(0x80 | 126) + header.extend(struct.pack("!H", length)) + else: + header.append(0x80 | 127) + header.extend(struct.pack("!Q", length)) + + masked_payload = bytes( + byte ^ mask_key[index % 4] for index, byte in enumerate(payload) + ) + self._socket.sendall(bytes(header) + mask_key + masked_payload) + + def send_binary(self, payload: bytes) -> None: + self._send_frame(OPCODE_BINARY, payload) + + def read_message(self) -> bytes: + fragments: list[bytes] = [] + opcode: Optional[int] = None + + while True: + first_byte, second_byte = self._read_exact(2) + frame_opcode = first_byte & 0x0F + finished = bool(first_byte & 0x80) + masked = bool(second_byte & 0x80) + payload_length = second_byte & 0x7F + + if payload_length == 126: + payload_length = struct.unpack("!H", self._read_exact(2))[0] + elif payload_length == 127: + payload_length = struct.unpack("!Q", self._read_exact(8))[0] + + mask_key = self._read_exact(4) if masked else b"" + payload = self._read_exact(payload_length) + if masked: + payload = bytes( + byte ^ mask_key[index % 4] for index, byte in enumerate(payload) + ) + + if frame_opcode == OPCODE_CLOSE: + raise PyiCloudTrustedDevicePromptException( + "Bridge websocket closed before delivering a prompt." + ) + if frame_opcode == OPCODE_PING: + self._send_frame(OPCODE_PONG, payload) + continue + if frame_opcode == OPCODE_PONG: + continue + + if frame_opcode != 0: + opcode = frame_opcode + fragments.append(payload) + if finished: + if opcode not in (0x1, OPCODE_BINARY): + raise PyiCloudTrustedDevicePromptException( + f"Unsupported websocket opcode: {opcode}" + ) + return b"".join(fragments) + + def close(self) -> None: + if getattr(self, "_socket", None) is None: + return + try: + self._send_frame(OPCODE_CLOSE, b"") + except OSError: + pass + finally: + try: + self._socket.close() + except OSError: + pass + + +class TrustedDeviceBridgeBootstrapper: + """Bootstrap the trusted-device bridge flow captured in Apple's browser client.""" + + def __init__( + self, + *, + timeout: float = WEBSOCKET_TIMEOUT_SECONDS, + websocket_factory: Optional[ + Callable[[str, float, str, str], _WebSocketLike] + ] = None, + prover_factory: Optional[Callable[[], TrustedDeviceBridgeProver]] = None, + ) -> None: + self.timeout = timeout + self._websocket_factory = websocket_factory or _RawWebSocketClient + self._prover_factory = prover_factory or TrustedDeviceBridgeProver + + def start( + self, + *, + session: Any, + auth_endpoint: str, + headers: Mapping[str, str], + boot_context: Hsa2BootContext, + user_agent: str, + ) -> TrustedDeviceBridgeState: + topic = _resolve_apns_topic(boot_context) + websocket_host = _resolve_websocket_host(boot_context) + origin = _derive_origin(auth_endpoint) + topics_by_hash = {_topic_hash(topic): topic} + source_app_id = boot_context.source_app_id + public_key, private_key = self._generate_keypair() + + LOGGER.debug( + "Bootstrapping trusted-device bridge: auth_endpoint=%s websocket_host=%s topic=%s source_app_id=%s", + auth_endpoint, + websocket_host, + topic, + source_app_id, + ) + + timestamp_ms: Optional[int] = None + last_error: Optional[Exception] = None + for _ in range(2): + nonce = _build_nonce(timestamp_ms or int(time.time() * 1000)) + signature = private_key.sign(nonce, ec.ECDSA(hashes.SHA256())) + connection_message = _encode_connection_message( + public_key, nonce, signature + ) + connection_path = connection_message.hex() + websocket_url = f"wss://{websocket_host}/v2/{connection_path}" + LOGGER.debug( + "Opening trusted-device websocket: host=%s bootstrapPayloadLen=%d", + websocket_host, + len(connection_path), + ) + websocket = self._websocket_factory( + websocket_url, + self.timeout, + origin, + user_agent, + ) + keep_websocket_open = False + + try: + push_token = self._wait_for_push_token(websocket) + push_token_hex = push_token.hex() + LOGGER.debug( + "Trusted-device bridge connected; received push token (%d bytes)", + len(push_token), + ) + websocket.send_binary(_encode_web_filter_message([topic])) + LOGGER.debug("Sent trusted-device webFilterMessage for topic=%s", topic) + + session_uuid = self._generate_session_uuid() + bridge_headers = dict(headers) + if source_app_id: + bridge_headers["X-Apple-App-Id"] = source_app_id + + LOGGER.debug( + "Posting trusted-device bridge step 0 with sessionUUID=%s ptknLen=%d", + _summarize_identifier(session_uuid), + len(push_token_hex), + ) + # Apple's browser posts step 0 immediately after obtaining the push + # token. Waiting for the first push before posting step 0 causes the + # bridge flow to stall. + self._post_bridge_step0( + session=session, + auth_endpoint=auth_endpoint, + headers=bridge_headers, + session_uuid=session_uuid, + push_token=push_token_hex, + ) + + push_payload = self._wait_for_bridge_push( + websocket, topic, topics_by_hash + ) + LOGGER.debug( + "Received trusted-device bridge payload: sessionUUID=%s nextStep=%s ruiURLKey=%s", + _summarize_identifier(push_payload.session_uuid), + push_payload.next_step, + push_payload.rui_url_key, + ) + if push_payload.session_uuid != session_uuid: + raise PyiCloudTrustedDevicePromptException( + "Trusted-device bridge returned a mismatched session UUID." + ) + + bridge_state = TrustedDeviceBridgeState( + connection_path=connection_path, + push_token=push_token_hex, + session_uuid=session_uuid, + websocket=websocket, + topic=topic, + topics_by_hash=dict(topics_by_hash), + source_app_id=source_app_id, + ) + bridge_state.apply_push_payload(push_payload) + keep_websocket_open = True + return bridge_state + except _InvalidNonceError as exc: + timestamp_ms = exc.server_timestamp_ms + last_error = exc + LOGGER.debug( + "Trusted-device bridge received INVALID_NONCE; retrying with server timestamp %s", + timestamp_ms, + ) + except (OSError, socket.timeout, ssl.SSLError) as exc: + last_error = exc + LOGGER.debug( + "Trusted-device websocket transport error during bootstrap.", + exc_info=True, + ) + break + except PyiCloudTrustedDevicePromptException as exc: + last_error = exc + LOGGER.debug( + "Trusted-device bridge bootstrap failed before completion.", + exc_info=True, + ) + break + finally: + if not keep_websocket_open: + websocket.close() + + raise PyiCloudTrustedDevicePromptException( + "Failed to bootstrap the trusted-device bridge prompt." + ) from last_error + + def _generate_keypair(self) -> tuple[bytes, ec.EllipticCurvePrivateKey]: + private_key = ec.generate_private_key(ec.SECP256R1()) + public_key = private_key.public_key().public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.UncompressedPoint, + ) + return public_key, private_key + + def _generate_session_uuid(self) -> str: + return f"{uuid.uuid4()}-{int(time.time())}" + + def close(self, bridge_state: Optional[TrustedDeviceBridgeState]) -> None: + """Close and detach the websocket associated with an active bridge session.""" + + if bridge_state is None: + return + websocket = bridge_state.websocket + bridge_state.websocket = None + if websocket is None: + return + try: + websocket.close() + except OSError: + LOGGER.debug( + "Trusted-device bridge websocket close failed.", + exc_info=True, + ) + + def validate_code( + self, + *, + session: Any, + auth_endpoint: str, + headers: Mapping[str, str], + bridge_state: TrustedDeviceBridgeState, + code: str, + ) -> bool: + """Run Apple's bridge-specific trusted-device verification flow.""" + + websocket = bridge_state.websocket + if websocket is None: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge session is not active." + ) + if bridge_state.uses_legacy_trusted_device_verifier: + raise PyiCloudTrustedDeviceVerificationException( + "Legacy trusted-device verification should bypass the bridge verifier." + ) + if bridge_state.next_step not in {"2", 2}: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge is not ready for step 2 verification." + ) + if not bridge_state.salt: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge payload is missing the step-2 salt." + ) + + prover = self._prover_factory() + bridge_headers = self._bridge_headers(headers, bridge_state) + + try: + LOGGER.debug( + "Starting trusted-device bridge code verification: sessionUUID=%s nextStep=%s txnid=%s", + _summarize_identifier(bridge_state.session_uuid), + bridge_state.next_step, + _summarize_identifier(bridge_state.txnid, prefix=12), + ) + + prover.init_with_salt(bridge_state.salt, code) + message1 = prover.get_message1() + LOGGER.debug( + "Posting trusted-device bridge step 2 with sessionUUID=%s", + _summarize_identifier(bridge_state.session_uuid), + ) + self._post_bridge_step( + session=session, + auth_endpoint=auth_endpoint, + headers=bridge_headers, + bridge_state=bridge_state, + next_step=2, + data=_hex_to_b64(message1), + idmsdata=bridge_state.idmsdata, + akdata=bridge_state.akdata, + ) + + step4_payload = self._wait_for_bridge_push( + websocket, + bridge_state.topic, + bridge_state.topics_by_hash, + ) + self._apply_expected_step4_push(bridge_state, step4_payload) + + if not bridge_state.data: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge step 4 payload is missing prover data." + ) + try: + step4_data = base64.b64decode( + bridge_state.data.encode("ascii"), validate=True + ).decode("utf-8") + bridge_message1_b64, bridge_message2_b64 = step4_data.split("_", 1) + bridge_message1_hex = _b64_to_hex(bridge_message1_b64) + bridge_message2_hex = _b64_to_hex(bridge_message2_b64) + except (ValueError, UnicodeDecodeError, BinasciiError) as exc: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge step 4 payload is malformed." + ) from exc + + LOGGER.debug( + "Processing trusted-device bridge step 4 payload for sessionUUID=%s", + _summarize_identifier(bridge_state.session_uuid), + ) + try: + message2 = prover.process_message1(bridge_message1_hex) + except ValueError as exc: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge step 4 payload is malformed." + ) from exc + try: + prover.process_message2(bridge_message2_hex) + except ValueError: + LOGGER.debug( + "Trusted-device bridge prover rejected the step-4 confirmation for sessionUUID=%s", + _summarize_identifier(bridge_state.session_uuid), + ) + return False + + LOGGER.debug( + "Posting trusted-device bridge step 4 with sessionUUID=%s", + _summarize_identifier(bridge_state.session_uuid), + ) + self._post_bridge_step( + session=session, + auth_endpoint=auth_endpoint, + headers=bridge_headers, + bridge_state=bridge_state, + next_step=4, + data=_hex_to_b64(message2), + idmsdata=bridge_state.idmsdata, + akdata=bridge_state.akdata, + ) + + final_payload = self._wait_for_bridge_push( + websocket, + bridge_state.topic, + bridge_state.topics_by_hash, + ) + self._apply_final_bridge_push(bridge_state, final_payload) + + if not bridge_state.encrypted_code: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge final payload is missing encryptedCode." + ) + + LOGGER.debug( + "Decrypting trusted-device bridge code for sessionUUID=%s", + _summarize_identifier(bridge_state.session_uuid), + ) + try: + derived_code = prover.decrypt_message(bridge_state.encrypted_code) + except ValueError as exc: + raise PyiCloudTrustedDeviceVerificationException( + "Failed to decrypt the trusted-device bridge validation code." + ) from exc + + verify_response = self._post_bridge_code_validate( + session=session, + auth_endpoint=auth_endpoint, + headers=bridge_headers, + bridge_state=bridge_state, + code=derived_code, + ) + verification_succeeded = ( + verify_response.status_code != HTTP_STATUS_PRECONDITION_FAILED + ) + + completion_step = 6 if bridge_state.next_step in {"6", 6} else 4 + LOGGER.debug( + "Posting trusted-device bridge completion step %s with sessionUUID=%s verifyStatus=%s", + completion_step, + _summarize_identifier(bridge_state.session_uuid), + verify_response.status_code, + ) + self._post_bridge_step( + session=session, + auth_endpoint=auth_endpoint, + headers=bridge_headers, + bridge_state=bridge_state, + next_step=completion_step, + data=BRIDGE_DONE_DATA_B64, + idmsdata=bridge_state.idmsdata, + akdata=bridge_state.akdata, + ) + return verification_succeeded + except PyiCloudTrustedDevicePromptException as exc: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge verification failed while waiting for the next bridge push." + ) from exc + except (OSError, socket.timeout, ssl.SSLError) as exc: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge verification failed due to a websocket transport error." + ) from exc + finally: + self.close(bridge_state) + + def _wait_for_push_token(self, websocket: _WebSocketLike) -> bytes: + deadline = time.monotonic() + self.timeout + while time.monotonic() < deadline: + message = websocket.read_message() + server_message = _decode_server_message(message) + connection_response = server_message.connection_response + if connection_response is None: + LOGGER.debug( + "Ignoring non-connection websocket frame while waiting for push token; fields=%s", + server_message.field_numbers, + ) + continue + + if ( + connection_response.status == STATUS_OK + and connection_response.push_token_b64 + ): + try: + return base64.b64decode( + connection_response.push_token_b64.encode("ascii"), + validate=True, + ) + except (ValueError, BinasciiError) as exc: + raise PyiCloudTrustedDevicePromptException( + "Malformed bridge push token." + ) from exc + + if ( + connection_response.status == STATUS_INVALID_NONCE + and connection_response.server_timestamp_seconds is not None + ): + raise _InvalidNonceError( + connection_response.server_timestamp_seconds * 1000 + ) + + LOGGER.debug( + "Trusted-device bridge connection response returned status=%s", + connection_response.status, + ) + raise PyiCloudTrustedDevicePromptException( + f"Bridge server returned status {connection_response.status}." + ) + + raise PyiCloudTrustedDevicePromptException( + "Timed out waiting for the bridge push token." + ) + + def _wait_for_bridge_push( + self, + websocket: _WebSocketLike, + topic: str, + topics_by_hash: Mapping[str, str], + ) -> BridgePushPayload: + deadline = time.monotonic() + self.timeout + while time.monotonic() < deadline: + message = websocket.read_message() + server_message = _decode_server_message(message) + if server_message.channel_subscription_response is not None: + channel_response = server_message.channel_subscription_response + LOGGER.debug( + "Received channel subscription response during bridge bootstrap: messageId=%s status=%s retryIntervalSeconds=%s topics=%s", + channel_response.message_id, + channel_response.status, + channel_response.retry_interval_seconds, + channel_response.topics, + ) + if channel_response.status != STATUS_OK: + raise PyiCloudTrustedDevicePromptException( + "Trusted-device bridge topic subscription failed " + f"(status {channel_response.status})." + ) + + if server_message.push_acknowledgment is not None: + push_ack = server_message.push_acknowledgment + LOGGER.debug( + "Received bridge push acknowledgment during bootstrap: messageId=%s deliveryStatus=%s topic=%s", + push_ack.message_id, + push_ack.delivery_status, + _topic_name(push_ack.topic, topics_by_hash), + ) + + push_message = server_message.push_message + if push_message is None: + LOGGER.debug( + "Ignoring non-push websocket frame during trusted-device bootstrap; fields=%s", + server_message.field_numbers, + ) + continue + + websocket.send_binary( + _encode_ack_message(push_message.topic, push_message.message_id) + ) + LOGGER.debug( + "Acknowledged trusted-device push message id=%s topic=%s", + push_message.message_id, + _topic_name(push_message.topic, topics_by_hash), + ) + + if _topic_name(push_message.topic, topics_by_hash) != topic: + continue + + payload = _extract_json_payload(push_message.payload) + return BridgePushPayload.from_payload(payload) + + raise PyiCloudTrustedDevicePromptException( + "Timed out waiting for the trusted-device bridge payload." + ) + + def _apply_bridge_push( + self, + bridge_state: TrustedDeviceBridgeState, + push_payload: BridgePushPayload, + ) -> None: + if push_payload.session_uuid != bridge_state.session_uuid: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge returned a mismatched session UUID." + ) + LOGGER.debug( + "Decoded trusted-device bridge payload: sessionUUID=%s nextStep=%s txnid=%s ec=%s has_data=%s has_encryptedCode=%s", + _summarize_identifier(push_payload.session_uuid), + push_payload.next_step, + _summarize_identifier(push_payload.txnid, prefix=12), + push_payload.error_code, + bool(push_payload.data), + bool(push_payload.encrypted_code), + ) + if push_payload.error_code not in (None, 0): + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge returned an error push " + f"(nextStep={push_payload.next_step!r}, ec={push_payload.error_code})." + ) + bridge_state.apply_push_payload(push_payload) + + def _apply_expected_step4_push( + self, + bridge_state: TrustedDeviceBridgeState, + push_payload: BridgePushPayload, + ) -> None: + self._apply_bridge_push(bridge_state, push_payload) + if bridge_state.next_step != "4" or not bridge_state.data: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge returned an unexpected post-step-2 payload." + ) + LOGGER.debug( + "Received trusted-device bridge payload: sessionUUID=%s nextStep=%s txnid=%s", + _summarize_identifier(bridge_state.session_uuid), + bridge_state.next_step, + _summarize_identifier(bridge_state.txnid, prefix=12), + ) + + def _apply_final_bridge_push( + self, + bridge_state: TrustedDeviceBridgeState, + push_payload: BridgePushPayload, + ) -> None: + self._apply_bridge_push(bridge_state, push_payload) + # Apple's bridge can finish with either: + # - nextStep=6 plus encryptedCode + # - nextStep=4 plus encryptedCode + # The browser routes both shapes into final code validation. + if ( + bridge_state.next_step not in {"4", "6", 4, 6} + or not bridge_state.encrypted_code + ): + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge returned an unexpected final payload." + ) + LOGGER.debug( + "Received trusted-device bridge final payload: sessionUUID=%s nextStep=%s txnid=%s", + _summarize_identifier(bridge_state.session_uuid), + bridge_state.next_step, + _summarize_identifier(bridge_state.txnid, prefix=12), + ) + + def _bridge_headers( + self, + headers: Mapping[str, str], + bridge_state: TrustedDeviceBridgeState, + ) -> dict[str, str]: + bridge_headers = dict(headers) + if bridge_state.source_app_id: + bridge_headers["X-Apple-App-Id"] = bridge_state.source_app_id + return bridge_headers + + def _bridge_step_json( + self, + *, + bridge_state: TrustedDeviceBridgeState, + next_step: int, + data: str, + idmsdata: Optional[str], + akdata: Any, + ) -> dict[str, Any]: + return BridgeStepRequest( + session_uuid=bridge_state.session_uuid, + data=data, + push_token=bridge_state.push_token, + next_step=next_step, + idmsdata=idmsdata, + akdata=akdata, + ).as_json() + + def _post_bridge_step( + self, + *, + session: Any, + auth_endpoint: str, + headers: Mapping[str, str], + bridge_state: TrustedDeviceBridgeState, + next_step: int, + data: str, + idmsdata: Optional[str], + akdata: Any, + ) -> Any: + response = session.request_raw( + "POST", + f"{auth_endpoint}{BRIDGE_STEP_PATH_TEMPLATE.format(step=next_step)}", + json=self._bridge_step_json( + bridge_state=bridge_state, + next_step=next_step, + data=data, + idmsdata=idmsdata, + akdata=akdata, + ), + headers=headers, + ) + if response.status_code not in { + HTTP_STATUS_OK, + HTTP_STATUS_NO_CONTENT, + HTTP_STATUS_CONFLICT, + }: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge step " + f"{next_step} failed with status {response.status_code}." + ) + return response + + def _post_bridge_step0( + self, + *, + session: Any, + auth_endpoint: str, + headers: Mapping[str, str], + session_uuid: str, + push_token: str, + ) -> Any: + response = session.request_raw( + "POST", + f"{auth_endpoint}{BRIDGE_STEP_PATH}", + json={ + "sessionUUID": session_uuid, + "ptkn": push_token, + }, + headers=headers, + ) + if response.status_code not in { + HTTP_STATUS_OK, + HTTP_STATUS_NO_CONTENT, + HTTP_STATUS_CONFLICT, + }: + raise PyiCloudTrustedDevicePromptException( + "Trusted-device bridge step 0 failed with status " + f"{response.status_code}." + ) + return response + + def _post_bridge_code_validate( + self, + *, + session: Any, + auth_endpoint: str, + headers: Mapping[str, str], + bridge_state: TrustedDeviceBridgeState, + code: str, + ) -> Any: + response = session.request_raw( + "POST", + f"{auth_endpoint}{BRIDGE_CODE_VALIDATE_PATH}", + json=BridgeCodeValidateRequest( + session_uuid=bridge_state.session_uuid, + code=code, + ).as_json(), + headers=headers, + ) + if response.status_code not in { + HTTP_STATUS_OK, + HTTP_STATUS_NO_CONTENT, + HTTP_STATUS_CONFLICT, + HTTP_STATUS_PRECONDITION_FAILED, + }: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge code validation failed with status " + f"{response.status_code}." + ) + return response diff --git a/pyicloud/hsa2_bridge_prover.py b/pyicloud/hsa2_bridge_prover.py new file mode 100644 index 00000000..ed5a8432 --- /dev/null +++ b/pyicloud/hsa2_bridge_prover.py @@ -0,0 +1,517 @@ +"""Pure-Python bridge prover for Apple's trusted-device HSA2 flow.""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import secrets +from dataclasses import dataclass +from typing import Optional + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +_SCRYPT_PARAMS = { + "n": 16384, + "r": 8, + "p": 1, + "dklen": 64, +} +_CLIENT_IDENTITY = b"com.apple.security.webprover" +_SERVER_IDENTITY = b"com.apple.security.webverifier" +_SPAKE2_CONTEXT = b"SPAKE2Web" +_KEY_LENGTH = 32 +_VERIFIER_KEY_INFO = b"webVerifier" +_PROVER_KEY_INFO = b"webProver" + +_P256_P = int("FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF", 16) +_P256_A = (_P256_P - 3) % _P256_P +_P256_B = int("5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B", 16) +_P256_ORDER = int( + "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551", 16 +) +_P256_GX = int("6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296", 16) +_P256_GY = int("4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5", 16) +_SPAKE2_M = "02886e2f97ace46e55ba9dd7242579f2993b64e16ef3dcab95afd497333d8fa12f" +_SPAKE2_N = "03d8bbd6c639c62937b04d997f38c3770719c629d7014d49a24b4f98baa1292b49" +_AES_GCM_LAYOUTS = {0: (12, 16)} + + +@dataclass(frozen=True) +class _Point: + x: Optional[int] + y: Optional[int] + + @property + def is_infinity(self) -> bool: + return self.x is None or self.y is None + + +_INFINITY = _Point(None, None) +_GENERATOR = _Point(_P256_GX, _P256_GY) + + +def _int_to_bytes(value: int, length: Optional[int] = None) -> bytes: + if length is None: + length = max(1, (value.bit_length() + 7) // 8) + return value.to_bytes(length, "big") + + +def _b64_to_bytes(value: str) -> bytes: + return base64.b64decode(value.encode("ascii")) + + +def _bytes_to_b64(value: bytes) -> str: + return base64.b64encode(value).decode("ascii") + + +def _encode_point(point: _Point) -> str: + if point.is_infinity: + raise ValueError("Cannot encode the point at infinity.") + return "04" + _int_to_bytes(point.x, 32).hex() + _int_to_bytes(point.y, 32).hex() + + +def _decode_point(value: str) -> _Point: + raw = bytes.fromhex(value) + if len(raw) == 65 and raw[0] == 0x04: + point = _Point( + int.from_bytes(raw[1:33], "big"), + int.from_bytes(raw[33:65], "big"), + ) + elif len(raw) == 33 and raw[0] in (0x02, 0x03): + x_coord = int.from_bytes(raw[1:], "big") + rhs = (pow(x_coord, 3, _P256_P) + _P256_A * x_coord + _P256_B) % _P256_P + y_coord = pow(rhs, (_P256_P + 1) // 4, _P256_P) + if y_coord & 1 != raw[0] & 1: + y_coord = (-y_coord) % _P256_P + point = _Point(x_coord, y_coord) + else: + raise ValueError("Unsupported P-256 point encoding.") + + if not _is_on_curve(point): + raise ValueError("Invalid P-256 point.") + return point + + +def _is_on_curve(point: _Point) -> bool: + if point.is_infinity: + return False + assert point.x is not None and point.y is not None + return ( + pow(point.y, 2, _P256_P) + - (pow(point.x, 3, _P256_P) + _P256_A * point.x + _P256_B) + ) % _P256_P == 0 + + +def _negate(point: _Point) -> _Point: + if point.is_infinity: + return point + assert point.x is not None and point.y is not None + return _Point(point.x, (-point.y) % _P256_P) + + +def _add_points(left: _Point, right: _Point) -> _Point: + if left.is_infinity: + return right + if right.is_infinity: + return left + + assert left.x is not None and left.y is not None + assert right.x is not None and right.y is not None + + if left.x == right.x and (left.y + right.y) % _P256_P == 0: + return _INFINITY + + if left.x == right.x and left.y == right.y: + if left.y == 0: + return _INFINITY + slope = ( + (3 * left.x * left.x + _P256_A) * pow(2 * left.y, -1, _P256_P) + ) % _P256_P + else: + slope = ((right.y - left.y) * pow(right.x - left.x, -1, _P256_P)) % _P256_P + + x_coord = (slope * slope - left.x - right.x) % _P256_P + y_coord = (slope * (left.x - x_coord) - left.y) % _P256_P + return _Point(x_coord, y_coord) + + +def _multiply_point(point: _Point, scalar: int) -> _Point: + scalar %= _P256_ORDER + result = _INFINITY + addend = point + while scalar: + if scalar & 1: + result = _add_points(result, addend) + addend = _add_points(addend, addend) + scalar >>= 1 + return result + + +def _concat_length_prefixed(*parts: bytes) -> bytes: + output = bytearray() + for part in parts: + output.extend(len(part).to_bytes(8, "little")) + output.extend(part) + return bytes(output) + + +def _hkdf_like(ikm: bytes, salt: bytes, info: bytes, length: int) -> bytes: + hash_len = hashlib.sha256().digest_size + if not salt: + salt = b"\x00" * hash_len + prk = hmac.new(salt, ikm, hashlib.sha256).digest() + blocks = bytearray() + previous = b"" + counter = 1 + while len(blocks) < length: + previous = hmac.new( + prk, + previous + info + bytes([counter]), + hashlib.sha256, + ).digest() + blocks.extend(previous) + counter += 1 + return bytes(blocks[:length]) + + +def _confirmation_key_length(info: bytes, requested_length: int) -> int: + if b"ConfirmationKeys" in info: + return 64 + return requested_length + + +def _derive_key(ikm: bytes, info: bytes, length: int = 64) -> bytes: + return _hkdf_like( + ikm=ikm, + salt=b"", + info=info, + length=_confirmation_key_length(info, length), + ) + + +def _derive_prover_and_verifier_keys(raw_key_hex: str) -> tuple[str, str]: + raw_key = bytes.fromhex(raw_key_hex) + verifier_key = _derive_key(raw_key, _VERIFIER_KEY_INFO, _KEY_LENGTH) + prover_key = _derive_key(raw_key, _PROVER_KEY_INFO, _KEY_LENGTH) + return verifier_key.hex(), prover_key.hex() + + +@dataclass(frozen=True) +class _ClientSharedSecret: + transcript: bytes + share_p: str + share_v: str + + def __post_init__(self) -> None: + digest = hashlib.sha256(self.transcript).digest() + object.__setattr__(self, "_hash_transcript", digest) + confirmations = _derive_key(digest, b"ConfirmationKeys", 64) + object.__setattr__(self, "_confirm_client", confirmations[:32]) + object.__setattr__(self, "_confirm_server", confirmations[32:]) + shared_key = _derive_key(digest, b"SharedKey", _KEY_LENGTH) + object.__setattr__(self, "_shared_key", shared_key) + + def get_confirmation(self) -> str: + return hmac.new( + self._confirm_client, + bytes.fromhex(self.share_v), + hashlib.sha256, + ).hexdigest() + + def verify(self, message_hex: str) -> bytes: + expected = hmac.new( + self._confirm_server, + bytes.fromhex(self.share_p), + hashlib.sha256, + ).hexdigest() + if expected != message_hex: + raise ValueError("invalid confirmation from server") + return self._shared_key + + +@dataclass(frozen=True) +class _ServerSharedSecret: + transcript: bytes + share_p: str + share_v: str + + def __post_init__(self) -> None: + digest = hashlib.sha256(self.transcript).digest() + confirmations = _derive_key(digest, b"ConfirmationKeys", 64) + object.__setattr__(self, "_confirm_client", confirmations[:32]) + object.__setattr__(self, "_confirm_server", confirmations[32:]) + object.__setattr__( + self, + "_shared_key", + _derive_key(digest, b"SharedKey", _KEY_LENGTH), + ) + + def get_confirmation(self) -> str: + return hmac.new( + self._confirm_server, + bytes.fromhex(self.share_p), + hashlib.sha256, + ).hexdigest() + + def verify(self, message_hex: str) -> bytes: + expected = hmac.new( + self._confirm_client, + bytes.fromhex(self.share_v), + hashlib.sha256, + ).hexdigest() + if expected != message_hex: + raise ValueError("invalid confirmation from client") + return self._shared_key + + +class _ClientHandshake: + def __init__( + self, + *, + x_scalar: int, + w0: int, + w1: int, + ) -> None: + self._x = x_scalar + self._w0 = w0 + self._w1 = w1 + self._message1_point: Optional[_Point] = None + self.share_p: Optional[str] = None + + def get_message(self) -> str: + point = _add_points( + _multiply_point(_GENERATOR, self._x), + _multiply_point(_decode_point(_SPAKE2_M), self._w0), + ) + self._message1_point = point + self.share_p = _encode_point(point) + return self.share_p + + def finish(self, server_message_hex: str) -> _ClientSharedSecret: + if self._message1_point is None or self.share_p is None: + raise ValueError("get_message must be called before finish") + + server_point = _decode_point(server_message_hex) + if server_point.is_infinity: + raise ValueError("invalid curve point") + + adjusted = _add_points( + server_point, + _negate(_multiply_point(_decode_point(_SPAKE2_N), self._w0)), + ) + y_point = _multiply_point(adjusted, self._x) + v_point = _multiply_point(adjusted, self._w1) + transcript = _concat_length_prefixed( + _SPAKE2_CONTEXT, + _CLIENT_IDENTITY, + _SERVER_IDENTITY, + bytes.fromhex(_encode_point(_decode_point(_SPAKE2_M))), + bytes.fromhex(_encode_point(_decode_point(_SPAKE2_N))), + bytes.fromhex(_encode_point(self._message1_point)), + bytes.fromhex(_encode_point(server_point)), + bytes.fromhex(_encode_point(y_point)), + bytes.fromhex(_encode_point(v_point)), + _int_to_bytes(self._w0), + ) + return _ClientSharedSecret( + transcript=transcript, + share_p=self.share_p, + share_v=server_message_hex, + ) + + +class _ServerHandshake: + def __init__( + self, + *, + y_scalar: int, + w0: int, + verifier_point: _Point, + ) -> None: + self._y = y_scalar + self._w0 = w0 + self._verifier_point = verifier_point + self._message1_point: Optional[_Point] = None + self.share_v: Optional[str] = None + + def get_message(self) -> str: + point = _add_points( + _multiply_point(_GENERATOR, self._y), + _multiply_point(_decode_point(_SPAKE2_N), self._w0), + ) + self._message1_point = point + self.share_v = _encode_point(point) + return self.share_v + + def finish(self, client_message_hex: str) -> _ServerSharedSecret: + if self._message1_point is None or self.share_v is None: + raise ValueError("get_message must be called before finish") + + client_point = _decode_point(client_message_hex) + if client_point.is_infinity: + raise ValueError("invalid curve point") + + adjusted = _add_points( + client_point, + _negate(_multiply_point(_decode_point(_SPAKE2_M), self._w0)), + ) + y_point = _multiply_point(adjusted, self._y) + verifier_share = _multiply_point(self._verifier_point, self._y) + transcript = _concat_length_prefixed( + _SPAKE2_CONTEXT, + _CLIENT_IDENTITY, + _SERVER_IDENTITY, + bytes.fromhex(_encode_point(_decode_point(_SPAKE2_M))), + bytes.fromhex(_encode_point(_decode_point(_SPAKE2_N))), + bytes.fromhex(_encode_point(client_point)), + bytes.fromhex(_encode_point(self._message1_point)), + bytes.fromhex(_encode_point(y_point)), + bytes.fromhex(_encode_point(verifier_share)), + _int_to_bytes(self._w0), + ) + return _ServerSharedSecret( + transcript=transcript, + share_p=client_message_hex, + share_v=self.share_v, + ) + + +def _compute_w0_w1(password: str, salt_b64: str) -> tuple[int, int]: + derived = hashlib.scrypt( + password.encode("utf-8"), + salt=_b64_to_bytes(salt_b64), + **_SCRYPT_PARAMS, + ) + midpoint = len(derived) // 2 + return ( + int.from_bytes(derived[:midpoint], "big"), + int.from_bytes(derived[midpoint:], "big"), + ) + + +class TrustedDeviceBridgeProver: + """Client-side prover mirroring Apple's prover worker.""" + + def __init__(self) -> None: + self._client: Optional[_ClientHandshake] = None + self._shared_secret: Optional[_ClientSharedSecret] = None + self._raw_key: Optional[str] = None + self._verified = False + self._verifier_key: Optional[str] = None + self._prover_key: Optional[str] = None + + def init_with_salt(self, salt_b64: str, code: str) -> None: + w0, w1 = _compute_w0_w1(code, salt_b64) + self._client = _ClientHandshake( + x_scalar=secrets.randbelow(_P256_ORDER), + w0=w0, + w1=w1, + ) + self._shared_secret = None + self._raw_key = None + self._verified = False + self._verifier_key = None + self._prover_key = None + + def get_message1(self) -> str: + if self._client is None: + raise ValueError("init_with_salt must be called before get_message1") + return self._client.get_message() + + def process_message1(self, message_hex: str) -> str: + if self._client is None: + raise ValueError("init_with_salt must be called before process_message1") + self._shared_secret = self._client.finish(message_hex) + return self.get_message2() + + def get_message2(self) -> str: + if self._shared_secret is None: + raise ValueError("process_message1 must be called before get_message2") + return self._shared_secret.get_confirmation() + + def process_message2(self, message_hex: str) -> dict[str, object]: + if self._shared_secret is None: + raise ValueError("process_message1 must be called before process_message2") + raw_key = self._shared_secret.verify(message_hex).hex() + self._raw_key = raw_key + self._verifier_key, self._prover_key = _derive_prover_and_verifier_keys(raw_key) + self._verified = True + return {"isVerified": True, "key": raw_key} + + def is_verified(self) -> bool: + return self._verified + + def get_key(self) -> str: + if self._raw_key is None: + raise ValueError("No bridge key is available yet.") + return self._raw_key + + def decrypt_message(self, ciphertext_b64: str) -> str: + if self._verifier_key is None: + raise ValueError("Bridge verifier key is not available.") + payload = _b64_to_bytes(ciphertext_b64) + version = payload[0] + iv_length, tag_length = _AES_GCM_LAYOUTS[version] + iv = payload[1 : 1 + iv_length] + tag = payload[1 + iv_length : 1 + iv_length + tag_length] + ciphertext = payload[1 + iv_length + tag_length :] + plaintext = AESGCM(bytes.fromhex(self._verifier_key)).decrypt( + iv, + ciphertext + tag, + bytes([version]), + ) + return plaintext.decode("utf-8") + + +class _TrustedDeviceBridgeServerProver: + """Internal test helper mirroring Apple's server-side bridge flow.""" + + def __init__(self, *, password: str, salt_b64: str) -> None: + w0, w1 = _compute_w0_w1(password, salt_b64) + verifier_point = _multiply_point(_GENERATOR, w1) + self._server = _ServerHandshake( + y_scalar=secrets.randbelow(_P256_ORDER), + w0=w0, + verifier_point=verifier_point, + ) + self._shared_secret: Optional[_ServerSharedSecret] = None + self._raw_key: Optional[str] = None + self._verifier_key: Optional[str] = None + self._prover_key: Optional[str] = None + + def get_message1(self) -> str: + return self._server.get_message() + + def process_message1(self, client_message_hex: str) -> str: + self._shared_secret = self._server.finish(client_message_hex) + return self.get_message2() + + def get_message2(self) -> str: + if self._shared_secret is None: + raise ValueError("process_message1 must be called before get_message2") + return self._shared_secret.get_confirmation() + + def verify_message2(self, message_hex: str) -> str: + if self._shared_secret is None: + raise ValueError("process_message1 must be called before verify_message2") + raw_key = self._shared_secret.verify(message_hex).hex() + self._raw_key = raw_key + self._verifier_key, self._prover_key = _derive_prover_and_verifier_keys(raw_key) + return raw_key + + def encrypt_message(self, plaintext: str) -> str: + if self._verifier_key is None: + raise ValueError("Bridge verifier key is not available.") + version = 0 + iv_length, tag_length = _AES_GCM_LAYOUTS[version] + iv = secrets.token_bytes(iv_length) + encrypted = AESGCM(bytes.fromhex(self._verifier_key)).encrypt( + iv, + plaintext.encode("utf-8"), + bytes([version]), + ) + ciphertext = encrypted[:-tag_length] + tag = encrypted[-tag_length:] + payload = bytes([version]) + iv + tag + ciphertext + return _bytes_to_b64(payload) diff --git a/pyicloud/session.py b/pyicloud/session.py index b696bfd9..9d8adfd9 100644 --- a/pyicloud/session.py +++ b/pyicloud/session.py @@ -188,6 +188,68 @@ def request( json=json, ) + def request_raw( + self, + method, + url, + params=None, + data=None, + headers=None, + cookies=None, + files=None, + auth=None, + timeout=None, + allow_redirects=True, + proxies=None, + hooks=None, + stream=None, + verify=None, + cert=None, + json=None, + ) -> Response: + """Dispatch a request without response-status normalization.""" + + return self._request_raw( + method, + url, + params=params, + data=data, + headers=headers, + cookies=cookies, + files=files, + auth=auth, + timeout=timeout, + allow_redirects=allow_redirects, + proxies=proxies, + hooks=hooks, + stream=stream, + verify=verify, + cert=cert, + json=json, + ) + + def _request_raw( + self, + method, + url, + **kwargs, + ) -> Response: + """Perform a request and persist cookies/session data without raising.""" + + self.logger.debug( + "%s %s", + method, + url, + ) + response: Response = super().request( + method=method, + url=url, + **kwargs, + ) + self._update_session_data(response) + self._save_session_data() + return response + def _request( self, method, diff --git a/requirements.txt b/requirements.txt index 06f2a422..6e0bc574 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ certifi>=2024.12.14 click>=8.1.8 +cryptography>=44.0.0 fido2>=2.0.0 keyring>=25.6.0 keyrings.alt>=5.0.2 diff --git a/tests/test_base.py b/tests/test_base.py index 01f8bdfe..7791b75f 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -5,6 +5,8 @@ import json import secrets +import tempfile +from pathlib import Path from typing import Any, List from unittest.mock import MagicMock, mock_open, patch @@ -21,6 +23,8 @@ PyiCloudFailedLoginException, PyiCloudServiceNotActivatedException, PyiCloudServiceUnavailable, + PyiCloudTrustedDevicePromptException, + PyiCloudTrustedDeviceVerificationException, ) from pyicloud.services.calendar import CalendarService from pyicloud.services.contacts import ContactsService @@ -257,6 +261,387 @@ def test_validate_2fa_code(pyicloud_service: PyiCloudService) -> None: assert pyicloud_service.validate_2fa_code("123456") +def test_validate_2fa_code_uses_bridge_verifier_for_step2_state( + pyicloud_service: PyiCloudService, +) -> None: + """Bridge-backed trusted-device prompts should use the bridge verifier instead of the legacy endpoint.""" + + pyicloud_service.data = {"dsInfo": {"hsaVersion": 2}, "hsaChallengeRequired": False} + pyicloud_service._two_factor_delivery_method = "trusted_device" + bridge_state = MagicMock(uses_legacy_trusted_device_verifier=False) + pyicloud_service._trusted_device_bridge_state = bridge_state + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._trusted_device_bridge.validate_code.return_value = True + pyicloud_service.trust_session = MagicMock( + side_effect=lambda: pyicloud_service.data.update({"hsaTrustedBrowser": True}) + or True + ) + pyicloud_service._session = MagicMock() + pyicloud_service.session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + + assert pyicloud_service.validate_2fa_code("123456") is True + + pyicloud_service._trusted_device_bridge.validate_code.assert_called_once() + pyicloud_service.session.post.assert_not_called() + pyicloud_service._trusted_device_bridge.close.assert_called_once_with(bridge_state) + pyicloud_service.trust_session.assert_called_once_with() + + +def test_validate_2fa_code_keeps_legacy_endpoint_for_bridge_w_subtype( + pyicloud_service: PyiCloudService, +) -> None: + """Apple's `_W` bridge subtype should keep using the legacy trusted-device verifier.""" + + pyicloud_service.data = {"dsInfo": {"hsaVersion": 2}, "hsaChallengeRequired": False} + pyicloud_service._two_factor_delivery_method = "trusted_device" + bridge_state = MagicMock(uses_legacy_trusted_device_verifier=True) + pyicloud_service._trusted_device_bridge_state = bridge_state + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service.trust_session = MagicMock( + side_effect=lambda: pyicloud_service.data.update({"hsaTrustedBrowser": True}) + or True + ) + pyicloud_service._session = MagicMock() + pyicloud_service.session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + pyicloud_service.session.post.return_value = MagicMock(status_code=200) + + assert pyicloud_service.validate_2fa_code("123456") is True + + pyicloud_service._trusted_device_bridge.validate_code.assert_not_called() + args = pyicloud_service.session.post.call_args.args + assert args[0] == ( + f"{pyicloud_service._auth_endpoint}/verify/trusteddevice/securitycode" + ) + pyicloud_service._trusted_device_bridge.close.assert_called_once_with(bridge_state) + + +def test_validate_2fa_code_bridge_verification_exception_propagates( + pyicloud_service: PyiCloudService, +) -> None: + """Bridge verification failures should not be downgraded to generic invalid-code results.""" + + pyicloud_service._two_factor_delivery_method = "trusted_device" + bridge_state = MagicMock(uses_legacy_trusted_device_verifier=False) + pyicloud_service._trusted_device_bridge_state = bridge_state + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._trusted_device_bridge.validate_code.side_effect = ( + PyiCloudTrustedDeviceVerificationException("bridge verification failed") + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="bridge verification failed", + ): + pyicloud_service.validate_2fa_code("123456") + + pyicloud_service._trusted_device_bridge.close.assert_called_once_with(bridge_state) + + +def test_request_2fa_code_requests_sms_delivery( + pyicloud_service: PyiCloudService, +) -> None: + """Nested phone verification data should trigger SMS delivery.""" + + pyicloud_service._auth_data = { + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": "sms", + } + } + } + + with patch("pyicloud.base.PyiCloudSession") as mock_session: + pyicloud_service._session = mock_session + mock_session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + + assert pyicloud_service.request_2fa_code() is True + + args = mock_session.put.call_args.args + kwargs = mock_session.put.call_args.kwargs + assert args[0] == f"{pyicloud_service._auth_endpoint}/verify/phone" + assert kwargs["json"] == { + "phoneNumber": {"id": 3, "nonFTEU": False}, + "mode": "sms", + } + assert kwargs["headers"]["Accept"] == "application/json" + + +def test_get_mfa_auth_options_parses_hsa2_boot_html( + pyicloud_service: PyiCloudService, +) -> None: + """GET /appleauth/auth HTML should populate the HSA2 boot context.""" + + response = MagicMock() + response.json.side_effect = ValueError("not json") + response.text = """ + + + + """ + pyicloud_service._session = MagicMock() + pyicloud_service.session.get.return_value = response + + auth_options = pyicloud_service._get_mfa_auth_options() + + _, kwargs = pyicloud_service.session.get.call_args + assert kwargs["headers"]["Accept"] == "text/html" + assert auth_options["authInitialRoute"] == "auth/bridge/step" + assert auth_options["hasTrustedDevices"] is True + assert auth_options["authFactors"] == ["web_piggybacking", "sms"] + assert auth_options["bridgeInitiateData"]["webSocketUrl"] == ( + "websocket.push.apple.com" + ) + assert auth_options["phoneNumberVerification"]["trustedPhoneNumber"]["id"] == 3 + assert auth_options["sourceAppId"] == "1159" + assert pyicloud_service._hsa2_boot_context is not None + assert pyicloud_service._hsa2_boot_context.auth_initial_route == ( + "auth/bridge/step" + ) + assert pyicloud_service._hsa2_boot_context.has_trusted_devices is True + + +def test_request_2fa_code_prefers_trusted_device_bridge( + pyicloud_service: PyiCloudService, +) -> None: + """Request-7 style HSA2 challenges should start the bridge before SMS.""" + + pyicloud_service.data = { + "dsInfo": {"hsaVersion": 2}, + "hsaChallengeRequired": True, + "hsaTrustedBrowser": False, + } + pyicloud_service._auth_data = { + "authInitialRoute": "auth/bridge/step", + "hasTrustedDevices": True, + "authFactors": ["web_piggybacking", "sms"], + "bridgeInitiateData": { + "apnsTopic": "com.apple.idmsauthwidget", + "apnsEnvironment": "prod", + "webSocketUrl": "websocket.push.apple.com", + }, + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": "sms", + } + }, + } + + bridge_state = MagicMock() + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._trusted_device_bridge.start.return_value = bridge_state + pyicloud_service._session = MagicMock() + pyicloud_service.session.headers = {"User-Agent": "test-agent"} + pyicloud_service.session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + + assert pyicloud_service.request_2fa_code() is True + + pyicloud_service._trusted_device_bridge.start.assert_called_once() + pyicloud_service.session.put.assert_not_called() + assert pyicloud_service.two_factor_delivery_method == "trusted_device" + assert pyicloud_service._trusted_device_bridge_state is bridge_state + + +def test_request_2fa_code_replaces_existing_bridge_state_before_restart( + pyicloud_service: PyiCloudService, +) -> None: + """Starting a new bridge prompt should close any previous in-memory bridge session.""" + + pyicloud_service._auth_data = { + "authInitialRoute": "auth/bridge/step", + "hasTrustedDevices": True, + "bridgeInitiateData": { + "apnsTopic": "com.apple.idmsauthwidget", + "apnsEnvironment": "prod", + "webSocketUrl": "websocket.push.apple.com", + }, + } + + previous_bridge_state = MagicMock() + next_bridge_state = MagicMock() + pyicloud_service._trusted_device_bridge_state = previous_bridge_state + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._trusted_device_bridge.start.return_value = next_bridge_state + pyicloud_service._session = MagicMock() + pyicloud_service.session.headers = {"User-Agent": "test-agent"} + pyicloud_service.session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + + assert pyicloud_service.request_2fa_code() is True + + pyicloud_service._trusted_device_bridge.close.assert_called_once_with( + previous_bridge_state + ) + assert pyicloud_service._trusted_device_bridge_state is next_bridge_state + + +def test_request_2fa_code_falls_back_to_sms_when_bridge_fails( + pyicloud_service: PyiCloudService, +) -> None: + """Bridge bootstrap failures should fall back to SMS when Apple exposes it.""" + + pyicloud_service._auth_data = { + "authInitialRoute": "auth/bridge/step", + "hasTrustedDevices": True, + "bridgeInitiateData": { + "apnsTopic": "com.apple.idmsauthwidget", + "apnsEnvironment": "prod", + "webSocketUrl": "websocket.push.apple.com", + }, + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": "sms", + } + }, + } + + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._trusted_device_bridge.start.side_effect = ( + PyiCloudTrustedDevicePromptException("bridge failed") + ) + pyicloud_service._session = MagicMock() + pyicloud_service.session.headers = {"User-Agent": "test-agent"} + pyicloud_service.session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + + assert pyicloud_service.request_2fa_code() is True + + args = pyicloud_service.session.put.call_args.args + kwargs = pyicloud_service.session.put.call_args.kwargs + assert args[0] == f"{pyicloud_service._auth_endpoint}/verify/phone" + assert kwargs["json"] == { + "phoneNumber": {"id": 3, "nonFTEU": False}, + "mode": "sms", + } + assert pyicloud_service.two_factor_delivery_method == "sms" + assert pyicloud_service.two_factor_delivery_notice == ( + "Trusted-device prompt failed; falling back to SMS." + ) + + +def test_request_2fa_code_keeps_security_key_path_separate( + pyicloud_service: PyiCloudService, +) -> None: + """Security-key challenges should not start the bridge or SMS flows.""" + + pyicloud_service._auth_data = { + "fsaChallenge": {"challenge": "abc"}, + "authInitialRoute": "auth/bridge/step", + "hasTrustedDevices": True, + "bridgeInitiateData": { + "apnsTopic": "com.apple.idmsauthwidget", + "apnsEnvironment": "prod", + "webSocketUrl": "websocket.push.apple.com", + }, + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": "sms", + } + }, + } + + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._session = MagicMock() + pyicloud_service.session.headers = {"User-Agent": "test-agent"} + + assert pyicloud_service.request_2fa_code() is False + + pyicloud_service._trusted_device_bridge.start.assert_not_called() + pyicloud_service.session.put.assert_not_called() + assert pyicloud_service.two_factor_delivery_method == "security_key" + + +def test_validate_2fa_code_uses_nested_sms_phone_number( + pyicloud_service: PyiCloudService, +) -> None: + """Nested phone verification data should validate via the SMS endpoint.""" + + pyicloud_service.data = {"dsInfo": {"hsaVersion": 1}, "hsaChallengeRequired": False} + pyicloud_service._auth_data = { + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": "sms", + } + } + } + pyicloud_service.trust_session = MagicMock( + side_effect=lambda: pyicloud_service.data.update({"hsaTrustedBrowser": True}) + or True + ) + + with patch("pyicloud.base.PyiCloudSession") as mock_session: + pyicloud_service._session = mock_session + mock_session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + "session_token": "test_session_token", + } + + mock_post_response = MagicMock() + mock_post_response.status_code = 200 + mock_post_response.json.return_value = {"success": True} + mock_session.post.return_value = mock_post_response + + assert pyicloud_service.validate_2fa_code("123456") + + args = mock_session.post.call_args.args + kwargs = mock_session.post.call_args.kwargs + assert args[0] == f"{pyicloud_service._auth_endpoint}/verify/phone/securitycode" + assert kwargs["json"] == { + "phoneNumber": {"id": 3, "nonFTEU": False}, + "securityCode": {"code": "123456"}, + "mode": "sms", + } + + def test_validate_2fa_code_failure(pyicloud_service: PyiCloudService) -> None: """Test the validate_2fa_code method with an invalid code.""" exception = PyiCloudAPIResponseException("Invalid code") @@ -431,6 +816,24 @@ def test_logout_clears_authenticated_state( assert pyicloud_service._devices is None +def test_logout_closes_active_trusted_device_bridge_state( + pyicloud_service: PyiCloudService, +) -> None: + """Logout should close any active trusted-device bridge session before clearing state.""" + + bridge_state = MagicMock() + pyicloud_service._trusted_device_bridge_state = bridge_state + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service.session.cookies = MagicMock() + pyicloud_service.session.cookies.get.return_value = None + pyicloud_service.session.clear_persistence = MagicMock() + + pyicloud_service.logout() + + pyicloud_service._trusted_device_bridge.close.assert_called_once_with(bridge_state) + assert pyicloud_service._trusted_device_bridge_state is None + + def test_cookiejar_path_property(pyicloud_session: PyiCloudSession) -> None: """Test the cookiejar_path property.""" path: str = pyicloud_session.cookiejar_path @@ -557,6 +960,53 @@ def test_request_success(pyicloud_service_working: PyiCloudService) -> None: ) +def test_session_persistence_excludes_trusted_device_bridge_state( + pyicloud_service_working: PyiCloudService, +) -> None: + """Bridge-only state should remain in memory and never be written to persisted session files.""" + + temp_root = Path(tempfile.gettempdir()) / "python-test-results" / "bridge-auth" + temp_root.mkdir(parents=True, exist_ok=True) + session = PyiCloudSession( + service=pyicloud_service_working, + client_id="", + cookie_directory=str(temp_root), + ) + pyicloud_service_working._session = session + pyicloud_service_working._trusted_device_bridge_state = MagicMock( + push_token="bridge-ptkn", + session_uuid="bridge-session-uuid", + idmsdata="bridge-idmsdata", + encrypted_code="bridge-encrypted-code", + ) + session._data = { + "session_token": "valid-token", + "session_id": "persisted-session-id", + } + + session._save_session_data() + + persisted_session = Path(session.session_path).read_text(encoding="utf-8") + for secret_value in ( + "bridge-ptkn", + "bridge-session-uuid", + "bridge-idmsdata", + "bridge-encrypted-code", + ): + assert secret_value not in persisted_session + + cookiejar_path = Path(session.cookiejar_path) + if cookiejar_path.exists(): + persisted_cookiejar = cookiejar_path.read_text(encoding="utf-8") + for secret_value in ( + "bridge-ptkn", + "bridge-session-uuid", + "bridge-idmsdata", + "bridge-encrypted-code", + ): + assert secret_value not in persisted_cookiejar + + def test_request_failure(pyicloud_service_working: PyiCloudService) -> None: """Test the request method with a failure response.""" diff --git a/tests/test_cmdline.py b/tests/test_cmdline.py index 964a107b..f1b68c69 100644 --- a/tests/test_cmdline.py +++ b/tests/test_cmdline.py @@ -10,12 +10,35 @@ from pathlib import Path from types import SimpleNamespace from typing import Any, Optional -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch from uuid import uuid4 import click from typer.testing import CliRunner +from pyicloud.services.notes.models import Attachment as NoteAttachment +from pyicloud.services.notes.models import ChangeEvent as NoteChangeEvent +from pyicloud.services.notes.models import ( + Note, + NoteFolder, + NoteSummary, +) +from pyicloud.services.notes.service import NoteLockedError, NoteNotFound +from pyicloud.services.reminders.models import ( + Alarm, + AlarmWithTrigger, + Hashtag, + ListRemindersResult, + LocationTrigger, + Proximity, + RecurrenceFrequency, + RecurrenceRule, + Reminder, + ReminderChangeEvent, + RemindersList, + URLAttachment, +) + account_index_module = importlib.import_module("pyicloud.cli.account_index") cli_module = importlib.import_module("pyicloud.cli.app") context_module = importlib.import_module("pyicloud.cli.context") @@ -195,6 +218,613 @@ def delete(self, anonymous_id: str) -> dict[str, Any]: return {"anonymousId": anonymous_id, "deleted": True} +class FakeNotes: + """Notes service fixture.""" + + def __init__(self) -> None: + attachment = NoteAttachment( + id="Attachment/PDF", + filename="agenda.pdf", + uti="com.adobe.pdf", + size=12, + download_url="https://example.com/agenda.pdf", + preview_url="https://example.com/agenda-preview.pdf", + thumbnail_url="https://example.com/agenda-thumb.png", + ) + self.recent_requests: list[int] = [] + self.iter_all_requests: list[str | None] = [] + self.folder_requests: list[tuple[str, int | None]] = [] + self.render_calls: list[dict[str, Any]] = [] + self.export_calls: list[dict[str, Any]] = [] + self.change_requests: list[str | None] = [] + self.folder_rows = [ + NoteFolder( + id="Folder/NOTES", + name="Notes", + has_subfolders=False, + count=1, + ), + NoteFolder( + id="Folder/WORK", + name="Work", + has_subfolders=True, + count=3, + ), + ] + self.recent_rows = [ + NoteSummary( + id="Note/DELETED", + title="Deleted Note", + snippet="Old note", + modified_at=datetime(2026, 3, 5, tzinfo=timezone.utc), + folder_id="Folder/DELETED", + folder_name="Recently Deleted", + is_deleted=True, + is_locked=False, + ), + NoteSummary( + id="Note/DAILY", + title="Daily Plan", + snippet="Ship CLI", + modified_at=datetime(2026, 3, 4, tzinfo=timezone.utc), + folder_id="Folder/NOTES", + folder_name="Notes", + is_deleted=False, + is_locked=False, + ), + NoteSummary( + id="Note/MEETING", + title="Meeting Notes", + snippet="Discuss roadmap", + modified_at=datetime(2026, 3, 3, tzinfo=timezone.utc), + folder_id="Folder/WORK", + folder_name="Work", + is_deleted=False, + is_locked=False, + ), + ] + self.all_rows = [ + self.recent_rows[2], + NoteSummary( + id="Note/FOLLOWUP", + title="Meeting Follow-up", + snippet="Send recap", + modified_at=datetime(2026, 3, 2, tzinfo=timezone.utc), + folder_id="Folder/WORK", + folder_name="Work", + is_deleted=False, + is_locked=False, + ), + self.recent_rows[1], + self.recent_rows[2], + ] + self.notes = { + "Note/DAILY": Note( + id="Note/DAILY", + title="Daily Plan", + snippet="Ship CLI", + modified_at=datetime(2026, 3, 4, tzinfo=timezone.utc), + folder_id="Folder/NOTES", + folder_name="Notes", + is_deleted=False, + is_locked=False, + text="Ship CLI", + html="

Ship CLI

", + attachments=[attachment], + ), + "Note/MEETING": Note( + id="Note/MEETING", + title="Meeting Notes", + snippet="Discuss roadmap", + modified_at=datetime(2026, 3, 3, tzinfo=timezone.utc), + folder_id="Folder/WORK", + folder_name="Work", + is_deleted=False, + is_locked=False, + text="Discuss roadmap", + html="

Discuss roadmap

", + attachments=[attachment], + ), + "Note/FOLLOWUP": Note( + id="Note/FOLLOWUP", + title="Meeting Follow-up", + snippet="Send recap", + modified_at=datetime(2026, 3, 2, tzinfo=timezone.utc), + folder_id="Folder/WORK", + folder_name="Work", + is_deleted=False, + is_locked=False, + text="Send recap", + html="

Send recap

", + attachments=None, + ), + } + self.change_rows = [ + NoteChangeEvent(type="updated", note=self.recent_rows[1]), + NoteChangeEvent(type="deleted", note=self.recent_rows[0]), + ] + self.cursor = "notes-cursor-1" + + @staticmethod + def _matches_id(note_id: str, query: str) -> bool: + return note_id == query or note_id.split("/", 1)[-1] == query + + def recents(self, *, limit: int = 50): + self.recent_requests.append(limit) + return list(self.recent_rows[:limit]) + + def folders(self): + return list(self.folder_rows) + + def in_folder(self, folder_id: str, limit: int | None = None): + self.folder_requests.append((folder_id, limit)) + rows = [row for row in self.all_rows if row.folder_id == folder_id] + return list(rows[:limit] if limit is not None else rows) + + def iter_all(self, *, since: Optional[str] = None): + self.iter_all_requests.append(since) + return iter(self.all_rows) + + def get(self, note_id: str, *, with_attachments: bool = False): + if self._matches_id("Note/LOCKED", note_id): + raise NoteLockedError(f"Note is locked: {note_id}") + for candidate_id, note in self.notes.items(): + if self._matches_id(candidate_id, note_id): + attachments = note.attachments if with_attachments else None + return note.model_copy(update={"attachments": attachments}) + raise NoteNotFound(f"Note not found: {note_id}") + + def render_note(self, note_id: str, **kwargs: Any) -> str: + note = self.get(note_id, with_attachments=False) + self.render_calls.append({"note_id": note.id, **kwargs}) + return note.html or f"

{note.id}

" + + def export_note(self, note_id: str, output_dir: str, **kwargs: Any) -> str: + note = self.get(note_id, with_attachments=False) + path = Path(output_dir) / f"{note.id.split('/', 1)[-1].lower()}.html" + self.export_calls.append( + {"note_id": note.id, "output_dir": output_dir, **kwargs} + ) + return str(path) + + def iter_changes(self, *, since: Optional[str] = None): + self.change_requests.append(since) + return iter(self.change_rows) + + def sync_cursor(self) -> str: + return self.cursor + + +class FakeReminders: + """Reminders service fixture.""" + + def __init__(self) -> None: + self.list_rows = { + "List/INBOX": RemindersList( + id="List/INBOX", + title="Inbox", + color='{"daHexString":"#007AFF","ckSymbolicColorName":"blue"}', + count=0, + ), + "List/WORK": RemindersList( + id="List/WORK", + title="Work", + color='{"daHexString":"#34C759","ckSymbolicColorName":"green"}', + count=0, + ), + } + self.reminder_rows = { + "Reminder/A": Reminder( + id="Reminder/A", + list_id="List/INBOX", + title="Buy milk", + desc="2 percent", + completed=False, + due_date=datetime(2026, 3, 31, 9, 0, tzinfo=timezone.utc), + priority=1, + flagged=True, + all_day=False, + time_zone="Europe/Luxembourg", + alarm_ids=["Alarm/A"], + hashtag_ids=["Hashtag/ERRANDS"], + attachment_ids=["Attachment/LINK"], + recurrence_rule_ids=["Recurrence/WEEKLY"], + parent_reminder_id="Reminder/PARENT", + created=datetime(2026, 3, 1, tzinfo=timezone.utc), + modified=datetime(2026, 3, 4, tzinfo=timezone.utc), + ), + "Reminder/B": Reminder( + id="Reminder/B", + list_id="List/INBOX", + title="Pay rent", + desc="", + completed=True, + completed_date=datetime(2026, 3, 2, tzinfo=timezone.utc), + priority=0, + flagged=False, + all_day=False, + created=datetime(2026, 3, 1, tzinfo=timezone.utc), + modified=datetime(2026, 3, 2, tzinfo=timezone.utc), + ), + "Reminder/C": Reminder( + id="Reminder/C", + list_id="List/WORK", + title="Prepare deck", + desc="Slides for review", + completed=False, + priority=5, + flagged=False, + all_day=False, + created=datetime(2026, 3, 3, tzinfo=timezone.utc), + modified=datetime(2026, 3, 4, tzinfo=timezone.utc), + ), + } + self.alarm_rows = { + "Alarm/A": Alarm( + id="Alarm/A", + alarm_uid="alarm-a", + reminder_id="Reminder/A", + trigger_id="Trigger/A", + ) + } + self.trigger_rows = { + "Trigger/A": LocationTrigger( + id="Trigger/A", + alarm_id="Alarm/A", + title="Office", + address="1 Infinite Loop", + latitude=37.3318, + longitude=-122.0312, + radius=150.0, + proximity=Proximity.ARRIVING, + location_uid="office", + ) + } + self.hashtag_rows = { + "Hashtag/ERRANDS": Hashtag( + id="Hashtag/ERRANDS", + name="errands", + reminder_id="Reminder/A", + created=datetime(2026, 3, 1, tzinfo=timezone.utc), + ) + } + self.attachment_rows = { + "Attachment/LINK": URLAttachment( + id="Attachment/LINK", + reminder_id="Reminder/A", + url="https://example.com/checklist", + uti="public.url", + ) + } + self.recurrence_rows = { + "Recurrence/WEEKLY": RecurrenceRule( + id="Recurrence/WEEKLY", + reminder_id="Reminder/A", + frequency=RecurrenceFrequency.WEEKLY, + interval=1, + occurrence_count=0, + first_day_of_week=1, + ) + } + self.snapshot_requests: list[dict[str, Any]] = [] + self.change_requests: list[str | None] = [] + self.cursor = "reminders-cursor-1" + + @staticmethod + def _matches_id(record_id: str, query: str) -> bool: + return record_id == query or record_id.split("/", 1)[-1] == query + + def _find_reminder(self, reminder_id: str) -> Reminder: + for candidate_id, reminder in self.reminder_rows.items(): + if self._matches_id(candidate_id, reminder_id): + return reminder + raise LookupError(f"Reminder not found: {reminder_id}") + + def lists(self): + for row in self.list_rows.values(): + row.count = sum( + 1 + for reminder in self.reminder_rows.values() + if reminder.list_id == row.id and not reminder.deleted + ) + return list(self.list_rows.values()) + + def reminders(self, list_id: Optional[str] = None): + rows = [ + reminder + for reminder in self.reminder_rows.values() + if not reminder.deleted and (list_id is None or reminder.list_id == list_id) + ] + return list(rows) + + def list_reminders( + self, + list_id: str, + include_completed: bool = False, + results_limit: int = 200, + ) -> ListRemindersResult: + normalized = list_id if list_id.startswith("List/") else f"List/{list_id}" + self.snapshot_requests.append( + { + "list_id": normalized, + "include_completed": include_completed, + "results_limit": results_limit, + } + ) + reminders = [ + reminder + for reminder in self.reminder_rows.values() + if reminder.list_id == normalized + and not reminder.deleted + and (include_completed or not reminder.completed) + ][:results_limit] + reminder_ids = {reminder.id for reminder in reminders} + return ListRemindersResult( + reminders=reminders, + alarms={ + alarm_id: alarm + for alarm_id, alarm in self.alarm_rows.items() + if alarm.reminder_id in reminder_ids + }, + triggers={ + trigger_id: trigger + for trigger_id, trigger in self.trigger_rows.items() + if any( + alarm.trigger_id == trigger_id + for alarm in self.alarm_rows.values() + if alarm.reminder_id in reminder_ids + ) + }, + attachments={ + attachment_id: attachment + for attachment_id, attachment in self.attachment_rows.items() + if attachment.reminder_id in reminder_ids + }, + hashtags={ + hashtag_id: hashtag + for hashtag_id, hashtag in self.hashtag_rows.items() + if hashtag.reminder_id in reminder_ids + }, + recurrence_rules={ + rule_id: rule + for rule_id, rule in self.recurrence_rows.items() + if rule.reminder_id in reminder_ids + }, + ) + + def get(self, reminder_id: str) -> Reminder: + return self._find_reminder(reminder_id) + + def create( + self, + list_id: str, + title: str, + desc: str = "", + completed: bool = False, + due_date: Optional[datetime] = None, + priority: int = 0, + flagged: bool = False, + all_day: bool = False, + time_zone: Optional[str] = None, + parent_reminder_id: Optional[str] = None, + ) -> Reminder: + next_id = f"Reminder/CREATED-{len(self.reminder_rows) + 1}" + reminder = Reminder( + id=next_id, + list_id=list_id, + title=title, + desc=desc, + completed=completed, + due_date=due_date, + priority=priority, + flagged=flagged, + all_day=all_day, + time_zone=time_zone, + parent_reminder_id=parent_reminder_id, + created=datetime(2026, 3, 30, tzinfo=timezone.utc), + modified=datetime(2026, 3, 30, tzinfo=timezone.utc), + ) + self.reminder_rows[reminder.id] = reminder + return reminder + + def update(self, reminder: Reminder) -> None: + self.reminder_rows[reminder.id] = reminder + + def delete(self, reminder: Reminder) -> None: + reminder.deleted = True + self.reminder_rows[reminder.id] = reminder + + def add_location_trigger( + self, + reminder: Reminder, + title: str = "", + address: str = "", + latitude: float = 0.0, + longitude: float = 0.0, + radius: float = 100.0, + proximity: Proximity = Proximity.ARRIVING, + ) -> tuple[Alarm, LocationTrigger]: + index = len(self.alarm_rows) + 1 + alarm = Alarm( + id=f"Alarm/{index}", + alarm_uid=f"alarm-{index}", + reminder_id=reminder.id, + trigger_id=f"Trigger/{index}", + ) + trigger = LocationTrigger( + id=f"Trigger/{index}", + alarm_id=alarm.id, + title=title, + address=address, + latitude=latitude, + longitude=longitude, + radius=radius, + proximity=proximity, + location_uid=f"location-{index}", + ) + self.alarm_rows[alarm.id] = alarm + self.trigger_rows[trigger.id] = trigger + reminder.alarm_ids.append(alarm.id) + return alarm, trigger + + def create_hashtag(self, reminder: Reminder, name: str) -> Hashtag: + hashtag = Hashtag( + id=f"Hashtag/{name.upper()}", + name=name, + reminder_id=reminder.id, + created=datetime(2026, 3, 30, tzinfo=timezone.utc), + ) + self.hashtag_rows[hashtag.id] = hashtag + reminder.hashtag_ids.append(hashtag.id) + return hashtag + + def update_hashtag(self, hashtag: Hashtag, name: str) -> None: + hashtag.name = name + + def delete_hashtag(self, reminder: Reminder, hashtag: Hashtag) -> None: + reminder.hashtag_ids = [ + row_id for row_id in reminder.hashtag_ids if row_id != hashtag.id + ] + self.hashtag_rows.pop(hashtag.id, None) + + def create_url_attachment( + self, reminder: Reminder, url: str, uti: str = "public.url" + ) -> URLAttachment: + attachment = URLAttachment( + id=f"Attachment/{len(self.attachment_rows) + 1}", + reminder_id=reminder.id, + url=url, + uti=uti, + ) + self.attachment_rows[attachment.id] = attachment + reminder.attachment_ids.append(attachment.id) + return attachment + + def update_attachment( + self, + attachment: URLAttachment, + *, + url: Optional[str] = None, + uti: Optional[str] = None, + filename: Optional[str] = None, + file_size: Optional[int] = None, + width: Optional[int] = None, + height: Optional[int] = None, + ) -> None: + if url is not None: + attachment.url = url + if uti is not None: + attachment.uti = uti + + def delete_attachment(self, reminder: Reminder, attachment: URLAttachment) -> None: + reminder.attachment_ids = [ + row_id for row_id in reminder.attachment_ids if row_id != attachment.id + ] + self.attachment_rows.pop(attachment.id, None) + + def create_recurrence_rule( + self, + reminder: Reminder, + *, + frequency: RecurrenceFrequency = RecurrenceFrequency.DAILY, + interval: int = 1, + occurrence_count: int = 0, + first_day_of_week: int = 0, + ) -> RecurrenceRule: + rule = RecurrenceRule( + id=f"Recurrence/{len(self.recurrence_rows) + 1}", + reminder_id=reminder.id, + frequency=frequency, + interval=interval, + occurrence_count=occurrence_count, + first_day_of_week=first_day_of_week, + ) + self.recurrence_rows[rule.id] = rule + reminder.recurrence_rule_ids.append(rule.id) + return rule + + def update_recurrence_rule( + self, + recurrence_rule: RecurrenceRule, + *, + frequency: Optional[RecurrenceFrequency] = None, + interval: Optional[int] = None, + occurrence_count: Optional[int] = None, + first_day_of_week: Optional[int] = None, + ) -> None: + if frequency is not None: + recurrence_rule.frequency = frequency + if interval is not None: + recurrence_rule.interval = interval + if occurrence_count is not None: + recurrence_rule.occurrence_count = occurrence_count + if first_day_of_week is not None: + recurrence_rule.first_day_of_week = first_day_of_week + + def delete_recurrence_rule( + self, reminder: Reminder, recurrence_rule: RecurrenceRule + ) -> None: + reminder.recurrence_rule_ids = [ + row_id + for row_id in reminder.recurrence_rule_ids + if row_id != recurrence_rule.id + ] + self.recurrence_rows.pop(recurrence_rule.id, None) + + def alarms_for(self, reminder: Reminder) -> list[AlarmWithTrigger]: + rows = [] + for alarm_id in reminder.alarm_ids: + alarm = self.alarm_rows[alarm_id] + rows.append( + AlarmWithTrigger( + alarm=alarm, + trigger=self.trigger_rows.get(alarm.trigger_id), + ) + ) + return rows + + def tags_for(self, reminder: Reminder) -> list[Hashtag]: + return [ + self.hashtag_rows[row_id] + for row_id in reminder.hashtag_ids + if row_id in self.hashtag_rows + ] + + def attachments_for(self, reminder: Reminder) -> list[URLAttachment]: + return [ + self.attachment_rows[row_id] + for row_id in reminder.attachment_ids + if row_id in self.attachment_rows + ] + + def recurrence_rules_for(self, reminder: Reminder) -> list[RecurrenceRule]: + return [ + self.recurrence_rows[row_id] + for row_id in reminder.recurrence_rule_ids + if row_id in self.recurrence_rows + ] + + def iter_changes(self, *, since: Optional[str] = None): + self.change_requests.append(since) + return iter( + [ + ReminderChangeEvent( + type="updated", + reminder_id="Reminder/A", + reminder=self.reminder_rows["Reminder/A"], + ), + ReminderChangeEvent( + type="deleted", + reminder_id="Reminder/Z", + reminder=None, + ), + ] + ) + + def sync_cursor(self) -> str: + return self.cursor + + class FakeAPI: """Authenticated API fixture.""" @@ -211,6 +841,9 @@ def __init__( self.is_china_mainland = china_mainland self.fido2_devices: list[dict[str, Any]] = [] self.trusted_devices: list[dict[str, Any]] = [] + self.two_factor_delivery_method = "unknown" + self.two_factor_delivery_notice = None + self.request_2fa_code = MagicMock(return_value=False) self.validate_2fa_code = MagicMock(return_value=True) self.confirm_security_key = MagicMock(return_value=True) self.send_verification_code = MagicMock(return_value=True) @@ -322,6 +955,8 @@ def __init__( all=photo_album, ) self.hidemyemail = FakeHideMyEmail() + self.notes = FakeNotes() + self.reminders = FakeReminders() def _logout( self, @@ -1679,6 +2314,180 @@ def test_trusted_device_2sa_flow() -> None: ) +def test_sms_2fa_flow_requests_sms_before_prompt() -> None: + """Auth login should request SMS delivery before prompting for the code.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + fake_api.two_factor_delivery_method = "sms" + fake_api.request_2fa_code.return_value = True + with patch.object(context_module.typer, "prompt", return_value="123456"): + result = _invoke(fake_api, "auth", "login", interactive=True) + assert result.exit_code == 0 + assert "Requested a 2FA code by SMS." in result.stdout + fake_api.request_2fa_code.assert_called_once_with() + fake_api.validate_2fa_code.assert_called_once_with("123456") + + +def test_trusted_device_2fa_flow_reports_device_prompt() -> None: + """Auth login should report trusted-device prompt delivery when bridge succeeds.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + + def request_prompt() -> bool: + fake_api.two_factor_delivery_method = "trusted_device" + return True + + fake_api.request_2fa_code.side_effect = request_prompt + + with patch.object(context_module.typer, "prompt", return_value="123456"): + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code == 0 + assert "Requested a 2FA prompt on your trusted Apple devices." in result.stdout + fake_api.validate_2fa_code.assert_called_once_with("123456") + + +def test_trusted_device_2fa_retries_invalid_codes_before_success() -> None: + """Auth login should allow up to three trusted-device 2FA attempts.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + + def request_prompt() -> bool: + fake_api.two_factor_delivery_method = "trusted_device" + return True + + fake_api.request_2fa_code.side_effect = request_prompt + fake_api.validate_2fa_code.side_effect = [False, False, True] + + with patch.object( + context_module.typer, + "prompt", + side_effect=["111111", "222222", "333333"], + ): + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code == 0 + assert "Invalid 2FA code. 2 attempt(s) remaining." in result.stdout + assert "Invalid 2FA code. 1 attempt(s) remaining." in result.stdout + assert fake_api.validate_2fa_code.call_args_list == [ + call("111111"), + call("222222"), + call("333333"), + ] + + +def test_sms_2fa_aborts_after_three_invalid_codes() -> None: + """Auth login should stop after three invalid 2FA attempts.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + fake_api.two_factor_delivery_method = "sms" + fake_api.request_2fa_code.return_value = True + fake_api.validate_2fa_code.side_effect = [False, False, False] + + with patch.object( + context_module.typer, + "prompt", + side_effect=["111111", "222222", "333333"], + ): + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code != 0 + assert result.exception.args[0] == "Failed to verify the 2FA code." + assert "Invalid 2FA code. 2 attempt(s) remaining." in result.stdout + assert "Invalid 2FA code. 1 attempt(s) remaining." in result.stdout + assert fake_api.validate_2fa_code.call_args_list == [ + call("111111"), + call("222222"), + call("333333"), + ] + + +def test_trusted_device_2fa_bridge_fallback_reports_notice() -> None: + """Auth login should print the bridge fallback notice before the SMS message.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + + def request_sms_fallback() -> bool: + fake_api.two_factor_delivery_method = "sms" + fake_api.two_factor_delivery_notice = ( + "Trusted-device prompt failed; falling back to SMS." + ) + return True + + fake_api.request_2fa_code.side_effect = request_sms_fallback + + with patch.object(context_module.typer, "prompt", return_value="123456"): + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code == 0 + assert "Trusted-device prompt failed; falling back to SMS." in result.stdout + assert "Requested a 2FA code by SMS." in result.stdout + fake_api.validate_2fa_code.assert_called_once_with("123456") + + +def test_sms_2fa_request_failure_aborts() -> None: + """Auth login should surface SMS delivery request failures clearly.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + fake_api.request_2fa_code.side_effect = context_module.PyiCloudAPIResponseException( + "sms request failed" + ) + + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code != 0 + assert result.exception.args[0] == "Failed to request the 2FA SMS code." + fake_api.validate_2fa_code.assert_not_called() + + +def test_trusted_device_2fa_request_failure_aborts() -> None: + """Auth login should surface bridge delivery failures clearly.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + fake_api.request_2fa_code.side_effect = ( + context_module.PyiCloudTrustedDevicePromptException("bridge failed") + ) + + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code != 0 + assert result.exception.args[0] == ( + "Failed to request the 2FA trusted-device prompt." + ) + fake_api.validate_2fa_code.assert_not_called() + + +def test_trusted_device_2fa_verification_failure_aborts() -> None: + """Auth login should surface bridge verification failures clearly.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + + def request_prompt() -> bool: + fake_api.two_factor_delivery_method = "trusted_device" + return True + + fake_api.request_2fa_code.side_effect = request_prompt + fake_api.validate_2fa_code.side_effect = ( + context_module.PyiCloudTrustedDeviceVerificationException( + "bridge verification failed" + ) + ) + + with patch.object(context_module.typer, "prompt", return_value="123456"): + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code != 0 + assert result.exception.args[0] == ("Failed to verify the 2FA trusted-device code.") + + def test_non_interactive_2sa_does_not_send_verification_code() -> None: """Non-interactive 2SA should fail before sending a verification code.""" diff --git a/tests/test_hsa2_bridge.py b/tests/test_hsa2_bridge.py new file mode 100644 index 00000000..f21d558f --- /dev/null +++ b/tests/test_hsa2_bridge.py @@ -0,0 +1,1323 @@ +"""Tests for the HSA2 trusted-device bridge helpers.""" + +from __future__ import annotations + +import base64 +import json +import socket +from binascii import unhexlify +from typing import Callable +from unittest.mock import MagicMock, call + +import pytest + +import pyicloud.hsa2_bridge as bridge_module +from pyicloud.exceptions import ( + PyiCloudTrustedDevicePromptException, + PyiCloudTrustedDeviceVerificationException, +) +from pyicloud.hsa2_bridge import ( + BRIDGE_DONE_DATA_B64, + BridgePushPayload, + Hsa2BootContext, + TrustedDeviceBridgeBootstrapper, + _encode_ack_message, + _encode_bytes_field, + _encode_string_field, + _encode_uint32_field, + _encode_web_filter_message, + _extract_json_payload, + _hex_to_b64, + _topic_hash, + parse_boot_args_html, +) +from pyicloud.hsa2_bridge_prover import ( + TrustedDeviceBridgeProver, + _TrustedDeviceBridgeServerProver, +) + + +class _FakeWebSocket: + def __init__( + self, + messages: list[bytes | Exception], + *, + on_read: Callable[[int], None] | None = None, + ) -> None: + self._messages = list(messages) + self._on_read = on_read + self.sent_messages: list[bytes] = [] + self.closed = False + self.read_count = 0 + + def send_binary(self, payload: bytes) -> None: + self.sent_messages.append(payload) + + def read_message(self) -> bytes: + self.read_count += 1 + if self._on_read is not None: + self._on_read(self.read_count) + message = self._messages.pop(0) + if isinstance(message, Exception): + raise message + return message + + def close(self) -> None: + self.closed = True + + +class _FakePrivateKey: + def sign(self, nonce: bytes, _algorithm: object) -> bytes: + return b"signature-for-" + nonce[:4] + + +def _encode_connection_response(push_token: bytes) -> bytes: + payload = b"".join( + [ + _encode_string_field(1, base64.b64encode(push_token).decode("ascii")), + _encode_uint32_field(2, 0), + ] + ) + return _encode_bytes_field(1, payload) + + +def _encode_connection_response_with_token_b64(push_token_b64: str) -> bytes: + payload = b"".join( + [ + _encode_string_field(1, push_token_b64), + _encode_uint32_field(2, 0), + ] + ) + return _encode_bytes_field(1, payload) + + +def _encode_push_message( + topic: str, payload: dict[str, object], message_id: int +) -> bytes: + topic_bytes = bytes.fromhex(_topic_hash(topic)) + body = b"".join( + [ + _encode_bytes_field(1, topic_bytes), + _encode_uint32_field(2, message_id), + _encode_bytes_field(4, json.dumps(payload).encode("utf-8")), + ] + ) + return _encode_bytes_field(2, body) + + +def _encode_channel_subscription_response(topic: str, message_id: int = 1) -> bytes: + channel_response = b"".join( + [ + _encode_string_field(1, topic), + _encode_bytes_field(2, _encode_bytes_field(1, b"channel-id")), + ] + ) + payload = _encode_bytes_field(1, channel_response) + body = b"".join( + [ + _encode_bytes_field(1, payload), + _encode_uint32_field(2, message_id), + _encode_uint32_field(3, 0), + ] + ) + return _encode_bytes_field(3, body) + + +def _read_varint(data: bytes, offset: int) -> tuple[int, int]: + value = 0 + shift = 0 + while True: + byte = data[offset] + offset += 1 + value |= (byte & 0x7F) << shift + if not (byte & 0x80): + return value, offset + shift += 7 + + +def _decode_fields(data: bytes) -> dict[int, list[int | bytes]]: + offset = 0 + fields: dict[int, list[int | bytes]] = {} + while offset < len(data): + key, offset = _read_varint(data, offset) + field_number = key >> 3 + wire_type = key & 0x07 + + if wire_type == 0: + value, offset = _read_varint(data, offset) + elif wire_type == 2: + length, offset = _read_varint(data, offset) + value = data[offset : offset + length] + offset += length + else: + raise AssertionError(f"Unexpected wire type {wire_type}") + + fields.setdefault(field_number, []).append(value) + return fields + + +def _boot_context(topic: str = "com.apple.idmsauthwidget") -> Hsa2BootContext: + return Hsa2BootContext( + auth_initial_route="auth/bridge/step", + has_trusted_devices=True, + auth_factors=("web_piggybacking", "sms"), + bridge_initiate_data={ + "apnsTopic": topic, + "apnsEnvironment": "prod", + "webSocketUrl": "websocket.push.apple.com", + }, + source_app_id="1159", + ) + + +def _response(status_code: int) -> MagicMock: + response = MagicMock() + response.status_code = status_code + response.text = "" + return response + + +def test_parse_boot_args_html_extracts_bridge_context() -> None: + """Request-5 style boot args should yield the bridge routing metadata.""" + + html = """ + + + + """ + + boot_context = parse_boot_args_html(html) + + assert boot_context.auth_initial_route == "auth/bridge/step" + assert boot_context.has_trusted_devices is True + assert boot_context.auth_factors == ( + "web_piggybacking", + "robocall", + "sms", + "generatedcode", + ) + assert boot_context.bridge_initiate_data["webSocketUrl"] == ( + "websocket.push.apple.com" + ) + assert boot_context.phone_number_verification["trustedPhoneNumber"]["id"] == 3 + assert boot_context.source_app_id == "1159" + + +def test_parse_boot_args_html_accepts_reordered_script_attributes() -> None: + """boot_args extraction should not depend on one exact script tag string.""" + + html = """ + + + + """ + + boot_context = parse_boot_args_html(html) + + assert boot_context.auth_initial_route == "auth/bridge/step" + assert boot_context.has_trusted_devices is True + assert boot_context.bridge_initiate_data["webSocketUrl"] == ( + "websocket.push.apple.com" + ) + + +def test_read_varint_rejects_malformed_overlong_varint() -> None: + """Malformed bridge varints should fail immediately instead of reading forever.""" + + with pytest.raises( + PyiCloudTrustedDevicePromptException, + match="Malformed protobuf varint", + ): + bridge_module._read_varint(b"\x80" * 10, 0) + + +def test_decode_fields_rejects_truncated_length_delimited_field() -> None: + """Length-delimited bridge fields must fit inside the current frame.""" + + with pytest.raises( + PyiCloudTrustedDevicePromptException, + match="Truncated protobuf field", + ): + bridge_module._decode_fields(b"\x0a\x05abc") + + +@pytest.mark.parametrize( + ("payload", "message"), + [ + ({"sessionUUID": 123}, "Malformed trusted-device bridge push payload"), + ({"sessionUUID": " "}, "Malformed trusted-device bridge push payload"), + ( + {"sessionUUID": "bridge-session", "nextStep": " "}, + "Malformed trusted-device bridge push payload", + ), + ( + {"sessionUUID": "bridge-session", "encryptedCode": " "}, + "Malformed trusted-device bridge push payload", + ), + ( + {"sessionUUID": "bridge-session", "ec": "oops"}, + "Malformed trusted-device bridge push payload", + ), + ], +) +def test_bridge_push_payload_rejects_malformed_fields( + payload: dict[str, object], message: str +) -> None: + """Bridge push validation should reject coerced or blank protocol fields.""" + + with pytest.raises(PyiCloudTrustedDevicePromptException, match=message): + BridgePushPayload.from_payload(payload) + + +def test_bridge_push_payload_preserves_unknown_extra_fields() -> None: + """Unknown Apple bridge fields should survive strict validation unchanged.""" + + payload = BridgePushPayload.from_payload( + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "extraField": {"foo": "bar"}, + } + ) + + assert payload.session_uuid == "bridge-session" + assert payload.payload["extraField"] == {"foo": "bar"} + + +def test_extract_json_payload_finds_embedded_json() -> None: + """Request-8 style binary payloads should yield the embedded JSON envelope.""" + + expected_payload = { + "sessionUUID": "bridge-session", + "nextStep": "2", + "ruiURLKey": "hsa2TwoFactorAuthApprovalFlowUrl", + } + noisy_payload = ( + b"\x12\xa8\x07\x00" + + json.dumps(expected_payload).encode("utf-8") + + b"\x18\x00\x01" + ) + + assert _extract_json_payload(noisy_payload) == expected_payload + + +def test_trusted_device_bridge_prover_roundtrip() -> None: + """The Python prover port should match the worker's SPAKE2+/AES-GCM flow.""" + + salt_b64 = base64.b64encode(b"0123456789abcdef").decode("ascii") + prover = TrustedDeviceBridgeProver() + server = _TrustedDeviceBridgeServerProver(password="050044", salt_b64=salt_b64) + + prover.init_with_salt(salt_b64, "050044") + client_message1 = prover.get_message1() + server_message1 = server.get_message1() + server_message2 = server.process_message1(client_message1) + client_message2 = prover.process_message1(server_message1) + server_key = server.verify_message2(client_message2) + client_key = prover.process_message2(server_message2)["key"] + + assert prover.is_verified() is True + assert client_key == server_key + encrypted_code = server.encrypt_message("derived-device-code") + assert prover.decrypt_message(encrypted_code) == "derived-device-code" + + +def test_trusted_device_bridge_bootstrap_keeps_websocket_open_and_persists_step2() -> ( + None +): + """The bridge bootstrap should keep the websocket alive after step 0 succeeds.""" + + topic = "com.apple.idmsauthwidget" + bridge_payload = { + "sessionUUID": "bridge-session", + "nextStep": "2", + "ruiURLKey": "hsa2TwoFactorAuthApprovalFlowUrl", + "txnid": "2300_282820214_S", + "salt": "c2FsdA==", + "mid": "bridge-mid", + "idmsdata": "idms-data", + "akdata": {"lat": 49.52, "lng": 6.1}, + } + websocket_urls: list[tuple[str, float, str, str]] = [] + session = MagicMock() + session.request_raw.return_value = _response(200) + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message(topic, bridge_payload, 2300), + ], + on_read=lambda read_count: ( + read_count == 1 or session.request_raw.call_count == 1 + ) + or (_ for _ in ()).throw( + AssertionError("Bridge step 0 should be posted before waiting for push") + ), + ) + + def websocket_factory( + url: str, timeout: float, origin: str, user_agent: str + ) -> _FakeWebSocket: + websocket_urls.append((url, timeout, origin, user_agent)) + return websocket + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=websocket_factory, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + websocket_url = websocket_urls[0][0] + assert websocket_url.startswith("wss://websocket.push.apple.com/v2/") + assert state.connection_path == websocket_url.rsplit("/", 1)[1] + connection_message = unhexlify(state.connection_path) + outer_fields = _decode_fields(connection_message) + inner_fields = _decode_fields(outer_fields[1][0]) + assert inner_fields[1][0] == b"\x04public-key" + assert bytes(inner_fields[3][0]).startswith(b"\x01\x03signature-for-") + assert state.push_token == b"push-token".hex() + assert state.session_uuid == "bridge-session" + assert state.next_step == "2" + assert state.rui_url_key == "hsa2TwoFactorAuthApprovalFlowUrl" + assert state.txnid == "2300_282820214_S" + assert state.salt == "c2FsdA==" + assert state.mid == "bridge-mid" + assert state.idmsdata == "idms-data" + assert state.akdata == {"lat": 49.52, "lng": 6.1} + assert state.websocket is websocket + assert websocket.sent_messages[0] == _encode_web_filter_message([topic]) + assert websocket.sent_messages[1] == _encode_ack_message( + bytes.fromhex(_topic_hash(topic)), + 2300, + ) + session.request_raw.assert_called_once_with( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/0", + json={ + "sessionUUID": "bridge-session", + "ptkn": b"push-token".hex(), + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ) + assert websocket.closed is False + bootstrapper.close(state) + assert websocket.closed is True + assert state.websocket is None + + +def test_trusted_device_bridge_rejects_malformed_push_token() -> None: + """Malformed push tokens should surface as bridge prompt failures.""" + + websocket = _FakeWebSocket( + [_encode_connection_response_with_token_b64("%%%not-base64%%%")] + ) + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + + with pytest.raises( + PyiCloudTrustedDevicePromptException, + match="Failed to bootstrap the trusted-device bridge prompt.", + ) as exc_info: + bootstrapper.start( + session=MagicMock(), + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(), + user_agent="test-agent", + ) + + assert isinstance(exc_info.value.__cause__, PyiCloudTrustedDevicePromptException) + assert "Malformed bridge push token" in str(exc_info.value.__cause__) + assert websocket.closed is True + + +def test_trusted_device_bridge_rejects_mismatched_session_uuid() -> None: + """The first bridge push should match the session UUID used for step 0.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "different-session", + "nextStep": "2", + }, + 2300, + ), + ] + ) + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.return_value = _response(200) + + with pytest.raises( + PyiCloudTrustedDevicePromptException, + match="Failed to bootstrap the trusted-device bridge prompt.", + ) as exc_info: + bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + assert isinstance(exc_info.value.__cause__, PyiCloudTrustedDevicePromptException) + assert "mismatched session UUID" in str(exc_info.value.__cause__) + assert websocket.closed is True + + +def test_trusted_device_bridge_start_propagates_unexpected_exception() -> None: + """Unexpected bootstrap bugs should surface directly instead of being wrapped.""" + + websocket = _FakeWebSocket([TypeError("boom")]) + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + + with pytest.raises(TypeError, match="boom"): + bootstrapper.start( + session=MagicMock(), + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(), + user_agent="test-agent", + ) + + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_runs_step2_step4_step6_sequence() -> None: + """Bridge-backed trusted-device verification should follow Apple's step 2/4/6 flow.""" + + topic = "com.apple.idmsauthwidget" + initial_push = { + "sessionUUID": "bridge-session", + "nextStep": "2", + "ruiURLKey": "hsa2TwoFactorAuthApprovalFlowUrl", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "mid": "bridge-mid", + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + } + server_message1_hex = "aa01" + server_message2_hex = "bb02" + step4_data = base64.b64encode( + ( + _hex_to_b64(server_message1_hex) + "_" + _hex_to_b64(server_message2_hex) + ).encode("utf-8") + ).decode("ascii") + step4_push = { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": step4_data, + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + } + step6_push = { + "sessionUUID": "bridge-session", + "nextStep": "6", + "txnid": "2300_282820214_S", + "encryptedCode": "ciphertext", + "idmsdata": "step6-idms", + "akdata": {"step": 6}, + "mid": "bridge-mid", + } + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message(topic, initial_push, 2300), + _encode_push_message(topic, step4_push, 2301), + _encode_push_message(topic, step6_push, 2302), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.return_value = "ef01" + prover.process_message2.return_value = {"isVerified": True, "key": "deadbeef"} + prover.get_key.return_value = "deadbeef" + prover.decrypt_message.return_value = "derived-device-code" + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + + session = MagicMock() + session.request_raw.side_effect = [ + _response(200), + _response(200), + _response(200), + _response(409), + _response(204), + ] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + assert ( + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + is True + ) + + prover.init_with_salt.assert_called_once_with(initial_push["salt"], "050044") + prover.process_message1.assert_called_once_with(server_message1_hex) + prover.process_message2.assert_called_once_with(server_message2_hex) + prover.decrypt_message.assert_called_once_with("ciphertext") + assert session.request_raw.call_args_list == [ + call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/0", + json={ + "sessionUUID": "bridge-session", + "ptkn": b"push-token".hex(), + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ), + call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/2", + json={ + "sessionUUID": "bridge-session", + "data": _hex_to_b64("abcd"), + "ptkn": b"push-token".hex(), + "nextStep": 2, + "idmsdata": "initial-idms", + "akdata": '{"lat":49.52}', + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ), + call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/4", + json={ + "sessionUUID": "bridge-session", + "data": _hex_to_b64("ef01"), + "ptkn": b"push-token".hex(), + "nextStep": 4, + "idmsdata": "step4-idms", + "akdata": '{"step":4}', + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ), + call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/code/validate", + json={ + "sessionUUID": "bridge-session", + "code": "derived-device-code", + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ), + call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/6", + json={ + "sessionUUID": "bridge-session", + "data": BRIDGE_DONE_DATA_B64, + "ptkn": b"push-token".hex(), + "nextStep": 6, + "idmsdata": "step6-idms", + "akdata": '{"step":6}', + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ), + ] + assert websocket.sent_messages[2] == _encode_ack_message( + bytes.fromhex(_topic_hash(topic)), + 2301, + ) + assert websocket.sent_messages[3] == _encode_ack_message( + bytes.fromhex(_topic_hash(topic)), + 2302, + ) + assert websocket.closed is True + assert state.websocket is None + + +def test_trusted_device_bridge_validate_code_accepts_step4_encrypted_code_final_push() -> ( + None +): + """Apple can finish the bridge flow with nextStep=4 when encryptedCode is present.""" + + topic = "com.apple.idmsauthwidget" + initial_push = { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + } + step4_data = base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode("utf-8") + ).decode("ascii") + prover_push = { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": step4_data, + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + } + final_push = { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "encryptedCode": "ciphertext", + "idmsdata": "final-idms", + "akdata": {"step": "final"}, + } + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message(topic, initial_push, 2300), + _encode_push_message(topic, prover_push, 2301), + _encode_push_message(topic, final_push, 2302), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.return_value = "ef01" + prover.process_message2.return_value = {"isVerified": True, "key": "deadbeef"} + prover.decrypt_message.return_value = "derived-device-code" + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [ + _response(200), + _response(200), + _response(200), + _response(200), + _response(204), + ] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + assert ( + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + is True + ) + assert session.request_raw.call_args_list[-1] == call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/4", + json={ + "sessionUUID": "bridge-session", + "data": BRIDGE_DONE_DATA_B64, + "ptkn": b"push-token".hex(), + "nextStep": 4, + "idmsdata": "final-idms", + "akdata": '{"step":"final"}', + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ) + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_returns_false_on_412() -> None: + """A bridge code-validate 412 should be treated as an invalid code, not a transport failure.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode( + "utf-8" + ) + ).decode("ascii"), + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + }, + 2301, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "6", + "txnid": "2300_282820214_S", + "encryptedCode": "ciphertext", + "idmsdata": "step6-idms", + "akdata": {"step": 6}, + }, + 2302, + ), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.return_value = "ef01" + prover.process_message2.return_value = {"isVerified": True, "key": "deadbeef"} + prover.decrypt_message.return_value = "derived-device-code" + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [ + _response(200), + _response(200), + _response(200), + _response(412), + _response(204), + ] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + assert ( + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + is False + ) + assert session.request_raw.call_args_list[-1].args[1].endswith("/bridge/step/6") + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_rejects_error_push() -> None: + """Bridge error pushes should surface as verification exceptions.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode( + "utf-8" + ) + ).decode("ascii"), + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + }, + 2301, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "6", + "txnid": "2300_282820214_S", + "ec": 7, + }, + 2302, + ), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.return_value = "ef01" + prover.process_message2.return_value = {"isVerified": True, "key": "deadbeef"} + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [_response(200), _response(200), _response(200)] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="error push", + ): + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_rejects_malformed_final_push() -> None: + """Final bridge pushes must include encryptedCode once the prover flow is complete.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode( + "utf-8" + ) + ).decode("ascii"), + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + }, + 2301, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + }, + 2302, + ), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.return_value = "ef01" + prover.process_message2.return_value = {"isVerified": True, "key": "deadbeef"} + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [_response(200), _response(200), _response(200)] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="unexpected final payload", + ): + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_rejects_mismatched_followup_push() -> None: + """Follow-up bridge pushes must stay on the same bridge session.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + _encode_push_message( + topic, + { + "sessionUUID": "different-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode( + "utf-8" + ) + ).decode("ascii"), + }, + 2301, + ), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [_response(200), _response(200)] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="mismatched session UUID", + ): + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_closes_on_timeout() -> None: + """Timeouts after prompt delivery should surface as bridge verification failures.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + socket.timeout("timed out"), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [_response(200), _response(200)] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="websocket transport error", + ): + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_wraps_step4_prover_message1_failure() -> ( + None +): + """Malformed step-4 prover data should surface as bridge verification failures.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode( + "utf-8" + ) + ).decode("ascii"), + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + }, + 2301, + ), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.side_effect = ValueError("bad point") + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [_response(200), _response(200)] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="step 4 payload is malformed", + ): + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + assert websocket.closed is True From cc74e91d0598374e6f27b7dc63eca913b800c33f Mon Sep 17 00:00:00 2001 From: mrjarnould Date: Tue, 31 Mar 2026 03:08:03 +0200 Subject: [PATCH 2/5] Trim unrelated Notes PR scope --- README.md | 52 ---- tests/test_cmdline.py | 632 ------------------------------------------ 2 files changed, 684 deletions(-) diff --git a/README.md b/README.md index 2646530e..fa2df8ad 100644 --- a/README.md +++ b/README.md @@ -1259,58 +1259,6 @@ Notes caveats: - `api.notes.raw` is available for advanced/debug workflows, but it is not the primary Notes API surface. -### Notes CLI - -### Notes CLI - -The official Typer CLI exposes `icloud notes ...` for recent-note inspection, -folder browsing, title-based search, HTML rendering, and note-id-based export. - -_List recent notes, folders, or one folder’s notes:_ - -```bash -uv run icloud notes recent --username you@example.com -uv run icloud notes folders --username you@example.com -uv run icloud notes list --username you@example.com --folder-id FOLDER_ID -uv run icloud notes list --username you@example.com --all --since PREVIOUS_CURSOR -``` - -_Search notes by title:_ - -```bash -uv run icloud notes search --username you@example.com --title "Daily Plan" -uv run icloud notes search --username you@example.com --title-contains "meeting" -``` - -`icloud notes search` is the official title-filter workflow. It uses a -recents-first search strategy and falls back to a full feed scan when needed. - -_Fetch, render, and export one note by id:_ - -```bash -uv run icloud notes get NOTE_ID --username you@example.com --with-attachments -uv run icloud notes render NOTE_ID --username you@example.com --preview-appearance dark -uv run icloud notes export NOTE_ID \ - --username you@example.com \ - --output-dir ./exports/notes_html \ - --export-mode archival \ - --assets-dir ./exports/assets -``` - -`icloud notes export` stays explicit by note id. Title filters are intentionally -handled by `icloud notes search` rather than by bulk export flags. - -_Inspect incremental changes:_ - -```bash -uv run icloud notes changes --username you@example.com --since PREVIOUS_CURSOR -uv run icloud notes sync-cursor --username you@example.com -``` - -### Notes CLI Example - -### Notes CLI Example - [`examples/notes_cli.py`](examples/notes_cli.py) is a local developer utility built on top of `api.notes`. It is useful for searching notes, inspecting the rendering pipeline, and exporting HTML, but its selection heuristics and debug diff --git a/tests/test_cmdline.py b/tests/test_cmdline.py index f1b68c69..25198b38 100644 --- a/tests/test_cmdline.py +++ b/tests/test_cmdline.py @@ -16,29 +16,6 @@ import click from typer.testing import CliRunner -from pyicloud.services.notes.models import Attachment as NoteAttachment -from pyicloud.services.notes.models import ChangeEvent as NoteChangeEvent -from pyicloud.services.notes.models import ( - Note, - NoteFolder, - NoteSummary, -) -from pyicloud.services.notes.service import NoteLockedError, NoteNotFound -from pyicloud.services.reminders.models import ( - Alarm, - AlarmWithTrigger, - Hashtag, - ListRemindersResult, - LocationTrigger, - Proximity, - RecurrenceFrequency, - RecurrenceRule, - Reminder, - ReminderChangeEvent, - RemindersList, - URLAttachment, -) - account_index_module = importlib.import_module("pyicloud.cli.account_index") cli_module = importlib.import_module("pyicloud.cli.app") context_module = importlib.import_module("pyicloud.cli.context") @@ -218,613 +195,6 @@ def delete(self, anonymous_id: str) -> dict[str, Any]: return {"anonymousId": anonymous_id, "deleted": True} -class FakeNotes: - """Notes service fixture.""" - - def __init__(self) -> None: - attachment = NoteAttachment( - id="Attachment/PDF", - filename="agenda.pdf", - uti="com.adobe.pdf", - size=12, - download_url="https://example.com/agenda.pdf", - preview_url="https://example.com/agenda-preview.pdf", - thumbnail_url="https://example.com/agenda-thumb.png", - ) - self.recent_requests: list[int] = [] - self.iter_all_requests: list[str | None] = [] - self.folder_requests: list[tuple[str, int | None]] = [] - self.render_calls: list[dict[str, Any]] = [] - self.export_calls: list[dict[str, Any]] = [] - self.change_requests: list[str | None] = [] - self.folder_rows = [ - NoteFolder( - id="Folder/NOTES", - name="Notes", - has_subfolders=False, - count=1, - ), - NoteFolder( - id="Folder/WORK", - name="Work", - has_subfolders=True, - count=3, - ), - ] - self.recent_rows = [ - NoteSummary( - id="Note/DELETED", - title="Deleted Note", - snippet="Old note", - modified_at=datetime(2026, 3, 5, tzinfo=timezone.utc), - folder_id="Folder/DELETED", - folder_name="Recently Deleted", - is_deleted=True, - is_locked=False, - ), - NoteSummary( - id="Note/DAILY", - title="Daily Plan", - snippet="Ship CLI", - modified_at=datetime(2026, 3, 4, tzinfo=timezone.utc), - folder_id="Folder/NOTES", - folder_name="Notes", - is_deleted=False, - is_locked=False, - ), - NoteSummary( - id="Note/MEETING", - title="Meeting Notes", - snippet="Discuss roadmap", - modified_at=datetime(2026, 3, 3, tzinfo=timezone.utc), - folder_id="Folder/WORK", - folder_name="Work", - is_deleted=False, - is_locked=False, - ), - ] - self.all_rows = [ - self.recent_rows[2], - NoteSummary( - id="Note/FOLLOWUP", - title="Meeting Follow-up", - snippet="Send recap", - modified_at=datetime(2026, 3, 2, tzinfo=timezone.utc), - folder_id="Folder/WORK", - folder_name="Work", - is_deleted=False, - is_locked=False, - ), - self.recent_rows[1], - self.recent_rows[2], - ] - self.notes = { - "Note/DAILY": Note( - id="Note/DAILY", - title="Daily Plan", - snippet="Ship CLI", - modified_at=datetime(2026, 3, 4, tzinfo=timezone.utc), - folder_id="Folder/NOTES", - folder_name="Notes", - is_deleted=False, - is_locked=False, - text="Ship CLI", - html="

Ship CLI

", - attachments=[attachment], - ), - "Note/MEETING": Note( - id="Note/MEETING", - title="Meeting Notes", - snippet="Discuss roadmap", - modified_at=datetime(2026, 3, 3, tzinfo=timezone.utc), - folder_id="Folder/WORK", - folder_name="Work", - is_deleted=False, - is_locked=False, - text="Discuss roadmap", - html="

Discuss roadmap

", - attachments=[attachment], - ), - "Note/FOLLOWUP": Note( - id="Note/FOLLOWUP", - title="Meeting Follow-up", - snippet="Send recap", - modified_at=datetime(2026, 3, 2, tzinfo=timezone.utc), - folder_id="Folder/WORK", - folder_name="Work", - is_deleted=False, - is_locked=False, - text="Send recap", - html="

Send recap

", - attachments=None, - ), - } - self.change_rows = [ - NoteChangeEvent(type="updated", note=self.recent_rows[1]), - NoteChangeEvent(type="deleted", note=self.recent_rows[0]), - ] - self.cursor = "notes-cursor-1" - - @staticmethod - def _matches_id(note_id: str, query: str) -> bool: - return note_id == query or note_id.split("/", 1)[-1] == query - - def recents(self, *, limit: int = 50): - self.recent_requests.append(limit) - return list(self.recent_rows[:limit]) - - def folders(self): - return list(self.folder_rows) - - def in_folder(self, folder_id: str, limit: int | None = None): - self.folder_requests.append((folder_id, limit)) - rows = [row for row in self.all_rows if row.folder_id == folder_id] - return list(rows[:limit] if limit is not None else rows) - - def iter_all(self, *, since: Optional[str] = None): - self.iter_all_requests.append(since) - return iter(self.all_rows) - - def get(self, note_id: str, *, with_attachments: bool = False): - if self._matches_id("Note/LOCKED", note_id): - raise NoteLockedError(f"Note is locked: {note_id}") - for candidate_id, note in self.notes.items(): - if self._matches_id(candidate_id, note_id): - attachments = note.attachments if with_attachments else None - return note.model_copy(update={"attachments": attachments}) - raise NoteNotFound(f"Note not found: {note_id}") - - def render_note(self, note_id: str, **kwargs: Any) -> str: - note = self.get(note_id, with_attachments=False) - self.render_calls.append({"note_id": note.id, **kwargs}) - return note.html or f"

{note.id}

" - - def export_note(self, note_id: str, output_dir: str, **kwargs: Any) -> str: - note = self.get(note_id, with_attachments=False) - path = Path(output_dir) / f"{note.id.split('/', 1)[-1].lower()}.html" - self.export_calls.append( - {"note_id": note.id, "output_dir": output_dir, **kwargs} - ) - return str(path) - - def iter_changes(self, *, since: Optional[str] = None): - self.change_requests.append(since) - return iter(self.change_rows) - - def sync_cursor(self) -> str: - return self.cursor - - -class FakeReminders: - """Reminders service fixture.""" - - def __init__(self) -> None: - self.list_rows = { - "List/INBOX": RemindersList( - id="List/INBOX", - title="Inbox", - color='{"daHexString":"#007AFF","ckSymbolicColorName":"blue"}', - count=0, - ), - "List/WORK": RemindersList( - id="List/WORK", - title="Work", - color='{"daHexString":"#34C759","ckSymbolicColorName":"green"}', - count=0, - ), - } - self.reminder_rows = { - "Reminder/A": Reminder( - id="Reminder/A", - list_id="List/INBOX", - title="Buy milk", - desc="2 percent", - completed=False, - due_date=datetime(2026, 3, 31, 9, 0, tzinfo=timezone.utc), - priority=1, - flagged=True, - all_day=False, - time_zone="Europe/Luxembourg", - alarm_ids=["Alarm/A"], - hashtag_ids=["Hashtag/ERRANDS"], - attachment_ids=["Attachment/LINK"], - recurrence_rule_ids=["Recurrence/WEEKLY"], - parent_reminder_id="Reminder/PARENT", - created=datetime(2026, 3, 1, tzinfo=timezone.utc), - modified=datetime(2026, 3, 4, tzinfo=timezone.utc), - ), - "Reminder/B": Reminder( - id="Reminder/B", - list_id="List/INBOX", - title="Pay rent", - desc="", - completed=True, - completed_date=datetime(2026, 3, 2, tzinfo=timezone.utc), - priority=0, - flagged=False, - all_day=False, - created=datetime(2026, 3, 1, tzinfo=timezone.utc), - modified=datetime(2026, 3, 2, tzinfo=timezone.utc), - ), - "Reminder/C": Reminder( - id="Reminder/C", - list_id="List/WORK", - title="Prepare deck", - desc="Slides for review", - completed=False, - priority=5, - flagged=False, - all_day=False, - created=datetime(2026, 3, 3, tzinfo=timezone.utc), - modified=datetime(2026, 3, 4, tzinfo=timezone.utc), - ), - } - self.alarm_rows = { - "Alarm/A": Alarm( - id="Alarm/A", - alarm_uid="alarm-a", - reminder_id="Reminder/A", - trigger_id="Trigger/A", - ) - } - self.trigger_rows = { - "Trigger/A": LocationTrigger( - id="Trigger/A", - alarm_id="Alarm/A", - title="Office", - address="1 Infinite Loop", - latitude=37.3318, - longitude=-122.0312, - radius=150.0, - proximity=Proximity.ARRIVING, - location_uid="office", - ) - } - self.hashtag_rows = { - "Hashtag/ERRANDS": Hashtag( - id="Hashtag/ERRANDS", - name="errands", - reminder_id="Reminder/A", - created=datetime(2026, 3, 1, tzinfo=timezone.utc), - ) - } - self.attachment_rows = { - "Attachment/LINK": URLAttachment( - id="Attachment/LINK", - reminder_id="Reminder/A", - url="https://example.com/checklist", - uti="public.url", - ) - } - self.recurrence_rows = { - "Recurrence/WEEKLY": RecurrenceRule( - id="Recurrence/WEEKLY", - reminder_id="Reminder/A", - frequency=RecurrenceFrequency.WEEKLY, - interval=1, - occurrence_count=0, - first_day_of_week=1, - ) - } - self.snapshot_requests: list[dict[str, Any]] = [] - self.change_requests: list[str | None] = [] - self.cursor = "reminders-cursor-1" - - @staticmethod - def _matches_id(record_id: str, query: str) -> bool: - return record_id == query or record_id.split("/", 1)[-1] == query - - def _find_reminder(self, reminder_id: str) -> Reminder: - for candidate_id, reminder in self.reminder_rows.items(): - if self._matches_id(candidate_id, reminder_id): - return reminder - raise LookupError(f"Reminder not found: {reminder_id}") - - def lists(self): - for row in self.list_rows.values(): - row.count = sum( - 1 - for reminder in self.reminder_rows.values() - if reminder.list_id == row.id and not reminder.deleted - ) - return list(self.list_rows.values()) - - def reminders(self, list_id: Optional[str] = None): - rows = [ - reminder - for reminder in self.reminder_rows.values() - if not reminder.deleted and (list_id is None or reminder.list_id == list_id) - ] - return list(rows) - - def list_reminders( - self, - list_id: str, - include_completed: bool = False, - results_limit: int = 200, - ) -> ListRemindersResult: - normalized = list_id if list_id.startswith("List/") else f"List/{list_id}" - self.snapshot_requests.append( - { - "list_id": normalized, - "include_completed": include_completed, - "results_limit": results_limit, - } - ) - reminders = [ - reminder - for reminder in self.reminder_rows.values() - if reminder.list_id == normalized - and not reminder.deleted - and (include_completed or not reminder.completed) - ][:results_limit] - reminder_ids = {reminder.id for reminder in reminders} - return ListRemindersResult( - reminders=reminders, - alarms={ - alarm_id: alarm - for alarm_id, alarm in self.alarm_rows.items() - if alarm.reminder_id in reminder_ids - }, - triggers={ - trigger_id: trigger - for trigger_id, trigger in self.trigger_rows.items() - if any( - alarm.trigger_id == trigger_id - for alarm in self.alarm_rows.values() - if alarm.reminder_id in reminder_ids - ) - }, - attachments={ - attachment_id: attachment - for attachment_id, attachment in self.attachment_rows.items() - if attachment.reminder_id in reminder_ids - }, - hashtags={ - hashtag_id: hashtag - for hashtag_id, hashtag in self.hashtag_rows.items() - if hashtag.reminder_id in reminder_ids - }, - recurrence_rules={ - rule_id: rule - for rule_id, rule in self.recurrence_rows.items() - if rule.reminder_id in reminder_ids - }, - ) - - def get(self, reminder_id: str) -> Reminder: - return self._find_reminder(reminder_id) - - def create( - self, - list_id: str, - title: str, - desc: str = "", - completed: bool = False, - due_date: Optional[datetime] = None, - priority: int = 0, - flagged: bool = False, - all_day: bool = False, - time_zone: Optional[str] = None, - parent_reminder_id: Optional[str] = None, - ) -> Reminder: - next_id = f"Reminder/CREATED-{len(self.reminder_rows) + 1}" - reminder = Reminder( - id=next_id, - list_id=list_id, - title=title, - desc=desc, - completed=completed, - due_date=due_date, - priority=priority, - flagged=flagged, - all_day=all_day, - time_zone=time_zone, - parent_reminder_id=parent_reminder_id, - created=datetime(2026, 3, 30, tzinfo=timezone.utc), - modified=datetime(2026, 3, 30, tzinfo=timezone.utc), - ) - self.reminder_rows[reminder.id] = reminder - return reminder - - def update(self, reminder: Reminder) -> None: - self.reminder_rows[reminder.id] = reminder - - def delete(self, reminder: Reminder) -> None: - reminder.deleted = True - self.reminder_rows[reminder.id] = reminder - - def add_location_trigger( - self, - reminder: Reminder, - title: str = "", - address: str = "", - latitude: float = 0.0, - longitude: float = 0.0, - radius: float = 100.0, - proximity: Proximity = Proximity.ARRIVING, - ) -> tuple[Alarm, LocationTrigger]: - index = len(self.alarm_rows) + 1 - alarm = Alarm( - id=f"Alarm/{index}", - alarm_uid=f"alarm-{index}", - reminder_id=reminder.id, - trigger_id=f"Trigger/{index}", - ) - trigger = LocationTrigger( - id=f"Trigger/{index}", - alarm_id=alarm.id, - title=title, - address=address, - latitude=latitude, - longitude=longitude, - radius=radius, - proximity=proximity, - location_uid=f"location-{index}", - ) - self.alarm_rows[alarm.id] = alarm - self.trigger_rows[trigger.id] = trigger - reminder.alarm_ids.append(alarm.id) - return alarm, trigger - - def create_hashtag(self, reminder: Reminder, name: str) -> Hashtag: - hashtag = Hashtag( - id=f"Hashtag/{name.upper()}", - name=name, - reminder_id=reminder.id, - created=datetime(2026, 3, 30, tzinfo=timezone.utc), - ) - self.hashtag_rows[hashtag.id] = hashtag - reminder.hashtag_ids.append(hashtag.id) - return hashtag - - def update_hashtag(self, hashtag: Hashtag, name: str) -> None: - hashtag.name = name - - def delete_hashtag(self, reminder: Reminder, hashtag: Hashtag) -> None: - reminder.hashtag_ids = [ - row_id for row_id in reminder.hashtag_ids if row_id != hashtag.id - ] - self.hashtag_rows.pop(hashtag.id, None) - - def create_url_attachment( - self, reminder: Reminder, url: str, uti: str = "public.url" - ) -> URLAttachment: - attachment = URLAttachment( - id=f"Attachment/{len(self.attachment_rows) + 1}", - reminder_id=reminder.id, - url=url, - uti=uti, - ) - self.attachment_rows[attachment.id] = attachment - reminder.attachment_ids.append(attachment.id) - return attachment - - def update_attachment( - self, - attachment: URLAttachment, - *, - url: Optional[str] = None, - uti: Optional[str] = None, - filename: Optional[str] = None, - file_size: Optional[int] = None, - width: Optional[int] = None, - height: Optional[int] = None, - ) -> None: - if url is not None: - attachment.url = url - if uti is not None: - attachment.uti = uti - - def delete_attachment(self, reminder: Reminder, attachment: URLAttachment) -> None: - reminder.attachment_ids = [ - row_id for row_id in reminder.attachment_ids if row_id != attachment.id - ] - self.attachment_rows.pop(attachment.id, None) - - def create_recurrence_rule( - self, - reminder: Reminder, - *, - frequency: RecurrenceFrequency = RecurrenceFrequency.DAILY, - interval: int = 1, - occurrence_count: int = 0, - first_day_of_week: int = 0, - ) -> RecurrenceRule: - rule = RecurrenceRule( - id=f"Recurrence/{len(self.recurrence_rows) + 1}", - reminder_id=reminder.id, - frequency=frequency, - interval=interval, - occurrence_count=occurrence_count, - first_day_of_week=first_day_of_week, - ) - self.recurrence_rows[rule.id] = rule - reminder.recurrence_rule_ids.append(rule.id) - return rule - - def update_recurrence_rule( - self, - recurrence_rule: RecurrenceRule, - *, - frequency: Optional[RecurrenceFrequency] = None, - interval: Optional[int] = None, - occurrence_count: Optional[int] = None, - first_day_of_week: Optional[int] = None, - ) -> None: - if frequency is not None: - recurrence_rule.frequency = frequency - if interval is not None: - recurrence_rule.interval = interval - if occurrence_count is not None: - recurrence_rule.occurrence_count = occurrence_count - if first_day_of_week is not None: - recurrence_rule.first_day_of_week = first_day_of_week - - def delete_recurrence_rule( - self, reminder: Reminder, recurrence_rule: RecurrenceRule - ) -> None: - reminder.recurrence_rule_ids = [ - row_id - for row_id in reminder.recurrence_rule_ids - if row_id != recurrence_rule.id - ] - self.recurrence_rows.pop(recurrence_rule.id, None) - - def alarms_for(self, reminder: Reminder) -> list[AlarmWithTrigger]: - rows = [] - for alarm_id in reminder.alarm_ids: - alarm = self.alarm_rows[alarm_id] - rows.append( - AlarmWithTrigger( - alarm=alarm, - trigger=self.trigger_rows.get(alarm.trigger_id), - ) - ) - return rows - - def tags_for(self, reminder: Reminder) -> list[Hashtag]: - return [ - self.hashtag_rows[row_id] - for row_id in reminder.hashtag_ids - if row_id in self.hashtag_rows - ] - - def attachments_for(self, reminder: Reminder) -> list[URLAttachment]: - return [ - self.attachment_rows[row_id] - for row_id in reminder.attachment_ids - if row_id in self.attachment_rows - ] - - def recurrence_rules_for(self, reminder: Reminder) -> list[RecurrenceRule]: - return [ - self.recurrence_rows[row_id] - for row_id in reminder.recurrence_rule_ids - if row_id in self.recurrence_rows - ] - - def iter_changes(self, *, since: Optional[str] = None): - self.change_requests.append(since) - return iter( - [ - ReminderChangeEvent( - type="updated", - reminder_id="Reminder/A", - reminder=self.reminder_rows["Reminder/A"], - ), - ReminderChangeEvent( - type="deleted", - reminder_id="Reminder/Z", - reminder=None, - ), - ] - ) - - def sync_cursor(self) -> str: - return self.cursor - - class FakeAPI: """Authenticated API fixture.""" @@ -955,8 +325,6 @@ def __init__( all=photo_album, ) self.hidemyemail = FakeHideMyEmail() - self.notes = FakeNotes() - self.reminders = FakeReminders() def _logout( self, From d8c78a85ca6785cd9f2b7400aba4554113847dc6 Mon Sep 17 00:00:00 2001 From: mrjarnould Date: Tue, 31 Mar 2026 03:14:24 +0200 Subject: [PATCH 3/5] Address CodeRabbit review comments --- README.md | 5 ++-- pyicloud/base.py | 5 +++- pyicloud/cli/context.py | 26 ++++++++-------- pyicloud/exceptions.py | 4 +-- pyicloud/session.py | 25 +++++++++++----- tests/test_base.py | 66 +++++++++++++++++++++++++++++++++++++++-- tests/test_cmdline.py | 16 ++++++++++ 7 files changed, 120 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index fa2df8ad..9e6310e6 100644 --- a/README.md +++ b/README.md @@ -183,8 +183,9 @@ authentication (2SA)](https://support.apple.com/en-us/HT204152) for the account you will have to do some extra work: For HSA2 accounts, `request_2fa_code()` now starts Apple's active delivery -route for the current challenge. Depending on the account and session, that may -be a trusted-device prompt, an SMS code, or a security-key flow. +route for code-based challenges. Depending on the account and session, that may +be a trusted-device prompt or an SMS code. Security-key challenges are handled +separately via `security_key_names` / `confirm_security_key()`. ```python import sys diff --git a/pyicloud/base.py b/pyicloud/base.py index f26bd2d4..b4f854f9 100644 --- a/pyicloud/base.py +++ b/pyicloud/base.py @@ -36,6 +36,7 @@ PyiCloudServiceNotActivatedException, PyiCloudServiceUnavailable, PyiCloudTrustedDevicePromptException, + PyiCloudTrustedDeviceVerificationException, ) from pyicloud.hsa2_bridge import ( Hsa2BootContext, @@ -1135,6 +1136,8 @@ def validate_2fa_code(self, code: str) -> bool: return False else: self._validate_trusted_device_code(code) + except PyiCloudTrustedDeviceVerificationException: + raise except PyiCloudAPIResponseException: # Wrong verification code LOGGER.error("Code verification failed.") @@ -1168,7 +1171,7 @@ def _validate_sms_code(self, code: str) -> None: data: dict[str, Any] = { "phoneNumber": trusted_phone_number.as_phone_number_payload(), "securityCode": {"code": code}, - "mode": trusted_phone_number.push_mode, + "mode": trusted_phone_number.push_mode or "sms", } headers: dict[str, Any] = self._get_auth_headers( {"Accept": f"{CONTENT_TYPE_JSON}, {CONTENT_TYPE_TEXT}"} diff --git a/pyicloud/cli/context.py b/pyicloud/cli/context.py index fd780738..319fbc1c 100644 --- a/pyicloud/cli/context.py +++ b/pyicloud/cli/context.py @@ -337,20 +337,22 @@ def _handle_2fa(self, api: PyiCloudService) -> None: "Two-factor authentication is required, but interactive prompts are disabled." ) try: - if api.request_2fa_code(): - notice = getattr(api, "two_factor_delivery_notice", None) - if notice: - self.console.print(notice) + if not api.request_2fa_code(): + raise CLIAbort( + "This 2FA challenge requires a security key. Connect one and retry." + ) + + notice = getattr(api, "two_factor_delivery_notice", None) + if notice: + self.console.print(notice) - delivery_method = getattr( - api, "two_factor_delivery_method", "unknown" + delivery_method = getattr(api, "two_factor_delivery_method", "unknown") + if delivery_method == "trusted_device": + self.console.print( + "Requested a 2FA prompt on your trusted Apple devices." ) - if delivery_method == "trusted_device": - self.console.print( - "Requested a 2FA prompt on your trusted Apple devices." - ) - elif delivery_method == "sms": - self.console.print("Requested a 2FA code by SMS.") + elif delivery_method == "sms": + self.console.print("Requested a 2FA code by SMS.") except PyiCloudNoTrustedNumberAvailable as exc: raise CLIAbort( "Two-factor authentication requires a trusted phone number, " diff --git a/pyicloud/exceptions.py b/pyicloud/exceptions.py index d5383faf..1a0c02a5 100644 --- a/pyicloud/exceptions.py +++ b/pyicloud/exceptions.py @@ -99,11 +99,11 @@ class PyiCloudNoTrustedNumberAvailable(PyiCloudException): """iCloud no trusted number exception.""" -class PyiCloudTrustedDevicePromptException(PyiCloudException): +class PyiCloudTrustedDevicePromptException(PyiCloudAPIResponseException): """Trusted-device prompt bootstrap exception.""" -class PyiCloudTrustedDeviceVerificationException(PyiCloudException): +class PyiCloudTrustedDeviceVerificationException(PyiCloudAPIResponseException): """Trusted-device bridge verification exception.""" diff --git a/pyicloud/session.py b/pyicloud/session.py index 9d8adfd9..51b63948 100644 --- a/pyicloud/session.py +++ b/pyicloud/session.py @@ -241,11 +241,14 @@ def _request_raw( method, url, ) - response: Response = super().request( - method=method, - url=url, - **kwargs, - ) + try: + response: Response = super().request( + method=method, + url=url, + **kwargs, + ) + except requests.exceptions.RequestException as err: + self._raise_request_exception(err) self._update_session_data(response) self._save_session_data() return response @@ -298,13 +301,19 @@ def _request( self._decode_json_response(response) return response - except requests.HTTPError as err: + except requests.exceptions.RequestException as err: + self._raise_request_exception(err) + + @staticmethod + def _raise_request_exception(err: requests.exceptions.RequestException) -> NoReturn: + """Normalize low-level requests failures into the session's public error type.""" + + if isinstance(err, requests.HTTPError) and err.response is not None: raise PyiCloudAPIResponseException( reason=err.response.text, code=err.response.status_code, ) from err - except requests.exceptions.RequestException as err: - raise PyiCloudAPIResponseException("Request failed to iCloud") from err + raise PyiCloudAPIResponseException("Request failed to iCloud") from err def _handle_request_error( self, diff --git a/tests/test_base.py b/tests/test_base.py index 7791b75f..0f65442e 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, mock_open, patch import pytest +import requests from fido2.hid import CtapHidDevice from requests import HTTPError, Response @@ -642,6 +643,46 @@ def test_validate_2fa_code_uses_nested_sms_phone_number( } +def test_validate_2fa_code_defaults_sms_mode_when_push_mode_missing( + pyicloud_service: PyiCloudService, +) -> None: + """Missing SMS pushMode should still validate using the delivery mode used to trigger SMS.""" + + pyicloud_service.data = {"dsInfo": {"hsaVersion": 1}, "hsaChallengeRequired": False} + pyicloud_service._auth_data = { + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": None, + } + } + } + pyicloud_service._two_factor_delivery_method = "sms" + pyicloud_service.trust_session = MagicMock( + side_effect=lambda: pyicloud_service.data.update({"hsaTrustedBrowser": True}) + or True + ) + + with patch("pyicloud.base.PyiCloudSession") as mock_session: + pyicloud_service._session = mock_session + mock_session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + "session_token": "test_session_token", + } + + mock_post_response = MagicMock() + mock_post_response.status_code = 200 + mock_post_response.json.return_value = {"success": True} + mock_session.post.return_value = mock_post_response + + assert pyicloud_service.validate_2fa_code("123456") + + kwargs = mock_session.post.call_args.kwargs + assert kwargs["json"]["mode"] == "sms" + + def test_validate_2fa_code_failure(pyicloud_service: PyiCloudService) -> None: """Test the validate_2fa_code method with an invalid code.""" exception = PyiCloudAPIResponseException("Invalid code") @@ -965,8 +1006,9 @@ def test_session_persistence_excludes_trusted_device_bridge_state( ) -> None: """Bridge-only state should remain in memory and never be written to persisted session files.""" - temp_root = Path(tempfile.gettempdir()) / "python-test-results" / "bridge-auth" - temp_root.mkdir(parents=True, exist_ok=True) + test_base = Path(tempfile.gettempdir()) / "python-test-results" + test_base.mkdir(parents=True, exist_ok=True) + temp_root = Path(tempfile.mkdtemp(prefix="bridge-auth-", dir=test_base)) session = PyiCloudSession( service=pyicloud_service_working, client_id="", @@ -1051,6 +1093,26 @@ def test_request_failure(pyicloud_service_working: PyiCloudService) -> None: assert open_mock.call_count == 2 +def test_request_raw_normalizes_transport_failure( + pyicloud_service_working: PyiCloudService, +) -> None: + """Raw requests should keep the session's normalized transport failure contract.""" + + with patch("requests.Session.request") as mock_request: + mock_request.side_effect = requests.exceptions.Timeout("timed out") + test_base = Path(tempfile.gettempdir()) / "python-test-results" + test_base.mkdir(parents=True, exist_ok=True) + temp_root = Path(tempfile.mkdtemp(prefix="request-raw-", dir=test_base)) + pyicloud_session = PyiCloudSession( + pyicloud_service_working, "", cookie_directory=str(temp_root) + ) + + with pytest.raises( + PyiCloudAPIResponseException, match="Request failed to iCloud" + ): + pyicloud_session.request_raw("GET", "https://example.com") + + def test_request_with_custom_headers(pyicloud_service_working: PyiCloudService) -> None: """Test the request method with custom headers.""" with ( diff --git a/tests/test_cmdline.py b/tests/test_cmdline.py index 25198b38..5b5598fe 100644 --- a/tests/test_cmdline.py +++ b/tests/test_cmdline.py @@ -1717,6 +1717,22 @@ def request_prompt() -> bool: fake_api.validate_2fa_code.assert_called_once_with("123456") +def test_code_prompt_aborts_when_request_2fa_code_requires_security_key() -> None: + """Auth login should not enter the numeric 2FA prompt loop for key-only challenges.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + fake_api.request_2fa_code.return_value = False + + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code != 0 + assert result.exception.args[0] == ( + "This 2FA challenge requires a security key. Connect one and retry." + ) + fake_api.validate_2fa_code.assert_not_called() + + def test_trusted_device_2fa_retries_invalid_codes_before_success() -> None: """Auth login should allow up to three trusted-device 2FA attempts.""" From df7330903ad71080a1bcd08498bdc59c3400147d Mon Sep 17 00:00:00 2001 From: mrjarnould Date: Tue, 31 Mar 2026 03:28:31 +0200 Subject: [PATCH 4/5] Add docstrings for auth bridge PR scope --- pyicloud/base.py | 6 +++ pyicloud/cli/context.py | 7 +++ pyicloud/exceptions.py | 5 ++ pyicloud/hsa2_bridge.py | 83 ++++++++++++++++++++++++++++++++-- pyicloud/hsa2_bridge_prover.py | 53 ++++++++++++++++++++++ pyicloud/session.py | 4 ++ 6 files changed, 155 insertions(+), 3 deletions(-) diff --git a/pyicloud/base.py b/pyicloud/base.py index b4f854f9..572fdda6 100644 --- a/pyicloud/base.py +++ b/pyicloud/base.py @@ -252,6 +252,7 @@ def __init__( authenticate: bool = True, cloudkit_validation_extra: Optional[CloudKitExtraMode] = None, ) -> None: + """Initialize a service session for one Apple ID account.""" self._is_china_mainland: bool = ( environ.get("icloud_china", "0") == "1" if china_mainland is None @@ -527,6 +528,7 @@ def logout( } def _authenticate(self) -> None: + """Authenticate with either the cached session token or fresh credentials.""" LOGGER.debug("Authenticating as %s", self.account_name) try: @@ -700,6 +702,7 @@ def _validate_token(self) -> Any: def _get_auth_headers( self, overrides: Optional[dict[str, Any]] = None ) -> dict[str, Any]: + """Build Apple auth headers for IDMS, bridge, and verification requests.""" headers: dict[str, Any] = _AUTH_HEADERS_JSON.copy() headers.update( { @@ -728,6 +731,7 @@ def session(self) -> PyiCloudSession: return self._session def _is_mfa_required(self) -> bool: + """Return whether the current auth state still requires MFA completion.""" return ( self.data.get("hsaChallengeRequired", False) or not self.is_trusted_session @@ -1414,7 +1418,9 @@ def account_name(self) -> str: return self._apple_id def __str__(self) -> str: + """Return a concise human-readable service description.""" return f"iCloud API: {self.account_name}" def __repr__(self) -> str: + """Mirror ``__str__`` for interactive inspection.""" return f"<{self}>" diff --git a/pyicloud/cli/context.py b/pyicloud/cli/context.py index 319fbc1c..84d29791 100644 --- a/pyicloud/cli/context.py +++ b/pyicloud/cli/context.py @@ -99,6 +99,7 @@ def __init__( log_level: LogLevel, output_format: OutputFormat, ) -> None: + """Capture the CLI options and shared runtime state for one invocation.""" self.username = (username or "").strip() self.password = password self.china_mainland = china_mainland @@ -235,6 +236,7 @@ def remember_account(self, api: PyiCloudService, *, select: bool = True) -> None self._resolved_username = api.account_name def _resolve_username(self) -> str: + """Resolve the Apple ID to use for the current CLI command.""" if self._resolved_username: return self._resolved_username @@ -280,6 +282,7 @@ def multiple_logged_in_accounts_message(usernames: list[str]) -> str: ) def _password_for_login(self, username: str) -> tuple[Optional[str], Optional[str]]: + """Return the password and its source for an interactive login flow.""" if self.password: return self.password, "explicit" @@ -293,6 +296,7 @@ def _password_for_login(self, username: str) -> tuple[Optional[str], Optional[st return utils.get_password(username, interactive=True), "prompt" def _configure_logging(self) -> None: + """Apply the requested log level once for the current CLI process.""" if self._logging_configured: return logging.basicConfig(level=self.log_level.logging_level()) @@ -306,6 +310,7 @@ def _stored_password_for_session(self, username: str) -> Optional[str]: return utils.get_password_from_keyring(username) def _prompt_index(self, prompt: str, count: int) -> int: + """Prompt for a zero-based selection index when multiple choices exist.""" if count <= 1 or not self.interactive: return 0 raw = typer.prompt(prompt, default="0") @@ -318,6 +323,7 @@ def _prompt_index(self, prompt: str, count: int) -> int: return idx def _handle_2fa(self, api: PyiCloudService) -> None: + """Complete Apple's HSA2 flow using a security key or code-based challenge.""" fido2_devices = list(getattr(api, "fido2_devices", []) or []) if fido2_devices: self.console.print("Security key verification required.") @@ -385,6 +391,7 @@ def _handle_2fa(self, api: PyiCloudService) -> None: api.trust_session() def _handle_2sa(self, api: PyiCloudService) -> None: + """Complete Apple's legacy two-step authentication flow.""" devices = list(api.trusted_devices or []) if not devices: raise CLIAbort( diff --git a/pyicloud/exceptions.py b/pyicloud/exceptions.py index 1a0c02a5..b0772743 100644 --- a/pyicloud/exceptions.py +++ b/pyicloud/exceptions.py @@ -31,6 +31,7 @@ def __init__( code: Optional[Union[int, str]] = None, response: Optional[Response] = None, ) -> None: + """Capture a normalized API error and the optional HTTP context.""" self.reason: str = reason self.code: Optional[Union[int, str]] = code self.response: Optional[Response] = response @@ -58,6 +59,7 @@ def __init__( *args, response: Optional[Response] = None, ) -> None: + """Initialize a login failure with optional HTTP response details.""" self.response: Optional[Response] = response message: str = msg or "Failed login to iCloud" if response is not None and response.text: @@ -73,6 +75,7 @@ class PyiCloud2FARequiredException(PyiCloudException): """iCloud 2FA required exception.""" def __init__(self, apple_id: str, response: Response) -> None: + """Initialize a 2FA-required error for an HSA2 login challenge.""" message: str = f"2FA authentication required for account: {apple_id} (HSA2)" super().__init__(message) self.response: Response = response @@ -82,6 +85,7 @@ class PyiCloud2SARequiredException(PyiCloudException): """iCloud 2SA required exception.""" def __init__(self, apple_id: str) -> None: + """Initialize a 2SA-required error for a legacy login challenge.""" message: str = f"Two-step authentication required for account: {apple_id}" super().__init__(message) @@ -90,6 +94,7 @@ class PyiCloudAuthRequiredException(PyiCloudException): """iCloud re-authentication required exception.""" def __init__(self, apple_id: str, response: Response) -> None: + """Initialize a reauthentication-required error with the triggering response.""" message: str = f"Re-authentication required for account: {apple_id}" super().__init__(message) self.response: Response = response diff --git a/pyicloud/hsa2_bridge.py b/pyicloud/hsa2_bridge.py index 9dded537..d9b8209d 100644 --- a/pyicloud/hsa2_bridge.py +++ b/pyicloud/hsa2_bridge.py @@ -79,6 +79,7 @@ class Hsa2BootContext: @classmethod def from_auth_options(cls, auth_options: Mapping[str, Any]) -> "Hsa2BootContext": + """Build a normalized boot context from Apple's auth-options payload.""" bridge_initiate_data = auth_options.get("bridgeInitiateData") if not isinstance(bridge_initiate_data, dict): bridge_initiate_data = {} @@ -156,6 +157,7 @@ class _BridgePushPayloadModel(BaseModel): @field_validator("session_uuid") @classmethod def _validate_session_uuid(cls, value: str) -> str: + """Reject blank bridge session identifiers.""" if not value.strip(): raise ValueError("sessionUUID must not be blank") return value @@ -172,6 +174,7 @@ def _validate_session_uuid(cls, value: str) -> str: def _validate_optional_non_empty_strings( cls, value: Optional[str] ) -> Optional[str]: + """Reject present-but-blank optional bridge string fields.""" if value is not None and not value.strip(): raise ValueError("Bridge payload strings must not be blank") return value @@ -179,6 +182,7 @@ def _validate_optional_non_empty_strings( @field_validator("next_step") @classmethod def _validate_next_step(cls, value: Optional[str | int]) -> Optional[str | int]: + """Reject blank next-step markers while allowing ints or strings.""" if isinstance(value, str) and not value.strip(): raise ValueError("nextStep must not be blank") return value @@ -203,6 +207,7 @@ class BridgePushPayload: @classmethod def from_payload(cls, payload: dict[str, Any]) -> "BridgePushPayload": + """Validate and normalize one decoded bridge push payload.""" try: validated = _BridgePushPayloadModel.model_validate(payload) except ValidationError as exc: @@ -291,6 +296,7 @@ class BridgeStepRequest: akdata: Any = None def as_json(self) -> dict[str, Any]: + """Serialize the step request into Apple's JSON envelope.""" payload: dict[str, Any] = { "sessionUUID": self.session_uuid, "data": self.data, @@ -316,6 +322,7 @@ class BridgeCodeValidateRequest: code: str def as_json(self) -> dict[str, str]: + """Serialize the final bridge code-validation request body.""" return { "sessionUUID": self.session_uuid, "code": self.code, @@ -324,6 +331,8 @@ def as_json(self) -> dict[str, str]: @dataclass(frozen=True) class _ConnectionResponse: + """Decoded server response for the initial websocket bootstrap.""" + push_token_b64: str = "" status: int = 0 server_timestamp_seconds: Optional[int] = None @@ -331,6 +340,8 @@ class _ConnectionResponse: @dataclass(frozen=True) class _PushMessage: + """Decoded APNS-style push frame from the bridge websocket.""" + topic: bytes message_id: int payload: bytes @@ -338,6 +349,8 @@ class _PushMessage: @dataclass(frozen=True) class _ChannelSubscriptionResponse: + """Decoded response to the bridge topic subscription request.""" + message_id: int = 0 status: int = 0 retry_interval_seconds: int = 0 @@ -346,6 +359,8 @@ class _ChannelSubscriptionResponse: @dataclass(frozen=True) class _AcknowledgementMessage: + """Decoded acknowledgment frame emitted by Apple's bridge service.""" + topic: bytes message_id: int delivery_status: int = 0 @@ -353,6 +368,8 @@ class _AcknowledgementMessage: @dataclass(frozen=True) class _ServerMessage: + """One websocket frame decoded into its known top-level message variants.""" + connection_response: Optional[_ConnectionResponse] = None push_message: Optional[_PushMessage] = None channel_subscription_response: Optional[_ChannelSubscriptionResponse] = None @@ -361,15 +378,23 @@ class _ServerMessage: class _WebSocketLike(Protocol): - def send_binary(self, payload: bytes) -> None: ... + """Protocol for the minimal websocket operations used by the bridge flow.""" - def read_message(self) -> bytes: ... + def send_binary(self, payload: bytes) -> None: + """Send one binary websocket message.""" - def close(self) -> None: ... + def read_message(self) -> bytes: + """Read one complete websocket message payload.""" + + def close(self) -> None: + """Close the websocket transport.""" class _InvalidNonceError(Exception): + """Signal Apple's INVALID_NONCE response along with the server timestamp.""" + def __init__(self, server_timestamp_ms: int) -> None: + """Capture the server timestamp returned with INVALID_NONCE.""" super().__init__("Invalid nonce from bridge server.") self.server_timestamp_ms = server_timestamp_ms @@ -378,6 +403,7 @@ class _BootArgsHTMLParser(HTMLParser): """Extract the JSON body from Apple's boot_args script tag.""" def __init__(self) -> None: + """Initialize parser state for the first matching boot_args script tag.""" super().__init__() self._collecting = False self._found = False @@ -385,9 +411,11 @@ def __init__(self) -> None: @property def payload(self) -> str: + """Return the collected boot_args JSON text.""" return "".join(self._chunks).strip() def handle_starttag(self, tag: str, attrs: list[tuple[str, Optional[str]]]) -> None: + """Start collecting data when the boot_args script tag is found.""" if tag != "script" or self._found: return attr_map = {key: value for key, value in attrs} @@ -397,10 +425,12 @@ def handle_starttag(self, tag: str, attrs: list[tuple[str, Optional[str]]]) -> N self._found = True def handle_endtag(self, tag: str) -> None: + """Stop collecting when the current script tag closes.""" if tag == "script" and self._collecting: self._collecting = False def handle_data(self, data: str) -> None: + """Append script contents while the boot_args tag is active.""" if self._collecting: self._chunks.append(data) @@ -459,6 +489,7 @@ def parse_boot_args_html(html_text: str) -> Hsa2BootContext: def _encode_varint(value: int) -> bytes: + """Encode an unsigned protobuf varint.""" if value < 0: raise ValueError("Negative varints are not supported.") parts = bytearray() @@ -473,6 +504,7 @@ def _encode_varint(value: int) -> bytes: def _read_varint(data: bytes, offset: int) -> tuple[int, int]: + """Decode one protobuf varint from a byte string and return the new offset.""" value = 0 shift = 0 start_offset = offset @@ -492,22 +524,27 @@ def _read_varint(data: bytes, offset: int) -> tuple[int, int]: def _encode_field(field_number: int, wire_type: int, value: bytes) -> bytes: + """Encode one protobuf field header and payload.""" return _encode_varint((field_number << 3) | wire_type) + value def _encode_bytes_field(field_number: int, value: bytes) -> bytes: + """Encode a length-delimited protobuf field.""" return _encode_field(field_number, 2, _encode_varint(len(value)) + value) def _encode_string_field(field_number: int, value: str) -> bytes: + """Encode a UTF-8 string protobuf field.""" return _encode_bytes_field(field_number, value.encode("utf-8")) def _encode_uint32_field(field_number: int, value: int) -> bytes: + """Encode an unsigned integer protobuf field.""" return _encode_field(field_number, 0, _encode_varint(value)) def _decode_fields(data: bytes) -> dict[int, list[Any]]: + """Decode a minimal subset of protobuf wire types into field lists.""" offset = 0 fields: dict[int, list[Any]] = {} while offset < len(data): @@ -536,6 +573,7 @@ def _decode_fields(data: bytes) -> dict[int, list[Any]]: def _decode_connection_response(message: bytes) -> _ConnectionResponse: + """Decode the server's websocket bootstrap response.""" fields = _decode_fields(message) push_token_b64 = "" if fields.get(1): @@ -557,6 +595,7 @@ def _decode_connection_response(message: bytes) -> _ConnectionResponse: def _decode_push_message(message: bytes) -> _PushMessage: + """Decode one push-delivery frame from the bridge websocket.""" fields = _decode_fields(message) topic = bytes(fields.get(1, [b""])[0]) message_id = int(fields.get(2, [0])[0]) @@ -567,6 +606,7 @@ def _decode_push_message(message: bytes) -> _PushMessage: def _decode_channel_subscription_response( message: bytes, ) -> _ChannelSubscriptionResponse: + """Decode the server's response to the topic subscription message.""" fields = _decode_fields(message) topics: list[str] = [] @@ -588,6 +628,7 @@ def _decode_channel_subscription_response( def _decode_acknowledgement_message(message: bytes) -> _AcknowledgementMessage: + """Decode a push acknowledgment frame from the bridge websocket.""" fields = _decode_fields(message) topic = bytes(fields.get(1, [b""])[0]) message_id = int(fields.get(2, [0])[0]) @@ -600,6 +641,7 @@ def _decode_acknowledgement_message(message: bytes) -> _AcknowledgementMessage: def _decode_server_message(message: bytes) -> _ServerMessage: + """Decode all known top-level messages embedded in one websocket frame.""" fields = _decode_fields(message) connection_response = None @@ -636,6 +678,7 @@ def _decode_server_message(message: bytes) -> _ServerMessage: def _encode_connection_message( public_key: bytes, nonce: bytes, signature: bytes ) -> bytes: + """Encode the initial bridge websocket bootstrap message.""" connection_message = b"".join( [ _encode_bytes_field(1, public_key), @@ -658,6 +701,7 @@ def _encode_bridge_signature(signature: bytes) -> bytes: def _encode_web_filter_message(allowed_topics: list[str]) -> bytes: + """Encode the topic subscription message sent after bridge connect.""" filter_payload = b"".join( _encode_string_field(1, topic) for topic in allowed_topics ) @@ -665,6 +709,7 @@ def _encode_web_filter_message(allowed_topics: list[str]) -> bytes: def _encode_ack_message(topic: bytes, message_id: int) -> bytes: + """Encode the acknowledgment frame for one delivered push message.""" ack_payload = b"".join( [ _encode_bytes_field(1, topic), @@ -675,14 +720,17 @@ def _encode_ack_message(topic: bytes, message_id: int) -> bytes: def _topic_hash(topic: str) -> str: + """Return Apple's websocket topic hash for a named APNS topic.""" return hashlib.sha1(topic.encode("utf-8")).hexdigest() def _topic_name(topic_bytes: bytes, topics_by_hash: Mapping[str, str]) -> str: + """Resolve a hashed topic payload back to a readable topic name.""" return topics_by_hash.get(topic_bytes.hex(), topic_bytes.decode("utf-8", "ignore")) def _extract_json_payload(payload: bytes) -> dict[str, Any]: + """Extract the JSON object embedded in one bridge push payload.""" try: return json.loads(payload.decode("utf-8")) except (UnicodeDecodeError, json.JSONDecodeError): @@ -721,6 +769,7 @@ def _extract_json_payload(payload: bytes) -> dict[str, Any]: def _b64_to_hex(value: str) -> str: + """Decode base64 bridge data and return it as lowercase hex.""" try: return base64.b64decode(value.encode("ascii"), validate=True).hex() except (ValueError, BinasciiError) as exc: @@ -728,16 +777,19 @@ def _b64_to_hex(value: str) -> str: def _hex_to_b64(value: str) -> str: + """Encode hex bridge data as standard base64 text.""" return base64.b64encode(bytes.fromhex(value)).decode("ascii") def _build_nonce(timestamp_ms: int) -> bytes: + """Build the nonce format expected by Apple's bridge bootstrap.""" return b"\x00" + timestamp_ms.to_bytes(8, "big", signed=False) + os.urandom(8) def _summarize_identifier( value: Optional[str], *, prefix: int = 8, empty: str = "" ) -> str: + """Shorten sensitive identifiers before logging them at debug level.""" if not value: return empty if len(value) <= prefix: @@ -746,6 +798,7 @@ def _summarize_identifier( def _resolve_websocket_host(boot_context: Hsa2BootContext) -> str: + """Resolve the websocket host Apple expects for the bridge session.""" bridge_data = boot_context.bridge_initiate_data web_socket_url = bridge_data.get("webSocketUrl") if isinstance(web_socket_url, str) and web_socket_url: @@ -765,6 +818,7 @@ def _resolve_websocket_host(boot_context: Hsa2BootContext) -> str: def _resolve_apns_topic(boot_context: Hsa2BootContext) -> str: + """Resolve the APNS topic Apple uses for trusted-device pushes.""" topic = boot_context.bridge_initiate_data.get("apnsTopic") if isinstance(topic, str) and topic: return topic @@ -775,6 +829,7 @@ def _resolve_apns_topic(boot_context: Hsa2BootContext) -> str: def _derive_origin(auth_endpoint: str) -> str: + """Derive the websocket Origin header from the auth endpoint URL.""" parsed = urlparse(auth_endpoint) if not parsed.scheme or not parsed.hostname: raise PyiCloudTrustedDevicePromptException( @@ -793,6 +848,7 @@ def __init__( origin: str, user_agent: str, ) -> None: + """Open a websocket connection and prepare buffered frame reads.""" self._url = url self._timeout = timeout self._origin = origin @@ -801,6 +857,7 @@ def __init__( self._socket = self._open() def _open(self) -> ssl.SSLSocket: + """Perform the websocket HTTP upgrade and return the TLS socket.""" parsed = urlparse(self._url) if parsed.scheme != "wss" or not parsed.hostname: raise PyiCloudTrustedDevicePromptException( @@ -856,6 +913,7 @@ def _open(self) -> ssl.SSLSocket: return secure_socket def _read_http_response(self, sock: ssl.SSLSocket) -> str: + """Read the HTTP upgrade response headers from the websocket socket.""" while b"\r\n\r\n" not in self._buffer: chunk = sock.recv(4096) if not chunk: @@ -870,6 +928,7 @@ def _read_http_response(self, sock: ssl.SSLSocket) -> str: return data def _read_exact(self, size: int) -> bytes: + """Read exactly ``size`` buffered bytes from the websocket socket.""" while len(self._buffer) < size: chunk = self._socket.recv(max(4096, size - len(self._buffer))) if not chunk: @@ -883,6 +942,7 @@ def _read_exact(self, size: int) -> bytes: return data def _send_frame(self, opcode: int, payload: bytes) -> None: + """Send one masked websocket frame to Apple's bridge server.""" first_byte = 0x80 | opcode mask_key = os.urandom(4) length = len(payload) @@ -903,9 +963,11 @@ def _send_frame(self, opcode: int, payload: bytes) -> None: self._socket.sendall(bytes(header) + mask_key + masked_payload) def send_binary(self, payload: bytes) -> None: + """Send one binary websocket message payload.""" self._send_frame(OPCODE_BINARY, payload) def read_message(self) -> bytes: + """Read one complete websocket message, handling control frames inline.""" fragments: list[bytes] = [] opcode: Optional[int] = None @@ -949,6 +1011,7 @@ def read_message(self) -> bytes: return b"".join(fragments) def close(self) -> None: + """Attempt a clean websocket close and always close the socket object.""" if getattr(self, "_socket", None) is None: return try: @@ -974,6 +1037,7 @@ def __init__( ] = None, prover_factory: Optional[Callable[[], TrustedDeviceBridgeProver]] = None, ) -> None: + """Configure websocket and prover factories for bridge operations.""" self.timeout = timeout self._websocket_factory = websocket_factory or _RawWebSocketClient self._prover_factory = prover_factory or TrustedDeviceBridgeProver @@ -987,6 +1051,7 @@ def start( boot_context: Hsa2BootContext, user_agent: str, ) -> TrustedDeviceBridgeState: + """Bootstrap Apple's trusted-device bridge until the first prompt payload arrives.""" topic = _resolve_apns_topic(boot_context) websocket_host = _resolve_websocket_host(boot_context) origin = _derive_origin(auth_endpoint) @@ -1112,6 +1177,7 @@ def start( ) from last_error def _generate_keypair(self) -> tuple[bytes, ec.EllipticCurvePrivateKey]: + """Generate the ephemeral P-256 keypair used for websocket bootstrap.""" private_key = ec.generate_private_key(ec.SECP256R1()) public_key = private_key.public_key().public_bytes( encoding=serialization.Encoding.X962, @@ -1120,6 +1186,7 @@ def _generate_keypair(self) -> tuple[bytes, ec.EllipticCurvePrivateKey]: return public_key, private_key def _generate_session_uuid(self) -> str: + """Generate the browser-style bridge session UUID string.""" return f"{uuid.uuid4()}-{int(time.time())}" def close(self, bridge_state: Optional[TrustedDeviceBridgeState]) -> None: @@ -1317,6 +1384,7 @@ def validate_code( self.close(bridge_state) def _wait_for_push_token(self, websocket: _WebSocketLike) -> bytes: + """Wait for the bridge connection response that carries the push token.""" deadline = time.monotonic() + self.timeout while time.monotonic() < deadline: message = websocket.read_message() @@ -1369,6 +1437,7 @@ def _wait_for_bridge_push( topic: str, topics_by_hash: Mapping[str, str], ) -> BridgePushPayload: + """Wait for, acknowledge, and decode the next relevant bridge push.""" deadline = time.monotonic() + self.timeout while time.monotonic() < deadline: message = websocket.read_message() @@ -1429,6 +1498,7 @@ def _apply_bridge_push( bridge_state: TrustedDeviceBridgeState, push_payload: BridgePushPayload, ) -> None: + """Validate a generic bridge push and merge it into the active state.""" if push_payload.session_uuid != bridge_state.session_uuid: raise PyiCloudTrustedDeviceVerificationException( "Trusted-device bridge returned a mismatched session UUID." @@ -1454,6 +1524,7 @@ def _apply_expected_step4_push( bridge_state: TrustedDeviceBridgeState, push_payload: BridgePushPayload, ) -> None: + """Require the post-step-2 bridge push to contain step-4 prover data.""" self._apply_bridge_push(bridge_state, push_payload) if bridge_state.next_step != "4" or not bridge_state.data: raise PyiCloudTrustedDeviceVerificationException( @@ -1471,6 +1542,7 @@ def _apply_final_bridge_push( bridge_state: TrustedDeviceBridgeState, push_payload: BridgePushPayload, ) -> None: + """Require the final bridge push to contain the encrypted validation code.""" self._apply_bridge_push(bridge_state, push_payload) # Apple's bridge can finish with either: # - nextStep=6 plus encryptedCode @@ -1495,6 +1567,7 @@ def _bridge_headers( headers: Mapping[str, str], bridge_state: TrustedDeviceBridgeState, ) -> dict[str, str]: + """Build the auth headers used for bridge-specific HTTP requests.""" bridge_headers = dict(headers) if bridge_state.source_app_id: bridge_headers["X-Apple-App-Id"] = bridge_state.source_app_id @@ -1509,6 +1582,7 @@ def _bridge_step_json( idmsdata: Optional[str], akdata: Any, ) -> dict[str, Any]: + """Build the JSON payload for one bridge step POST.""" return BridgeStepRequest( session_uuid=bridge_state.session_uuid, data=data, @@ -1530,6 +1604,7 @@ def _post_bridge_step( idmsdata: Optional[str], akdata: Any, ) -> Any: + """POST one bridge step and enforce the small set of valid statuses.""" response = session.request_raw( "POST", f"{auth_endpoint}{BRIDGE_STEP_PATH_TEMPLATE.format(step=next_step)}", @@ -1562,6 +1637,7 @@ def _post_bridge_step0( session_uuid: str, push_token: str, ) -> Any: + """POST bridge step 0 immediately after obtaining the push token.""" response = session.request_raw( "POST", f"{auth_endpoint}{BRIDGE_STEP_PATH}", @@ -1591,6 +1667,7 @@ def _post_bridge_code_validate( bridge_state: TrustedDeviceBridgeState, code: str, ) -> Any: + """POST the decrypted bridge code to Apple's final validation endpoint.""" response = session.request_raw( "POST", f"{auth_endpoint}{BRIDGE_CODE_VALIDATE_PATH}", diff --git a/pyicloud/hsa2_bridge_prover.py b/pyicloud/hsa2_bridge_prover.py index ed5a8432..8a86bdb6 100644 --- a/pyicloud/hsa2_bridge_prover.py +++ b/pyicloud/hsa2_bridge_prover.py @@ -39,11 +39,14 @@ @dataclass(frozen=True) class _Point: + """Affine P-256 point used by the bridge SPAKE2 math helpers.""" + x: Optional[int] y: Optional[int] @property def is_infinity(self) -> bool: + """Return whether this point is the point at infinity.""" return self.x is None or self.y is None @@ -52,26 +55,31 @@ def is_infinity(self) -> bool: def _int_to_bytes(value: int, length: Optional[int] = None) -> bytes: + """Encode an integer using big-endian bytes.""" if length is None: length = max(1, (value.bit_length() + 7) // 8) return value.to_bytes(length, "big") def _b64_to_bytes(value: str) -> bytes: + """Decode a base64 string into raw bytes.""" return base64.b64decode(value.encode("ascii")) def _bytes_to_b64(value: bytes) -> str: + """Encode raw bytes as an ASCII base64 string.""" return base64.b64encode(value).decode("ascii") def _encode_point(point: _Point) -> str: + """Encode a P-256 point using SEC1 uncompressed point format.""" if point.is_infinity: raise ValueError("Cannot encode the point at infinity.") return "04" + _int_to_bytes(point.x, 32).hex() + _int_to_bytes(point.y, 32).hex() def _decode_point(value: str) -> _Point: + """Decode a compressed or uncompressed SEC1 point into affine coordinates.""" raw = bytes.fromhex(value) if len(raw) == 65 and raw[0] == 0x04: point = _Point( @@ -94,6 +102,7 @@ def _decode_point(value: str) -> _Point: def _is_on_curve(point: _Point) -> bool: + """Return whether a point lies on the configured P-256 curve.""" if point.is_infinity: return False assert point.x is not None and point.y is not None @@ -104,6 +113,7 @@ def _is_on_curve(point: _Point) -> bool: def _negate(point: _Point) -> _Point: + """Return the additive inverse of a P-256 point.""" if point.is_infinity: return point assert point.x is not None and point.y is not None @@ -111,6 +121,7 @@ def _negate(point: _Point) -> _Point: def _add_points(left: _Point, right: _Point) -> _Point: + """Add two affine P-256 points.""" if left.is_infinity: return right if right.is_infinity: @@ -137,6 +148,7 @@ def _add_points(left: _Point, right: _Point) -> _Point: def _multiply_point(point: _Point, scalar: int) -> _Point: + """Multiply a P-256 point by a scalar using double-and-add.""" scalar %= _P256_ORDER result = _INFINITY addend = point @@ -149,6 +161,7 @@ def _multiply_point(point: _Point, scalar: int) -> _Point: def _concat_length_prefixed(*parts: bytes) -> bytes: + """Concatenate transcript parts using the bridge's length-prefixed format.""" output = bytearray() for part in parts: output.extend(len(part).to_bytes(8, "little")) @@ -157,6 +170,7 @@ def _concat_length_prefixed(*parts: bytes) -> bytes: def _hkdf_like(ikm: bytes, salt: bytes, info: bytes, length: int) -> bytes: + """Derive key material using the bridge worker's HKDF-like expansion.""" hash_len = hashlib.sha256().digest_size if not salt: salt = b"\x00" * hash_len @@ -176,12 +190,14 @@ def _hkdf_like(ikm: bytes, salt: bytes, info: bytes, length: int) -> bytes: def _confirmation_key_length(info: bytes, requested_length: int) -> int: + """Return the bridge-specific output length for a given HKDF info label.""" if b"ConfirmationKeys" in info: return 64 return requested_length def _derive_key(ikm: bytes, info: bytes, length: int = 64) -> bytes: + """Derive one bridge sub-key from raw shared-secret material.""" return _hkdf_like( ikm=ikm, salt=b"", @@ -191,6 +207,7 @@ def _derive_key(ikm: bytes, info: bytes, length: int = 64) -> bytes: def _derive_prover_and_verifier_keys(raw_key_hex: str) -> tuple[str, str]: + """Split the raw bridge key into prover and verifier AES/HMAC keys.""" raw_key = bytes.fromhex(raw_key_hex) verifier_key = _derive_key(raw_key, _VERIFIER_KEY_INFO, _KEY_LENGTH) prover_key = _derive_key(raw_key, _PROVER_KEY_INFO, _KEY_LENGTH) @@ -199,11 +216,14 @@ def _derive_prover_and_verifier_keys(raw_key_hex: str) -> tuple[str, str]: @dataclass(frozen=True) class _ClientSharedSecret: + """Client-side shared-secret transcript and derived confirmation keys.""" + transcript: bytes share_p: str share_v: str def __post_init__(self) -> None: + """Derive confirmation keys and the final shared key from the transcript.""" digest = hashlib.sha256(self.transcript).digest() object.__setattr__(self, "_hash_transcript", digest) confirmations = _derive_key(digest, b"ConfirmationKeys", 64) @@ -213,6 +233,7 @@ def __post_init__(self) -> None: object.__setattr__(self, "_shared_key", shared_key) def get_confirmation(self) -> str: + """Return the prover's HMAC confirmation message.""" return hmac.new( self._confirm_client, bytes.fromhex(self.share_v), @@ -220,6 +241,7 @@ def get_confirmation(self) -> str: ).hexdigest() def verify(self, message_hex: str) -> bytes: + """Verify the server confirmation and return the shared key bytes.""" expected = hmac.new( self._confirm_server, bytes.fromhex(self.share_p), @@ -232,11 +254,14 @@ def verify(self, message_hex: str) -> bytes: @dataclass(frozen=True) class _ServerSharedSecret: + """Server-side shared-secret transcript and derived confirmation keys.""" + transcript: bytes share_p: str share_v: str def __post_init__(self) -> None: + """Derive confirmation keys and the final shared key from the transcript.""" digest = hashlib.sha256(self.transcript).digest() confirmations = _derive_key(digest, b"ConfirmationKeys", 64) object.__setattr__(self, "_confirm_client", confirmations[:32]) @@ -248,6 +273,7 @@ def __post_init__(self) -> None: ) def get_confirmation(self) -> str: + """Return the verifier's HMAC confirmation message.""" return hmac.new( self._confirm_server, bytes.fromhex(self.share_p), @@ -255,6 +281,7 @@ def get_confirmation(self) -> str: ).hexdigest() def verify(self, message_hex: str) -> bytes: + """Verify the prover confirmation and return the shared key bytes.""" expected = hmac.new( self._confirm_client, bytes.fromhex(self.share_v), @@ -266,6 +293,8 @@ def verify(self, message_hex: str) -> bytes: class _ClientHandshake: + """Client-side SPAKE2 handshake state for Apple's bridge prover.""" + def __init__( self, *, @@ -273,6 +302,7 @@ def __init__( w0: int, w1: int, ) -> None: + """Initialize the prover handshake with the derived SPAKE2 scalars.""" self._x = x_scalar self._w0 = w0 self._w1 = w1 @@ -280,6 +310,7 @@ def __init__( self.share_p: Optional[str] = None def get_message(self) -> str: + """Return the prover's first SPAKE2 message.""" point = _add_points( _multiply_point(_GENERATOR, self._x), _multiply_point(_decode_point(_SPAKE2_M), self._w0), @@ -289,6 +320,7 @@ def get_message(self) -> str: return self.share_p def finish(self, server_message_hex: str) -> _ClientSharedSecret: + """Finish the handshake using the verifier's first message.""" if self._message1_point is None or self.share_p is None: raise ValueError("get_message must be called before finish") @@ -322,6 +354,8 @@ def finish(self, server_message_hex: str) -> _ClientSharedSecret: class _ServerHandshake: + """Server-side SPAKE2 handshake state used by the local test helper.""" + def __init__( self, *, @@ -329,6 +363,7 @@ def __init__( w0: int, verifier_point: _Point, ) -> None: + """Initialize the verifier handshake with its scalar and verifier point.""" self._y = y_scalar self._w0 = w0 self._verifier_point = verifier_point @@ -336,6 +371,7 @@ def __init__( self.share_v: Optional[str] = None def get_message(self) -> str: + """Return the verifier's first SPAKE2 message.""" point = _add_points( _multiply_point(_GENERATOR, self._y), _multiply_point(_decode_point(_SPAKE2_N), self._w0), @@ -345,6 +381,7 @@ def get_message(self) -> str: return self.share_v def finish(self, client_message_hex: str) -> _ServerSharedSecret: + """Finish the verifier handshake using the prover's first message.""" if self._message1_point is None or self.share_v is None: raise ValueError("get_message must be called before finish") @@ -378,6 +415,7 @@ def finish(self, client_message_hex: str) -> _ServerSharedSecret: def _compute_w0_w1(password: str, salt_b64: str) -> tuple[int, int]: + """Derive the SPAKE2 scalars from the user code and bridge salt.""" derived = hashlib.scrypt( password.encode("utf-8"), salt=_b64_to_bytes(salt_b64), @@ -394,6 +432,7 @@ class TrustedDeviceBridgeProver: """Client-side prover mirroring Apple's prover worker.""" def __init__(self) -> None: + """Initialize empty prover state for one bridge verification attempt.""" self._client: Optional[_ClientHandshake] = None self._shared_secret: Optional[_ClientSharedSecret] = None self._raw_key: Optional[str] = None @@ -402,6 +441,7 @@ def __init__(self) -> None: self._prover_key: Optional[str] = None def init_with_salt(self, salt_b64: str, code: str) -> None: + """Initialize the prover with Apple's salt and the user-entered code.""" w0, w1 = _compute_w0_w1(code, salt_b64) self._client = _ClientHandshake( x_scalar=secrets.randbelow(_P256_ORDER), @@ -415,22 +455,26 @@ def init_with_salt(self, salt_b64: str, code: str) -> None: self._prover_key = None def get_message1(self) -> str: + """Return the prover's first bridge message.""" if self._client is None: raise ValueError("init_with_salt must be called before get_message1") return self._client.get_message() def process_message1(self, message_hex: str) -> str: + """Process Apple's first bridge message and return the prover confirmation.""" if self._client is None: raise ValueError("init_with_salt must be called before process_message1") self._shared_secret = self._client.finish(message_hex) return self.get_message2() def get_message2(self) -> str: + """Return the prover confirmation generated from the shared transcript.""" if self._shared_secret is None: raise ValueError("process_message1 must be called before get_message2") return self._shared_secret.get_confirmation() def process_message2(self, message_hex: str) -> dict[str, object]: + """Verify Apple's confirmation and persist the derived bridge keys.""" if self._shared_secret is None: raise ValueError("process_message1 must be called before process_message2") raw_key = self._shared_secret.verify(message_hex).hex() @@ -440,14 +484,17 @@ def process_message2(self, message_hex: str) -> dict[str, object]: return {"isVerified": True, "key": raw_key} def is_verified(self) -> bool: + """Return whether the bridge confirmation exchange has completed.""" return self._verified def get_key(self) -> str: + """Return the raw shared bridge key as hexadecimal.""" if self._raw_key is None: raise ValueError("No bridge key is available yet.") return self._raw_key def decrypt_message(self, ciphertext_b64: str) -> str: + """Decrypt Apple's final encrypted validation code.""" if self._verifier_key is None: raise ValueError("Bridge verifier key is not available.") payload = _b64_to_bytes(ciphertext_b64) @@ -468,6 +515,7 @@ class _TrustedDeviceBridgeServerProver: """Internal test helper mirroring Apple's server-side bridge flow.""" def __init__(self, *, password: str, salt_b64: str) -> None: + """Initialize the local verifier helper with the same password and salt.""" w0, w1 = _compute_w0_w1(password, salt_b64) verifier_point = _multiply_point(_GENERATOR, w1) self._server = _ServerHandshake( @@ -481,18 +529,22 @@ def __init__(self, *, password: str, salt_b64: str) -> None: self._prover_key: Optional[str] = None def get_message1(self) -> str: + """Return the verifier's first bridge message.""" return self._server.get_message() def process_message1(self, client_message_hex: str) -> str: + """Process the prover message and return the verifier confirmation.""" self._shared_secret = self._server.finish(client_message_hex) return self.get_message2() def get_message2(self) -> str: + """Return the verifier confirmation generated from the shared transcript.""" if self._shared_secret is None: raise ValueError("process_message1 must be called before get_message2") return self._shared_secret.get_confirmation() def verify_message2(self, message_hex: str) -> str: + """Verify the prover confirmation and persist the derived bridge keys.""" if self._shared_secret is None: raise ValueError("process_message1 must be called before verify_message2") raw_key = self._shared_secret.verify(message_hex).hex() @@ -501,6 +553,7 @@ def verify_message2(self, message_hex: str) -> str: return raw_key def encrypt_message(self, plaintext: str) -> str: + """Encrypt a plaintext validation code using Apple's AES-GCM payload layout.""" if self._verifier_key is None: raise ValueError("Bridge verifier key is not available.") version = 0 diff --git a/pyicloud/session.py b/pyicloud/session.py index 51b63948..067f73ac 100644 --- a/pyicloud/session.py +++ b/pyicloud/session.py @@ -44,6 +44,7 @@ def __init__( verify: bool = False, headers: Optional[dict[str, str]] = None, ) -> None: + """Initialize the persisted requests session used by the service.""" super().__init__() self._service: PyiCloudService = service @@ -143,6 +144,7 @@ def _update_session_data(self, response: Response) -> None: self._data.update({session_arg: response.headers.get(header)}) def _is_json_response(self, response: Response) -> bool: + """Return whether a response advertises one of the accepted JSON mimetypes.""" content_type: str = response.headers.get(CONTENT_TYPE, "") json_mimetypes: list[str] = [ CONTENT_TYPE_JSON, @@ -169,6 +171,7 @@ def request( cert=None, json=None, ) -> Response: + """Dispatch a request through the normalized session request pipeline.""" return self._request( method, url, @@ -368,6 +371,7 @@ def _decode_json_response(self, response: Response) -> None: def _raise_error( self, response: Response, code: Optional[Union[int, str]], reason: str ) -> NoReturn: + """Raise the session's public exception for a parsed iCloud error payload.""" if ( self.service.requires_2sa and reason == "Missing X-APPLE-WEBAUTH-TOKEN cookie" From ecbdc3ec4f8c7dd23bd99da448e6f7c3f785e75d Mon Sep 17 00:00:00 2001 From: mrjarnould Date: Tue, 31 Mar 2026 03:45:33 +0200 Subject: [PATCH 5/5] Harden bridge prover and persistence tests --- pyicloud/hsa2_bridge_prover.py | 40 ++++++++++++++++++++++------------ pyicloud/session.py | 33 +++++++++++++++++++++++++++- tests/test_base.py | 7 +++++- tests/test_hsa2_bridge.py | 34 +++++++++++++++++++++++++++++ 4 files changed, 98 insertions(+), 16 deletions(-) diff --git a/pyicloud/hsa2_bridge_prover.py b/pyicloud/hsa2_bridge_prover.py index 8a86bdb6..7e6ef8bd 100644 --- a/pyicloud/hsa2_bridge_prover.py +++ b/pyicloud/hsa2_bridge_prover.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from typing import Optional +from cryptography.exceptions import InvalidTag from cryptography.hazmat.primitives.ciphers.aead import AESGCM _SCRYPT_PARAMS = { @@ -428,6 +429,14 @@ def _compute_w0_w1(password: str, salt_b64: str) -> tuple[int, int]: ) +def _random_nonzero_scalar() -> int: + """Return a random scalar in the non-zero P-256 subgroup range.""" + scalar = 0 + while scalar == 0: + scalar = secrets.randbelow(_P256_ORDER) + return scalar + + class TrustedDeviceBridgeProver: """Client-side prover mirroring Apple's prover worker.""" @@ -444,7 +453,7 @@ def init_with_salt(self, salt_b64: str, code: str) -> None: """Initialize the prover with Apple's salt and the user-entered code.""" w0, w1 = _compute_w0_w1(code, salt_b64) self._client = _ClientHandshake( - x_scalar=secrets.randbelow(_P256_ORDER), + x_scalar=_random_nonzero_scalar(), w0=w0, w1=w1, ) @@ -497,18 +506,21 @@ def decrypt_message(self, ciphertext_b64: str) -> str: """Decrypt Apple's final encrypted validation code.""" if self._verifier_key is None: raise ValueError("Bridge verifier key is not available.") - payload = _b64_to_bytes(ciphertext_b64) - version = payload[0] - iv_length, tag_length = _AES_GCM_LAYOUTS[version] - iv = payload[1 : 1 + iv_length] - tag = payload[1 + iv_length : 1 + iv_length + tag_length] - ciphertext = payload[1 + iv_length + tag_length :] - plaintext = AESGCM(bytes.fromhex(self._verifier_key)).decrypt( - iv, - ciphertext + tag, - bytes([version]), - ) - return plaintext.decode("utf-8") + try: + payload = _b64_to_bytes(ciphertext_b64) + version = payload[0] + iv_length, tag_length = _AES_GCM_LAYOUTS[version] + iv = payload[1 : 1 + iv_length] + tag = payload[1 + iv_length : 1 + iv_length + tag_length] + ciphertext = payload[1 + iv_length + tag_length :] + plaintext = AESGCM(bytes.fromhex(self._verifier_key)).decrypt( + iv, + ciphertext + tag, + bytes([version]), + ) + return plaintext.decode("utf-8") + except (IndexError, KeyError, InvalidTag, UnicodeDecodeError) as exc: + raise ValueError("Malformed bridge payload") from exc class _TrustedDeviceBridgeServerProver: @@ -519,7 +531,7 @@ def __init__(self, *, password: str, salt_b64: str) -> None: w0, w1 = _compute_w0_w1(password, salt_b64) verifier_point = _multiply_point(_GENERATOR, w1) self._server = _ServerHandshake( - y_scalar=secrets.randbelow(_P256_ORDER), + y_scalar=_random_nonzero_scalar(), w0=w0, verifier_point=verifier_point, ) diff --git a/pyicloud/session.py b/pyicloud/session.py index 067f73ac..c134d25b 100644 --- a/pyicloud/session.py +++ b/pyicloud/session.py @@ -33,6 +33,30 @@ from pyicloud.base import PyiCloudService +NON_PERSISTED_SESSION_KEYS = frozenset( + { + "akdata", + "connection_path", + "data", + "encryptedCode", + "encrypted_code", + "idmsdata", + "mid", + "nextStep", + "next_step", + "ptkn", + "push_token", + "salt", + "sessionUUID", + "session_uuid", + "source_app_id", + "topic", + "topics_by_hash", + "txnid", + } +) + + class PyiCloudSession(requests.Session): """iCloud session.""" @@ -103,7 +127,14 @@ def _save_session_data(self) -> None: os.makedirs(self._cookie_directory, exist_ok=True) with open(self.session_path, "w", encoding="utf-8") as outfile: # Copy to avoid dict mutation during concurrent access - dump(dict(self._data), outfile) + dump( + { + key: value + for key, value in dict(self._data).items() + if key not in NON_PERSISTED_SESSION_KEYS + }, + outfile, + ) self.logger.debug("Saved session data to file: %s", self.session_path) try: diff --git a/tests/test_base.py b/tests/test_base.py index 0f65442e..0bebaf05 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1015,15 +1015,20 @@ def test_session_persistence_excludes_trusted_device_bridge_state( cookie_directory=str(temp_root), ) pyicloud_service_working._session = session - pyicloud_service_working._trusted_device_bridge_state = MagicMock( + bridge_state = MagicMock( push_token="bridge-ptkn", session_uuid="bridge-session-uuid", idmsdata="bridge-idmsdata", encrypted_code="bridge-encrypted-code", ) + pyicloud_service_working._trusted_device_bridge_state = bridge_state session._data = { "session_token": "valid-token", "session_id": "persisted-session-id", + "push_token": bridge_state.push_token, + "session_uuid": bridge_state.session_uuid, + "idmsdata": bridge_state.idmsdata, + "encrypted_code": bridge_state.encrypted_code, } session._save_session_data() diff --git a/tests/test_hsa2_bridge.py b/tests/test_hsa2_bridge.py index f21d558f..2e64f398 100644 --- a/tests/test_hsa2_bridge.py +++ b/tests/test_hsa2_bridge.py @@ -361,6 +361,40 @@ def test_trusted_device_bridge_prover_roundtrip() -> None: assert prover.decrypt_message(encrypted_code) == "derived-device-code" +def test_trusted_device_bridge_prover_retries_zero_ephemeral_scalars( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Ephemeral prover scalars must stay in the non-zero subgroup range.""" + + draws = iter([0, 7, 0, 9]) + monkeypatch.setattr( + "pyicloud.hsa2_bridge_prover.secrets.randbelow", + lambda _limit: next(draws), + ) + + salt_b64 = base64.b64encode(b"0123456789abcdef").decode("ascii") + prover = TrustedDeviceBridgeProver() + prover.init_with_salt(salt_b64, "050044") + server = _TrustedDeviceBridgeServerProver(password="050044", salt_b64=salt_b64) + + assert prover._client is not None + assert prover._client._x == 7 + assert server._server._y == 9 + + +def test_trusted_device_bridge_prover_normalizes_malformed_bridge_payloads() -> None: + """Malformed encrypted payloads should surface as ValueError.""" + + prover = TrustedDeviceBridgeProver() + prover._verifier_key = "00" * 32 + + with pytest.raises(ValueError, match="Malformed bridge payload"): + prover.decrypt_message(base64.b64encode(b"").decode("ascii")) + + with pytest.raises(ValueError, match="Malformed bridge payload"): + prover.decrypt_message(base64.b64encode(b"\x01truncated").decode("ascii")) + + def test_trusted_device_bridge_bootstrap_keeps_websocket_open_and_persists_step2() -> ( None ):