diff --git a/mobilerun/agent/utils/oauth/anthropic_oauth_llm.py b/mobilerun/agent/utils/oauth/anthropic_oauth_llm.py index ac6f5f46..ce723f07 100644 --- a/mobilerun/agent/utils/oauth/anthropic_oauth_llm.py +++ b/mobilerun/agent/utils/oauth/anthropic_oauth_llm.py @@ -3,6 +3,7 @@ import json import os import secrets +import sys import threading import time import webbrowser @@ -63,6 +64,18 @@ def _pkce_pair() -> tuple[str, str]: return verifier, challenge +def _is_headless_environment() -> bool: + """Detect SSH, WSL, or missing display where browser popups won't work.""" + if os.environ.get("SSH_CONNECTION") or os.environ.get("SSH_TTY"): + return True + if os.environ.get("WSL_DISTRO_NAME"): + return True + if sys.platform.startswith("linux"): + if not os.environ.get("DISPLAY") and not os.environ.get("WAYLAND_DISPLAY"): + return True + return False + + def _normalize_manual_code(raw: str, expected_state: str) -> str: value = raw.strip() if not value: @@ -70,10 +83,17 @@ def _normalize_manual_code(raw: str, expected_state: str) -> str: first_token = value.split()[0] - if "code=" in first_token: + if "error=" in first_token or "code=" in first_token: parsed = urlparse(first_token) params = parse_qs(parsed.query) + error = params.get("error", [None])[0] + if error: + desc = params.get("error_description", [error])[0] + raise RuntimeError(f"OAuth error: {desc}") code = params.get("code", [None])[0] + state_from_url = params.get("state", [None])[0] + if state_from_url and state_from_url != expected_state: + raise RuntimeError("OAuth manual code state mismatch.") if isinstance(code, str) and code: return code @@ -392,6 +412,8 @@ def login( expires_in: Optional[int] = None, ) -> str: result: Dict[str, Optional[str]] = {"code": None, "state": None, "error": None} + manual_code: Dict[str, Optional[str]] = {"code": None} + manual_failed = threading.Event() done = threading.Event() code_verifier, code_challenge = _pkce_pair() @@ -431,7 +453,18 @@ def do_GET(self) -> None: # noqa: N802 def log_message(self, format: str, *args: Any) -> None: # noqa: A003 return - httpd = HTTPServer((callback_host, callback_port), _OAuthHandler) + try: + httpd = HTTPServer((callback_host, callback_port), _OAuthHandler) + except OSError as exc: + self.authorize_url = original_authorize_url + print( + f"Could not bind callback server on {callback_host}:{callback_port} ({exc}). " + "Falling back to manual code entry." + ) + return self.login_manual( + open_browser=open_browser, expires_in=expires_in + ) + actual_port = httpd.server_address[1] redirect_uri = f"http://localhost:{actual_port}{callback_path}" auth_url = self._build_auth_url( @@ -442,23 +475,78 @@ def log_message(self, format: str, *args: Any) -> None: # noqa: A003 server_thread = threading.Thread(target=httpd.serve_forever, daemon=True) server_thread.start() + try: + print(f"Open this URL to login:\n{auth_url}\n") if open_browser: webbrowser.open(auth_url) - else: - print(f"Open this URL to login:\n{auth_url}") + + # Only run the manual-paste race when we can't rely on the local + # browser callback: headless envs (SSH/WSL/no-display), or when the + # user explicitly opts in with DROIDRUN_OAUTH_MANUAL=1. On a normal + # desktop the server always wins anyway, and a blocked input() + # thread would intercept InquirerPy's terminal queries and lag the + # configure wizard. + enable_manual = _is_headless_environment() or os.environ.get( + "DROIDRUN_OAUTH_MANUAL", "" + ).lower() in ("1", "true", "yes") + if enable_manual: + def _read_manual() -> None: + for attempt in range(2): + if done.is_set(): + return + try: + raw = str(input("Or paste the redirect URL / authorization code: ")) + except Exception: + return + if done.is_set(): + return + if not raw.strip(): + if attempt == 0: + print("Invalid paste. Try again.") + continue + if not done.is_set(): + manual_failed.set() + done.set() + return + try: + code = _normalize_manual_code(raw, state) + except Exception: # noqa: BLE001 + if attempt == 0: + print("Invalid paste. Try again.") + continue + print("Invalid paste.") + if not done.is_set(): + manual_failed.set() + done.set() + return + if code: + manual_code["code"] = code + done.set() + return + + manual_thread = threading.Thread(target=_read_manual, daemon=True) + manual_thread.start() if not done.wait(timeout=timeout_seconds): raise TimeoutError("OAuth login timed out before callback was received.") - if result["error"]: - raise RuntimeError(f"OAuth callback returned error: {result['error']}") - if result["state"] != state: - raise RuntimeError("OAuth callback state mismatch.") - if not result["code"]: - raise RuntimeError("OAuth callback did not include an authorization code.") + + if manual_failed.is_set(): + raise RuntimeError("Login failed.") + + if manual_code["code"]: + code_to_exchange = manual_code["code"] + else: + if result["error"]: + raise RuntimeError(f"OAuth callback returned error: {result['error']}") + if result["state"] != state: + raise RuntimeError("OAuth callback state mismatch.") + if not result["code"]: + raise RuntimeError("OAuth callback did not include an authorization code.") + code_to_exchange = result["code"] return self._exchange_authorization_code( - code=result["code"], + code=code_to_exchange, redirect_uri=redirect_uri, code_verifier=code_verifier, state=state, @@ -491,26 +579,38 @@ def login_manual( ) try: + print(f"Open this URL to login:\n{auth_url}") if open_browser: webbrowser.open(auth_url) - print(f"Open this URL to login:\n{auth_url}") - code = _normalize_manual_code( - str(input_fn("Paste authorization code: ")), - state, - ) - if not code: - raise ValueError("Authorization code was empty.") - - return self._exchange_authorization_code( - code=code, - redirect_uri=redirect_uri, - code_verifier=code_verifier, - state=state, - expires_in=expires_in, - ) + for attempt in range(2): + raw = str(input_fn("Paste the redirect URL or authorization code: ")) + if not raw.strip(): + if attempt == 0: + print("Invalid paste. Try again.") + continue + raise RuntimeError("Login failed.") + try: + code = _normalize_manual_code(raw, state) + except Exception: # noqa: BLE001 + if attempt == 0: + print("Invalid paste. Try again.") + continue + raise RuntimeError("Login failed.") + if code: + return self._exchange_authorization_code( + code=code, + redirect_uri=redirect_uri, + code_verifier=code_verifier, + state=state, + expires_in=expires_in, + ) + if attempt == 0: + print("Invalid paste. Try again.") + continue + raise RuntimeError("Login failed.") + raise RuntimeError("Login failed.") finally: self.authorize_url = original_authorize_url - def _resolve_access_token(self) -> str: env_access_token = os.environ.get("ANTHROPIC_OAUTH_TOKEN") diff --git a/mobilerun/agent/utils/oauth/gemini_oauth_code_assist_llm.py b/mobilerun/agent/utils/oauth/gemini_oauth_code_assist_llm.py index 5a2383bc..a4ec51f3 100644 --- a/mobilerun/agent/utils/oauth/gemini_oauth_code_assist_llm.py +++ b/mobilerun/agent/utils/oauth/gemini_oauth_code_assist_llm.py @@ -1,7 +1,9 @@ +import base64 +import hashlib import json import os import secrets -import socket +import sys import threading import time import webbrowser @@ -42,6 +44,62 @@ ) DEFAULT_CLIENT_SECRET = "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" +# LlamaIndex-internal kwargs that must never be forwarded to Google's API. +_IGNORED_REQUEST_KWARGS = {"formatted"} + + +def _b64_no_pad(raw: bytes) -> str: + return base64.urlsafe_b64encode(raw).decode("utf-8").rstrip("=") + + +def _pkce_pair() -> tuple[str, str]: + verifier = _b64_no_pad(secrets.token_bytes(64)) + challenge = _b64_no_pad(hashlib.sha256(verifier.encode("utf-8")).digest()) + return verifier, challenge + + +def _is_headless_environment() -> bool: + """Detect SSH, WSL, or missing display where browser popups won't work.""" + if os.environ.get("SSH_CONNECTION") or os.environ.get("SSH_TTY"): + return True + if os.environ.get("WSL_DISTRO_NAME"): + return True + if sys.platform.startswith("linux"): + if not os.environ.get("DISPLAY") and not os.environ.get("WAYLAND_DISPLAY"): + return True + return False + + +def _normalize_manual_code(raw: str, expected_state: str) -> str: + """Parse pasted input: full URL with code= param, code#state, or bare code.""" + value = raw.strip() + if not value: + return value + + first_token = value.split()[0] + + if "error=" in first_token or "code=" in first_token: + parsed = urlparse(first_token) + params = parse_qs(parsed.query) + error = params.get("error", [None])[0] + if error: + desc = params.get("error_description", [error])[0] + raise RuntimeError(f"OAuth error: {desc}") + code = params.get("code", [None])[0] + state_from_url = params.get("state", [None])[0] + if state_from_url and state_from_url != expected_state: + raise RuntimeError("OAuth manual code state mismatch.") + if isinstance(code, str) and code: + return code + + if "#" in first_token: + code_part, fragment = first_token.split("#", 1) + if fragment and fragment != expected_state: + raise RuntimeError("OAuth manual code state mismatch.") + return code_part + + return first_token + class GeminiOAuthCodeAssistLLM(CustomLLM): """Gemini OAuth LLM that talks to Google Code Assist endpoints. @@ -125,8 +183,13 @@ def __init__( selected_model = custom_model if not selected_model: if model in self.MODEL_PRESETS: + # Passed a preset key like "pro_preview" → resolve to actual id. selected_model = self.MODEL_PRESETS[model] + elif model and model != DEFAULT_MODEL: + # Explicit model name from config/CLI → honor it verbatim. + selected_model = model elif model_preset in self.MODEL_PRESETS: + # Fall back to preset only when no explicit model was provided. selected_model = self.MODEL_PRESETS[model_preset] else: selected_model = model @@ -272,14 +335,26 @@ def _metadata_payload(self) -> Dict[str, str]: "pluginType": "GEMINI", } - def _ensure_project_id(self, token: str) -> Optional[str]: - if self.project_id: - return self.project_id + def _build_headers(self, token: str) -> Dict[str, str]: + """Build Code Assist request headers matching gemini-cli expectations. - headers = { + The private v1internal endpoint requires the X-Goog-Api-Client and + Client-Metadata headers to identify the caller as a gemini-cli-style + client; without them, requests return 400. + """ + return { "Authorization": f"Bearer {token}", "Content-Type": "application/json", + "User-Agent": "google-cloud-sdk vscode_cloudshelleditor/0.1", + "X-Goog-Api-Client": "gl-node/22.17.0", + "Client-Metadata": json.dumps(self._metadata_payload()), } + + def _ensure_project_id(self, token: str) -> Optional[str]: + if self.project_id: + return self.project_id + + headers = self._build_headers(token) metadata = self._metadata_payload() response = self._session.post( self._method_url(DEFAULT_CODE_ASSIST_LOAD_METHOD), @@ -359,16 +434,25 @@ def _refresh_access_token(self) -> str: return access_token - def _exchange_authorization_code(self, code: str, redirect_uri: str) -> str: + def _exchange_authorization_code( + self, + code: str, + redirect_uri: str, + code_verifier: Optional[str] = None, + ) -> str: + payload = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": self.client_id, + "client_secret": self.client_secret, + } + if code_verifier: + payload["code_verifier"] = code_verifier + response = self._session.post( self.token_url, - data={ - "grant_type": "authorization_code", - "code": code, - "redirect_uri": redirect_uri, - "client_id": self.client_id, - "client_secret": self.client_secret, - }, + data=payload, timeout=self.timeout, ) response.raise_for_status() @@ -395,7 +479,13 @@ def _exchange_authorization_code(self, code: str, redirect_uri: str) -> str: self._persist_credentials() return access_token - def _build_auth_url(self, redirect_uri: str, state: str, prompt_consent: bool) -> str: + def _build_auth_url( + self, + redirect_uri: str, + state: str, + prompt_consent: bool, + code_challenge: Optional[str] = None, + ) -> str: scope = " ".join( [ "https://www.googleapis.com/auth/cloud-platform", @@ -413,6 +503,9 @@ def _build_auth_url(self, redirect_uri: str, state: str, prompt_consent: bool) - } if prompt_consent: query["prompt"] = "consent" + if code_challenge: + query["code_challenge"] = code_challenge + query["code_challenge_method"] = "S256" return f"{self.authorize_url}?{urlencode(query)}" def login( @@ -426,8 +519,11 @@ def login( prompt_consent: bool = True, ) -> str: result: Dict[str, Optional[str]] = {"code": None, "state": None, "error": None} + manual_code: Dict[str, Optional[str]] = {"code": None} + manual_failed = threading.Event() done = threading.Event() expected_state = secrets.token_hex(32) + code_verifier, code_challenge = _pkce_pair() class _OAuthHandler(BaseHTTPRequestHandler): def do_GET(self) -> None: # noqa: N802 @@ -459,39 +555,158 @@ def do_GET(self) -> None: # noqa: N802 def log_message(self, format: str, *args: Any) -> None: # noqa: A003 return - httpd = HTTPServer((callback_host, callback_port), _OAuthHandler) + try: + httpd = HTTPServer((callback_host, callback_port), _OAuthHandler) + except OSError as exc: + print( + f"Could not bind callback server on {callback_host}:{callback_port} ({exc}). " + "Falling back to manual code entry." + ) + return self.login_manual( + open_browser=open_browser, prompt_consent=prompt_consent + ) + actual_port = httpd.server_address[1] redirect_uri = f"http://127.0.0.1:{actual_port}{callback_path}" auth_url = self._build_auth_url( redirect_uri=redirect_uri, state=expected_state, prompt_consent=prompt_consent, + code_challenge=code_challenge, ) server_thread = threading.Thread(target=httpd.serve_forever, daemon=True) server_thread.start() try: + print(f"Open this URL to login:\n{auth_url}\n") if open_browser: webbrowser.open(auth_url) - else: - print(f"Open this URL to login:\n{auth_url}") + + # Only run the manual-paste race when we can't rely on the local + # browser callback: headless envs (SSH/WSL/no-display), or when the + # user explicitly opts in with DROIDRUN_OAUTH_MANUAL=1. On a normal + # desktop the server always wins anyway, and a blocked input() + # thread would intercept InquirerPy's terminal queries and lag the + # configure wizard. + enable_manual = _is_headless_environment() or os.environ.get( + "DROIDRUN_OAUTH_MANUAL", "" + ).lower() in ("1", "true", "yes") + if enable_manual: + def _read_manual() -> None: + for attempt in range(2): + if done.is_set(): + return + try: + raw = str(input("Or paste the redirect URL / authorization code: ")) + except Exception: + return + if done.is_set(): + return + if not raw.strip(): + if attempt == 0: + print("Invalid paste. Try again.") + continue + if not done.is_set(): + manual_failed.set() + done.set() + return + try: + code = _normalize_manual_code(raw, expected_state) + except Exception: # noqa: BLE001 + if attempt == 0: + print("Invalid paste. Try again.") + continue + print("Invalid paste.") + if not done.is_set(): + manual_failed.set() + done.set() + return + if code: + manual_code["code"] = code + done.set() + return + + manual_thread = threading.Thread(target=_read_manual, daemon=True) + manual_thread.start() if not done.wait(timeout=timeout_seconds): raise TimeoutError("OAuth login timed out before callback was received.") - if result["error"]: - raise RuntimeError(f"OAuth callback returned error: {result['error']}") - if result["state"] != expected_state: - raise RuntimeError("OAuth callback state mismatch.") - if not result["code"]: - raise RuntimeError("OAuth callback did not include an authorization code.") + if manual_failed.is_set(): + raise RuntimeError("Login failed.") - return self._exchange_authorization_code(result["code"], redirect_uri) + if manual_code["code"]: + code_to_exchange = manual_code["code"] + else: + if result["error"]: + raise RuntimeError(f"OAuth callback returned error: {result['error']}") + if result["state"] != expected_state: + raise RuntimeError("OAuth callback state mismatch.") + if not result["code"]: + raise RuntimeError("OAuth callback did not include an authorization code.") + code_to_exchange = result["code"] + + return self._exchange_authorization_code( + code_to_exchange, redirect_uri, code_verifier=code_verifier + ) finally: httpd.shutdown() httpd.server_close() + def login_manual( + self, + *, + open_browser: bool = True, + input_fn: Any = input, + prompt_consent: bool = True, + ) -> str: + """Manual OAuth flow for headless/VPS/WSL environments. + + Opens (or prints) the auth URL and prompts the user to paste the + redirected URL or bare authorization code from the browser. + """ + code_verifier, code_challenge = _pkce_pair() + expected_state = secrets.token_hex(32) + # Google allows any loopback redirect for installed apps. The browser + # will fail to load the page, but the URL bar will contain the code. + redirect_uri = "http://localhost/oauth2callback" + + auth_url = self._build_auth_url( + redirect_uri=redirect_uri, + state=expected_state, + prompt_consent=prompt_consent, + code_challenge=code_challenge, + ) + + print(f"Open this URL to login:\n{auth_url}") + if open_browser: + webbrowser.open(auth_url) + + for attempt in range(2): + raw = str(input_fn("Paste the redirect URL or authorization code: ")) + if not raw.strip(): + if attempt == 0: + print("Invalid paste. Try again.") + continue + raise RuntimeError("Login failed.") + try: + code = _normalize_manual_code(raw, expected_state) + except Exception: # noqa: BLE001 + if attempt == 0: + print("Invalid paste. Try again.") + continue + raise RuntimeError("Login failed.") + if code: + return self._exchange_authorization_code( + code, redirect_uri, code_verifier=code_verifier + ) + if attempt == 0: + print("Invalid paste. Try again.") + continue + raise RuntimeError("Login failed.") + raise RuntimeError("Login failed.") + def _resolve_access_token(self) -> str: env_access_token = os.environ.get("GEMINI_OAUTH_ACCESS_TOKEN") if env_access_token: @@ -566,11 +781,18 @@ def _to_code_assist_request( payload: Dict[str, Any] = { "model": self.model, "request": request, + "userAgent": "droidrun", + "requestId": f"droidrun-{int(time.time() * 1000)}-{secrets.token_hex(4)}", } if self.project_id: payload["project"] = self.project_id - payload.update(kwargs) + # Strip LlamaIndex-internal kwargs (e.g. ``formatted``) that Google's + # Code Assist API rejects as unknown fields. + safe_kwargs = { + k: v for k, v in kwargs.items() if k not in _IGNORED_REQUEST_KWARGS + } + payload.update(safe_kwargs) return payload def _method_url(self, method: str) -> str: @@ -598,31 +820,76 @@ def _extract_text(chunk: Dict[str, Any]) -> str: @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + # Google's Code Assist private API (v1internal) does not reliably + # accept the non-streaming generateContent endpoint for every model. + # Route chat through streamGenerateContent and accumulate the stream, + # matching the pattern used by gemini-cli / OpenClaw. token = self._resolve_access_token() self._ensure_project_id(token) payload = self._to_code_assist_request(messages, **kwargs) response = self._session.post( - self._method_url("generateContent"), + self._method_url("streamGenerateContent"), + params={"alt": "sse"}, headers={ - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", + **self._build_headers(token), + "Accept": "text/event-stream", }, json=payload, timeout=self.timeout, + stream=True, ) - response.raise_for_status() + if not response.ok: + raise requests.HTTPError( + f"Code Assist {response.status_code} error: {response.text}", + response=response, + ) - raw = response.json() - text = self._extract_text(raw) + accumulated = "" + last_raw: Dict[str, Any] = {} + buffer: list[str] = [] + + def _flush(buffer: list[str]) -> Optional[Dict[str, Any]]: + if not buffer: + return None + chunk_text = "\n".join(buffer) + try: + return json.loads(chunk_text) + except json.JSONDecodeError: + return None + + for line in response.iter_lines(decode_unicode=True): + if line is None: + continue + stripped = line.strip() + if stripped.startswith("data:"): + buffer.append(stripped[5:].strip()) + continue + if stripped != "" or not buffer: + continue + raw_chunk = _flush(buffer) + buffer = [] + if raw_chunk is None: + continue + last_raw = raw_chunk + delta = self._extract_text(raw_chunk) + if delta: + accumulated += delta + + raw_chunk = _flush(buffer) + if raw_chunk is not None: + last_raw = raw_chunk + delta = self._extract_text(raw_chunk) + if delta: + accumulated += delta return ChatResponse( - message=ChatMessage(role=MessageRole.ASSISTANT, content=text), - raw=raw, + message=ChatMessage(role=MessageRole.ASSISTANT, content=accumulated), + raw=last_raw, additional_kwargs={ - "trace_id": raw.get("traceId"), - "usage": (raw.get("response") or {}).get("usageMetadata"), - "model_version": (raw.get("response") or {}).get("modelVersion"), + "trace_id": last_raw.get("traceId"), + "usage": (last_raw.get("response") or {}).get("usageMetadata"), + "model_version": (last_raw.get("response") or {}).get("modelVersion"), }, ) @@ -648,8 +915,8 @@ def stream_chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatRes self._method_url("streamGenerateContent"), params={"alt": "sse"}, headers={ - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", + **self._build_headers(token), + "Accept": "text/event-stream", }, json=payload, timeout=self.timeout, diff --git a/mobilerun/agent/utils/oauth/openai_oauth_llm.py b/mobilerun/agent/utils/oauth/openai_oauth_llm.py index 394724cb..f2159bdc 100644 --- a/mobilerun/agent/utils/oauth/openai_oauth_llm.py +++ b/mobilerun/agent/utils/oauth/openai_oauth_llm.py @@ -19,6 +19,7 @@ import json import os import secrets +import sys import threading import time from dataclasses import dataclass @@ -60,6 +61,94 @@ def _pkce_pair() -> tuple[str, str]: return verifier, challenge +def _is_headless_environment() -> bool: + """Detect SSH, WSL, or missing display where browser popups won't work.""" + if os.environ.get("SSH_CONNECTION") or os.environ.get("SSH_TTY"): + return True + if os.environ.get("WSL_DISTRO_NAME"): + return True + if sys.platform.startswith("linux"): + if not os.environ.get("DISPLAY") and not os.environ.get("WAYLAND_DISPLAY"): + return True + return False + + +def _normalize_manual_code(raw: str, expected_state: str) -> str: + """Parse pasted input: full URL with code= param, code#state, or bare code.""" + value = raw.strip() + if not value: + return value + + first_token = value.split()[0] + + if "error=" in first_token or "code=" in first_token: + parsed = urlparse(first_token) + params = parse_qs(parsed.query) + error = params.get("error", [None])[0] + if error: + desc = params.get("error_description", [error])[0] + raise RuntimeError(f"OAuth error: {desc}") + code = params.get("code", [None])[0] + state_from_url = params.get("state", [None])[0] + if state_from_url and state_from_url != expected_state: + raise RuntimeError("OAuth manual code state mismatch.") + if isinstance(code, str) and code: + return code + + if "#" in first_token: + code_part, fragment = first_token.split("#", 1) + if fragment and fragment != expected_state: + raise RuntimeError("OAuth manual code state mismatch.") + return code_part + + return first_token + + +def _tls_preflight(issuer: str, timeout: float = 5.0) -> None: + """Probe the OAuth issuer to detect TLS/certificate issues before login. + + Raises RuntimeError on TLS certificate errors (with fix suggestions). + Prints warnings for non-TLS connection errors but does not block. + """ + probe_url = f"{issuer.rstrip('/')}/oauth/authorize" + try: + httpx.head(probe_url, follow_redirects=False, timeout=timeout) + except httpx.ConnectError as exc: + err_str = str(exc).lower() + tls_indicators = ( + "certificate", + "ssl", + "tls", + "unable_to_get_issuer_cert", + "cert_has_expired", + "self_signed_cert", + "verify_leaf_signature", + "altname_invalid", + ) + if any(indicator in err_str for indicator in tls_indicators): + raise RuntimeError( + f"TLS certificate error connecting to {probe_url}: {exc}\n" + "Possible fixes:\n" + " - Update CA certificates " + "(e.g. `sudo update-ca-certificates` on Linux, " + "`brew postinstall ca-certificates` on macOS)\n" + " - Update OpenSSL (e.g. `brew postinstall openssl@3`)\n" + " - Check if a corporate proxy is intercepting HTTPS traffic" + ) from exc + print( + f"Warning: Could not connect to {probe_url}: {exc}\n" + "The login flow may fail if there is a DNS or firewall issue." + ) + except httpx.TimeoutException: + print( + f"Warning: Connection to {probe_url} timed out.\n" + "The login flow may fail if there is a network issue." + ) + except Exception as exc: + # Unexpected error — warn but don't block. + print(f"Warning: TLS preflight check encountered an error: {exc}") + + @dataclass class OpenAIOAuthCredentials: access_token: str @@ -480,7 +569,11 @@ def login( redirect_host: str = DEFAULT_OPENAI_OAUTH_CALLBACK_HOST, scope: str = DEFAULT_OPENAI_OAUTH_SCOPE, ) -> OpenAIOAuthCredentials: + _tls_preflight(self._oauth_manager.issuer) + result: Dict[str, Optional[str]] = {"code": None, "state": None, "error": None} + manual_code: Dict[str, Optional[str]] = {"code": None} + manual_failed = threading.Event() done = threading.Event() code_verifier, code_challenge = _pkce_pair() state = _b64_no_pad(secrets.token_bytes(32)) @@ -515,7 +608,21 @@ def do_GET(self) -> None: # noqa: N802 def log_message(self, format: str, *args: Any) -> None: # noqa: A003 return - httpd = HTTPServer((callback_host, callback_port), _OAuthHandler) + try: + httpd = HTTPServer((callback_host, callback_port), _OAuthHandler) + except OSError as exc: + print( + f"Could not bind callback server on {callback_host}:{callback_port} ({exc}). " + "Falling back to manual code entry." + ) + return self.login_manual( + open_browser=open_browser, + callback_port=callback_port, + callback_path=callback_path, + redirect_host=redirect_host, + scope=scope, + ) + actual_port = httpd.server_address[1] redirect_uri = f"http://{redirect_host}:{actual_port}{callback_path}" auth_url = self._build_auth_url( @@ -529,23 +636,78 @@ def log_message(self, format: str, *args: Any) -> None: # noqa: A003 server_thread = threading.Thread(target=httpd.serve_forever, daemon=True) server_thread.start() + try: + print(f"Open this URL to login:\n{auth_url}\n") if open_browser: webbrowser.open(auth_url) - else: - print(f"Open this URL to login:\n{auth_url}") + + # Only run the manual-paste race when we can't rely on the local + # browser callback: headless envs (SSH/WSL/no-display), or when the + # user explicitly opts in with DROIDRUN_OAUTH_MANUAL=1. On a normal + # desktop the server always wins anyway, and a blocked input() + # thread would intercept InquirerPy's terminal queries and lag the + # configure wizard. + enable_manual = _is_headless_environment() or os.environ.get( + "DROIDRUN_OAUTH_MANUAL", "" + ).lower() in ("1", "true", "yes") + if enable_manual: + def _read_manual() -> None: + for attempt in range(2): + if done.is_set(): + return + try: + raw = str(input("Or paste the redirect URL / authorization code: ")) + except Exception: + return + if done.is_set(): + return + if not raw.strip(): + if attempt == 0: + print("Invalid paste. Try again.") + continue + if not done.is_set(): + manual_failed.set() + done.set() + return + try: + code = _normalize_manual_code(raw, state) + except Exception: # noqa: BLE001 + if attempt == 0: + print("Invalid paste. Try again.") + continue + print("Invalid paste.") + if not done.is_set(): + manual_failed.set() + done.set() + return + if code: + manual_code["code"] = code + done.set() + return + + manual_thread = threading.Thread(target=_read_manual, daemon=True) + manual_thread.start() if not done.wait(timeout=timeout_seconds): raise TimeoutError("OAuth login timed out before callback was received.") - if result["error"]: - raise RuntimeError(f"OAuth callback returned error: {result['error']}") - if result["state"] != state: - raise RuntimeError("OAuth callback state mismatch.") - if not result["code"]: - raise RuntimeError("OAuth callback did not include an authorization code.") + + if manual_failed.is_set(): + raise RuntimeError("Login failed.") + + if manual_code["code"]: + code_to_exchange = manual_code["code"] + else: + if result["error"]: + raise RuntimeError(f"OAuth callback returned error: {result['error']}") + if result["state"] != state: + raise RuntimeError("OAuth callback state mismatch.") + if not result["code"]: + raise RuntimeError("OAuth callback did not include an authorization code.") + code_to_exchange = result["code"] creds = self._oauth_manager.exchange_authorization_code( - code=result["code"], + code=code_to_exchange, redirect_uri=redirect_uri, code_verifier=code_verifier, ) @@ -556,6 +718,67 @@ def log_message(self, format: str, *args: Any) -> None: # noqa: A003 httpd.shutdown() httpd.server_close() + def login_manual( + self, + *, + open_browser: bool = True, + input_fn: Any = input, + callback_port: int = DEFAULT_OPENAI_OAUTH_CALLBACK_PORT, + callback_path: str = DEFAULT_OPENAI_OAUTH_CALLBACK_PATH, + redirect_host: str = DEFAULT_OPENAI_OAUTH_CALLBACK_HOST, + scope: str = DEFAULT_OPENAI_OAUTH_SCOPE, + ) -> OpenAIOAuthCredentials: + """Manual OAuth flow for headless/VPS/WSL environments. + + Uses the same redirect_uri as the browser flow (OpenAI requires + port 1455). The browser will fail to load the redirect page, but + the URL bar will contain the authorization code. + """ + code_verifier, code_challenge = _pkce_pair() + state = _b64_no_pad(secrets.token_bytes(32)) + redirect_uri = f"http://{redirect_host}:{callback_port}{callback_path}" + auth_url = self._build_auth_url( + issuer=self._oauth_manager.issuer, + client_id=self._oauth_manager.client_id, + redirect_uri=redirect_uri, + code_challenge=code_challenge, + state=state, + scope=scope, + ) + + print(f"Open this URL to login:\n{auth_url}") + if open_browser: + webbrowser.open(auth_url) + + for attempt in range(2): + raw = str(input_fn("Paste the redirect URL or authorization code: ")) + if not raw.strip(): + if attempt == 0: + print("Invalid paste. Try again.") + continue + raise RuntimeError("Login failed.") + try: + code = _normalize_manual_code(raw, state) + except Exception: # noqa: BLE001 + if attempt == 0: + print("Invalid paste. Try again.") + continue + raise RuntimeError("Login failed.") + if code: + creds = self._oauth_manager.exchange_authorization_code( + code=code, + redirect_uri=redirect_uri, + code_verifier=code_verifier, + ) + if creds.account_id: + object.__setattr__(self, "_oauth_account_id", creds.account_id) + return creds + if attempt == 0: + print("Invalid paste. Try again.") + continue + raise RuntimeError("Login failed.") + raise RuntimeError("Login failed.") + def _ensure_access_token(self) -> OpenAIOAuthCredentials: creds = self._oauth_manager.get_valid_credentials(skew_ms=self._oauth_refresh_skew_ms)