From 83a55d974deb7b6a66aa8233c29cb2589d1466c8 Mon Sep 17 00:00:00 2001 From: mrjarnould Date: Tue, 31 Mar 2026 02:34:28 +0200 Subject: [PATCH 01/13] feat: add CLI support for Notes and Reminders --- README.md | 106 ++ pyicloud/cli/app.py | 8 + pyicloud/cli/commands/notes.py | 542 ++++++++++ pyicloud/cli/commands/reminders.py | 1604 ++++++++++++++++++++++++++++ pyicloud/cli/normalize.py | 95 ++ pyicloud/cli/output.py | 34 + tests/test_cmdline.py | 1533 ++++++++++++++++++++------ tests/test_output.py | 18 + 8 files changed, 3594 insertions(+), 346 deletions(-) create mode 100644 pyicloud/cli/commands/notes.py create mode 100644 pyicloud/cli/commands/reminders.py diff --git a/README.md b/README.md index f6836411..edef5673 100644 --- a/README.md +++ b/README.md @@ -1128,6 +1128,66 @@ rules, and delete flows against a real iCloud account. [`example_reminders_delta.py`](example_reminders_delta.py) is a smaller live validator focused on `sync_cursor()` and `iter_changes(since=...)`. +### Reminders CLI + +The official Typer CLI exposes `icloud reminders ...` for common read and write +flows. + +_List reminder lists and open reminders:_ + +```bash +uv run icloud reminders lists --username you@example.com +uv run icloud reminders list --username you@example.com +uv run icloud reminders list --username you@example.com --list-id INBOX --include-completed +``` + +`icloud reminders list` defaults to open reminders only. Use +`--include-completed` to include completed reminders, and `--list-id` to scope +the query to one list. + +_Get, create, update, and delete reminders:_ + +```bash +uv run icloud reminders get REMINDER_ID --username you@example.com +uv run icloud reminders create --username you@example.com --list-id INBOX --title "Buy milk" +uv run icloud reminders update REMINDER_ID --username you@example.com --title "Buy oat milk" +uv run icloud reminders set-status REMINDER_ID --username you@example.com --completed +uv run icloud reminders delete REMINDER_ID --username you@example.com +``` + +_Inspect snapshots and incremental changes:_ + +```bash +uv run icloud reminders snapshot --username you@example.com --list-id INBOX +uv run icloud reminders changes --username you@example.com --since PREVIOUS_CURSOR +uv run icloud reminders sync-cursor --username you@example.com +``` + +_Work with reminder sub-records:_ + +```bash +uv run icloud reminders alarm add-location REMINDER_ID \ + --username you@example.com \ + --title "Office" \ + --address "1 Infinite Loop, Cupertino, CA" \ + --latitude 37.3318 \ + --longitude -122.0312 + +uv run icloud reminders hashtag create REMINDER_ID errands --username you@example.com +uv run icloud reminders attachment create-url REMINDER_ID \ + --username you@example.com \ + --url https://example.com/checklist +uv run icloud reminders recurrence create REMINDER_ID \ + --username you@example.com \ + --frequency weekly \ + --interval 1 +``` + +The reminder CLI is organized as core commands plus `alarm`, `hashtag`, +`attachment`, and `recurrence` subgroups. Hashtag rename is exposed through +`icloud reminders hashtag update`, but Apple’s web app may still treat hashtag +names as effectively read-only in some live flows. + ## Notes You can access your iCloud Notes through the `notes` property: @@ -1254,6 +1314,52 @@ Notes caveats: - `api.notes.raw` is available for advanced/debug workflows, but it is not the primary Notes API surface. +### Notes CLI + +The official Typer CLI exposes `icloud notes ...` for recent-note inspection, +folder browsing, title-based search, HTML rendering, and note-id-based export. + +_List recent notes, folders, or one folder’s notes:_ + +```bash +uv run icloud notes recent --username you@example.com +uv run icloud notes folders --username you@example.com +uv run icloud notes list --username you@example.com --folder-id FOLDER_ID +uv run icloud notes list --username you@example.com --all --since PREVIOUS_CURSOR +``` + +_Search notes by title:_ + +```bash +uv run icloud notes search --username you@example.com --title "Daily Plan" +uv run icloud notes search --username you@example.com --title-contains "meeting" +``` + +`icloud notes search` is the official title-filter workflow. It uses a +recents-first search strategy and falls back to a full feed scan when needed. + +_Fetch, render, and export one note by id:_ + +```bash +uv run icloud notes get NOTE_ID --username you@example.com --with-attachments +uv run icloud notes render NOTE_ID --username you@example.com --preview-appearance dark +uv run icloud notes export NOTE_ID \ + --username you@example.com \ + --output-dir ./exports/notes_html \ + --export-mode archival \ + --assets-dir ./exports/assets +``` + +`icloud notes export` stays explicit by note id. Title filters are intentionally +handled by `icloud notes search` rather than by bulk export flags. + +_Inspect incremental changes:_ + +```bash +uv run icloud notes changes --username you@example.com --since PREVIOUS_CURSOR +uv run icloud notes sync-cursor --username you@example.com +``` + ### Notes CLI Example [`examples/notes_cli.py`](examples/notes_cli.py) is a local developer utility diff --git a/pyicloud/cli/app.py b/pyicloud/cli/app.py index 9bb5ae39..c478b00e 100644 --- a/pyicloud/cli/app.py +++ b/pyicloud/cli/app.py @@ -11,7 +11,9 @@ from pyicloud.cli.commands.devices import app as devices_app from pyicloud.cli.commands.drive import app as drive_app from pyicloud.cli.commands.hidemyemail import app as hidemyemail_app +from pyicloud.cli.commands.notes import app as notes_app from pyicloud.cli.commands.photos import app as photos_app +from pyicloud.cli.commands.reminders import app as reminders_app from pyicloud.cli.context import CLIAbort app = typer.Typer( @@ -54,6 +56,12 @@ def _group_root(ctx: typer.Context) -> None: invoke_without_command=True, callback=_group_root, ) +app.add_typer( + reminders_app, name="reminders", invoke_without_command=True, callback=_group_root +) +app.add_typer( + notes_app, name="notes", invoke_without_command=True, callback=_group_root +) def main() -> int: diff --git a/pyicloud/cli/commands/notes.py b/pyicloud/cli/commands/notes.py new file mode 100644 index 00000000..ff6748b7 --- /dev/null +++ b/pyicloud/cli/commands/notes.py @@ -0,0 +1,542 @@ +"""Notes commands.""" + +from __future__ import annotations + +from enum import Enum +from itertools import islice +from pathlib import Path +from typing import Optional + +import typer + +from pyicloud.cli.context import CLIAbort, get_state, service_call +from pyicloud.cli.normalize import ( + search_notes_by_title, + select_recent_notes, +) +from pyicloud.cli.options import ( + DEFAULT_LOG_LEVEL, + DEFAULT_OUTPUT_FORMAT, + HttpProxyOption, + HttpsProxyOption, + LogLevelOption, + NoVerifySslOption, + OutputFormatOption, + SessionDirOption, + UsernameOption, + store_command_options, +) +from pyicloud.cli.output import console_table +from pyicloud.services.notes.service import NoteLockedError, NoteNotFound + +app = typer.Typer(help="Inspect, render, and export Notes.") + +NOTES = "Notes" + + +class PreviewAppearance(str, Enum): + """Supported Notes preview appearances.""" + + LIGHT = "light" + DARK = "dark" + + +class ExportMode(str, Enum): + """Supported Notes export modes.""" + + ARCHIVAL = "archival" + LIGHTWEIGHT = "lightweight" + + +def _notes_service(api): + """Return the Notes service with reauthentication handling.""" + + return service_call(NOTES, lambda: api.notes, account_name=api.account_name) + + +def _notes_call(api, fn): + """Wrap Notes service calls with note-specific user-facing errors.""" + + try: + return service_call(NOTES, fn, account_name=api.account_name) + except (NoteNotFound, NoteLockedError) as err: + raise CLIAbort(str(err)) from err + + +def _print_note_rows(state, title: str, rows) -> None: + """Render note summary rows in text mode.""" + + state.console.print( + console_table( + title, + ["ID", "Title", "Folder", "Modified", "Deleted"], + [ + ( + row.id, + row.title, + row.folder_name, + row.modified_at, + getattr(row, "is_deleted", False), + ) + for row in rows + ], + ) + ) + + +@app.command("recent") +def notes_recent( + ctx: typer.Context, + limit: int = typer.Option(10, "--limit", min=1, help="Maximum notes to show."), + include_deleted: bool = typer.Option( + False, + "--include-deleted", + help="Include notes from Recently Deleted.", + ), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """List recent notes.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + notes = _notes_service(api) + payload = _notes_call( + api, + lambda: select_recent_notes( + notes, + limit=limit, + include_deleted=include_deleted, + ), + ) + if state.json_output: + state.write_json(payload) + return + _print_note_rows(state, "Recent Notes", payload) + + +@app.command("folders") +def notes_folders( + ctx: typer.Context, + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """List note folders.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + notes = _notes_service(api) + payload = _notes_call(api, lambda: list(notes.folders())) + if state.json_output: + state.write_json(payload) + return + state.console.print( + console_table( + "Note Folders", + ["ID", "Name", "Has Subfolders", "Count"], + [(row.id, row.name, row.has_subfolders, row.count) for row in payload], + ) + ) + + +@app.command("list") +def notes_list( + ctx: typer.Context, + folder_id: Optional[str] = typer.Option(None, "--folder-id", help="Folder id."), + all_notes: bool = typer.Option(False, "--all", help="Iterate all notes."), + since: Optional[str] = typer.Option( + None, + "--since", + help="Incremental sync cursor for --all.", + ), + limit: int = typer.Option(50, "--limit", min=1, help="Maximum notes to show."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """List notes.""" + + if folder_id and all_notes: + raise typer.BadParameter("Choose either --folder-id or --all, not both.") + if since and not all_notes: + raise typer.BadParameter("The --since option requires --all.") + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + notes = _notes_service(api) + if folder_id: + payload = _notes_call( + api, lambda: list(notes.in_folder(folder_id, limit=limit)) + ) + elif all_notes: + payload = _notes_call( + api, + lambda: list(islice(notes.iter_all(since=since), limit)), + ) + else: + payload = _notes_call( + api, + lambda: select_recent_notes( + notes, + limit=limit, + include_deleted=False, + ), + ) + if state.json_output: + state.write_json(payload) + return + _print_note_rows(state, "Notes", payload) + + +@app.command("search") +def notes_search( + ctx: typer.Context, + title: str = typer.Option("", "--title", help="Exact note title."), + title_contains: str = typer.Option( + "", + "--title-contains", + help="Case-insensitive note title substring.", + ), + limit: int = typer.Option(10, "--limit", min=1, help="Maximum notes to show."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Search notes by title.""" + + if not title.strip() and not title_contains.strip(): + raise CLIAbort("Pass --title or --title-contains to search notes.") + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + notes = _notes_service(api) + payload = _notes_call( + api, + lambda: search_notes_by_title( + notes, + title=title, + title_contains=title_contains, + limit=limit, + ), + ) + if state.json_output: + state.write_json(payload) + return + _print_note_rows(state, "Matching Notes", payload) + + +@app.command("get") +def notes_get( + ctx: typer.Context, + note_id: str = typer.Argument(..., help="Note id."), + with_attachments: bool = typer.Option( + False, + "--with-attachments", + help="Include attachment metadata.", + ), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Get one note.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + notes = _notes_service(api) + note = _notes_call( + api, lambda: notes.get(note_id, with_attachments=with_attachments) + ) + if state.json_output: + state.write_json(note) + return + state.console.print(f"{note.title} [{note.id}]") + if note.folder_name: + state.console.print(f"Folder: {note.folder_name}") + if note.modified_at: + state.console.print(f"Modified: {note.modified_at}") + if note.text: + state.console.print(note.text) + if with_attachments and note.attachments: + state.console.print( + console_table( + "Attachments", + ["ID", "Filename", "UTI", "Size"], + [(att.id, att.filename, att.uti, att.size) for att in note.attachments], + ) + ) + + +@app.command("render") +def notes_render( + ctx: typer.Context, + note_id: str = typer.Argument(..., help="Note id."), + preview_appearance: PreviewAppearance = typer.Option( + PreviewAppearance.LIGHT, + "--preview-appearance", + help="Preview appearance preference.", + ), + pdf_height: int = typer.Option( + 600, + "--pdf-height", + min=1, + help="Embedded PDF height in pixels.", + ), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Render a note to HTML.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + notes = _notes_service(api) + html = _notes_call( + api, + lambda: notes.render_note( + note_id, + preview_appearance=preview_appearance.value, + pdf_object_height=pdf_height, + ), + ) + if state.json_output: + state.write_json({"note_id": note_id, "html": html}) + return + state.console.print(html, soft_wrap=True) + + +@app.command("export") +def notes_export( + ctx: typer.Context, + note_id: str = typer.Argument(..., help="Note id."), + output_dir: Path = typer.Option(..., "--output-dir", help="Destination directory."), + export_mode: ExportMode = typer.Option( + ExportMode.ARCHIVAL, + "--export-mode", + help="Export mode.", + ), + assets_dir: Path | None = typer.Option( + None, + "--assets-dir", + help="Directory for downloaded assets in archival mode.", + ), + full_page: bool = typer.Option( + True, + "--full-page/--fragment", + help="Wrap exported output in a full HTML page.", + ), + preview_appearance: PreviewAppearance = typer.Option( + PreviewAppearance.LIGHT, + "--preview-appearance", + help="Preview appearance preference.", + ), + pdf_height: int = typer.Option( + 600, + "--pdf-height", + min=1, + help="Embedded PDF height in pixels.", + ), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Export a note to disk.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + notes = _notes_service(api) + path = _notes_call( + api, + lambda: notes.export_note( + note_id, + str(output_dir), + export_mode=export_mode.value, + assets_dir=str(assets_dir) if assets_dir else None, + full_page=full_page, + preview_appearance=preview_appearance.value, + pdf_object_height=pdf_height, + ), + ) + if state.json_output: + state.write_json({"note_id": note_id, "path": path}) + return + state.console.print(path) + + +@app.command("changes") +def notes_changes( + ctx: typer.Context, + since: str | None = typer.Option(None, "--since", help="Sync cursor."), + limit: int = typer.Option(50, "--limit", min=1, help="Maximum changes to show."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """List note changes since a cursor.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + notes = _notes_service(api) + payload = _notes_call( + api, + lambda: list(islice(notes.iter_changes(since=since), limit)), + ) + if state.json_output: + state.write_json(payload) + return + state.console.print( + console_table( + "Note Changes", + ["Type", "Note ID", "Folder", "Modified"], + [ + (row.type, row.note.id, row.note.folder_name, row.note.modified_at) + for row in payload + ], + ) + ) + + +@app.command("sync-cursor") +def notes_sync_cursor( + ctx: typer.Context, + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Print the current Notes sync cursor.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + notes = _notes_service(api) + cursor = _notes_call(api, lambda: notes.sync_cursor()) + if state.json_output: + state.write_json({"cursor": cursor}) + return + state.console.print(cursor) diff --git a/pyicloud/cli/commands/reminders.py b/pyicloud/cli/commands/reminders.py new file mode 100644 index 00000000..a9df0805 --- /dev/null +++ b/pyicloud/cli/commands/reminders.py @@ -0,0 +1,1604 @@ +"""Reminders commands.""" + +from __future__ import annotations + +from enum import Enum +from typing import Callable, TypeVar + +import typer +from pydantic import ValidationError + +from pyicloud.cli.context import CLIAbort, get_state, parse_datetime, service_call +from pyicloud.cli.options import ( + DEFAULT_LOG_LEVEL, + DEFAULT_OUTPUT_FORMAT, + HttpProxyOption, + HttpsProxyOption, + LogLevelOption, + NoVerifySslOption, + OutputFormatOption, + SessionDirOption, + UsernameOption, + store_command_options, +) +from pyicloud.cli.output import console_kv_table, console_table, format_color_value +from pyicloud.services.reminders.models import ( + AlarmWithTrigger, + ImageAttachment, + RecurrenceFrequency, + Reminder, + URLAttachment, +) +from pyicloud.services.reminders.service import Attachment, Proximity + +app = typer.Typer(help="Inspect and mutate Reminders.") +alarm_app = typer.Typer(help="Work with reminder alarms.") +attachment_app = typer.Typer(help="Work with reminder attachments.") +hashtag_app = typer.Typer(help="Work with reminder hashtags.") +recurrence_app = typer.Typer(help="Work with reminder recurrence rules.") + +REMINDERS = "Reminders" +TRelated = TypeVar("TRelated") + + +class ProximityChoice(str, Enum): + """CLI-facing proximity choice.""" + + ARRIVING = "arriving" + LEAVING = "leaving" + + +class RecurrenceFrequencyChoice(str, Enum): + """CLI-facing recurrence frequency.""" + + DAILY = "daily" + WEEKLY = "weekly" + MONTHLY = "monthly" + YEARLY = "yearly" + + +PROXIMITY_MAP = { + ProximityChoice.ARRIVING: Proximity.ARRIVING, + ProximityChoice.LEAVING: Proximity.LEAVING, +} +RECURRENCE_FREQUENCY_MAP = { + RecurrenceFrequencyChoice.DAILY: RecurrenceFrequency.DAILY, + RecurrenceFrequencyChoice.WEEKLY: RecurrenceFrequency.WEEKLY, + RecurrenceFrequencyChoice.MONTHLY: RecurrenceFrequency.MONTHLY, + RecurrenceFrequencyChoice.YEARLY: RecurrenceFrequency.YEARLY, +} + + +def _group_root(ctx: typer.Context) -> None: + """Show subgroup help when invoked without a subcommand.""" + + if ctx.invoked_subcommand is None: + typer.echo(ctx.get_help()) + raise typer.Exit() + + +app.add_typer( + alarm_app, name="alarm", invoke_without_command=True, callback=_group_root +) +app.add_typer( + hashtag_app, name="hashtag", invoke_without_command=True, callback=_group_root +) +app.add_typer( + attachment_app, + name="attachment", + invoke_without_command=True, + callback=_group_root, +) +app.add_typer( + recurrence_app, + name="recurrence", + invoke_without_command=True, + callback=_group_root, +) + + +def _normalize_prefixed_id(value: str, prefix: str) -> str: + """Return an identifier with the expected record prefix.""" + + normalized = str(value).strip() + if not normalized: + return normalized + token = f"{prefix}/" + if normalized.startswith(token): + return normalized + return f"{token}{normalized}" + + +def _id_matches(record_id: str, query: str) -> bool: + """Return whether a record id matches a full or shorthand query.""" + + normalized = str(query).strip() + if not normalized: + return False + if record_id == normalized: + return True + if "/" in record_id and record_id.split("/", 1)[1] == normalized: + return True + return False + + +def _reminders_service(api): + """Return the Reminders service with reauthentication handling.""" + + return service_call(REMINDERS, lambda: api.reminders, account_name=api.account_name) + + +def _reminders_call(api, fn): + """Wrap reminder calls with reminder-specific user-facing errors.""" + + try: + return service_call(REMINDERS, fn, account_name=api.account_name) + except LookupError as err: + raise CLIAbort(str(err)) from err + except ValidationError as err: + raise CLIAbort(str(err)) from err + + +def _resolve_reminder(api, reminder_id: str) -> Reminder: + """Return one reminder by id.""" + + reminders = _reminders_service(api) + return _reminders_call(api, lambda: reminders.get(reminder_id)) + + +def _list_reminder_rows( + api, + *, + list_id: str | None = None, + include_completed: bool, + limit: int, +) -> list[Reminder]: + """Return reminder rows using compound snapshots to preserve completion filtering.""" + + reminders = _reminders_service(api) + results_limit = max(limit, 200) + if list_id: + snapshot = _reminders_call( + api, + lambda: reminders.list_reminders( + list_id=_normalize_prefixed_id(list_id, "List"), + include_completed=include_completed, + results_limit=results_limit, + ), + ) + return snapshot.reminders[:limit] + + rows: list[Reminder] = [] + seen_ids: set[str] = set() + for reminder_list in _reminders_call(api, lambda: list(reminders.lists())): + snapshot = _reminders_call( + api, + lambda lid=reminder_list.id: reminders.list_reminders( + list_id=lid, + include_completed=include_completed, + results_limit=results_limit, + ), + ) + for reminder in snapshot.reminders: + if reminder.id in seen_ids: + continue + seen_ids.add(reminder.id) + rows.append(reminder) + if len(rows) >= limit: + return rows + return rows + + +def _resolve_related_record( + api, + reminder_id: str, + query: str, + *, + label: str, + fetch_rows: Callable[[Reminder], list[TRelated]], +) -> tuple[Reminder, TRelated]: + """Return one reminder child record matched by full or shorthand id.""" + + reminder = _resolve_reminder(api, reminder_id) + rows = _reminders_call(api, lambda: fetch_rows(reminder)) + for row in rows: + row_id = getattr(row, "id", "") + if _id_matches(row_id, query): + return reminder, row + raise CLIAbort(f"No {label} matched '{query}' for reminder {reminder.id}.") + + +def _attachment_kind(attachment: Attachment) -> str: + """Return a compact attachment type label.""" + + if isinstance(attachment, URLAttachment): + return "url" + if isinstance(attachment, ImageAttachment): + return "image" + return type(attachment).__name__.lower() + + +def _proximity_label(proximity: Proximity | None) -> str | None: + """Return a human-readable proximity label.""" + + if proximity is None: + return None + return proximity.name.lower() + + +def _frequency_label(frequency: RecurrenceFrequency | None) -> str | None: + """Return a human-readable recurrence frequency label.""" + + if frequency is None: + return None + return frequency.name.lower() + + +def _sync_cursor_payload(state, cursor: str) -> None: + """Render a sync cursor in JSON or text mode.""" + + if state.json_output: + state.write_json({"cursor": cursor}) + return + state.console.print(cursor) + + +@app.command("lists") +def reminders_lists( + ctx: typer.Context, + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """List reminder lists.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + payload = _reminders_call(api, lambda: list(reminders.lists())) + if state.json_output: + state.write_json(payload) + return + state.console.print( + console_table( + "Reminder Lists", + ["ID", "Title", "Color", "Count"], + [ + ( + row.id, + row.title, + format_color_value(row.color), + row.count, + ) + for row in payload + ], + ) + ) + + +@app.command("list") +def reminders_list( + ctx: typer.Context, + list_id: str | None = typer.Option(None, "--list-id", help="Reminder list id."), + include_completed: bool = typer.Option( + False, + "--include-completed", + help="Include completed reminders.", + ), + limit: int = typer.Option(50, "--limit", min=1, help="Maximum reminders to show."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """List reminders.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + payload = _list_reminder_rows( + api, + list_id=list_id, + include_completed=include_completed, + limit=limit, + ) + if state.json_output: + state.write_json(payload) + return + state.console.print( + console_table( + "Reminders", + ["ID", "Title", "Completed", "Due", "Priority"], + [ + ( + reminder.id, + reminder.title, + reminder.completed, + reminder.due_date, + reminder.priority, + ) + for reminder in payload + ], + ) + ) + + +@app.command("get") +def reminders_get( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Get one reminder.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminder = _resolve_reminder(api, reminder_id) + if state.json_output: + state.write_json(reminder) + return + state.console.print( + console_kv_table( + f"Reminder: {reminder.title}", + [ + ("ID", reminder.id), + ("List ID", reminder.list_id), + ("Description", reminder.desc), + ("Completed", reminder.completed), + ("Due Date", reminder.due_date), + ("Priority", reminder.priority), + ("Flagged", reminder.flagged), + ("All Day", reminder.all_day), + ("Time Zone", reminder.time_zone), + ("Parent Reminder", reminder.parent_reminder_id), + ], + ) + ) + + +@app.command("create") +def reminders_create( + ctx: typer.Context, + list_id: str = typer.Option(..., "--list-id", help="Target list id."), + title: str = typer.Option(..., "--title", help="Reminder title."), + desc: str = typer.Option("", "--desc", help="Reminder description."), + completed: bool = typer.Option( + False, + "--completed/--not-completed", + help="Create the reminder as completed or incomplete.", + ), + due_date: str | None = typer.Option(None, "--due-date", help="Due datetime."), + priority: int = typer.Option(0, "--priority", help="Apple priority number."), + flagged: bool = typer.Option( + False, + "--flagged/--not-flagged", + help="Create the reminder flagged or unflagged.", + ), + all_day: bool = typer.Option( + False, + "--all-day/--not-all-day", + help="Create the reminder as all-day or timed.", + ), + time_zone: str | None = typer.Option(None, "--time-zone", help="IANA time zone."), + parent_reminder_id: str | None = typer.Option( + None, + "--parent-reminder-id", + help="Parent reminder id for a subtask.", + ), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Create a reminder.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + reminder = _reminders_call( + api, + lambda: reminders.create( + list_id=_normalize_prefixed_id(list_id, "List"), + title=title, + desc=desc, + completed=completed, + due_date=parse_datetime(due_date), + priority=priority, + flagged=flagged, + all_day=all_day, + time_zone=time_zone, + parent_reminder_id=parent_reminder_id, + ), + ) + if state.json_output: + state.write_json(reminder) + return + state.console.print(reminder.id) + + +@app.command("update") +def reminders_update( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + title: str | None = typer.Option(None, "--title", help="Reminder title."), + desc: str | None = typer.Option(None, "--desc", help="Reminder description."), + completed: bool | None = typer.Option( + None, + "--completed/--not-completed", + help="Mark the reminder completed or incomplete.", + ), + due_date: str | None = typer.Option(None, "--due-date", help="Due datetime."), + clear_due_date: bool = typer.Option( + False, + "--clear-due-date", + help="Clear the due date.", + ), + priority: int | None = typer.Option(None, "--priority", help="Apple priority."), + flagged: bool | None = typer.Option( + None, + "--flagged/--not-flagged", + help="Flag or unflag the reminder.", + ), + all_day: bool | None = typer.Option( + None, + "--all-day/--not-all-day", + help="Mark as all-day or timed.", + ), + time_zone: str | None = typer.Option(None, "--time-zone", help="IANA time zone."), + clear_time_zone: bool = typer.Option( + False, + "--clear-time-zone", + help="Clear the time zone.", + ), + parent_reminder_id: str | None = typer.Option( + None, + "--parent-reminder-id", + help="Set the parent reminder id.", + ), + clear_parent_reminder: bool = typer.Option( + False, + "--clear-parent-reminder", + help="Clear the parent reminder id.", + ), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Update one reminder.""" + + if due_date and clear_due_date: + raise typer.BadParameter( + "Choose either --due-date or --clear-due-date, not both." + ) + if time_zone and clear_time_zone: + raise typer.BadParameter( + "Choose either --time-zone or --clear-time-zone, not both." + ) + if parent_reminder_id and clear_parent_reminder: + raise typer.BadParameter( + "Choose either --parent-reminder-id or --clear-parent-reminder, not both." + ) + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + reminder = _resolve_reminder(api, reminder_id) + changed = False + + if title is not None: + reminder.title = title + changed = True + if desc is not None: + reminder.desc = desc + changed = True + if completed is not None: + reminder.completed = completed + changed = True + if due_date is not None: + reminder.due_date = parse_datetime(due_date) + changed = True + elif clear_due_date: + reminder.due_date = None + changed = True + if priority is not None: + reminder.priority = priority + changed = True + if flagged is not None: + reminder.flagged = flagged + changed = True + if all_day is not None: + reminder.all_day = all_day + changed = True + if time_zone is not None: + reminder.time_zone = time_zone + changed = True + elif clear_time_zone: + reminder.time_zone = None + changed = True + if parent_reminder_id is not None: + reminder.parent_reminder_id = parent_reminder_id + changed = True + elif clear_parent_reminder: + reminder.parent_reminder_id = None + changed = True + + if not changed: + raise CLIAbort("No reminder updates were requested.") + + _reminders_call(api, lambda: reminders.update(reminder)) + if state.json_output: + state.write_json(reminder) + return + state.console.print(f"Updated {reminder.id}") + + +@app.command("set-status") +def reminders_set_status( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + completed: bool = typer.Option( + True, + "--completed/--not-completed", + help="Mark the reminder completed or incomplete.", + ), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Mark a reminder completed or incomplete.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + reminder = _resolve_reminder(api, reminder_id) + reminder.completed = completed + _reminders_call(api, lambda: reminders.update(reminder)) + if state.json_output: + state.write_json(reminder) + return + state.console.print(f"Updated {reminder.id}: completed={completed}") + + +@app.command("delete") +def reminders_delete( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Delete a reminder.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + reminder = _resolve_reminder(api, reminder_id) + _reminders_call(api, lambda: reminders.delete(reminder)) + if state.json_output: + state.write_json({"reminder_id": reminder.id, "deleted": True}) + return + state.console.print(f"Deleted {reminder.id}") + + +@app.command("snapshot") +def reminders_snapshot( + ctx: typer.Context, + list_id: str = typer.Option(..., "--list-id", help="Reminder list id."), + include_completed: bool = typer.Option( + False, + "--include-completed", + help="Include completed reminders.", + ), + results_limit: int = typer.Option( + 200, + "--results-limit", + min=1, + help="Maximum reminders to request from the compound query.", + ), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Fetch a compound reminder snapshot for one list.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + payload = _reminders_call( + api, + lambda: reminders.list_reminders( + list_id=_normalize_prefixed_id(list_id, "List"), + include_completed=include_completed, + results_limit=results_limit, + ), + ) + if state.json_output: + state.write_json(payload) + return + state.console.print( + console_kv_table( + "Reminder Snapshot", + [ + ("List ID", _normalize_prefixed_id(list_id, "List")), + ("Reminders", len(payload.reminders)), + ("Alarms", len(payload.alarms)), + ("Triggers", len(payload.triggers)), + ("Attachments", len(payload.attachments)), + ("Hashtags", len(payload.hashtags)), + ("Recurrence Rules", len(payload.recurrence_rules)), + ], + ) + ) + state.console.print( + console_table( + "Snapshot Reminders", + ["ID", "Title", "Completed", "Due", "Priority"], + [ + ( + reminder.id, + reminder.title, + reminder.completed, + reminder.due_date, + reminder.priority, + ) + for reminder in payload.reminders + ], + ) + ) + + +@app.command("changes") +def reminders_changes( + ctx: typer.Context, + since: str | None = typer.Option(None, "--since", help="Sync cursor."), + limit: int = typer.Option(50, "--limit", min=1, help="Maximum changes to show."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """List reminder changes since a cursor.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + payload = _reminders_call( + api, + lambda: list(reminders.iter_changes(since=since))[:limit], + ) + if state.json_output: + state.write_json(payload) + return + state.console.print( + console_table( + "Reminder Changes", + ["Type", "Reminder ID", "Title", "Completed"], + [ + ( + event.type, + event.reminder_id, + event.reminder.title if event.reminder else None, + event.reminder.completed if event.reminder else None, + ) + for event in payload + ], + ) + ) + + +@app.command("sync-cursor") +def reminders_sync_cursor( + ctx: typer.Context, + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Print the current Reminders sync cursor.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + cursor = _reminders_call(api, lambda: reminders.sync_cursor()) + _sync_cursor_payload(state, cursor) + + +@alarm_app.command("list") +def reminders_alarm_list( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """List alarms for one reminder.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + reminder = _resolve_reminder(api, reminder_id) + payload = _reminders_call(api, lambda: reminders.alarms_for(reminder)) + if state.json_output: + state.write_json(payload) + return + state.console.print( + console_table( + "Reminder Alarms", + [ + "Alarm ID", + "Trigger ID", + "Title", + "Address", + "Radius", + "Proximity", + ], + [ + ( + row.alarm.id, + row.trigger.id if row.trigger else None, + row.trigger.title if row.trigger else None, + row.trigger.address if row.trigger else None, + row.trigger.radius if row.trigger else None, + _proximity_label(row.trigger.proximity if row.trigger else None), + ) + for row in payload + ], + ) + ) + + +@alarm_app.command("add-location") +def reminders_alarm_add_location( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + title: str = typer.Option(..., "--title", help="Location title."), + address: str = typer.Option(..., "--address", help="Location address."), + latitude: float = typer.Option(..., "--latitude", help="Location latitude."), + longitude: float = typer.Option(..., "--longitude", help="Location longitude."), + radius: float = typer.Option(100.0, "--radius", min=0.0, help="Radius in meters."), + proximity: ProximityChoice = typer.Option( + ProximityChoice.ARRIVING, + "--proximity", + help="Trigger direction.", + ), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Add a location alarm to a reminder.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + reminder = _resolve_reminder(api, reminder_id) + alarm, trigger = _reminders_call( + api, + lambda: reminders.add_location_trigger( + reminder, + title=title, + address=address, + latitude=latitude, + longitude=longitude, + radius=radius, + proximity=PROXIMITY_MAP[proximity], + ), + ) + payload = AlarmWithTrigger(alarm=alarm, trigger=trigger) + if state.json_output: + state.write_json(payload) + return + state.console.print(f"Created {alarm.id} with trigger {trigger.id}") + + +@hashtag_app.command("list") +def reminders_hashtag_list( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """List hashtags for one reminder.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + reminder = _resolve_reminder(api, reminder_id) + payload = _reminders_call(api, lambda: reminders.tags_for(reminder)) + if state.json_output: + state.write_json(payload) + return + state.console.print( + console_table( + "Reminder Hashtags", + ["ID", "Name", "Reminder ID"], + [(row.id, row.name, row.reminder_id) for row in payload], + ) + ) + + +@hashtag_app.command("create") +def reminders_hashtag_create( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + name: str = typer.Argument(..., help="Hashtag name."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Create a hashtag on one reminder.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + reminder = _resolve_reminder(api, reminder_id) + payload = _reminders_call(api, lambda: reminders.create_hashtag(reminder, name)) + if state.json_output: + state.write_json(payload) + return + state.console.print(payload.id) + + +@hashtag_app.command("update") +def reminders_hashtag_update( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + hashtag_id: str = typer.Argument(..., help="Hashtag id."), + name: str = typer.Option(..., "--name", help="Updated hashtag name."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Update a hashtag name.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + _reminder, hashtag = _resolve_related_record( + api, + reminder_id, + hashtag_id, + label="hashtag", + fetch_rows=lambda reminder: reminders.tags_for(reminder), + ) + _reminders_call(api, lambda: reminders.update_hashtag(hashtag, name)) + hashtag.name = name + if state.json_output: + state.write_json(hashtag) + return + state.console.print(f"Updated {hashtag.id}") + + +@hashtag_app.command("delete") +def reminders_hashtag_delete( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + hashtag_id: str = typer.Argument(..., help="Hashtag id."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Delete a hashtag from one reminder.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + reminder, hashtag = _resolve_related_record( + api, + reminder_id, + hashtag_id, + label="hashtag", + fetch_rows=lambda row: reminders.tags_for(row), + ) + _reminders_call(api, lambda: reminders.delete_hashtag(reminder, hashtag)) + payload = {"reminder_id": reminder.id, "hashtag_id": hashtag.id, "deleted": True} + if state.json_output: + state.write_json(payload) + return + state.console.print(f"Deleted {hashtag.id}") + + +@attachment_app.command("list") +def reminders_attachment_list( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """List attachments for one reminder.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + reminder = _resolve_reminder(api, reminder_id) + payload = _reminders_call(api, lambda: reminders.attachments_for(reminder)) + if state.json_output: + state.write_json(payload) + return + state.console.print( + console_table( + "Reminder Attachments", + ["ID", "Type", "URL", "Filename", "UTI", "Size"], + [ + ( + row.id, + _attachment_kind(row), + getattr(row, "url", None), + getattr(row, "filename", None), + row.uti, + getattr(row, "file_size", None), + ) + for row in payload + ], + ) + ) + + +@attachment_app.command("create-url") +def reminders_attachment_create_url( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + url: str = typer.Option(..., "--url", help="Attachment URL."), + uti: str = typer.Option("public.url", "--uti", help="Attachment UTI."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Create a URL attachment on one reminder.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + reminder = _resolve_reminder(api, reminder_id) + payload = _reminders_call( + api, + lambda: reminders.create_url_attachment(reminder, url=url, uti=uti), + ) + if state.json_output: + state.write_json(payload) + return + state.console.print(payload.id) + + +@attachment_app.command("update") +def reminders_attachment_update( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + attachment_id: str = typer.Argument(..., help="Attachment id."), + url: str | None = typer.Option(None, "--url", help="Updated attachment URL."), + uti: str | None = typer.Option(None, "--uti", help="Updated attachment UTI."), + filename: str | None = typer.Option( + None, + "--filename", + help="Updated attachment filename.", + ), + file_size: int | None = typer.Option( + None, + "--file-size", + min=0, + help="Updated attachment size.", + ), + width: int | None = typer.Option( + None, + "--width", + min=0, + help="Updated attachment width.", + ), + height: int | None = typer.Option( + None, + "--height", + min=0, + help="Updated attachment height.", + ), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Update one attachment.""" + + if all(value is None for value in (url, uti, filename, file_size, width, height)): + raise CLIAbort("No attachment updates were requested.") + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + _reminder, attachment = _resolve_related_record( + api, + reminder_id, + attachment_id, + label="attachment", + fetch_rows=lambda reminder: reminders.attachments_for(reminder), + ) + _reminders_call( + api, + lambda: reminders.update_attachment( + attachment, + url=url, + uti=uti, + filename=filename, + file_size=file_size, + width=width, + height=height, + ), + ) + if url is not None and hasattr(attachment, "url"): + attachment.url = url + if uti is not None: + attachment.uti = uti + if filename is not None and hasattr(attachment, "filename"): + attachment.filename = filename + if file_size is not None and hasattr(attachment, "file_size"): + attachment.file_size = file_size + if width is not None and hasattr(attachment, "width"): + attachment.width = width + if height is not None and hasattr(attachment, "height"): + attachment.height = height + if state.json_output: + state.write_json(attachment) + return + state.console.print(f"Updated {attachment.id}") + + +@attachment_app.command("delete") +def reminders_attachment_delete( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + attachment_id: str = typer.Argument(..., help="Attachment id."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Delete one attachment.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + reminder, attachment = _resolve_related_record( + api, + reminder_id, + attachment_id, + label="attachment", + fetch_rows=lambda row: reminders.attachments_for(row), + ) + _reminders_call(api, lambda: reminders.delete_attachment(reminder, attachment)) + payload = { + "reminder_id": reminder.id, + "attachment_id": attachment.id, + "deleted": True, + } + if state.json_output: + state.write_json(payload) + return + state.console.print(f"Deleted {attachment.id}") + + +@recurrence_app.command("list") +def reminders_recurrence_list( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """List recurrence rules for one reminder.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + reminder = _resolve_reminder(api, reminder_id) + payload = _reminders_call(api, lambda: reminders.recurrence_rules_for(reminder)) + if state.json_output: + state.write_json(payload) + return + state.console.print( + console_table( + "Reminder Recurrence Rules", + ["ID", "Frequency", "Interval", "Occurrence Count", "First Day"], + [ + ( + row.id, + _frequency_label(row.frequency), + row.interval, + row.occurrence_count, + row.first_day_of_week, + ) + for row in payload + ], + ) + ) + + +@recurrence_app.command("create") +def reminders_recurrence_create( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + frequency: RecurrenceFrequencyChoice = typer.Option( + RecurrenceFrequencyChoice.DAILY, + "--frequency", + help="Recurrence frequency.", + ), + interval: int = typer.Option(1, "--interval", min=1, help="Recurrence interval."), + occurrence_count: int = typer.Option( + 0, + "--occurrence-count", + min=0, + help="Occurrence count; 0 means unlimited.", + ), + first_day_of_week: int = typer.Option( + 0, + "--first-day-of-week", + min=0, + max=6, + help="First day of week; 0 is Sunday.", + ), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Create a recurrence rule on one reminder.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + reminder = _resolve_reminder(api, reminder_id) + payload = _reminders_call( + api, + lambda: reminders.create_recurrence_rule( + reminder, + frequency=RECURRENCE_FREQUENCY_MAP[frequency], + interval=interval, + occurrence_count=occurrence_count, + first_day_of_week=first_day_of_week, + ), + ) + if state.json_output: + state.write_json(payload) + return + state.console.print(payload.id) + + +@recurrence_app.command("update") +def reminders_recurrence_update( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + rule_id: str = typer.Argument(..., help="Recurrence rule id."), + frequency: RecurrenceFrequencyChoice | None = typer.Option( + None, + "--frequency", + help="Recurrence frequency.", + ), + interval: int | None = typer.Option( + None, + "--interval", + min=1, + help="Recurrence interval.", + ), + occurrence_count: int | None = typer.Option( + None, + "--occurrence-count", + min=0, + help="Occurrence count; 0 means unlimited.", + ), + first_day_of_week: int | None = typer.Option( + None, + "--first-day-of-week", + min=0, + max=6, + help="First day of week; 0 is Sunday.", + ), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Update one recurrence rule.""" + + if all( + value is None + for value in (frequency, interval, occurrence_count, first_day_of_week) + ): + raise CLIAbort("No recurrence updates were requested.") + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + _reminder, recurrence_rule = _resolve_related_record( + api, + reminder_id, + rule_id, + label="recurrence rule", + fetch_rows=lambda reminder: reminders.recurrence_rules_for(reminder), + ) + _reminders_call( + api, + lambda: reminders.update_recurrence_rule( + recurrence_rule, + frequency=( + RECURRENCE_FREQUENCY_MAP[frequency] if frequency is not None else None + ), + interval=interval, + occurrence_count=occurrence_count, + first_day_of_week=first_day_of_week, + ), + ) + if frequency is not None: + recurrence_rule.frequency = RECURRENCE_FREQUENCY_MAP[frequency] + if interval is not None: + recurrence_rule.interval = interval + if occurrence_count is not None: + recurrence_rule.occurrence_count = occurrence_count + if first_day_of_week is not None: + recurrence_rule.first_day_of_week = first_day_of_week + if state.json_output: + state.write_json(recurrence_rule) + return + state.console.print(f"Updated {recurrence_rule.id}") + + +@recurrence_app.command("delete") +def reminders_recurrence_delete( + ctx: typer.Context, + reminder_id: str = typer.Argument(..., help="Reminder id."), + rule_id: str = typer.Argument(..., help="Recurrence rule id."), + username: UsernameOption = None, + session_dir: SessionDirOption = None, + http_proxy: HttpProxyOption = None, + https_proxy: HttpsProxyOption = None, + no_verify_ssl: NoVerifySslOption = False, + output_format: OutputFormatOption = DEFAULT_OUTPUT_FORMAT, + log_level: LogLevelOption = DEFAULT_LOG_LEVEL, +) -> None: + """Delete one recurrence rule.""" + + store_command_options( + ctx, + username=username, + session_dir=session_dir, + http_proxy=http_proxy, + https_proxy=https_proxy, + no_verify_ssl=no_verify_ssl, + output_format=output_format, + log_level=log_level, + ) + state = get_state(ctx) + api = state.get_api() + reminders = _reminders_service(api) + reminder, recurrence_rule = _resolve_related_record( + api, + reminder_id, + rule_id, + label="recurrence rule", + fetch_rows=lambda row: reminders.recurrence_rules_for(row), + ) + _reminders_call( + api, + lambda: reminders.delete_recurrence_rule(reminder, recurrence_rule), + ) + payload = { + "reminder_id": reminder.id, + "recurrence_rule_id": recurrence_rule.id, + "deleted": True, + } + if state.json_output: + state.write_json(payload) + return + state.console.print(f"Deleted {recurrence_rule.id}") diff --git a/pyicloud/cli/normalize.py b/pyicloud/cli/normalize.py index acc65557..4e8209b4 100644 --- a/pyicloud/cli/normalize.py +++ b/pyicloud/cli/normalize.py @@ -2,6 +2,7 @@ from __future__ import annotations +from datetime import datetime, timezone from typing import Any @@ -174,3 +175,97 @@ def normalize_alias(alias: dict[str, Any]) -> dict[str, Any]: "label": alias.get("label"), "anonymous_id": alias.get("anonymousId"), } + + +def select_recent_notes( + notes_service: Any, *, limit: int, include_deleted: bool +) -> list[Any]: + """Return recent notes, excluding deleted notes by default.""" + + if limit <= 0: + return [] + if include_deleted: + return list(notes_service.recents(limit=limit)) + + probe_limit = limit + max_probe = min(max(limit, 10) * 8, 500) + while True: + rows = list(notes_service.recents(limit=probe_limit)) + filtered = [row for row in rows if not getattr(row, "is_deleted", False)] + if ( + len(filtered) >= limit + or len(rows) < probe_limit + or probe_limit >= max_probe + ): + return filtered[:limit] + probe_limit = min(probe_limit * 2, max_probe) + + +def search_notes_by_title( + notes_service: Any, + *, + title: str | None = None, + title_contains: str | None = None, + limit: int, +) -> list[Any]: + """Return title-matched notes using recents-first search with full-scan fallback.""" + + if limit <= 0: + return [] + + exact = (title or "").strip() + contains = (title_contains or "").strip().lower() + if not exact and not contains: + return [] + + def matches(note_title: str | None) -> bool: + if not note_title: + return False + if exact and note_title == exact: + return True + if contains and contains in note_title.lower(): + return True + return False + + def dedupe_key(item: Any) -> Any: + return getattr(item, "id", None) or id(item) + + candidates: list[Any] = [] + seen: set[Any] = set() + window = max(500, limit * 50) + + for note in notes_service.recents(limit=window): + if not matches(getattr(note, "title", None)): + continue + key = dedupe_key(note) + if key in seen: + continue + seen.add(key) + candidates.append(note) + if len(candidates) >= limit: + break + + if len(candidates) < limit: + for note in notes_service.iter_all(): + if not matches(getattr(note, "title", None)): + continue + key = dedupe_key(note) + if key in seen: + continue + seen.add(key) + candidates.append(note) + if len(candidates) >= limit: + break + + epoch = datetime(1970, 1, 1, tzinfo=timezone.utc) + + def sort_key(item: Any) -> datetime: + modified_at = getattr(item, "modified_at", None) + if modified_at is None: + return epoch + if modified_at.tzinfo is None: + return modified_at.replace(tzinfo=timezone.utc) + return modified_at + + candidates.sort(key=sort_key, reverse=True) + return candidates[:limit] diff --git a/pyicloud/cli/output.py b/pyicloud/cli/output.py index 9d816da4..77256147 100644 --- a/pyicloud/cli/output.py +++ b/pyicloud/cli/output.py @@ -119,3 +119,37 @@ def print_json_text(console: Console, payload: Any) -> None: """Pretty-print a JSON object in text mode.""" console.print_json(json=to_json_string(payload, indent=2)) + + +def format_color_value(value: Any) -> str: + """Return a compact human-friendly representation of reminder colors.""" + + if not value: + return "" + + payload = value + if isinstance(value, str): + stripped = value.strip() + if not stripped: + return "" + if not stripped.startswith("{"): + return stripped + try: + payload = json.loads(stripped) + except json.JSONDecodeError: + return stripped + + if isinstance(payload, dict): + hex_value = payload.get("daHexString") + symbolic = payload.get("ckSymbolicColorName") or payload.get( + "daSymbolicColorName" + ) + if hex_value and symbolic and symbolic != "custom": + return f"{symbolic} ({hex_value})" + if hex_value: + return str(hex_value) + if symbolic: + return str(symbolic) + return to_json_string(payload) + + return str(payload) diff --git a/tests/test_cmdline.py b/tests/test_cmdline.py index 964a107b..8a1cf75c 100644 --- a/tests/test_cmdline.py +++ b/tests/test_cmdline.py @@ -16,6 +16,29 @@ import click from typer.testing import CliRunner +from pyicloud.services.notes.models import Attachment as NoteAttachment +from pyicloud.services.notes.models import ChangeEvent as NoteChangeEvent +from pyicloud.services.notes.models import ( + Note, + NoteFolder, + NoteSummary, +) +from pyicloud.services.notes.service import NoteLockedError, NoteNotFound +from pyicloud.services.reminders.models import ( + Alarm, + AlarmWithTrigger, + Hashtag, + ListRemindersResult, + LocationTrigger, + Proximity, + RecurrenceFrequency, + RecurrenceRule, + Reminder, + ReminderChangeEvent, + RemindersList, + URLAttachment, +) + account_index_module = importlib.import_module("pyicloud.cli.account_index") cli_module = importlib.import_module("pyicloud.cli.app") context_module = importlib.import_module("pyicloud.cli.context") @@ -195,6 +218,613 @@ def delete(self, anonymous_id: str) -> dict[str, Any]: return {"anonymousId": anonymous_id, "deleted": True} +class FakeNotes: + """Notes service fixture.""" + + def __init__(self) -> None: + attachment = NoteAttachment( + id="Attachment/PDF", + filename="agenda.pdf", + uti="com.adobe.pdf", + size=12, + download_url="https://example.com/agenda.pdf", + preview_url="https://example.com/agenda-preview.pdf", + thumbnail_url="https://example.com/agenda-thumb.png", + ) + self.recent_requests: list[int] = [] + self.iter_all_requests: list[str | None] = [] + self.folder_requests: list[tuple[str, int | None]] = [] + self.render_calls: list[dict[str, Any]] = [] + self.export_calls: list[dict[str, Any]] = [] + self.change_requests: list[str | None] = [] + self.folder_rows = [ + NoteFolder( + id="Folder/NOTES", + name="Notes", + has_subfolders=False, + count=1, + ), + NoteFolder( + id="Folder/WORK", + name="Work", + has_subfolders=True, + count=3, + ), + ] + self.recent_rows = [ + NoteSummary( + id="Note/DELETED", + title="Deleted Note", + snippet="Old note", + modified_at=datetime(2026, 3, 5, tzinfo=timezone.utc), + folder_id="Folder/DELETED", + folder_name="Recently Deleted", + is_deleted=True, + is_locked=False, + ), + NoteSummary( + id="Note/DAILY", + title="Daily Plan", + snippet="Ship CLI", + modified_at=datetime(2026, 3, 4, tzinfo=timezone.utc), + folder_id="Folder/NOTES", + folder_name="Notes", + is_deleted=False, + is_locked=False, + ), + NoteSummary( + id="Note/MEETING", + title="Meeting Notes", + snippet="Discuss roadmap", + modified_at=datetime(2026, 3, 3, tzinfo=timezone.utc), + folder_id="Folder/WORK", + folder_name="Work", + is_deleted=False, + is_locked=False, + ), + ] + self.all_rows = [ + self.recent_rows[2], + NoteSummary( + id="Note/FOLLOWUP", + title="Meeting Follow-up", + snippet="Send recap", + modified_at=datetime(2026, 3, 2, tzinfo=timezone.utc), + folder_id="Folder/WORK", + folder_name="Work", + is_deleted=False, + is_locked=False, + ), + self.recent_rows[1], + self.recent_rows[2], + ] + self.notes = { + "Note/DAILY": Note( + id="Note/DAILY", + title="Daily Plan", + snippet="Ship CLI", + modified_at=datetime(2026, 3, 4, tzinfo=timezone.utc), + folder_id="Folder/NOTES", + folder_name="Notes", + is_deleted=False, + is_locked=False, + text="Ship CLI", + html="

Ship CLI

", + attachments=[attachment], + ), + "Note/MEETING": Note( + id="Note/MEETING", + title="Meeting Notes", + snippet="Discuss roadmap", + modified_at=datetime(2026, 3, 3, tzinfo=timezone.utc), + folder_id="Folder/WORK", + folder_name="Work", + is_deleted=False, + is_locked=False, + text="Discuss roadmap", + html="

Discuss roadmap

", + attachments=[attachment], + ), + "Note/FOLLOWUP": Note( + id="Note/FOLLOWUP", + title="Meeting Follow-up", + snippet="Send recap", + modified_at=datetime(2026, 3, 2, tzinfo=timezone.utc), + folder_id="Folder/WORK", + folder_name="Work", + is_deleted=False, + is_locked=False, + text="Send recap", + html="

Send recap

", + attachments=None, + ), + } + self.change_rows = [ + NoteChangeEvent(type="updated", note=self.recent_rows[1]), + NoteChangeEvent(type="deleted", note=self.recent_rows[0]), + ] + self.cursor = "notes-cursor-1" + + @staticmethod + def _matches_id(note_id: str, query: str) -> bool: + return note_id == query or note_id.split("/", 1)[-1] == query + + def recents(self, *, limit: int = 50): + self.recent_requests.append(limit) + return list(self.recent_rows[:limit]) + + def folders(self): + return list(self.folder_rows) + + def in_folder(self, folder_id: str, limit: int | None = None): + self.folder_requests.append((folder_id, limit)) + rows = [row for row in self.all_rows if row.folder_id == folder_id] + return list(rows[:limit] if limit is not None else rows) + + def iter_all(self, *, since: Optional[str] = None): + self.iter_all_requests.append(since) + return iter(self.all_rows) + + def get(self, note_id: str, *, with_attachments: bool = False): + if self._matches_id("Note/LOCKED", note_id): + raise NoteLockedError(f"Note is locked: {note_id}") + for candidate_id, note in self.notes.items(): + if self._matches_id(candidate_id, note_id): + attachments = note.attachments if with_attachments else None + return note.model_copy(update={"attachments": attachments}) + raise NoteNotFound(f"Note not found: {note_id}") + + def render_note(self, note_id: str, **kwargs: Any) -> str: + note = self.get(note_id, with_attachments=False) + self.render_calls.append({"note_id": note.id, **kwargs}) + return note.html or f"

{note.id}

" + + def export_note(self, note_id: str, output_dir: str, **kwargs: Any) -> str: + note = self.get(note_id, with_attachments=False) + path = Path(output_dir) / f"{note.id.split('/', 1)[-1].lower()}.html" + self.export_calls.append( + {"note_id": note.id, "output_dir": output_dir, **kwargs} + ) + return str(path) + + def iter_changes(self, *, since: Optional[str] = None): + self.change_requests.append(since) + return iter(self.change_rows) + + def sync_cursor(self) -> str: + return self.cursor + + +class FakeReminders: + """Reminders service fixture.""" + + def __init__(self) -> None: + self.list_rows = { + "List/INBOX": RemindersList( + id="List/INBOX", + title="Inbox", + color='{"daHexString":"#007AFF","ckSymbolicColorName":"blue"}', + count=0, + ), + "List/WORK": RemindersList( + id="List/WORK", + title="Work", + color='{"daHexString":"#34C759","ckSymbolicColorName":"green"}', + count=0, + ), + } + self.reminder_rows = { + "Reminder/A": Reminder( + id="Reminder/A", + list_id="List/INBOX", + title="Buy milk", + desc="2 percent", + completed=False, + due_date=datetime(2026, 3, 31, 9, 0, tzinfo=timezone.utc), + priority=1, + flagged=True, + all_day=False, + time_zone="Europe/Luxembourg", + alarm_ids=["Alarm/A"], + hashtag_ids=["Hashtag/ERRANDS"], + attachment_ids=["Attachment/LINK"], + recurrence_rule_ids=["Recurrence/WEEKLY"], + parent_reminder_id="Reminder/PARENT", + created=datetime(2026, 3, 1, tzinfo=timezone.utc), + modified=datetime(2026, 3, 4, tzinfo=timezone.utc), + ), + "Reminder/B": Reminder( + id="Reminder/B", + list_id="List/INBOX", + title="Pay rent", + desc="", + completed=True, + completed_date=datetime(2026, 3, 2, tzinfo=timezone.utc), + priority=0, + flagged=False, + all_day=False, + created=datetime(2026, 3, 1, tzinfo=timezone.utc), + modified=datetime(2026, 3, 2, tzinfo=timezone.utc), + ), + "Reminder/C": Reminder( + id="Reminder/C", + list_id="List/WORK", + title="Prepare deck", + desc="Slides for review", + completed=False, + priority=5, + flagged=False, + all_day=False, + created=datetime(2026, 3, 3, tzinfo=timezone.utc), + modified=datetime(2026, 3, 4, tzinfo=timezone.utc), + ), + } + self.alarm_rows = { + "Alarm/A": Alarm( + id="Alarm/A", + alarm_uid="alarm-a", + reminder_id="Reminder/A", + trigger_id="Trigger/A", + ) + } + self.trigger_rows = { + "Trigger/A": LocationTrigger( + id="Trigger/A", + alarm_id="Alarm/A", + title="Office", + address="1 Infinite Loop", + latitude=37.3318, + longitude=-122.0312, + radius=150.0, + proximity=Proximity.ARRIVING, + location_uid="office", + ) + } + self.hashtag_rows = { + "Hashtag/ERRANDS": Hashtag( + id="Hashtag/ERRANDS", + name="errands", + reminder_id="Reminder/A", + created=datetime(2026, 3, 1, tzinfo=timezone.utc), + ) + } + self.attachment_rows = { + "Attachment/LINK": URLAttachment( + id="Attachment/LINK", + reminder_id="Reminder/A", + url="https://example.com/checklist", + uti="public.url", + ) + } + self.recurrence_rows = { + "Recurrence/WEEKLY": RecurrenceRule( + id="Recurrence/WEEKLY", + reminder_id="Reminder/A", + frequency=RecurrenceFrequency.WEEKLY, + interval=1, + occurrence_count=0, + first_day_of_week=1, + ) + } + self.snapshot_requests: list[dict[str, Any]] = [] + self.change_requests: list[str | None] = [] + self.cursor = "reminders-cursor-1" + + @staticmethod + def _matches_id(record_id: str, query: str) -> bool: + return record_id == query or record_id.split("/", 1)[-1] == query + + def _find_reminder(self, reminder_id: str) -> Reminder: + for candidate_id, reminder in self.reminder_rows.items(): + if self._matches_id(candidate_id, reminder_id): + return reminder + raise LookupError(f"Reminder not found: {reminder_id}") + + def lists(self): + for row in self.list_rows.values(): + row.count = sum( + 1 + for reminder in self.reminder_rows.values() + if reminder.list_id == row.id and not reminder.deleted + ) + return list(self.list_rows.values()) + + def reminders(self, list_id: Optional[str] = None): + rows = [ + reminder + for reminder in self.reminder_rows.values() + if not reminder.deleted and (list_id is None or reminder.list_id == list_id) + ] + return list(rows) + + def list_reminders( + self, + list_id: str, + include_completed: bool = False, + results_limit: int = 200, + ) -> ListRemindersResult: + normalized = list_id if list_id.startswith("List/") else f"List/{list_id}" + self.snapshot_requests.append( + { + "list_id": normalized, + "include_completed": include_completed, + "results_limit": results_limit, + } + ) + reminders = [ + reminder + for reminder in self.reminder_rows.values() + if reminder.list_id == normalized + and not reminder.deleted + and (include_completed or not reminder.completed) + ][:results_limit] + reminder_ids = {reminder.id for reminder in reminders} + return ListRemindersResult( + reminders=reminders, + alarms={ + alarm_id: alarm + for alarm_id, alarm in self.alarm_rows.items() + if alarm.reminder_id in reminder_ids + }, + triggers={ + trigger_id: trigger + for trigger_id, trigger in self.trigger_rows.items() + if any( + alarm.trigger_id == trigger_id + for alarm in self.alarm_rows.values() + if alarm.reminder_id in reminder_ids + ) + }, + attachments={ + attachment_id: attachment + for attachment_id, attachment in self.attachment_rows.items() + if attachment.reminder_id in reminder_ids + }, + hashtags={ + hashtag_id: hashtag + for hashtag_id, hashtag in self.hashtag_rows.items() + if hashtag.reminder_id in reminder_ids + }, + recurrence_rules={ + rule_id: rule + for rule_id, rule in self.recurrence_rows.items() + if rule.reminder_id in reminder_ids + }, + ) + + def get(self, reminder_id: str) -> Reminder: + return self._find_reminder(reminder_id) + + def create( + self, + list_id: str, + title: str, + desc: str = "", + completed: bool = False, + due_date: Optional[datetime] = None, + priority: int = 0, + flagged: bool = False, + all_day: bool = False, + time_zone: Optional[str] = None, + parent_reminder_id: Optional[str] = None, + ) -> Reminder: + next_id = f"Reminder/CREATED-{len(self.reminder_rows) + 1}" + reminder = Reminder( + id=next_id, + list_id=list_id, + title=title, + desc=desc, + completed=completed, + due_date=due_date, + priority=priority, + flagged=flagged, + all_day=all_day, + time_zone=time_zone, + parent_reminder_id=parent_reminder_id, + created=datetime(2026, 3, 30, tzinfo=timezone.utc), + modified=datetime(2026, 3, 30, tzinfo=timezone.utc), + ) + self.reminder_rows[reminder.id] = reminder + return reminder + + def update(self, reminder: Reminder) -> None: + self.reminder_rows[reminder.id] = reminder + + def delete(self, reminder: Reminder) -> None: + reminder.deleted = True + self.reminder_rows[reminder.id] = reminder + + def add_location_trigger( + self, + reminder: Reminder, + title: str = "", + address: str = "", + latitude: float = 0.0, + longitude: float = 0.0, + radius: float = 100.0, + proximity: Proximity = Proximity.ARRIVING, + ) -> tuple[Alarm, LocationTrigger]: + index = len(self.alarm_rows) + 1 + alarm = Alarm( + id=f"Alarm/{index}", + alarm_uid=f"alarm-{index}", + reminder_id=reminder.id, + trigger_id=f"Trigger/{index}", + ) + trigger = LocationTrigger( + id=f"Trigger/{index}", + alarm_id=alarm.id, + title=title, + address=address, + latitude=latitude, + longitude=longitude, + radius=radius, + proximity=proximity, + location_uid=f"location-{index}", + ) + self.alarm_rows[alarm.id] = alarm + self.trigger_rows[trigger.id] = trigger + reminder.alarm_ids.append(alarm.id) + return alarm, trigger + + def create_hashtag(self, reminder: Reminder, name: str) -> Hashtag: + hashtag = Hashtag( + id=f"Hashtag/{name.upper()}", + name=name, + reminder_id=reminder.id, + created=datetime(2026, 3, 30, tzinfo=timezone.utc), + ) + self.hashtag_rows[hashtag.id] = hashtag + reminder.hashtag_ids.append(hashtag.id) + return hashtag + + def update_hashtag(self, hashtag: Hashtag, name: str) -> None: + hashtag.name = name + + def delete_hashtag(self, reminder: Reminder, hashtag: Hashtag) -> None: + reminder.hashtag_ids = [ + row_id for row_id in reminder.hashtag_ids if row_id != hashtag.id + ] + self.hashtag_rows.pop(hashtag.id, None) + + def create_url_attachment( + self, reminder: Reminder, url: str, uti: str = "public.url" + ) -> URLAttachment: + attachment = URLAttachment( + id=f"Attachment/{len(self.attachment_rows) + 1}", + reminder_id=reminder.id, + url=url, + uti=uti, + ) + self.attachment_rows[attachment.id] = attachment + reminder.attachment_ids.append(attachment.id) + return attachment + + def update_attachment( + self, + attachment: URLAttachment, + *, + url: Optional[str] = None, + uti: Optional[str] = None, + filename: Optional[str] = None, + file_size: Optional[int] = None, + width: Optional[int] = None, + height: Optional[int] = None, + ) -> None: + if url is not None: + attachment.url = url + if uti is not None: + attachment.uti = uti + + def delete_attachment(self, reminder: Reminder, attachment: URLAttachment) -> None: + reminder.attachment_ids = [ + row_id for row_id in reminder.attachment_ids if row_id != attachment.id + ] + self.attachment_rows.pop(attachment.id, None) + + def create_recurrence_rule( + self, + reminder: Reminder, + *, + frequency: RecurrenceFrequency = RecurrenceFrequency.DAILY, + interval: int = 1, + occurrence_count: int = 0, + first_day_of_week: int = 0, + ) -> RecurrenceRule: + rule = RecurrenceRule( + id=f"Recurrence/{len(self.recurrence_rows) + 1}", + reminder_id=reminder.id, + frequency=frequency, + interval=interval, + occurrence_count=occurrence_count, + first_day_of_week=first_day_of_week, + ) + self.recurrence_rows[rule.id] = rule + reminder.recurrence_rule_ids.append(rule.id) + return rule + + def update_recurrence_rule( + self, + recurrence_rule: RecurrenceRule, + *, + frequency: Optional[RecurrenceFrequency] = None, + interval: Optional[int] = None, + occurrence_count: Optional[int] = None, + first_day_of_week: Optional[int] = None, + ) -> None: + if frequency is not None: + recurrence_rule.frequency = frequency + if interval is not None: + recurrence_rule.interval = interval + if occurrence_count is not None: + recurrence_rule.occurrence_count = occurrence_count + if first_day_of_week is not None: + recurrence_rule.first_day_of_week = first_day_of_week + + def delete_recurrence_rule( + self, reminder: Reminder, recurrence_rule: RecurrenceRule + ) -> None: + reminder.recurrence_rule_ids = [ + row_id + for row_id in reminder.recurrence_rule_ids + if row_id != recurrence_rule.id + ] + self.recurrence_rows.pop(recurrence_rule.id, None) + + def alarms_for(self, reminder: Reminder) -> list[AlarmWithTrigger]: + rows = [] + for alarm_id in reminder.alarm_ids: + alarm = self.alarm_rows[alarm_id] + rows.append( + AlarmWithTrigger( + alarm=alarm, + trigger=self.trigger_rows.get(alarm.trigger_id), + ) + ) + return rows + + def tags_for(self, reminder: Reminder) -> list[Hashtag]: + return [ + self.hashtag_rows[row_id] + for row_id in reminder.hashtag_ids + if row_id in self.hashtag_rows + ] + + def attachments_for(self, reminder: Reminder) -> list[URLAttachment]: + return [ + self.attachment_rows[row_id] + for row_id in reminder.attachment_ids + if row_id in self.attachment_rows + ] + + def recurrence_rules_for(self, reminder: Reminder) -> list[RecurrenceRule]: + return [ + self.recurrence_rows[row_id] + for row_id in reminder.recurrence_rule_ids + if row_id in self.recurrence_rows + ] + + def iter_changes(self, *, since: Optional[str] = None): + self.change_requests.append(since) + return iter( + [ + ReminderChangeEvent( + type="updated", + reminder_id="Reminder/A", + reminder=self.reminder_rows["Reminder/A"], + ), + ReminderChangeEvent( + type="deleted", + reminder_id="Reminder/Z", + reminder=None, + ), + ] + ) + + def sync_cursor(self) -> str: + return self.cursor + + class FakeAPI: """Authenticated API fixture.""" @@ -211,6 +841,9 @@ def __init__( self.is_china_mainland = china_mainland self.fido2_devices: list[dict[str, Any]] = [] self.trusted_devices: list[dict[str, Any]] = [] + self.two_factor_delivery_method = "unknown" + self.two_factor_delivery_notice = None + self.request_2fa_code = MagicMock(return_value=False) self.validate_2fa_code = MagicMock(return_value=True) self.confirm_security_key = MagicMock(return_value=True) self.send_verification_code = MagicMock(return_value=True) @@ -322,6 +955,8 @@ def __init__( all=photo_album, ) self.hidemyemail = FakeHideMyEmail() + self.notes = FakeNotes() + self.reminders = FakeReminders() def _logout( self, @@ -521,6 +1156,8 @@ def test_root_help() -> None: "drive", "photos", "hidemyemail", + "notes", + "reminders", ): assert command in text @@ -537,6 +1174,8 @@ def test_group_help() -> None: "drive", "photos", "hidemyemail", + "notes", + "reminders", ): result = _runner().invoke(app, [command, "--help"]) assert result.exit_code == 0 @@ -554,6 +1193,8 @@ def test_bare_group_invocation_shows_help() -> None: "drive", "photos", "hidemyemail", + "notes", + "reminders", ): result = _runner().invoke(app, [command]) text = _plain_output(result) @@ -562,6 +1203,22 @@ def test_bare_group_invocation_shows_help() -> None: assert "Missing command" not in text +def test_notes_and_reminders_leaf_help() -> None: + """New service groups and reminder subgroups should expose leaf help.""" + + for cli_args in ( + ["notes", "search", "--help"], + ["reminders", "create", "--help"], + ["reminders", "alarm", "--help"], + ["reminders", "alarm", "add-location", "--help"], + ["reminders", "hashtag", "--help"], + ["reminders", "attachment", "--help"], + ["reminders", "recurrence", "--help"], + ): + result = _runner().invoke(app, cli_args) + assert result.exit_code == 0 + + def test_leaf_help_includes_execution_context_options() -> None: """Leaf command help should show the command-local options it supports.""" @@ -1679,462 +2336,646 @@ def test_trusted_device_2sa_flow() -> None: ) -def test_non_interactive_2sa_does_not_send_verification_code() -> None: - """Non-interactive 2SA should fail before sending a verification code.""" +def test_notes_commands() -> None: + """Notes commands should expose list, detail, render, export, and sync flows.""" fake_api = FakeAPI() - fake_api.requires_2sa = True - fake_api.trusted_devices = [{"deviceName": "Trusted Device", "phoneNumber": "+1"}] - - result = _invoke(fake_api, "auth", "login", interactive=False) - - assert result.exit_code != 0 - assert result.exception.args[0] == ( - "Two-step authentication is required, but interactive prompts are disabled." - ) - fake_api.send_verification_code.assert_not_called() - -def test_devices_list_and_show_commands() -> None: - """Devices list and show should expose summary and detailed views.""" + recent_result = _invoke(fake_api, "notes", "recent") + assert recent_result.exit_code == 0 + assert "Daily Plan" in recent_result.stdout + assert "Deleted Note" not in recent_result.stdout - fake_api = FakeAPI() - list_result = _invoke(fake_api, "devices", "list", "--locate") - show_result = _invoke(fake_api, "devices", "show", "device-1") - raw_result = _invoke( + recent_json_result = _invoke( fake_api, - "devices", - "show", - "device-1", - "--raw", + "notes", + "recent", + "--include-deleted", output_format="json", ) - assert list_result.exit_code == 0 - assert "Example iPhone" in list_result.stdout - assert show_result.exit_code == 0 - assert "Battery Status" in show_result.stdout - assert raw_result.exit_code == 0 - assert json.loads(raw_result.stdout)["deviceDisplayName"] == "iPhone" - - -def test_devices_show_reports_reauthentication_requirement() -> None: - """Device resolution should collapse reauth failures into a CLIAbort.""" - - session_dir = _unique_session_dir("devices-show-reauth") - - class ReauthAPI: - def __init__(self) -> None: - self.account_name = "user@example.com" - self.is_china_mainland = False - self.session = SimpleNamespace( - session_path=str(session_dir / "userexamplecom.session"), - cookiejar_path=str(session_dir / "userexamplecom.cookiejar"), - ) - self.get_auth_status = MagicMock( - return_value={ - "authenticated": True, - "trusted_session": True, - "requires_2fa": False, - "requires_2sa": False, - } - ) - - @property - def devices(self): - raise context_module.PyiCloudFailedLoginException("No password set") - - result = _invoke( - ReauthAPI(), - "devices", - "show", - "Example iPhone", - session_dir=session_dir, - ) - - assert result.exit_code != 0 - assert result.exception.args[0] == ( - "Find My requires re-authentication for user@example.com. " - "Run: icloud auth login --username user@example.com" - ) + recent_payload = json.loads(recent_json_result.stdout) + assert recent_json_result.exit_code == 0 + assert [row["id"] for row in recent_payload] == [ + "Note/DELETED", + "Note/DAILY", + "Note/MEETING", + ] + folders_result = _invoke(fake_api, "notes", "folders") + assert folders_result.exit_code == 0 + assert "Work" in folders_result.stdout -def test_account_summary_reports_reauthentication_requirement() -> None: - """Account commands should collapse reauth failures into a CLIAbort.""" - - session_dir = _unique_session_dir("account-summary-reauth") - - class ReauthAPI: - def __init__(self) -> None: - self.account_name = "user@example.com" - self.is_china_mainland = False - self.session = SimpleNamespace( - session_path=str(session_dir / "userexamplecom.session"), - cookiejar_path=str(session_dir / "userexamplecom.cookiejar"), - ) - self.get_auth_status = MagicMock( - return_value={ - "authenticated": True, - "trusted_session": True, - "requires_2fa": False, - "requires_2sa": False, - } - ) - - @property - def account(self): - raise context_module.PyiCloudFailedLoginException("No password set") - - result = _invoke( - ReauthAPI(), - "account", - "summary", - session_dir=session_dir, - ) - - assert result.exit_code != 0 - assert result.exception.args[0] == ( - "Account requires re-authentication for user@example.com. " - "Run: icloud auth login --username user@example.com" + folder_list_result = _invoke( + fake_api, + "notes", + "list", + "--folder-id", + "Folder/WORK", + "--limit", + "2", + output_format="json", ) + folder_payload = json.loads(folder_list_result.stdout) + assert folder_list_result.exit_code == 0 + assert [row["id"] for row in folder_payload] == ["Note/MEETING", "Note/FOLLOWUP"] - -def test_devices_mutations_and_export() -> None: - """Device actions should map to the Find My device methods.""" - - fake_api = FakeAPI() - export_path = TEST_ROOT / "device.json" - export_path.parent.mkdir(parents=True, exist_ok=True) - sound_result = _invoke( + all_notes_result = _invoke( fake_api, - "devices", - "sound", - "device-1", - "--subject", - "Ping", + "notes", + "list", + "--all", + "--since", + "notes-prev", + "--limit", + "2", output_format="json", ) - silent_result = _invoke( + all_payload = json.loads(all_notes_result.stdout) + assert all_notes_result.exit_code == 0 + assert fake_api.notes.iter_all_requests[-1] == "notes-prev" + assert [row["id"] for row in all_payload] == ["Note/MEETING", "Note/FOLLOWUP"] + + get_result = _invoke( fake_api, - "devices", - "message", - "device-1", - "Hello", - "--silent", + "notes", + "get", + "Note/DAILY", + "--with-attachments", + output_format="json", ) - lost_result = _invoke( + get_payload = json.loads(get_result.stdout) + assert get_result.exit_code == 0 + assert get_payload["attachments"][0]["id"] == "Attachment/PDF" + + render_result = _invoke( fake_api, - "devices", - "lost-mode", - "device-1", - "--phone", - "123", - "--message", - "Lost", - "--passcode", - "4567", + "notes", + "render", + "Note/DAILY", + "--preview-appearance", + "dark", + "--pdf-height", + "720", + output_format="json", ) + render_payload = json.loads(render_result.stdout) + assert render_result.exit_code == 0 + assert render_payload["html"] == "

Ship CLI

" + assert fake_api.notes.render_calls[-1]["preview_appearance"] == "dark" + assert fake_api.notes.render_calls[-1]["pdf_object_height"] == 720 + export_result = _invoke( fake_api, - "devices", + "notes", "export", - "device-1", - "--output", - str(export_path), + "Note/DAILY", + "--output-dir", + str(TEST_ROOT / "notes-export"), + "--export-mode", + "lightweight", + "--fragment", + "--preview-appearance", + "dark", + "--pdf-height", + "480", output_format="json", ) - assert sound_result.exit_code == 0 - assert json.loads(sound_result.stdout)["subject"] == "Ping" - assert fake_api.devices[0].sound_subject == "Ping" - assert silent_result.exit_code == 0 - assert fake_api.devices[0].messages[-1]["sounds"] is False - assert lost_result.exit_code == 0 - assert fake_api.devices[0].lost_mode == { - "number": "123", - "text": "Lost", - "newpasscode": "4567", - } - assert export_result.exit_code == 0 export_payload = json.loads(export_result.stdout) - written_payload = json.loads(export_path.read_text(encoding="utf-8")) - assert export_payload["path"] == str(export_path) - assert export_payload["raw"] is False - assert written_payload["name"] == "Example iPhone" - assert written_payload["display_name"] == "iPhone" - assert "raw_data" in written_payload - assert "deviceDisplayName" not in written_payload - - raw_export_path = TEST_ROOT / "device-raw.json" - raw_export_result = _invoke( + assert export_result.exit_code == 0 + assert export_payload["path"].endswith("daily.html") + assert fake_api.notes.export_calls[-1]["export_mode"] == "lightweight" + assert fake_api.notes.export_calls[-1]["full_page"] is False + assert fake_api.notes.export_calls[-1]["preview_appearance"] == "dark" + assert fake_api.notes.export_calls[-1]["pdf_object_height"] == 480 + + changes_result = _invoke( fake_api, - "devices", - "export", - "device-1", - "--output", - str(raw_export_path), - "--raw", + "notes", + "changes", + "--since", + "notes-prev", + "--limit", + "1", output_format="json", ) - no_raw_export_path = TEST_ROOT / "device-no-raw.json" - no_raw_export_result = _invoke( + changes_payload = json.loads(changes_result.stdout) + assert changes_result.exit_code == 0 + assert fake_api.notes.change_requests[-1] == "notes-prev" + assert changes_payload[0]["type"] == "updated" + + cursor_result = _invoke(fake_api, "notes", "sync-cursor") + assert cursor_result.exit_code == 0 + assert cursor_result.stdout.strip() == "notes-cursor-1" + + +def test_notes_search_uses_recents_first_and_fallback() -> None: + """Notes search should probe recents first, fall back to iter_all, and dedupe.""" + + fake_api = FakeAPI() + + result = _invoke( fake_api, - "devices", - "export", - "device-1", - "--output", - str(no_raw_export_path), - "--no-raw", + "notes", + "search", + "--title-contains", + "Meeting", + "--limit", + "2", output_format="json", ) - assert raw_export_result.exit_code == 0 - assert json.loads(raw_export_result.stdout)["raw"] is True - assert "deviceDisplayName" in json.loads( - raw_export_path.read_text(encoding="utf-8") - ) - assert no_raw_export_result.exit_code == 0 - assert json.loads(no_raw_export_result.stdout)["raw"] is False - assert "display_name" in json.loads(no_raw_export_path.read_text(encoding="utf-8")) + payload = json.loads(result.stdout) + assert result.exit_code == 0 + assert [row["id"] for row in payload] == ["Note/MEETING", "Note/FOLLOWUP"] + assert fake_api.notes.recent_requests[-1] == 500 + assert fake_api.notes.iter_all_requests == [None] -def test_device_mutation_reports_reauthentication_requirement() -> None: - """Mutating Find My commands should surface a clean reauthentication message.""" + +def test_notes_commands_report_errors() -> None: + """Notes commands should surface clean selection and note-specific errors.""" fake_api = FakeAPI() - fake_api.devices[0].play_sound = MagicMock( - side_effect=context_module.PyiCloudFailedLoginException("No password set") + + search_result = _invoke(fake_api, "notes", "search") + assert search_result.exit_code != 0 + assert search_result.exception.args[0] == ( + "Pass --title or --title-contains to search notes." ) - result = _invoke(fake_api, "devices", "sound", "device-1") + missing_result = _invoke(fake_api, "notes", "get", "Note/MISSING") + assert missing_result.exit_code != 0 + assert missing_result.exception.args[0] == "Note not found: Note/MISSING" - assert result.exit_code != 0 - assert result.exception.args[0] == ( - "Find My requires re-authentication for user@example.com. " - "Run: icloud auth login --username user@example.com" - ) + locked_result = _invoke(fake_api, "notes", "get", "Note/LOCKED") + assert locked_result.exit_code != 0 + assert locked_result.exception.args[0] == "Note is locked: Note/LOCKED" -def test_destructive_device_commands_require_unique_match() -> None: - """Lost mode should require an unambiguous device name or an explicit device id.""" +def test_notes_commands_report_reauthentication_and_unavailability() -> None: + """Notes commands should wrap service reauth and service-unavailable failures.""" - fake_api = FakeAPI() - duplicate = FakeDevice() - duplicate.id = "device-2" - duplicate.data["id"] = duplicate.id - fake_api.devices = [fake_api.devices[0], duplicate] + class ReauthNotes: + def recents(self, *, limit: int = 50): + raise context_module.PyiCloudFailedLoginException("No password set") - result = _invoke(fake_api, "devices", "lost-mode", "Example iPhone") + class UnavailableNotes: + def sync_cursor(self) -> str: + raise context_module.PyiCloudServiceUnavailable("temporarily unavailable") - assert result.exit_code != 0 - assert result.exception.args[0] == ( - "Multiple devices matched 'Example iPhone'. Use a device id instead.\n" - " - device-1 (Example iPhone / iPhone)\n" - " - device-2 (Example iPhone / iPhone)" + fake_api = FakeAPI() + fake_api.notes = ReauthNotes() + reauth_result = _invoke(fake_api, "notes", "recent") + assert reauth_result.exit_code != 0 + assert reauth_result.exception.args[0] == ( + "Notes requires re-authentication for user@example.com. " + "Run: icloud auth login --username user@example.com" ) + fake_api = FakeAPI() + fake_api.notes = UnavailableNotes() + unavailable_result = _invoke(fake_api, "notes", "sync-cursor") + assert unavailable_result.exit_code != 0 + assert unavailable_result.exception.args[0] == ( + "Notes service unavailable: temporarily unavailable" + ) -def test_calendar_and_contacts_commands() -> None: - """Calendar and contacts groups should expose read commands.""" + +def test_reminders_core_commands() -> None: + """Reminders core commands should expose list, detail, mutation, and sync flows.""" fake_api = FakeAPI() - calendars = _invoke(fake_api, "calendar", "calendars") - contacts = _invoke(fake_api, "contacts", "me") - assert calendars.exit_code == 0 - assert "Home" in calendars.stdout - assert contacts.exit_code == 0 - assert "John Appleseed" in contacts.stdout + lists_result = _invoke(fake_api, "reminders", "lists") + assert lists_result.exit_code == 0 + assert "Inbox" in lists_result.stdout + assert "blue (#007AFF)" in lists_result.stdout -def test_drive_and_photos_commands() -> None: - """Drive and photos commands should expose listing and download flows.""" + list_result = _invoke(fake_api, "reminders", "list", output_format="json") + list_payload = json.loads(list_result.stdout) + assert list_result.exit_code == 0 + assert [row["id"] for row in list_payload] == ["Reminder/A", "Reminder/C"] + assert all(not row["completed"] for row in list_payload) - fake_api = FakeAPI() - output_path = TEST_ROOT / "photo.bin" - json_output_path = TEST_ROOT / "report.txt" - output_path.parent.mkdir(parents=True, exist_ok=True) - drive_result = _invoke(fake_api, "drive", "list", "/") - photo_result = _invoke( + completed_result = _invoke( fake_api, - "photos", - "download", - "photo-1", - "--output", - str(output_path), + "reminders", + "list", + "--list-id", + "INBOX", + "--include-completed", + output_format="json", ) - json_drive_result = _invoke( + completed_payload = json.loads(completed_result.stdout) + assert completed_result.exit_code == 0 + assert [row["id"] for row in completed_payload] == ["Reminder/A", "Reminder/B"] + assert fake_api.reminders.snapshot_requests[-1]["list_id"] == "List/INBOX" + + get_result = _invoke(fake_api, "reminders", "get", "Reminder/A") + assert get_result.exit_code == 0 + assert "Parent Reminder" in get_result.stdout + + create_result = _invoke( fake_api, - "drive", - "download", - "/report.txt", - "--output", - str(json_output_path), + "reminders", + "create", + "--list-id", + "INBOX", + "--title", + "Call mom", + "--desc", + "Saturday", + "--priority", + "9", + "--flagged", + "--all-day", output_format="json", ) - assert drive_result.exit_code == 0 - assert "report.txt" in drive_result.stdout - assert photo_result.exit_code == 0 - assert output_path.read_bytes() == b"photo-1:original" - assert json_drive_result.exit_code == 0 - assert json.loads(json_drive_result.stdout)["path"] == str(json_output_path) + create_payload = json.loads(create_result.stdout) + created_id = create_payload["id"] + assert create_result.exit_code == 0 + assert create_payload["list_id"] == "List/INBOX" + assert create_payload["flagged"] is True + assert create_payload["all_day"] is True + update_result = _invoke( + fake_api, + "reminders", + "update", + "Reminder/A", + "--title", + "Buy oat milk", + "--not-flagged", + "--clear-time-zone", + "--clear-parent-reminder", + output_format="json", + ) + update_payload = json.loads(update_result.stdout) + assert update_result.exit_code == 0 + assert update_payload["title"] == "Buy oat milk" + assert update_payload["flagged"] is False + assert update_payload["time_zone"] is None + assert update_payload["parent_reminder_id"] is None -def test_drive_missing_paths_report_cli_abort() -> None: - """Drive commands should collapse missing path lookups into CLIAbort errors.""" + status_result = _invoke( + fake_api, + "reminders", + "set-status", + "Reminder/A", + "--completed", + output_format="json", + ) + status_payload = json.loads(status_result.stdout) + assert status_result.exit_code == 0 + assert status_payload["completed"] is True - fake_api = FakeAPI() - output_path = TEST_ROOT / "missing.txt" + snapshot_result = _invoke( + fake_api, + "reminders", + "snapshot", + "--list-id", + "INBOX", + output_format="json", + ) + snapshot_payload = json.loads(snapshot_result.stdout) + assert snapshot_result.exit_code == 0 + assert set(snapshot_payload) == { + "alarms", + "attachments", + "hashtags", + "recurrence_rules", + "reminders", + "triggers", + } - list_result = _invoke(fake_api, "drive", "list", "/missing") - download_result = _invoke( + changes_result = _invoke( fake_api, - "drive", - "download", - "/missing", - "--output", - str(output_path), + "reminders", + "changes", + "--since", + "reminders-prev", + "--limit", + "1", + output_format="json", ) + changes_payload = json.loads(changes_result.stdout) + assert changes_result.exit_code == 0 + assert fake_api.reminders.change_requests[-1] == "reminders-prev" + assert changes_payload[0]["type"] == "updated" - assert list_result.exit_code != 0 - assert list_result.exception.args[0] == "Path not found: /missing" - assert download_result.exit_code != 0 - assert download_result.exception.args[0] == "Path not found: /missing" + cursor_result = _invoke(fake_api, "reminders", "sync-cursor") + assert cursor_result.exit_code == 0 + assert cursor_result.stdout.strip() == "reminders-cursor-1" + delete_result = _invoke( + fake_api, + "reminders", + "delete", + created_id, + output_format="json", + ) + delete_payload = json.loads(delete_result.stdout) + assert delete_result.exit_code == 0 + assert delete_payload["deleted"] is True + assert fake_api.reminders.reminder_rows[created_id].deleted is True -def test_photos_commands_report_reauthentication_requirement() -> None: - """Photos commands should wrap nested service operations in service_call.""" - class ReauthAlbums: - @property - def albums(self): - raise context_module.PyiCloudFailedLoginException("No password set") +def test_reminders_subgroup_commands() -> None: + """Reminder subgroup commands should expose alarm, hashtag, attachment, and recurrence flows.""" fake_api = FakeAPI() - fake_api.photos = ReauthAlbums() - albums_result = _invoke(fake_api, "photos", "albums") + alarm_list_result = _invoke( + fake_api, + "reminders", + "alarm", + "list", + "Reminder/A", + output_format="json", + ) + alarm_list_payload = json.loads(alarm_list_result.stdout) + assert alarm_list_result.exit_code == 0 + assert alarm_list_payload[0]["alarm"]["id"] == "Alarm/A" - assert albums_result.exit_code != 0 - assert albums_result.exception.args[0] == ( - "Photos requires re-authentication for user@example.com. " - "Run: icloud auth login --username user@example.com" + alarm_create_result = _invoke( + fake_api, + "reminders", + "alarm", + "add-location", + "Reminder/C", + "--title", + "Home", + "--address", + "Rue de Example", + "--latitude", + "49.61", + "--longitude", + "6.13", + "--radius", + "75", + "--proximity", + "leaving", + output_format="json", + ) + alarm_create_payload = json.loads(alarm_create_result.stdout) + assert alarm_create_result.exit_code == 0 + assert alarm_create_payload["trigger"]["title"] == "Home" + assert ( + fake_api.reminders.trigger_rows[alarm_create_payload["trigger"]["id"]].proximity + == Proximity.LEAVING ) - class BrokenPhoto(FakePhoto): - def download(self, version: str = "original") -> bytes: - raise context_module.PyiCloudFailedLoginException("No password set") + hashtag_list_result = _invoke( + fake_api, + "reminders", + "hashtag", + "list", + "Reminder/A", + output_format="json", + ) + hashtag_list_payload = json.loads(hashtag_list_result.stdout) + assert hashtag_list_result.exit_code == 0 + assert hashtag_list_payload[0]["id"] == "Hashtag/ERRANDS" - photo_album = FakePhotoAlbum("All Photos", [BrokenPhoto("photo-1", "img.jpg")]) - fake_api = FakeAPI() - fake_api.photos = SimpleNamespace( - albums=FakeAlbumContainer([photo_album]), - all=photo_album, + hashtag_create_result = _invoke( + fake_api, + "reminders", + "hashtag", + "create", + "Reminder/C", + "home", + output_format="json", ) - output_path = TEST_ROOT / "photo-reauth.bin" + hashtag_create_payload = json.loads(hashtag_create_result.stdout) + hashtag_suffix = hashtag_create_payload["id"].split("/", 1)[1] + assert hashtag_create_result.exit_code == 0 - download_result = _invoke( + hashtag_update_result = _invoke( fake_api, - "photos", - "download", - "photo-1", - "--output", - str(output_path), + "reminders", + "hashtag", + "update", + "Reminder/C", + hashtag_suffix, + "--name", + "chores", + output_format="json", ) + hashtag_update_payload = json.loads(hashtag_update_result.stdout) + assert hashtag_update_result.exit_code == 0 + assert hashtag_update_payload["name"] == "chores" - assert download_result.exit_code != 0 - assert download_result.exception.args[0] == ( - "Photos requires re-authentication for user@example.com. " - "Run: icloud auth login --username user@example.com" + hashtag_delete_result = _invoke( + fake_api, + "reminders", + "hashtag", + "delete", + "Reminder/C", + hashtag_suffix, + output_format="json", ) + hashtag_delete_payload = json.loads(hashtag_delete_result.stdout) + assert hashtag_delete_result.exit_code == 0 + assert hashtag_delete_payload["deleted"] is True + attachment_list_result = _invoke( + fake_api, + "reminders", + "attachment", + "list", + "Reminder/A", + output_format="json", + ) + attachment_list_payload = json.loads(attachment_list_result.stdout) + assert attachment_list_result.exit_code == 0 + assert attachment_list_payload[0]["id"] == "Attachment/LINK" -def test_hidemyemail_commands() -> None: - """Hide My Email commands should expose list and generate.""" + attachment_create_result = _invoke( + fake_api, + "reminders", + "attachment", + "create-url", + "Reminder/C", + "--url", + "https://example.com/new", + output_format="json", + ) + attachment_create_payload = json.loads(attachment_create_result.stdout) + attachment_suffix = attachment_create_payload["id"].split("/", 1)[1] + assert attachment_create_result.exit_code == 0 - fake_api = FakeAPI() - list_result = _invoke(fake_api, "hidemyemail", "list") - generate_result = _invoke(fake_api, "hidemyemail", "generate") - assert list_result.exit_code == 0 - assert "Shopping" in list_result.stdout - assert generate_result.exit_code == 0 - assert "generated@privaterelay.appleid.com" in generate_result.stdout + attachment_update_result = _invoke( + fake_api, + "reminders", + "attachment", + "update", + "Reminder/C", + attachment_suffix, + "--url", + "https://example.org/new", + "--uti", + "public.url", + output_format="json", + ) + attachment_update_payload = json.loads(attachment_update_result.stdout) + assert attachment_update_result.exit_code == 0 + assert attachment_update_payload["url"] == "https://example.org/new" + attachment_delete_result = _invoke( + fake_api, + "reminders", + "attachment", + "delete", + "Reminder/C", + attachment_suffix, + output_format="json", + ) + attachment_delete_payload = json.loads(attachment_delete_result.stdout) + assert attachment_delete_result.exit_code == 0 + assert attachment_delete_payload["deleted"] is True -def test_hidemyemail_generate_requires_alias() -> None: - """Generate should fail when the backend returns an empty alias.""" + recurrence_list_result = _invoke( + fake_api, + "reminders", + "recurrence", + "list", + "Reminder/A", + output_format="json", + ) + recurrence_list_payload = json.loads(recurrence_list_result.stdout) + assert recurrence_list_result.exit_code == 0 + assert recurrence_list_payload[0]["id"] == "Recurrence/WEEKLY" - fake_api = FakeAPI() - fake_api.hidemyemail.generate = MagicMock(return_value=None) + recurrence_create_result = _invoke( + fake_api, + "reminders", + "recurrence", + "create", + "Reminder/C", + "--frequency", + "monthly", + "--interval", + "2", + output_format="json", + ) + recurrence_create_payload = json.loads(recurrence_create_result.stdout) + recurrence_suffix = recurrence_create_payload["id"].split("/", 1)[1] + assert recurrence_create_result.exit_code == 0 - result = _invoke(fake_api, "hidemyemail", "generate") + recurrence_update_result = _invoke( + fake_api, + "reminders", + "recurrence", + "update", + "Reminder/C", + recurrence_suffix, + "--frequency", + "yearly", + "--interval", + "3", + "--occurrence-count", + "4", + output_format="json", + ) + recurrence_update_payload = json.loads(recurrence_update_result.stdout) + assert recurrence_update_result.exit_code == 0 + assert recurrence_update_payload["interval"] == 3 + assert recurrence_update_payload["occurrence_count"] == 4 - assert result.exit_code != 0 - assert result.exception.args[0] == ( - "Hide My Email generate returned an empty alias." + recurrence_delete_result = _invoke( + fake_api, + "reminders", + "recurrence", + "delete", + "Reminder/C", + recurrence_suffix, + output_format="json", ) + recurrence_delete_payload = json.loads(recurrence_delete_result.stdout) + assert recurrence_delete_result.exit_code == 0 + assert recurrence_delete_payload["deleted"] is True -def test_hidemyemail_update_omits_note_when_not_provided() -> None: - """Label-only updates should not overwrite notes with a synthetic default.""" +def test_reminders_commands_report_errors() -> None: + """Reminders commands should surface clean validation and lookup errors.""" fake_api = FakeAPI() - update_metadata = MagicMock(return_value={"anonymousId": "alias-1", "label": "New"}) - fake_api.hidemyemail.update_metadata = update_metadata - result = _invoke(fake_api, "hidemyemail", "update", "alias-1", "New") + missing_result = _invoke(fake_api, "reminders", "get", "Reminder/MISSING") + assert missing_result.exit_code != 0 + assert missing_result.exception.args[0] == "Reminder not found: Reminder/MISSING" - assert result.exit_code == 0 - update_metadata.assert_called_once_with("alias-1", "New", None) - - -def test_hidemyemail_mutations_require_valid_payload() -> None: - """Hide My Email mutators should reject empty success payloads.""" - - fake_api = FakeAPI() - fake_api.hidemyemail.delete = MagicMock(return_value={}) + update_result = _invoke(fake_api, "reminders", "update", "Reminder/A") + assert update_result.exit_code != 0 + assert update_result.exception.args[0] == "No reminder updates were requested." - result = _invoke(fake_api, "hidemyemail", "delete", "alias-1") - - assert result.exit_code != 0 - assert result.exception.args[0] == ( - "Hide My Email delete returned an invalid response: {}" + hashtag_result = _invoke( + fake_api, + "reminders", + "hashtag", + "delete", + "Reminder/A", + "missing", + ) + assert hashtag_result.exit_code != 0 + assert hashtag_result.exception.args[0] == ( + "No hashtag matched 'missing' for reminder Reminder/A." ) - fake_api = FakeAPI() - fake_api.hidemyemail.reserve = MagicMock(return_value={}) - - result = _invoke( + attachment_result = _invoke( fake_api, - "hidemyemail", - "reserve", - "alias@example.com", - "Shopping", + "reminders", + "attachment", + "update", + "Reminder/A", + "LINK", + ) + assert attachment_result.exit_code != 0 + assert attachment_result.exception.args[0] == ( + "No attachment updates were requested." ) - assert result.exit_code != 0 - assert result.exception.args[0] == ( - "Hide My Email reserve returned an invalid response: {}" + recurrence_result = _invoke( + fake_api, + "reminders", + "recurrence", + "update", + "Reminder/A", + "WEEKLY", + ) + assert recurrence_result.exit_code != 0 + assert recurrence_result.exception.args[0] == ( + "No recurrence updates were requested." ) -def test_hidemyemail_list_reports_reauthentication_requirement() -> None: - """Hide My Email iteration errors should be wrapped in a CLIAbort.""" +def test_reminders_commands_report_reauthentication_and_unavailability() -> None: + """Reminders commands should wrap service reauth and service-unavailable failures.""" - class ReauthHideMyEmail: - def __iter__(self): + class ReauthReminders: + def lists(self): raise context_module.PyiCloudFailedLoginException("No password set") - def generate(self) -> str: # pragma: no cover - not used in this test - return "ignored" + class UnavailableReminders: + def sync_cursor(self) -> str: + raise context_module.PyiCloudServiceUnavailable("temporarily unavailable") fake_api = FakeAPI() - fake_api.hidemyemail = ReauthHideMyEmail() - - result = _invoke(fake_api, "hidemyemail", "list") - - assert result.exit_code != 0 - assert result.exception.args[0] == ( - "Hide My Email requires re-authentication for user@example.com. " + fake_api.reminders = ReauthReminders() + reauth_result = _invoke(fake_api, "reminders", "lists") + assert reauth_result.exit_code != 0 + assert reauth_result.exception.args[0] == ( + "Reminders requires re-authentication for user@example.com. " "Run: icloud auth login --username user@example.com" ) + fake_api = FakeAPI() + fake_api.reminders = UnavailableReminders() + unavailable_result = _invoke(fake_api, "reminders", "sync-cursor") + assert unavailable_result.exit_code != 0 + assert unavailable_result.exception.args[0] == ( + "Reminders service unavailable: temporarily unavailable" + ) + def test_main_returns_clean_error_for_user_abort(capsys) -> None: """The entrypoint should not emit a traceback for expected CLI errors.""" diff --git a/tests/test_output.py b/tests/test_output.py index 915c11be..34b7b6a2 100644 --- a/tests/test_output.py +++ b/tests/test_output.py @@ -8,6 +8,7 @@ TABLE_TITLE_STYLE, console_kv_table, console_table, + format_color_value, ) @@ -34,3 +35,20 @@ def test_console_kv_table_styles_key_column() -> None: assert table.border_style == TABLE_BORDER_STYLE assert tuple(table.row_styles) == TABLE_ROW_STYLES assert table.columns[0].style == TABLE_KEY_STYLE + + +def test_format_color_value_handles_symbolic_payloads() -> None: + """Reminder color payloads should render as a compact symbolic label.""" + + assert ( + format_color_value('{"daHexString":"#007AFF","ckSymbolicColorName":"blue"}') + == "blue (#007AFF)" + ) + + +def test_format_color_value_handles_plain_values() -> None: + """Plain, empty, and malformed color values should degrade gracefully.""" + + assert format_color_value("#34C759") == "#34C759" + assert format_color_value("") == "" + assert format_color_value("{not-json}") == "{not-json}" From 6996b93a569d111c2f1393e76157bd6727a84358 Mon Sep 17 00:00:00 2001 From: mrjarnould Date: Tue, 31 Mar 2026 17:19:41 +0200 Subject: [PATCH 02/13] fix: repair Notes folder desired keys --- pyicloud/services/notes/service.py | 15 ++++---- tests/test_notes.py | 58 ++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 8 deletions(-) diff --git a/pyicloud/services/notes/service.py b/pyicloud/services/notes/service.py index 213e3275..354a626c 100644 --- a/pyicloud/services/notes/service.py +++ b/pyicloud/services/notes/service.py @@ -56,6 +56,7 @@ from .models.dto import ChangeEvent, NoteFolder LOGGER = logging.getLogger(__name__) +_HAS_SUBFOLDER_FIELD = "HasSubfolder" class NoteNotFound(NotesError): @@ -248,8 +249,7 @@ def folders(self) -> Iterable[NoteFolder]: """ desired_keys = [ NotesDesiredKey.TITLE_ENCRYPTED, - NotesDesiredKey.TITLE_MODIFICATION_DATE, - NotesDesiredKey.HAS_SUBFOLDER, + _HAS_SUBFOLDER_FIELD, ] query = CKQueryObject( recordType="SearchIndexes", @@ -277,13 +277,12 @@ def folders(self) -> Iterable[NoteFolder]: name = self._decode_encrypted( rec.fields.get_value("TitleEncrypted") ) - has_sub = bool( - getattr( - rec.fields.get_field(NotesDesiredKey.HAS_SUBFOLDER) or (), - "value", - False, - ) + has_sub_value = getattr( + rec.fields.get_field(_HAS_SUBFOLDER_FIELD) or (), + "value", + None, ) + has_sub = None if has_sub_value is None else bool(has_sub_value) yield NoteFolder( id=folder_id, name=name, has_subfolders=has_sub, count=None ) diff --git a/tests/test_notes.py b/tests/test_notes.py index b8294fa7..3b4cd327 100644 --- a/tests/test_notes.py +++ b/tests/test_notes.py @@ -306,6 +306,64 @@ def test_notes_service_attachment_lookup_prefers_canonical_record_names(self): self.assertEqual(attachments[0].id, "Attachment/CANONICAL") self.assertIs(self.service._attachment_meta_cache["ALIAS-1"], attachments[0]) + def test_notes_service_folders_uses_supported_desired_keys(self): + """Folder listing should not depend on nonexistent Notes desired-key enums.""" + + folder_record = CKRecord.model_validate( + { + "recordName": "Folder/1", + "recordType": "SearchIndexes", + "fields": { + "TitleEncrypted": { + "type": "STRING", + "value": "Work", + "isEncrypted": True, + }, + "HasSubfolder": {"type": "INT64", "value": 1}, + }, + } + ) + self.service.raw.query = MagicMock( + return_value=MagicMock(records=[folder_record], continuationMarker=None) + ) + + folders = list(self.service.folders()) + + self.assertEqual( + self.service.raw.query.call_args.kwargs["desired_keys"], + ["TitleEncrypted", "HasSubfolder"], + ) + self.assertEqual(len(folders), 1) + self.assertEqual(folders[0].id, "Folder/1") + self.assertEqual(folders[0].name, "Work") + self.assertTrue(folders[0].has_subfolders) + + def test_notes_service_folders_treats_subfolder_flag_as_optional(self): + """Folder listing should still work when Apple omits the subfolder flag.""" + + folder_record = CKRecord.model_validate( + { + "recordName": "Folder/2", + "recordType": "SearchIndexes", + "fields": { + "TitleEncrypted": { + "type": "STRING", + "value": "Personal", + "isEncrypted": True, + }, + }, + } + ) + self.service.raw.query = MagicMock( + return_value=MagicMock(records=[folder_record], continuationMarker=None) + ) + + folders = list(self.service.folders()) + + self.assertEqual(len(folders), 1) + self.assertEqual(folders[0].name, "Personal") + self.assertIsNone(folders[0].has_subfolders) + def test_write_html_rejects_filename_escape(self): out_dir = os.path.join( tempfile.gettempdir(), From 7200a47a7906252b23d396234545360e57641981 Mon Sep 17 00:00:00 2001 From: mrjarnould Date: Tue, 31 Mar 2026 17:56:18 +0200 Subject: [PATCH 03/13] fix: derive reminder list counts from membership --- pyicloud/services/reminders/_mappers.py | 11 +++++-- tests/services/test_reminders_cloudkit.py | 36 +++++++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/pyicloud/services/reminders/_mappers.py b/pyicloud/services/reminders/_mappers.py index 0b04e1d8..1380178b 100644 --- a/pyicloud/services/reminders/_mappers.py +++ b/pyicloud/services/reminders/_mappers.py @@ -137,16 +137,23 @@ def record_to_list(self, rec: CKRecord) -> RemindersList: fields = rec.fields title = fields.get_value("Name") color = fields.get_value("Color") + reminder_ids = self._reminder_ids_for_list_record(rec) + raw_count = fields.get_value("Count") + count = int(raw_count) if raw_count is not None else 0 + if count == 0 and reminder_ids: + # Live list records can carry complete reminder membership while the + # Count field stays at zero. Prefer the membership size in that case. + count = len(reminder_ids) return RemindersList( id=rec.recordName, title=str(title) if title else "Untitled", color=str(color) if color else None, - count=int(fields.get_value("Count") or 0), + count=count, badge_emblem=fields.get_value("BadgeEmblem"), sorting_style=fields.get_value("SortingStyle"), is_group=bool(fields.get_value("IsGroup") or 0), - reminder_ids=self._reminder_ids_for_list_record(rec), + reminder_ids=reminder_ids, record_change_tag=rec.recordChangeTag, ) diff --git a/tests/services/test_reminders_cloudkit.py b/tests/services/test_reminders_cloudkit.py index 2ef8b0e9..be20ad14 100644 --- a/tests/services/test_reminders_cloudkit.py +++ b/tests/services/test_reminders_cloudkit.py @@ -634,6 +634,40 @@ def test_list_parses_inline_reminder_ids_json(self, service): lst = service._record_to_list(rec) assert lst.reminder_ids == ["REM-1", "REM-2"] + assert lst.count == 2 + + def test_list_falls_back_to_reminder_ids_length_when_count_missing(self, service): + rec = _ck_record( + "List", + "LIST-003A", + { + "ReminderIDs": { + "type": "STRING", + "value": '["REM-1","Reminder/REM-2","REM-3"]', + } + }, + ) + + lst = service._record_to_list(rec) + assert lst.reminder_ids == ["REM-1", "REM-2", "REM-3"] + assert lst.count == 3 + + def test_list_falls_back_to_reminder_ids_length_when_count_is_zero(self, service): + rec = _ck_record( + "List", + "LIST-003B", + { + "Count": {"type": "INT64", "value": 0}, + "ReminderIDs": { + "type": "STRING", + "value": '["REM-1","Reminder/REM-2"]', + }, + }, + ) + + lst = service._record_to_list(rec) + assert lst.reminder_ids == ["REM-1", "REM-2"] + assert lst.count == 2 def test_list_parses_asset_backed_reminder_ids_from_downloaded_data(self, service): payload = base64.b64encode(b'["REM-1","Reminder/REM-2"]').decode("ascii") @@ -650,6 +684,7 @@ def test_list_parses_asset_backed_reminder_ids_from_downloaded_data(self, servic lst = service._record_to_list(rec) assert lst.reminder_ids == ["REM-1", "REM-2"] + assert lst.count == 2 service._raw.download_asset_bytes.assert_not_called() def test_list_parses_asset_backed_reminder_ids_from_download_url(self, service): @@ -667,6 +702,7 @@ def test_list_parses_asset_backed_reminder_ids_from_download_url(self, service): lst = service._record_to_list(rec) assert lst.reminder_ids == ["REM-3", "REM-4"] + assert lst.count == 2 service._raw.download_asset_bytes.assert_called_once_with( "https://example.com/reminder-ids.json" ) From 0c2e9cfffc7d9207c7f34b34601ad487e77ecb2f Mon Sep 17 00:00:00 2001 From: mrjarnould Date: Tue, 31 Mar 2026 19:00:40 +0200 Subject: [PATCH 04/13] fix: speed up Notes sync and fix Reminders hashtag round-trips --- pyicloud/services/notes/service.py | 19 ++++ pyicloud/services/reminders/_mappers.py | 25 ++--- pyicloud/services/reminders/_protocol.py | 17 +++ pyicloud/services/reminders/_reads.py | 4 + pyicloud/services/reminders/_writes.py | 5 +- pyicloud/services/reminders/client.py | 23 ++++ pyicloud/services/reminders/service.py | 1 + tests/services/test_reminders_cloudkit.py | 122 +++++++++++++++++++++- tests/test_notes.py | 31 ++++++ 9 files changed, 229 insertions(+), 18 deletions(-) diff --git a/pyicloud/services/notes/service.py b/pyicloud/services/notes/service.py index 354a626c..ad7728a3 100644 --- a/pyicloud/services/notes/service.py +++ b/pyicloud/services/notes/service.py @@ -216,6 +216,10 @@ def iter_all(self, *, since: Optional[str] = None) -> Iterable[NoteSummary]: ``NoteSummary`` instances for full exports, indexing jobs, or local cache refreshes. """ + if self._matches_current_sync_cursor(since): + LOGGER.debug("Skipping Notes full scan because sync token is current") + return + LOGGER.debug("Iterating all notes%s", f" since={since}" if since else "") for zone in self._raw.changes( zone_req=CKZoneChangesZoneReq( @@ -523,6 +527,10 @@ def iter_changes(self, *, since: Optional[str] = None) -> Iterable[ChangeEvent]: Pass a sync token from ``sync_cursor()`` to process only new changes since a previous run. """ + if self._matches_current_sync_cursor(since): + LOGGER.debug("Skipping Notes change scan because sync token is current") + return + LOGGER.debug("Iterating changes%s", f" since={since}" if since else "") for zone in self._raw.changes( zone_req=CKZoneChangesZoneReq( @@ -608,6 +616,17 @@ def raw(self) -> CloudKitNotesClient: # -------------------------- Internal helpers ----------------------------- + def _matches_current_sync_cursor(self, since: Optional[str]) -> bool: + """Return whether an incremental Notes cursor is already current.""" + if not since: + return False + + try: + return self._raw.current_sync_token(zone_name="Notes") == since + except NotesApiError as exc: + LOGGER.warning("Failed to preflight Notes sync token: %s", exc) + return False + @staticmethod def _coerce_keys(keys: Optional[Iterable[object]]) -> Optional[List[str]]: if keys is None: diff --git a/pyicloud/services/reminders/_mappers.py b/pyicloud/services/reminders/_mappers.py index 1380178b..7c67480e 100644 --- a/pyicloud/services/reminders/_mappers.py +++ b/pyicloud/services/reminders/_mappers.py @@ -11,6 +11,7 @@ from ._protocol import ( _as_raw_id, _decode_attachment_url, + _decode_cloudkit_text_value, _decode_crdt_document, _ref_name, ) @@ -117,21 +118,15 @@ def _reminder_ids_for_list_record(self, rec: CKRecord) -> list[str]: def _coerce_text(self, value: Any, *, field_name: str, record_name: str) -> str: """Normalize CloudKit text-like values into ``str`` for domain models.""" - if value is None: - return "" - if isinstance(value, str): - return value - if isinstance(value, bytes): - try: - return value.decode("utf-8") - except UnicodeDecodeError: - self._logger.warning( - "Field %s on %s was undecodable bytes; replacing invalid UTF-8", - field_name, - record_name, - ) - return value.decode("utf-8", errors="replace") - return str(value) + try: + return _decode_cloudkit_text_value(value) + except UnicodeDecodeError: + self._logger.warning( + "Field %s on %s was undecodable bytes; replacing invalid UTF-8", + field_name, + record_name, + ) + return value.decode("utf-8", errors="replace") def record_to_list(self, rec: CKRecord) -> RemindersList: fields = rec.fields diff --git a/pyicloud/services/reminders/_protocol.py b/pyicloud/services/reminders/_protocol.py index b5b0b31b..1ec726c3 100644 --- a/pyicloud/services/reminders/_protocol.py +++ b/pyicloud/services/reminders/_protocol.py @@ -84,6 +84,23 @@ def _decode_attachment_url(value: str) -> str: return value +def _encode_cloudkit_text_field(value: str) -> dict[str, str]: + """Encode text for CloudKit fields that store UTF-8 payload bytes.""" + encoded = base64.b64encode((value or "").encode("utf-8")).decode("ascii") + return {"type": "ENCRYPTED_BYTES", "value": encoded} + + +def _decode_cloudkit_text_value(value: object) -> str: + """Decode a CloudKit text field value into plain ``str``.""" + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, bytes): + return value.decode("utf-8") + return str(value) + + def _decode_crdt_document(encrypted_value: str | bytes) -> str: """Decode a CRDT document (TitleDocument or NotesDocument).""" data = encrypted_value diff --git a/pyicloud/services/reminders/_reads.py b/pyicloud/services/reminders/_reads.py index 57027827..2ddd6ad5 100644 --- a/pyicloud/services/reminders/_reads.py +++ b/pyicloud/services/reminders/_reads.py @@ -115,6 +115,10 @@ def reminders(self, list_id: Optional[str] = None) -> Iterable[Reminder]: def sync_cursor(self) -> str: """Return the latest usable sync token for the Reminders zone.""" + query_token = self._get_raw().current_sync_token(zone_id=_REMINDERS_ZONE_REQ) + if query_token: + return query_token + sync_token: Optional[str] = None for zone in self._iter_zone_change_pages( desired_record_types=[], diff --git a/pyicloud/services/reminders/_writes.py b/pyicloud/services/reminders/_writes.py index 6a354cda..c8bf9f2b 100644 --- a/pyicloud/services/reminders/_writes.py +++ b/pyicloud/services/reminders/_writes.py @@ -21,6 +21,7 @@ from ._protocol import ( _as_raw_id, _as_record_name, + _encode_cloudkit_text_field, _encode_crdt_document, _generate_resolution_token_map, ) @@ -736,7 +737,7 @@ def create_hashtag(self, reminder: Reminder, name: str) -> Hashtag: field_name="HashtagIDs", token_field_name="hashtagIDs", child_fields={ - "Name": {"type": "STRING", "value": name}, + "Name": _encode_cloudkit_text_field(name), "Deleted": {"type": "INT64", "value": 0}, "Reminder": { "type": "REFERENCE", @@ -762,7 +763,7 @@ def create_hashtag(self, reminder: Reminder, name: str) -> Hashtag: def update_hashtag(self, hashtag: Hashtag, name: str) -> None: """Update an existing hashtag name.""" fields: dict[str, Any] = { - "Name": {"type": "STRING", "value": name}, + "Name": _encode_cloudkit_text_field(name), } if hashtag.reminder_id: fields["Reminder"] = { diff --git a/pyicloud/services/reminders/client.py b/pyicloud/services/reminders/client.py index 8dd5a2c2..6ac294d6 100644 --- a/pyicloud/services/reminders/client.py +++ b/pyicloud/services/reminders/client.py @@ -187,6 +187,29 @@ def query( "Query response validation failed", payload=data ) from e + def current_sync_token( + self, + *, + zone_id: CKZoneIDReq, + record_type: str = "reminderList", + ) -> str | None: + """Fetch the current zone sync token using a lightweight query first.""" + payload = CKQueryRequest( + query=CKQueryObject(recordType=record_type), + zoneID=zone_id, + resultsLimit=1, + ).model_dump(mode="json", exclude_none=True) + + try: + data = self._http.post("/records/query", payload) + response = self._validate_response(CKQueryResponse, data) + except (RemindersApiError, ValidationError): + return None + + if getattr(response, "syncToken", None): + return str(response.syncToken) + return None + def changes( self, *, diff --git a/pyicloud/services/reminders/service.py b/pyicloud/services/reminders/service.py index d6218846..67e087b6 100644 --- a/pyicloud/services/reminders/service.py +++ b/pyicloud/services/reminders/service.py @@ -92,6 +92,7 @@ def __init__( ) base_params = { "remapEnums": True, + "getCurrentSyncToken": True, **(params or {}), } self._raw = CloudKitRemindersClient( diff --git a/tests/services/test_reminders_cloudkit.py b/tests/services/test_reminders_cloudkit.py index be20ad14..6d964871 100644 --- a/tests/services/test_reminders_cloudkit.py +++ b/tests/services/test_reminders_cloudkit.py @@ -319,6 +319,35 @@ def test_reminders_client_strict_mode_wraps_validation_error(): assert isinstance(excinfo.value.__cause__, ValidationError) +def test_reminders_client_current_sync_token_uses_query_sync_token(): + session = MagicMock() + session.post.return_value = MagicMock( + status_code=200, + json=lambda: {"records": [], "syncToken": "tok-query"}, + ) + client = CloudKitRemindersClient("https://example.com", session, {}) + + token = client.current_sync_token(zone_id=CKZoneIDReq(zoneName="Reminders")) + + assert token == "tok-query" + query_payload = session.post.call_args.kwargs["json"] + assert query_payload["query"]["recordType"] == "reminderList" + assert query_payload["resultsLimit"] == 1 + + +def test_reminders_client_current_sync_token_returns_none_when_missing(): + session = MagicMock() + session.post.return_value = MagicMock( + status_code=200, + json=lambda: {"records": []}, + ) + client = CloudKitRemindersClient("https://example.com", session, {}) + + token = client.current_sync_token(zone_id=CKZoneIDReq(zoneName="Reminders")) + + assert token is None + + def test_reminders_service_passes_through_validation_override(): service = RemindersService( "https://example.com", @@ -1429,7 +1458,9 @@ def test_create_and_delete_hashtag(self): create_ops = svc._raw.modify.call_args.kwargs["operations"] assert len(create_ops) == 2 assert create_ops[1].record.recordType == "Hashtag" - assert create_ops[1].record.fields["Name"].value == "travel" + name_field = create_ops[1].record.fields["Name"].root + assert name_field.type == "ENCRYPTED_BYTES" + assert name_field.value == b"travel" assert svc._raw.modify.call_args.kwargs["atomic"] is True svc._raw.modify.reset_mock() @@ -1913,6 +1944,46 @@ def _side_effect(**kwargs): assert reminder.record_change_tag == "ctag-rem-new" assert hashtag.record_change_tag == "ctag-hash-new" + def test_create_hashtag_name_round_trips_via_mapper(self): + svc = RemindersService("https://ckdatabasews.icloud.com", MagicMock(), {}) + svc._raw = MagicMock() + svc._raw.modify.return_value = self._ok_modify() + + reminder = Reminder( + id="Reminder/REM-TAG-ROUNDTRIP", + list_id="List/LIST-001", + title="Hashtag reminder", + record_change_tag="ctag-rem-old", + hashtag_ids=[], + ) + + svc.create_hashtag(reminder, "travel") + + name_field = ( + svc._raw.modify.call_args.kwargs["operations"][1].record.fields["Name"].root + ) + parsed = svc._record_to_hashtag( + _ck_record( + "Hashtag", + "Hashtag/HASH-ROUNDTRIP", + { + "Name": { + "type": name_field.type, + "value": base64.b64encode(name_field.value).decode("ascii"), + }, + "Reminder": { + "type": "REFERENCE", + "value": { + "recordName": "Reminder/REM-TAG-ROUNDTRIP", + "action": "VALIDATE", + }, + }, + }, + ) + ) + + assert parsed.name == "travel" + def test_create_url_attachment_hydrates_record_change_tags(self): svc = RemindersService("https://ckdatabasews.icloud.com", MagicMock(), {}) svc._raw = MagicMock() @@ -2028,6 +2099,45 @@ def _side_effect(**kwargs): svc.update_recurrence_rule(recurrence_rule, interval=2) assert recurrence_rule.record_change_tag == "new-recurrencerule-tag" + def test_update_hashtag_writes_encoded_name_field(self): + svc = RemindersService("https://ckdatabasews.icloud.com", MagicMock(), {}) + svc._raw = MagicMock() + svc._raw.modify.return_value = self._ok_modify() + + hashtag = Hashtag( + id="Hashtag/H-UPD-ENC", + name="old", + reminder_id="Reminder/REM-UPD-ENC", + record_change_tag="old-hashtag-tag", + ) + svc.update_hashtag(hashtag, "chores") + + operation = svc._raw.modify.call_args.kwargs["operations"][0] + name_field = operation.record.fields["Name"].root + assert name_field.type == "ENCRYPTED_BYTES" + assert name_field.value == b"chores" + + parsed = svc._record_to_hashtag( + _ck_record( + "Hashtag", + "Hashtag/H-UPD-ENC", + { + "Name": { + "type": name_field.type, + "value": base64.b64encode(name_field.value).decode("ascii"), + }, + "Reminder": { + "type": "REFERENCE", + "value": { + "recordName": "Reminder/REM-UPD-ENC", + "action": "VALIDATE", + }, + }, + }, + ) + ) + assert parsed.name == "chores" + class TestReminderReadPaths: """Validate reminders() and list_reminders() query behavior.""" @@ -2697,6 +2807,7 @@ def _changes_response( def test_sync_cursor_returns_final_paged_token(self): svc = RemindersService("https://ckdatabasews.icloud.com", MagicMock(), {}) svc._raw = MagicMock() + svc._raw.current_sync_token.return_value = None svc._raw.changes.side_effect = [ self._changes_response([], sync_token="tok-1", more_coming=True), self._changes_response([], sync_token="tok-2", more_coming=False), @@ -2711,6 +2822,15 @@ def test_sync_cursor_returns_final_paged_token(self): assert first_zone_req.desiredRecordTypes == [] assert first_zone_req.desiredKeys == [] + def test_sync_cursor_prefers_query_sync_token(self): + svc = RemindersService("https://ckdatabasews.icloud.com", MagicMock(), {}) + svc._raw = MagicMock() + svc._raw.current_sync_token.return_value = "tok-query" + + assert svc.sync_cursor() == "tok-query" + svc._raw.current_sync_token.assert_called_once() + svc._raw.changes.assert_not_called() + def test_iter_changes_emits_updated_deleted_and_tombstone_events(self): svc = RemindersService("https://ckdatabasews.icloud.com", MagicMock(), {}) svc._raw = MagicMock() diff --git a/tests/test_notes.py b/tests/test_notes.py index 3b4cd327..6155cb03 100644 --- a/tests/test_notes.py +++ b/tests/test_notes.py @@ -257,6 +257,37 @@ def test_notes_service_export_note_uses_lazy_importer(self): self.assertEqual(exported, output_path) mock_export.assert_called_once() + def test_iter_all_skips_changes_when_sync_cursor_is_current(self): + self.service._raw = MagicMock() + self.service._raw.current_sync_token.return_value = "tok-current" + + rows = list(self.service.iter_all(since="tok-current")) + + self.assertEqual(rows, []) + self.service._raw.current_sync_token.assert_called_once_with(zone_name="Notes") + self.service._raw.changes.assert_not_called() + + def test_iter_changes_skips_changes_when_sync_cursor_is_current(self): + self.service._raw = MagicMock() + self.service._raw.current_sync_token.return_value = "tok-current" + + rows = list(self.service.iter_changes(since="tok-current")) + + self.assertEqual(rows, []) + self.service._raw.current_sync_token.assert_called_once_with(zone_name="Notes") + self.service._raw.changes.assert_not_called() + + def test_iter_all_uses_changes_when_sync_cursor_is_not_current(self): + self.service._raw = MagicMock() + self.service._raw.current_sync_token.return_value = "tok-other" + self.service._raw.changes.return_value = [] + + rows = list(self.service.iter_all(since="tok-stale")) + + self.assertEqual(rows, []) + self.service._raw.current_sync_token.assert_called_once_with(zone_name="Notes") + self.service._raw.changes.assert_called_once() + def test_notes_service_attachment_lookup_prefers_canonical_record_names(self): note_record = CKRecord.model_validate( { From 3dae36963455bf53bf8bc48f60b4169dc3bb7d7f Mon Sep 17 00:00:00 2001 From: mrjarnould Date: Tue, 31 Mar 2026 19:10:00 +0200 Subject: [PATCH 05/13] docs: use installed icloud CLI in README examples --- README.md | 52 ++++++++++++++++++++++++++-------------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index edef5673..83a6760f 100644 --- a/README.md +++ b/README.md @@ -1136,9 +1136,9 @@ flows. _List reminder lists and open reminders:_ ```bash -uv run icloud reminders lists --username you@example.com -uv run icloud reminders list --username you@example.com -uv run icloud reminders list --username you@example.com --list-id INBOX --include-completed +icloud reminders lists --username you@example.com +icloud reminders list --username you@example.com +icloud reminders list --username you@example.com --list-id INBOX --include-completed ``` `icloud reminders list` defaults to open reminders only. Use @@ -1148,36 +1148,36 @@ the query to one list. _Get, create, update, and delete reminders:_ ```bash -uv run icloud reminders get REMINDER_ID --username you@example.com -uv run icloud reminders create --username you@example.com --list-id INBOX --title "Buy milk" -uv run icloud reminders update REMINDER_ID --username you@example.com --title "Buy oat milk" -uv run icloud reminders set-status REMINDER_ID --username you@example.com --completed -uv run icloud reminders delete REMINDER_ID --username you@example.com +icloud reminders get REMINDER_ID --username you@example.com +icloud reminders create --username you@example.com --list-id INBOX --title "Buy milk" +icloud reminders update REMINDER_ID --username you@example.com --title "Buy oat milk" +icloud reminders set-status REMINDER_ID --username you@example.com --completed +icloud reminders delete REMINDER_ID --username you@example.com ``` _Inspect snapshots and incremental changes:_ ```bash -uv run icloud reminders snapshot --username you@example.com --list-id INBOX -uv run icloud reminders changes --username you@example.com --since PREVIOUS_CURSOR -uv run icloud reminders sync-cursor --username you@example.com +icloud reminders snapshot --username you@example.com --list-id INBOX +icloud reminders changes --username you@example.com --since PREVIOUS_CURSOR +icloud reminders sync-cursor --username you@example.com ``` _Work with reminder sub-records:_ ```bash -uv run icloud reminders alarm add-location REMINDER_ID \ +icloud reminders alarm add-location REMINDER_ID \ --username you@example.com \ --title "Office" \ --address "1 Infinite Loop, Cupertino, CA" \ --latitude 37.3318 \ --longitude -122.0312 -uv run icloud reminders hashtag create REMINDER_ID errands --username you@example.com -uv run icloud reminders attachment create-url REMINDER_ID \ +icloud reminders hashtag create REMINDER_ID errands --username you@example.com +icloud reminders attachment create-url REMINDER_ID \ --username you@example.com \ --url https://example.com/checklist -uv run icloud reminders recurrence create REMINDER_ID \ +icloud reminders recurrence create REMINDER_ID \ --username you@example.com \ --frequency weekly \ --interval 1 @@ -1322,17 +1322,17 @@ folder browsing, title-based search, HTML rendering, and note-id-based export. _List recent notes, folders, or one folder’s notes:_ ```bash -uv run icloud notes recent --username you@example.com -uv run icloud notes folders --username you@example.com -uv run icloud notes list --username you@example.com --folder-id FOLDER_ID -uv run icloud notes list --username you@example.com --all --since PREVIOUS_CURSOR +icloud notes recent --username you@example.com +icloud notes folders --username you@example.com +icloud notes list --username you@example.com --folder-id FOLDER_ID +icloud notes list --username you@example.com --all --since PREVIOUS_CURSOR ``` _Search notes by title:_ ```bash -uv run icloud notes search --username you@example.com --title "Daily Plan" -uv run icloud notes search --username you@example.com --title-contains "meeting" +icloud notes search --username you@example.com --title "Daily Plan" +icloud notes search --username you@example.com --title-contains "meeting" ``` `icloud notes search` is the official title-filter workflow. It uses a @@ -1341,9 +1341,9 @@ recents-first search strategy and falls back to a full feed scan when needed. _Fetch, render, and export one note by id:_ ```bash -uv run icloud notes get NOTE_ID --username you@example.com --with-attachments -uv run icloud notes render NOTE_ID --username you@example.com --preview-appearance dark -uv run icloud notes export NOTE_ID \ +icloud notes get NOTE_ID --username you@example.com --with-attachments +icloud notes render NOTE_ID --username you@example.com --preview-appearance dark +icloud notes export NOTE_ID \ --username you@example.com \ --output-dir ./exports/notes_html \ --export-mode archival \ @@ -1356,8 +1356,8 @@ handled by `icloud notes search` rather than by bulk export flags. _Inspect incremental changes:_ ```bash -uv run icloud notes changes --username you@example.com --since PREVIOUS_CURSOR -uv run icloud notes sync-cursor --username you@example.com +icloud notes changes --username you@example.com --since PREVIOUS_CURSOR +icloud notes sync-cursor --username you@example.com ``` ### Notes CLI Example From 32efc4796d44e5d022b911e45f155792c6d75473 Mon Sep 17 00:00:00 2001 From: mrjarnould Date: Tue, 31 Mar 2026 19:20:49 +0200 Subject: [PATCH 06/13] fix: address Notes and Reminders CLI review nits --- README.md | 2 +- pyicloud/cli/commands/reminders.py | 10 +++++++--- pyicloud/cli/normalize.py | 4 +++- tests/test_cmdline.py | 22 ++++++++++++++++++++++ 4 files changed, 33 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 83a6760f..db8cd416 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,7 @@ api.devices The `icloud` command line interface is organized around top-level subcommands such as `auth`, `account`, `devices`, `calendar`, -`contacts`, `drive`, `photos`, and `hidemyemail`. +`contacts`, `drive`, `photos`, `hidemyemail`, `notes`, and `reminders`. Command options belong on the final command that uses them. For example: diff --git a/pyicloud/cli/commands/reminders.py b/pyicloud/cli/commands/reminders.py index a9df0805..6020bff0 100644 --- a/pyicloud/cli/commands/reminders.py +++ b/pyicloud/cli/commands/reminders.py @@ -22,6 +22,7 @@ store_command_options, ) from pyicloud.cli.output import console_kv_table, console_table, format_color_value +from pyicloud.services.reminders.client import RemindersApiError, RemindersAuthError from pyicloud.services.reminders.models import ( AlarmWithTrigger, ImageAttachment, @@ -133,9 +134,12 @@ def _reminders_call(api, fn): try: return service_call(REMINDERS, fn, account_name=api.account_name) - except LookupError as err: - raise CLIAbort(str(err)) from err - except ValidationError as err: + except ( + LookupError, + ValidationError, + RemindersApiError, + RemindersAuthError, + ) as err: raise CLIAbort(str(err)) from err diff --git a/pyicloud/cli/normalize.py b/pyicloud/cli/normalize.py index 4e8209b4..c087227a 100644 --- a/pyicloud/cli/normalize.py +++ b/pyicloud/cli/normalize.py @@ -5,6 +5,8 @@ from datetime import datetime, timezone from typing import Any +MAX_NOTES_SEARCH_WINDOW = 5_000 + def normalize_account_summary(api, account) -> dict[str, Any]: """Normalize account summary data.""" @@ -232,7 +234,7 @@ def dedupe_key(item: Any) -> Any: candidates: list[Any] = [] seen: set[Any] = set() - window = max(500, limit * 50) + window = min(MAX_NOTES_SEARCH_WINDOW, max(500, limit * 50)) for note in notes_service.recents(limit=window): if not matches(getattr(note, "title", None)): diff --git a/tests/test_cmdline.py b/tests/test_cmdline.py index 8a1cf75c..6152916d 100644 --- a/tests/test_cmdline.py +++ b/tests/test_cmdline.py @@ -24,6 +24,7 @@ NoteSummary, ) from pyicloud.services.notes.service import NoteLockedError, NoteNotFound +from pyicloud.services.reminders.client import RemindersApiError, RemindersAuthError from pyicloud.services.reminders.models import ( Alarm, AlarmWithTrigger, @@ -296,6 +297,7 @@ def __init__(self) -> None: is_locked=False, ), self.recent_rows[1], + # Duplicate entry to verify deduplication in search_notes_by_title. self.recent_rows[2], ] self.notes = { @@ -2947,6 +2949,26 @@ def test_reminders_commands_report_errors() -> None: "No recurrence updates were requested." ) + class ApiErrorReminders: + def sync_cursor(self) -> str: + raise RemindersApiError("sync failed") + + class AuthErrorReminders: + def sync_cursor(self) -> str: + raise RemindersAuthError("token expired") + + fake_api = FakeAPI() + fake_api.reminders = ApiErrorReminders() + api_error_result = _invoke(fake_api, "reminders", "sync-cursor") + assert api_error_result.exit_code != 0 + assert api_error_result.exception.args[0] == "sync failed" + + fake_api = FakeAPI() + fake_api.reminders = AuthErrorReminders() + auth_error_result = _invoke(fake_api, "reminders", "sync-cursor") + assert auth_error_result.exit_code != 0 + assert auth_error_result.exception.args[0] == "token expired" + def test_reminders_commands_report_reauthentication_and_unavailability() -> None: """Reminders commands should wrap service reauth and service-unavailable failures.""" From 8c453922d7ebacd0176dd579f071641bd23259d7 Mon Sep 17 00:00:00 2001 From: Tim Laing <11019084+timlaing@users.noreply.github.com> Date: Fri, 3 Apr 2026 18:01:41 +0100 Subject: [PATCH 07/13] Update permissions for autolabeler workflow Signed-off-by: Tim Laing <11019084+timlaing@users.noreply.github.com> --- .github/workflows/autolabeler.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/autolabeler.yml b/.github/workflows/autolabeler.yml index eec213c3..74b9cc86 100644 --- a/.github/workflows/autolabeler.yml +++ b/.github/workflows/autolabeler.yml @@ -5,11 +5,11 @@ on: permissions: contents: read - + pull-requests: write + issues: write + jobs: auto_label: - permissions: - pull-requests: write runs-on: ubuntu-latest steps: # runs autolabeler From 0b2424cb4125c42a9585338b389f23058396c724 Mon Sep 17 00:00:00 2001 From: Tim Laing <11019084+timlaing@users.noreply.github.com> Date: Fri, 3 Apr 2026 17:05:07 +0000 Subject: [PATCH 08/13] Fix formatting in autolabeler workflow by removing unnecessary blank line --- .github/workflows/autolabeler.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/autolabeler.yml b/.github/workflows/autolabeler.yml index 74b9cc86..e754d6d5 100644 --- a/.github/workflows/autolabeler.yml +++ b/.github/workflows/autolabeler.yml @@ -7,7 +7,7 @@ permissions: contents: read pull-requests: write issues: write - + jobs: auto_label: runs-on: ubuntu-latest From ec01e9394b584e1d1b33e02742a87778581e78e3 Mon Sep 17 00:00:00 2001 From: Jacob Arnould Date: Fri, 3 Apr 2026 19:05:32 +0200 Subject: [PATCH 09/13] feat: add a root --version CLI flag (#214) Co-authored-by: Tim Laing <11019084+timlaing@users.noreply.github.com> --- pyicloud/cli/app.py | 33 +++++++++++++++++++++++++++++++++ tests/test_cmdline.py | 10 ++++++++++ 2 files changed, 43 insertions(+) diff --git a/pyicloud/cli/app.py b/pyicloud/cli/app.py index c478b00e..b8626327 100644 --- a/pyicloud/cli/app.py +++ b/pyicloud/cli/app.py @@ -2,6 +2,9 @@ from __future__ import annotations +from importlib.metadata import PackageNotFoundError +from importlib.metadata import version as package_version + import typer from pyicloud.cli.commands.account import app as account_app @@ -23,6 +26,23 @@ ) +def _installed_version() -> str: + """Return the installed pyicloud package version.""" + + try: + return package_version("pyicloud") + except PackageNotFoundError: + return "unknown" + + +def _version_callback(value: bool) -> None: + """Print the installed pyicloud version and exit.""" + + if value: + typer.echo(_installed_version()) + raise typer.Exit() + + def _group_root(ctx: typer.Context) -> None: """Show mounted group help when invoked without a subcommand.""" @@ -31,6 +51,19 @@ def _group_root(ctx: typer.Context) -> None: raise typer.Exit() +@app.callback() +def root_callback( + version: bool = typer.Option( + False, + "--version", + help="Show the installed pyicloud version and exit.", + callback=_version_callback, + is_eager=True, + ), +) -> None: + """Handle root CLI options before subcommand dispatch.""" + + app.add_typer( account_app, name="account", invoke_without_command=True, callback=_group_root ) diff --git a/tests/test_cmdline.py b/tests/test_cmdline.py index 6152916d..64dbc87a 100644 --- a/tests/test_cmdline.py +++ b/tests/test_cmdline.py @@ -1164,6 +1164,16 @@ def test_root_help() -> None: assert command in text +def test_root_version_prints_installed_package_version() -> None: + """The root --version flag should print the installed pyicloud version.""" + + with patch.object(cli_module, "_installed_version", return_value="9.9.9"): + result = _runner().invoke(app, ["--version"]) + + assert result.exit_code == 0 + assert result.stdout.strip() == "9.9.9" + + def test_group_help() -> None: """Each command group should expose help.""" From b230450accc0e9b28d6e182d6115753481689054 Mon Sep 17 00:00:00 2001 From: Jacob Arnould Date: Fri, 3 Apr 2026 19:07:41 +0200 Subject: [PATCH 10/13] fix: restore SMS and trusted-device 2FA auth flows (#210) * feat: handle Apple's HSA2 trusted-device prompts * Trim unrelated Notes PR scope * Address CodeRabbit review comments * Add docstrings for auth bridge PR scope * Harden bridge prover and persistence tests --- README.md | 6 + pyicloud/base.py | 350 ++++++- pyicloud/cli/context.py | 59 +- pyicloud/exceptions.py | 13 + pyicloud/hsa2_bridge.py | 1690 ++++++++++++++++++++++++++++++++ pyicloud/hsa2_bridge_prover.py | 582 +++++++++++ pyicloud/session.py | 114 ++- requirements.txt | 1 + tests/test_base.py | 517 ++++++++++ tests/test_cmdline.py | 191 +++- tests/test_hsa2_bridge.py | 1357 +++++++++++++++++++++++++ 11 files changed, 4851 insertions(+), 29 deletions(-) create mode 100644 pyicloud/hsa2_bridge.py create mode 100644 pyicloud/hsa2_bridge_prover.py create mode 100644 tests/test_hsa2_bridge.py diff --git a/README.md b/README.md index db8cd416..916fef73 100644 --- a/README.md +++ b/README.md @@ -182,6 +182,11 @@ If you have enabled two-factor authentications (2FA) or [two-step authentication (2SA)](https://support.apple.com/en-us/HT204152) for the account you will have to do some extra work: +For HSA2 accounts, `request_2fa_code()` now starts Apple's active delivery +route for code-based challenges. Depending on the account and session, that may +be a trusted-device prompt or an SMS code. Security-key challenges are handled +separately via `security_key_names` / `confirm_security_key()`. + ```python import sys @@ -216,6 +221,7 @@ if api.requires_2fa: else: print("Two-factor authentication required.") + api.request_2fa_code() code = input( "Enter the code you received of one of your approved devices: " ) diff --git a/pyicloud/base.py b/pyicloud/base.py index 47897c70..572fdda6 100644 --- a/pyicloud/base.py +++ b/pyicloud/base.py @@ -5,9 +5,10 @@ import json import logging import time +from dataclasses import dataclass from os import chmod, environ, makedirs, path, umask from tempfile import gettempdir -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Mapping, Optional from uuid import uuid1 import srp @@ -34,6 +35,14 @@ PyiCloudPasswordException, PyiCloudServiceNotActivatedException, PyiCloudServiceUnavailable, + PyiCloudTrustedDevicePromptException, + PyiCloudTrustedDeviceVerificationException, +) +from pyicloud.hsa2_bridge import ( + Hsa2BootContext, + TrustedDeviceBridgeBootstrapper, + TrustedDeviceBridgeState, + parse_boot_args_html, ) from pyicloud.services import ( AccountService, @@ -106,6 +115,92 @@ def resolve_cookie_directory(cookie_directory: Optional[str] = None) -> str: return path.join(topdir, getpass.getuser()) +@dataclass(frozen=True) +class TrustedPhoneNumber: + """Typed view of Apple's trusted-phone metadata.""" + + device_id: int | str + non_fteu: Optional[bool] = None + push_mode: Optional[str] = None + + @classmethod + def from_mapping( + cls, value: Optional[Mapping[str, Any]] + ) -> Optional["TrustedPhoneNumber"]: + """Return a typed phone record when Apple's payload includes one.""" + + if not isinstance(value, Mapping): + return None + device_id = value.get("id") + if not isinstance(device_id, (int, str)): + return None + + non_fteu = value.get("nonFTEU") + if not isinstance(non_fteu, bool): + non_fteu = None + + push_mode = value.get("pushMode") + if push_mode is not None: + push_mode = str(push_mode) + + return cls( + device_id=device_id, + non_fteu=non_fteu, + push_mode=push_mode, + ) + + def as_phone_number_payload(self) -> dict[str, Any]: + """Return the nested phoneNumber payload expected by Apple's SMS endpoints.""" + + payload: dict[str, Any] = {"id": self.device_id} + if self.non_fteu is not None: + payload["nonFTEU"] = self.non_fteu + return payload + + +@dataclass(frozen=True) +class PhoneNumberVerification: + """Typed view of Apple's phone verification wrapper payload.""" + + trusted_phone_number: Optional[TrustedPhoneNumber] = None + trusted_phone_numbers: tuple[TrustedPhoneNumber, ...] = () + + @classmethod + def from_mapping( + cls, value: Optional[Mapping[str, Any]] + ) -> "PhoneNumberVerification": + """Return the parsed phone verification payload when Apple exposes one.""" + + if not isinstance(value, Mapping): + return cls() + + trusted_phone_number = TrustedPhoneNumber.from_mapping( + value.get("trustedPhoneNumber") + ) + + trusted_phone_numbers_raw = value.get("trustedPhoneNumbers") + trusted_phone_numbers: list[TrustedPhoneNumber] = [] + if isinstance(trusted_phone_numbers_raw, list): + for entry in trusted_phone_numbers_raw: + phone_number = TrustedPhoneNumber.from_mapping(entry) + if phone_number is not None: + trusted_phone_numbers.append(phone_number) + + return cls( + trusted_phone_number=trusted_phone_number, + trusted_phone_numbers=tuple(trusted_phone_numbers), + ) + + def best_trusted_phone_number(self) -> Optional[TrustedPhoneNumber]: + """Return the first usable trusted phone number from Apple's payload.""" + + if self.trusted_phone_number is not None: + return self.trusted_phone_number + if self.trusted_phone_numbers: + return self.trusted_phone_numbers[0] + return None + + class PyiCloudService: """ A base authentication class for the iCloud service. Handles the @@ -157,6 +252,7 @@ def __init__( authenticate: bool = True, cloudkit_validation_extra: Optional[CloudKitExtraMode] = None, ) -> None: + """Initialize a service session for one Apple ID account.""" self._is_china_mainland: bool = ( environ.get("icloud_china", "0") == "1" if china_mainland is None @@ -175,6 +271,11 @@ def __init__( self.data: dict[str, Any] = {} self._auth_data: dict[str, Any] = {} + self._hsa2_boot_context: Optional[Hsa2BootContext] = None + self._trusted_device_bridge_state: Optional[TrustedDeviceBridgeState] = None + self._trusted_device_bridge = TrustedDeviceBridgeBootstrapper() + self._two_factor_delivery_method: str = "unknown" + self._two_factor_delivery_notice: Optional[str] = None self.params: dict[str, Any] = {} self._client_id: str = client_id or str(uuid1()).lower() @@ -326,6 +427,10 @@ def _clear_authenticated_state(self) -> None: self.data = {} self._auth_data = {} + self._hsa2_boot_context = None + self._clear_trusted_device_bridge_state() + self._two_factor_delivery_method = "unknown" + self._two_factor_delivery_notice = None self._webservices = None self._account = None self._calendar = None @@ -339,6 +444,13 @@ def _clear_authenticated_state(self) -> None: self._requires_mfa = False self.params.pop("dsid", None) + def _clear_trusted_device_bridge_state(self) -> None: + """Close any active trusted-device bridge session and clear in-memory state.""" + + if self._trusted_device_bridge_state is not None: + self._trusted_device_bridge.close(self._trusted_device_bridge_state) + self._trusted_device_bridge_state = None + def get_auth_status(self) -> dict[str, Any]: """Probe current authentication state without prompting for login.""" @@ -416,6 +528,7 @@ def logout( } def _authenticate(self) -> None: + """Authenticate with either the cached session token or fresh credentials.""" LOGGER.debug("Authenticating as %s", self.account_name) try: @@ -542,6 +655,11 @@ def _authenticate_with_token(self) -> None: if not self.is_trusted_session: raise PyiCloud2FARequiredException(self.account_name, resp) + + self._auth_data = {} + self._hsa2_boot_context = None + self._clear_trusted_device_bridge_state() + self._set_two_factor_delivery_state("unknown") except (PyiCloudAPIResponseException, HTTPError) as error: msg = "Invalid authentication token." raise PyiCloudFailedLoginException(msg, error) from error @@ -584,6 +702,7 @@ def _validate_token(self) -> Any: def _get_auth_headers( self, overrides: Optional[dict[str, Any]] = None ) -> dict[str, Any]: + """Build Apple auth headers for IDMS, bridge, and verification requests.""" headers: dict[str, Any] = _AUTH_HEADERS_JSON.copy() headers.update( { @@ -612,6 +731,7 @@ def session(self) -> PyiCloudSession: return self._session def _is_mfa_required(self) -> bool: + """Return whether the current auth state still requires MFA completion.""" return ( self.data.get("hsaChallengeRequired", False) or not self.is_trusted_session @@ -679,9 +799,111 @@ def validate_verification_code(self, device: dict[str, Any], code: str) -> bool: def _get_mfa_auth_options(self) -> Dict: """Retrieve auth request options for assertion.""" + # Apple exposes the HSA2 bridge bootstrap in the HTML auth shell. + # Requesting JSON here tends to collapse the response to the SMS-oriented shape. + headers = self._get_auth_headers({"Accept": "text/html"}) + response = self.session.get(self._auth_endpoint, headers=headers) + + auth_options: dict[str, Any] = {} + try: + response_json = response.json() + except (AttributeError, TypeError, ValueError): + response_json = None + + if isinstance(response_json, dict): + auth_options.update(response_json) + boot_context = Hsa2BootContext.from_auth_options(auth_options) + else: + boot_context = parse_boot_args_html(getattr(response, "text", "")) + + boot_auth_data = boot_context.as_auth_data() + auth_options.update(boot_auth_data) + self._hsa2_boot_context = boot_context + self._clear_trusted_device_bridge_state() + self._set_two_factor_delivery_state("unknown") + return auth_options + + def _set_two_factor_delivery_state( + self, method: str, notice: Optional[str] = None + ) -> None: + """Track the active MFA delivery route for the current auth challenge.""" + + self._two_factor_delivery_method = method + self._two_factor_delivery_notice = notice + + def _current_hsa2_boot_context(self) -> Hsa2BootContext: + """Return the best available HSA2 boot context for the active challenge.""" + + if self._hsa2_boot_context is not None: + return self._hsa2_boot_context + + boot_context = Hsa2BootContext.from_auth_options(self._auth_data) + self._hsa2_boot_context = boot_context + return boot_context + + def _supports_trusted_device_bridge(self) -> bool: + """Return whether Apple's HSA2 boot data prefers the bridge flow.""" + + boot_context = self._current_hsa2_boot_context() + return ( + boot_context.auth_initial_route == "auth/bridge/step" + and boot_context.has_trusted_devices + and bool(boot_context.bridge_initiate_data) + ) + + def _can_request_sms_2fa_code(self) -> bool: + """Return whether SMS delivery is currently available.""" + + return ( + self._two_factor_mode() == "sms" + and self._trusted_phone_number() is not None + ) + + def _request_sms_2fa_code(self, notice: Optional[str] = None) -> bool: + """Trigger SMS delivery for the current HSA2 challenge.""" + + trusted_phone_number = self._trusted_phone_number() + if not trusted_phone_number: + raise PyiCloudNoTrustedNumberAvailable() + + data: dict[str, Any] = { + "phoneNumber": trusted_phone_number.as_phone_number_payload(), + "mode": "sms", + } headers = self._get_auth_headers({"Accept": CONTENT_TYPE_JSON}) - return self.session.get(self._auth_endpoint, headers=headers).json() + self.session.put( + f"{self._auth_endpoint}/verify/phone", + json=data, + headers=headers, + ) + self._clear_trusted_device_bridge_state() + self._set_two_factor_delivery_state("sms", notice) + return True + + @property + def two_factor_delivery_method(self) -> str: + """Return the current HSA2 delivery method without exposing auth internals.""" + + if self._two_factor_delivery_method != "unknown": + return self._two_factor_delivery_method + + if self._auth_data.get("fsaChallenge") or self.security_key_names: + return "security_key" + + if self._supports_trusted_device_bridge(): + return "trusted_device" + + if self._two_factor_mode() == "sms": + return "sms" + + return "unknown" + + @property + def two_factor_delivery_notice(self) -> Optional[str]: + """Return an optional user-facing note about the active 2FA delivery path.""" + + return self._two_factor_delivery_notice @property def security_key_names(self) -> Optional[List[str]]: @@ -696,6 +918,74 @@ def _submit_webauthn_assertion_response(self, data: Dict) -> None: f"{self._auth_endpoint}/verify/security/key", json=data, headers=headers ) + def _phone_number_verification(self) -> PhoneNumberVerification: + """Return Apple's nested phone verification payload when present.""" + + phone_verification = self._auth_data.get("phoneNumberVerification") + return PhoneNumberVerification.from_mapping(phone_verification) + + def _trusted_phone_number(self) -> Optional[TrustedPhoneNumber]: + """Return the best available trusted phone number description.""" + + trusted_phone_number = TrustedPhoneNumber.from_mapping( + self._auth_data.get("trustedPhoneNumber") + ) + if trusted_phone_number is not None: + return trusted_phone_number + + return self._phone_number_verification().best_trusted_phone_number() + + def _two_factor_mode(self) -> Optional[str]: + """Return the current 2FA delivery mode reported by Apple.""" + + mode = self._auth_data.get("mode") + if isinstance(mode, str): + return mode + + trusted_phone_number = self._trusted_phone_number() + if trusted_phone_number is None: + return None + + return trusted_phone_number.push_mode + + def request_2fa_code(self) -> bool: + """Trigger the active HSA2 delivery route for the current challenge.""" + + if self._auth_data.get("fsaChallenge") or self.security_key_names: + self._set_two_factor_delivery_state("security_key") + return False + + self._clear_trusted_device_bridge_state() + + if self._supports_trusted_device_bridge(): + try: + self._trusted_device_bridge_state = self._trusted_device_bridge.start( + session=self.session, + auth_endpoint=self._auth_endpoint, + headers=self._get_auth_headers({"Accept": CONTENT_TYPE_JSON}), + boot_context=self._current_hsa2_boot_context(), + user_agent=self.session.headers.get( + "User-Agent", _HEADERS["User-Agent"] + ), + ) + self._set_two_factor_delivery_state("trusted_device") + return True + except PyiCloudTrustedDevicePromptException: + LOGGER.debug( + "Trusted-device bridge bootstrap failed; falling back to SMS when available.", + exc_info=True, + ) + if self._can_request_sms_2fa_code(): + return self._request_sms_2fa_code( + notice="Trusted-device prompt failed; falling back to SMS." + ) + raise + + if self._can_request_sms_2fa_code(): + return self._request_sms_2fa_code() + + return False + @property def fido2_devices(self) -> List[CtapHidDevice]: """List the available FIDO2 devices.""" @@ -831,45 +1121,61 @@ def _request_pcs_for_service(self, app_name: str) -> None: def validate_2fa_code(self, code: str) -> bool: """Verifies a verification code received via Apple's 2FA system (HSA2).""" + bridge_state = self._trusted_device_bridge_state try: - if self._auth_data.get("mode") == "sms": + if self.two_factor_delivery_method == "sms": self._validate_sms_code(code) + elif ( + bridge_state is not None + and not bridge_state.uses_legacy_trusted_device_verifier + ): + if not self._trusted_device_bridge.validate_code( + session=self.session, + auth_endpoint=self._auth_endpoint, + headers=self._get_auth_headers({"Accept": CONTENT_TYPE_JSON}), + bridge_state=bridge_state, + code=code, + ): + LOGGER.error("Code verification failed.") + return False else: - data: dict[str, Any] = {"securityCode": {"code": code}} - headers: dict[str, Any] = self._get_auth_headers( - {"Accept": CONTENT_TYPE_JSON} - ) - self.session.post( - f"{self._auth_endpoint}/verify/trusteddevice/securitycode", - json=data, - headers=headers, - ) + self._validate_trusted_device_code(code) + except PyiCloudTrustedDeviceVerificationException: + raise except PyiCloudAPIResponseException: # Wrong verification code LOGGER.error("Code verification failed.") return False + finally: + if bridge_state is not None: + self._clear_trusted_device_bridge_state() LOGGER.debug("Code verification successful.") self.trust_session() return not self.requires_2sa + def _validate_trusted_device_code(self, code: str) -> None: + """Verifies a verification code received via Apple's legacy device endpoint.""" + + data: dict[str, Any] = {"securityCode": {"code": code}} + headers: dict[str, Any] = self._get_auth_headers({"Accept": CONTENT_TYPE_JSON}) + self.session.post( + f"{self._auth_endpoint}/verify/trusteddevice/securitycode", + json=data, + headers=headers, + ) + def _validate_sms_code(self, code: str) -> None: """Verifies a verification code received via Apple's SMS system.""" - trusted_phone_number: dict[str, Any] | None = self._auth_data.get( - "trustedPhoneNumber" - ) + trusted_phone_number = self._trusted_phone_number() if not trusted_phone_number: raise PyiCloudNoTrustedNumberAvailable() - device_id: int | None = trusted_phone_number.get("id") - non_fteu: bool | None = trusted_phone_number.get("nonFTEU") - mode: str | None = trusted_phone_number.get("pushMode") - data: dict[str, Any] = { - "phoneNumber": {"id": device_id, "nonFTEU": non_fteu}, + "phoneNumber": trusted_phone_number.as_phone_number_payload(), "securityCode": {"code": code}, - "mode": mode, + "mode": trusted_phone_number.push_mode or "sms", } headers: dict[str, Any] = self._get_auth_headers( {"Accept": f"{CONTENT_TYPE_JSON}, {CONTENT_TYPE_TEXT}"} @@ -1112,7 +1418,9 @@ def account_name(self) -> str: return self._apple_id def __str__(self) -> str: + """Return a concise human-readable service description.""" return f"iCloud API: {self.account_name}" def __repr__(self) -> str: + """Mirror ``__str__`` for interactive inspection.""" return f"<{self}>" diff --git a/pyicloud/cli/context.py b/pyicloud/cli/context.py index ab22fccb..84d29791 100644 --- a/pyicloud/cli/context.py +++ b/pyicloud/cli/context.py @@ -17,9 +17,13 @@ from pyicloud import PyiCloudService, utils from pyicloud.base import resolve_cookie_directory from pyicloud.exceptions import ( + PyiCloudAPIResponseException, PyiCloudAuthRequiredException, PyiCloudFailedLoginException, + PyiCloudNoTrustedNumberAvailable, PyiCloudServiceUnavailable, + PyiCloudTrustedDevicePromptException, + PyiCloudTrustedDeviceVerificationException, ) from pyicloud.ssl_context import configurable_ssl_verification @@ -95,6 +99,7 @@ def __init__( log_level: LogLevel, output_format: OutputFormat, ) -> None: + """Capture the CLI options and shared runtime state for one invocation.""" self.username = (username or "").strip() self.password = password self.china_mainland = china_mainland @@ -231,6 +236,7 @@ def remember_account(self, api: PyiCloudService, *, select: bool = True) -> None self._resolved_username = api.account_name def _resolve_username(self) -> str: + """Resolve the Apple ID to use for the current CLI command.""" if self._resolved_username: return self._resolved_username @@ -276,6 +282,7 @@ def multiple_logged_in_accounts_message(usernames: list[str]) -> str: ) def _password_for_login(self, username: str) -> tuple[Optional[str], Optional[str]]: + """Return the password and its source for an interactive login flow.""" if self.password: return self.password, "explicit" @@ -289,6 +296,7 @@ def _password_for_login(self, username: str) -> tuple[Optional[str], Optional[st return utils.get_password(username, interactive=True), "prompt" def _configure_logging(self) -> None: + """Apply the requested log level once for the current CLI process.""" if self._logging_configured: return logging.basicConfig(level=self.log_level.logging_level()) @@ -302,6 +310,7 @@ def _stored_password_for_session(self, username: str) -> Optional[str]: return utils.get_password_from_keyring(username) def _prompt_index(self, prompt: str, count: int) -> int: + """Prompt for a zero-based selection index when multiple choices exist.""" if count <= 1 or not self.interactive: return 0 raw = typer.prompt(prompt, default="0") @@ -314,6 +323,7 @@ def _prompt_index(self, prompt: str, count: int) -> int: return idx def _handle_2fa(self, api: PyiCloudService) -> None: + """Complete Apple's HSA2 flow using a security key or code-based challenge.""" fido2_devices = list(getattr(api, "fido2_devices", []) or []) if fido2_devices: self.console.print("Security key verification required.") @@ -332,13 +342,56 @@ def _handle_2fa(self, api: PyiCloudService) -> None: raise CLIAbort( "Two-factor authentication is required, but interactive prompts are disabled." ) - code = typer.prompt("Enter 2FA code") - if not api.validate_2fa_code(code): - raise CLIAbort("Failed to verify the 2FA code.") + try: + if not api.request_2fa_code(): + raise CLIAbort( + "This 2FA challenge requires a security key. Connect one and retry." + ) + + notice = getattr(api, "two_factor_delivery_notice", None) + if notice: + self.console.print(notice) + + delivery_method = getattr(api, "two_factor_delivery_method", "unknown") + if delivery_method == "trusted_device": + self.console.print( + "Requested a 2FA prompt on your trusted Apple devices." + ) + elif delivery_method == "sms": + self.console.print("Requested a 2FA code by SMS.") + except PyiCloudNoTrustedNumberAvailable as exc: + raise CLIAbort( + "Two-factor authentication requires a trusted phone number, " + "but none was returned." + ) from exc + except PyiCloudTrustedDevicePromptException as exc: + raise CLIAbort( + "Failed to request the 2FA trusted-device prompt." + ) from exc + except PyiCloudAPIResponseException as exc: + raise CLIAbort("Failed to request the 2FA SMS code.") from exc + max_attempts = 3 + for attempt in range(max_attempts): + code = typer.prompt("Enter 2FA code") + try: + is_valid = api.validate_2fa_code(code) + except PyiCloudTrustedDeviceVerificationException as exc: + raise CLIAbort( + "Failed to verify the 2FA trusted-device code." + ) from exc + if is_valid: + break + remaining_attempts = max_attempts - attempt - 1 + if remaining_attempts <= 0: + raise CLIAbort("Failed to verify the 2FA code.") + self.console.print( + f"Invalid 2FA code. {remaining_attempts} attempt(s) remaining." + ) if not api.is_trusted_session: api.trust_session() def _handle_2sa(self, api: PyiCloudService) -> None: + """Complete Apple's legacy two-step authentication flow.""" devices = list(api.trusted_devices or []) if not devices: raise CLIAbort( diff --git a/pyicloud/exceptions.py b/pyicloud/exceptions.py index 57f2a7de..b0772743 100644 --- a/pyicloud/exceptions.py +++ b/pyicloud/exceptions.py @@ -31,6 +31,7 @@ def __init__( code: Optional[Union[int, str]] = None, response: Optional[Response] = None, ) -> None: + """Capture a normalized API error and the optional HTTP context.""" self.reason: str = reason self.code: Optional[Union[int, str]] = code self.response: Optional[Response] = response @@ -58,6 +59,7 @@ def __init__( *args, response: Optional[Response] = None, ) -> None: + """Initialize a login failure with optional HTTP response details.""" self.response: Optional[Response] = response message: str = msg or "Failed login to iCloud" if response is not None and response.text: @@ -73,6 +75,7 @@ class PyiCloud2FARequiredException(PyiCloudException): """iCloud 2FA required exception.""" def __init__(self, apple_id: str, response: Response) -> None: + """Initialize a 2FA-required error for an HSA2 login challenge.""" message: str = f"2FA authentication required for account: {apple_id} (HSA2)" super().__init__(message) self.response: Response = response @@ -82,6 +85,7 @@ class PyiCloud2SARequiredException(PyiCloudException): """iCloud 2SA required exception.""" def __init__(self, apple_id: str) -> None: + """Initialize a 2SA-required error for a legacy login challenge.""" message: str = f"Two-step authentication required for account: {apple_id}" super().__init__(message) @@ -90,6 +94,7 @@ class PyiCloudAuthRequiredException(PyiCloudException): """iCloud re-authentication required exception.""" def __init__(self, apple_id: str, response: Response) -> None: + """Initialize a reauthentication-required error with the triggering response.""" message: str = f"Re-authentication required for account: {apple_id}" super().__init__(message) self.response: Response = response @@ -99,6 +104,14 @@ class PyiCloudNoTrustedNumberAvailable(PyiCloudException): """iCloud no trusted number exception.""" +class PyiCloudTrustedDevicePromptException(PyiCloudAPIResponseException): + """Trusted-device prompt bootstrap exception.""" + + +class PyiCloudTrustedDeviceVerificationException(PyiCloudAPIResponseException): + """Trusted-device bridge verification exception.""" + + class PyiCloudNoStoredPasswordAvailableException(PyiCloudException): """iCloud no stored password exception.""" diff --git a/pyicloud/hsa2_bridge.py b/pyicloud/hsa2_bridge.py new file mode 100644 index 00000000..d9b8209d --- /dev/null +++ b/pyicloud/hsa2_bridge.py @@ -0,0 +1,1690 @@ +"""Internal helpers for Apple's HSA2 trusted-device bridge flow.""" + +from __future__ import annotations + +import base64 +import hashlib +import json +import logging +import os +import socket +import ssl +import struct +import time +import uuid +from binascii import Error as BinasciiError +from dataclasses import dataclass, field +from html.parser import HTMLParser +from typing import Any, Callable, Mapping, Optional, Protocol +from urllib.parse import urlparse + +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec +from pydantic import ( + BaseModel, + ConfigDict, + Field, + StrictInt, + StrictStr, + ValidationError, + field_validator, +) + +from pyicloud.exceptions import ( + PyiCloudTrustedDevicePromptException, + PyiCloudTrustedDeviceVerificationException, +) +from pyicloud.hsa2_bridge_prover import TrustedDeviceBridgeProver + +LOGGER = logging.getLogger(__name__) + +BRIDGE_STEP_PATH = "/bridge/step/0" +BRIDGE_STEP_PATH_TEMPLATE = "/bridge/step/{step}" +BRIDGE_CODE_VALIDATE_PATH = "/bridge/code/validate" +NEW_CONNECTION_EXPIRATION_SECONDS = 86400 +OPCODE_BINARY = 0x2 +OPCODE_CLOSE = 0x8 +OPCODE_PING = 0x9 +OPCODE_PONG = 0xA +SERVER_MESSAGE_CONNECTION_RESPONSE = 1 +SERVER_MESSAGE_PUSH = 2 +SERVER_MESSAGE_CHANNEL_SUBSCRIPTION_RESPONSE = 3 +SERVER_MESSAGE_PUSH_ACK = 7 +STATUS_OK = 0 +STATUS_INVALID_NONCE = 2 +BRIDGE_SIGNATURE_PREFIX = b"\x01\x03" +BRIDGE_DONE_DATA_B64 = base64.b64encode(b"done").decode("ascii") +WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" +WEBSOCKET_TIMEOUT_SECONDS = 30.0 +WEBSOCKET_ENVIRONMENT_HOSTS: dict[str, str] = { + "prod": "websocket.push.apple.com", + "sandbox": "websocket.sandbox.push.apple.com", +} +HTTP_STATUS_OK = 200 +HTTP_STATUS_NO_CONTENT = 204 +HTTP_STATUS_CONFLICT = 409 +HTTP_STATUS_PRECONDITION_FAILED = 412 + + +@dataclass(frozen=True) +class Hsa2BootContext: + """Bridge-related HSA2 boot data parsed from Apple's HTML bootstrap.""" + + auth_initial_route: str = "" + has_trusted_devices: bool = False + auth_factors: tuple[str, ...] = () + bridge_initiate_data: dict[str, Any] = field(default_factory=dict) + phone_number_verification: dict[str, Any] = field(default_factory=dict) + source_app_id: Optional[str] = None + + @classmethod + def from_auth_options(cls, auth_options: Mapping[str, Any]) -> "Hsa2BootContext": + """Build a normalized boot context from Apple's auth-options payload.""" + bridge_initiate_data = auth_options.get("bridgeInitiateData") + if not isinstance(bridge_initiate_data, dict): + bridge_initiate_data = {} + + phone_number_verification = auth_options.get("phoneNumberVerification") + if not isinstance(phone_number_verification, dict): + phone_number_verification = bridge_initiate_data.get( + "phoneNumberVerification" + ) + if not isinstance(phone_number_verification, dict): + phone_number_verification = {} + + auth_factors = auth_options.get("authFactors") + if not isinstance(auth_factors, list): + auth_factors = [] + + source_app_id = auth_options.get("sourceAppId") + if source_app_id is not None: + source_app_id = str(source_app_id) + + return cls( + auth_initial_route=str(auth_options.get("authInitialRoute") or ""), + has_trusted_devices=bool(auth_options.get("hasTrustedDevices")), + auth_factors=tuple( + factor for factor in auth_factors if isinstance(factor, str) + ), + bridge_initiate_data=dict(bridge_initiate_data), + phone_number_verification=dict(phone_number_verification), + source_app_id=source_app_id, + ) + + def as_auth_data(self) -> dict[str, Any]: + """Return parsed boot data in the shape expected by the auth flow.""" + + auth_data: dict[str, Any] = { + "authInitialRoute": self.auth_initial_route, + "hasTrustedDevices": self.has_trusted_devices, + "authFactors": list(self.auth_factors), + } + if self.bridge_initiate_data: + auth_data["bridgeInitiateData"] = dict(self.bridge_initiate_data) + if self.phone_number_verification: + auth_data["phoneNumberVerification"] = dict(self.phone_number_verification) + trusted_phone_number = self.phone_number_verification.get( + "trustedPhoneNumber" + ) + if isinstance(trusted_phone_number, dict): + auth_data["trustedPhoneNumber"] = dict(trusted_phone_number) + if self.source_app_id is not None: + auth_data["sourceAppId"] = self.source_app_id + return auth_data + + +class _BridgePushPayloadModel(BaseModel): + """Strict validator for Apple's bridge push JSON envelope.""" + + model_config = ConfigDict( + extra="allow", + populate_by_name=True, + arbitrary_types_allowed=True, + ) + + session_uuid: StrictStr = Field(alias="sessionUUID") + next_step: Optional[StrictStr | StrictInt] = Field(default=None, alias="nextStep") + rui_url_key: Optional[str] = Field(default=None, alias="ruiURLKey") + txnid: Optional[StrictStr] = None + salt: Optional[StrictStr] = None + mid: Optional[StrictStr] = None + idmsdata: Optional[StrictStr] = None + akdata: Any = None + data: Optional[StrictStr] = None + encrypted_code: Optional[StrictStr] = Field(default=None, alias="encryptedCode") + error_code: Optional[StrictInt] = Field(default=None, alias="ec") + + @field_validator("session_uuid") + @classmethod + def _validate_session_uuid(cls, value: str) -> str: + """Reject blank bridge session identifiers.""" + if not value.strip(): + raise ValueError("sessionUUID must not be blank") + return value + + @field_validator( + "txnid", + "salt", + "mid", + "idmsdata", + "data", + "encrypted_code", + ) + @classmethod + def _validate_optional_non_empty_strings( + cls, value: Optional[str] + ) -> Optional[str]: + """Reject present-but-blank optional bridge string fields.""" + if value is not None and not value.strip(): + raise ValueError("Bridge payload strings must not be blank") + return value + + @field_validator("next_step") + @classmethod + def _validate_next_step(cls, value: Optional[str | int]) -> Optional[str | int]: + """Reject blank next-step markers while allowing ints or strings.""" + if isinstance(value, str) and not value.strip(): + raise ValueError("nextStep must not be blank") + return value + + +@dataclass(frozen=True) +class BridgePushPayload: + """Decoded bridge push metadata needed to bootstrap trusted-device prompts.""" + + payload: dict[str, Any] + session_uuid: str + next_step: Optional[str] = None + rui_url_key: Optional[str] = None + txnid: Optional[str] = None + salt: Optional[str] = None + mid: Optional[str] = None + idmsdata: Optional[str] = None + akdata: Any = None + data: Optional[str] = None + encrypted_code: Optional[str] = None + error_code: Optional[int] = None + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> "BridgePushPayload": + """Validate and normalize one decoded bridge push payload.""" + try: + validated = _BridgePushPayloadModel.model_validate(payload) + except ValidationError as exc: + raise PyiCloudTrustedDevicePromptException( + "Malformed trusted-device bridge push payload." + ) from exc + + if not validated.session_uuid: + raise PyiCloudTrustedDevicePromptException( + "Trusted-device bridge push payload is missing sessionUUID." + ) + + return cls( + payload=payload, + session_uuid=validated.session_uuid, + next_step=( + str(validated.next_step) if validated.next_step is not None else None + ), + rui_url_key=validated.rui_url_key, + txnid=validated.txnid, + salt=validated.salt, + mid=validated.mid, + idmsdata=validated.idmsdata, + akdata=validated.akdata, + data=validated.data, + encrypted_code=validated.encrypted_code, + error_code=validated.error_code, + ) + + +@dataclass +class TrustedDeviceBridgeState: + """Ephemeral trusted-device bridge state.""" + + connection_path: str + push_token: str + session_uuid: str + websocket: Optional[_WebSocketLike] + topic: str + topics_by_hash: dict[str, str] + source_app_id: Optional[str] = None + next_step: Optional[str] = None + rui_url_key: Optional[str] = None + push_payload: dict[str, Any] = field(default_factory=dict) + txnid: Optional[str] = None + salt: Optional[str] = None + mid: Optional[str] = None + idmsdata: Optional[str] = None + akdata: Any = None + data: Optional[str] = None + encrypted_code: Optional[str] = None + error_code: Optional[int] = None + + def apply_push_payload(self, push_payload: BridgePushPayload) -> None: + """Persist the latest bridge push metadata in the live bridge session.""" + + self.push_payload = dict(push_payload.payload) + self.session_uuid = push_payload.session_uuid + self.next_step = push_payload.next_step + self.rui_url_key = push_payload.rui_url_key + self.txnid = push_payload.txnid + self.salt = push_payload.salt + self.mid = push_payload.mid + self.idmsdata = push_payload.idmsdata + self.akdata = push_payload.akdata + self.data = push_payload.data + self.encrypted_code = push_payload.encrypted_code + self.error_code = push_payload.error_code + + @property + def uses_legacy_trusted_device_verifier(self) -> bool: + """Return whether Apple routed this bridge challenge to the legacy verifier.""" + + return bool(self.txnid and self.txnid.endswith("_W")) + + +@dataclass(frozen=True) +class BridgeStepRequest: + """Typed request body for Apple's bridge step endpoints.""" + + session_uuid: str + data: str + push_token: str + next_step: int + idmsdata: Optional[str] = None + akdata: Any = None + + def as_json(self) -> dict[str, Any]: + """Serialize the step request into Apple's JSON envelope.""" + payload: dict[str, Any] = { + "sessionUUID": self.session_uuid, + "data": self.data, + "ptkn": self.push_token, + "nextStep": self.next_step, + } + if self.idmsdata is not None: + payload["idmsdata"] = self.idmsdata + if self.akdata is not None: + payload["akdata"] = ( + json.dumps(self.akdata, separators=(",", ":")) + if isinstance(self.akdata, dict) + else self.akdata + ) + return payload + + +@dataclass(frozen=True) +class BridgeCodeValidateRequest: + """Typed request body for Apple's final bridge code validation endpoint.""" + + session_uuid: str + code: str + + def as_json(self) -> dict[str, str]: + """Serialize the final bridge code-validation request body.""" + return { + "sessionUUID": self.session_uuid, + "code": self.code, + } + + +@dataclass(frozen=True) +class _ConnectionResponse: + """Decoded server response for the initial websocket bootstrap.""" + + push_token_b64: str = "" + status: int = 0 + server_timestamp_seconds: Optional[int] = None + + +@dataclass(frozen=True) +class _PushMessage: + """Decoded APNS-style push frame from the bridge websocket.""" + + topic: bytes + message_id: int + payload: bytes + + +@dataclass(frozen=True) +class _ChannelSubscriptionResponse: + """Decoded response to the bridge topic subscription request.""" + + message_id: int = 0 + status: int = 0 + retry_interval_seconds: int = 0 + topics: tuple[str, ...] = () + + +@dataclass(frozen=True) +class _AcknowledgementMessage: + """Decoded acknowledgment frame emitted by Apple's bridge service.""" + + topic: bytes + message_id: int + delivery_status: int = 0 + + +@dataclass(frozen=True) +class _ServerMessage: + """One websocket frame decoded into its known top-level message variants.""" + + connection_response: Optional[_ConnectionResponse] = None + push_message: Optional[_PushMessage] = None + channel_subscription_response: Optional[_ChannelSubscriptionResponse] = None + push_acknowledgment: Optional[_AcknowledgementMessage] = None + field_numbers: tuple[int, ...] = () + + +class _WebSocketLike(Protocol): + """Protocol for the minimal websocket operations used by the bridge flow.""" + + def send_binary(self, payload: bytes) -> None: + """Send one binary websocket message.""" + + def read_message(self) -> bytes: + """Read one complete websocket message payload.""" + + def close(self) -> None: + """Close the websocket transport.""" + + +class _InvalidNonceError(Exception): + """Signal Apple's INVALID_NONCE response along with the server timestamp.""" + + def __init__(self, server_timestamp_ms: int) -> None: + """Capture the server timestamp returned with INVALID_NONCE.""" + super().__init__("Invalid nonce from bridge server.") + self.server_timestamp_ms = server_timestamp_ms + + +class _BootArgsHTMLParser(HTMLParser): + """Extract the JSON body from Apple's boot_args script tag.""" + + def __init__(self) -> None: + """Initialize parser state for the first matching boot_args script tag.""" + super().__init__() + self._collecting = False + self._found = False + self._chunks: list[str] = [] + + @property + def payload(self) -> str: + """Return the collected boot_args JSON text.""" + return "".join(self._chunks).strip() + + def handle_starttag(self, tag: str, attrs: list[tuple[str, Optional[str]]]) -> None: + """Start collecting data when the boot_args script tag is found.""" + if tag != "script" or self._found: + return + attr_map = {key: value for key, value in attrs} + classes = (attr_map.get("class") or "").split() + if "boot_args" in classes: + self._collecting = True + self._found = True + + def handle_endtag(self, tag: str) -> None: + """Stop collecting when the current script tag closes.""" + if tag == "script" and self._collecting: + self._collecting = False + + def handle_data(self, data: str) -> None: + """Append script contents while the boot_args tag is active.""" + if self._collecting: + self._chunks.append(data) + + +def parse_boot_args_html(html_text: str) -> Hsa2BootContext: + """Extract HSA2 boot args from the HTML returned by GET /appleauth/auth.""" + + parser = _BootArgsHTMLParser() + parser.feed(html_text) + parser.close() + + payload_text = parser.payload + if not payload_text: + raise PyiCloudTrustedDevicePromptException("Missing HSA2 boot args payload.") + + try: + payload = json.loads(payload_text) + except json.JSONDecodeError as exc: + raise PyiCloudTrustedDevicePromptException( + "Malformed HSA2 boot args payload." + ) from exc + direct = payload.get("direct") + if not isinstance(direct, dict): + raise PyiCloudTrustedDevicePromptException("Missing HSA2 direct boot data.") + + two_sv = direct.get("twoSV") + if not isinstance(two_sv, dict): + two_sv = {} + + bridge_initiate_data = two_sv.get("bridgeInitiateData") + if not isinstance(bridge_initiate_data, dict): + bridge_initiate_data = {} + + phone_number_verification = bridge_initiate_data.get("phoneNumberVerification") + if not isinstance(phone_number_verification, dict): + phone_number_verification = {} + + auth_factors = two_sv.get("authFactors") + if not isinstance(auth_factors, list): + auth_factors = [] + + source_app_id = two_sv.get("sourceAppId") + if source_app_id is not None: + source_app_id = str(source_app_id) + + return Hsa2BootContext( + auth_initial_route=str(direct.get("authInitialRoute") or ""), + has_trusted_devices=bool(direct.get("hasTrustedDevices")), + auth_factors=tuple( + factor for factor in auth_factors if isinstance(factor, str) + ), + bridge_initiate_data=dict(bridge_initiate_data), + phone_number_verification=dict(phone_number_verification), + source_app_id=source_app_id, + ) + + +def _encode_varint(value: int) -> bytes: + """Encode an unsigned protobuf varint.""" + if value < 0: + raise ValueError("Negative varints are not supported.") + parts = bytearray() + while True: + to_write = value & 0x7F + value >>= 7 + if value: + parts.append(to_write | 0x80) + else: + parts.append(to_write) + return bytes(parts) + + +def _read_varint(data: bytes, offset: int) -> tuple[int, int]: + """Decode one protobuf varint from a byte string and return the new offset.""" + value = 0 + shift = 0 + start_offset = offset + while True: + if offset >= len(data): + raise PyiCloudTrustedDevicePromptException("Truncated protobuf varint.") + byte = data[offset] + offset += 1 + value |= (byte & 0x7F) << shift + if not (byte & 0x80): + return value, offset + shift += 7 + # Guard against malformed wire data rather than silently accepting an + # overlong varint from Apple's private bridge protocol. + if shift > 63 or offset - start_offset >= 10: + raise PyiCloudTrustedDevicePromptException("Malformed protobuf varint.") + + +def _encode_field(field_number: int, wire_type: int, value: bytes) -> bytes: + """Encode one protobuf field header and payload.""" + return _encode_varint((field_number << 3) | wire_type) + value + + +def _encode_bytes_field(field_number: int, value: bytes) -> bytes: + """Encode a length-delimited protobuf field.""" + return _encode_field(field_number, 2, _encode_varint(len(value)) + value) + + +def _encode_string_field(field_number: int, value: str) -> bytes: + """Encode a UTF-8 string protobuf field.""" + return _encode_bytes_field(field_number, value.encode("utf-8")) + + +def _encode_uint32_field(field_number: int, value: int) -> bytes: + """Encode an unsigned integer protobuf field.""" + return _encode_field(field_number, 0, _encode_varint(value)) + + +def _decode_fields(data: bytes) -> dict[int, list[Any]]: + """Decode a minimal subset of protobuf wire types into field lists.""" + offset = 0 + fields: dict[int, list[Any]] = {} + while offset < len(data): + key, offset = _read_varint(data, offset) + field_number = key >> 3 + wire_type = key & 0x07 + + if wire_type == 0: + value, offset = _read_varint(data, offset) + elif wire_type == 2: + length, offset = _read_varint(data, offset) + # Length-delimited fields must stay within the current message + # bounds; otherwise the bridge frame is truncated or malformed. + end_offset = offset + length + if end_offset > len(data): + raise PyiCloudTrustedDevicePromptException("Truncated protobuf field.") + value = data[offset:end_offset] + offset = end_offset + else: + raise PyiCloudTrustedDevicePromptException( + f"Unsupported protobuf wire type: {wire_type}" + ) + + fields.setdefault(field_number, []).append(value) + return fields + + +def _decode_connection_response(message: bytes) -> _ConnectionResponse: + """Decode the server's websocket bootstrap response.""" + fields = _decode_fields(message) + push_token_b64 = "" + if fields.get(1): + try: + push_token_b64 = fields[1][0].decode("ascii") + except UnicodeDecodeError as exc: + raise PyiCloudTrustedDevicePromptException( + "Malformed bridge connection response push token." + ) from exc + status = int(fields.get(2, [0])[0]) + server_timestamp_seconds = None + if fields.get(3): + server_timestamp_seconds = int(fields[3][0]) + return _ConnectionResponse( + push_token_b64=push_token_b64, + status=status, + server_timestamp_seconds=server_timestamp_seconds, + ) + + +def _decode_push_message(message: bytes) -> _PushMessage: + """Decode one push-delivery frame from the bridge websocket.""" + fields = _decode_fields(message) + topic = bytes(fields.get(1, [b""])[0]) + message_id = int(fields.get(2, [0])[0]) + payload = bytes(fields.get(4, [b""])[0]) + return _PushMessage(topic=topic, message_id=message_id, payload=payload) + + +def _decode_channel_subscription_response( + message: bytes, +) -> _ChannelSubscriptionResponse: + """Decode the server's response to the topic subscription message.""" + fields = _decode_fields(message) + topics: list[str] = [] + + payload_values = fields.get(1) + if payload_values: + payload_fields = _decode_fields(bytes(payload_values[0])) + for app_response_value in payload_fields.get(1, []): + app_response_fields = _decode_fields(bytes(app_response_value)) + topic_value = app_response_fields.get(1, [b""])[0] + if isinstance(topic_value, bytes): + topics.append(topic_value.decode("utf-8", "ignore")) + + return _ChannelSubscriptionResponse( + message_id=int(fields.get(2, [0])[0]), + status=int(fields.get(3, [0])[0]), + retry_interval_seconds=int(fields.get(4, [0])[0]), + topics=tuple(topic for topic in topics if topic), + ) + + +def _decode_acknowledgement_message(message: bytes) -> _AcknowledgementMessage: + """Decode a push acknowledgment frame from the bridge websocket.""" + fields = _decode_fields(message) + topic = bytes(fields.get(1, [b""])[0]) + message_id = int(fields.get(2, [0])[0]) + delivery_status = int(fields.get(3, [0])[0]) + return _AcknowledgementMessage( + topic=topic, + message_id=message_id, + delivery_status=delivery_status, + ) + + +def _decode_server_message(message: bytes) -> _ServerMessage: + """Decode all known top-level messages embedded in one websocket frame.""" + fields = _decode_fields(message) + + connection_response = None + if fields.get(SERVER_MESSAGE_CONNECTION_RESPONSE): + connection_response = _decode_connection_response( + bytes(fields[SERVER_MESSAGE_CONNECTION_RESPONSE][0]) + ) + + push_message = None + if fields.get(SERVER_MESSAGE_PUSH): + push_message = _decode_push_message(bytes(fields[SERVER_MESSAGE_PUSH][0])) + + channel_subscription_response = None + if fields.get(SERVER_MESSAGE_CHANNEL_SUBSCRIPTION_RESPONSE): + channel_subscription_response = _decode_channel_subscription_response( + bytes(fields[SERVER_MESSAGE_CHANNEL_SUBSCRIPTION_RESPONSE][0]) + ) + + push_acknowledgment = None + if fields.get(SERVER_MESSAGE_PUSH_ACK): + push_acknowledgment = _decode_acknowledgement_message( + bytes(fields[SERVER_MESSAGE_PUSH_ACK][0]) + ) + + return _ServerMessage( + connection_response=connection_response, + push_message=push_message, + channel_subscription_response=channel_subscription_response, + push_acknowledgment=push_acknowledgment, + field_numbers=tuple(sorted(fields)), + ) + + +def _encode_connection_message( + public_key: bytes, nonce: bytes, signature: bytes +) -> bytes: + """Encode the initial bridge websocket bootstrap message.""" + connection_message = b"".join( + [ + _encode_bytes_field(1, public_key), + _encode_bytes_field(2, nonce), + _encode_bytes_field(3, _encode_bridge_signature(signature)), + _encode_bytes_field( + 5, _encode_uint32_field(1, NEW_CONNECTION_EXPIRATION_SECONDS) + ), + ] + ) + return _encode_bytes_field(1, connection_message) + + +def _encode_bridge_signature(signature: bytes) -> bytes: + """Wrap the DER ECDSA signature using Apple's bridge signature envelope.""" + + if signature.startswith(BRIDGE_SIGNATURE_PREFIX): + return signature + return BRIDGE_SIGNATURE_PREFIX + signature + + +def _encode_web_filter_message(allowed_topics: list[str]) -> bytes: + """Encode the topic subscription message sent after bridge connect.""" + filter_payload = b"".join( + _encode_string_field(1, topic) for topic in allowed_topics + ) + return _encode_bytes_field(3, filter_payload) + + +def _encode_ack_message(topic: bytes, message_id: int) -> bytes: + """Encode the acknowledgment frame for one delivered push message.""" + ack_payload = b"".join( + [ + _encode_bytes_field(1, topic), + _encode_uint32_field(2, message_id), + ] + ) + return _encode_bytes_field(2, ack_payload) + + +def _topic_hash(topic: str) -> str: + """Return Apple's websocket topic hash for a named APNS topic.""" + return hashlib.sha1(topic.encode("utf-8")).hexdigest() + + +def _topic_name(topic_bytes: bytes, topics_by_hash: Mapping[str, str]) -> str: + """Resolve a hashed topic payload back to a readable topic name.""" + return topics_by_hash.get(topic_bytes.hex(), topic_bytes.decode("utf-8", "ignore")) + + +def _extract_json_payload(payload: bytes) -> dict[str, Any]: + """Extract the JSON object embedded in one bridge push payload.""" + try: + return json.loads(payload.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError): + text = payload.decode("utf-8", "ignore") + + start = text.find("{") + while start >= 0: + depth = 0 + in_string = False + escaped = False + for index, character in enumerate(text[start:], start=start): + if in_string: + if escaped: + escaped = False + elif character == "\\": + escaped = True + elif character == '"': + in_string = False + continue + if character == '"': + in_string = True + elif character == "{": + depth += 1 + elif character == "}": + depth -= 1 + if depth == 0: + try: + return json.loads(text[start : index + 1]) + except json.JSONDecodeError: + break + start = text.find("{", start + 1) + + raise PyiCloudTrustedDevicePromptException( + "Could not decode the trusted-device bridge push payload." + ) + + +def _b64_to_hex(value: str) -> str: + """Decode base64 bridge data and return it as lowercase hex.""" + try: + return base64.b64decode(value.encode("ascii"), validate=True).hex() + except (ValueError, BinasciiError) as exc: + raise ValueError("Malformed base64-encoded bridge payload.") from exc + + +def _hex_to_b64(value: str) -> str: + """Encode hex bridge data as standard base64 text.""" + return base64.b64encode(bytes.fromhex(value)).decode("ascii") + + +def _build_nonce(timestamp_ms: int) -> bytes: + """Build the nonce format expected by Apple's bridge bootstrap.""" + return b"\x00" + timestamp_ms.to_bytes(8, "big", signed=False) + os.urandom(8) + + +def _summarize_identifier( + value: Optional[str], *, prefix: int = 8, empty: str = "" +) -> str: + """Shorten sensitive identifiers before logging them at debug level.""" + if not value: + return empty + if len(value) <= prefix: + return value + return f"{value[:prefix]}..." + + +def _resolve_websocket_host(boot_context: Hsa2BootContext) -> str: + """Resolve the websocket host Apple expects for the bridge session.""" + bridge_data = boot_context.bridge_initiate_data + web_socket_url = bridge_data.get("webSocketUrl") + if isinstance(web_socket_url, str) and web_socket_url: + if "://" in web_socket_url: + parsed = urlparse(web_socket_url) + if parsed.hostname: + return parsed.hostname + return web_socket_url.split("/", 1)[0] + + environment = bridge_data.get("apnsEnvironment") + if isinstance(environment, str) and environment in WEBSOCKET_ENVIRONMENT_HOSTS: + return WEBSOCKET_ENVIRONMENT_HOSTS[environment] + + raise PyiCloudTrustedDevicePromptException( + "Missing HSA2 websocket host for the trusted-device bridge." + ) + + +def _resolve_apns_topic(boot_context: Hsa2BootContext) -> str: + """Resolve the APNS topic Apple uses for trusted-device pushes.""" + topic = boot_context.bridge_initiate_data.get("apnsTopic") + if isinstance(topic, str) and topic: + return topic + + raise PyiCloudTrustedDevicePromptException( + "Missing HSA2 APNS topic for the trusted-device bridge." + ) + + +def _derive_origin(auth_endpoint: str) -> str: + """Derive the websocket Origin header from the auth endpoint URL.""" + parsed = urlparse(auth_endpoint) + if not parsed.scheme or not parsed.hostname: + raise PyiCloudTrustedDevicePromptException( + "Invalid auth endpoint for trusted-device bridge." + ) + return f"{parsed.scheme}://{parsed.hostname}" + + +class _RawWebSocketClient: + """Minimal websocket client for Apple's webcourier bridge.""" + + def __init__( + self, + url: str, + timeout: float, + origin: str, + user_agent: str, + ) -> None: + """Open a websocket connection and prepare buffered frame reads.""" + self._url = url + self._timeout = timeout + self._origin = origin + self._user_agent = user_agent + self._buffer = bytearray() + self._socket = self._open() + + def _open(self) -> ssl.SSLSocket: + """Perform the websocket HTTP upgrade and return the TLS socket.""" + parsed = urlparse(self._url) + if parsed.scheme != "wss" or not parsed.hostname: + raise PyiCloudTrustedDevicePromptException( + f"Unsupported websocket URL: {self._url}" + ) + + port = parsed.port or 443 + resource = parsed.path or "/" + if parsed.query: + resource = f"{resource}?{parsed.query}" + + raw_socket = socket.create_connection((parsed.hostname, port), self._timeout) + context = ssl.create_default_context() + secure_socket = context.wrap_socket(raw_socket, server_hostname=parsed.hostname) + secure_socket.settimeout(self._timeout) + + websocket_key = base64.b64encode(os.urandom(16)).decode("ascii") + request_headers = [ + f"GET {resource} HTTP/1.1", + f"Host: {parsed.hostname}", + "Upgrade: websocket", + "Connection: Upgrade", + f"Origin: {self._origin}", + f"User-Agent: {self._user_agent}", + "Sec-WebSocket-Version: 13", + f"Sec-WebSocket-Key: {websocket_key}", + "\r\n", + ] + secure_socket.sendall("\r\n".join(request_headers).encode("ascii")) + + response = self._read_http_response(secure_socket) + status_line, _, headers_text = response.partition("\r\n") + if " 101 " not in status_line: + raise PyiCloudTrustedDevicePromptException( + f"Websocket upgrade failed: {status_line}" + ) + + headers: dict[str, str] = {} + for line in headers_text.split("\r\n"): + if not line or ":" not in line: + continue + key, value = line.split(":", 1) + headers[key.strip().lower()] = value.strip() + + expected_accept = base64.b64encode( + hashlib.sha1((websocket_key + WEBSOCKET_GUID).encode("ascii")).digest() + ).decode("ascii") + if headers.get("sec-websocket-accept") != expected_accept: + raise PyiCloudTrustedDevicePromptException( + "Invalid websocket accept header from bridge server." + ) + + return secure_socket + + def _read_http_response(self, sock: ssl.SSLSocket) -> str: + """Read the HTTP upgrade response headers from the websocket socket.""" + while b"\r\n\r\n" not in self._buffer: + chunk = sock.recv(4096) + if not chunk: + raise PyiCloudTrustedDevicePromptException( + "Unexpected EOF during websocket handshake." + ) + self._buffer.extend(chunk) + + marker = self._buffer.find(b"\r\n\r\n") + 4 + data = bytes(self._buffer[:marker]).decode("iso-8859-1") + del self._buffer[:marker] + return data + + def _read_exact(self, size: int) -> bytes: + """Read exactly ``size`` buffered bytes from the websocket socket.""" + while len(self._buffer) < size: + chunk = self._socket.recv(max(4096, size - len(self._buffer))) + if not chunk: + raise PyiCloudTrustedDevicePromptException( + "Unexpected EOF while reading websocket frame." + ) + self._buffer.extend(chunk) + + data = bytes(self._buffer[:size]) + del self._buffer[:size] + return data + + def _send_frame(self, opcode: int, payload: bytes) -> None: + """Send one masked websocket frame to Apple's bridge server.""" + first_byte = 0x80 | opcode + mask_key = os.urandom(4) + length = len(payload) + + header = bytearray([first_byte]) + if length < 126: + header.append(0x80 | length) + elif length < 65536: + header.append(0x80 | 126) + header.extend(struct.pack("!H", length)) + else: + header.append(0x80 | 127) + header.extend(struct.pack("!Q", length)) + + masked_payload = bytes( + byte ^ mask_key[index % 4] for index, byte in enumerate(payload) + ) + self._socket.sendall(bytes(header) + mask_key + masked_payload) + + def send_binary(self, payload: bytes) -> None: + """Send one binary websocket message payload.""" + self._send_frame(OPCODE_BINARY, payload) + + def read_message(self) -> bytes: + """Read one complete websocket message, handling control frames inline.""" + fragments: list[bytes] = [] + opcode: Optional[int] = None + + while True: + first_byte, second_byte = self._read_exact(2) + frame_opcode = first_byte & 0x0F + finished = bool(first_byte & 0x80) + masked = bool(second_byte & 0x80) + payload_length = second_byte & 0x7F + + if payload_length == 126: + payload_length = struct.unpack("!H", self._read_exact(2))[0] + elif payload_length == 127: + payload_length = struct.unpack("!Q", self._read_exact(8))[0] + + mask_key = self._read_exact(4) if masked else b"" + payload = self._read_exact(payload_length) + if masked: + payload = bytes( + byte ^ mask_key[index % 4] for index, byte in enumerate(payload) + ) + + if frame_opcode == OPCODE_CLOSE: + raise PyiCloudTrustedDevicePromptException( + "Bridge websocket closed before delivering a prompt." + ) + if frame_opcode == OPCODE_PING: + self._send_frame(OPCODE_PONG, payload) + continue + if frame_opcode == OPCODE_PONG: + continue + + if frame_opcode != 0: + opcode = frame_opcode + fragments.append(payload) + if finished: + if opcode not in (0x1, OPCODE_BINARY): + raise PyiCloudTrustedDevicePromptException( + f"Unsupported websocket opcode: {opcode}" + ) + return b"".join(fragments) + + def close(self) -> None: + """Attempt a clean websocket close and always close the socket object.""" + if getattr(self, "_socket", None) is None: + return + try: + self._send_frame(OPCODE_CLOSE, b"") + except OSError: + pass + finally: + try: + self._socket.close() + except OSError: + pass + + +class TrustedDeviceBridgeBootstrapper: + """Bootstrap the trusted-device bridge flow captured in Apple's browser client.""" + + def __init__( + self, + *, + timeout: float = WEBSOCKET_TIMEOUT_SECONDS, + websocket_factory: Optional[ + Callable[[str, float, str, str], _WebSocketLike] + ] = None, + prover_factory: Optional[Callable[[], TrustedDeviceBridgeProver]] = None, + ) -> None: + """Configure websocket and prover factories for bridge operations.""" + self.timeout = timeout + self._websocket_factory = websocket_factory or _RawWebSocketClient + self._prover_factory = prover_factory or TrustedDeviceBridgeProver + + def start( + self, + *, + session: Any, + auth_endpoint: str, + headers: Mapping[str, str], + boot_context: Hsa2BootContext, + user_agent: str, + ) -> TrustedDeviceBridgeState: + """Bootstrap Apple's trusted-device bridge until the first prompt payload arrives.""" + topic = _resolve_apns_topic(boot_context) + websocket_host = _resolve_websocket_host(boot_context) + origin = _derive_origin(auth_endpoint) + topics_by_hash = {_topic_hash(topic): topic} + source_app_id = boot_context.source_app_id + public_key, private_key = self._generate_keypair() + + LOGGER.debug( + "Bootstrapping trusted-device bridge: auth_endpoint=%s websocket_host=%s topic=%s source_app_id=%s", + auth_endpoint, + websocket_host, + topic, + source_app_id, + ) + + timestamp_ms: Optional[int] = None + last_error: Optional[Exception] = None + for _ in range(2): + nonce = _build_nonce(timestamp_ms or int(time.time() * 1000)) + signature = private_key.sign(nonce, ec.ECDSA(hashes.SHA256())) + connection_message = _encode_connection_message( + public_key, nonce, signature + ) + connection_path = connection_message.hex() + websocket_url = f"wss://{websocket_host}/v2/{connection_path}" + LOGGER.debug( + "Opening trusted-device websocket: host=%s bootstrapPayloadLen=%d", + websocket_host, + len(connection_path), + ) + websocket = self._websocket_factory( + websocket_url, + self.timeout, + origin, + user_agent, + ) + keep_websocket_open = False + + try: + push_token = self._wait_for_push_token(websocket) + push_token_hex = push_token.hex() + LOGGER.debug( + "Trusted-device bridge connected; received push token (%d bytes)", + len(push_token), + ) + websocket.send_binary(_encode_web_filter_message([topic])) + LOGGER.debug("Sent trusted-device webFilterMessage for topic=%s", topic) + + session_uuid = self._generate_session_uuid() + bridge_headers = dict(headers) + if source_app_id: + bridge_headers["X-Apple-App-Id"] = source_app_id + + LOGGER.debug( + "Posting trusted-device bridge step 0 with sessionUUID=%s ptknLen=%d", + _summarize_identifier(session_uuid), + len(push_token_hex), + ) + # Apple's browser posts step 0 immediately after obtaining the push + # token. Waiting for the first push before posting step 0 causes the + # bridge flow to stall. + self._post_bridge_step0( + session=session, + auth_endpoint=auth_endpoint, + headers=bridge_headers, + session_uuid=session_uuid, + push_token=push_token_hex, + ) + + push_payload = self._wait_for_bridge_push( + websocket, topic, topics_by_hash + ) + LOGGER.debug( + "Received trusted-device bridge payload: sessionUUID=%s nextStep=%s ruiURLKey=%s", + _summarize_identifier(push_payload.session_uuid), + push_payload.next_step, + push_payload.rui_url_key, + ) + if push_payload.session_uuid != session_uuid: + raise PyiCloudTrustedDevicePromptException( + "Trusted-device bridge returned a mismatched session UUID." + ) + + bridge_state = TrustedDeviceBridgeState( + connection_path=connection_path, + push_token=push_token_hex, + session_uuid=session_uuid, + websocket=websocket, + topic=topic, + topics_by_hash=dict(topics_by_hash), + source_app_id=source_app_id, + ) + bridge_state.apply_push_payload(push_payload) + keep_websocket_open = True + return bridge_state + except _InvalidNonceError as exc: + timestamp_ms = exc.server_timestamp_ms + last_error = exc + LOGGER.debug( + "Trusted-device bridge received INVALID_NONCE; retrying with server timestamp %s", + timestamp_ms, + ) + except (OSError, socket.timeout, ssl.SSLError) as exc: + last_error = exc + LOGGER.debug( + "Trusted-device websocket transport error during bootstrap.", + exc_info=True, + ) + break + except PyiCloudTrustedDevicePromptException as exc: + last_error = exc + LOGGER.debug( + "Trusted-device bridge bootstrap failed before completion.", + exc_info=True, + ) + break + finally: + if not keep_websocket_open: + websocket.close() + + raise PyiCloudTrustedDevicePromptException( + "Failed to bootstrap the trusted-device bridge prompt." + ) from last_error + + def _generate_keypair(self) -> tuple[bytes, ec.EllipticCurvePrivateKey]: + """Generate the ephemeral P-256 keypair used for websocket bootstrap.""" + private_key = ec.generate_private_key(ec.SECP256R1()) + public_key = private_key.public_key().public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.UncompressedPoint, + ) + return public_key, private_key + + def _generate_session_uuid(self) -> str: + """Generate the browser-style bridge session UUID string.""" + return f"{uuid.uuid4()}-{int(time.time())}" + + def close(self, bridge_state: Optional[TrustedDeviceBridgeState]) -> None: + """Close and detach the websocket associated with an active bridge session.""" + + if bridge_state is None: + return + websocket = bridge_state.websocket + bridge_state.websocket = None + if websocket is None: + return + try: + websocket.close() + except OSError: + LOGGER.debug( + "Trusted-device bridge websocket close failed.", + exc_info=True, + ) + + def validate_code( + self, + *, + session: Any, + auth_endpoint: str, + headers: Mapping[str, str], + bridge_state: TrustedDeviceBridgeState, + code: str, + ) -> bool: + """Run Apple's bridge-specific trusted-device verification flow.""" + + websocket = bridge_state.websocket + if websocket is None: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge session is not active." + ) + if bridge_state.uses_legacy_trusted_device_verifier: + raise PyiCloudTrustedDeviceVerificationException( + "Legacy trusted-device verification should bypass the bridge verifier." + ) + if bridge_state.next_step not in {"2", 2}: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge is not ready for step 2 verification." + ) + if not bridge_state.salt: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge payload is missing the step-2 salt." + ) + + prover = self._prover_factory() + bridge_headers = self._bridge_headers(headers, bridge_state) + + try: + LOGGER.debug( + "Starting trusted-device bridge code verification: sessionUUID=%s nextStep=%s txnid=%s", + _summarize_identifier(bridge_state.session_uuid), + bridge_state.next_step, + _summarize_identifier(bridge_state.txnid, prefix=12), + ) + + prover.init_with_salt(bridge_state.salt, code) + message1 = prover.get_message1() + LOGGER.debug( + "Posting trusted-device bridge step 2 with sessionUUID=%s", + _summarize_identifier(bridge_state.session_uuid), + ) + self._post_bridge_step( + session=session, + auth_endpoint=auth_endpoint, + headers=bridge_headers, + bridge_state=bridge_state, + next_step=2, + data=_hex_to_b64(message1), + idmsdata=bridge_state.idmsdata, + akdata=bridge_state.akdata, + ) + + step4_payload = self._wait_for_bridge_push( + websocket, + bridge_state.topic, + bridge_state.topics_by_hash, + ) + self._apply_expected_step4_push(bridge_state, step4_payload) + + if not bridge_state.data: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge step 4 payload is missing prover data." + ) + try: + step4_data = base64.b64decode( + bridge_state.data.encode("ascii"), validate=True + ).decode("utf-8") + bridge_message1_b64, bridge_message2_b64 = step4_data.split("_", 1) + bridge_message1_hex = _b64_to_hex(bridge_message1_b64) + bridge_message2_hex = _b64_to_hex(bridge_message2_b64) + except (ValueError, UnicodeDecodeError, BinasciiError) as exc: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge step 4 payload is malformed." + ) from exc + + LOGGER.debug( + "Processing trusted-device bridge step 4 payload for sessionUUID=%s", + _summarize_identifier(bridge_state.session_uuid), + ) + try: + message2 = prover.process_message1(bridge_message1_hex) + except ValueError as exc: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge step 4 payload is malformed." + ) from exc + try: + prover.process_message2(bridge_message2_hex) + except ValueError: + LOGGER.debug( + "Trusted-device bridge prover rejected the step-4 confirmation for sessionUUID=%s", + _summarize_identifier(bridge_state.session_uuid), + ) + return False + + LOGGER.debug( + "Posting trusted-device bridge step 4 with sessionUUID=%s", + _summarize_identifier(bridge_state.session_uuid), + ) + self._post_bridge_step( + session=session, + auth_endpoint=auth_endpoint, + headers=bridge_headers, + bridge_state=bridge_state, + next_step=4, + data=_hex_to_b64(message2), + idmsdata=bridge_state.idmsdata, + akdata=bridge_state.akdata, + ) + + final_payload = self._wait_for_bridge_push( + websocket, + bridge_state.topic, + bridge_state.topics_by_hash, + ) + self._apply_final_bridge_push(bridge_state, final_payload) + + if not bridge_state.encrypted_code: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge final payload is missing encryptedCode." + ) + + LOGGER.debug( + "Decrypting trusted-device bridge code for sessionUUID=%s", + _summarize_identifier(bridge_state.session_uuid), + ) + try: + derived_code = prover.decrypt_message(bridge_state.encrypted_code) + except ValueError as exc: + raise PyiCloudTrustedDeviceVerificationException( + "Failed to decrypt the trusted-device bridge validation code." + ) from exc + + verify_response = self._post_bridge_code_validate( + session=session, + auth_endpoint=auth_endpoint, + headers=bridge_headers, + bridge_state=bridge_state, + code=derived_code, + ) + verification_succeeded = ( + verify_response.status_code != HTTP_STATUS_PRECONDITION_FAILED + ) + + completion_step = 6 if bridge_state.next_step in {"6", 6} else 4 + LOGGER.debug( + "Posting trusted-device bridge completion step %s with sessionUUID=%s verifyStatus=%s", + completion_step, + _summarize_identifier(bridge_state.session_uuid), + verify_response.status_code, + ) + self._post_bridge_step( + session=session, + auth_endpoint=auth_endpoint, + headers=bridge_headers, + bridge_state=bridge_state, + next_step=completion_step, + data=BRIDGE_DONE_DATA_B64, + idmsdata=bridge_state.idmsdata, + akdata=bridge_state.akdata, + ) + return verification_succeeded + except PyiCloudTrustedDevicePromptException as exc: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge verification failed while waiting for the next bridge push." + ) from exc + except (OSError, socket.timeout, ssl.SSLError) as exc: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge verification failed due to a websocket transport error." + ) from exc + finally: + self.close(bridge_state) + + def _wait_for_push_token(self, websocket: _WebSocketLike) -> bytes: + """Wait for the bridge connection response that carries the push token.""" + deadline = time.monotonic() + self.timeout + while time.monotonic() < deadline: + message = websocket.read_message() + server_message = _decode_server_message(message) + connection_response = server_message.connection_response + if connection_response is None: + LOGGER.debug( + "Ignoring non-connection websocket frame while waiting for push token; fields=%s", + server_message.field_numbers, + ) + continue + + if ( + connection_response.status == STATUS_OK + and connection_response.push_token_b64 + ): + try: + return base64.b64decode( + connection_response.push_token_b64.encode("ascii"), + validate=True, + ) + except (ValueError, BinasciiError) as exc: + raise PyiCloudTrustedDevicePromptException( + "Malformed bridge push token." + ) from exc + + if ( + connection_response.status == STATUS_INVALID_NONCE + and connection_response.server_timestamp_seconds is not None + ): + raise _InvalidNonceError( + connection_response.server_timestamp_seconds * 1000 + ) + + LOGGER.debug( + "Trusted-device bridge connection response returned status=%s", + connection_response.status, + ) + raise PyiCloudTrustedDevicePromptException( + f"Bridge server returned status {connection_response.status}." + ) + + raise PyiCloudTrustedDevicePromptException( + "Timed out waiting for the bridge push token." + ) + + def _wait_for_bridge_push( + self, + websocket: _WebSocketLike, + topic: str, + topics_by_hash: Mapping[str, str], + ) -> BridgePushPayload: + """Wait for, acknowledge, and decode the next relevant bridge push.""" + deadline = time.monotonic() + self.timeout + while time.monotonic() < deadline: + message = websocket.read_message() + server_message = _decode_server_message(message) + if server_message.channel_subscription_response is not None: + channel_response = server_message.channel_subscription_response + LOGGER.debug( + "Received channel subscription response during bridge bootstrap: messageId=%s status=%s retryIntervalSeconds=%s topics=%s", + channel_response.message_id, + channel_response.status, + channel_response.retry_interval_seconds, + channel_response.topics, + ) + if channel_response.status != STATUS_OK: + raise PyiCloudTrustedDevicePromptException( + "Trusted-device bridge topic subscription failed " + f"(status {channel_response.status})." + ) + + if server_message.push_acknowledgment is not None: + push_ack = server_message.push_acknowledgment + LOGGER.debug( + "Received bridge push acknowledgment during bootstrap: messageId=%s deliveryStatus=%s topic=%s", + push_ack.message_id, + push_ack.delivery_status, + _topic_name(push_ack.topic, topics_by_hash), + ) + + push_message = server_message.push_message + if push_message is None: + LOGGER.debug( + "Ignoring non-push websocket frame during trusted-device bootstrap; fields=%s", + server_message.field_numbers, + ) + continue + + websocket.send_binary( + _encode_ack_message(push_message.topic, push_message.message_id) + ) + LOGGER.debug( + "Acknowledged trusted-device push message id=%s topic=%s", + push_message.message_id, + _topic_name(push_message.topic, topics_by_hash), + ) + + if _topic_name(push_message.topic, topics_by_hash) != topic: + continue + + payload = _extract_json_payload(push_message.payload) + return BridgePushPayload.from_payload(payload) + + raise PyiCloudTrustedDevicePromptException( + "Timed out waiting for the trusted-device bridge payload." + ) + + def _apply_bridge_push( + self, + bridge_state: TrustedDeviceBridgeState, + push_payload: BridgePushPayload, + ) -> None: + """Validate a generic bridge push and merge it into the active state.""" + if push_payload.session_uuid != bridge_state.session_uuid: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge returned a mismatched session UUID." + ) + LOGGER.debug( + "Decoded trusted-device bridge payload: sessionUUID=%s nextStep=%s txnid=%s ec=%s has_data=%s has_encryptedCode=%s", + _summarize_identifier(push_payload.session_uuid), + push_payload.next_step, + _summarize_identifier(push_payload.txnid, prefix=12), + push_payload.error_code, + bool(push_payload.data), + bool(push_payload.encrypted_code), + ) + if push_payload.error_code not in (None, 0): + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge returned an error push " + f"(nextStep={push_payload.next_step!r}, ec={push_payload.error_code})." + ) + bridge_state.apply_push_payload(push_payload) + + def _apply_expected_step4_push( + self, + bridge_state: TrustedDeviceBridgeState, + push_payload: BridgePushPayload, + ) -> None: + """Require the post-step-2 bridge push to contain step-4 prover data.""" + self._apply_bridge_push(bridge_state, push_payload) + if bridge_state.next_step != "4" or not bridge_state.data: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge returned an unexpected post-step-2 payload." + ) + LOGGER.debug( + "Received trusted-device bridge payload: sessionUUID=%s nextStep=%s txnid=%s", + _summarize_identifier(bridge_state.session_uuid), + bridge_state.next_step, + _summarize_identifier(bridge_state.txnid, prefix=12), + ) + + def _apply_final_bridge_push( + self, + bridge_state: TrustedDeviceBridgeState, + push_payload: BridgePushPayload, + ) -> None: + """Require the final bridge push to contain the encrypted validation code.""" + self._apply_bridge_push(bridge_state, push_payload) + # Apple's bridge can finish with either: + # - nextStep=6 plus encryptedCode + # - nextStep=4 plus encryptedCode + # The browser routes both shapes into final code validation. + if ( + bridge_state.next_step not in {"4", "6", 4, 6} + or not bridge_state.encrypted_code + ): + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge returned an unexpected final payload." + ) + LOGGER.debug( + "Received trusted-device bridge final payload: sessionUUID=%s nextStep=%s txnid=%s", + _summarize_identifier(bridge_state.session_uuid), + bridge_state.next_step, + _summarize_identifier(bridge_state.txnid, prefix=12), + ) + + def _bridge_headers( + self, + headers: Mapping[str, str], + bridge_state: TrustedDeviceBridgeState, + ) -> dict[str, str]: + """Build the auth headers used for bridge-specific HTTP requests.""" + bridge_headers = dict(headers) + if bridge_state.source_app_id: + bridge_headers["X-Apple-App-Id"] = bridge_state.source_app_id + return bridge_headers + + def _bridge_step_json( + self, + *, + bridge_state: TrustedDeviceBridgeState, + next_step: int, + data: str, + idmsdata: Optional[str], + akdata: Any, + ) -> dict[str, Any]: + """Build the JSON payload for one bridge step POST.""" + return BridgeStepRequest( + session_uuid=bridge_state.session_uuid, + data=data, + push_token=bridge_state.push_token, + next_step=next_step, + idmsdata=idmsdata, + akdata=akdata, + ).as_json() + + def _post_bridge_step( + self, + *, + session: Any, + auth_endpoint: str, + headers: Mapping[str, str], + bridge_state: TrustedDeviceBridgeState, + next_step: int, + data: str, + idmsdata: Optional[str], + akdata: Any, + ) -> Any: + """POST one bridge step and enforce the small set of valid statuses.""" + response = session.request_raw( + "POST", + f"{auth_endpoint}{BRIDGE_STEP_PATH_TEMPLATE.format(step=next_step)}", + json=self._bridge_step_json( + bridge_state=bridge_state, + next_step=next_step, + data=data, + idmsdata=idmsdata, + akdata=akdata, + ), + headers=headers, + ) + if response.status_code not in { + HTTP_STATUS_OK, + HTTP_STATUS_NO_CONTENT, + HTTP_STATUS_CONFLICT, + }: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge step " + f"{next_step} failed with status {response.status_code}." + ) + return response + + def _post_bridge_step0( + self, + *, + session: Any, + auth_endpoint: str, + headers: Mapping[str, str], + session_uuid: str, + push_token: str, + ) -> Any: + """POST bridge step 0 immediately after obtaining the push token.""" + response = session.request_raw( + "POST", + f"{auth_endpoint}{BRIDGE_STEP_PATH}", + json={ + "sessionUUID": session_uuid, + "ptkn": push_token, + }, + headers=headers, + ) + if response.status_code not in { + HTTP_STATUS_OK, + HTTP_STATUS_NO_CONTENT, + HTTP_STATUS_CONFLICT, + }: + raise PyiCloudTrustedDevicePromptException( + "Trusted-device bridge step 0 failed with status " + f"{response.status_code}." + ) + return response + + def _post_bridge_code_validate( + self, + *, + session: Any, + auth_endpoint: str, + headers: Mapping[str, str], + bridge_state: TrustedDeviceBridgeState, + code: str, + ) -> Any: + """POST the decrypted bridge code to Apple's final validation endpoint.""" + response = session.request_raw( + "POST", + f"{auth_endpoint}{BRIDGE_CODE_VALIDATE_PATH}", + json=BridgeCodeValidateRequest( + session_uuid=bridge_state.session_uuid, + code=code, + ).as_json(), + headers=headers, + ) + if response.status_code not in { + HTTP_STATUS_OK, + HTTP_STATUS_NO_CONTENT, + HTTP_STATUS_CONFLICT, + HTTP_STATUS_PRECONDITION_FAILED, + }: + raise PyiCloudTrustedDeviceVerificationException( + "Trusted-device bridge code validation failed with status " + f"{response.status_code}." + ) + return response diff --git a/pyicloud/hsa2_bridge_prover.py b/pyicloud/hsa2_bridge_prover.py new file mode 100644 index 00000000..7e6ef8bd --- /dev/null +++ b/pyicloud/hsa2_bridge_prover.py @@ -0,0 +1,582 @@ +"""Pure-Python bridge prover for Apple's trusted-device HSA2 flow.""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import secrets +from dataclasses import dataclass +from typing import Optional + +from cryptography.exceptions import InvalidTag +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +_SCRYPT_PARAMS = { + "n": 16384, + "r": 8, + "p": 1, + "dklen": 64, +} +_CLIENT_IDENTITY = b"com.apple.security.webprover" +_SERVER_IDENTITY = b"com.apple.security.webverifier" +_SPAKE2_CONTEXT = b"SPAKE2Web" +_KEY_LENGTH = 32 +_VERIFIER_KEY_INFO = b"webVerifier" +_PROVER_KEY_INFO = b"webProver" + +_P256_P = int("FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF", 16) +_P256_A = (_P256_P - 3) % _P256_P +_P256_B = int("5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B", 16) +_P256_ORDER = int( + "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551", 16 +) +_P256_GX = int("6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296", 16) +_P256_GY = int("4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5", 16) +_SPAKE2_M = "02886e2f97ace46e55ba9dd7242579f2993b64e16ef3dcab95afd497333d8fa12f" +_SPAKE2_N = "03d8bbd6c639c62937b04d997f38c3770719c629d7014d49a24b4f98baa1292b49" +_AES_GCM_LAYOUTS = {0: (12, 16)} + + +@dataclass(frozen=True) +class _Point: + """Affine P-256 point used by the bridge SPAKE2 math helpers.""" + + x: Optional[int] + y: Optional[int] + + @property + def is_infinity(self) -> bool: + """Return whether this point is the point at infinity.""" + return self.x is None or self.y is None + + +_INFINITY = _Point(None, None) +_GENERATOR = _Point(_P256_GX, _P256_GY) + + +def _int_to_bytes(value: int, length: Optional[int] = None) -> bytes: + """Encode an integer using big-endian bytes.""" + if length is None: + length = max(1, (value.bit_length() + 7) // 8) + return value.to_bytes(length, "big") + + +def _b64_to_bytes(value: str) -> bytes: + """Decode a base64 string into raw bytes.""" + return base64.b64decode(value.encode("ascii")) + + +def _bytes_to_b64(value: bytes) -> str: + """Encode raw bytes as an ASCII base64 string.""" + return base64.b64encode(value).decode("ascii") + + +def _encode_point(point: _Point) -> str: + """Encode a P-256 point using SEC1 uncompressed point format.""" + if point.is_infinity: + raise ValueError("Cannot encode the point at infinity.") + return "04" + _int_to_bytes(point.x, 32).hex() + _int_to_bytes(point.y, 32).hex() + + +def _decode_point(value: str) -> _Point: + """Decode a compressed or uncompressed SEC1 point into affine coordinates.""" + raw = bytes.fromhex(value) + if len(raw) == 65 and raw[0] == 0x04: + point = _Point( + int.from_bytes(raw[1:33], "big"), + int.from_bytes(raw[33:65], "big"), + ) + elif len(raw) == 33 and raw[0] in (0x02, 0x03): + x_coord = int.from_bytes(raw[1:], "big") + rhs = (pow(x_coord, 3, _P256_P) + _P256_A * x_coord + _P256_B) % _P256_P + y_coord = pow(rhs, (_P256_P + 1) // 4, _P256_P) + if y_coord & 1 != raw[0] & 1: + y_coord = (-y_coord) % _P256_P + point = _Point(x_coord, y_coord) + else: + raise ValueError("Unsupported P-256 point encoding.") + + if not _is_on_curve(point): + raise ValueError("Invalid P-256 point.") + return point + + +def _is_on_curve(point: _Point) -> bool: + """Return whether a point lies on the configured P-256 curve.""" + if point.is_infinity: + return False + assert point.x is not None and point.y is not None + return ( + pow(point.y, 2, _P256_P) + - (pow(point.x, 3, _P256_P) + _P256_A * point.x + _P256_B) + ) % _P256_P == 0 + + +def _negate(point: _Point) -> _Point: + """Return the additive inverse of a P-256 point.""" + if point.is_infinity: + return point + assert point.x is not None and point.y is not None + return _Point(point.x, (-point.y) % _P256_P) + + +def _add_points(left: _Point, right: _Point) -> _Point: + """Add two affine P-256 points.""" + if left.is_infinity: + return right + if right.is_infinity: + return left + + assert left.x is not None and left.y is not None + assert right.x is not None and right.y is not None + + if left.x == right.x and (left.y + right.y) % _P256_P == 0: + return _INFINITY + + if left.x == right.x and left.y == right.y: + if left.y == 0: + return _INFINITY + slope = ( + (3 * left.x * left.x + _P256_A) * pow(2 * left.y, -1, _P256_P) + ) % _P256_P + else: + slope = ((right.y - left.y) * pow(right.x - left.x, -1, _P256_P)) % _P256_P + + x_coord = (slope * slope - left.x - right.x) % _P256_P + y_coord = (slope * (left.x - x_coord) - left.y) % _P256_P + return _Point(x_coord, y_coord) + + +def _multiply_point(point: _Point, scalar: int) -> _Point: + """Multiply a P-256 point by a scalar using double-and-add.""" + scalar %= _P256_ORDER + result = _INFINITY + addend = point + while scalar: + if scalar & 1: + result = _add_points(result, addend) + addend = _add_points(addend, addend) + scalar >>= 1 + return result + + +def _concat_length_prefixed(*parts: bytes) -> bytes: + """Concatenate transcript parts using the bridge's length-prefixed format.""" + output = bytearray() + for part in parts: + output.extend(len(part).to_bytes(8, "little")) + output.extend(part) + return bytes(output) + + +def _hkdf_like(ikm: bytes, salt: bytes, info: bytes, length: int) -> bytes: + """Derive key material using the bridge worker's HKDF-like expansion.""" + hash_len = hashlib.sha256().digest_size + if not salt: + salt = b"\x00" * hash_len + prk = hmac.new(salt, ikm, hashlib.sha256).digest() + blocks = bytearray() + previous = b"" + counter = 1 + while len(blocks) < length: + previous = hmac.new( + prk, + previous + info + bytes([counter]), + hashlib.sha256, + ).digest() + blocks.extend(previous) + counter += 1 + return bytes(blocks[:length]) + + +def _confirmation_key_length(info: bytes, requested_length: int) -> int: + """Return the bridge-specific output length for a given HKDF info label.""" + if b"ConfirmationKeys" in info: + return 64 + return requested_length + + +def _derive_key(ikm: bytes, info: bytes, length: int = 64) -> bytes: + """Derive one bridge sub-key from raw shared-secret material.""" + return _hkdf_like( + ikm=ikm, + salt=b"", + info=info, + length=_confirmation_key_length(info, length), + ) + + +def _derive_prover_and_verifier_keys(raw_key_hex: str) -> tuple[str, str]: + """Split the raw bridge key into prover and verifier AES/HMAC keys.""" + raw_key = bytes.fromhex(raw_key_hex) + verifier_key = _derive_key(raw_key, _VERIFIER_KEY_INFO, _KEY_LENGTH) + prover_key = _derive_key(raw_key, _PROVER_KEY_INFO, _KEY_LENGTH) + return verifier_key.hex(), prover_key.hex() + + +@dataclass(frozen=True) +class _ClientSharedSecret: + """Client-side shared-secret transcript and derived confirmation keys.""" + + transcript: bytes + share_p: str + share_v: str + + def __post_init__(self) -> None: + """Derive confirmation keys and the final shared key from the transcript.""" + digest = hashlib.sha256(self.transcript).digest() + object.__setattr__(self, "_hash_transcript", digest) + confirmations = _derive_key(digest, b"ConfirmationKeys", 64) + object.__setattr__(self, "_confirm_client", confirmations[:32]) + object.__setattr__(self, "_confirm_server", confirmations[32:]) + shared_key = _derive_key(digest, b"SharedKey", _KEY_LENGTH) + object.__setattr__(self, "_shared_key", shared_key) + + def get_confirmation(self) -> str: + """Return the prover's HMAC confirmation message.""" + return hmac.new( + self._confirm_client, + bytes.fromhex(self.share_v), + hashlib.sha256, + ).hexdigest() + + def verify(self, message_hex: str) -> bytes: + """Verify the server confirmation and return the shared key bytes.""" + expected = hmac.new( + self._confirm_server, + bytes.fromhex(self.share_p), + hashlib.sha256, + ).hexdigest() + if expected != message_hex: + raise ValueError("invalid confirmation from server") + return self._shared_key + + +@dataclass(frozen=True) +class _ServerSharedSecret: + """Server-side shared-secret transcript and derived confirmation keys.""" + + transcript: bytes + share_p: str + share_v: str + + def __post_init__(self) -> None: + """Derive confirmation keys and the final shared key from the transcript.""" + digest = hashlib.sha256(self.transcript).digest() + confirmations = _derive_key(digest, b"ConfirmationKeys", 64) + object.__setattr__(self, "_confirm_client", confirmations[:32]) + object.__setattr__(self, "_confirm_server", confirmations[32:]) + object.__setattr__( + self, + "_shared_key", + _derive_key(digest, b"SharedKey", _KEY_LENGTH), + ) + + def get_confirmation(self) -> str: + """Return the verifier's HMAC confirmation message.""" + return hmac.new( + self._confirm_server, + bytes.fromhex(self.share_p), + hashlib.sha256, + ).hexdigest() + + def verify(self, message_hex: str) -> bytes: + """Verify the prover confirmation and return the shared key bytes.""" + expected = hmac.new( + self._confirm_client, + bytes.fromhex(self.share_v), + hashlib.sha256, + ).hexdigest() + if expected != message_hex: + raise ValueError("invalid confirmation from client") + return self._shared_key + + +class _ClientHandshake: + """Client-side SPAKE2 handshake state for Apple's bridge prover.""" + + def __init__( + self, + *, + x_scalar: int, + w0: int, + w1: int, + ) -> None: + """Initialize the prover handshake with the derived SPAKE2 scalars.""" + self._x = x_scalar + self._w0 = w0 + self._w1 = w1 + self._message1_point: Optional[_Point] = None + self.share_p: Optional[str] = None + + def get_message(self) -> str: + """Return the prover's first SPAKE2 message.""" + point = _add_points( + _multiply_point(_GENERATOR, self._x), + _multiply_point(_decode_point(_SPAKE2_M), self._w0), + ) + self._message1_point = point + self.share_p = _encode_point(point) + return self.share_p + + def finish(self, server_message_hex: str) -> _ClientSharedSecret: + """Finish the handshake using the verifier's first message.""" + if self._message1_point is None or self.share_p is None: + raise ValueError("get_message must be called before finish") + + server_point = _decode_point(server_message_hex) + if server_point.is_infinity: + raise ValueError("invalid curve point") + + adjusted = _add_points( + server_point, + _negate(_multiply_point(_decode_point(_SPAKE2_N), self._w0)), + ) + y_point = _multiply_point(adjusted, self._x) + v_point = _multiply_point(adjusted, self._w1) + transcript = _concat_length_prefixed( + _SPAKE2_CONTEXT, + _CLIENT_IDENTITY, + _SERVER_IDENTITY, + bytes.fromhex(_encode_point(_decode_point(_SPAKE2_M))), + bytes.fromhex(_encode_point(_decode_point(_SPAKE2_N))), + bytes.fromhex(_encode_point(self._message1_point)), + bytes.fromhex(_encode_point(server_point)), + bytes.fromhex(_encode_point(y_point)), + bytes.fromhex(_encode_point(v_point)), + _int_to_bytes(self._w0), + ) + return _ClientSharedSecret( + transcript=transcript, + share_p=self.share_p, + share_v=server_message_hex, + ) + + +class _ServerHandshake: + """Server-side SPAKE2 handshake state used by the local test helper.""" + + def __init__( + self, + *, + y_scalar: int, + w0: int, + verifier_point: _Point, + ) -> None: + """Initialize the verifier handshake with its scalar and verifier point.""" + self._y = y_scalar + self._w0 = w0 + self._verifier_point = verifier_point + self._message1_point: Optional[_Point] = None + self.share_v: Optional[str] = None + + def get_message(self) -> str: + """Return the verifier's first SPAKE2 message.""" + point = _add_points( + _multiply_point(_GENERATOR, self._y), + _multiply_point(_decode_point(_SPAKE2_N), self._w0), + ) + self._message1_point = point + self.share_v = _encode_point(point) + return self.share_v + + def finish(self, client_message_hex: str) -> _ServerSharedSecret: + """Finish the verifier handshake using the prover's first message.""" + if self._message1_point is None or self.share_v is None: + raise ValueError("get_message must be called before finish") + + client_point = _decode_point(client_message_hex) + if client_point.is_infinity: + raise ValueError("invalid curve point") + + adjusted = _add_points( + client_point, + _negate(_multiply_point(_decode_point(_SPAKE2_M), self._w0)), + ) + y_point = _multiply_point(adjusted, self._y) + verifier_share = _multiply_point(self._verifier_point, self._y) + transcript = _concat_length_prefixed( + _SPAKE2_CONTEXT, + _CLIENT_IDENTITY, + _SERVER_IDENTITY, + bytes.fromhex(_encode_point(_decode_point(_SPAKE2_M))), + bytes.fromhex(_encode_point(_decode_point(_SPAKE2_N))), + bytes.fromhex(_encode_point(client_point)), + bytes.fromhex(_encode_point(self._message1_point)), + bytes.fromhex(_encode_point(y_point)), + bytes.fromhex(_encode_point(verifier_share)), + _int_to_bytes(self._w0), + ) + return _ServerSharedSecret( + transcript=transcript, + share_p=client_message_hex, + share_v=self.share_v, + ) + + +def _compute_w0_w1(password: str, salt_b64: str) -> tuple[int, int]: + """Derive the SPAKE2 scalars from the user code and bridge salt.""" + derived = hashlib.scrypt( + password.encode("utf-8"), + salt=_b64_to_bytes(salt_b64), + **_SCRYPT_PARAMS, + ) + midpoint = len(derived) // 2 + return ( + int.from_bytes(derived[:midpoint], "big"), + int.from_bytes(derived[midpoint:], "big"), + ) + + +def _random_nonzero_scalar() -> int: + """Return a random scalar in the non-zero P-256 subgroup range.""" + scalar = 0 + while scalar == 0: + scalar = secrets.randbelow(_P256_ORDER) + return scalar + + +class TrustedDeviceBridgeProver: + """Client-side prover mirroring Apple's prover worker.""" + + def __init__(self) -> None: + """Initialize empty prover state for one bridge verification attempt.""" + self._client: Optional[_ClientHandshake] = None + self._shared_secret: Optional[_ClientSharedSecret] = None + self._raw_key: Optional[str] = None + self._verified = False + self._verifier_key: Optional[str] = None + self._prover_key: Optional[str] = None + + def init_with_salt(self, salt_b64: str, code: str) -> None: + """Initialize the prover with Apple's salt and the user-entered code.""" + w0, w1 = _compute_w0_w1(code, salt_b64) + self._client = _ClientHandshake( + x_scalar=_random_nonzero_scalar(), + w0=w0, + w1=w1, + ) + self._shared_secret = None + self._raw_key = None + self._verified = False + self._verifier_key = None + self._prover_key = None + + def get_message1(self) -> str: + """Return the prover's first bridge message.""" + if self._client is None: + raise ValueError("init_with_salt must be called before get_message1") + return self._client.get_message() + + def process_message1(self, message_hex: str) -> str: + """Process Apple's first bridge message and return the prover confirmation.""" + if self._client is None: + raise ValueError("init_with_salt must be called before process_message1") + self._shared_secret = self._client.finish(message_hex) + return self.get_message2() + + def get_message2(self) -> str: + """Return the prover confirmation generated from the shared transcript.""" + if self._shared_secret is None: + raise ValueError("process_message1 must be called before get_message2") + return self._shared_secret.get_confirmation() + + def process_message2(self, message_hex: str) -> dict[str, object]: + """Verify Apple's confirmation and persist the derived bridge keys.""" + if self._shared_secret is None: + raise ValueError("process_message1 must be called before process_message2") + raw_key = self._shared_secret.verify(message_hex).hex() + self._raw_key = raw_key + self._verifier_key, self._prover_key = _derive_prover_and_verifier_keys(raw_key) + self._verified = True + return {"isVerified": True, "key": raw_key} + + def is_verified(self) -> bool: + """Return whether the bridge confirmation exchange has completed.""" + return self._verified + + def get_key(self) -> str: + """Return the raw shared bridge key as hexadecimal.""" + if self._raw_key is None: + raise ValueError("No bridge key is available yet.") + return self._raw_key + + def decrypt_message(self, ciphertext_b64: str) -> str: + """Decrypt Apple's final encrypted validation code.""" + if self._verifier_key is None: + raise ValueError("Bridge verifier key is not available.") + try: + payload = _b64_to_bytes(ciphertext_b64) + version = payload[0] + iv_length, tag_length = _AES_GCM_LAYOUTS[version] + iv = payload[1 : 1 + iv_length] + tag = payload[1 + iv_length : 1 + iv_length + tag_length] + ciphertext = payload[1 + iv_length + tag_length :] + plaintext = AESGCM(bytes.fromhex(self._verifier_key)).decrypt( + iv, + ciphertext + tag, + bytes([version]), + ) + return plaintext.decode("utf-8") + except (IndexError, KeyError, InvalidTag, UnicodeDecodeError) as exc: + raise ValueError("Malformed bridge payload") from exc + + +class _TrustedDeviceBridgeServerProver: + """Internal test helper mirroring Apple's server-side bridge flow.""" + + def __init__(self, *, password: str, salt_b64: str) -> None: + """Initialize the local verifier helper with the same password and salt.""" + w0, w1 = _compute_w0_w1(password, salt_b64) + verifier_point = _multiply_point(_GENERATOR, w1) + self._server = _ServerHandshake( + y_scalar=_random_nonzero_scalar(), + w0=w0, + verifier_point=verifier_point, + ) + self._shared_secret: Optional[_ServerSharedSecret] = None + self._raw_key: Optional[str] = None + self._verifier_key: Optional[str] = None + self._prover_key: Optional[str] = None + + def get_message1(self) -> str: + """Return the verifier's first bridge message.""" + return self._server.get_message() + + def process_message1(self, client_message_hex: str) -> str: + """Process the prover message and return the verifier confirmation.""" + self._shared_secret = self._server.finish(client_message_hex) + return self.get_message2() + + def get_message2(self) -> str: + """Return the verifier confirmation generated from the shared transcript.""" + if self._shared_secret is None: + raise ValueError("process_message1 must be called before get_message2") + return self._shared_secret.get_confirmation() + + def verify_message2(self, message_hex: str) -> str: + """Verify the prover confirmation and persist the derived bridge keys.""" + if self._shared_secret is None: + raise ValueError("process_message1 must be called before verify_message2") + raw_key = self._shared_secret.verify(message_hex).hex() + self._raw_key = raw_key + self._verifier_key, self._prover_key = _derive_prover_and_verifier_keys(raw_key) + return raw_key + + def encrypt_message(self, plaintext: str) -> str: + """Encrypt a plaintext validation code using Apple's AES-GCM payload layout.""" + if self._verifier_key is None: + raise ValueError("Bridge verifier key is not available.") + version = 0 + iv_length, tag_length = _AES_GCM_LAYOUTS[version] + iv = secrets.token_bytes(iv_length) + encrypted = AESGCM(bytes.fromhex(self._verifier_key)).encrypt( + iv, + plaintext.encode("utf-8"), + bytes([version]), + ) + ciphertext = encrypted[:-tag_length] + tag = encrypted[-tag_length:] + payload = bytes([version]) + iv + tag + ciphertext + return _bytes_to_b64(payload) diff --git a/pyicloud/session.py b/pyicloud/session.py index b696bfd9..c134d25b 100644 --- a/pyicloud/session.py +++ b/pyicloud/session.py @@ -33,6 +33,30 @@ from pyicloud.base import PyiCloudService +NON_PERSISTED_SESSION_KEYS = frozenset( + { + "akdata", + "connection_path", + "data", + "encryptedCode", + "encrypted_code", + "idmsdata", + "mid", + "nextStep", + "next_step", + "ptkn", + "push_token", + "salt", + "sessionUUID", + "session_uuid", + "source_app_id", + "topic", + "topics_by_hash", + "txnid", + } +) + + class PyiCloudSession(requests.Session): """iCloud session.""" @@ -44,6 +68,7 @@ def __init__( verify: bool = False, headers: Optional[dict[str, str]] = None, ) -> None: + """Initialize the persisted requests session used by the service.""" super().__init__() self._service: PyiCloudService = service @@ -102,7 +127,14 @@ def _save_session_data(self) -> None: os.makedirs(self._cookie_directory, exist_ok=True) with open(self.session_path, "w", encoding="utf-8") as outfile: # Copy to avoid dict mutation during concurrent access - dump(dict(self._data), outfile) + dump( + { + key: value + for key, value in dict(self._data).items() + if key not in NON_PERSISTED_SESSION_KEYS + }, + outfile, + ) self.logger.debug("Saved session data to file: %s", self.session_path) try: @@ -143,6 +175,7 @@ def _update_session_data(self, response: Response) -> None: self._data.update({session_arg: response.headers.get(header)}) def _is_json_response(self, response: Response) -> bool: + """Return whether a response advertises one of the accepted JSON mimetypes.""" content_type: str = response.headers.get(CONTENT_TYPE, "") json_mimetypes: list[str] = [ CONTENT_TYPE_JSON, @@ -169,6 +202,7 @@ def request( cert=None, json=None, ) -> Response: + """Dispatch a request through the normalized session request pipeline.""" return self._request( method, url, @@ -188,6 +222,71 @@ def request( json=json, ) + def request_raw( + self, + method, + url, + params=None, + data=None, + headers=None, + cookies=None, + files=None, + auth=None, + timeout=None, + allow_redirects=True, + proxies=None, + hooks=None, + stream=None, + verify=None, + cert=None, + json=None, + ) -> Response: + """Dispatch a request without response-status normalization.""" + + return self._request_raw( + method, + url, + params=params, + data=data, + headers=headers, + cookies=cookies, + files=files, + auth=auth, + timeout=timeout, + allow_redirects=allow_redirects, + proxies=proxies, + hooks=hooks, + stream=stream, + verify=verify, + cert=cert, + json=json, + ) + + def _request_raw( + self, + method, + url, + **kwargs, + ) -> Response: + """Perform a request and persist cookies/session data without raising.""" + + self.logger.debug( + "%s %s", + method, + url, + ) + try: + response: Response = super().request( + method=method, + url=url, + **kwargs, + ) + except requests.exceptions.RequestException as err: + self._raise_request_exception(err) + self._update_session_data(response) + self._save_session_data() + return response + def _request( self, method, @@ -236,13 +335,19 @@ def _request( self._decode_json_response(response) return response - except requests.HTTPError as err: + except requests.exceptions.RequestException as err: + self._raise_request_exception(err) + + @staticmethod + def _raise_request_exception(err: requests.exceptions.RequestException) -> NoReturn: + """Normalize low-level requests failures into the session's public error type.""" + + if isinstance(err, requests.HTTPError) and err.response is not None: raise PyiCloudAPIResponseException( reason=err.response.text, code=err.response.status_code, ) from err - except requests.exceptions.RequestException as err: - raise PyiCloudAPIResponseException("Request failed to iCloud") from err + raise PyiCloudAPIResponseException("Request failed to iCloud") from err def _handle_request_error( self, @@ -297,6 +402,7 @@ def _decode_json_response(self, response: Response) -> None: def _raise_error( self, response: Response, code: Optional[Union[int, str]], reason: str ) -> NoReturn: + """Raise the session's public exception for a parsed iCloud error payload.""" if ( self.service.requires_2sa and reason == "Missing X-APPLE-WEBAUTH-TOKEN cookie" diff --git a/requirements.txt b/requirements.txt index 06f2a422..6e0bc574 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ certifi>=2024.12.14 click>=8.1.8 +cryptography>=44.0.0 fido2>=2.0.0 keyring>=25.6.0 keyrings.alt>=5.0.2 diff --git a/tests/test_base.py b/tests/test_base.py index 01f8bdfe..0bebaf05 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -5,10 +5,13 @@ import json import secrets +import tempfile +from pathlib import Path from typing import Any, List from unittest.mock import MagicMock, mock_open, patch import pytest +import requests from fido2.hid import CtapHidDevice from requests import HTTPError, Response @@ -21,6 +24,8 @@ PyiCloudFailedLoginException, PyiCloudServiceNotActivatedException, PyiCloudServiceUnavailable, + PyiCloudTrustedDevicePromptException, + PyiCloudTrustedDeviceVerificationException, ) from pyicloud.services.calendar import CalendarService from pyicloud.services.contacts import ContactsService @@ -257,6 +262,427 @@ def test_validate_2fa_code(pyicloud_service: PyiCloudService) -> None: assert pyicloud_service.validate_2fa_code("123456") +def test_validate_2fa_code_uses_bridge_verifier_for_step2_state( + pyicloud_service: PyiCloudService, +) -> None: + """Bridge-backed trusted-device prompts should use the bridge verifier instead of the legacy endpoint.""" + + pyicloud_service.data = {"dsInfo": {"hsaVersion": 2}, "hsaChallengeRequired": False} + pyicloud_service._two_factor_delivery_method = "trusted_device" + bridge_state = MagicMock(uses_legacy_trusted_device_verifier=False) + pyicloud_service._trusted_device_bridge_state = bridge_state + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._trusted_device_bridge.validate_code.return_value = True + pyicloud_service.trust_session = MagicMock( + side_effect=lambda: pyicloud_service.data.update({"hsaTrustedBrowser": True}) + or True + ) + pyicloud_service._session = MagicMock() + pyicloud_service.session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + + assert pyicloud_service.validate_2fa_code("123456") is True + + pyicloud_service._trusted_device_bridge.validate_code.assert_called_once() + pyicloud_service.session.post.assert_not_called() + pyicloud_service._trusted_device_bridge.close.assert_called_once_with(bridge_state) + pyicloud_service.trust_session.assert_called_once_with() + + +def test_validate_2fa_code_keeps_legacy_endpoint_for_bridge_w_subtype( + pyicloud_service: PyiCloudService, +) -> None: + """Apple's `_W` bridge subtype should keep using the legacy trusted-device verifier.""" + + pyicloud_service.data = {"dsInfo": {"hsaVersion": 2}, "hsaChallengeRequired": False} + pyicloud_service._two_factor_delivery_method = "trusted_device" + bridge_state = MagicMock(uses_legacy_trusted_device_verifier=True) + pyicloud_service._trusted_device_bridge_state = bridge_state + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service.trust_session = MagicMock( + side_effect=lambda: pyicloud_service.data.update({"hsaTrustedBrowser": True}) + or True + ) + pyicloud_service._session = MagicMock() + pyicloud_service.session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + pyicloud_service.session.post.return_value = MagicMock(status_code=200) + + assert pyicloud_service.validate_2fa_code("123456") is True + + pyicloud_service._trusted_device_bridge.validate_code.assert_not_called() + args = pyicloud_service.session.post.call_args.args + assert args[0] == ( + f"{pyicloud_service._auth_endpoint}/verify/trusteddevice/securitycode" + ) + pyicloud_service._trusted_device_bridge.close.assert_called_once_with(bridge_state) + + +def test_validate_2fa_code_bridge_verification_exception_propagates( + pyicloud_service: PyiCloudService, +) -> None: + """Bridge verification failures should not be downgraded to generic invalid-code results.""" + + pyicloud_service._two_factor_delivery_method = "trusted_device" + bridge_state = MagicMock(uses_legacy_trusted_device_verifier=False) + pyicloud_service._trusted_device_bridge_state = bridge_state + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._trusted_device_bridge.validate_code.side_effect = ( + PyiCloudTrustedDeviceVerificationException("bridge verification failed") + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="bridge verification failed", + ): + pyicloud_service.validate_2fa_code("123456") + + pyicloud_service._trusted_device_bridge.close.assert_called_once_with(bridge_state) + + +def test_request_2fa_code_requests_sms_delivery( + pyicloud_service: PyiCloudService, +) -> None: + """Nested phone verification data should trigger SMS delivery.""" + + pyicloud_service._auth_data = { + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": "sms", + } + } + } + + with patch("pyicloud.base.PyiCloudSession") as mock_session: + pyicloud_service._session = mock_session + mock_session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + + assert pyicloud_service.request_2fa_code() is True + + args = mock_session.put.call_args.args + kwargs = mock_session.put.call_args.kwargs + assert args[0] == f"{pyicloud_service._auth_endpoint}/verify/phone" + assert kwargs["json"] == { + "phoneNumber": {"id": 3, "nonFTEU": False}, + "mode": "sms", + } + assert kwargs["headers"]["Accept"] == "application/json" + + +def test_get_mfa_auth_options_parses_hsa2_boot_html( + pyicloud_service: PyiCloudService, +) -> None: + """GET /appleauth/auth HTML should populate the HSA2 boot context.""" + + response = MagicMock() + response.json.side_effect = ValueError("not json") + response.text = """ + + + + """ + pyicloud_service._session = MagicMock() + pyicloud_service.session.get.return_value = response + + auth_options = pyicloud_service._get_mfa_auth_options() + + _, kwargs = pyicloud_service.session.get.call_args + assert kwargs["headers"]["Accept"] == "text/html" + assert auth_options["authInitialRoute"] == "auth/bridge/step" + assert auth_options["hasTrustedDevices"] is True + assert auth_options["authFactors"] == ["web_piggybacking", "sms"] + assert auth_options["bridgeInitiateData"]["webSocketUrl"] == ( + "websocket.push.apple.com" + ) + assert auth_options["phoneNumberVerification"]["trustedPhoneNumber"]["id"] == 3 + assert auth_options["sourceAppId"] == "1159" + assert pyicloud_service._hsa2_boot_context is not None + assert pyicloud_service._hsa2_boot_context.auth_initial_route == ( + "auth/bridge/step" + ) + assert pyicloud_service._hsa2_boot_context.has_trusted_devices is True + + +def test_request_2fa_code_prefers_trusted_device_bridge( + pyicloud_service: PyiCloudService, +) -> None: + """Request-7 style HSA2 challenges should start the bridge before SMS.""" + + pyicloud_service.data = { + "dsInfo": {"hsaVersion": 2}, + "hsaChallengeRequired": True, + "hsaTrustedBrowser": False, + } + pyicloud_service._auth_data = { + "authInitialRoute": "auth/bridge/step", + "hasTrustedDevices": True, + "authFactors": ["web_piggybacking", "sms"], + "bridgeInitiateData": { + "apnsTopic": "com.apple.idmsauthwidget", + "apnsEnvironment": "prod", + "webSocketUrl": "websocket.push.apple.com", + }, + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": "sms", + } + }, + } + + bridge_state = MagicMock() + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._trusted_device_bridge.start.return_value = bridge_state + pyicloud_service._session = MagicMock() + pyicloud_service.session.headers = {"User-Agent": "test-agent"} + pyicloud_service.session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + + assert pyicloud_service.request_2fa_code() is True + + pyicloud_service._trusted_device_bridge.start.assert_called_once() + pyicloud_service.session.put.assert_not_called() + assert pyicloud_service.two_factor_delivery_method == "trusted_device" + assert pyicloud_service._trusted_device_bridge_state is bridge_state + + +def test_request_2fa_code_replaces_existing_bridge_state_before_restart( + pyicloud_service: PyiCloudService, +) -> None: + """Starting a new bridge prompt should close any previous in-memory bridge session.""" + + pyicloud_service._auth_data = { + "authInitialRoute": "auth/bridge/step", + "hasTrustedDevices": True, + "bridgeInitiateData": { + "apnsTopic": "com.apple.idmsauthwidget", + "apnsEnvironment": "prod", + "webSocketUrl": "websocket.push.apple.com", + }, + } + + previous_bridge_state = MagicMock() + next_bridge_state = MagicMock() + pyicloud_service._trusted_device_bridge_state = previous_bridge_state + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._trusted_device_bridge.start.return_value = next_bridge_state + pyicloud_service._session = MagicMock() + pyicloud_service.session.headers = {"User-Agent": "test-agent"} + pyicloud_service.session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + + assert pyicloud_service.request_2fa_code() is True + + pyicloud_service._trusted_device_bridge.close.assert_called_once_with( + previous_bridge_state + ) + assert pyicloud_service._trusted_device_bridge_state is next_bridge_state + + +def test_request_2fa_code_falls_back_to_sms_when_bridge_fails( + pyicloud_service: PyiCloudService, +) -> None: + """Bridge bootstrap failures should fall back to SMS when Apple exposes it.""" + + pyicloud_service._auth_data = { + "authInitialRoute": "auth/bridge/step", + "hasTrustedDevices": True, + "bridgeInitiateData": { + "apnsTopic": "com.apple.idmsauthwidget", + "apnsEnvironment": "prod", + "webSocketUrl": "websocket.push.apple.com", + }, + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": "sms", + } + }, + } + + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._trusted_device_bridge.start.side_effect = ( + PyiCloudTrustedDevicePromptException("bridge failed") + ) + pyicloud_service._session = MagicMock() + pyicloud_service.session.headers = {"User-Agent": "test-agent"} + pyicloud_service.session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + } + + assert pyicloud_service.request_2fa_code() is True + + args = pyicloud_service.session.put.call_args.args + kwargs = pyicloud_service.session.put.call_args.kwargs + assert args[0] == f"{pyicloud_service._auth_endpoint}/verify/phone" + assert kwargs["json"] == { + "phoneNumber": {"id": 3, "nonFTEU": False}, + "mode": "sms", + } + assert pyicloud_service.two_factor_delivery_method == "sms" + assert pyicloud_service.two_factor_delivery_notice == ( + "Trusted-device prompt failed; falling back to SMS." + ) + + +def test_request_2fa_code_keeps_security_key_path_separate( + pyicloud_service: PyiCloudService, +) -> None: + """Security-key challenges should not start the bridge or SMS flows.""" + + pyicloud_service._auth_data = { + "fsaChallenge": {"challenge": "abc"}, + "authInitialRoute": "auth/bridge/step", + "hasTrustedDevices": True, + "bridgeInitiateData": { + "apnsTopic": "com.apple.idmsauthwidget", + "apnsEnvironment": "prod", + "webSocketUrl": "websocket.push.apple.com", + }, + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": "sms", + } + }, + } + + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service._session = MagicMock() + pyicloud_service.session.headers = {"User-Agent": "test-agent"} + + assert pyicloud_service.request_2fa_code() is False + + pyicloud_service._trusted_device_bridge.start.assert_not_called() + pyicloud_service.session.put.assert_not_called() + assert pyicloud_service.two_factor_delivery_method == "security_key" + + +def test_validate_2fa_code_uses_nested_sms_phone_number( + pyicloud_service: PyiCloudService, +) -> None: + """Nested phone verification data should validate via the SMS endpoint.""" + + pyicloud_service.data = {"dsInfo": {"hsaVersion": 1}, "hsaChallengeRequired": False} + pyicloud_service._auth_data = { + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": "sms", + } + } + } + pyicloud_service.trust_session = MagicMock( + side_effect=lambda: pyicloud_service.data.update({"hsaTrustedBrowser": True}) + or True + ) + + with patch("pyicloud.base.PyiCloudSession") as mock_session: + pyicloud_service._session = mock_session + mock_session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + "session_token": "test_session_token", + } + + mock_post_response = MagicMock() + mock_post_response.status_code = 200 + mock_post_response.json.return_value = {"success": True} + mock_session.post.return_value = mock_post_response + + assert pyicloud_service.validate_2fa_code("123456") + + args = mock_session.post.call_args.args + kwargs = mock_session.post.call_args.kwargs + assert args[0] == f"{pyicloud_service._auth_endpoint}/verify/phone/securitycode" + assert kwargs["json"] == { + "phoneNumber": {"id": 3, "nonFTEU": False}, + "securityCode": {"code": "123456"}, + "mode": "sms", + } + + +def test_validate_2fa_code_defaults_sms_mode_when_push_mode_missing( + pyicloud_service: PyiCloudService, +) -> None: + """Missing SMS pushMode should still validate using the delivery mode used to trigger SMS.""" + + pyicloud_service.data = {"dsInfo": {"hsaVersion": 1}, "hsaChallengeRequired": False} + pyicloud_service._auth_data = { + "phoneNumberVerification": { + "trustedPhoneNumber": { + "id": 3, + "nonFTEU": False, + "pushMode": None, + } + } + } + pyicloud_service._two_factor_delivery_method = "sms" + pyicloud_service.trust_session = MagicMock( + side_effect=lambda: pyicloud_service.data.update({"hsaTrustedBrowser": True}) + or True + ) + + with patch("pyicloud.base.PyiCloudSession") as mock_session: + pyicloud_service._session = mock_session + mock_session.data = { + "scnt": "test_scnt", + "session_id": "test_session_id", + "session_token": "test_session_token", + } + + mock_post_response = MagicMock() + mock_post_response.status_code = 200 + mock_post_response.json.return_value = {"success": True} + mock_session.post.return_value = mock_post_response + + assert pyicloud_service.validate_2fa_code("123456") + + kwargs = mock_session.post.call_args.kwargs + assert kwargs["json"]["mode"] == "sms" + + def test_validate_2fa_code_failure(pyicloud_service: PyiCloudService) -> None: """Test the validate_2fa_code method with an invalid code.""" exception = PyiCloudAPIResponseException("Invalid code") @@ -431,6 +857,24 @@ def test_logout_clears_authenticated_state( assert pyicloud_service._devices is None +def test_logout_closes_active_trusted_device_bridge_state( + pyicloud_service: PyiCloudService, +) -> None: + """Logout should close any active trusted-device bridge session before clearing state.""" + + bridge_state = MagicMock() + pyicloud_service._trusted_device_bridge_state = bridge_state + pyicloud_service._trusted_device_bridge = MagicMock() + pyicloud_service.session.cookies = MagicMock() + pyicloud_service.session.cookies.get.return_value = None + pyicloud_service.session.clear_persistence = MagicMock() + + pyicloud_service.logout() + + pyicloud_service._trusted_device_bridge.close.assert_called_once_with(bridge_state) + assert pyicloud_service._trusted_device_bridge_state is None + + def test_cookiejar_path_property(pyicloud_session: PyiCloudSession) -> None: """Test the cookiejar_path property.""" path: str = pyicloud_session.cookiejar_path @@ -557,6 +1001,59 @@ def test_request_success(pyicloud_service_working: PyiCloudService) -> None: ) +def test_session_persistence_excludes_trusted_device_bridge_state( + pyicloud_service_working: PyiCloudService, +) -> None: + """Bridge-only state should remain in memory and never be written to persisted session files.""" + + test_base = Path(tempfile.gettempdir()) / "python-test-results" + test_base.mkdir(parents=True, exist_ok=True) + temp_root = Path(tempfile.mkdtemp(prefix="bridge-auth-", dir=test_base)) + session = PyiCloudSession( + service=pyicloud_service_working, + client_id="", + cookie_directory=str(temp_root), + ) + pyicloud_service_working._session = session + bridge_state = MagicMock( + push_token="bridge-ptkn", + session_uuid="bridge-session-uuid", + idmsdata="bridge-idmsdata", + encrypted_code="bridge-encrypted-code", + ) + pyicloud_service_working._trusted_device_bridge_state = bridge_state + session._data = { + "session_token": "valid-token", + "session_id": "persisted-session-id", + "push_token": bridge_state.push_token, + "session_uuid": bridge_state.session_uuid, + "idmsdata": bridge_state.idmsdata, + "encrypted_code": bridge_state.encrypted_code, + } + + session._save_session_data() + + persisted_session = Path(session.session_path).read_text(encoding="utf-8") + for secret_value in ( + "bridge-ptkn", + "bridge-session-uuid", + "bridge-idmsdata", + "bridge-encrypted-code", + ): + assert secret_value not in persisted_session + + cookiejar_path = Path(session.cookiejar_path) + if cookiejar_path.exists(): + persisted_cookiejar = cookiejar_path.read_text(encoding="utf-8") + for secret_value in ( + "bridge-ptkn", + "bridge-session-uuid", + "bridge-idmsdata", + "bridge-encrypted-code", + ): + assert secret_value not in persisted_cookiejar + + def test_request_failure(pyicloud_service_working: PyiCloudService) -> None: """Test the request method with a failure response.""" @@ -601,6 +1098,26 @@ def test_request_failure(pyicloud_service_working: PyiCloudService) -> None: assert open_mock.call_count == 2 +def test_request_raw_normalizes_transport_failure( + pyicloud_service_working: PyiCloudService, +) -> None: + """Raw requests should keep the session's normalized transport failure contract.""" + + with patch("requests.Session.request") as mock_request: + mock_request.side_effect = requests.exceptions.Timeout("timed out") + test_base = Path(tempfile.gettempdir()) / "python-test-results" + test_base.mkdir(parents=True, exist_ok=True) + temp_root = Path(tempfile.mkdtemp(prefix="request-raw-", dir=test_base)) + pyicloud_session = PyiCloudSession( + pyicloud_service_working, "", cookie_directory=str(temp_root) + ) + + with pytest.raises( + PyiCloudAPIResponseException, match="Request failed to iCloud" + ): + pyicloud_session.request_raw("GET", "https://example.com") + + def test_request_with_custom_headers(pyicloud_service_working: PyiCloudService) -> None: """Test the request method with custom headers.""" with ( diff --git a/tests/test_cmdline.py b/tests/test_cmdline.py index 64dbc87a..bc92bc6c 100644 --- a/tests/test_cmdline.py +++ b/tests/test_cmdline.py @@ -10,7 +10,7 @@ from pathlib import Path from types import SimpleNamespace from typing import Any, Optional -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch from uuid import uuid4 import click @@ -2347,6 +2347,195 @@ def test_trusted_device_2sa_flow() -> None: fake_api.trusted_devices[0], "123456" ) +def test_sms_2fa_flow_requests_sms_before_prompt() -> None: + """Auth login should request SMS delivery before prompting for the code.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + fake_api.two_factor_delivery_method = "sms" + fake_api.request_2fa_code.return_value = True + with patch.object(context_module.typer, "prompt", return_value="123456"): + result = _invoke(fake_api, "auth", "login", interactive=True) + assert result.exit_code == 0 + assert "Requested a 2FA code by SMS." in result.stdout + fake_api.request_2fa_code.assert_called_once_with() + fake_api.validate_2fa_code.assert_called_once_with("123456") + + +def test_trusted_device_2fa_flow_reports_device_prompt() -> None: + """Auth login should report trusted-device prompt delivery when bridge succeeds.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + + def request_prompt() -> bool: + fake_api.two_factor_delivery_method = "trusted_device" + return True + + fake_api.request_2fa_code.side_effect = request_prompt + + with patch.object(context_module.typer, "prompt", return_value="123456"): + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code == 0 + assert "Requested a 2FA prompt on your trusted Apple devices." in result.stdout + fake_api.validate_2fa_code.assert_called_once_with("123456") + + +def test_code_prompt_aborts_when_request_2fa_code_requires_security_key() -> None: + """Auth login should not enter the numeric 2FA prompt loop for key-only challenges.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + fake_api.request_2fa_code.return_value = False + + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code != 0 + assert result.exception.args[0] == ( + "This 2FA challenge requires a security key. Connect one and retry." + ) + fake_api.validate_2fa_code.assert_not_called() + + +def test_trusted_device_2fa_retries_invalid_codes_before_success() -> None: + """Auth login should allow up to three trusted-device 2FA attempts.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + + def request_prompt() -> bool: + fake_api.two_factor_delivery_method = "trusted_device" + return True + + fake_api.request_2fa_code.side_effect = request_prompt + fake_api.validate_2fa_code.side_effect = [False, False, True] + + with patch.object( + context_module.typer, + "prompt", + side_effect=["111111", "222222", "333333"], + ): + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code == 0 + assert "Invalid 2FA code. 2 attempt(s) remaining." in result.stdout + assert "Invalid 2FA code. 1 attempt(s) remaining." in result.stdout + assert fake_api.validate_2fa_code.call_args_list == [ + call("111111"), + call("222222"), + call("333333"), + ] + + +def test_sms_2fa_aborts_after_three_invalid_codes() -> None: + """Auth login should stop after three invalid 2FA attempts.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + fake_api.two_factor_delivery_method = "sms" + fake_api.request_2fa_code.return_value = True + fake_api.validate_2fa_code.side_effect = [False, False, False] + + with patch.object( + context_module.typer, + "prompt", + side_effect=["111111", "222222", "333333"], + ): + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code != 0 + assert result.exception.args[0] == "Failed to verify the 2FA code." + assert "Invalid 2FA code. 2 attempt(s) remaining." in result.stdout + assert "Invalid 2FA code. 1 attempt(s) remaining." in result.stdout + assert fake_api.validate_2fa_code.call_args_list == [ + call("111111"), + call("222222"), + call("333333"), + ] + + +def test_trusted_device_2fa_bridge_fallback_reports_notice() -> None: + """Auth login should print the bridge fallback notice before the SMS message.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + + def request_sms_fallback() -> bool: + fake_api.two_factor_delivery_method = "sms" + fake_api.two_factor_delivery_notice = ( + "Trusted-device prompt failed; falling back to SMS." + ) + return True + + fake_api.request_2fa_code.side_effect = request_sms_fallback + + with patch.object(context_module.typer, "prompt", return_value="123456"): + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code == 0 + assert "Trusted-device prompt failed; falling back to SMS." in result.stdout + assert "Requested a 2FA code by SMS." in result.stdout + fake_api.validate_2fa_code.assert_called_once_with("123456") + + +def test_sms_2fa_request_failure_aborts() -> None: + """Auth login should surface SMS delivery request failures clearly.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + fake_api.request_2fa_code.side_effect = context_module.PyiCloudAPIResponseException( + "sms request failed" + ) + + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code != 0 + assert result.exception.args[0] == "Failed to request the 2FA SMS code." + fake_api.validate_2fa_code.assert_not_called() + + +def test_trusted_device_2fa_request_failure_aborts() -> None: + """Auth login should surface bridge delivery failures clearly.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + fake_api.request_2fa_code.side_effect = ( + context_module.PyiCloudTrustedDevicePromptException("bridge failed") + ) + + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code != 0 + assert result.exception.args[0] == ( + "Failed to request the 2FA trusted-device prompt." + ) + fake_api.validate_2fa_code.assert_not_called() + + +def test_trusted_device_2fa_verification_failure_aborts() -> None: + """Auth login should surface bridge verification failures clearly.""" + + fake_api = FakeAPI() + fake_api.requires_2fa = True + + def request_prompt() -> bool: + fake_api.two_factor_delivery_method = "trusted_device" + return True + + fake_api.request_2fa_code.side_effect = request_prompt + fake_api.validate_2fa_code.side_effect = ( + context_module.PyiCloudTrustedDeviceVerificationException( + "bridge verification failed" + ) + ) + + with patch.object(context_module.typer, "prompt", return_value="123456"): + result = _invoke(fake_api, "auth", "login", interactive=True) + + assert result.exit_code != 0 + assert result.exception.args[0] == ("Failed to verify the 2FA trusted-device code.") + def test_notes_commands() -> None: """Notes commands should expose list, detail, render, export, and sync flows.""" diff --git a/tests/test_hsa2_bridge.py b/tests/test_hsa2_bridge.py new file mode 100644 index 00000000..2e64f398 --- /dev/null +++ b/tests/test_hsa2_bridge.py @@ -0,0 +1,1357 @@ +"""Tests for the HSA2 trusted-device bridge helpers.""" + +from __future__ import annotations + +import base64 +import json +import socket +from binascii import unhexlify +from typing import Callable +from unittest.mock import MagicMock, call + +import pytest + +import pyicloud.hsa2_bridge as bridge_module +from pyicloud.exceptions import ( + PyiCloudTrustedDevicePromptException, + PyiCloudTrustedDeviceVerificationException, +) +from pyicloud.hsa2_bridge import ( + BRIDGE_DONE_DATA_B64, + BridgePushPayload, + Hsa2BootContext, + TrustedDeviceBridgeBootstrapper, + _encode_ack_message, + _encode_bytes_field, + _encode_string_field, + _encode_uint32_field, + _encode_web_filter_message, + _extract_json_payload, + _hex_to_b64, + _topic_hash, + parse_boot_args_html, +) +from pyicloud.hsa2_bridge_prover import ( + TrustedDeviceBridgeProver, + _TrustedDeviceBridgeServerProver, +) + + +class _FakeWebSocket: + def __init__( + self, + messages: list[bytes | Exception], + *, + on_read: Callable[[int], None] | None = None, + ) -> None: + self._messages = list(messages) + self._on_read = on_read + self.sent_messages: list[bytes] = [] + self.closed = False + self.read_count = 0 + + def send_binary(self, payload: bytes) -> None: + self.sent_messages.append(payload) + + def read_message(self) -> bytes: + self.read_count += 1 + if self._on_read is not None: + self._on_read(self.read_count) + message = self._messages.pop(0) + if isinstance(message, Exception): + raise message + return message + + def close(self) -> None: + self.closed = True + + +class _FakePrivateKey: + def sign(self, nonce: bytes, _algorithm: object) -> bytes: + return b"signature-for-" + nonce[:4] + + +def _encode_connection_response(push_token: bytes) -> bytes: + payload = b"".join( + [ + _encode_string_field(1, base64.b64encode(push_token).decode("ascii")), + _encode_uint32_field(2, 0), + ] + ) + return _encode_bytes_field(1, payload) + + +def _encode_connection_response_with_token_b64(push_token_b64: str) -> bytes: + payload = b"".join( + [ + _encode_string_field(1, push_token_b64), + _encode_uint32_field(2, 0), + ] + ) + return _encode_bytes_field(1, payload) + + +def _encode_push_message( + topic: str, payload: dict[str, object], message_id: int +) -> bytes: + topic_bytes = bytes.fromhex(_topic_hash(topic)) + body = b"".join( + [ + _encode_bytes_field(1, topic_bytes), + _encode_uint32_field(2, message_id), + _encode_bytes_field(4, json.dumps(payload).encode("utf-8")), + ] + ) + return _encode_bytes_field(2, body) + + +def _encode_channel_subscription_response(topic: str, message_id: int = 1) -> bytes: + channel_response = b"".join( + [ + _encode_string_field(1, topic), + _encode_bytes_field(2, _encode_bytes_field(1, b"channel-id")), + ] + ) + payload = _encode_bytes_field(1, channel_response) + body = b"".join( + [ + _encode_bytes_field(1, payload), + _encode_uint32_field(2, message_id), + _encode_uint32_field(3, 0), + ] + ) + return _encode_bytes_field(3, body) + + +def _read_varint(data: bytes, offset: int) -> tuple[int, int]: + value = 0 + shift = 0 + while True: + byte = data[offset] + offset += 1 + value |= (byte & 0x7F) << shift + if not (byte & 0x80): + return value, offset + shift += 7 + + +def _decode_fields(data: bytes) -> dict[int, list[int | bytes]]: + offset = 0 + fields: dict[int, list[int | bytes]] = {} + while offset < len(data): + key, offset = _read_varint(data, offset) + field_number = key >> 3 + wire_type = key & 0x07 + + if wire_type == 0: + value, offset = _read_varint(data, offset) + elif wire_type == 2: + length, offset = _read_varint(data, offset) + value = data[offset : offset + length] + offset += length + else: + raise AssertionError(f"Unexpected wire type {wire_type}") + + fields.setdefault(field_number, []).append(value) + return fields + + +def _boot_context(topic: str = "com.apple.idmsauthwidget") -> Hsa2BootContext: + return Hsa2BootContext( + auth_initial_route="auth/bridge/step", + has_trusted_devices=True, + auth_factors=("web_piggybacking", "sms"), + bridge_initiate_data={ + "apnsTopic": topic, + "apnsEnvironment": "prod", + "webSocketUrl": "websocket.push.apple.com", + }, + source_app_id="1159", + ) + + +def _response(status_code: int) -> MagicMock: + response = MagicMock() + response.status_code = status_code + response.text = "" + return response + + +def test_parse_boot_args_html_extracts_bridge_context() -> None: + """Request-5 style boot args should yield the bridge routing metadata.""" + + html = """ + + + + """ + + boot_context = parse_boot_args_html(html) + + assert boot_context.auth_initial_route == "auth/bridge/step" + assert boot_context.has_trusted_devices is True + assert boot_context.auth_factors == ( + "web_piggybacking", + "robocall", + "sms", + "generatedcode", + ) + assert boot_context.bridge_initiate_data["webSocketUrl"] == ( + "websocket.push.apple.com" + ) + assert boot_context.phone_number_verification["trustedPhoneNumber"]["id"] == 3 + assert boot_context.source_app_id == "1159" + + +def test_parse_boot_args_html_accepts_reordered_script_attributes() -> None: + """boot_args extraction should not depend on one exact script tag string.""" + + html = """ + + + + """ + + boot_context = parse_boot_args_html(html) + + assert boot_context.auth_initial_route == "auth/bridge/step" + assert boot_context.has_trusted_devices is True + assert boot_context.bridge_initiate_data["webSocketUrl"] == ( + "websocket.push.apple.com" + ) + + +def test_read_varint_rejects_malformed_overlong_varint() -> None: + """Malformed bridge varints should fail immediately instead of reading forever.""" + + with pytest.raises( + PyiCloudTrustedDevicePromptException, + match="Malformed protobuf varint", + ): + bridge_module._read_varint(b"\x80" * 10, 0) + + +def test_decode_fields_rejects_truncated_length_delimited_field() -> None: + """Length-delimited bridge fields must fit inside the current frame.""" + + with pytest.raises( + PyiCloudTrustedDevicePromptException, + match="Truncated protobuf field", + ): + bridge_module._decode_fields(b"\x0a\x05abc") + + +@pytest.mark.parametrize( + ("payload", "message"), + [ + ({"sessionUUID": 123}, "Malformed trusted-device bridge push payload"), + ({"sessionUUID": " "}, "Malformed trusted-device bridge push payload"), + ( + {"sessionUUID": "bridge-session", "nextStep": " "}, + "Malformed trusted-device bridge push payload", + ), + ( + {"sessionUUID": "bridge-session", "encryptedCode": " "}, + "Malformed trusted-device bridge push payload", + ), + ( + {"sessionUUID": "bridge-session", "ec": "oops"}, + "Malformed trusted-device bridge push payload", + ), + ], +) +def test_bridge_push_payload_rejects_malformed_fields( + payload: dict[str, object], message: str +) -> None: + """Bridge push validation should reject coerced or blank protocol fields.""" + + with pytest.raises(PyiCloudTrustedDevicePromptException, match=message): + BridgePushPayload.from_payload(payload) + + +def test_bridge_push_payload_preserves_unknown_extra_fields() -> None: + """Unknown Apple bridge fields should survive strict validation unchanged.""" + + payload = BridgePushPayload.from_payload( + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "extraField": {"foo": "bar"}, + } + ) + + assert payload.session_uuid == "bridge-session" + assert payload.payload["extraField"] == {"foo": "bar"} + + +def test_extract_json_payload_finds_embedded_json() -> None: + """Request-8 style binary payloads should yield the embedded JSON envelope.""" + + expected_payload = { + "sessionUUID": "bridge-session", + "nextStep": "2", + "ruiURLKey": "hsa2TwoFactorAuthApprovalFlowUrl", + } + noisy_payload = ( + b"\x12\xa8\x07\x00" + + json.dumps(expected_payload).encode("utf-8") + + b"\x18\x00\x01" + ) + + assert _extract_json_payload(noisy_payload) == expected_payload + + +def test_trusted_device_bridge_prover_roundtrip() -> None: + """The Python prover port should match the worker's SPAKE2+/AES-GCM flow.""" + + salt_b64 = base64.b64encode(b"0123456789abcdef").decode("ascii") + prover = TrustedDeviceBridgeProver() + server = _TrustedDeviceBridgeServerProver(password="050044", salt_b64=salt_b64) + + prover.init_with_salt(salt_b64, "050044") + client_message1 = prover.get_message1() + server_message1 = server.get_message1() + server_message2 = server.process_message1(client_message1) + client_message2 = prover.process_message1(server_message1) + server_key = server.verify_message2(client_message2) + client_key = prover.process_message2(server_message2)["key"] + + assert prover.is_verified() is True + assert client_key == server_key + encrypted_code = server.encrypt_message("derived-device-code") + assert prover.decrypt_message(encrypted_code) == "derived-device-code" + + +def test_trusted_device_bridge_prover_retries_zero_ephemeral_scalars( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Ephemeral prover scalars must stay in the non-zero subgroup range.""" + + draws = iter([0, 7, 0, 9]) + monkeypatch.setattr( + "pyicloud.hsa2_bridge_prover.secrets.randbelow", + lambda _limit: next(draws), + ) + + salt_b64 = base64.b64encode(b"0123456789abcdef").decode("ascii") + prover = TrustedDeviceBridgeProver() + prover.init_with_salt(salt_b64, "050044") + server = _TrustedDeviceBridgeServerProver(password="050044", salt_b64=salt_b64) + + assert prover._client is not None + assert prover._client._x == 7 + assert server._server._y == 9 + + +def test_trusted_device_bridge_prover_normalizes_malformed_bridge_payloads() -> None: + """Malformed encrypted payloads should surface as ValueError.""" + + prover = TrustedDeviceBridgeProver() + prover._verifier_key = "00" * 32 + + with pytest.raises(ValueError, match="Malformed bridge payload"): + prover.decrypt_message(base64.b64encode(b"").decode("ascii")) + + with pytest.raises(ValueError, match="Malformed bridge payload"): + prover.decrypt_message(base64.b64encode(b"\x01truncated").decode("ascii")) + + +def test_trusted_device_bridge_bootstrap_keeps_websocket_open_and_persists_step2() -> ( + None +): + """The bridge bootstrap should keep the websocket alive after step 0 succeeds.""" + + topic = "com.apple.idmsauthwidget" + bridge_payload = { + "sessionUUID": "bridge-session", + "nextStep": "2", + "ruiURLKey": "hsa2TwoFactorAuthApprovalFlowUrl", + "txnid": "2300_282820214_S", + "salt": "c2FsdA==", + "mid": "bridge-mid", + "idmsdata": "idms-data", + "akdata": {"lat": 49.52, "lng": 6.1}, + } + websocket_urls: list[tuple[str, float, str, str]] = [] + session = MagicMock() + session.request_raw.return_value = _response(200) + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message(topic, bridge_payload, 2300), + ], + on_read=lambda read_count: ( + read_count == 1 or session.request_raw.call_count == 1 + ) + or (_ for _ in ()).throw( + AssertionError("Bridge step 0 should be posted before waiting for push") + ), + ) + + def websocket_factory( + url: str, timeout: float, origin: str, user_agent: str + ) -> _FakeWebSocket: + websocket_urls.append((url, timeout, origin, user_agent)) + return websocket + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=websocket_factory, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + websocket_url = websocket_urls[0][0] + assert websocket_url.startswith("wss://websocket.push.apple.com/v2/") + assert state.connection_path == websocket_url.rsplit("/", 1)[1] + connection_message = unhexlify(state.connection_path) + outer_fields = _decode_fields(connection_message) + inner_fields = _decode_fields(outer_fields[1][0]) + assert inner_fields[1][0] == b"\x04public-key" + assert bytes(inner_fields[3][0]).startswith(b"\x01\x03signature-for-") + assert state.push_token == b"push-token".hex() + assert state.session_uuid == "bridge-session" + assert state.next_step == "2" + assert state.rui_url_key == "hsa2TwoFactorAuthApprovalFlowUrl" + assert state.txnid == "2300_282820214_S" + assert state.salt == "c2FsdA==" + assert state.mid == "bridge-mid" + assert state.idmsdata == "idms-data" + assert state.akdata == {"lat": 49.52, "lng": 6.1} + assert state.websocket is websocket + assert websocket.sent_messages[0] == _encode_web_filter_message([topic]) + assert websocket.sent_messages[1] == _encode_ack_message( + bytes.fromhex(_topic_hash(topic)), + 2300, + ) + session.request_raw.assert_called_once_with( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/0", + json={ + "sessionUUID": "bridge-session", + "ptkn": b"push-token".hex(), + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ) + assert websocket.closed is False + bootstrapper.close(state) + assert websocket.closed is True + assert state.websocket is None + + +def test_trusted_device_bridge_rejects_malformed_push_token() -> None: + """Malformed push tokens should surface as bridge prompt failures.""" + + websocket = _FakeWebSocket( + [_encode_connection_response_with_token_b64("%%%not-base64%%%")] + ) + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + + with pytest.raises( + PyiCloudTrustedDevicePromptException, + match="Failed to bootstrap the trusted-device bridge prompt.", + ) as exc_info: + bootstrapper.start( + session=MagicMock(), + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(), + user_agent="test-agent", + ) + + assert isinstance(exc_info.value.__cause__, PyiCloudTrustedDevicePromptException) + assert "Malformed bridge push token" in str(exc_info.value.__cause__) + assert websocket.closed is True + + +def test_trusted_device_bridge_rejects_mismatched_session_uuid() -> None: + """The first bridge push should match the session UUID used for step 0.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "different-session", + "nextStep": "2", + }, + 2300, + ), + ] + ) + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.return_value = _response(200) + + with pytest.raises( + PyiCloudTrustedDevicePromptException, + match="Failed to bootstrap the trusted-device bridge prompt.", + ) as exc_info: + bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + assert isinstance(exc_info.value.__cause__, PyiCloudTrustedDevicePromptException) + assert "mismatched session UUID" in str(exc_info.value.__cause__) + assert websocket.closed is True + + +def test_trusted_device_bridge_start_propagates_unexpected_exception() -> None: + """Unexpected bootstrap bugs should surface directly instead of being wrapped.""" + + websocket = _FakeWebSocket([TypeError("boom")]) + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + + with pytest.raises(TypeError, match="boom"): + bootstrapper.start( + session=MagicMock(), + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(), + user_agent="test-agent", + ) + + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_runs_step2_step4_step6_sequence() -> None: + """Bridge-backed trusted-device verification should follow Apple's step 2/4/6 flow.""" + + topic = "com.apple.idmsauthwidget" + initial_push = { + "sessionUUID": "bridge-session", + "nextStep": "2", + "ruiURLKey": "hsa2TwoFactorAuthApprovalFlowUrl", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "mid": "bridge-mid", + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + } + server_message1_hex = "aa01" + server_message2_hex = "bb02" + step4_data = base64.b64encode( + ( + _hex_to_b64(server_message1_hex) + "_" + _hex_to_b64(server_message2_hex) + ).encode("utf-8") + ).decode("ascii") + step4_push = { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": step4_data, + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + } + step6_push = { + "sessionUUID": "bridge-session", + "nextStep": "6", + "txnid": "2300_282820214_S", + "encryptedCode": "ciphertext", + "idmsdata": "step6-idms", + "akdata": {"step": 6}, + "mid": "bridge-mid", + } + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message(topic, initial_push, 2300), + _encode_push_message(topic, step4_push, 2301), + _encode_push_message(topic, step6_push, 2302), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.return_value = "ef01" + prover.process_message2.return_value = {"isVerified": True, "key": "deadbeef"} + prover.get_key.return_value = "deadbeef" + prover.decrypt_message.return_value = "derived-device-code" + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + + session = MagicMock() + session.request_raw.side_effect = [ + _response(200), + _response(200), + _response(200), + _response(409), + _response(204), + ] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + assert ( + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + is True + ) + + prover.init_with_salt.assert_called_once_with(initial_push["salt"], "050044") + prover.process_message1.assert_called_once_with(server_message1_hex) + prover.process_message2.assert_called_once_with(server_message2_hex) + prover.decrypt_message.assert_called_once_with("ciphertext") + assert session.request_raw.call_args_list == [ + call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/0", + json={ + "sessionUUID": "bridge-session", + "ptkn": b"push-token".hex(), + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ), + call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/2", + json={ + "sessionUUID": "bridge-session", + "data": _hex_to_b64("abcd"), + "ptkn": b"push-token".hex(), + "nextStep": 2, + "idmsdata": "initial-idms", + "akdata": '{"lat":49.52}', + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ), + call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/4", + json={ + "sessionUUID": "bridge-session", + "data": _hex_to_b64("ef01"), + "ptkn": b"push-token".hex(), + "nextStep": 4, + "idmsdata": "step4-idms", + "akdata": '{"step":4}', + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ), + call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/code/validate", + json={ + "sessionUUID": "bridge-session", + "code": "derived-device-code", + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ), + call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/6", + json={ + "sessionUUID": "bridge-session", + "data": BRIDGE_DONE_DATA_B64, + "ptkn": b"push-token".hex(), + "nextStep": 6, + "idmsdata": "step6-idms", + "akdata": '{"step":6}', + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ), + ] + assert websocket.sent_messages[2] == _encode_ack_message( + bytes.fromhex(_topic_hash(topic)), + 2301, + ) + assert websocket.sent_messages[3] == _encode_ack_message( + bytes.fromhex(_topic_hash(topic)), + 2302, + ) + assert websocket.closed is True + assert state.websocket is None + + +def test_trusted_device_bridge_validate_code_accepts_step4_encrypted_code_final_push() -> ( + None +): + """Apple can finish the bridge flow with nextStep=4 when encryptedCode is present.""" + + topic = "com.apple.idmsauthwidget" + initial_push = { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + } + step4_data = base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode("utf-8") + ).decode("ascii") + prover_push = { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": step4_data, + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + } + final_push = { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "encryptedCode": "ciphertext", + "idmsdata": "final-idms", + "akdata": {"step": "final"}, + } + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message(topic, initial_push, 2300), + _encode_push_message(topic, prover_push, 2301), + _encode_push_message(topic, final_push, 2302), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.return_value = "ef01" + prover.process_message2.return_value = {"isVerified": True, "key": "deadbeef"} + prover.decrypt_message.return_value = "derived-device-code" + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [ + _response(200), + _response(200), + _response(200), + _response(200), + _response(204), + ] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + assert ( + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + is True + ) + assert session.request_raw.call_args_list[-1] == call( + "POST", + "https://idmsa.apple.com/appleauth/auth/bridge/step/4", + json={ + "sessionUUID": "bridge-session", + "data": BRIDGE_DONE_DATA_B64, + "ptkn": b"push-token".hex(), + "nextStep": 4, + "idmsdata": "final-idms", + "akdata": '{"step":"final"}', + }, + headers={"scnt": "test-scnt", "X-Apple-App-Id": "1159"}, + ) + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_returns_false_on_412() -> None: + """A bridge code-validate 412 should be treated as an invalid code, not a transport failure.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode( + "utf-8" + ) + ).decode("ascii"), + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + }, + 2301, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "6", + "txnid": "2300_282820214_S", + "encryptedCode": "ciphertext", + "idmsdata": "step6-idms", + "akdata": {"step": 6}, + }, + 2302, + ), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.return_value = "ef01" + prover.process_message2.return_value = {"isVerified": True, "key": "deadbeef"} + prover.decrypt_message.return_value = "derived-device-code" + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [ + _response(200), + _response(200), + _response(200), + _response(412), + _response(204), + ] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + assert ( + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + is False + ) + assert session.request_raw.call_args_list[-1].args[1].endswith("/bridge/step/6") + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_rejects_error_push() -> None: + """Bridge error pushes should surface as verification exceptions.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode( + "utf-8" + ) + ).decode("ascii"), + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + }, + 2301, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "6", + "txnid": "2300_282820214_S", + "ec": 7, + }, + 2302, + ), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.return_value = "ef01" + prover.process_message2.return_value = {"isVerified": True, "key": "deadbeef"} + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [_response(200), _response(200), _response(200)] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="error push", + ): + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_rejects_malformed_final_push() -> None: + """Final bridge pushes must include encryptedCode once the prover flow is complete.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode( + "utf-8" + ) + ).decode("ascii"), + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + }, + 2301, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + }, + 2302, + ), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.return_value = "ef01" + prover.process_message2.return_value = {"isVerified": True, "key": "deadbeef"} + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [_response(200), _response(200), _response(200)] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="unexpected final payload", + ): + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_rejects_mismatched_followup_push() -> None: + """Follow-up bridge pushes must stay on the same bridge session.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + _encode_push_message( + topic, + { + "sessionUUID": "different-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode( + "utf-8" + ) + ).decode("ascii"), + }, + 2301, + ), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [_response(200), _response(200)] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="mismatched session UUID", + ): + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_closes_on_timeout() -> None: + """Timeouts after prompt delivery should surface as bridge verification failures.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + socket.timeout("timed out"), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [_response(200), _response(200)] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="websocket transport error", + ): + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + assert websocket.closed is True + + +def test_trusted_device_bridge_validate_code_wraps_step4_prover_message1_failure() -> ( + None +): + """Malformed step-4 prover data should surface as bridge verification failures.""" + + topic = "com.apple.idmsauthwidget" + websocket = _FakeWebSocket( + [ + _encode_connection_response(b"push-token"), + _encode_channel_subscription_response(topic), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "2", + "txnid": "2300_282820214_S", + "salt": base64.b64encode(b"0123456789abcdef").decode("ascii"), + "idmsdata": "initial-idms", + "akdata": {"lat": 49.52}, + }, + 2300, + ), + _encode_push_message( + topic, + { + "sessionUUID": "bridge-session", + "nextStep": "4", + "txnid": "2300_282820214_S", + "data": base64.b64encode( + (_hex_to_b64("aa01") + "_" + _hex_to_b64("bb02")).encode( + "utf-8" + ) + ).decode("ascii"), + "idmsdata": "step4-idms", + "akdata": {"step": 4}, + }, + 2301, + ), + ] + ) + prover = MagicMock() + prover.get_message1.return_value = "abcd" + prover.process_message1.side_effect = ValueError("bad point") + + bootstrapper = TrustedDeviceBridgeBootstrapper( + timeout=1.0, + websocket_factory=lambda *_args: websocket, + prover_factory=lambda: prover, + ) + bootstrapper._generate_keypair = MagicMock( # type: ignore[attr-defined] + return_value=(b"\x04public-key", _FakePrivateKey()) + ) + bootstrapper._generate_session_uuid = MagicMock( # type: ignore[attr-defined] + return_value="bridge-session" + ) + session = MagicMock() + session.request_raw.side_effect = [_response(200), _response(200)] + + state = bootstrapper.start( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + boot_context=_boot_context(topic), + user_agent="test-agent", + ) + + with pytest.raises( + PyiCloudTrustedDeviceVerificationException, + match="step 4 payload is malformed", + ): + bootstrapper.validate_code( + session=session, + auth_endpoint="https://idmsa.apple.com/appleauth/auth", + headers={"scnt": "test-scnt"}, + bridge_state=state, + code="050044", + ) + assert websocket.closed is True From b8cfb2b72782c2e6ab77dc89447996fa9215123f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 3 Apr 2026 18:08:45 +0100 Subject: [PATCH 11/13] Update protobuf requirement from <7,>=6.31.1 to >=6.31.1,<8 (#208) Updates the requirements on [protobuf](https://github.com/protocolbuffers/protobuf) to permit the latest version. - [Release notes](https://github.com/protocolbuffers/protobuf/releases) - [Commits](https://github.com/protocolbuffers/protobuf/commits) --- updated-dependencies: - dependency-name: protobuf dependency-version: 7.34.1 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6e0bc574..f974c314 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ cryptography>=44.0.0 fido2>=2.0.0 keyring>=25.6.0 keyrings.alt>=5.0.2 -protobuf>=6.31.1,<7 +protobuf>=6.31.1,<8 pydantic>=2.12,<3 requests>=2.31.0 rich>=13.0.0 From 66429652140e9c1dd39b0f42cfd93bdb65629f54 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 3 Apr 2026 18:08:54 +0100 Subject: [PATCH 12/13] Bump github/codeql-action from 4.34.1 to 4.35.1 (#209) Bumps [github/codeql-action](https://github.com/github/codeql-action) from 4.34.1 to 4.35.1. - [Release notes](https://github.com/github/codeql-action/releases) - [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/github/codeql-action/compare/38697555549f1db7851b81482ff19f1fa5c4fedc...c10b8064de6f491fea524254123dbe5e09572f13) --- updated-dependencies: - dependency-name: github/codeql-action dependency-version: 4.35.1 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/codeql.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 1ebddc37..6c07b721 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -26,11 +26,11 @@ jobs: uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Initialize CodeQL - uses: github/codeql-action/init@38697555549f1db7851b81482ff19f1fa5c4fedc # v4.34.1 + uses: github/codeql-action/init@c10b8064de6f491fea524254123dbe5e09572f13 # v4.35.1 with: languages: python - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@38697555549f1db7851b81482ff19f1fa5c4fedc # v4.34.1 + uses: github/codeql-action/analyze@c10b8064de6f491fea524254123dbe5e09572f13 # v4.35.1 with: category: "/language:python" From da067bf6a3c435f8c068f48544fde2be2d4839f2 Mon Sep 17 00:00:00 2001 From: Tim Laing <11019084+timlaing@users.noreply.github.com> Date: Fri, 3 Apr 2026 18:04:22 +0000 Subject: [PATCH 13/13] refactor: remove obsolete 2FA tests for trusted device and SMS flows --- tests/test_cmdline.py | 156 +----------------------------------------- 1 file changed, 1 insertion(+), 155 deletions(-) diff --git a/tests/test_cmdline.py b/tests/test_cmdline.py index 4f8c0e7a..d691a89a 100644 --- a/tests/test_cmdline.py +++ b/tests/test_cmdline.py @@ -2347,6 +2347,7 @@ def test_trusted_device_2sa_flow() -> None: fake_api.trusted_devices[0], "123456" ) + def test_sms_2fa_flow_requests_sms_before_prompt() -> None: """Auth login should request SMS delivery before prompting for the code.""" @@ -2557,161 +2558,6 @@ def request_prompt() -> bool: fake_api.validate_2fa_code.assert_called_once_with("123456") -def test_code_prompt_aborts_when_request_2fa_code_requires_security_key() -> None: - """Auth login should not enter the numeric 2FA prompt loop for key-only challenges.""" - - fake_api = FakeAPI() - fake_api.requires_2fa = True - fake_api.request_2fa_code.return_value = False - - result = _invoke(fake_api, "auth", "login", interactive=True) - - assert result.exit_code != 0 - assert result.exception.args[0] == ( - "This 2FA challenge requires a security key. Connect one and retry." - ) - fake_api.validate_2fa_code.assert_not_called() - - -def test_trusted_device_2fa_retries_invalid_codes_before_success() -> None: - """Auth login should allow up to three trusted-device 2FA attempts.""" - - fake_api = FakeAPI() - fake_api.requires_2fa = True - - def request_prompt() -> bool: - fake_api.two_factor_delivery_method = "trusted_device" - return True - - fake_api.request_2fa_code.side_effect = request_prompt - fake_api.validate_2fa_code.side_effect = [False, False, True] - - with patch.object( - context_module.typer, - "prompt", - side_effect=["111111", "222222", "333333"], - ): - result = _invoke(fake_api, "auth", "login", interactive=True) - - assert result.exit_code == 0 - assert "Invalid 2FA code. 2 attempt(s) remaining." in result.stdout - assert "Invalid 2FA code. 1 attempt(s) remaining." in result.stdout - assert fake_api.validate_2fa_code.call_args_list == [ - call("111111"), - call("222222"), - call("333333"), - ] - - -def test_sms_2fa_aborts_after_three_invalid_codes() -> None: - """Auth login should stop after three invalid 2FA attempts.""" - - fake_api = FakeAPI() - fake_api.requires_2fa = True - fake_api.two_factor_delivery_method = "sms" - fake_api.request_2fa_code.return_value = True - fake_api.validate_2fa_code.side_effect = [False, False, False] - - with patch.object( - context_module.typer, - "prompt", - side_effect=["111111", "222222", "333333"], - ): - result = _invoke(fake_api, "auth", "login", interactive=True) - - assert result.exit_code != 0 - assert result.exception.args[0] == "Failed to verify the 2FA code." - assert "Invalid 2FA code. 2 attempt(s) remaining." in result.stdout - assert "Invalid 2FA code. 1 attempt(s) remaining." in result.stdout - assert fake_api.validate_2fa_code.call_args_list == [ - call("111111"), - call("222222"), - call("333333"), - ] - - -def test_trusted_device_2fa_bridge_fallback_reports_notice() -> None: - """Auth login should print the bridge fallback notice before the SMS message.""" - - fake_api = FakeAPI() - fake_api.requires_2fa = True - - def request_sms_fallback() -> bool: - fake_api.two_factor_delivery_method = "sms" - fake_api.two_factor_delivery_notice = ( - "Trusted-device prompt failed; falling back to SMS." - ) - return True - - fake_api.request_2fa_code.side_effect = request_sms_fallback - - with patch.object(context_module.typer, "prompt", return_value="123456"): - result = _invoke(fake_api, "auth", "login", interactive=True) - - assert result.exit_code == 0 - assert "Trusted-device prompt failed; falling back to SMS." in result.stdout - assert "Requested a 2FA code by SMS." in result.stdout - fake_api.validate_2fa_code.assert_called_once_with("123456") - - -def test_sms_2fa_request_failure_aborts() -> None: - """Auth login should surface SMS delivery request failures clearly.""" - - fake_api = FakeAPI() - fake_api.requires_2fa = True - fake_api.request_2fa_code.side_effect = context_module.PyiCloudAPIResponseException( - "sms request failed" - ) - - result = _invoke(fake_api, "auth", "login", interactive=True) - - assert result.exit_code != 0 - assert result.exception.args[0] == "Failed to request the 2FA SMS code." - fake_api.validate_2fa_code.assert_not_called() - - -def test_trusted_device_2fa_request_failure_aborts() -> None: - """Auth login should surface bridge delivery failures clearly.""" - - fake_api = FakeAPI() - fake_api.requires_2fa = True - fake_api.request_2fa_code.side_effect = ( - context_module.PyiCloudTrustedDevicePromptException("bridge failed") - ) - - result = _invoke(fake_api, "auth", "login", interactive=True) - - assert result.exit_code != 0 - assert result.exception.args[0] == ( - "Failed to request the 2FA trusted-device prompt." - ) - fake_api.validate_2fa_code.assert_not_called() - - -def test_trusted_device_2fa_verification_failure_aborts() -> None: - """Auth login should surface bridge verification failures clearly.""" - - fake_api = FakeAPI() - fake_api.requires_2fa = True - - def request_prompt() -> bool: - fake_api.two_factor_delivery_method = "trusted_device" - return True - - fake_api.request_2fa_code.side_effect = request_prompt - fake_api.validate_2fa_code.side_effect = ( - context_module.PyiCloudTrustedDeviceVerificationException( - "bridge verification failed" - ) - ) - - with patch.object(context_module.typer, "prompt", return_value="123456"): - result = _invoke(fake_api, "auth", "login", interactive=True) - - assert result.exit_code != 0 - assert result.exception.args[0] == ("Failed to verify the 2FA trusted-device code.") - - def test_notes_commands() -> None: """Notes commands should expose list, detail, render, export, and sync flows."""