diff --git a/.dockerignore b/.dockerignore index b98651c..53c45bd 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,8 +1,9 @@ -__pycache__/ -*.pyc -*.pyo -*.pyd -.pytest_cache/ -data/ -secrets/ -tests/ +* +!README.md +!pyproject.toml +!src/ +!src/** +src/**/__pycache__/ +src/**/*.pyc +src/**/*.pyo +src/**/*.pyd diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml new file mode 100644 index 0000000..9cb8fdb --- /dev/null +++ b/.github/workflows/docker-release.yml @@ -0,0 +1,155 @@ +name: docker-release + +on: + pull_request: + release: + types: + - published + workflow_dispatch: + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + +permissions: + contents: read + +concurrency: + group: docker-release-${{ github.ref }} + cancel-in-progress: true + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version-file: pyproject.toml + + - name: Install uv + uses: astral-sh/setup-uv@v7 + with: + enable-cache: true + + - name: Install project + run: uv sync --locked --extra dev + + - name: Run tests + run: uv run pytest -q + + docker-validate: + needs: test + if: github.event_name == 'pull_request' || github.event_name == 'workflow_dispatch' + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v4 + + - name: Build Docker image + uses: docker/build-push-action@v7 + with: + context: . + file: ./Dockerfile + platforms: linux/amd64 + push: false + cache-from: type=gha + cache-to: type=gha,mode=max + + docker-release: + needs: test + if: github.event_name == 'release' + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + attestations: write + id-token: write + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version-file: pyproject.toml + + - name: Validate release tag matches package version + shell: python + run: | + import re + import tomllib + from pathlib import Path + + tag = "${{ github.event.release.tag_name }}" + pyproject = tomllib.loads(Path("pyproject.toml").read_text(encoding="utf-8")) + project_version = str(pyproject["project"]["version"]).strip() + + init_text = Path("src/roborock_local_server/__init__.py").read_text(encoding="utf-8") + match = re.search(r'__version__\s*=\s*"([^"]+)"', init_text) + if match is None: + raise SystemExit("Could not find __version__ in src/roborock_local_server/__init__.py") + + module_version = match.group(1).strip() + expected_tag = f"v{project_version}" + if module_version != project_version: + raise SystemExit( + f"Version mismatch: pyproject.toml={project_version}, __init__.py={module_version}" + ) + if tag != expected_tag: + raise SystemExit(f"Git tag {tag} does not match package version {expected_tag}") + + - name: Log in to GHCR + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v4 + + - name: Extract metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + flavor: latest=false + tags: | + type=semver,pattern={{version}},value=${{ github.event.release.tag_name }} + type=semver,pattern={{major}}.{{minor}},value=${{ github.event.release.tag_name }} + type=semver,pattern={{major}},value=${{ github.event.release.tag_name }} + type=raw,value=latest,enable=${{ startsWith(github.event.release.tag_name, 'v') && !contains(github.event.release.tag_name, '-') }} + type=sha,prefix=sha- + labels: | + org.opencontainers.image.title=roborock-local-server + org.opencontainers.image.description=Private Roborock HTTPS and MQTT stack you run on your own system. + + - name: Build and push Docker image + id: push + uses: docker/build-push-action@v7 + with: + context: . + file: ./Dockerfile + platforms: linux/amd64,linux/arm64 + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max + + - name: Generate artifact attestation + uses: actions/attest@v4 + with: + subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + subject-digest: ${{ steps.push.outputs.digest }} + push-to-registry: true diff --git a/Dockerfile b/Dockerfile index 256f57d..406fd81 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,7 +16,8 @@ RUN mkdir -p /opt/acme.sh \ WORKDIR /app -COPY . /app +COPY pyproject.toml README.md /app/ +COPY src /app/src RUN pip install --no-cache-dir /app diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..b917f7d --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Luke Lashley + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 826ac17..405acb5 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,15 @@ # Roborock Local Server +The best way to support this project is the next time you are buying a Roborock device come back here and use one of my affiliate links where I will receive a commission. -If this project helps, you can support it or next time you buy a Roborock device, come back here and use my affiliate links! +[![Amazon Affiliate][badge-amazon]][link-amazon] +[![Roborock 5 Off][badge-roborock-discount]][link-roborock-discount] +[![Roborock Affiliate][badge-roborock-affiliate]][link-roborock-affiliate] + + +You can also support via BMAC or paypal: [![Buy Me a Coffee][badge-bmac]][link-bmac] [![PayPal][badge-paypal]][link-paypal] -[![Roborock 5 Off][badge-roborock-discount]][link-roborock-discount] -[![Roborock Affiliate][badge-roborock-affiliate]][link-roborock-affiliate] -[![Amazon Affiliate][badge-amazon]][link-amazon] Roborock Local Server is a private Roborock HTTPS and MQTT stack you run on your own system. @@ -44,12 +47,21 @@ Additional docs: ## Acknowledgements - [Dennis Giese (@dgiese)](https://dontvacuum.me/) whose research and papers inspired much of the work on reverse-engineering Roborock vacuums +- [Sören Beye (@Hypfer)](https://github.com/Hypfer) creator of [Valetudo](https://valetudo.cloud/), whose work on cloud-free vacuum control has been foundational for this whole space. - [@rovo89](https://github.com/rovo89) who has been VERY helpful through this process, giving lots of tips and advice. - [python-miio](https://github.com/rytilahti/python-miio) - Their repo was the basis for a lot of python-roborock's logic. - [@humbertogontijo](https://github.com/humbertogontijo) who first created the python-roborock repo. - [@allenporter](https://github.com/allenporter) who has taken up a significant role in the maintenance of the python-roborock library as well as the Roborock integration. The improvements Allen has made to the repository cannot be overstated. - [@rccoleman](https://github.com/rccoleman) who was the first beta tester and helped work out some kinks! +## Disclaimer + +This software is provided "as is", without warranty of any kind. Running this stack involves modifying how your Roborock vacuum communicates with the network. You are solely responsible for any damage to your hardware, data loss, network exposure, or other consequences. Use at your own risk. This project is not affiliated with, endorsed by, or sponsored by Roborock. + +## License + +This project is licensed under the MIT License — see [LICENSE](LICENSE) for details. + [link-bmac]: https://buymeacoffee.com/lashl [badge-bmac]: https://img.shields.io/badge/Buy%20Me%20a%20Coffee-donate-yellow?style=for-the-badge&logo=buymeacoffee&logoColor=black [link-paypal]: https://paypal.me/LLashley304 @@ -58,5 +70,5 @@ Additional docs: [badge-roborock-discount]: https://img.shields.io/badge/Roborock-5%25%20Off-C00000?style=for-the-badge [link-roborock-affiliate]: https://roborock.pxf.io/B0VYV9 [badge-roborock-affiliate]: https://img.shields.io/badge/Roborock-affiliate-B22222?style=for-the-badge -[link-amazon]: https://amzn.to/4bGfG6B +[link-amazon]: https://amzn.to/4cx8zg3 [badge-amazon]: https://img.shields.io/badge/Amazon-affiliate-FF9900?style=for-the-badge&logo=amazon&logoColor=white diff --git a/compose.yaml b/compose.yaml index b9521d9..d69edd8 100644 --- a/compose.yaml +++ b/compose.yaml @@ -6,14 +6,14 @@ services: container_name: roborock-local-server restart: unless-stopped ports: - - "443:443" - - "8883:8883" + - "${ROBOROCK_SERVER_HTTPS_PORT:-555}:${ROBOROCK_SERVER_HTTPS_PORT:-555}" + - "${ROBOROCK_SERVER_MQTT_TLS_PORT:-8881}:${ROBOROCK_SERVER_MQTT_TLS_PORT:-8881}" volumes: - ./config.toml:/app/config.toml:ro - ./data:/data - ./secrets/cloudflare_token:/run/secrets/cloudflare_token:ro healthcheck: - test: ["CMD", "curl", "-skf", "https://127.0.0.1/admin"] + test: ["CMD", "curl", "-skf", "https://127.0.0.1:${ROBOROCK_SERVER_HTTPS_PORT:-555}/admin"] interval: 30s timeout: 5s retries: 5 diff --git a/config.example.toml b/config.example.toml index c387224..52614eb 100644 --- a/config.example.toml +++ b/config.example.toml @@ -1,9 +1,10 @@ [network] -# The one hostname the stack will serve. +# The one hostname the stack will serve. Keep this as the hostname only. stack_fqdn = "roborock.example.com" bind_host = "0.0.0.0" -https_port = 443 -mqtt_tls_port = 8883 +# Change these if you need the stack to advertise and listen on custom ports. +https_port = 555 +mqtt_tls_port = 8881 region = "us" [broker] @@ -31,3 +32,7 @@ acme_server = "zerossl" password_hash = "pbkdf2_sha256$600000$replace_me$replace_me" session_secret = "replace-with-at-least-24-random-characters" session_ttl_seconds = 86400 +protocol_auth_enabled = true +# Home Assistant/app logins use this email plus a local 6-digit PIN entered as the "code". +protocol_login_email = "you@example.com" +protocol_login_pin_hash = "pbkdf2_sha256$600000$replace_me$replace_me" diff --git a/docs/home_assistant.md b/docs/home_assistant.md index 05e1f92..6bcd832 100644 --- a/docs/home_assistant.md +++ b/docs/home_assistant.md @@ -6,10 +6,12 @@ To use this server with Home Assistant, edit your config entry at `config/.stora Find `"roborock.com"` and replace the endpoint values with your local stack URLs: -- `base_url` -> `https://api-roborock.example.com` -- `"a"` -> `https://api-roborock.example.com` -- `"l"` -> `https://api-roborock.example.com` -- `"m"` -> `ssl://mqtt-roborock.example.com:8883` +- `base_url` -> `https://api-roborock.example.com:555` +- `"a"` -> `https://api-roborock.example.com:555` +- `"l"` -> `https://api-roborock.example.com:555` +- `"m"` -> `ssl://mqtt-roborock.example.com:8881` + +If you changed `network.https_port` or `network.mqtt_tls_port`, use those values instead. ## Related Docs diff --git a/docs/installation.md b/docs/installation.md index b831068..b3ddba2 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -8,7 +8,7 @@ Start here for a first-time setup. After the stack is running, continue with [On - Python (I recommend installing [uv](https://docs.astral.sh/uv/getting-started/installation/)) - Two machines - one to run the server and one to do the onboarding - A domain name that you own -- A machine that can host this with ports `443` and `8883` exposed internally on your network +- A machine that can host the stack's HTTPS and MQTT TLS ports internally on your network. The defaults are `555` and `8881`. - A Cloudflare API token with DNS edit access for the zone if you want Cloudflare DNS-01 auto-renew. See [Cloudflare setup](cloudflare_setup.md). ## Network Setup @@ -48,9 +48,11 @@ uv run roborock-local-server configure The wizard asks only for: - your `stack_fqdn` (the URL for your server - must start with `api-`) +- your HTTPS and MQTT TLS ports if you do not want the defaults `555` and `8881` - embedded MQTT or your own broker - whether to use Cloudflare DNS-01 auto-renew - your admin password +- your Home Assistant/app login email and 6-digit PIN It then writes `config.toml`, generates `admin.password_hash` and `admin.session_secret`, and if you chose Cloudflare it also writes `secrets/cloudflare_token`. @@ -64,7 +66,15 @@ It then writes `config.toml`, generates `admin.password_hash` and `admin.session docker compose up -d --build ``` -8. Go to the admin dashboard: https://api-roborock.example.com/admin (Replace with your real domain.) + If you changed `network.https_port` or `network.mqtt_tls_port` in `config.toml`, set matching Docker Compose variables before you start the stack so the published ports stay aligned. For example: + + ```bash + ROBOROCK_SERVER_HTTPS_PORT=8443 + ROBOROCK_SERVER_MQTT_TLS_PORT=9443 + docker compose up -d --build + ``` + +8. Go to the admin dashboard: `https://api-roborock.example.com:555/admin` by default, or `https://api-roborock.example.com:YOUR_HTTPS_PORT/admin` if you chose a custom HTTPS port. 9. Import your data from the cloud so things like routines and rooms will work. Enter your email in under cloud import, then hit send code. Once the code is returned enter the code and hit fetch data. diff --git a/docs/onboarding.md b/docs/onboarding.md index c70615a..639e4dc 100644 --- a/docs/onboarding.md +++ b/docs/onboarding.md @@ -12,6 +12,8 @@ Run onboarding from a second machine, not from the machine hosting the local ser uv run start_onboarding.py --server api-roborock.example.com ``` +If you omit the port, the CLI assumes the default local stack HTTPS port `555`. If your stack uses a custom HTTPS port, include it in `--server`, for example `api-roborock.example.com:8443`. + This is a standalone script — you can copy `start_onboarding.py` to any machine and run it with just `uv`. The guided CLI will: @@ -44,7 +46,7 @@ You can still pass them explicitly if you prefer: uv run start_onboarding.py --server api-roborock.example.com --ssid "My Wifi" --password "Password123" --timezone "America/New_York" --cst EST5EDT,M3.2.0,M11.1.0 --country-domain us ``` -`server` should be your real stack hostname, usually the same `api-...` hostname you use for `/admin`. +`server` should be your real stack hostname, usually the same `api-...` hostname you use for `/admin`. If you omit the port, the CLI assumes `:555`. Explicit ports are supported, so if your admin page is at `https://api-roborock.example.com:8443/admin`, use `--server api-roborock.example.com:8443`. ## CST Examples diff --git a/docs/roborock_app.md b/docs/roborock_app.md index 5edcf2d..ab98cca 100644 --- a/docs/roborock_app.md +++ b/docs/roborock_app.md @@ -2,6 +2,10 @@ Use this after [Installation](installation.md) and [Onboarding](onboarding.md) if you want the official Roborock app to talk to your local stack. +During the MITM login step, the script now needs to sync the captured protocol-auth session back to your server. Pass `admin.session_secret` from `config.toml` as `--sync-secret`. That sync callback always uses the `--local-api` host and port. + +The launcher now preflights that callback before starting `mitmweb`. If the `--local-api` host cannot be reached, if the TLS certificate does not validate for that host, or if the sync secret is rejected, the script exits immediately instead of letting you proceed into a broken login flow. + ## iPhone 1. Log out of the app on your phone. @@ -9,9 +13,21 @@ Use this after [Installation](installation.md) and [Onboarding](onboarding.md) i 2. On a machine that is not running the server, run the MITM script: ```bash - uv run mitm_redirect.py --local-api api-roborock.example.com + uv run mitm_redirect.py --local-api api-roborock.example.com --sync-secret YOUR_ADMIN_SESSION_SECRET + ``` + + Use the `admin.session_secret` value from `config.toml` for `YOUR_ADMIN_SESSION_SECRET`. + + If you use the default local stack ports, host-only values are fine here: the script assumes HTTPS `:555` and MQTT TLS `:8881`. + + If your stack uses custom ports, include them directly. For example: + + ```bash + uv run mitm_redirect.py --local-api api-roborock.example.com:8443 --local-mqtt api-roborock.example.com:9443 --sync-secret YOUR_ADMIN_SESSION_SECRET ``` + The `--local-api` hostname must resolve from the MITM machine and match the HTTPS certificate served by your local stack. A raw IP such as `127.0.0.1` will fail unless your certificate is valid for that IP. + 3. Install the WireGuard app on your phone. Then tap the plus button in WireGuard, choose to add from QR code, and scan the code at `http://127.0.0.1:8081/#/capture`. 4. Open `mitm.it` in your web browser (iPhone). Follow the instructions there for your device. In Safari, complete all device-specific steps, including installing and trusting the certificate. @@ -100,9 +116,21 @@ Make sure you have the following installed: 8. On a machine that is not running the server, run the MITM script: ```bash - uv run mitm_redirect.py --local-api api-roborock.example.com + uv run mitm_redirect.py --local-api api-roborock.example.com --sync-secret YOUR_ADMIN_SESSION_SECRET + ``` + + Use the `admin.session_secret` value from `config.toml` for `YOUR_ADMIN_SESSION_SECRET`. + + If you use the default local stack ports, host-only values are fine here: the script assumes HTTPS `:555` and MQTT TLS `:8881`. + + If your stack uses custom ports, include them directly. For example: + + ```bash + uv run mitm_redirect.py --local-api api-roborock.example.com:8443 --local-mqtt api-roborock.example.com:9443 --sync-secret YOUR_ADMIN_SESSION_SECRET ``` + The `--local-api` hostname must resolve from the MITM machine and match the HTTPS certificate served by your local stack. A raw IP such as `127.0.0.1` will fail unless your certificate is valid for that IP. + 9. Install the WireGuard app on your phone. Then tap the plus button in WireGuard, choose to add from QR code, and scan the code at `http://127.0.0.1:8081/#/capture`. 10. Open `mitm.it` in your web browser (Android). Follow the instructions there for your device. In Chrome, complete all device-specific steps, including installing and trusting the certificate. diff --git a/docs/support.md b/docs/support.md index 0b03949..7c9f4a8 100644 --- a/docs/support.md +++ b/docs/support.md @@ -13,7 +13,7 @@ Use these links next time you buy a Roborock — it doesn't cost you anything ex - [Roborock Store (5% Off)](https://us.roborock.com/discount/RRSAP202602071713342D18X?redirect=%2Fpages%2Froborock-store%3Fuuid%3DEQe6p1jdZczHEN4Q0nbsG9sZRm0RK1gW5eSM%252FCzcW4Q%253D) - [Roborock Affiliate Link](https://roborock.pxf.io/B0VYV9) -- [Amazon Affiliate Link](https://amzn.to/4bGfG6B) +- [Amazon Affiliate Link](https://amzn.to/4cx8zg3) ## Star the Repo diff --git a/mitm_redirect.py b/mitm_redirect.py index e4e57c6..d5cbe1c 100644 --- a/mitm_redirect.py +++ b/mitm_redirect.py @@ -6,7 +6,7 @@ state comes from your local stack. Usage: - uv run mitm_redirect.py --local-api YOUR_SERVER_HOST [--local-mqtt HOST] [--local-wood HOST] [--mode wireguard] + uv run mitm_redirect.py --local-api YOUR_SERVER_HOST [--local-mqtt HOST] [--local-wood HOST] [--sync-secret SECRET] [--mode wireguard] """ from __future__ import annotations @@ -15,6 +15,11 @@ import json import os import re +import ssl +import tomllib +from urllib.parse import urlsplit +from urllib.error import HTTPError +from urllib.request import Request, urlopen # mitmproxy is only available when loaded as an addon by mitmweb, # not when running this script directly as a CLI launcher. @@ -24,8 +29,17 @@ # Populated by load() from env vars set by the CLI launcher. LOCAL_API: str = "" +LOCAL_API_HOST: str = "" +LOCAL_API_PORT: int | None = None LOCAL_MQTT: str = "" +LOCAL_MQTT_HOST: str = "" +LOCAL_MQTT_PORT: int | None = None LOCAL_WOOD: str = "" +LOCAL_WOOD_HOST: str = "" +LOCAL_WOOD_PORT: int | None = None +LOCAL_SYNC_SECRET: str = "" +DEFAULT_LOCAL_API_PORT = 555 +DEFAULT_LOCAL_MQTT_PORT = 8881 # Domains whose responses are candidates for host rewrite. @@ -68,6 +82,35 @@ } +def _compile_host_patterns(hosts: set[str]) -> tuple[re.Pattern[str], ...]: + return tuple( + re.compile(rf"(?\d+))?", re.IGNORECASE) + for host in hosts + ) + + +_API_REWRITE_PATTERNS = _compile_host_patterns(API_ROUTE_HOSTS) +_MQTT_REWRITE_PATTERNS = _compile_host_patterns( + { + "mqtt.roborock.com", + "mqtt-us.roborock.com", + "mqtt-eu.roborock.com", + "mqtt-cn.roborock.com", + } +) +_MQTT_NUMBERED_REWRITE_PATTERNS = ( + re.compile(r"(?\d+))?", re.IGNORECASE), +) +_WOOD_REWRITE_PATTERNS = _compile_host_patterns( + { + "wood.roborock.com", + "wood-us.roborock.com", + "wood-eu.roborock.com", + "wood-cn.roborock.com", + } +) + + # Keep these on cloud so login/auth keeps working. CLOUD_ONLY_PATH_PREFIXES = ( "/api/v5/auth/", @@ -107,6 +150,25 @@ "/v3/user/", ) +PROTOCOL_AUTH_SYNC_PATH = "/internal/protocol/user-data" +PROTOCOL_AUTH_SYNC_SOURCE = "mitm_cloud_login" +PROTOCOL_AUTH_PREFLIGHT_SOURCE = "mitm_preflight" +LOGIN_SYNC_EXACT_PATHS = { + "/api/v1/loginwithcode", + "/api/v4/auth/email/login/code", + "/api/v4/auth/phone/login/code", + "/api/v4/auth/mobile/login/code", + "/api/v5/auth/email/login/code", + "/api/v5/auth/phone/login/code", + "/api/v5/auth/mobile/login/code", + "/api/v3/auth/email/login", + "/api/v3/auth/phone/login", + "/api/v3/auth/mobile/login", + "/api/v5/auth/email/login/pwd", + "/api/v5/auth/phone/login/pwd", + "/api/v5/auth/mobile/login/pwd", +} + SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) LOG_DIR = "" @@ -114,9 +176,22 @@ LOG_DIR_PASSTHROUGH = "" _seq_rewrite = 0 _seq_passthrough = 0 +_sync_warning_emitted = False _FILENAME_SAFE_RE = re.compile(r'[^A-Za-z0-9._-]+') +class SyncEndpointError(RuntimeError): + def __init__(self, sync_url: str, detail: str, *, status: int | None = None) -> None: + self.sync_url = str(sync_url or "").strip() + self.status = status + self.detail = str(detail or "").strip() or "unknown sync error" + super().__init__(self.detail) + + def __str__(self) -> str: + prefix = f"{self.sync_url}: " if self.sync_url else "" + return f"{prefix}{self.detail}" + + def _next_seq_rewrite() -> int: global _seq_rewrite _seq_rewrite += 1 @@ -141,13 +216,162 @@ def _init_log_dir() -> None: ctx.log.info(f"[LOG] Passthrough -> {LOG_DIR_PASSTHROUGH}") +def _parse_endpoint(value: str, *, fallback_host: str = "", fallback_port: int | None = None) -> tuple[str, int | None]: + raw = str(value or "").strip() + if not raw: + return fallback_host, fallback_port + parsed = urlsplit(raw if "://" in raw else f"//{raw}") + host = (parsed.hostname or parsed.path.split("/", 1)[0]).strip().strip("/") + if not host: + return fallback_host, fallback_port + return host, parsed.port if parsed.port is not None else fallback_port + + +def _format_authority(host: str, port: int | None, *, default_port: int | None = None) -> str: + normalized_host = str(host or "").strip().strip("/") + if not normalized_host: + return "" + if port is None: + return normalized_host + if default_port is not None and port == default_port: + return normalized_host + return f"{normalized_host}:{port}" + + +def _sync_callback_url(local_api: str) -> str: + authority = str(local_api or "").strip().strip("/") + return f"https://{authority}{PROTOCOL_AUTH_SYNC_PATH}" + + +def _parse_json_object(content: bytes) -> dict[str, object]: + try: + decoded = content.decode("utf-8") + parsed = json.loads(decoded) + except Exception: + return {} + return parsed if isinstance(parsed, dict) else {} + + +def _describe_sync_http_result(status: int, body: bytes) -> str: + parsed = _parse_json_object(body) + msg = str(parsed.get("msg") or "").strip() + data = parsed.get("data") + reason = "" + detail = "" + if isinstance(data, dict): + reason = str(data.get("reason") or "").strip() + detail = str(data.get("detail") or "").strip() + parts = [f"HTTP {status}"] + if msg: + parts.append(msg) + if reason: + parts.append(reason) + if detail: + parts.append(detail) + return " - ".join(parts) + + +def _post_sync_payload( + *, + local_api: str, + sync_secret: str, + payload: dict[str, object], + timeout: float = 5.0, +) -> tuple[str, int, bytes]: + sync_url = _sync_callback_url(local_api) + request = Request( + sync_url, + data=json.dumps(payload, separators=(",", ":")).encode("utf-8"), + headers={ + "Content-Type": "application/json", + "X-Local-Sync-Secret": sync_secret, + }, + method="POST", + ) + context = ssl.create_default_context() + try: + with urlopen(request, timeout=timeout, context=context) as response: + status = getattr(response, "status", 200) + return sync_url, status, response.read() + except HTTPError as exc: + return sync_url, exc.code, exc.read() + except Exception as exc: + raise SyncEndpointError(sync_url, f"request failed: {exc}") from exc + + +def _preflight_sync_endpoint(local_api: str, sync_secret: str) -> None: + sync_url, status, body = _post_sync_payload( + local_api=local_api, + sync_secret=sync_secret, + payload={"source": PROTOCOL_AUTH_PREFLIGHT_SOURCE}, + ) + parsed = _parse_json_object(body) + data = parsed.get("data") + reason = str(data.get("reason") or "").strip() if isinstance(data, dict) else "" + if status == 400 and reason == "missing_user_data": + return + raise SyncEndpointError(sync_url, f"preflight failed: {_describe_sync_http_result(status, body)}", status=status) + + +def _write_sync_failure_response(flow: http.HTTPFlow, exc: SyncEndpointError) -> None: + payload = { + "code": 50241, + "msg": "local_sync_failed", + "data": { + "reason": "sync_unreachable", + "syncUrl": exc.sync_url, + "detail": exc.detail, + }, + } + flow.response.status_code = 502 + flow.response.headers["content-type"] = "application/json" + flow.response.content = json.dumps(payload, separators=(",", ":")).encode("utf-8") + + +def _load_local_sync_secret() -> str: + config_path = os.path.join(SCRIPT_DIR, "config.toml") + if not os.path.exists(config_path): + return "" + try: + with open(config_path, "rb") as handle: + parsed = tomllib.load(handle) + except Exception: + return "" + admin = parsed.get("admin") + if not isinstance(admin, dict): + return "" + return str(admin.get("session_secret") or "").strip() + + def load(loader) -> None: - global LOCAL_API, LOCAL_MQTT, LOCAL_WOOD - LOCAL_API = os.environ["MITM_LOCAL_API"] - LOCAL_MQTT = os.environ.get("MITM_LOCAL_MQTT", LOCAL_API) or LOCAL_API - LOCAL_WOOD = os.environ.get("MITM_LOCAL_WOOD", LOCAL_API) or LOCAL_API + global LOCAL_API, LOCAL_API_HOST, LOCAL_API_PORT + global LOCAL_MQTT, LOCAL_MQTT_HOST, LOCAL_MQTT_PORT + global LOCAL_WOOD, LOCAL_WOOD_HOST, LOCAL_WOOD_PORT + global LOCAL_SYNC_SECRET + LOCAL_API_HOST, LOCAL_API_PORT = _parse_endpoint( + os.environ["MITM_LOCAL_API"], + fallback_port=DEFAULT_LOCAL_API_PORT, + ) + LOCAL_API = _format_authority(LOCAL_API_HOST, LOCAL_API_PORT, default_port=443) + LOCAL_MQTT_HOST, LOCAL_MQTT_PORT = _parse_endpoint( + os.environ.get("MITM_LOCAL_MQTT", ""), + fallback_host=LOCAL_API_HOST, + fallback_port=DEFAULT_LOCAL_MQTT_PORT, + ) + LOCAL_MQTT = _format_authority(LOCAL_MQTT_HOST, LOCAL_MQTT_PORT) + LOCAL_WOOD_HOST, LOCAL_WOOD_PORT = _parse_endpoint( + os.environ.get("MITM_LOCAL_WOOD", ""), + fallback_host=LOCAL_API_HOST, + fallback_port=LOCAL_API_PORT, + ) + LOCAL_WOOD = _format_authority(LOCAL_WOOD_HOST, LOCAL_WOOD_PORT, default_port=443) + LOCAL_SYNC_SECRET = str(os.environ.get("MITM_LOCAL_SYNC_SECRET") or "").strip() _init_log_dir() ctx.log.info(f"[CONFIG] LOCAL_API={LOCAL_API} LOCAL_MQTT={LOCAL_MQTT} LOCAL_WOOD={LOCAL_WOOD}") + if LOCAL_SYNC_SECRET: + ctx.log.info(f"[SYNC] protocol auth session sync enabled via {_sync_callback_url(LOCAL_API)}") + else: + ctx.log.warn("[SYNC] protocol auth session sync disabled: no sync secret configured") def _safe_body(content: bytes, content_type: str) -> str: @@ -255,12 +479,66 @@ def request(flow: http.HTTPFlow) -> None: source = flow.request.pretty_host flow.request.scheme = "https" - flow.request.host = LOCAL_API - flow.request.port = 443 + flow.request.host = LOCAL_API_HOST + flow.request.port = LOCAL_API_PORT or DEFAULT_LOCAL_API_PORT flow.request.headers["Host"] = LOCAL_API ctx.log.info(f"[ROUTE] {source}{path} -> {LOCAL_API}{path}") +def _clean_path(path: str) -> str: + clean_path = (str(path or "").split("?", 1)[0] or "/").rstrip("/").lower() + return clean_path or "/" + + +def _is_login_sync_candidate(path: str) -> bool: + return _clean_path(path) in LOGIN_SYNC_EXACT_PATHS + + +def _extract_protocol_user_data(payload: object) -> dict[str, object] | None: + if not isinstance(payload, dict): + return None + data = payload.get("data") + if not isinstance(data, dict): + return None + rriot = data.get("rriot") + if not isinstance(rriot, dict): + return None + required_values = ( + str(data.get("token") or "").strip(), + str(data.get("rruid") or "").strip(), + str(rriot.get("u") or "").strip(), + str(rriot.get("s") or "").strip(), + str(rriot.get("h") or "").strip(), + str(rriot.get("k") or "").strip(), + ) + if not all(required_values): + return None + return dict(data) + + +def _sync_protocol_user_data(user_data: dict[str, object]) -> None: + global _sync_warning_emitted + if not LOCAL_SYNC_SECRET: + if not _sync_warning_emitted: + ctx.log.warn("[SYNC] skipped protocol auth sync: no sync secret configured") + _sync_warning_emitted = True + return + + sync_url, status, body = _post_sync_payload( + local_api=LOCAL_API, + sync_secret=LOCAL_SYNC_SECRET, + payload={"source": PROTOCOL_AUTH_SYNC_SOURCE, "user_data": user_data}, + ) + if not 200 <= status < 300: + raise SyncEndpointError(sync_url, _describe_sync_http_result(status, body), status=status) + + hawk_id = str(((user_data.get("rriot") or {}) if isinstance(user_data.get("rriot"), dict) else {}).get("u") or "") + ctx.log.info( + f"[SYNC] stored protocol auth session rruid={str(user_data.get('rruid') or '')} " + f"hawk_id={hawk_id} status={status}" + ) + + def response(flow: http.HTTPFlow) -> None: """Rewrite cloud endpoint references in JSON payloads.""" host = flow.request.pretty_host @@ -279,6 +557,16 @@ def response(flow: http.HTTPFlow) -> None: if "json" in content_type or _looks_like_json(flow.response.content): try: body = json.loads(flow.response.content) + if _is_login_sync_candidate(flow.request.path): + user_data = _extract_protocol_user_data(body) + if user_data is not None: + try: + _sync_protocol_user_data(user_data) + except SyncEndpointError as exc: + ctx.log.error(f"[SYNC] blocking login response: {exc}") + _write_sync_failure_response(flow, exc) + _log_flow(flow, rewritten=False) + return if _rewrite_json(body, rewrites): _log_flow(flow, rewritten=True, rewrites=rewrites) flow.response.content = json.dumps(body).encode("utf-8") @@ -324,33 +612,50 @@ def _rewrite_json(obj, rewrites: list[str]) -> bool: return changed -def _rewrite_value(text: str) -> str: - for host in ("mqtt.roborock.com", "mqtt-us.roborock.com", "mqtt-eu.roborock.com", "mqtt-cn.roborock.com"): - if host in text: - text = text.replace(host, LOCAL_MQTT) - text = re.sub(r"mqtt-\w+-\d+\.roborock\.com", LOCAL_MQTT, text) - - for host in ( - "api.roborock.com", - "api-us.roborock.com", - "api-eu.roborock.com", - "api-cn.roborock.com", - "usiot.roborock.com", - "euiot.roborock.com", - "cniot.roborock.com", - "cnaccount.roborock.com", - "usaccount.roborock.com", - "euaccount.roborock.com", - "account.roborock.com", - ): - if host in text: - text = text.replace(host, LOCAL_API) +def _rewrite_authorities( + text: str, + *, + patterns: tuple[re.Pattern[str], ...], + replacement_host: str, + replacement_port: int | None, + default_port: int | None = None, +) -> str: + if not replacement_host: + return text - for host in ("wood.roborock.com", "wood-us.roborock.com", "wood-eu.roborock.com", "wood-cn.roborock.com"): - if host in text: - text = text.replace(host, LOCAL_WOOD) + def _replace(match: re.Match[str]) -> str: + original_port = match.group("port") + port = int(original_port) if replacement_port is None and original_port else replacement_port + return _format_authority(replacement_host, port, default_port=default_port) - return text + rewritten = text + for pattern in patterns: + rewritten = pattern.sub(_replace, rewritten) + return rewritten + + +def _rewrite_value(text: str) -> str: + rewritten = _rewrite_authorities( + text, + patterns=_MQTT_REWRITE_PATTERNS + _MQTT_NUMBERED_REWRITE_PATTERNS, + replacement_host=LOCAL_MQTT_HOST, + replacement_port=LOCAL_MQTT_PORT, + ) + rewritten = _rewrite_authorities( + rewritten, + patterns=_API_REWRITE_PATTERNS, + replacement_host=LOCAL_API_HOST, + replacement_port=LOCAL_API_PORT, + default_port=443, + ) + rewritten = _rewrite_authorities( + rewritten, + patterns=_WOOD_REWRITE_PATTERNS, + replacement_host=LOCAL_WOOD_HOST, + replacement_port=LOCAL_WOOD_PORT, + default_port=443, + ) + return rewritten if __name__ == "__main__": @@ -361,20 +666,87 @@ def _rewrite_value(text: str) -> str: parser = argparse.ArgumentParser( description="Launch mitmweb with Roborock traffic interception.", ) - parser.add_argument("--local-api", required=True, help="Hostname of your local API server") - parser.add_argument("--local-mqtt", default=None, help="Hostname of your local MQTT server (defaults to --local-api)") - parser.add_argument("--local-wood", default=None, help="Hostname of your local Wood server (defaults to --local-api)") + parser.add_argument( + "--local-api", + required=True, + help="Hostname or HTTPS URL of your local API server. Defaults to HTTPS :555 when omitted.", + ) + parser.add_argument( + "--local-mqtt", + default=None, + help="Hostname or URL of your local MQTT server. Omit to reuse the API hostname and default to MQTT TLS :8881.", + ) + parser.add_argument( + "--local-wood", + default=None, + help="Hostname or HTTPS URL of your local Wood server. Explicit ports are supported.", + ) + parser.add_argument( + "--sync-secret", + default=None, + help="Optional admin.session_secret for protocol auth sync. Defaults to config.toml when available.", + ) parser.add_argument("--mode", default="wireguard", help="mitmweb proxy mode (default: wireguard)") parser.add_argument("--listen-port", default=None, help="mitmweb listen port") args, extra = parser.parse_known_args() + local_api_host, local_api_port = _parse_endpoint( + args.local_api, + fallback_port=DEFAULT_LOCAL_API_PORT, + ) + local_api = _format_authority(local_api_host, local_api_port, default_port=443) + local_mqtt_host, local_mqtt_port = _parse_endpoint( + args.local_mqtt or "", + fallback_host=local_api_host, + fallback_port=DEFAULT_LOCAL_MQTT_PORT, + ) + local_mqtt = _format_authority(local_mqtt_host, local_mqtt_port) + local_wood_host, local_wood_port = _parse_endpoint( + args.local_wood or "", + fallback_host=local_api_host, + fallback_port=local_api_port, + ) + local_wood = _format_authority(local_wood_host, local_wood_port, default_port=443) + local_sync_secret = str(args.sync_secret or os.environ.get("MITM_LOCAL_SYNC_SECRET") or _load_local_sync_secret()).strip() + + for label, original, normalized in ( + ("local-api", args.local_api, local_api), + ("local-mqtt", args.local_mqtt or "", local_mqtt), + ("local-wood", args.local_wood or "", local_wood), + ): + if original and normalized and str(original).strip() != normalized: + print(f"[CONFIG] normalized --{label} from {original!r} to {normalized!r}") env = os.environ.copy() - env["MITM_LOCAL_API"] = args.local_api - env["MITM_LOCAL_MQTT"] = args.local_mqtt or args.local_api - env["MITM_LOCAL_WOOD"] = args.local_wood or args.local_api + env["MITM_LOCAL_API"] = local_api + env["MITM_LOCAL_MQTT"] = local_mqtt + env["MITM_LOCAL_WOOD"] = local_wood + env["MITM_LOCAL_SYNC_SECRET"] = local_sync_secret - cmd = ["uvx", "--from", "mitmproxy", "mitmweb", "--mode", args.mode, "-s", os.path.abspath(__file__)] + if local_sync_secret: + try: + _preflight_sync_endpoint(local_api, local_sync_secret) + except SyncEndpointError as exc: + print(f"[SYNC] refusing to start mitmweb: {exc}", file=sys.stderr) + sys.exit(2) + print(f"[SYNC] verified protocol auth sync endpoint via {_sync_callback_url(local_api)}") + else: + print("[SYNC] protocol auth session sync disabled: no sync secret configured") + + cmd = [ + "uvx", + "--from", + "mitmproxy", + "mitmweb", + "--mode", + args.mode, + "--set", + "connection_strategy=lazy", + "--set", + "http3=false", + "-s", + os.path.abspath(__file__), + ] if args.listen_port: cmd += ["--listen-port", args.listen_port] cmd += extra diff --git a/pyproject.toml b/pyproject.toml index ca5e36c..e4151b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,10 @@ [project] name = "roborock-local-server" -version = "0.1.0" -description = "Clean release surface for a private local Roborock server stack." +version = "0.0.2" +description = "private local Roborock server stack." requires-python = ">=3.11,<3.14" readme = "README.md" -authors = [{ name = "Codex" }] +authors = [{ name = "Luke Lashley" }] dependencies = [ "aiohttp>=3.9,<4", "cryptography>=42,<47", diff --git a/src/roborock_local_server/__init__.py b/src/roborock_local_server/__init__.py index 3eee9d3..f51fc6f 100644 --- a/src/roborock_local_server/__init__.py +++ b/src/roborock_local_server/__init__.py @@ -2,5 +2,4 @@ __all__ = ["__version__"] -__version__ = "0.1.0" - +__version__ = "0.0.2" diff --git a/src/roborock_local_server/bundled_backend/https_server/endpoint_rules.py b/src/roborock_local_server/bundled_backend/https_server/endpoint_rules.py index 0a80f83..f87799f 100644 --- a/src/roborock_local_server/bundled_backend/https_server/endpoint_rules.py +++ b/src/roborock_local_server/bundled_backend/https_server/endpoint_rules.py @@ -444,7 +444,10 @@ def _filter_home_data_to_runtime_devices(ctx: ServerContext, home_data: dict[str for collection_key in ("devices", "receivedDevices", "received_devices"): devices_value = home_data.get(collection_key) - devices = devices_value if isinstance(devices_value, list) else [] + if not isinstance(devices_value, list): + filtered_home.pop(collection_key, None) + continue + devices = devices_value filtered_devices: list[dict[str, Any]] = [] for device in devices: if not isinstance(device, dict): diff --git a/src/roborock_local_server/bundled_backend/https_server/routes/api/v1/user.py b/src/roborock_local_server/bundled_backend/https_server/routes/api/v1/user.py index 757fb67..ce950a8 100644 --- a/src/roborock_local_server/bundled_backend/https_server/routes/api/v1/user.py +++ b/src/roborock_local_server/bundled_backend/https_server/routes/api/v1/user.py @@ -47,7 +47,7 @@ def build_get_url_by_email( region_upper = ctx.region.upper() return ok( { - "url": f"https://{ctx.api_host}", + "url": ctx.api_url(), "countrycode": _default_country_code_for_region(region_upper), "country": region_upper, } @@ -65,10 +65,13 @@ def build_user_info( meta_value = snapshot.get("meta") meta = meta_value if isinstance(meta_value, dict) else {} username = str(meta.get("username") or "").strip() + configured_email = str(getattr(ctx, "protocol_login_email", "") or "").strip() email = str(get_value(cloud_user_data, "email", default="") or "").strip() mobile = str(get_value(cloud_user_data, "mobile", default="") or "").strip() - if not email and "@" in username: + if configured_email: + email = configured_email + elif not email and "@" in username: email = username if not mobile and username.isdigit(): mobile = username diff --git a/src/roborock_local_server/bundled_backend/https_server/routes/auth/service.py b/src/roborock_local_server/bundled_backend/https_server/routes/auth/service.py index 84bc3d6..e790dde 100644 --- a/src/roborock_local_server/bundled_backend/https_server/routes/auth/service.py +++ b/src/roborock_local_server/bundled_backend/https_server/routes/auth/service.py @@ -18,14 +18,12 @@ def cloud_snapshot_path(ctx: ServerContext): def current_server_urls(ctx: ServerContext) -> tuple[str, str, str]: - api_url = f"https://{ctx.api_host}" - mqtt_url = f"ssl://{ctx.mqtt_host}:{ctx.mqtt_tls_port}" - wood_url = f"https://{ctx.wood_host}" - return api_url, mqtt_url, wood_url + return ctx.api_url(), ctx.mqtt_url(), ctx.wood_url() def with_current_server_urls(ctx: ServerContext, cloud_user_data: dict[str, Any]) -> dict[str, Any]: api_url, mqtt_url, wood_url = current_server_urls(ctx) + region_code = str(ctx.region or "").upper() or "US" patched_user_data = dict(cloud_user_data) rriot_value = patched_user_data.get("rriot") @@ -33,7 +31,7 @@ def with_current_server_urls(ctx: ServerContext, cloud_user_data: dict[str, Any] rriot = dict(rriot_value) ref_value = rriot.get("r") ref = dict(ref_value) if isinstance(ref_value, dict) else {} - ref.update({"a": api_url, "m": mqtt_url, "l": wood_url}) + ref.update({"r": region_code, "a": api_url, "m": mqtt_url, "l": wood_url}) rriot["r"] = ref patched_user_data["rriot"] = rriot @@ -158,16 +156,15 @@ def build_password_reset_response( return ok(None) -def build_login_data_response(ctx: ServerContext) -> dict[str, Any]: - cloud_user_data = load_cloud_user_data(ctx) - if cloud_user_data is None: +def build_login_data_response(ctx: ServerContext, user_data: dict[str, Any] | None = None) -> dict[str, Any]: + candidate_user_data = with_current_server_urls(ctx, user_data) if isinstance(user_data, dict) else load_cloud_user_data(ctx) + if candidate_user_data is None: return cloud_login_data_required_response(ctx, reason="missing_snapshot_or_user_data") - missing_fields = missing_cloud_login_fields(cloud_user_data) + missing_fields = missing_cloud_login_fields(candidate_user_data) if missing_fields: return cloud_login_data_required_response( ctx, reason="incomplete_cloud_user_data", missing_fields=missing_fields, ) - return ok(cloud_user_data) - + return ok(candidate_user_data) diff --git a/src/roborock_local_server/bundled_backend/https_server/routes/bootstrap/region.py b/src/roborock_local_server/bundled_backend/https_server/routes/bootstrap/region.py index d7eebbf..459a41a 100644 --- a/src/roborock_local_server/bundled_backend/https_server/routes/bootstrap/region.py +++ b/src/roborock_local_server/bundled_backend/https_server/routes/bootstrap/region.py @@ -3,6 +3,7 @@ import json from typing import Any +from shared.context import split_host_port from shared.context import ServerContext from shared.http_helpers import wrap_response @@ -23,10 +24,11 @@ def build( ) -> dict[str, Any]: did = ctx.extract_did(query_params, body_params) host_override = request_host_override(query_params) - api_host = host_override or ctx.api_host - mqtt_host = host_override or ctx.mqtt_host - api_url = f"https://{api_host}" - mqtt_url = f"ssl://{mqtt_host}:{ctx.mqtt_tls_port}" + override_host, override_port = split_host_port(host_override) + api_host = override_host or ctx.api_host + mqtt_host = override_host or ctx.mqtt_host + api_url = ctx.api_url(host=api_host, port=override_port if override_host else None) + mqtt_url = ctx.mqtt_url(host=mqtt_host) region_payload = { "apiUrl": api_url, "mqttUrl": mqtt_url, diff --git a/src/roborock_local_server/bundled_backend/https_server/routes/bootstrap/service.py b/src/roborock_local_server/bundled_backend/https_server/routes/bootstrap/service.py index 19ce019..2631e9a 100644 --- a/src/roborock_local_server/bundled_backend/https_server/routes/bootstrap/service.py +++ b/src/roborock_local_server/bundled_backend/https_server/routes/bootstrap/service.py @@ -17,7 +17,7 @@ def request_host_override(query_params: dict[str, list[str]]) -> str: for value in values: candidate = str(value or "").strip() if candidate: - return candidate.split(":", 1)[0].strip() + return candidate return "" @@ -30,4 +30,3 @@ def extract_explicit_did(query_params: dict[str, list[str]], body_params: dict[s + (body_params.get("d") or []) + (body_params.get("duid") or []) ) - diff --git a/src/roborock_local_server/bundled_backend/https_server/routes/plugin/common.py b/src/roborock_local_server/bundled_backend/https_server/routes/plugin/common.py index ee3a65e..f95dcc3 100644 --- a/src/roborock_local_server/bundled_backend/https_server/routes/plugin/common.py +++ b/src/roborock_local_server/bundled_backend/https_server/routes/plugin/common.py @@ -110,7 +110,7 @@ def plugin_proxy_url(ctx: ServerContext, source_url: str) -> str: return source digest = hashlib.sha256(source.encode("utf-8")).hexdigest()[:16] encoded_source = quote(source, safe="") - return f"https://{ctx.api_host}/plugin/proxy/{digest}.zip?src={encoded_source}" + return f"{ctx.api_url()}/plugin/proxy/{digest}.zip?src={encoded_source}" def proxied_plugin_records( @@ -198,4 +198,3 @@ async def plugin_proxy_response(*, runtime_dir: Path, source_url: str) -> Respon media_type=media_type, headers={"Cache-Control": "public, max-age=86400", "X-RR-Plugin-Cache": "miss"}, ) - diff --git a/src/roborock_local_server/bundled_backend/https_server/routes/user/homes/service.py b/src/roborock_local_server/bundled_backend/https_server/routes/user/homes/service.py index 4b1c19d..dc58239 100644 --- a/src/roborock_local_server/bundled_backend/https_server/routes/user/homes/service.py +++ b/src/roborock_local_server/bundled_backend/https_server/routes/user/homes/service.py @@ -100,7 +100,10 @@ def _filter_home_data_to_runtime_devices(ctx: ServerContext, home_data: dict[str for collection_key in ("devices", "receivedDevices", "received_devices"): devices_value = home_data.get(collection_key) - devices = devices_value if isinstance(devices_value, list) else [] + if not isinstance(devices_value, list): + filtered_home.pop(collection_key, None) + continue + devices = devices_value filtered_devices: list[dict[str, Any]] = [] for device in devices: if not isinstance(device, dict): diff --git a/src/roborock_local_server/bundled_backend/mqtt_broker_server/server.py b/src/roborock_local_server/bundled_backend/mqtt_broker_server/server.py index 2ee3251..62ac8ca 100644 --- a/src/roborock_local_server/bundled_backend/mqtt_broker_server/server.py +++ b/src/roborock_local_server/bundled_backend/mqtt_broker_server/server.py @@ -19,7 +19,7 @@ def build_broker_config(port: int) -> str: return "\n".join( [ "# Generated by roborock_local_server", - f"listener {port}", + f"listener {port} 127.0.0.1", "allow_anonymous true", "connection_messages true", "log_dest stdout", diff --git a/src/roborock_local_server/bundled_backend/mqtt_tls_proxy_server/server.py b/src/roborock_local_server/bundled_backend/mqtt_tls_proxy_server/server.py index c73909e..5d4de2b 100644 --- a/src/roborock_local_server/bundled_backend/mqtt_tls_proxy_server/server.py +++ b/src/roborock_local_server/bundled_backend/mqtt_tls_proxy_server/server.py @@ -11,18 +11,21 @@ import socket import ssl import threading -from typing import Any +from typing import Any, Callable from shared.constants import MQTT_TYPES from shared.decoder import build_decoder from shared.io_utils import append_jsonl, payload_preview -from shared.runtime_credentials import RuntimeCredentialsStore +from shared.protocol_auth import ProtocolAuthStore +from shared.runtime_credentials import RuntimeCredentialsStore, parse_mqtt_connect_packet from shared.runtime_state import RuntimeState from shared.zone_ranges_store import ZoneRangesStore from .command_handlers import RpcCommandRegistry class MqttTlsProxy: + _MAX_FIRST_PACKET_BYTES = 1024 * 1024 + def __init__( self, *, @@ -35,6 +38,9 @@ def __init__( localkey: str, logger: logging.Logger, decoded_jsonl: Path, + cloud_snapshot_path: Path | None = None, + protocol_auth_sessions_path: Path | None = None, + protocol_auth_enabled: Callable[[], bool] | None = None, runtime_state: RuntimeState | None = None, runtime_credentials: RuntimeCredentialsStore | None = None, zone_ranges_store: ZoneRangesStore | None = None, @@ -48,6 +54,8 @@ def __init__( self.localkey = localkey self.logger = logger self.decoded_jsonl = decoded_jsonl + self.cloud_snapshot_path = cloud_snapshot_path + self._protocol_auth_enabled = protocol_auth_enabled or (lambda: True) self.runtime_state = runtime_state self.runtime_credentials = runtime_credentials self.zone_ranges_store = zone_ranges_store @@ -58,6 +66,14 @@ def __init__( self._conn_protocol_levels: dict[str, int] = {} self._trace_queue: queue.Queue[tuple[str, str, bytes] | None] = queue.Queue() self._trace_thread: threading.Thread | None = None + self._protocol_auth = ( + ProtocolAuthStore( + cloud_snapshot_path, + session_store_path=protocol_auth_sessions_path, + ) + if cloud_snapshot_path is not None + else None + ) default_decoder, self._protocol_names = build_decoder(localkey) self._decoder_cache: dict[str, Any] = {localkey: default_decoder} self._command_registry = RpcCommandRegistry() @@ -85,6 +101,12 @@ def _decode_remaining_length(data: bytes, start: int) -> tuple[int | None, int]: break return None, 0 + @staticmethod + def _remaining_length_invalid(data: bytes, start: int) -> bool: + if start + 3 >= len(data): + return False + return (data[start + 3] & 0x80) != 0 + def _extract_packets(self, frame_buf: bytearray) -> list[bytes]: packets: list[bytes] = [] offset = 0 @@ -120,6 +142,91 @@ def _extract_connect_protocol_level(cls, packet: bytes) -> int | None: return None return packet[protocol_level_idx] + @staticmethod + def _build_connect_reject_packet(protocol_level: int | None) -> bytes | None: + if protocol_level == 5: + # MQTT 5 CONNACK with reason code 0x87 "Not authorized". + return b"\x20\x03\x00\x87\x00" + if protocol_level in (None, 3, 4): + # MQTT 3.1/3.1.1 CONNACK with return code 0x05 "Not authorized". + return b"\x20\x02\x00\x05" + return None + + @classmethod + def _read_first_packet(cls, conn: socket.socket) -> tuple[bytes, bytes] | None: + buffer = bytearray() + while True: + chunk = conn.recv(4096) + if not chunk: + return None + buffer.extend(chunk) + if len(buffer) > cls._MAX_FIRST_PACKET_BYTES: + raise ValueError("MQTT CONNECT exceeds maximum supported size") + if len(buffer) < 2: + continue + if cls._remaining_length_invalid(buffer, 1): + raise ValueError("Invalid MQTT remaining length in CONNECT packet") + remaining_len, remaining_len_bytes = cls._decode_remaining_length(buffer, 1) + if remaining_len is None or remaining_len_bytes == 0: + continue + total_len = 1 + remaining_len_bytes + remaining_len + if total_len > cls._MAX_FIRST_PACKET_BYTES: + raise ValueError("MQTT CONNECT exceeds maximum supported size") + if len(buffer) < total_len: + continue + return bytes(buffer[:total_len]), bytes(buffer[total_len:]) + + def _expected_bootstrap_credentials(self) -> tuple[str, str, str] | None: + if self.runtime_credentials is None: + return None + username = str(self.runtime_credentials.bootstrap_value("mqtt_usr", "") or "").strip() + password = str(self.runtime_credentials.bootstrap_value("mqtt_passwd", "") or "").strip() + client_id = str(self.runtime_credentials.bootstrap_value("mqtt_clientid", "") or "").strip() + if not username or not password: + return None + return username, password, client_id + + def _authorize_connect_packet(self, packet: bytes) -> tuple[bool, str, dict[str, Any] | None]: + info = parse_mqtt_connect_packet(packet) + if info is None: + return False, "invalid_connect_packet", None + + username = str(info.get("username") or "").strip() + password = str(info.get("password") or "").strip() + client_id = str(info.get("client_id") or "").strip() + if not username or not password: + return False, "missing_mqtt_credentials", info + + if self._protocol_auth is not None and self._protocol_auth_enabled(): + authorized, auth_reason, _matched_user = self._protocol_auth.verify_user_mqtt_credentials(username, password) + if authorized: + return True, auth_reason, info + + bootstrap_credentials = self._expected_bootstrap_credentials() + if bootstrap_credentials is not None: + expected_username, expected_password, expected_client_id = bootstrap_credentials + if username == expected_username and password == expected_password: + if expected_client_id and client_id and client_id != expected_client_id: + return False, "invalid_bootstrap_client_id", info + return True, "bootstrap", info + + if self.runtime_credentials is not None: + authorized, auth_reason, _matched_device = self.runtime_credentials.verify_device_mqtt_credentials( + username=username, + password=password, + ) + if authorized: + return True, auth_reason, info + if auth_reason == "device_mqtt_password_missing": + recovered_device = self.runtime_credentials.recover_device_mqtt_password( + username=username, + password=password, + ) + if recovered_device is not None: + return True, "device_mqtt_recovered", info + + return False, "invalid_mqtt_credentials", info + @classmethod def _extract_publish(cls, packet: bytes, protocol_level: int | None = None) -> tuple[str | None, bytes | None]: if len(packet) < 4: @@ -484,6 +591,8 @@ def _relay(self, src: socket.socket, dst: socket.socket, conn_id: str, direction def _handle_client(self, tls_conn: ssl.SSLSocket, addr: tuple[str, int]) -> None: conn_id = self._next_conn() + backend: socket.socket | None = None + relay_started = False self.logger.info( "[conn %s] backend connect %s:%d from %s:%d", conn_id, @@ -495,10 +604,50 @@ def _handle_client(self, tls_conn: ssl.SSLSocket, addr: tuple[str, int]) -> None if self.runtime_state is not None: self.runtime_state.record_mqtt_connection(conn_id=conn_id, client_ip=addr[0], client_port=addr[1]) try: + first_packet = self._read_first_packet(tls_conn) + if first_packet is None: + self.logger.warning("[conn %s] client closed before MQTT CONNECT", conn_id) + return + connect_packet, initial_remainder = first_packet + authorized, auth_reason, connect_info = self._authorize_connect_packet(connect_packet) + if connect_info is not None: + protocol_level = connect_info.get("protocol_level") + if isinstance(protocol_level, int): + self._set_conn_protocol_level(conn_id, protocol_level) + self._queue_trace_packet(conn_id, "c2b", connect_packet) + if not authorized: + self.logger.warning( + "[conn %s] rejected MQTT CONNECT reason=%s client_id=%s username=%s", + conn_id, + auth_reason, + str((connect_info or {}).get("client_id") or ""), + str((connect_info or {}).get("username") or ""), + ) + reject_packet = self._build_connect_reject_packet( + connect_info.get("protocol_level") if isinstance(connect_info, dict) else None + ) + if reject_packet is not None: + try: + tls_conn.sendall(reject_packet) + except (OSError, ConnectionResetError, BrokenPipeError): + pass + else: + self._queue_trace_packet(conn_id, "b2c", reject_packet) + return + backend = socket.socket(socket.AF_INET, socket.SOCK_STREAM) backend.connect((self.backend_host, self.backend_port)) - c2b = threading.Thread(target=self._relay, args=(tls_conn, backend, conn_id, "c2b", bytearray()), daemon=True) + c2b_frame_buf = bytearray(initial_remainder) + for packet in self._extract_packets(c2b_frame_buf): + self._queue_trace_packet(conn_id, "c2b", packet) + backend.sendall(connect_packet + initial_remainder) + c2b = threading.Thread( + target=self._relay, + args=(tls_conn, backend, conn_id, "c2b", c2b_frame_buf), + daemon=True, + ) b2c = threading.Thread(target=self._relay, args=(backend, tls_conn, conn_id, "b2c", bytearray()), daemon=True) + relay_started = True c2b.start() b2c.start() c2b.join() @@ -506,6 +655,14 @@ def _handle_client(self, tls_conn: ssl.SSLSocket, addr: tuple[str, int]) -> None except Exception as exc: self.logger.error("[conn %s] connection error: %s", conn_id, exc) finally: + if not relay_started: + for endpoint in (tls_conn, backend): + if endpoint is None: + continue + try: + endpoint.close() + except OSError: + pass if self.runtime_state is not None: self.runtime_state.record_mqtt_disconnect(conn_id=conn_id) with self._lock: diff --git a/src/roborock_local_server/bundled_backend/shared/context.py b/src/roborock_local_server/bundled_backend/shared/context.py index b8e82c5..0a8791f 100644 --- a/src/roborock_local_server/bundled_backend/shared/context.py +++ b/src/roborock_local_server/bundled_backend/shared/context.py @@ -8,6 +8,7 @@ from pathlib import Path import secrets from typing import Any +from urllib.parse import urlsplit from .bootstrap_crypto import BootstrapEncryptor from .device_key_recovery import DeviceKeyCache @@ -17,17 +18,47 @@ from .zone_ranges_store import ZoneRangesStore +def split_host_port(value: str) -> tuple[str, int | None]: + raw = str(value or "").strip() + if not raw: + return "", None + try: + parsed = urlsplit(raw if "://" in raw else f"//{raw}") + host = (parsed.hostname or parsed.path.split("/", 1)[0]).strip().strip("/") + port = parsed.port + except ValueError: + candidate = raw.split("/", 1)[0].strip() + host, _sep, port_text = candidate.partition(":") + host = host.strip().strip("[]") + port = int(port_text) if port_text.isdigit() else None + return host, port + + +def format_authority(host: str, *, port: int | None = None, default_port: int | None = None) -> str: + normalized_host, embedded_port = split_host_port(host) + resolved_port = port if port is not None else embedded_port + if not normalized_host: + return "" + if resolved_port is None: + return normalized_host + if default_port is not None and resolved_port == default_port: + return normalized_host + return f"{normalized_host}:{resolved_port}" + + @dataclass class ServerContext: api_host: str mqtt_host: str wood_host: str region: str + protocol_login_email: str localkey: str duid: str mqtt_usr: str mqtt_passwd: str mqtt_clientid: str + https_port: int mqtt_tls_port: int http_jsonl: Path mqtt_jsonl: Path @@ -48,10 +79,25 @@ def __post_init__(self) -> None: if self.runtime_credentials is not None: self.runtime_credentials.sync_inventory() + def api_url(self, *, host: str | None = None, port: int | None = None) -> str: + return ( + f"https://" + f"{format_authority(host or self.api_host, port=self.https_port if port is None else port, default_port=443)}" + ) + + def wood_url(self, *, host: str | None = None, port: int | None = None) -> str: + return ( + f"https://" + f"{format_authority(host or self.wood_host, port=self.https_port if port is None else port, default_port=443)}" + ) + + def mqtt_url(self, *, host: str | None = None, port: int | None = None) -> str: + return f"ssl://{format_authority(host or self.mqtt_host, port=self.mqtt_tls_port if port is None else port)}" + def region_payload(self) -> dict[str, Any]: - api_url = f"https://{self.api_host}" - mqtt_url = f"ssl://{self.mqtt_host}:{self.mqtt_tls_port}" - wood_url = f"https://{self.wood_host}" + api_url = self.api_url() + mqtt_url = self.mqtt_url() + wood_url = self.wood_url() return { "apiUrl": api_url, "mqttUrl": mqtt_url, @@ -171,8 +217,8 @@ def nc_payload( source="onboarding_nc", assign_if_missing=True, ) - api_url = f"https://{self.api_host}" - mqtt_url = f"ssl://{self.mqtt_host}:{self.mqtt_tls_port}" + api_url = self.api_url() + mqtt_url = self.mqtt_url() if self.runtime_credentials is not None: self.runtime_credentials.ensure_device( did=did, @@ -201,7 +247,7 @@ def nc_payload( "r": self.region.upper(), "a": api_url, "m": mqtt_url, - "l": f"https://{self.wood_host}", + "l": self.wood_url(), }, }, } diff --git a/src/roborock_local_server/bundled_backend/shared/protocol_auth.py b/src/roborock_local_server/bundled_backend/shared/protocol_auth.py new file mode 100644 index 0000000..23902e4 --- /dev/null +++ b/src/roborock_local_server/bundled_backend/shared/protocol_auth.py @@ -0,0 +1,590 @@ +"""Protocol auth helpers shared by the HTTPS server and MQTT proxy.""" + +from __future__ import annotations + +import base64 +from dataclasses import dataclass +from datetime import datetime, timezone +import hashlib +import hmac +import json +from pathlib import Path +import secrets +import threading +import time +from typing import Any, Mapping + + +def _clean_str(value: Any) -> str: + return str(value or "").strip() + + +def _md5hex(value: str) -> str: + return hashlib.md5(value.encode("utf-8")).hexdigest() + + +def _parse_json_body_param_map(body_params: dict[str, list[str]]) -> dict[str, Any]: + for raw in body_params.get("__json") or []: + try: + parsed = json.loads(raw) + except (TypeError, json.JSONDecodeError): + continue + if isinstance(parsed, dict): + return parsed + return {} + + +def _normalize_param_values(params: dict[str, list[str]], *, include_json: bool = False) -> dict[str, Any]: + json_values = _parse_json_body_param_map(params) if include_json else {} + normalized: dict[str, Any] = dict(json_values) + for key, values in params.items(): + if str(key).startswith("__") or not values: + continue + if json_values and values == [""] and str(key).lstrip().startswith(("{", "[")): + continue + normalized[str(key)] = values[0] if len(values) == 1 else list(values) + return normalized + + +def _process_extra_hawk_values(values: dict[str, Any] | None) -> str: + if not values: + return "" + result: list[str] = [] + for key in sorted(values): + result.append(f"{key}={values.get(key)}") + return _md5hex("&".join(result)) + + +def _build_hawk_mac( + *, + hawk_id: str, + hawk_session: str, + hawk_key: str, + path: str, + query_values: dict[str, Any] | None, + form_values: dict[str, Any] | None, + timestamp: int, + nonce: str, +) -> str: + prestr = ":".join( + [ + hawk_id, + hawk_session, + nonce, + str(timestamp), + _md5hex(path), + _process_extra_hawk_values(query_values), + _process_extra_hawk_values(form_values), + ] + ) + return base64.b64encode(hmac.new(hawk_key.encode(), prestr.encode(), hashlib.sha256).digest()).decode() + + +def _parse_hawk_authorization(value: str) -> dict[str, str] | None: + raw = _clean_str(value) + if not raw or not raw.lower().startswith("hawk "): + return None + attributes: dict[str, str] = {} + for item in raw[5:].split(","): + if "=" not in item: + continue + key, raw_value = item.split("=", 1) + normalized_key = _clean_str(key).lower() + normalized_value = _clean_str(raw_value) + if normalized_value.startswith('"') and normalized_value.endswith('"') and len(normalized_value) >= 2: + normalized_value = normalized_value[1:-1] + attributes[normalized_key] = normalized_value + required = {"id", "s", "ts", "nonce", "mac"} + if not required.issubset(attributes): + return None + return attributes + + +@dataclass(frozen=True) +class ProtocolUserData: + token: str + rruid: str + hawk_id: str + hawk_session: str + hawk_key: str + mqtt_username: str + mqtt_password: str + source: str = "" + updated_at_utc: str = "" + + +@dataclass(frozen=True) +class ProtocolAvailability: + user: ProtocolUserData | None + reason: str + users: tuple[ProtocolUserData, ...] = () + missing_fields: tuple[str, ...] = () + + +def _session_identity(user_data: Mapping[str, Any]) -> tuple[str, str]: + rriot = user_data.get("rriot") + if not isinstance(rriot, Mapping): + return "", "" + return _clean_str(rriot.get("u")), _clean_str(rriot.get("s")) + + +def _minimal_session_user_data(user_data: Mapping[str, Any], *, source: str = "", updated_at_utc: str = "") -> dict[str, Any]: + rriot = dict(user_data.get("rriot") or {}) + normalized: dict[str, Any] = { + "uid": user_data.get("uid"), + "token": _clean_str(user_data.get("token")), + "rruid": _clean_str(user_data.get("rruid")), + "rriot": { + "u": _clean_str(rriot.get("u")), + "s": _clean_str(rriot.get("s")), + "h": _clean_str(rriot.get("h")), + "k": _clean_str(rriot.get("k")), + }, + } + if source: + normalized["source"] = source + if updated_at_utc: + normalized["updated_at_utc"] = updated_at_utc + return normalized + + +def _clone_json_value(value: Any) -> Any: + if isinstance(value, dict): + return {str(key): _clone_json_value(item) for key, item in value.items()} + if isinstance(value, list): + return [_clone_json_value(item) for item in value] + return value + + +def build_hawk_authorization( + *, + user: ProtocolUserData, + path: str, + query_values: dict[str, Any] | None = None, + form_values: dict[str, Any] | None = None, + timestamp: int | None = None, + nonce: str | None = None, +) -> str: + ts = int(time.time() if timestamp is None else timestamp) + normalized_nonce = _clean_str(nonce) or secrets.token_urlsafe(6) + mac = _build_hawk_mac( + hawk_id=user.hawk_id, + hawk_session=user.hawk_session, + hawk_key=user.hawk_key, + path=path, + query_values=query_values, + form_values=form_values, + timestamp=ts, + nonce=normalized_nonce, + ) + return ( + f'Hawk id="{user.hawk_id}",s="{user.hawk_session}",ts="{ts}",' + f'nonce="{normalized_nonce}",mac="{mac}"' + ) + + +class ProtocolAuthStore: + """Loads protocol auth state from the persisted cloud snapshot and session store.""" + + def __init__( + self, + snapshot_path: str | Path, + *, + session_store_path: str | Path | None = None, + max_persisted_sessions: int = 8, + hawk_clock_skew_seconds: int = 300, + hawk_nonce_ttl_seconds: int = 600, + ) -> None: + self.snapshot_path = Path(snapshot_path) + self.session_store_path = Path(session_store_path) if session_store_path is not None else None + self.max_persisted_sessions = max(1, int(max_persisted_sessions)) + self.hawk_clock_skew_seconds = hawk_clock_skew_seconds + self.hawk_nonce_ttl_seconds = hawk_nonce_ttl_seconds + self._lock = threading.RLock() + self._snapshot_mtime_ns: int | None = None + self._session_store_mtime_ns: int | None = None + self._persisted_session_records: tuple[dict[str, Any], ...] = () + self._availability = ProtocolAvailability(user=None, reason="missing_snapshot_or_user_data", users=()) + self._nonces: dict[str, float] = {} + + @staticmethod + def _missing_user_fields(user_data: dict[str, Any]) -> list[str]: + missing: list[str] = [] + if not _clean_str(user_data.get("token")): + missing.append("token") + if not _clean_str(user_data.get("rruid")): + missing.append("rruid") + rriot = user_data.get("rriot") + if not isinstance(rriot, dict): + missing.append("rriot") + return missing + if not _clean_str(rriot.get("u")): + missing.append("rriot.u") + if not _clean_str(rriot.get("s")): + missing.append("rriot.s") + if not _clean_str(rriot.get("h")): + missing.append("rriot.h") + if not _clean_str(rriot.get("k")): + missing.append("rriot.k") + return missing + + @staticmethod + def _build_user(user_data: dict[str, Any]) -> ProtocolUserData: + rriot = dict(user_data.get("rriot") or {}) + hawk_id = _clean_str(rriot.get("u")) + hawk_session = _clean_str(rriot.get("s")) + mqtt_key = _clean_str(rriot.get("k")) + return ProtocolUserData( + token=_clean_str(user_data.get("token")), + rruid=_clean_str(user_data.get("rruid")), + hawk_id=hawk_id, + hawk_session=hawk_session, + hawk_key=_clean_str(rriot.get("h")), + mqtt_username=_md5hex(f"{hawk_id}:{mqtt_key}")[2:10], + mqtt_password=_md5hex(f"{hawk_session}:{mqtt_key}")[16:], + source=_clean_str(user_data.get("source")), + updated_at_utc=_clean_str(user_data.get("updated_at_utc")), + ) + + def _load_snapshot_user_locked(self) -> tuple[ProtocolUserData | None, str, tuple[str, ...]]: + try: + stat = self.snapshot_path.stat() + except OSError: + self._snapshot_mtime_ns = None + return None, "missing_snapshot_or_user_data", () + + if self._snapshot_mtime_ns == stat.st_mtime_ns: + availability = self._availability + snapshot_user = availability.user if availability.reason == "ok" else None + snapshot_identity = ( + (snapshot_user.hawk_id, snapshot_user.hawk_session) + if snapshot_user is not None + else ("", "") + ) + for user in availability.users: + if (user.hawk_id, user.hawk_session) == snapshot_identity: + return user, "ok", () + if snapshot_user is not None: + return snapshot_user, "ok", () + return None, availability.reason, availability.missing_fields + + self._snapshot_mtime_ns = stat.st_mtime_ns + try: + parsed = json.loads(self.snapshot_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + return None, "missing_snapshot_or_user_data", () + if not isinstance(parsed, dict): + return None, "missing_snapshot_or_user_data", () + + user_data = parsed.get("user_data") + if not isinstance(user_data, dict): + return None, "missing_snapshot_or_user_data", () + + missing_fields = tuple(self._missing_user_fields(user_data)) + if missing_fields: + return None, "incomplete_cloud_user_data", missing_fields + + return self._build_user(user_data), "ok", () + + def _load_persisted_session_records_locked(self) -> tuple[dict[str, Any], ...]: + if self.session_store_path is None: + self._session_store_mtime_ns = None + self._persisted_session_records = () + return () + + try: + stat = self.session_store_path.stat() + except OSError: + self._session_store_mtime_ns = None + self._persisted_session_records = () + return () + + if self._session_store_mtime_ns == stat.st_mtime_ns: + return self._persisted_session_records + + self._session_store_mtime_ns = stat.st_mtime_ns + try: + parsed = json.loads(self.session_store_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + self._persisted_session_records = () + return () + + sessions = parsed.get("sessions") if isinstance(parsed, dict) else None + if not isinstance(sessions, list): + self._persisted_session_records = () + return () + + normalized_records: list[dict[str, Any]] = [] + for raw_record in sessions: + if isinstance(raw_record, dict): + user_data = raw_record.get("user_data") if isinstance(raw_record.get("user_data"), dict) else raw_record + if isinstance(user_data, dict): + normalized_records.append(dict(raw_record)) + self._persisted_session_records = tuple(normalized_records) + return self._persisted_session_records + + def _persist_session_records_locked(self, records: list[dict[str, Any]]) -> None: + if self.session_store_path is None: + return + payload = {"version": 1, "sessions": records[: self.max_persisted_sessions]} + self.session_store_path.parent.mkdir(parents=True, exist_ok=True) + self.session_store_path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8") + try: + stat = self.session_store_path.stat() + except OSError: + self._session_store_mtime_ns = None + else: + self._session_store_mtime_ns = stat.st_mtime_ns + self._persisted_session_records = tuple(payload["sessions"]) + + def _persisted_users_locked(self) -> list[ProtocolUserData]: + records = self._load_persisted_session_records_locked() + users: list[ProtocolUserData] = [] + seen: set[tuple[str, str]] = set() + for raw_record in records: + user_data = raw_record.get("user_data") if isinstance(raw_record.get("user_data"), dict) else raw_record + if not isinstance(user_data, dict): + continue + missing_fields = self._missing_user_fields(user_data) + if missing_fields: + continue + user = self._build_user(user_data) + identity = (user.hawk_id, user.hawk_session) + if identity in seen: + continue + seen.add(identity) + users.append(user) + return users + + def _refresh_locked(self) -> None: + snapshot_user, snapshot_reason, snapshot_missing_fields = self._load_snapshot_user_locked() + persisted_users = self._persisted_users_locked() + + users: list[ProtocolUserData] = [] + seen: set[tuple[str, str]] = set() + for candidate in [snapshot_user, *persisted_users]: + if candidate is None: + continue + identity = (candidate.hawk_id, candidate.hawk_session) + if identity in seen: + continue + seen.add(identity) + users.append(candidate) + + if users: + self._availability = ProtocolAvailability( + user=users[0], + users=tuple(users), + reason="ok", + ) + return + + self._availability = ProtocolAvailability( + user=None, + users=(), + reason=snapshot_reason, + missing_fields=snapshot_missing_fields, + ) + + def availability(self) -> ProtocolAvailability: + with self._lock: + self._refresh_locked() + return self._availability + + def persisted_sessions(self) -> list[dict[str, Any]]: + with self._lock: + records = self._load_persisted_session_records_locked() + return [_clone_json_value(dict(record)) for record in records] + + def remove_session(self, *, hawk_id: str, hawk_session: str) -> bool: + normalized_identity = (_clean_str(hawk_id), _clean_str(hawk_session)) + if not all(normalized_identity): + return False + + with self._lock: + existing_records = list(self._load_persisted_session_records_locked()) + filtered_records: list[dict[str, Any]] = [] + removed = False + for existing_record in existing_records: + existing_user = ( + existing_record.get("user_data") + if isinstance(existing_record.get("user_data"), dict) + else existing_record + ) + if isinstance(existing_user, Mapping) and _session_identity(existing_user) == normalized_identity: + removed = True + continue + filtered_records.append(existing_record) + if not removed: + return False + self._persist_session_records_locked(filtered_records) + self._refresh_locked() + return True + + def expected_user_mqtt_credentials(self) -> tuple[str, str] | None: + availability = self.availability() + if availability.user is None: + return None + return availability.user.mqtt_username, availability.user.mqtt_password + + def verify_user_mqtt_credentials(self, username: str, password: str) -> tuple[bool, str, ProtocolUserData | None]: + availability = self.availability() + if not availability.users: + return False, availability.reason, None + for user in availability.users: + if username == user.mqtt_username and password == user.mqtt_password: + return True, "user_hash", user + return False, "invalid_mqtt_credentials", None + + def upsert_user_data(self, user_data: Mapping[str, Any], *, source: str = "") -> ProtocolUserData: + normalized_user_data = dict(user_data) + missing_fields = tuple(self._missing_user_fields(normalized_user_data)) + if missing_fields: + raise ValueError(f"incomplete protocol user_data: {', '.join(missing_fields)}") + + updated_at_utc = datetime.now(timezone.utc).isoformat() + persisted_user_data = _minimal_session_user_data( + normalized_user_data, + source=source, + updated_at_utc=updated_at_utc, + ) + persisted_record = { + "updated_at_utc": updated_at_utc, + "source": _clean_str(source), + "user_data": persisted_user_data, + } + identity = _session_identity(persisted_user_data) + + with self._lock: + existing_records = list(self._load_persisted_session_records_locked()) + filtered_records: list[dict[str, Any]] = [] + for existing_record in existing_records: + existing_user = ( + existing_record.get("user_data") + if isinstance(existing_record.get("user_data"), dict) + else existing_record + ) + if isinstance(existing_user, Mapping) and _session_identity(existing_user) == identity: + continue + filtered_records.append(existing_record) + filtered_records.insert(0, persisted_record) + self._persist_session_records_locked(filtered_records) + self._refresh_locked() + + return self._build_user(persisted_user_data) + + def issue_local_session(self, base_user_data: Mapping[str, Any], *, source: str = "") -> dict[str, Any]: + if not isinstance(base_user_data, Mapping): + raise ValueError("base_user_data must be a mapping") + + issued_user_data = _clone_json_value(dict(base_user_data)) + if not isinstance(issued_user_data, dict): + raise ValueError("base_user_data must be a mapping") + + rruid = _clean_str(issued_user_data.get("rruid")) + if not rruid: + raise ValueError("base_user_data is missing rruid") + + issued_user_data["token"] = f"rr{secrets.token_hex(16)}" + rriot_value = issued_user_data.get("rriot") + rriot = dict(rriot_value) if isinstance(rriot_value, dict) else {} + rriot["u"] = secrets.token_hex(11) + rriot["s"] = secrets.token_hex(6) + rriot["h"] = secrets.token_hex(16) + rriot["k"] = secrets.token_hex(16) + issued_user_data["rriot"] = rriot + + self.upsert_user_data(issued_user_data, source=source or "local_issued") + return issued_user_data + + def upsert_snapshot_user(self, *, source: str = "cloud_snapshot") -> ProtocolUserData: + with self._lock: + try: + parsed = json.loads(self.snapshot_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError) as exc: + raise ValueError("missing_snapshot_or_user_data") from exc + if not isinstance(parsed, dict) or not isinstance(parsed.get("user_data"), dict): + raise ValueError("missing_snapshot_or_user_data") + return self.upsert_user_data(parsed["user_data"], source=source) + + def verify_token(self, headers: Mapping[str, str]) -> tuple[bool, str]: + availability = self.availability() + if not availability.users: + return False, availability.reason + + authorization = _clean_str(headers.get("authorization")) + if not authorization: + return False, "missing_authorization" + if authorization.lower().startswith("bearer "): + authorization = authorization[7:].strip() + matched_user = next((user for user in availability.users if authorization == user.token), None) + if matched_user is None: + return False, "invalid_token" + + header_username = _clean_str(headers.get("header_username")) + if header_username and header_username != matched_user.rruid: + return False, "invalid_header_username" + return True, "ok" + + def verify_hawk( + self, + *, + path: str, + query_params: dict[str, list[str]], + body_params: dict[str, list[str]], + headers: Mapping[str, str], + now_ts: float | None = None, + ) -> tuple[bool, str]: + availability = self.availability() + if not availability.users: + return False, availability.reason + + hawk = _parse_hawk_authorization(headers.get("authorization", "")) + if hawk is None: + return False, "missing_authorization" + + hawk_id = _clean_str(hawk.get("id")) + matching_id_users = [user for user in availability.users if user.hawk_id == hawk_id] + if not matching_id_users: + return False, "invalid_hawk_id" + + hawk_session = _clean_str(hawk.get("s")) + user = next((candidate for candidate in matching_id_users if candidate.hawk_session == hawk_session), None) + if user is None: + return False, "invalid_hawk_session" + + try: + timestamp = int(_clean_str(hawk.get("ts"))) + except ValueError: + return False, "invalid_hawk_timestamp" + current_ts = int(time.time() if now_ts is None else now_ts) + if abs(current_ts - timestamp) > self.hawk_clock_skew_seconds: + return False, "stale_hawk_timestamp" + + nonce = _clean_str(hawk.get("nonce")) + if not nonce: + return False, "missing_hawk_nonce" + + expected_mac = _build_hawk_mac( + hawk_id=user.hawk_id, + hawk_session=user.hawk_session, + hawk_key=user.hawk_key, + path=path, + query_values=_normalize_param_values(query_params), + form_values=_normalize_param_values(body_params, include_json=True), + timestamp=timestamp, + nonce=nonce, + ) + if not hmac.compare_digest(_clean_str(hawk.get("mac")), expected_mac): + return False, "invalid_hawk_mac" + + nonce_key = f"{user.hawk_id}:{user.hawk_session}:{timestamp}:{nonce}" + expires_at = float(current_ts + self.hawk_nonce_ttl_seconds) + with self._lock: + stale_keys = [key for key, value in self._nonces.items() if value <= current_ts] + for key in stale_keys: + self._nonces.pop(key, None) + if nonce_key in self._nonces: + return False, "replayed_hawk_nonce" + self._nonces[nonce_key] = expires_at + return True, "ok" diff --git a/src/roborock_local_server/bundled_backend/shared/routine_runner.py b/src/roborock_local_server/bundled_backend/shared/routine_runner.py index ae0fccd..8f0b934 100644 --- a/src/roborock_local_server/bundled_backend/shared/routine_runner.py +++ b/src/roborock_local_server/bundled_backend/shared/routine_runner.py @@ -440,9 +440,9 @@ def _create_rriot(self, *, localkey: str) -> RRiot: k=localkey, r=Reference( r=self._context.region.upper(), - a=f"https://{api_host}", + a=self._context.api_url(host=api_host), m=f"tcp://127.0.0.1:{backend_port}", - l=f"https://{wood_host}", + l=self._context.wood_url(host=wood_host), ), ) diff --git a/src/roborock_local_server/bundled_backend/shared/runtime_credentials.py b/src/roborock_local_server/bundled_backend/shared/runtime_credentials.py index c237ba0..6850f92 100644 --- a/src/roborock_local_server/bundled_backend/shared/runtime_credentials.py +++ b/src/roborock_local_server/bundled_backend/shared/runtime_credentials.py @@ -74,6 +74,104 @@ def _extract_from_query(query: str) -> str: return "" +def _decode_remaining_length(data: bytes, start: int) -> tuple[int | None, int]: + multiplier = 1 + value = 0 + consumed = 0 + idx = start + while idx < len(data): + byte_val = data[idx] + consumed += 1 + value += (byte_val & 0x7F) * multiplier + if (byte_val & 0x80) == 0: + return value, consumed + multiplier *= 128 + idx += 1 + if consumed >= 4: + break + return None, 0 + + +def _decode_mqtt_string(packet: bytes, cursor: int) -> tuple[str | None, int]: + if cursor + 2 > len(packet): + return None, cursor + length = int.from_bytes(packet[cursor : cursor + 2], "big") + start = cursor + 2 + end = start + length + if end > len(packet): + return None, cursor + return packet[start:end].decode("utf-8", errors="replace"), end + + +def _decode_mqtt_binary(packet: bytes, cursor: int) -> tuple[bytes | None, int]: + if cursor + 2 > len(packet): + return None, cursor + length = int.from_bytes(packet[cursor : cursor + 2], "big") + start = cursor + 2 + end = start + length + if end > len(packet): + return None, cursor + return packet[start:end], end + + +def parse_mqtt_connect_packet(packet: bytes) -> dict[str, Any] | None: + if not packet or (packet[0] >> 4) != 1: + return None + remaining_len, remaining_len_bytes = _decode_remaining_length(packet, 1) + if remaining_len is None or remaining_len_bytes == 0: + return None + cursor = 1 + remaining_len_bytes + protocol_name, cursor = _decode_mqtt_string(packet, cursor) + if protocol_name is None or cursor + 4 > len(packet): + return None + protocol_level = packet[cursor] + connect_flags = packet[cursor + 1] + cursor += 4 + if protocol_level == 5: + property_len, property_len_bytes = _decode_remaining_length(packet, cursor) + if property_len is None or property_len_bytes == 0: + return None + cursor += property_len_bytes + property_len + client_id, cursor = _decode_mqtt_string(packet, cursor) + if client_id is None: + return None + + will_flag = (connect_flags & 0x04) != 0 + username_flag = (connect_flags & 0x80) != 0 + password_flag = (connect_flags & 0x40) != 0 + if will_flag: + if protocol_level == 5: + property_len, property_len_bytes = _decode_remaining_length(packet, cursor) + if property_len is None or property_len_bytes == 0: + return None + cursor += property_len_bytes + property_len + _, cursor = _decode_mqtt_string(packet, cursor) + _, cursor = _decode_mqtt_binary(packet, cursor) + if cursor > len(packet): + return None + + username = "" + password = "" + if username_flag: + decoded_username, cursor = _decode_mqtt_string(packet, cursor) + if decoded_username is None: + return None + username = decoded_username + if password_flag: + decoded_password, cursor = _decode_mqtt_binary(packet, cursor) + if decoded_password is None: + return None + password = decoded_password.decode("utf-8", errors="replace") + + return { + "protocol_name": protocol_name, + "protocol_level": protocol_level, + "client_id": client_id, + "username": username, + "password": password, + } + + class RuntimeCredentialsStore: """Persists stack credentials and per-device onboarding keys.""" @@ -149,6 +247,7 @@ def _normalize_device(raw: dict[str, Any]) -> dict[str, str]: "localkey": _clean_str(raw.get("localkey") or raw.get("local_key") or raw.get("localKey") or raw.get("k")), "local_key_source": _clean_str(raw.get("local_key_source") or raw.get("source")), "device_mqtt_usr": _clean_str(raw.get("device_mqtt_usr") or raw.get("mqtt_usr")), + "device_mqtt_pass": _clean_str(raw.get("device_mqtt_pass") or raw.get("mqtt_pass") or raw.get("mqtt_password")), "updated_at": _clean_str(raw.get("updated_at")), "last_nc_at": _clean_str(raw.get("last_nc_at")), "last_mqtt_seen_at": _clean_str(raw.get("last_mqtt_seen_at")), @@ -201,6 +300,7 @@ def _bootstrap_device_locked(self) -> dict[str, str] | None: "localkey": localkey, "local_key_source": "bootstrap", "device_mqtt_usr": "", + "device_mqtt_pass": "", "updated_at": "", "last_nc_at": "", "last_mqtt_seen_at": "", @@ -264,6 +364,7 @@ def _merge_device_records_locked(self, primary_index: int, secondary_index: int) "localkey", "local_key_source", "device_mqtt_usr", + "device_mqtt_pass", ): if primary.get(key) or not secondary.get(key): continue @@ -297,6 +398,7 @@ def ensure_device( localkey: str = "", local_key_source: str = "", device_mqtt_usr: str = "", + device_mqtt_pass: str = "", last_nc_at: str = "", last_mqtt_seen_at: str = "", assign_localkey: bool = False, @@ -309,6 +411,7 @@ def ensure_device( normalized_localkey = _clean_str(localkey) normalized_source = _clean_str(local_key_source) normalized_device_mqtt_usr = _clean_str(device_mqtt_usr) + normalized_device_mqtt_pass = _clean_str(device_mqtt_pass) normalized_last_nc_at = _clean_str(last_nc_at) normalized_last_mqtt_seen_at = _clean_str(last_mqtt_seen_at) @@ -340,6 +443,7 @@ def ensure_device( "localkey": "", "local_key_source": "", "device_mqtt_usr": "", + "device_mqtt_pass": "", "updated_at": "", "last_nc_at": "", "last_mqtt_seen_at": "", @@ -354,6 +458,7 @@ def ensure_device( ("model", normalized_model), ("product_id", normalized_product_id), ("device_mqtt_usr", normalized_device_mqtt_usr), + ("device_mqtt_pass", normalized_device_mqtt_pass), ("last_nc_at", normalized_last_nc_at), ("last_mqtt_seen_at", normalized_last_mqtt_seen_at), ): @@ -452,6 +557,16 @@ def resolve_device(self, *, did: str = "", duid: str = "", model: str = "") -> d return None return dict(self._devices[index]) + def resolve_device_by_mqtt_username(self, username: str) -> dict[str, str] | None: + normalized_username = _clean_str(username) + if not normalized_username: + return None + with self._lock: + for device in self._devices: + if _clean_str(device.get("device_mqtt_usr")) == normalized_username: + return dict(device) + return None + def device_for_selector(self, selector: str = "") -> dict[str, str] | None: normalized = _clean_str(selector).lower() devices = self.devices() @@ -503,6 +618,104 @@ def record_mqtt_topic(self, *, topic: str) -> None: assign_localkey=False, ) + def verify_device_mqtt_credentials(self, *, username: str, password: str) -> tuple[bool, str, dict[str, str] | None]: + normalized_username = _clean_str(username) + normalized_password = _clean_str(password) + if not normalized_username or not normalized_password: + return False, "missing_mqtt_credentials", None + + with self._lock: + for device in self._devices: + if _clean_str(device.get("device_mqtt_usr")) != normalized_username: + continue + stored_password = _clean_str(device.get("device_mqtt_pass")) + if not stored_password: + return False, "device_mqtt_password_missing", dict(device) + if stored_password == normalized_password: + return True, "device_mqtt_user", dict(device) + return False, "invalid_device_mqtt_password", dict(device) + return False, "unknown_device_mqtt_username", None + + def recover_device_mqtt_password(self, *, username: str, password: str) -> dict[str, str] | None: + normalized_username = _clean_str(username) + normalized_password = _clean_str(password) + if not normalized_username or not normalized_password: + return None + with self._lock: + for device in self._devices: + if _clean_str(device.get("device_mqtt_usr")) != normalized_username: + continue + if _clean_str(device.get("device_mqtt_pass")): + return dict(device) + device["device_mqtt_pass"] = normalized_password + device["updated_at"] = utcnow_iso() + self._save_locked() + return dict(device) + return None + + def recovery_pending_devices(self) -> list[dict[str, str]]: + with self._lock: + pending = [ + dict(device) + for device in self._devices + if _clean_str(device.get("device_mqtt_usr")) and not _clean_str(device.get("device_mqtt_pass")) + ] + pending.sort(key=lambda item: (item.get("name") or "", item.get("duid") or item.get("did") or "")) + return pending + + def backfill_device_mqtt_passwords(self, log_path: str | Path) -> int: + path = Path(log_path) + if not path.exists(): + return 0 + changed = 0 + with self._lock: + pending_usernames = { + _clean_str(device.get("device_mqtt_usr")) + for device in self._devices + if _clean_str(device.get("device_mqtt_usr")) and not _clean_str(device.get("device_mqtt_pass")) + } + if not pending_usernames: + return 0 + + recovered_by_username: dict[str, str] = {} + try: + with path.open("r", encoding="utf-8", errors="replace") as handle: + for raw_line in handle: + line = str(raw_line or "") + if " CONNECT " not in line or " hex=" not in line: + continue + hex_value = line.rsplit(" hex=", 1)[-1].strip() + if not hex_value: + continue + try: + packet = bytes.fromhex(hex_value) + except ValueError: + continue + parsed = parse_mqtt_connect_packet(packet) + if parsed is None: + continue + username = _clean_str(parsed.get("username")) + password = _clean_str(parsed.get("password")) + if username in pending_usernames and password: + recovered_by_username[username] = password + except OSError: + return 0 + + with self._lock: + for username, password in recovered_by_username.items(): + for device in self._devices: + if _clean_str(device.get("device_mqtt_usr")) != username: + continue + if _clean_str(device.get("device_mqtt_pass")): + break + device["device_mqtt_pass"] = password + device["updated_at"] = utcnow_iso() + changed += 1 + break + if changed: + self._save_locked() + return changed + def sync_inventory(self) -> None: inventory_devices = self._load_inventory_devices() key_models_by_did = self._load_key_models_by_did() @@ -531,6 +744,7 @@ def sync_inventory(self) -> None: "localkey": "", "local_key_source": "", "device_mqtt_usr": "", + "device_mqtt_pass": "", "updated_at": "", "last_nc_at": "", "last_mqtt_seen_at": "", diff --git a/src/roborock_local_server/config.py b/src/roborock_local_server/config.py index 4d13096..c60e67e 100644 --- a/src/roborock_local_server/config.py +++ b/src/roborock_local_server/config.py @@ -53,6 +53,9 @@ class AdminConfig: password_hash: str session_secret: str session_ttl_seconds: int + protocol_auth_enabled: bool + protocol_login_email: str + protocol_login_pin_hash: str @dataclass(frozen=True) @@ -74,6 +77,7 @@ class AppPaths: acme_dir: Path inventory_path: Path cloud_snapshot_path: Path + protocol_auth_sessions_path: Path runtime_credentials_path: Path device_key_state_path: Path http_jsonl_path: Path @@ -145,8 +149,8 @@ def load_config(path: str | Path) -> AppConfig: network=NetworkConfig( stack_fqdn=_require_non_empty(network.get("stack_fqdn"), "network.stack_fqdn"), bind_host=str(network.get("bind_host", "0.0.0.0")).strip() or "0.0.0.0", - https_port=_as_int(network.get("https_port"), "network.https_port", 443), - mqtt_tls_port=_as_int(network.get("mqtt_tls_port"), "network.mqtt_tls_port", 8883), + https_port=_as_int(network.get("https_port"), "network.https_port", 555), + mqtt_tls_port=_as_int(network.get("mqtt_tls_port"), "network.mqtt_tls_port", 8881), region=str(network.get("region", "us")).strip().lower() or "us", localkey=str(network.get("localkey", "")).strip(), duid=str(network.get("duid", "")).strip(), @@ -179,11 +183,19 @@ def load_config(path: str | Path) -> AppConfig: password_hash=_require_non_empty(admin.get("password_hash"), "admin.password_hash"), session_secret=_require_non_empty(admin.get("session_secret"), "admin.session_secret"), session_ttl_seconds=_as_int(admin.get("session_ttl_seconds"), "admin.session_ttl_seconds", 86400), + protocol_auth_enabled=_as_bool(admin.get("protocol_auth_enabled"), True), + protocol_login_email=_require_non_empty(admin.get("protocol_login_email"), "admin.protocol_login_email"), + protocol_login_pin_hash=_require_non_empty( + admin.get("protocol_login_pin_hash"), + "admin.protocol_login_pin_hash", + ), ), ) if len(config.admin.session_secret) < 24: raise ValueError("admin.session_secret must be at least 24 characters") + if "@" not in config.admin.protocol_login_email: + raise ValueError("admin.protocol_login_email must be an email address") if config.broker.mode == "external": _require_non_empty(config.broker.host, "broker.host") @@ -230,6 +242,7 @@ def resolve_paths(config_file: str | Path, config: AppConfig) -> AppPaths: acme_dir=acme_dir, inventory_path=runtime_dir / "web_api_inventory.json", cloud_snapshot_path=runtime_dir / "web_api_inventory_full_snapshot.json", + protocol_auth_sessions_path=state_dir / "protocol_auth_sessions.json", runtime_credentials_path=runtime_dir / "runtime_credentials.json", device_key_state_path=state_dir / "device_key_state.json", http_jsonl_path=runtime_dir / "decompiled_http.jsonl", diff --git a/src/roborock_local_server/configure.py b/src/roborock_local_server/configure.py index 23c4125..1c77b0c 100644 --- a/src/roborock_local_server/configure.py +++ b/src/roborock_local_server/configure.py @@ -41,6 +41,8 @@ def hash_password(password: str, *, iterations: int = 600_000) -> str: @dataclass(frozen=True) class ConfigureAnswers: stack_fqdn: str + https_port: int + mqtt_tls_port: int broker_mode: str tls_mode: str base_domain: str @@ -48,6 +50,8 @@ class ConfigureAnswers: cloudflare_token: str password_hash: str session_secret: str + protocol_login_email: str + protocol_login_pin_hash: str @dataclass(frozen=True) @@ -101,6 +105,21 @@ def _prompt_hostname(prompt: str, *, field_name: str) -> str: print(exc) +def _prompt_port(prompt: str, *, default: int) -> int: + while True: + raw_value = input(f"{prompt} [{default}]: ").strip() + if not raw_value: + return default + try: + port = int(raw_value) + except ValueError: + print("Please enter a valid port number.") + continue + if 1 <= port <= 65535: + return port + print("Port must be between 1 and 65535.") + + def _prompt_yes_no(prompt: str, *, default: bool) -> bool: suffix = "Y/n" if default else "y/N" while True: @@ -122,12 +141,44 @@ def _prompt_password() -> str: print("A password is required.") +def _prompt_protocol_login_email() -> str: + while True: + email = _prompt_non_empty("Protocol login email for app/HA sign-in: ") + if "@" in email: + return email + print("Protocol login email must look like an email address.") + + +def _validate_protocol_login_pin(pin: str) -> str: + normalized = str(pin or "").strip() + if len(normalized) != 6 or not normalized.isdigit(): + raise ValueError("Protocol login PIN must be exactly 6 digits.") + return normalized + + +def _prompt_protocol_login_pin() -> str: + while True: + pin = getpass("Protocol login PIN (6 digits, input hidden): ").strip() + try: + normalized_pin = _validate_protocol_login_pin(pin) + except ValueError as exc: + print(exc) + continue + confirmation = getpass("Confirm protocol login PIN: ").strip() + if normalized_pin != confirmation: + print("PIN entries did not match.") + continue + return normalized_pin + + def collect_configure_answers() -> ConfigureAnswers: print("This writes a small config.toml with opinionated defaults.") stack_fqdn = _prompt_hostname( "Stack FQDN (hostname only (no 'https://'); it needs to start with api-): ", field_name="stack_fqdn", ) + https_port = _prompt_port("HTTPS port to advertise and listen on", default=555) + mqtt_tls_port = _prompt_port("MQTT TLS port to advertise and listen on", default=8881) use_external_broker = _prompt_yes_no("Use your own MQTT broker instead of the embedded one?", default=False) use_cloudflare_acme = _prompt_yes_no("Use Cloudflare DNS-01 for automatic TLS renewal?", default=True) @@ -149,8 +200,12 @@ def collect_configure_answers() -> ConfigureAnswers: cloudflare_token = getpass("Cloudflare API token (input hidden): ").strip() password = _prompt_password() + protocol_login_email = _prompt_protocol_login_email() + protocol_login_pin = _prompt_protocol_login_pin() return ConfigureAnswers( stack_fqdn=stack_fqdn, + https_port=https_port, + mqtt_tls_port=mqtt_tls_port, broker_mode=broker_mode, tls_mode=tls_mode, base_domain=base_domain, @@ -158,6 +213,8 @@ def collect_configure_answers() -> ConfigureAnswers: cloudflare_token=cloudflare_token, password_hash=hash_password(password), session_secret=secrets.token_urlsafe(32), + protocol_login_email=protocol_login_email, + protocol_login_pin_hash=hash_password(protocol_login_pin), ) @@ -166,8 +223,8 @@ def render_config_toml(answers: ConfigureAnswers) -> str: "[network]", f"stack_fqdn = {_toml_string(answers.stack_fqdn)}", 'bind_host = "0.0.0.0"', - "https_port = 443", - "mqtt_tls_port = 8883", + f"https_port = {answers.https_port}", + f"mqtt_tls_port = {answers.mqtt_tls_port}", 'region = "us"', "", "[broker]", @@ -235,6 +292,9 @@ def render_config_toml(answers: ConfigureAnswers) -> str: f"password_hash = {_toml_string(answers.password_hash)}", f"session_secret = {_toml_string(answers.session_secret)}", "session_ttl_seconds = 86400", + "protocol_auth_enabled = true", + f"protocol_login_email = {_toml_string(answers.protocol_login_email)}", + f"protocol_login_pin_hash = {_toml_string(answers.protocol_login_pin_hash)}", "", ] ) diff --git a/src/roborock_local_server/server.py b/src/roborock_local_server/server.py index e47b8b8..ef5f0dd 100644 --- a/src/roborock_local_server/server.py +++ b/src/roborock_local_server/server.py @@ -45,11 +45,24 @@ start_broker, strip_roborock_prefix, ) +from shared.protocol_auth import ProtocolAuthStore +from https_server.routes.auth.service import ( + build_login_data_response, + cloud_login_data_required_response, +) from .bundled_backend.shared.zone_ranges_store import ZoneRangesStore -from .security import AdminSessionManager +from .security import AdminSessionManager, verify_password ALL_HTTP_METHODS = ("GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS") +PROTOCOL_AUTH_SYNC_PATH = "/internal/protocol/user-data" +PROTOCOL_AUTH_SYNC_SECRET_HEADER = "x-local-sync-secret" +_REGION_COUNTRY_CODE = { + "US": "1", + "CN": "86", + "EU": "49", + "RU": "7", +} PROJECT_SUPPORT = { "title": "Support This Project", "text": ( @@ -64,7 +77,7 @@ "url": "https://us.roborock.com/discount/RRSAP202602071713342D18X?redirect=%2Fpages%2Froborock-store%3Fuuid%3DEQe6p1jdZczHEN4Q0nbsG9sZRm0RK1gW5eSM%252FCzcW4Q%253D", }, {"label": "Roborock Affiliate", "url": "https://roborock.pxf.io/B0VYV9"}, - {"label": "Amazon Affiliate", "url": "https://amzn.to/4bGfG6B"}, + {"label": "Amazon Affiliate", "url": "https://amzn.to/4cx8zg3"}, ], } @@ -73,11 +86,16 @@ def _request_query_params(request: Request) -> dict[str, list[str]]: return parse_qs(request.url.query, keep_blank_values=True) -def _request_body_params(raw_body: bytes) -> tuple[str, dict[str, list[str]]]: +def _request_body_params(raw_body: bytes, *, content_type: str = "") -> tuple[str, dict[str, list[str]]]: body_text = raw_body.decode("utf-8", errors="replace") if not body_text: return "", {} - return body_text, parse_qs(body_text, keep_blank_values=True) + body_params = parse_qs(body_text, keep_blank_values=True) + stripped = body_text.lstrip() + if "json" in content_type.lower() or stripped.startswith(("{", "[")): + body_params = dict(body_params) + body_params.setdefault("__json", [body_text]) + return body_text, body_params def _pick_first_header(headers: dict[str, str], keys: tuple[str, ...]) -> str: @@ -88,6 +106,47 @@ def _pick_first_header(headers: dict[str, str], keys: tuple[str, ...]) -> str: return "" +def _request_json_object(body_params: dict[str, list[str]]) -> dict[str, Any]: + for raw in body_params.get("__json") or []: + try: + parsed = json.loads(raw) + except (TypeError, json.JSONDecodeError): + continue + if isinstance(parsed, dict): + return parsed + return {} + + +def _request_value( + query_params: dict[str, list[str]], + body_params: dict[str, list[str]], + *keys: str, +) -> str: + json_body = _request_json_object(body_params) + for key in keys: + for value in query_params.get(key, []): + candidate = str(value).strip() + if candidate: + return candidate + for value in body_params.get(key, []): + candidate = str(value).strip() + if candidate: + return candidate + json_value = json_body.get(key) + candidate = str(json_value).strip() if json_value is not None else "" + if candidate: + return candidate + return "" + + +def _headers_for_log(headers: dict[str, str]) -> dict[str, str]: + normalized = dict(headers) + for key in list(normalized): + if str(key).strip().lower() == PROTOCOL_AUTH_SYNC_SECRET_HEADER: + normalized[key] = "" + return normalized + + def _extract_explicit_pid( query_params: dict[str, list[str]], body_params: dict[str, list[str]], @@ -267,6 +326,10 @@ def __init__( inventory_path=self.paths.inventory_path, snapshot_path=self.paths.cloud_snapshot_path, ) + self.protocol_auth = ProtocolAuthStore( + self.paths.cloud_snapshot_path, + session_store_path=self.paths.protocol_auth_sessions_path, + ) self.loggers = self._setup_loggers() if not self.paths.device_key_state_path.exists(): @@ -327,17 +390,24 @@ def __init__( mqtt_backend_port=self.config.broker.port, ) self.runtime_credentials.sync_inventory() + recovered_device_passwords = self.runtime_credentials.backfill_device_mqtt_passwords( + self.paths.runtime_dir / "mqtt_server.log" + ) + if recovered_device_passwords: + self.root_logger.info("Recovered %d device MQTT password(s) from mqtt_server.log", recovered_device_passwords) self.context = ServerContext( api_host=self.config.network.stack_fqdn, mqtt_host=self.config.network.stack_fqdn, wood_host=self.config.network.stack_fqdn, region=self.config.network.region, + protocol_login_email=self.config.admin.protocol_login_email, localkey=self._bootstrap_credentials["localkey"], duid=self._bootstrap_credentials["duid"], mqtt_usr=self._bootstrap_credentials["mqtt_usr"], mqtt_passwd=self._bootstrap_credentials["mqtt_passwd"], mqtt_clientid=self._bootstrap_credentials["mqtt_clientid"], + https_port=self.config.network.https_port, mqtt_tls_port=self.config.network.mqtt_tls_port, http_jsonl=self.paths.http_jsonl_path, mqtt_jsonl=self.paths.mqtt_jsonl_path, @@ -408,6 +478,346 @@ def _require_admin(self, request: Request) -> None: if not self._authenticated(request): raise HTTPException(status_code=401, detail="Authentication required") + def protocol_auth_enabled(self) -> bool: + return bool(self.config.admin.protocol_auth_enabled) + + def _protocol_login_email(self) -> str: + return str(self.config.admin.protocol_login_email or "").strip() + + @staticmethod + def _protocol_login_pin_valid(pin: str) -> bool: + normalized = str(pin or "").strip() + return len(normalized) == 6 and normalized.isdigit() + + def _protocol_login_email_matches(self, email: str) -> bool: + configured_email = self._protocol_login_email() + normalized_email = str(email or "").strip() + return bool( + configured_email + and normalized_email + and configured_email.casefold() == normalized_email.casefold() + ) + + def _protocol_login_pin_matches(self, pin: str) -> bool: + normalized_pin = str(pin or "").strip() + return self._protocol_login_pin_valid(normalized_pin) and verify_password( + normalized_pin, + self.config.admin.protocol_login_pin_hash, + ) + + @staticmethod + def _default_country_code_for_region(region: str) -> str: + return _REGION_COUNTRY_CODE.get(str(region or "").upper(), "1") + + def _local_protocol_identity(self) -> dict[str, Any]: + email = self._protocol_login_email() + normalized_email = email.casefold() + digest = hashlib.sha256(normalized_email.encode("utf-8")).hexdigest() + uid = (int(digest[:12], 16) % 900_000_000) + 100_000_000 + region_upper = self.config.network.region.upper() + return { + "uid": uid, + "rruid": f"rrls-{digest[:20]}", + "email": email, + "country": region_upper, + "countrycode": self._default_country_code_for_region(region_upper), + "nickname": "Local User", + "hasPassword": True, + "rriot": { + "r": { + "r": region_upper, + } + }, + } + + @staticmethod + def _normalized_path(path: str) -> str: + normalized = str(path or "").rstrip("/") + return normalized or "/" + + @classmethod + def _is_public_protocol_path(cls, clean_path: str) -> bool: + normalized = cls._normalized_path(clean_path) + return normalized in { + "/", + "/region", + "/time", + "/location", + "/nc/prepare", + "/api/v1/getUrlByEmail", + "/api/v1/ml/c", + "/api/v1/sendEmailCode", + "/api/v1/sendSmsCode", + "/api/v1/validateEmailCode", + "/api/v1/validateSmsCode", + "/api/v1/loginWithCode", + "/api/v3/key/sign", + "/api/v3/sms/sendCode", + "/api/v4/key/captcha", + "/api/v4/email/code/send", + "/api/v4/sms/code/send", + "/api/v4/email/code/validate", + "/api/v4/sms/code/validate", + "/api/v4/auth/email/login/code", + "/api/v4/auth/phone/login/code", + "/api/v4/auth/mobile/login/code", + "/api/v5/email/code/send", + "/api/v5/sms/code/send", + "/api/v5/email/code/validate", + "/api/v5/sms/code/validate", + "/api/v5/auth/email/login/code", + "/api/v5/auth/phone/login/code", + "/api/v5/auth/mobile/login/code", + "/api/v1/country/version", + "/api/v1/country/list", + "/api/v1/appconfig", + "/api/v2/appconfig", + "/api/v1/appfeatureplugin", + "/api/v1/appplugin", + "/api/v1/plugins", + "/api/v4/agreement/latest", + } + + @classmethod + def _is_code_send_path(cls, clean_path: str) -> bool: + normalized = cls._normalized_path(clean_path) + return normalized in { + "/api/v1/sendEmailCode", + "/api/v1/sendSmsCode", + "/api/v3/sms/sendCode", + "/api/v4/email/code/send", + "/api/v4/sms/code/send", + "/api/v5/email/code/send", + "/api/v5/sms/code/send", + } + + @classmethod + def _is_code_validate_path(cls, clean_path: str) -> bool: + normalized = cls._normalized_path(clean_path) + return normalized in { + "/api/v1/validateEmailCode", + "/api/v1/validateSmsCode", + "/api/v4/email/code/validate", + "/api/v4/sms/code/validate", + "/api/v5/email/code/validate", + "/api/v5/sms/code/validate", + } + + @classmethod + def _is_code_submit_path(cls, clean_path: str) -> bool: + normalized = cls._normalized_path(clean_path) + return normalized in { + "/api/v1/loginWithCode", + "/api/v4/auth/email/login/code", + "/api/v4/auth/phone/login/code", + "/api/v4/auth/mobile/login/code", + "/api/v5/auth/email/login/code", + "/api/v5/auth/phone/login/code", + "/api/v5/auth/mobile/login/code", + } + + @classmethod + def _is_password_login_path(cls, clean_path: str) -> bool: + normalized = cls._normalized_path(clean_path) + return normalized in { + "/api/v1/login", + "/api/v3/auth/email/login", + "/api/v3/auth/phone/login", + "/api/v3/auth/mobile/login", + "/api/v5/auth/email/login/pwd", + "/api/v5/auth/phone/login/pwd", + "/api/v5/auth/mobile/login/pwd", + } + + @classmethod + def _is_password_reset_path(cls, clean_path: str) -> bool: + normalized = cls._normalized_path(clean_path) + return normalized in { + "/api/v5/user/password/mobile/reset", + "/api/v5/user/password/email/reset", + } + + @classmethod + def _required_protocol_auth(cls, clean_path: str) -> str | None: + normalized = cls._normalized_path(clean_path) + if cls._is_public_protocol_path(normalized): + return None + if normalized.startswith(("/user/", "/v2/user/", "/v3/user/")): + return "hawk" + if normalized.startswith("/api/"): + return "token" + return None + + def _required_protocol_auth_for_request(self, clean_path: str) -> str | None: + if not self.protocol_auth_enabled(): + return None + return self._required_protocol_auth(clean_path) + + def _login_send_success_payload(self) -> dict[str, Any]: + return {"code": 200, "msg": "success", "data": {"sent": True, "validForSec": 300}} + + def _login_validate_success_payload(self) -> dict[str, Any]: + return {"code": 200, "msg": "success", "data": {"valid": True}} + + @staticmethod + def _invalid_login_credentials_payload(reason: str) -> tuple[int, dict[str, Any]]: + return 401, { + "code": 2010, + "msg": "invalid_credentials", + "data": {"reason": reason, "auth": "code"}, + } + + @staticmethod + def _unsupported_password_login_payload() -> dict[str, Any]: + return { + "code": 40031, + "msg": "password_login_not_supported", + "data": {"reason": "code_login_only"}, + } + + @staticmethod + def _protocol_auth_failure_response(reason: str, auth_kind: str) -> tuple[int, dict[str, Any]]: + if auth_kind == "token": + # Home Assistant's python-roborock client only starts reauth when + # Roborock's web API returns the invalid-credentials code it + # already maps to RoborockInvalidCredentials. + return 401, { + "code": 2010, + "msg": "invalid_credentials", + "data": {"reason": reason, "auth": auth_kind}, + } + return 401, { + "code": 40101, + "msg": "authentication_required", + "data": {"reason": reason, "auth": auth_kind}, + } + + def _protocol_auth_not_ready_payload(self) -> tuple[int, dict[str, Any]]: + availability = self.protocol_auth.availability() + payload = cloud_login_data_required_response( + self.context, + reason=availability.reason, + missing_fields=list(availability.missing_fields) or None, + ) + return 412, payload + + @classmethod + def _is_protocol_sync_path(cls, clean_path: str) -> bool: + return cls._normalized_path(clean_path) == PROTOCOL_AUTH_SYNC_PATH + + @staticmethod + def _protocol_sync_success_payload(*, source: str) -> dict[str, Any]: + return { + "code": 200, + "msg": "success", + "data": {"stored": True, "source": source}, + } + + @staticmethod + def _protocol_sync_failure_payload(*, reason: str, detail: str = "") -> dict[str, Any]: + payload: dict[str, Any] = {"reason": reason} + if detail: + payload["detail"] = detail + return {"code": 40041, "msg": "protocol_sync_failed", "data": payload} + + def _sync_secret_matches(self, headers: dict[str, str]) -> bool: + provided = _pick_first_header(headers, ("x-local-sync-secret", "X-Local-Sync-Secret")) + expected = str(self.config.admin.session_secret or "").strip() + return bool(expected) and secrets.compare_digest(provided, expected) + + async def _handle_protocol_sync_route( + self, + *, + method: str, + clean_path: str, + headers: dict[str, str], + body_params: dict[str, list[str]], + ) -> tuple[str, int, dict[str, Any]] | None: + if not self._is_protocol_sync_path(clean_path): + return None + if method.upper() != "POST": + return "protocol_auth_sync_method_not_allowed", 405, self._protocol_sync_failure_payload( + reason="method_not_allowed" + ) + if not self._sync_secret_matches(headers): + return "protocol_auth_sync_unauthorized", 401, self._protocol_sync_failure_payload( + reason="invalid_sync_secret" + ) + + payload = _request_json_object(body_params) + user_data = payload.get("user_data") + if not isinstance(user_data, dict): + return "protocol_auth_sync_invalid_payload", 400, self._protocol_sync_failure_payload( + reason="missing_user_data" + ) + + source = str(payload.get("source") or "mitm_cloud_login").strip() or "mitm_cloud_login" + try: + self.protocol_auth.upsert_user_data(user_data, source=source) + except ValueError as exc: + return "protocol_auth_sync_invalid_payload", 400, self._protocol_sync_failure_payload( + reason="invalid_user_data", + detail=str(exc), + ) + + return "protocol_auth_sync", 200, self._protocol_sync_success_payload(source=source) + + async def _handle_protocol_login_route( + self, + *, + clean_path: str, + query_params: dict[str, list[str]], + body_params: dict[str, list[str]], + ) -> tuple[str, int, dict[str, Any]] | None: + normalized = self._normalized_path(clean_path) + if self._is_code_send_path(normalized): + return "protocol_login_request_code", 200, self._login_send_success_payload() + + if self._is_code_validate_path(normalized): + return "protocol_login_validate_code", 200, self._login_validate_success_payload() + + if self._is_code_submit_path(normalized): + account = _request_value( + query_params, + body_params, + "email", + "username", + "account", + "mobile", + "phone", + ) + code = _request_value( + query_params, + body_params, + "code", + "verifyCode", + "emailCode", + "smsCode", + ) + if not self._protocol_login_email_matches(account): + status_code, payload = self._invalid_login_credentials_payload("invalid_login_email") + return "protocol_login_submit_code", status_code, payload + if not self._protocol_login_pin_matches(code): + status_code, payload = self._invalid_login_credentials_payload("invalid_login_pin") + return "protocol_login_submit_code", status_code, payload + try: + issued_user_data = self.protocol_auth.issue_local_session( + self._local_protocol_identity(), + source="protocol_code_login", + ) + except ValueError as exc: + return "protocol_login_submit_code", 400, { + "code": 40023, + "msg": "local_session_issue_failed", + "data": {"error": str(exc)}, + } + return "protocol_login_submit_code", 200, build_login_data_response(self.context, issued_user_data) + + if self._is_password_login_path(normalized) or self._is_password_reset_path(normalized): + return "protocol_login_password_unsupported", 400, self._unsupported_password_login_payload() + + return None + async def _handle_roborock_request(self, request: Request) -> Response: host = (request.headers.get("host") or "").strip() group = classify_host(host) @@ -415,14 +825,18 @@ async def _handle_roborock_request(self, request: Request) -> Response: raw_body = await request.body() clean_path = strip_roborock_prefix(request.url.path) query_params = _request_query_params(request) - body_text, body_params = _request_body_params(raw_body) + body_text, body_params = _request_body_params( + raw_body, + content_type=str(request.headers.get("content-type") or ""), + ) body_sha256 = hashlib.sha256(raw_body).hexdigest() + is_protocol_sync_request = self._is_protocol_sync_path(clean_path) if host: - host_no_port = host.split(":", 1)[0].strip() - if host_no_port: + host_authority = host.strip() + if host_authority: query_params = {key: list(values) for key, values in query_params.items()} - query_params.setdefault("__host", [host_no_port]) + query_params.setdefault("__host", [host_authority]) explicit_did = self.context.extract_explicit_did(query_params, body_params) explicit_pid = _extract_explicit_pid(query_params, body_params) @@ -476,17 +890,20 @@ async def _handle_roborock_request(self, request: Request) -> Response: "raw_path": raw_path, "clean_path": clean_path, "query": {key: value for key, value in query_params.items()}, - "headers": dict(request.headers), + "headers": _headers_for_log(dict(request.headers)), "body_len": len(raw_body), "body_sha256": body_sha256, - "body_b64": base64.b64encode(raw_body).decode("ascii"), "remote": f"{client_host}:{client_port}", } if explicit_did: entry["did"] = explicit_did if explicit_pid: entry["pid"] = explicit_pid - if body_text: + if is_protocol_sync_request: + entry["body_redacted"] = True + else: + entry["body_b64"] = base64.b64encode(raw_body).decode("ascii") + if body_text and not is_protocol_sync_request: entry["body_text"] = body_text try: entry["body_json"] = json.loads(body_text) @@ -501,6 +918,151 @@ async def _handle_roborock_request(self, request: Request) -> Response: "header_sample_added": header_sample_added, } + custom_sync = await self._handle_protocol_sync_route( + method=request.method, + clean_path=clean_path, + headers=dict(request.headers), + body_params=body_params, + ) + if custom_sync is not None: + route_name, status_code, response_payload = custom_sync + entry["route"] = route_name + entry["response_json"] = response_payload + try: + self.runtime_state.record_http_event( + event_time=str(entry["time"]), + route_name=route_name, + clean_path=clean_path, + raw_path=raw_path, + method=request.method, + host=host, + remote=str(entry["remote"]), + did=explicit_did or None, + pid=explicit_pid or None, + ) + except Exception as exc: # noqa: BLE001 + logger.warning("runtime_state record_http_event failed: %s", exc) + append_jsonl(self.context.http_jsonl, entry) + logger.info( + "%s %s host=%s route=%s status=%d body_sha256=%s", + request.method, + clean_path, + host or "-", + route_name, + status_code, + body_sha256[:16], + ) + return JSONResponse(response_payload, status_code=status_code) + + custom_login = await self._handle_protocol_login_route( + clean_path=clean_path, + query_params=query_params, + body_params=body_params, + ) + if custom_login is not None: + route_name, status_code, response_payload = custom_login + entry["route"] = route_name + entry["response_json"] = response_payload + try: + self.runtime_state.record_http_event( + event_time=str(entry["time"]), + route_name=route_name, + clean_path=clean_path, + raw_path=raw_path, + method=request.method, + host=host, + remote=str(entry["remote"]), + did=explicit_did or None, + pid=explicit_pid or None, + ) + except Exception as exc: # noqa: BLE001 + logger.warning("runtime_state record_http_event failed: %s", exc) + append_jsonl(self.context.http_jsonl, entry) + logger.info( + "%s %s host=%s route=%s status=%d body_sha256=%s", + request.method, + clean_path, + host or "-", + route_name, + status_code, + body_sha256[:16], + ) + return JSONResponse(response_payload, status_code=status_code) + + required_auth = self._required_protocol_auth_for_request(clean_path) + if required_auth is not None: + availability = self.protocol_auth.availability() + if availability.user is None: + status_code, response_payload = self._protocol_auth_not_ready_payload() + route_name = f"{required_auth}_auth_not_ready" + entry["route"] = route_name + entry["response_json"] = response_payload + try: + self.runtime_state.record_http_event( + event_time=str(entry["time"]), + route_name=route_name, + clean_path=clean_path, + raw_path=raw_path, + method=request.method, + host=host, + remote=str(entry["remote"]), + did=explicit_did or None, + pid=explicit_pid or None, + ) + except Exception as exc: # noqa: BLE001 + logger.warning("runtime_state record_http_event failed: %s", exc) + append_jsonl(self.context.http_jsonl, entry) + logger.info( + "%s %s host=%s route=%s status=%d body_sha256=%s", + request.method, + clean_path, + host or "-", + route_name, + status_code, + body_sha256[:16], + ) + return JSONResponse(response_payload, status_code=status_code) + + if required_auth == "token": + authenticated, auth_reason = self.protocol_auth.verify_token(request.headers) + else: + authenticated, auth_reason = self.protocol_auth.verify_hawk( + path=self._normalized_path(clean_path), + query_params=query_params, + body_params=body_params, + headers=request.headers, + ) + if not authenticated: + route_name = f"{required_auth}_auth_failed" + status_code, response_payload = self._protocol_auth_failure_response(auth_reason, required_auth) + entry["route"] = route_name + entry["response_json"] = response_payload + try: + self.runtime_state.record_http_event( + event_time=str(entry["time"]), + route_name=route_name, + clean_path=clean_path, + raw_path=raw_path, + method=request.method, + host=host, + remote=str(entry["remote"]), + did=explicit_did or None, + pid=explicit_pid or None, + ) + except Exception as exc: # noqa: BLE001 + logger.warning("runtime_state record_http_event failed: %s", exc) + append_jsonl(self.context.http_jsonl, entry) + logger.info( + "%s %s host=%s route=%s status=%d body_sha256=%s", + request.method, + clean_path, + host or "-", + route_name, + status_code, + body_sha256[:16], + ) + return JSONResponse(response_payload, status_code=status_code) + try: plugin_dispatch = await dispatch_plugin_zip_request( clean_path=clean_path, @@ -628,6 +1190,7 @@ def _status_payload(self) -> dict[str, Any]: health["connected_vacuums"] = [vac for vac in merged_vacuums if vac.get("connected")] return { "health": health, + "auth": self._auth_payload(), "pairing": self.runtime_state.pairing_snapshot(), "support": PROJECT_SUPPORT, "inventory_path": str(self.paths.inventory_path), @@ -672,6 +1235,101 @@ def _onboarding_devices_payload(self) -> dict[str, Any]: "generated_at": utcnow_iso(), } + @staticmethod + def _redacted_protocol_session(record: dict[str, Any]) -> dict[str, str]: + user_data = record.get("user_data") if isinstance(record.get("user_data"), dict) else record + if not isinstance(user_data, dict): + return {} + rriot = user_data.get("rriot") if isinstance(user_data.get("rriot"), dict) else {} + return { + "rruid": str(user_data.get("rruid") or "").strip(), + "hawk_id": str(rriot.get("u") or "").strip(), + "hawk_session": str(rriot.get("s") or "").strip(), + "source": str(record.get("source") or user_data.get("source") or "").strip(), + "updated_at_utc": str(record.get("updated_at_utc") or user_data.get("updated_at_utc") or "").strip(), + } + + def _pending_device_mqtt_recovery_payload(self) -> list[dict[str, str]]: + devices: list[dict[str, str]] = [] + for device in self.runtime_credentials.recovery_pending_devices(): + devices.append( + { + "did": str(device.get("did") or "").strip(), + "duid": str(device.get("duid") or "").strip(), + "name": str(device.get("name") or device.get("duid") or device.get("did") or "").strip(), + "model": str(device.get("model") or "").strip(), + "device_mqtt_usr": str(device.get("device_mqtt_usr") or "").strip(), + } + ) + return devices + + def _auth_payload(self) -> dict[str, Any]: + sessions = [ + session + for session in ( + self._redacted_protocol_session(record) + for record in self.protocol_auth.persisted_sessions() + ) + if session.get("hawk_id") and session.get("hawk_session") + ] + return { + "protocol_auth_enabled": self.protocol_auth_enabled(), + "protocol_sessions": sessions, + "protocol_session_count": len(sessions), + "pending_device_mqtt_recovery": self._pending_device_mqtt_recovery_payload(), + } + + def _rewrite_admin_bool_setting(self, *, key: str, value: bool) -> None: + config_path = self.paths.config_file + lines = config_path.read_text(encoding="utf-8").splitlines() + rendered_value = "true" if value else "false" + output: list[str] = [] + in_admin_section = False + admin_section_found = False + updated = False + + for line in lines: + stripped = line.strip() + is_section = stripped.startswith("[") and stripped.endswith("]") + if is_section and in_admin_section and not updated: + output.append(f"{key} = {rendered_value}") + updated = True + if stripped == "[admin]": + admin_section_found = True + in_admin_section = True + elif is_section: + in_admin_section = False + if in_admin_section: + is_comment = stripped.startswith(("#", ";")) + if not is_comment and "=" in line: + existing_key, _existing_value = line.split("=", 1) + if existing_key.strip() == key: + indent = line[: len(line) - len(line.lstrip())] + output.append(f"{indent}{key} = {rendered_value}") + updated = True + continue + output.append(line) + + if admin_section_found and in_admin_section and not updated: + output.append(f"{key} = {rendered_value}") + updated = True + + if not admin_section_found: + if output and output[-1].strip(): + output.append("") + output.extend(["[admin]", f"{key} = {rendered_value}"]) + + config_path.write_text("\n".join(output) + "\n", encoding="utf-8") + + def set_protocol_auth_enabled(self, enabled: bool) -> dict[str, Any]: + normalized_enabled = bool(enabled) + self._rewrite_admin_bool_setting(key="protocol_auth_enabled", value=normalized_enabled) + self.config = load_config(self.paths.config_file) + return self._auth_payload() + + def remove_protocol_session(self, *, hawk_id: str, hawk_session: str) -> bool: + return self.protocol_auth.remove_session(hawk_id=hawk_id, hawk_session=hawk_session) + def start_onboarding_session(self, *, duid: str) -> dict[str, Any]: normalized_duid = str(duid or "").strip() if not normalized_duid: @@ -746,11 +1404,17 @@ def _is_standalone_route_path(path: str) -> bool: def _register_protocol_routes(self, app: FastAPI) -> None: @app.get("/ui/api/health") - async def ui_health() -> JSONResponse: + async def ui_health(request: Request) -> JSONResponse: + if not self.enable_standalone_admin: + return JSONResponse({"error": "Not Found"}, status_code=404) + self._require_admin(request) return JSONResponse(self._ui_health_payload()) @app.get("/ui/api/vacuums") - async def ui_vacuums() -> JSONResponse: + async def ui_vacuums(request: Request) -> JSONResponse: + if not self.enable_standalone_admin: + return JSONResponse({"error": "Not Found"}, status_code=404) + self._require_admin(request) return JSONResponse(self._ui_vacuums_payload()) @app.api_route("/", methods=list(ALL_HTTP_METHODS)) @@ -801,6 +1465,9 @@ def _start_mqtt_proxy(self) -> None: localkey=self.context.localkey, logger=self.loggers["mqtt"], decoded_jsonl=self.context.mqtt_jsonl, + cloud_snapshot_path=self.paths.cloud_snapshot_path, + protocol_auth_sessions_path=self.paths.protocol_auth_sessions_path, + protocol_auth_enabled=self.protocol_auth_enabled, runtime_state=self.runtime_state, runtime_credentials=self.runtime_credentials, zone_ranges_store=self.context.zone_ranges_store, @@ -965,11 +1632,13 @@ def repair_runtime_identities(*, config_file: Path, links: list[str]) -> int: mqtt_host=config.network.stack_fqdn, wood_host=config.network.stack_fqdn, region=config.network.region, + protocol_login_email=config.admin.protocol_login_email, localkey=str(runtime_credentials.bootstrap_value("localkey", "") or ""), duid=str(runtime_credentials.bootstrap_value("duid", "") or ""), mqtt_usr=str(runtime_credentials.bootstrap_value("mqtt_usr", "") or ""), mqtt_passwd=str(runtime_credentials.bootstrap_value("mqtt_passwd", "") or ""), mqtt_clientid=str(runtime_credentials.bootstrap_value("mqtt_clientid", "") or ""), + https_port=config.network.https_port, mqtt_tls_port=config.network.mqtt_tls_port, http_jsonl=paths.http_jsonl_path, mqtt_jsonl=paths.mqtt_jsonl_path, diff --git a/src/roborock_local_server/standalone_admin.py b/src/roborock_local_server/standalone_admin.py index 62a012f..9b16212 100644 --- a/src/roborock_local_server/standalone_admin.py +++ b/src/roborock_local_server/standalone_admin.py @@ -58,6 +58,13 @@ def _admin_dashboard_html(project_support: dict[str, Any]) -> str:
No cloud request yet.
+

Protocol Auth

+ + +
Loading auth state...
+
+
Loading sessions...
+

Health

Vacuums

@@ -128,11 +135,73 @@ def _admin_dashboard_html(project_support: dict[str, Any]) -> str: container.appendChild(card); }} }} + function renderAuth(auth) {{ + const enabled = Boolean(auth.protocol_auth_enabled); + document.getElementById("protocolAuthEnabled").checked = enabled; + document.getElementById("authMeta").textContent = + `Protocol auth: ${{enabled ? "Enabled" : "Disabled"}}. Persisted sessions: ${{Number(auth.protocol_session_count || 0)}}.`; + + const pendingContainer = document.getElementById("pendingRecovery"); + const pendingItems = Array.isArray(auth.pending_device_mqtt_recovery) ? auth.pending_device_mqtt_recovery : []; + if (!pendingItems.length) {{ + pendingContainer.textContent = "No devices are waiting for MQTT password recovery."; + }} else {{ + pendingContainer.textContent = + "Devices waiting for first reconnect MQTT password recovery: " + + pendingItems.map((item) => item.name || item.duid || item.did || item.device_mqtt_usr).join(", "); + }} + + const sessionList = document.getElementById("sessionList"); + sessionList.innerHTML = ""; + const sessions = Array.isArray(auth.protocol_sessions) ? auth.protocol_sessions : []; + if (!sessions.length) {{ + const empty = document.createElement("div"); + empty.textContent = "No persisted protocol sessions."; + empty.style.color = "#555"; + sessionList.appendChild(empty); + return; + }} + for (const session of sessions) {{ + const card = document.createElement("div"); + card.style.border = "1px solid #ddd"; + card.style.borderRadius = "6px"; + card.style.padding = "10px"; + card.style.background = "#fafafa"; + const label = document.createElement("div"); + label.textContent = session.rruid || session.hawk_id || "Protocol session"; + label.style.fontWeight = "600"; + card.appendChild(label); + + const detail = document.createElement("div"); + detail.textContent = `source=${{session.source || "unknown"}} updated=${{session.updated_at_utc || "unknown"}} hawk_id=${{session.hawk_id || ""}}`; + detail.style.marginTop = "6px"; + detail.style.fontSize = "12px"; + card.appendChild(detail); + + const remove = document.createElement("button"); + remove.textContent = "Remove"; + remove.style.marginTop = "8px"; + remove.addEventListener("click", async () => {{ + try {{ + await fetchJson( + `/admin/api/auth/sessions/${{encodeURIComponent(session.hawk_id || "")}}/${{encodeURIComponent(session.hawk_session || "")}}`, + {{method: "DELETE"}} + ); + await refresh(); + }} catch (error) {{ + document.getElementById("authMeta").textContent = error.message; + }} + }}); + card.appendChild(remove); + sessionList.appendChild(card); + }} + }} async function refresh() {{ const status = await fetchJson("/admin/api/status"); document.getElementById("overall").textContent = status.health.overall_ok ? "Healthy" : "Needs Attention"; document.getElementById("health").textContent = JSON.stringify(status.health, null, 2); + renderAuth(await fetchJson("/admin/api/auth")); const vacuums = await fetchJson("/admin/api/vacuums"); renderVacuumSummary(vacuums.vacuums); document.getElementById("vacuums").textContent = JSON.stringify(vacuums.vacuums, null, 2); @@ -164,6 +233,21 @@ def _admin_dashboard_html(project_support: dict[str, Any]) -> str: document.getElementById("cloudResult").textContent = error.message; }} }}); + document.getElementById("saveAuth").addEventListener("click", async () => {{ + try {{ + const payload = await fetchJson("/admin/api/auth", {{ + method: "POST", + headers: {{"Content-Type":"application/json"}}, + body: JSON.stringify({{ + protocol_auth_enabled: document.getElementById("protocolAuthEnabled").checked + }}) + }}); + renderAuth(payload); + await refresh(); + }} catch (error) {{ + document.getElementById("authMeta").textContent = error.message; + }} + }}); document.getElementById("logout").addEventListener("click", async () => {{ await fetch("/admin/api/logout", {{method:"POST"}}); @@ -225,6 +309,38 @@ async def admin_vacuums(request: Request) -> JSONResponse: supervisor._require_admin(request) return JSONResponse(supervisor._vacuums_payload()) + @app.get("/admin/api/auth") + async def admin_auth(request: Request) -> JSONResponse: + supervisor._require_admin(request) + return JSONResponse(supervisor._auth_payload()) + + @app.post("/admin/api/auth") + async def admin_auth_update(request: Request) -> JSONResponse: + supervisor._require_admin(request) + try: + body = await request.json() + except json.JSONDecodeError: + return JSONResponse({"error": "Invalid JSON body"}, status_code=400) + if not isinstance(body, dict): + return JSONResponse({"error": "JSON body must be an object"}, status_code=400) + if "protocol_auth_enabled" not in body: + return JSONResponse({"error": "protocol_auth_enabled is required"}, status_code=400) + protocol_auth_enabled = body.get("protocol_auth_enabled") + if not isinstance(protocol_auth_enabled, bool): + return JSONResponse({"error": "protocol_auth_enabled must be a boolean"}, status_code=400) + try: + payload = supervisor.set_protocol_auth_enabled(protocol_auth_enabled) + except Exception as exc: # noqa: BLE001 + return JSONResponse({"error": str(exc)}, status_code=500) + return JSONResponse(payload) + + @app.delete("/admin/api/auth/sessions/{hawk_id}/{hawk_session}") + async def admin_auth_delete_session(hawk_id: str, hawk_session: str, request: Request) -> JSONResponse: + supervisor._require_admin(request) + if not supervisor.remove_protocol_session(hawk_id=hawk_id, hawk_session=hawk_session): + return JSONResponse({"error": "Protocol session not found"}, status_code=404) + return JSONResponse({"ok": True, "auth": supervisor._auth_payload()}) + @app.get("/admin/api/onboarding/devices") async def admin_onboarding_devices(request: Request) -> JSONResponse: supervisor._require_admin(request) diff --git a/start_onboarding.py b/start_onboarding.py index 965e2da..459ae12 100644 --- a/start_onboarding.py +++ b/start_onboarding.py @@ -38,6 +38,7 @@ CFGWIFI_UID = "1234567890" DEFAULT_COUNTRY_DOMAIN = "us" DEFAULT_TIMEZONE = "America/New_York" +DEFAULT_STACK_HTTPS_PORT = 555 POLL_INTERVAL_SECONDS = 5.0 POLL_TIMEOUT_SECONDS = 300.0 @@ -166,32 +167,49 @@ def recv_with_timeout(sock: socket.socket, timeout: float) -> bytes | None: def sanitize_stack_server(url: str) -> str: - value = str(url or "").strip() - for prefix in ("https://", "http://"): - if value.lower().startswith(prefix): - value = value[len(prefix) :] - value = value.strip().strip("/") - if value.lower().startswith("api-"): - value = value[4:] - if not value: + host, port = _parse_server_target(url, default_port=DEFAULT_STACK_HTTPS_PORT) + if host.lower().startswith("api-"): + host = host[4:] + authority = _format_authority(host, port=port, default_port=443) + if not authority: raise ValueError("A server host is required.") - return f"{value}/" + return f"{authority}/" def normalize_api_base_url(url: str) -> str: + host, port = _parse_server_target(url, default_port=DEFAULT_STACK_HTTPS_PORT) + if not host.lower().startswith("api-"): + host = f"api-{host}" + authority = _format_authority(host, port=port, default_port=443) + return f"https://{authority}" + + +def _parse_server_target(url: str, *, default_port: int | None = None) -> tuple[str, int | None]: value = str(url or "").strip() if not value: raise ValueError("A server host is required.") - for prefix in ("https://", "http://"): - if value.lower().startswith(prefix): - value = value[len(prefix) :] - break - value = value.strip().strip("/") - if not value: + parsed = parse.urlsplit(value if "://" in value else f"//{value}") + host = str(parsed.hostname or "").strip().strip("/") + if not host: raise ValueError("A server host is required.") - if not value.lower().startswith("api-"): - value = f"api-{value}" - return f"https://{value}" + try: + port = parsed.port + except ValueError as exc: + raise ValueError("Server port must be numeric.") from exc + if port is None: + port = default_port + return host, port + + +def _format_authority(host: str, *, port: int | None = None, default_port: int | None = None) -> str: + normalized_host = str(host or "").strip().strip("/") + if not normalized_host: + return "" + if port is None: + return normalized_host + if default_port is not None and port == default_port: + return normalized_host + return f"{normalized_host}:{port}" def _format_bool_label(value: bool, true_label: str, false_label: str) -> str: @@ -239,6 +257,10 @@ class GuidedOnboardingConfig: country_domain: str +class ApiReachabilityError(RuntimeError): + """Raised when the admin API is temporarily unreachable.""" + + class RemoteOnboardingApi: """Thin authenticated JSON client for the admin onboarding endpoints.""" @@ -258,6 +280,14 @@ def __init__( self._opener = opener or request.build_opener(request.HTTPCookieProcessor()) self._logged_in = False + def _open(self, req: request.Request): + if self._ssl_context is not None: + try: + return self._opener.open(req, timeout=self.timeout_seconds, context=self._ssl_context) + except TypeError: + pass + return self._opener.open(req, timeout=self.timeout_seconds) + def login(self) -> None: if self._logged_in: return @@ -298,24 +328,17 @@ def _request_json( headers["Content-Type"] = "application/json" req = request.Request(f"{self.base_url}{path}", data=data, headers=headers, method=method) try: - with self._opener.open(req, timeout=self.timeout_seconds, context=self._ssl_context) as response: + with self._open(req) as response: raw = response.read().decode("utf-8", errors="replace") - except TypeError: - try: - with self._opener.open(req, timeout=self.timeout_seconds) as response: - raw = response.read().decode("utf-8", errors="replace") - except error.HTTPError as exc: - if exc.code == 401 and allow_401: - raise RuntimeError("Invalid admin password.") from exc - detail = exc.read().decode("utf-8", errors="replace") - raise RuntimeError(_format_http_error(exc.code, detail)) from exc except error.HTTPError as exc: if exc.code == 401 and allow_401: raise RuntimeError("Invalid admin password.") from exc detail = exc.read().decode("utf-8", errors="replace") raise RuntimeError(_format_http_error(exc.code, detail)) from exc except error.URLError as exc: - raise RuntimeError(f"Unable to reach {self.base_url}: {exc.reason}") from exc + raise ApiReachabilityError(f"Unable to reach {self.base_url}: {exc.reason}") from exc + except OSError as exc: + raise ApiReachabilityError(f"Unable to reach {self.base_url}: {exc}") from exc if not raw: return {} try: @@ -344,7 +367,7 @@ def _format_http_error(status_code: int, raw_body: str) -> str: def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Guided Roborock remote onboarding") - parser.add_argument("--server", required=True, help="Main server hostname, usually starting with api-") + parser.add_argument("--server", required=True, help="Main server hostname or HTTPS URL, usually starting with api-") parser.add_argument("--admin-password", default="") parser.add_argument("--ssid", default="") parser.add_argument("--password", default="") @@ -484,15 +507,33 @@ def poll_session_until_progress( *, session_id: str, baseline_samples: int, + baseline_status: dict[str, Any] | None = None, output: TextIO, poll_interval_seconds: float = POLL_INTERVAL_SECONDS, timeout_seconds: float = POLL_TIMEOUT_SECONDS, sleep_fn: Callable[[float], None] = time.sleep, ) -> tuple[str, dict[str, Any]]: deadline = time.monotonic() + timeout_seconds - latest = api.get_session(session_id=session_id) + latest = dict(baseline_status or {"session_id": session_id, "query_samples": baseline_samples}) + waiting_for_reconnect = False while True: - latest = api.get_session(session_id=session_id) + try: + latest = api.get_session(session_id=session_id) + waiting_for_reconnect = False + except ApiReachabilityError as exc: + if time.monotonic() >= deadline: + return "timeout", latest + if not waiting_for_reconnect: + output.write( + "The main server is not reachable yet from this machine. " + "Finish reconnecting to your normal Wi-Fi and the script will keep retrying.\n" + ) + output.write(f"{exc}\n") + waiting_for_reconnect = True + else: + output.write("Still waiting for this machine to reach the main server again...\n") + sleep_fn(poll_interval_seconds) + continue if str(latest.get("identity_conflict") or "").strip(): return "conflict", latest if bool(latest.get("connected")): @@ -575,6 +616,7 @@ def run_guided_onboarding( api, session_id=session_id, baseline_samples=baseline_samples, + baseline_status=status, output=output, poll_interval_seconds=poll_interval_seconds, timeout_seconds=timeout_seconds, diff --git a/tests/conftest.py b/tests/conftest.py index 226780e..237243e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,8 +17,14 @@ def write_release_config( tmp_path: Path, *, + stack_fqdn: str = "roborock.example.com", + https_port: int = 443, + mqtt_tls_port: int = 8883, broker_mode: str = "external", enable_topic_bridge: bool = False, + protocol_auth_enabled: bool = True, + protocol_login_email: str = "user@example.com", + protocol_login_pin: str = "123456", ) -> Path: cert_dir = tmp_path / "certs" cert_dir.mkdir(parents=True, exist_ok=True) @@ -29,7 +35,9 @@ def write_release_config( config_file.write_text( f""" [network] -stack_fqdn = "roborock.example.com" +stack_fqdn = "{stack_fqdn}" +https_port = {https_port} +mqtt_tls_port = {mqtt_tls_port} [broker] mode = "{broker_mode}" @@ -49,6 +57,9 @@ def write_release_config( password_hash = "{hash_password("correct horse battery staple", iterations=10_000)}" session_secret = "abcdefghijklmnopqrstuvwxyz123456" session_ttl_seconds = 3600 +protocol_auth_enabled = {"true" if protocol_auth_enabled else "false"} +protocol_login_email = "{protocol_login_email}" +protocol_login_pin_hash = "{hash_password(protocol_login_pin, iterations=10_000)}" """.strip(), encoding="utf-8", ) diff --git a/tests/contracts/test_ios_app_init_contract.py b/tests/contracts/test_ios_app_init_contract.py index f574487..c8fc9af 100644 --- a/tests/contracts/test_ios_app_init_contract.py +++ b/tests/contracts/test_ios_app_init_contract.py @@ -8,6 +8,7 @@ from conftest import write_release_config from roborock_local_server.config import load_config, resolve_paths from roborock_local_server.server import ReleaseSupervisor +from shared.protocol_auth import ProtocolAuthStore, build_hawk_authorization FIXTURE_PATH = Path(__file__).with_name("fixtures") / "ios_app_init_v4_59_02_anonymized.json" @@ -22,6 +23,28 @@ def _write_json(path: Path, payload: Any) -> None: path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") +def _cloud_snapshot_with_protocol_user_data(snapshot: dict[str, Any]) -> dict[str, Any]: + seeded = dict(snapshot) + seeded["user_data"] = { + "uid": 1001, + "token": "anon-app-token", + "rruid": "anon-rruid", + "rriot": { + "u": "user-anon", + "s": "session-anon", + "h": "contract-hawk-secret", + "k": "contract-hawk-mqtt-key", + "r": { + "r": "US", + "a": "https://api-us.roborock.com", + "m": "ssl://mqtt-us.roborock.com:8883", + "l": "https://wood-us.roborock.com", + }, + }, + } + return seeded + + def _record_runtime_presence(supervisor: ReleaseSupervisor, entries: list[dict[str, Any]]) -> None: for entry in entries: conn_id = str(entry["conn_id"]) @@ -40,6 +63,16 @@ def _record_runtime_presence(supervisor: ReleaseSupervisor, entries: list[dict[s ) +def _normalize_expected_response(request_name: str, payload: dict[str, Any]) -> dict[str, Any]: + normalized = json.loads(json.dumps(payload)) + if request_name == "get_home_data": + for key in ("data", "result"): + section = normalized.get(key) + if isinstance(section, dict): + section.pop("received_devices", None) + return normalized + + def test_ios_app_init_contract_from_anonymized_capture(tmp_path: Path, monkeypatch) -> None: fixture = _load_fixture() @@ -52,7 +85,7 @@ def test_ios_app_init_contract_from_anonymized_capture(tmp_path: Path, monkeypat seed = fixture["seed"] _write_json(paths.inventory_path, seed["inventory"]) - _write_json(paths.cloud_snapshot_path, seed["cloud_snapshot"]) + _write_json(paths.cloud_snapshot_path, _cloud_snapshot_with_protocol_user_data(seed["cloud_snapshot"])) _write_json(paths.runtime_credentials_path, seed["runtime_credentials"]) supervisor = ReleaseSupervisor(config=config, paths=paths) @@ -63,10 +96,22 @@ def test_ios_app_init_contract_from_anonymized_capture(tmp_path: Path, monkeypat client = TestClient(supervisor.app) default_headers = fixture["default_headers"] + auth_store = ProtocolAuthStore(paths.cloud_snapshot_path) + user = auth_store.availability().user + assert user is not None - for request in fixture["requests"]: + for index, request in enumerate(fixture["requests"]): headers = dict(default_headers) headers.update(request.get("headers", {})) + if request["path"].startswith(("/user/", "/v2/user/", "/v3/user/")): + headers["authorization"] = build_hawk_authorization( + user=user, + path=request["path"], + query_values=request.get("query"), + form_values=request.get("form") or request.get("json"), + timestamp=fixture["frozen_time"], + nonce=f"contract-{index}", + ) response = client.request( method=request["method"], url=request["path"], @@ -76,4 +121,4 @@ def test_ios_app_init_contract_from_anonymized_capture(tmp_path: Path, monkeypat data=request.get("form"), ) assert response.status_code == 200, request["name"] - assert response.json() == request["expected_response"], request["name"] + assert response.json() == _normalize_expected_response(request["name"], request["expected_response"]), request["name"] diff --git a/tests/test_admin_api.py b/tests/test_admin_api.py index 7def935..6c722b9 100644 --- a/tests/test_admin_api.py +++ b/tests/test_admin_api.py @@ -7,6 +7,7 @@ from conftest import write_release_config from roborock_local_server.config import load_config, resolve_paths from roborock_local_server.server import ReleaseSupervisor, resolve_route +from shared.protocol_auth import ProtocolAuthStore, build_hawk_authorization def _scene_zone_step( @@ -146,6 +147,48 @@ def _write_scene_zone_trace(mqtt_jsonl_path: Path) -> None: ) +def _seed_protocol_snapshot(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps( + { + "user_data": { + "uid": 1001, + "token": "local-token-123", + "rruid": "local-rruid-123", + "rriot": { + "u": "hawk-user-123", + "s": "hawk-session-123", + "h": "hawk-secret-123", + "k": "hawk-mqtt-key-123", + "r": { + "r": "US", + "a": "https://api-us.roborock.com", + "m": "ssl://mqtt-us.roborock.com:8883", + "l": "https://wood-us.roborock.com", + }, + }, + } + } + ) + + "\n", + encoding="utf-8", + ) + + +def _hawk_headers(snapshot_path: Path, path: str, *, form_values: dict[str, object] | None = None, json_values: dict[str, object] | None = None) -> dict[str, str]: + user = ProtocolAuthStore(snapshot_path).availability().user + assert user is not None + return { + "Authorization": build_hawk_authorization( + user=user, + path=path, + form_values=form_values or json_values, + nonce=f"nonce-{path.replace('/', '-')}", + ) + } + + def test_admin_login_and_status_flow(tmp_path: Path) -> None: config_file = write_release_config(tmp_path) config = load_config(config_file) @@ -177,7 +220,7 @@ def test_admin_login_and_status_flow(tmp_path: Path) -> None: "https://paypal.me/LLashley304", "https://us.roborock.com/discount/RRSAP202602071713342D18X?redirect=%2Fpages%2Froborock-store%3Fuuid%3DEQe6p1jdZczHEN4Q0nbsG9sZRm0RK1gW5eSM%252FCzcW4Q%253D", "https://roborock.pxf.io/B0VYV9", - "https://amzn.to/4bGfG6B", + "https://amzn.to/4cx8zg3", ] assert payload["health"]["services"] assert payload["pairing"]["active"] is False @@ -189,6 +232,7 @@ def test_admin_login_and_status_flow(tmp_path: Path) -> None: dashboard_page = client.get("/admin") assert dashboard_page.status_code == 200 assert "Cloud Import" in dashboard_page.text + assert "Protocol Auth" in dashboard_page.text assert "Num query samples" in dashboard_page.text assert "Public Key determined" in dashboard_page.text @@ -206,6 +250,98 @@ def test_admin_login_and_status_flow(tmp_path: Path) -> None: assert status_after_logout.status_code == 401 +def test_admin_auth_endpoints_toggle_protocol_auth_and_manage_sessions(tmp_path: Path) -> None: + config_file = write_release_config(tmp_path) + config = load_config(config_file) + paths = resolve_paths(config_file, config) + _seed_protocol_snapshot(paths.cloud_snapshot_path) + supervisor = ReleaseSupervisor(config=config, paths=paths) + issued = supervisor.protocol_auth.issue_local_session( + json.loads(paths.cloud_snapshot_path.read_text(encoding="utf-8"))["user_data"], + source="admin_test_session", + ) + + client = TestClient(supervisor.app) + + assert client.get("/admin/api/auth").status_code == 401 + + login = client.post("/admin/api/login", json={"password": "correct horse battery staple"}) + assert login.status_code == 200 + + auth_payload = client.get("/admin/api/auth") + assert auth_payload.status_code == 200 + auth_json = auth_payload.json() + assert auth_json["protocol_auth_enabled"] is True + assert auth_json["protocol_session_count"] >= 1 + session = next(item for item in auth_json["protocol_sessions"] if item["hawk_id"] == issued["rriot"]["u"]) + + toggled = client.post("/admin/api/auth", json={"protocol_auth_enabled": False}) + assert toggled.status_code == 200 + assert toggled.json()["protocol_auth_enabled"] is False + assert 'protocol_auth_enabled = false' in paths.config_file.read_text(encoding="utf-8") + + unauthed_home = client.get("/api/v1/getHomeDetail") + assert unauthed_home.status_code == 200 + + deleted = client.delete(f"/admin/api/auth/sessions/{session['hawk_id']}/{session['hawk_session']}") + assert deleted.status_code == 200 + assert deleted.json()["ok"] is True + assert deleted.json()["auth"]["protocol_session_count"] == 0 + + missing = client.delete(f"/admin/api/auth/sessions/{session['hawk_id']}/{session['hawk_session']}") + assert missing.status_code == 404 + + +def test_admin_auth_update_rejects_invalid_payload_types(tmp_path: Path) -> None: + config_file = write_release_config(tmp_path) + config = load_config(config_file) + paths = resolve_paths(config_file, config) + supervisor = ReleaseSupervisor(config=config, paths=paths) + + client = TestClient(supervisor.app) + login = client.post("/admin/api/login", json={"password": "correct horse battery staple"}) + assert login.status_code == 200 + + invalid_string = client.post("/admin/api/auth", json={"protocol_auth_enabled": "false"}) + assert invalid_string.status_code == 400 + assert invalid_string.json()["error"] == "protocol_auth_enabled must be a boolean" + + invalid_container = client.post("/admin/api/auth", json=["not-an-object"]) + assert invalid_container.status_code == 400 + assert invalid_container.json()["error"] == "JSON body must be an object" + + invalid_json = client.post( + "/admin/api/auth", + content="{", + headers={"Content-Type": "application/json"}, + ) + assert invalid_json.status_code == 400 + assert invalid_json.json()["error"] == "Invalid JSON body" + + +def test_set_protocol_auth_enabled_rewrites_only_exact_admin_key(tmp_path: Path) -> None: + config_file = write_release_config(tmp_path) + original = config_file.read_text(encoding="utf-8") + modified = original.replace( + "protocol_auth_enabled = true", + "# protocol_auth_enabled = true\nprotocol_auth_enabled_backup = true", + ) + config_file.write_text(modified, encoding="utf-8") + + config = load_config(config_file) + paths = resolve_paths(config_file, config) + supervisor = ReleaseSupervisor(config=config, paths=paths) + + payload = supervisor.set_protocol_auth_enabled(False) + + rendered = config_file.read_text(encoding="utf-8") + assert payload["protocol_auth_enabled"] is False + assert "# protocol_auth_enabled = true" in rendered + assert "protocol_auth_enabled_backup = true" in rendered + assert "protocol_auth_enabled = false" in rendered + assert rendered.count("protocol_auth_enabled = false") == 1 + + def test_admin_onboarding_endpoints_require_auth_and_manage_session(tmp_path: Path) -> None: config_file = write_release_config(tmp_path) config = load_config(config_file) @@ -321,13 +457,13 @@ def test_core_only_mode_disables_standalone_admin_routes(tmp_path: Path) -> None assert admin_page.status_code == 404 ui_health = client.get("/ui/api/health") - assert ui_health.status_code == 200 + assert ui_health.status_code == 404 region_response = client.get("/region") assert region_response.status_code == 200 -def test_ui_api_health_and_vacuums_return_runtime_payload_without_auth(tmp_path: Path) -> None: +def test_ui_api_health_and_vacuums_require_admin_auth(tmp_path: Path) -> None: config_file = write_release_config(tmp_path) config = load_config(config_file) paths = resolve_paths(config_file, config) @@ -362,6 +498,12 @@ def test_ui_api_health_and_vacuums_return_runtime_payload_without_auth(tmp_path: client = TestClient(supervisor.app) + assert client.get("/ui/api/health").status_code == 401 + assert client.get("/ui/api/vacuums").status_code == 401 + + login = client.post("/admin/api/login", json={"password": "correct horse battery staple"}) + assert login.status_code == 200 + health = client.get("/ui/api/health") assert health.status_code == 200 health_payload = health.json() @@ -576,16 +718,34 @@ def test_scene_update_routes_persist_name_and_zone_ranges(tmp_path: Path) -> Non + "\n", encoding="utf-8", ) + _seed_protocol_snapshot(paths.cloud_snapshot_path) _write_scene_zone_trace(paths.mqtt_jsonl_path) supervisor = ReleaseSupervisor(config=config, paths=paths) client = TestClient(supervisor.app) - rename_response = client.put("/user/scene/4491073/name", data={"name": "After dinner"}) + rename_response = client.put( + "/user/scene/4491073/name", + data={"name": "After dinner"}, + headers=_hawk_headers( + paths.cloud_snapshot_path, + "/user/scene/4491073/name", + form_values={"name": "After dinner"}, + ), + ) assert rename_response.status_code == 200 assert rename_response.json()["data"]["name"] == "After dinner" - update_response = client.put("/user/scene/4491073/param", json=_after_dinner_param_payload(device_id, include_ranges=False)) + update_payload = _after_dinner_param_payload(device_id, include_ranges=False) + update_response = client.put( + "/user/scene/4491073/param", + json=update_payload, + headers=_hawk_headers( + paths.cloud_snapshot_path, + "/user/scene/4491073/param", + json_values=update_payload, + ), + ) assert update_response.status_code == 200 stored_inventory = json.loads(paths.inventory_path.read_text(encoding="utf-8")) diff --git a/tests/test_config.py b/tests/test_config.py index c16ea74..799b898 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,4 +1,5 @@ from pathlib import Path +import pytest from roborock_local_server.config import load_config, resolve_paths @@ -24,6 +25,8 @@ def test_load_config_and_resolve_paths(tmp_path: Path) -> None: [admin] password_hash = "pbkdf2_sha256$600000$abc$def" session_secret = "abcdefghijklmnopqrstuvwxyz123456" +protocol_login_email = "user@example.com" +protocol_login_pin_hash = "pbkdf2_sha256$600000$ghi$jkl" """.strip(), encoding="utf-8", ) @@ -32,6 +35,39 @@ def test_load_config_and_resolve_paths(tmp_path: Path) -> None: paths = resolve_paths(config_file, config) assert config.network.stack_fqdn == "roborock.example.com" + assert config.network.https_port == 555 + assert config.network.mqtt_tls_port == 8881 + assert config.admin.protocol_auth_enabled is True + assert config.admin.protocol_login_email == "user@example.com" assert paths.data_dir == (tmp_path / "data").resolve() assert paths.cert_file == (tmp_path / "certs" / "fullchain.pem").resolve() assert paths.key_file == (tmp_path / "certs" / "privkey.pem").resolve() + + +def test_load_config_requires_protocol_login_credentials(tmp_path: Path) -> None: + config_file = tmp_path / "config.toml" + config_file.write_text( + """ +[network] +stack_fqdn = "roborock.example.com" + +[broker] +mode = "embedded" + +[storage] +data_dir = "data" + +[tls] +mode = "provided" +cert_file = "certs/fullchain.pem" +key_file = "certs/privkey.pem" + +[admin] +password_hash = "pbkdf2_sha256$600000$abc$def" +session_secret = "abcdefghijklmnopqrstuvwxyz123456" + """.strip(), + encoding="utf-8", + ) + + with pytest.raises(ValueError, match="admin.protocol_login_email is required"): + load_config(config_file) diff --git a/tests/test_configure.py b/tests/test_configure.py index 7676d09..2d0299b 100644 --- a/tests/test_configure.py +++ b/tests/test_configure.py @@ -3,16 +3,20 @@ import pytest from roborock_local_server.config import load_config -from roborock_local_server.configure import ConfigureAnswers, write_config_setup +from roborock_local_server.configure import ConfigureAnswers, _validate_protocol_login_pin, write_config_setup def _answers( *, + https_port: int = 555, + mqtt_tls_port: int = 8881, broker_mode: str = "embedded", tls_mode: str = "cloudflare_acme", ) -> ConfigureAnswers: return ConfigureAnswers( stack_fqdn="roborock.example.com", + https_port=https_port, + mqtt_tls_port=mqtt_tls_port, broker_mode=broker_mode, tls_mode=tls_mode, base_domain="example.com" if tls_mode == "cloudflare_acme" else "", @@ -20,6 +24,8 @@ def _answers( cloudflare_token="cloudflare-token" if tls_mode == "cloudflare_acme" else "", password_hash="pbkdf2_sha256$600000$abc$def", session_secret="abcdefghijklmnopqrstuvwxyz123456", + protocol_login_email="user@example.com", + protocol_login_pin_hash="pbkdf2_sha256$600000$ghi$jkl", ) @@ -35,11 +41,15 @@ def test_write_config_setup_embedded_cloudflare(tmp_path: Path) -> None: config = load_config(result.config_file) assert config.network.stack_fqdn == "roborock.example.com" + assert config.network.https_port == 555 + assert config.network.mqtt_tls_port == 8881 assert config.broker.mode == "embedded" assert config.broker.host == "127.0.0.1" assert config.broker.port == 18830 assert config.tls.mode == "cloudflare_acme" assert config.tls.cloudflare_token_file == "/run/secrets/cloudflare_token" + assert config.admin.protocol_auth_enabled is True + assert config.admin.protocol_login_email == "user@example.com" def test_write_config_setup_external_broker_requires_host_before_serve(tmp_path: Path) -> None: @@ -56,6 +66,8 @@ def test_write_config_setup_external_broker_requires_host_before_serve(tmp_path: assert 'mode = "external"' in rendered assert 'host = ""' in rendered assert "port = 1883" in rendered + assert "protocol_auth_enabled = true" in rendered + assert 'protocol_login_email = "user@example.com"' in rendered with pytest.raises(ValueError, match="broker.host is required"): load_config(config_file) @@ -67,3 +79,26 @@ def test_write_config_setup_refuses_overwrite_without_force(tmp_path: Path) -> N with pytest.raises(FileExistsError, match="Refusing to overwrite existing file"): write_config_setup(config_file=config_file, answers=_answers()) + + +def test_write_config_setup_persists_custom_ports(tmp_path: Path) -> None: + config_file = tmp_path / "config.toml" + + result = write_config_setup( + config_file=config_file, + answers=_answers(https_port=8443, mqtt_tls_port=9443), + ) + + config = load_config(result.config_file) + assert config.network.https_port == 8443 + assert config.network.mqtt_tls_port == 9443 + + +def test_validate_protocol_login_pin_requires_exactly_six_digits() -> None: + assert _validate_protocol_login_pin("123456") == "123456" + + with pytest.raises(ValueError, match="exactly 6 digits"): + _validate_protocol_login_pin("12345") + + with pytest.raises(ValueError, match="exactly 6 digits"): + _validate_protocol_login_pin("12345a") diff --git a/tests/test_custom_ports.py b/tests/test_custom_ports.py new file mode 100644 index 0000000..1eab546 --- /dev/null +++ b/tests/test_custom_ports.py @@ -0,0 +1,59 @@ +import json + +from fastapi.testclient import TestClient + +from conftest import write_release_config +from roborock_local_server.config import load_config, resolve_paths +from roborock_local_server.server import ReleaseSupervisor + + +def _write_json(path, payload: object) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8") + + +def test_server_advertises_custom_https_and_mqtt_ports(tmp_path) -> None: + config_file = write_release_config( + tmp_path, + stack_fqdn="api-roborock.example.com", + https_port=8443, + mqtt_tls_port=9443, + ) + config = load_config(config_file) + paths = resolve_paths(config_file, config) + _write_json(paths.inventory_path, {"home": {"id": 12345, "name": "Test Home"}, "devices": []}) + _write_json( + paths.cloud_snapshot_path, + { + "user_data": { + "uid": 1001, + "token": "local-token-123", + "rruid": "local-rruid-123", + "rriot": { + "u": "hawk-user-123", + "s": "hawk-session-123", + "h": "hawk-secret-123", + "k": "hawk-mqtt-key-123", + "r": { + "r": "US", + "a": "https://api-us.roborock.com", + "m": "ssl://mqtt-us.roborock.com:8883", + "l": "https://wood-us.roborock.com", + }, + }, + } + }, + ) + + supervisor = ReleaseSupervisor(config=config, paths=paths) + client = TestClient(supervisor.app) + + region_response = client.get("/region", headers={"host": "api-roborock.example.com:8443"}) + assert region_response.status_code == 200 + region_payload = region_response.json()["data"] + assert region_payload["apiUrl"] == "https://api-roborock.example.com:8443" + assert region_payload["mqttUrl"] == "ssl://api-roborock.example.com:9443" + + assert supervisor.context.api_url() == "https://api-roborock.example.com:8443" + assert supervisor.context.mqtt_url() == "ssl://api-roborock.example.com:9443" + assert supervisor.context.wood_url() == "https://api-roborock.example.com:8443" diff --git a/tests/test_home_data_online.py b/tests/test_home_data_online.py index 3b91676..b9902b5 100644 --- a/tests/test_home_data_online.py +++ b/tests/test_home_data_online.py @@ -2,10 +2,53 @@ from pathlib import Path from fastapi.testclient import TestClient +from roborock.data import HomeData from conftest import write_release_config from roborock_local_server.config import load_config, resolve_paths from roborock_local_server.server import ReleaseSupervisor +from shared.protocol_auth import ProtocolAuthStore, build_hawk_authorization + + +def _seed_cloud_snapshot(path: Path, home_data: dict[str, object]) -> None: + path.write_text( + json.dumps( + { + "user_data": { + "uid": 1001, + "token": "local-token-123", + "rruid": "local-rruid-123", + "rriot": { + "u": "hawk-user-123", + "s": "hawk-session-123", + "h": "hawk-secret-123", + "k": "hawk-mqtt-key-123", + "r": { + "r": "US", + "a": "https://api-us.roborock.com", + "m": "ssl://mqtt-us.roborock.com:8883", + "l": "https://wood-us.roborock.com", + }, + }, + }, + "home_data": home_data, + } + ) + + "\n", + encoding="utf-8", + ) + + +def _hawk_headers(snapshot_path: Path, path: str) -> dict[str, str]: + user = ProtocolAuthStore(snapshot_path).availability().user + assert user is not None + return { + "Authorization": build_hawk_authorization( + user=user, + path=path, + nonce=f"nonce-{path.replace('/', '-')}", + ) + } def test_home_data_marks_runtime_connected_device_online_via_runtime_credentials(tmp_path: Path) -> None: @@ -34,6 +77,28 @@ def test_home_data_marks_runtime_connected_device_online_via_runtime_credentials + "\n", encoding="utf-8", ) + _seed_cloud_snapshot( + paths.cloud_snapshot_path, + { + "id": 1233716, + "name": "My Home", + "devices": [ + { + "duid": "1OVJHS7cL6XxkYkoOGr2Hw", + "name": "S7", + "productId": "1YYW18rpgyAJTISwb1NM91", + } + ], + "products": [ + { + "id": "1YYW18rpgyAJTISwb1NM91", + "name": "S7", + "model": "roborock.vacuum.a15", + "category": "robot.vacuum.cleaner", + } + ], + }, + ) paths.runtime_credentials_path.write_text( json.dumps( { @@ -70,7 +135,7 @@ def test_home_data_marks_runtime_connected_device_online_via_runtime_credentials ) client = TestClient(supervisor.app) - response = client.get("/v3/user/homes/1233716") + response = client.get("/v3/user/homes/1233716", headers=_hawk_headers(paths.cloud_snapshot_path, "/v3/user/homes/1233716")) assert response.status_code == 200 home_data = response.json()["data"] @@ -78,6 +143,100 @@ def test_home_data_marks_runtime_connected_device_online_via_runtime_credentials assert s7["online"] is True +def test_home_data_response_does_not_emit_empty_snake_case_received_devices(tmp_path: Path) -> None: + config_file = write_release_config(tmp_path) + config = load_config(config_file) + paths = resolve_paths(config_file, config) + + paths.runtime_dir.mkdir(parents=True, exist_ok=True) + paths.state_dir.mkdir(parents=True, exist_ok=True) + paths.inventory_path.write_text( + json.dumps( + { + "home": {"id": 1316433, "name": "My Home"}, + "received_devices": [ + { + "duid": "6HL2zfniaoYYV01CkVuhkO", + "name": "Roborock Qrevo MaxV 2", + "model": "roborock.vacuum.a87", + "product_id": "5gUei3OIJIXVD3eD85Balg", + "local_key": "xPd5Dr8CGGqtdDlH", + "online": True, + "pv": "1.0", + "share": True, + } + ], + } + ) + + "\n", + encoding="utf-8", + ) + _seed_cloud_snapshot( + paths.cloud_snapshot_path, + { + "id": 1316433, + "name": "My Home", + "receivedDevices": [ + { + "duid": "6HL2zfniaoYYV01CkVuhkO", + "name": "Roborock Qrevo MaxV 2", + "productId": "5gUei3OIJIXVD3eD85Balg", + "share": True, + } + ], + "products": [ + { + "id": "5gUei3OIJIXVD3eD85Balg", + "name": "Roborock Qrevo MaxV", + "model": "roborock.vacuum.a87", + "category": "robot.vacuum.cleaner", + } + ], + }, + ) + paths.runtime_credentials_path.write_text( + json.dumps( + { + "schema_version": 2, + "devices": [ + { + "did": "1103821560705", + "duid": "6HL2zfniaoYYV01CkVuhkO", + "name": "Roborock Qrevo MaxV 2", + "model": "roborock.vacuum.a87", + "product_id": "5gUei3OIJIXVD3eD85Balg", + "localkey": "xPd5Dr8CGGqtdDlH", + "local_key_source": "inventory_cloud", + "device_mqtt_usr": "c25b14ceac358d2a", + "updated_at": "2026-03-17T22:50:00+00:00", + "last_nc_at": "", + "last_mqtt_seen_at": "", + } + ], + } + ) + + "\n", + encoding="utf-8", + ) + + supervisor = ReleaseSupervisor(config=config, paths=paths) + supervisor.refresh_inventory_state() + + client = TestClient(supervisor.app) + response = client.get("/v3/user/homes/1316433", headers=_hawk_headers(paths.cloud_snapshot_path, "/v3/user/homes/1316433")) + assert response.status_code == 200 + + payload = response.json() + assert "received_devices" not in payload["data"] + assert "received_devices" not in payload["result"] + + parsed_home = HomeData.from_dict(payload["result"]) + assert parsed_home is not None + assert len(parsed_home.received_devices) == 1 + assert len(parsed_home.device_products) == 1 + assert "6HL2zfniaoYYV01CkVuhkO" in parsed_home.device_products + + def test_device_detail_uses_runtime_connection_and_preserves_inventory_fields(tmp_path: Path) -> None: config_file = write_release_config(tmp_path) config = load_config(config_file) @@ -106,33 +265,28 @@ def test_device_detail_uses_runtime_connection_and_preserves_inventory_fields(tm + "\n", encoding="utf-8", ) - paths.cloud_snapshot_path.write_text( - json.dumps( - { - "home_data": { - "id": 1233716, - "name": "My Home", - "devices": [ - { - "duid": "6HL2zfniaoYYV01CkVuhkO", - "name": "Roborock Qrevo MaxV 2", - "productId": "5gUei3OIJIXVD3eD85Balg", - "extra": "{\"RRMonitorPrivacyVersion\": \"1\"}", - } - ], - "products": [ - { - "id": "5gUei3OIJIXVD3eD85Balg", - "name": "Roborock Qrevo MaxV", - "model": "roborock.vacuum.a87", - "category": "RoborockCategory.VACUUM", - } - ], + _seed_cloud_snapshot( + paths.cloud_snapshot_path, + { + "id": 1233716, + "name": "My Home", + "devices": [ + { + "duid": "6HL2zfniaoYYV01CkVuhkO", + "name": "Roborock Qrevo MaxV 2", + "productId": "5gUei3OIJIXVD3eD85Balg", + "extra": "{\"RRMonitorPrivacyVersion\": \"1\"}", } - } - ) - + "\n", - encoding="utf-8", + ], + "products": [ + { + "id": "5gUei3OIJIXVD3eD85Balg", + "name": "Roborock Qrevo MaxV", + "model": "roborock.vacuum.a87", + "category": "RoborockCategory.VACUUM", + } + ], + }, ) paths.runtime_credentials_path.write_text( json.dumps( @@ -170,7 +324,10 @@ def test_device_detail_uses_runtime_connection_and_preserves_inventory_fields(tm ) client = TestClient(supervisor.app) - response = client.get("/user/devices/6HL2zfniaoYYV01CkVuhkO") + response = client.get( + "/user/devices/6HL2zfniaoYYV01CkVuhkO", + headers=_hawk_headers(paths.cloud_snapshot_path, "/user/devices/6HL2zfniaoYYV01CkVuhkO"), + ) assert response.status_code == 200 device_data = response.json()["data"] @@ -205,38 +362,33 @@ def test_home_data_preserves_last_working_app_contract(tmp_path: Path) -> None: + "\n", encoding="utf-8", ) - paths.cloud_snapshot_path.write_text( - json.dumps( - { - "home_data": { - "id": 1233716, - "name": "My Home", - "devices": [ - { - "duid": "6HL2zfniaoYYV01CkVuhkO", - "name": "Roborock Qrevo MaxV 2", - "productId": "5gUei3OIJIXVD3eD85Balg", - "extra": "{\"RRMonitorPrivacyVersion\": \"1\"}", - "featureSet": "2233384992473071", - "newFeatureSet": "7", - "f": False, - "share": False, - "createTime": 1712144203, - } - ], - "products": [ - { - "id": "5gUei3OIJIXVD3eD85Balg", - "name": "Roborock Qrevo MaxV", - "model": "roborock.vacuum.a87", - "category": "robot.vacuum.cleaner", - } - ], + _seed_cloud_snapshot( + paths.cloud_snapshot_path, + { + "id": 1233716, + "name": "My Home", + "devices": [ + { + "duid": "6HL2zfniaoYYV01CkVuhkO", + "name": "Roborock Qrevo MaxV 2", + "productId": "5gUei3OIJIXVD3eD85Balg", + "extra": "{\"RRMonitorPrivacyVersion\": \"1\"}", + "featureSet": "2233384992473071", + "newFeatureSet": "7", + "f": False, + "share": False, + "createTime": 1712144203, } - } - ) - + "\n", - encoding="utf-8", + ], + "products": [ + { + "id": "5gUei3OIJIXVD3eD85Balg", + "name": "Roborock Qrevo MaxV", + "model": "roborock.vacuum.a87", + "category": "robot.vacuum.cleaner", + } + ], + }, ) paths.runtime_credentials_path.write_text( json.dumps( @@ -267,7 +419,7 @@ def test_home_data_preserves_last_working_app_contract(tmp_path: Path) -> None: supervisor.refresh_inventory_state() client = TestClient(supervisor.app) - response = client.get("/v3/user/homes/1233716") + response = client.get("/v3/user/homes/1233716", headers=_hawk_headers(paths.cloud_snapshot_path, "/v3/user/homes/1233716")) assert response.status_code == 200 home_data = response.json()["data"] diff --git a/tests/test_mitm_redirect.py b/tests/test_mitm_redirect.py new file mode 100644 index 0000000..eb4740d --- /dev/null +++ b/tests/test_mitm_redirect.py @@ -0,0 +1,227 @@ +import importlib +import io +import json +import ssl +import sys +import types +from urllib.error import HTTPError + +import pytest + + +class _FakeResponse: + def __init__(self, status: int = 200) -> None: + self.status = status + + def __enter__(self) -> "_FakeResponse": + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def read(self) -> bytes: + return b"" + + +def _load_mitm_redirect(monkeypatch): + fake_log = types.SimpleNamespace( + info=lambda *args, **kwargs: None, + warn=lambda *args, **kwargs: None, + error=lambda *args, **kwargs: None, + ) + fake_mitmproxy = types.SimpleNamespace( + ctx=types.SimpleNamespace(log=fake_log), + http=types.SimpleNamespace(HTTPFlow=object), + ) + monkeypatch.setitem(sys.modules, "mitmproxy", fake_mitmproxy) + sys.modules.pop("mitm_redirect", None) + return importlib.import_module("mitm_redirect") + + +def _sample_user_data() -> dict[str, object]: + return { + "token": "real-cloud-token-999", + "rruid": "real-cloud-rruid-999", + "rriot": { + "u": "real-cloud-hawk-user", + "s": "real-cloud-hawk-session", + "h": "real-cloud-hawk-secret", + "k": "real-cloud-mqtt-key", + }, + } + + +class _FakeRequest: + def __init__(self, host: str, path: str) -> None: + self.pretty_host = host + self.path = path + self.pretty_url = f"https://{host}{path}" + self.method = "GET" + self.headers: dict[str, str] = {} + self.content = b"" + self.scheme = "https" + self.host = host + self.port = 443 + + +class _FakeFlow: + def __init__(self, host: str, path: str) -> None: + self.request = _FakeRequest(host, path) + self.response = None + + +class _FakeResponseFlow: + def __init__(self, host: str, path: str, content: bytes) -> None: + self.request = _FakeRequest(host, path) + self.response = types.SimpleNamespace( + headers={"content-type": "application/json"}, + content=content, + status_code=200, + reason="OK", + ) + + +def test_sync_protocol_user_data_verifies_tls_by_default(monkeypatch) -> None: + mitm_redirect = _load_mitm_redirect(monkeypatch) + captured: dict[str, object] = {} + + def fake_urlopen(request, timeout, context): + captured["request"] = request + captured["timeout"] = timeout + captured["context"] = context + return _FakeResponse() + + monkeypatch.setattr(mitm_redirect, "urlopen", fake_urlopen) + mitm_redirect.LOCAL_SYNC_SECRET = "abcdefghijklmnopqrstuvwxyz123456" + mitm_redirect.LOCAL_API = "api-roborock.example.com" + + mitm_redirect._sync_protocol_user_data(_sample_user_data()) + + context = captured["context"] + assert isinstance(context, ssl.SSLContext) + assert context.verify_mode == ssl.CERT_REQUIRED + assert context.check_hostname is True + + +def test_preflight_sync_endpoint_accepts_expected_missing_user_data(monkeypatch) -> None: + mitm_redirect = _load_mitm_redirect(monkeypatch) + + def fake_urlopen(request, timeout, context): + raise HTTPError( + request.full_url, + 400, + "Bad Request", + hdrs=None, + fp=io.BytesIO( + b'{"code":40041,"msg":"protocol_sync_failed","data":{"reason":"missing_user_data"}}' + ), + ) + + monkeypatch.setattr(mitm_redirect, "urlopen", fake_urlopen) + + mitm_redirect._preflight_sync_endpoint("https://api-roborock.example.com", "abcdefghijklmnopqrstuvwxyz123456") + + +def test_preflight_sync_endpoint_rejects_invalid_secret(monkeypatch) -> None: + mitm_redirect = _load_mitm_redirect(monkeypatch) + + def fake_urlopen(request, timeout, context): + raise HTTPError( + request.full_url, + 401, + "Unauthorized", + hdrs=None, + fp=io.BytesIO( + b'{"code":40041,"msg":"protocol_sync_failed","data":{"reason":"invalid_sync_secret"}}' + ), + ) + + monkeypatch.setattr(mitm_redirect, "urlopen", fake_urlopen) + + with pytest.raises(mitm_redirect.SyncEndpointError, match="preflight failed: HTTP 401 - protocol_sync_failed - invalid_sync_secret"): + mitm_redirect._preflight_sync_endpoint("https://api-roborock.example.com", "bad-secret") + + +def test_login_response_is_blocked_when_sync_fails(monkeypatch) -> None: + mitm_redirect = _load_mitm_redirect(monkeypatch) + monkeypatch.setattr(mitm_redirect, "_log_flow", lambda *args, **kwargs: None) + + def fake_sync(_user_data): + raise mitm_redirect.SyncEndpointError( + "https://127.0.0.1:555/internal/protocol/user-data", + "request failed: certificate verify failed", + ) + + monkeypatch.setattr(mitm_redirect, "_sync_protocol_user_data", fake_sync) + flow = _FakeResponseFlow( + "usiot.roborock.com", + "/api/v5/auth/email/login/code", + json.dumps({"data": _sample_user_data()}).encode("utf-8"), + ) + + mitm_redirect.response(flow) + + assert flow.response.status_code == 502 + body = json.loads(flow.response.content) + assert body["msg"] == "local_sync_failed" + assert body["data"]["reason"] == "sync_unreachable" + assert "127.0.0.1:555" in body["data"]["syncUrl"] + + +def test_rewrite_value_supports_custom_ports(monkeypatch) -> None: + mitm_redirect = _load_mitm_redirect(monkeypatch) + mitm_redirect.LOCAL_API_HOST = "api-roborock.example.com" + mitm_redirect.LOCAL_API_PORT = 8443 + mitm_redirect.LOCAL_MQTT_HOST = "mqtt-roborock.example.com" + mitm_redirect.LOCAL_MQTT_PORT = 9443 + mitm_redirect.LOCAL_WOOD_HOST = "wood-roborock.example.com" + mitm_redirect.LOCAL_WOOD_PORT = 8443 + + rewritten = mitm_redirect._rewrite_value( + "https://api-us.roborock.com ssl://mqtt-us.roborock.com:8883 https://wood-us.roborock.com" + ) + + assert "https://api-roborock.example.com:8443" in rewritten + assert "ssl://mqtt-roborock.example.com:9443" in rewritten + assert "https://wood-roborock.example.com:8443" in rewritten + + +def test_rewrite_value_preserves_default_mqtt_port_when_only_host_changes(monkeypatch) -> None: + mitm_redirect = _load_mitm_redirect(monkeypatch) + mitm_redirect.LOCAL_MQTT_HOST = "api-roborock.example.com" + mitm_redirect.LOCAL_MQTT_PORT = None + + rewritten = mitm_redirect._rewrite_value("ssl://mqtt-us.roborock.com:8883") + + assert rewritten == "ssl://api-roborock.example.com:8883" + + +def test_request_routes_to_custom_api_port(monkeypatch) -> None: + mitm_redirect = _load_mitm_redirect(monkeypatch) + mitm_redirect.LOCAL_API = "api-roborock.example.com:8443" + mitm_redirect.LOCAL_API_HOST = "api-roborock.example.com" + mitm_redirect.LOCAL_API_PORT = 8443 + flow = _FakeFlow("api-us.roborock.com", "/api/v1/getHomeDetail") + + mitm_redirect.request(flow) + + assert flow.request.host == "api-roborock.example.com" + assert flow.request.port == 8443 + assert flow.request.headers["Host"] == "api-roborock.example.com:8443" + + +def test_parse_endpoint_defaults_to_new_stack_ports(monkeypatch) -> None: + mitm_redirect = _load_mitm_redirect(monkeypatch) + + api_host, api_port = mitm_redirect._parse_endpoint( + "api-roborock.example.com", + fallback_port=mitm_redirect.DEFAULT_LOCAL_API_PORT, + ) + mqtt_host, mqtt_port = mitm_redirect._parse_endpoint( + "", + fallback_host=api_host, + fallback_port=mitm_redirect.DEFAULT_LOCAL_MQTT_PORT, + ) + + assert (api_host, api_port) == ("api-roborock.example.com", 555) + assert (mqtt_host, mqtt_port) == ("api-roborock.example.com", 8881) diff --git a/tests/test_mqtt_tls_proxy.py b/tests/test_mqtt_tls_proxy.py index b26a5ad..0e54cda 100644 --- a/tests/test_mqtt_tls_proxy.py +++ b/tests/test_mqtt_tls_proxy.py @@ -1,6 +1,11 @@ +import json import logging +import socket import threading import time +from pathlib import Path + +import pytest from roborock_local_server.backend import MqttTlsProxy @@ -8,13 +13,19 @@ class _FakeSourceSocket: def __init__(self, *chunks: bytes) -> None: self._chunks = list(chunks) + self.sent: list[bytes] = [] self.closed = False + self.recv_calls = 0 def recv(self, _size: int) -> bytes: + self.recv_calls += 1 if self._chunks: return self._chunks.pop(0) return b"" + def sendall(self, chunk: bytes) -> None: + self.sent.append(chunk) + def close(self) -> None: self.closed = True @@ -33,7 +44,97 @@ def close(self) -> None: self.closed = True +class _FakeBackendSocket(_FakeDestinationSocket): + def __init__(self) -> None: + super().__init__() + self.connected_to: tuple[str, int] | None = None + + def connect(self, addr: tuple[str, int]) -> None: + self.connected_to = addr + + def recv(self, _size: int) -> bytes: + return b"" + + +def _write_json(path: Path, payload: object) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8") + + +def _seed_cloud_snapshot(path: Path) -> None: + _write_json( + path, + { + "user_data": { + "token": "local-token-123", + "rruid": "local-rruid-123", + "rriot": { + "u": "hawk-user-123", + "s": "hawk-session-123", + "h": "hawk-secret-123", + "k": "hawk-mqtt-key-123", + "r": { + "r": "US", + "a": "https://api-us.roborock.com", + "m": "ssl://mqtt-us.roborock.com:8883", + "l": "https://wood-us.roborock.com", + }, + }, + } + }, + ) + + +def _seed_protocol_sessions(path: Path) -> None: + _write_json( + path, + { + "version": 1, + "sessions": [ + { + "source": "test_sync", + "updated_at_utc": "2026-04-17T17:00:00+00:00", + "user_data": { + "token": "real-cloud-token-999", + "rruid": "real-cloud-rruid-999", + "rriot": { + "u": "real-cloud-hawk-user", + "s": "real-cloud-hawk-session", + "h": "real-cloud-hawk-secret", + "k": "real-cloud-mqtt-key", + }, + }, + } + ], + }, + ) + + +def _build_connect_packet(*, client_id: str, username: str, password: str, protocol_level: int = 4) -> bytes: + protocol_name = b"MQTT" + variable_header = ( + len(protocol_name).to_bytes(2, "big") + + protocol_name + + bytes([protocol_level, 0xC2]) # clean session + username + password + + (60).to_bytes(2, "big") + ) + if protocol_level == 5: + variable_header += b"\x00" + payload = ( + len(client_id.encode()).to_bytes(2, "big") + + client_id.encode() + + len(username.encode()).to_bytes(2, "big") + + username.encode() + + len(password.encode()).to_bytes(2, "big") + + password.encode() + ) + remaining = variable_header + payload + return bytes([0x10, len(remaining)]) + remaining + + def test_relay_forwards_chunk_before_slow_packet_tracing_finishes(tmp_path, monkeypatch) -> None: + cloud_snapshot_path = tmp_path / "cloud_snapshot.json" + _seed_cloud_snapshot(cloud_snapshot_path) proxy = MqttTlsProxy( cert_file=tmp_path / "fullchain.pem", key_file=tmp_path / "privkey.pem", @@ -44,6 +145,7 @@ def test_relay_forwards_chunk_before_slow_packet_tracing_finishes(tmp_path, monk localkey="test-local-key", logger=logging.getLogger("test.mqtt_tls_proxy"), decoded_jsonl=tmp_path / "decoded.jsonl", + cloud_snapshot_path=cloud_snapshot_path, ) trace_started = threading.Event() trace_finished = threading.Event() @@ -83,3 +185,434 @@ def slow_trace_packet(conn_id: str, direction: str, packet: bytes) -> None: assert dst.closed is True proxy.stop() + + +def test_authorize_connect_accepts_native_user_hash_credentials(tmp_path) -> None: + cloud_snapshot_path = tmp_path / "cloud_snapshot.json" + _seed_cloud_snapshot(cloud_snapshot_path) + runtime_credentials_path = tmp_path / "runtime_credentials.json" + _write_json( + runtime_credentials_path, + { + "schema_version": 2, + "mqtt_usr": "bootstrap-user", + "mqtt_passwd": "bootstrap-pass", + "mqtt_clientid": "bootstrap-client", + "devices": [], + }, + ) + from shared.runtime_credentials import RuntimeCredentialsStore + + runtime_credentials = RuntimeCredentialsStore(runtime_credentials_path) + proxy = MqttTlsProxy( + cert_file=tmp_path / "fullchain.pem", + key_file=tmp_path / "privkey.pem", + listen_host="127.0.0.1", + listen_port=8883, + backend_host="127.0.0.1", + backend_port=1883, + localkey="test-local-key", + logger=logging.getLogger("test.mqtt_tls_proxy"), + decoded_jsonl=tmp_path / "decoded.jsonl", + cloud_snapshot_path=cloud_snapshot_path, + runtime_credentials=runtime_credentials, + ) + + packet = _build_connect_packet( + client_id="ha-client", + username="52359d04", + password="cb5af78c8d901feb", + ) + authorized, reason, info = proxy._authorize_connect_packet(packet) + + assert authorized is True + assert reason == "user_hash" + assert info is not None + assert info["client_id"] == "ha-client" + + +def test_authorize_connect_accepts_bootstrap_credentials_and_rejects_wrong_password(tmp_path) -> None: + cloud_snapshot_path = tmp_path / "cloud_snapshot.json" + _seed_cloud_snapshot(cloud_snapshot_path) + runtime_credentials_path = tmp_path / "runtime_credentials.json" + _write_json( + runtime_credentials_path, + { + "schema_version": 2, + "mqtt_usr": "bootstrap-user", + "mqtt_passwd": "bootstrap-pass", + "mqtt_clientid": "bootstrap-client", + "devices": [], + }, + ) + from shared.runtime_credentials import RuntimeCredentialsStore + + runtime_credentials = RuntimeCredentialsStore(runtime_credentials_path) + proxy = MqttTlsProxy( + cert_file=tmp_path / "fullchain.pem", + key_file=tmp_path / "privkey.pem", + listen_host="127.0.0.1", + listen_port=8883, + backend_host="127.0.0.1", + backend_port=1883, + localkey="test-local-key", + logger=logging.getLogger("test.mqtt_tls_proxy"), + decoded_jsonl=tmp_path / "decoded.jsonl", + cloud_snapshot_path=cloud_snapshot_path, + runtime_credentials=runtime_credentials, + ) + + bootstrap_packet = _build_connect_packet( + client_id="bootstrap-client", + username="bootstrap-user", + password="bootstrap-pass", + ) + authorized, reason, _info = proxy._authorize_connect_packet(bootstrap_packet) + assert authorized is True + assert reason == "bootstrap" + + wrong_password_packet = _build_connect_packet( + client_id="bootstrap-client", + username="bootstrap-user", + password="wrong-pass", + ) + rejected, reject_reason, _info = proxy._authorize_connect_packet(wrong_password_packet) + assert rejected is False + assert reject_reason == "invalid_mqtt_credentials" + + +def test_authorize_connect_accepts_known_device_mqtt_user(tmp_path) -> None: + cloud_snapshot_path = tmp_path / "cloud_snapshot.json" + _seed_cloud_snapshot(cloud_snapshot_path) + runtime_credentials_path = tmp_path / "runtime_credentials.json" + _write_json( + runtime_credentials_path, + { + "schema_version": 2, + "mqtt_usr": "bootstrap-user", + "mqtt_passwd": "bootstrap-pass", + "mqtt_clientid": "bootstrap-client", + "devices": [ + { + "did": "1103821560705", + "duid": "6HL2zfniaoYYV01CkVuhkO", + "name": "Roborock Qrevo MaxV 2", + "model": "roborock.vacuum.a87", + "product_id": "5gUei3OIJIXVD3eD85Balg", + "localkey": "xPd5Dr8CGGqtdDlH", + "local_key_source": "inventory", + "device_mqtt_usr": "c25b14ceac358d2a", + "device_mqtt_pass": "ff8922d24a9a9af81f18f35dcee9a5a5", + "updated_at": "2026-04-17T17:00:00+00:00", + "last_nc_at": "", + "last_mqtt_seen_at": "", + } + ], + }, + ) + from shared.runtime_credentials import RuntimeCredentialsStore + + runtime_credentials = RuntimeCredentialsStore(runtime_credentials_path) + proxy = MqttTlsProxy( + cert_file=tmp_path / "fullchain.pem", + key_file=tmp_path / "privkey.pem", + listen_host="127.0.0.1", + listen_port=8883, + backend_host="127.0.0.1", + backend_port=1883, + localkey="test-local-key", + logger=logging.getLogger("test.mqtt_tls_proxy"), + decoded_jsonl=tmp_path / "decoded.jsonl", + cloud_snapshot_path=cloud_snapshot_path, + runtime_credentials=runtime_credentials, + ) + + packet = _build_connect_packet( + client_id="a012391cb5f8bc97", + username="c25b14ceac358d2a", + password="ff8922d24a9a9af81f18f35dcee9a5a5", + ) + authorized, reason, info = proxy._authorize_connect_packet(packet) + + assert authorized is True + assert reason == "device_mqtt_user" + assert info is not None + assert info["client_id"] == "a012391cb5f8bc97" + + +def test_authorize_connect_recovers_missing_known_device_mqtt_password(tmp_path) -> None: + cloud_snapshot_path = tmp_path / "cloud_snapshot.json" + _seed_cloud_snapshot(cloud_snapshot_path) + runtime_credentials_path = tmp_path / "runtime_credentials.json" + _write_json( + runtime_credentials_path, + { + "schema_version": 2, + "mqtt_usr": "bootstrap-user", + "mqtt_passwd": "bootstrap-pass", + "mqtt_clientid": "bootstrap-client", + "devices": [ + { + "did": "1103821560705", + "duid": "6HL2zfniaoYYV01CkVuhkO", + "name": "Roborock Qrevo MaxV 2", + "model": "roborock.vacuum.a87", + "product_id": "5gUei3OIJIXVD3eD85Balg", + "localkey": "xPd5Dr8CGGqtdDlH", + "local_key_source": "inventory", + "device_mqtt_usr": "c25b14ceac358d2a", + "device_mqtt_pass": "", + "updated_at": "2026-04-17T17:00:00+00:00", + "last_nc_at": "", + "last_mqtt_seen_at": "", + } + ], + }, + ) + from shared.runtime_credentials import RuntimeCredentialsStore + + runtime_credentials = RuntimeCredentialsStore(runtime_credentials_path) + proxy = MqttTlsProxy( + cert_file=tmp_path / "fullchain.pem", + key_file=tmp_path / "privkey.pem", + listen_host="127.0.0.1", + listen_port=8883, + backend_host="127.0.0.1", + backend_port=1883, + localkey="test-local-key", + logger=logging.getLogger("test.mqtt_tls_proxy"), + decoded_jsonl=tmp_path / "decoded.jsonl", + cloud_snapshot_path=cloud_snapshot_path, + runtime_credentials=runtime_credentials, + ) + + packet = _build_connect_packet( + client_id="a012391cb5f8bc97", + username="c25b14ceac358d2a", + password="ff8922d24a9a9af81f18f35dcee9a5a5", + ) + authorized, reason, info = proxy._authorize_connect_packet(packet) + + assert authorized is True + assert reason == "device_mqtt_recovered" + assert info is not None + + recovered_device = runtime_credentials.resolve_device(duid="6HL2zfniaoYYV01CkVuhkO") + assert recovered_device is not None + assert recovered_device["device_mqtt_pass"] == "ff8922d24a9a9af81f18f35dcee9a5a5" + + rejected, reject_reason, _info = proxy._authorize_connect_packet( + _build_connect_packet( + client_id="a012391cb5f8bc97", + username="c25b14ceac358d2a", + password="wrong-pass", + ) + ) + assert rejected is False + assert reject_reason == "invalid_mqtt_credentials" + + +def test_authorize_connect_accepts_persisted_synced_user_hash_credentials(tmp_path) -> None: + cloud_snapshot_path = tmp_path / "cloud_snapshot.json" + _seed_cloud_snapshot(cloud_snapshot_path) + protocol_sessions_path = tmp_path / "protocol_sessions.json" + _seed_protocol_sessions(protocol_sessions_path) + proxy = MqttTlsProxy( + cert_file=tmp_path / "fullchain.pem", + key_file=tmp_path / "privkey.pem", + listen_host="127.0.0.1", + listen_port=8883, + backend_host="127.0.0.1", + backend_port=1883, + localkey="test-local-key", + logger=logging.getLogger("test.mqtt_tls_proxy"), + decoded_jsonl=tmp_path / "decoded.jsonl", + cloud_snapshot_path=cloud_snapshot_path, + protocol_auth_sessions_path=protocol_sessions_path, + ) + + packet = _build_connect_packet( + client_id="ios-app-client", + username="7ad5ebc1", + password="558d41e0cece0ee7", + ) + authorized, reason, info = proxy._authorize_connect_packet(packet) + + assert authorized is True + assert reason == "user_hash" + assert info is not None + assert info["client_id"] == "ios-app-client" + + +def test_authorize_connect_rejects_protocol_user_hash_when_protocol_auth_disabled(tmp_path) -> None: + cloud_snapshot_path = tmp_path / "cloud_snapshot.json" + _seed_cloud_snapshot(cloud_snapshot_path) + protocol_sessions_path = tmp_path / "protocol_sessions.json" + _seed_protocol_sessions(protocol_sessions_path) + proxy = MqttTlsProxy( + cert_file=tmp_path / "fullchain.pem", + key_file=tmp_path / "privkey.pem", + listen_host="127.0.0.1", + listen_port=8883, + backend_host="127.0.0.1", + backend_port=1883, + localkey="test-local-key", + logger=logging.getLogger("test.mqtt_tls_proxy"), + decoded_jsonl=tmp_path / "decoded.jsonl", + cloud_snapshot_path=cloud_snapshot_path, + protocol_auth_sessions_path=protocol_sessions_path, + protocol_auth_enabled=lambda: False, + ) + + packet = _build_connect_packet( + client_id="ios-app-client", + username="7ad5ebc1", + password="558d41e0cece0ee7", + ) + authorized, reason, info = proxy._authorize_connect_packet(packet) + + assert authorized is False + assert reason == "invalid_mqtt_credentials" + assert info is not None + assert info["client_id"] == "ios-app-client" + + +def test_read_first_packet_rejects_invalid_remaining_length() -> None: + src = _FakeSourceSocket(b"\x10\xff\xff\xff\xff") + + with pytest.raises(ValueError, match="remaining length"): + MqttTlsProxy._read_first_packet(src) + + assert src.recv_calls == 1 + + +def test_handle_client_traces_packets_already_buffered_before_relay(tmp_path, monkeypatch) -> None: + cloud_snapshot_path = tmp_path / "cloud_snapshot.json" + _seed_cloud_snapshot(cloud_snapshot_path) + proxy = MqttTlsProxy( + cert_file=tmp_path / "fullchain.pem", + key_file=tmp_path / "privkey.pem", + listen_host="127.0.0.1", + listen_port=8883, + backend_host="127.0.0.1", + backend_port=1883, + localkey="test-local-key", + logger=logging.getLogger("test.mqtt_tls_proxy"), + decoded_jsonl=tmp_path / "decoded.jsonl", + cloud_snapshot_path=cloud_snapshot_path, + ) + backend = _FakeBackendSocket() + traced: list[tuple[str, str, bytes]] = [] + connect_packet = _build_connect_packet( + client_id="ha-client", + username="52359d04", + password="cb5af78c8d901feb", + ) + tls_conn = _FakeSourceSocket(connect_packet + b"\xc0\x00") + + monkeypatch.setattr(socket, "socket", lambda *args, **kwargs: backend) + monkeypatch.setattr(proxy, "_queue_trace_packet", lambda conn_id, direction, packet: traced.append((conn_id, direction, packet))) + + proxy._running = True + proxy._handle_client(tls_conn, ("127.0.0.1", 4321)) + + assert backend.connected_to == ("127.0.0.1", 1883) + assert backend.sent == [connect_packet + b"\xc0\x00"] + assert traced == [ + ("1", "c2b", connect_packet), + ("1", "c2b", b"\xc0\x00"), + ] + + +def test_handle_client_closes_tls_conn_when_client_closes_before_connect(tmp_path) -> None: + cloud_snapshot_path = tmp_path / "cloud_snapshot.json" + _seed_cloud_snapshot(cloud_snapshot_path) + proxy = MqttTlsProxy( + cert_file=tmp_path / "fullchain.pem", + key_file=tmp_path / "privkey.pem", + listen_host="127.0.0.1", + listen_port=8883, + backend_host="127.0.0.1", + backend_port=1883, + localkey="test-local-key", + logger=logging.getLogger("test.mqtt_tls_proxy"), + decoded_jsonl=tmp_path / "decoded.jsonl", + cloud_snapshot_path=cloud_snapshot_path, + ) + tls_conn = _FakeSourceSocket() + + proxy._running = True + proxy._handle_client(tls_conn, ("127.0.0.1", 4321)) + + assert tls_conn.closed is True + + +def test_handle_client_closes_tls_conn_when_connect_is_rejected(tmp_path, monkeypatch) -> None: + cloud_snapshot_path = tmp_path / "cloud_snapshot.json" + _seed_cloud_snapshot(cloud_snapshot_path) + proxy = MqttTlsProxy( + cert_file=tmp_path / "fullchain.pem", + key_file=tmp_path / "privkey.pem", + listen_host="127.0.0.1", + listen_port=8883, + backend_host="127.0.0.1", + backend_port=1883, + localkey="test-local-key", + logger=logging.getLogger("test.mqtt_tls_proxy"), + decoded_jsonl=tmp_path / "decoded.jsonl", + cloud_snapshot_path=cloud_snapshot_path, + ) + tls_conn = _FakeSourceSocket( + _build_connect_packet( + client_id="bad-client", + username="unknown-user", + password="unknown-pass", + ) + ) + + def _unexpected_backend(*args, **kwargs): + raise AssertionError("backend socket should not be created for rejected MQTT CONNECT") + + monkeypatch.setattr(socket, "socket", _unexpected_backend) + + proxy._running = True + proxy._handle_client(tls_conn, ("127.0.0.1", 4321)) + + assert tls_conn.sent == [b"\x20\x02\x00\x05"] + assert tls_conn.closed is True + + +def test_handle_client_returns_mqtt5_not_authorized_connack_on_rejected_connect(tmp_path, monkeypatch) -> None: + cloud_snapshot_path = tmp_path / "cloud_snapshot.json" + _seed_cloud_snapshot(cloud_snapshot_path) + proxy = MqttTlsProxy( + cert_file=tmp_path / "fullchain.pem", + key_file=tmp_path / "privkey.pem", + listen_host="127.0.0.1", + listen_port=8883, + backend_host="127.0.0.1", + backend_port=1883, + localkey="test-local-key", + logger=logging.getLogger("test.mqtt_tls_proxy"), + decoded_jsonl=tmp_path / "decoded.jsonl", + cloud_snapshot_path=cloud_snapshot_path, + ) + tls_conn = _FakeSourceSocket( + _build_connect_packet( + client_id="bad-client", + username="unknown-user", + password="unknown-pass", + protocol_level=5, + ) + ) + + def _unexpected_backend(*args, **kwargs): + raise AssertionError("backend socket should not be created for rejected MQTT CONNECT") + + monkeypatch.setattr(socket, "socket", _unexpected_backend) + + proxy._running = True + proxy._handle_client(tls_conn, ("127.0.0.1", 4321)) + + assert tls_conn.sent == [b"\x20\x03\x00\x87\x00"] + assert tls_conn.closed is True diff --git a/tests/test_onboarding_cli.py b/tests/test_onboarding_cli.py index 02109ec..c8914c0 100644 --- a/tests/test_onboarding_cli.py +++ b/tests/test_onboarding_cli.py @@ -4,7 +4,15 @@ import pytest -from start_onboarding import GuidedOnboardingConfig, run_guided_onboarding +from start_onboarding import ( + ApiReachabilityError, + GuidedOnboardingConfig, + RemoteOnboardingApi, + normalize_api_base_url, + poll_session_until_progress, + run_guided_onboarding, + sanitize_stack_server, +) class FakeApi: @@ -48,6 +56,27 @@ def delete_session(self, *, session_id: str) -> dict: return {"ok": True} +class _RecordingResponse: + def __enter__(self) -> "_RecordingResponse": + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def read(self) -> bytes: + return b"{}" + + +class _RecordingOpener: + def __init__(self) -> None: + self.urls: list[str] = [] + + def open(self, request, timeout=None, context=None): + _ = timeout, context + self.urls.append(request.full_url) + return _RecordingResponse() + + @pytest.fixture def config() -> GuidedOnboardingConfig: return GuidedOnboardingConfig( @@ -323,3 +352,81 @@ def test_guided_onboarding_duplicate_names_still_selects_requested_device( assert api.started_duids == ["cloud-q7-b"] assert "cloud-q7-a" in output.getvalue() assert "cloud-q7-b" in output.getvalue() + + +def test_poll_session_retries_while_machine_reconnects_to_normal_wifi() -> None: + class FlakyApi: + def __init__(self) -> None: + self.calls = 0 + + def get_session(self, *, session_id: str) -> dict: + assert session_id == "sess-1" + self.calls += 1 + if self.calls <= 2: + raise ApiReachabilityError( + "Unable to reach https://api-roborock.example.com: [Errno -2] Name or service not known" + ) + return { + "session_id": session_id, + "query_samples": 2, + "has_public_key": True, + "public_key_state": "ready", + "connected": True, + "guidance": "Device paired and connected.", + "target": {"name": "Q7 Upstairs", "duid": "cloud-q7-a", "did": "1103821560705"}, + } + + api = FlakyApi() + output = StringIO() + sleeps: list[float] = [] + + result, status = poll_session_until_progress( + api, + session_id="sess-1", + baseline_samples=0, + baseline_status={ + "session_id": "sess-1", + "query_samples": 0, + "has_public_key": False, + "public_key_state": "missing", + "connected": False, + "target": {"name": "Q7 Upstairs", "duid": "cloud-q7-a", "did": ""}, + }, + output=output, + poll_interval_seconds=5.0, + timeout_seconds=20.0, + sleep_fn=sleeps.append, + ) + + assert result == "connected" + assert status["connected"] is True + assert sleeps == [5.0, 5.0] + assert "The main server is not reachable yet from this machine." in output.getvalue() + assert "Still waiting for this machine to reach the main server again..." in output.getvalue() + + +def test_onboarding_server_normalization_preserves_custom_ports() -> None: + assert normalize_api_base_url("api-roborock.example.com:8443") == "https://api-roborock.example.com:8443" + assert sanitize_stack_server("https://api-roborock.example.com:8443") == "roborock.example.com:8443/" + + +def test_onboarding_server_normalization_defaults_to_port_555() -> None: + assert normalize_api_base_url("api-roborock.example.com") == "https://api-roborock.example.com:555" + assert sanitize_stack_server("https://api-roborock.example.com") == "roborock.example.com:555/" + + +def test_remote_onboarding_api_uses_custom_port_base_url() -> None: + opener = _RecordingOpener() + api = RemoteOnboardingApi( + base_url="https://api-roborock.example.com:8443", + admin_password="secret", + opener=opener, + ) + + api.login() + api.list_devices() + + assert opener.urls == [ + "https://api-roborock.example.com:8443/admin/api/login", + "https://api-roborock.example.com:8443/admin/api/onboarding/devices", + ] diff --git a/tests/test_protocol_auth.py b/tests/test_protocol_auth.py new file mode 100644 index 0000000..9ad762c --- /dev/null +++ b/tests/test_protocol_auth.py @@ -0,0 +1,401 @@ +import json +import time +from pathlib import Path + +from fastapi.testclient import TestClient + +from conftest import write_release_config +from roborock_local_server.config import load_config, resolve_paths +from roborock_local_server.server import PROTOCOL_AUTH_SYNC_PATH, ReleaseSupervisor +from https_server.routes.auth.service import load_cloud_user_data +from shared.protocol_auth import ProtocolAuthStore, build_hawk_authorization + + +def _write_json(path: Path, payload: object) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8") + + +def _seed_cloud_snapshot(path: Path) -> None: + _write_json( + path, + { + "meta": {"username": "user@example.com"}, + "user_data": { + "uid": 1001, + "token": "local-token-123", + "rruid": "local-rruid-123", + "rriot": { + "u": "hawk-user-123", + "s": "hawk-session-123", + "h": "hawk-secret-123", + "k": "hawk-mqtt-key-123", + "r": { + "r": "US", + "a": "https://api-us.roborock.com", + "m": "ssl://mqtt-us.roborock.com:8883", + "l": "https://wood-us.roborock.com", + }, + }, + }, + "home_data": {"id": 12345, "name": "Test Home", "devices": []}, + }, + ) + + +def _protocol_user_data( + *, + token: str, + rruid: str, + hawk_id: str, + hawk_session: str, + hawk_key: str, + mqtt_key: str, +) -> dict[str, object]: + return { + "uid": 1001, + "token": token, + "rruid": rruid, + "rriot": { + "u": hawk_id, + "s": hawk_session, + "h": hawk_key, + "k": mqtt_key, + }, + } + + +def _token_headers(login_payload: dict[str, object]) -> dict[str, str]: + return { + "Authorization": str(login_payload["token"]), + "header_username": str(login_payload["rruid"]), + } + + +def _build_supervisor(tmp_path: Path, *, with_snapshot: bool = True) -> tuple[ReleaseSupervisor, object]: + config_file = write_release_config(tmp_path) + config = load_config(config_file) + paths = resolve_paths(config_file, config) + _write_json(paths.inventory_path, {"home": {"id": 12345, "name": "Test Home"}, "devices": []}) + if with_snapshot: + _seed_cloud_snapshot(paths.cloud_snapshot_path) + supervisor = ReleaseSupervisor(config=config, paths=paths) + return supervisor, paths + + +def _build_supervisor_with_protocol_toggle( + tmp_path: Path, + *, + protocol_auth_enabled: bool, +) -> tuple[ReleaseSupervisor, object]: + config_file = write_release_config(tmp_path, protocol_auth_enabled=protocol_auth_enabled) + config = load_config(config_file) + paths = resolve_paths(config_file, config) + _write_json(paths.inventory_path, {"home": {"id": 12345, "name": "Test Home"}, "devices": []}) + _seed_cloud_snapshot(paths.cloud_snapshot_path) + supervisor = ReleaseSupervisor(config=config, paths=paths) + return supervisor, paths + + +def test_protected_routes_require_native_token_and_hawk_auth(tmp_path: Path) -> None: + supervisor, paths = _build_supervisor(tmp_path) + client = TestClient(supervisor.app) + + unauth_home = client.get("/api/v1/getHomeDetail") + assert unauth_home.status_code == 401 + assert unauth_home.json()["code"] == 2010 + assert unauth_home.json()["data"]["auth"] == "token" + + token_headers = { + "Authorization": "local-token-123", + "header_username": "local-rruid-123", + } + authed_home = client.get("/api/v1/getHomeDetail", headers=token_headers) + assert authed_home.status_code == 200 + assert authed_home.json()["data"]["rrHomeId"] == 12345 + + unauth_inbox = client.get("/user/inbox/latest") + assert unauth_inbox.status_code == 401 + assert unauth_inbox.json()["code"] == 40101 + assert unauth_inbox.json()["data"]["auth"] == "hawk" + + auth_store = ProtocolAuthStore(paths.cloud_snapshot_path) + user = auth_store.availability().user + assert user is not None + hawk_headers = { + "Authorization": build_hawk_authorization( + user=user, + path="/user/inbox/latest", + timestamp=int(time.time()), + nonce="nonce-protocol-auth", + ) + } + authed_inbox = client.get("/user/inbox/latest", headers=hawk_headers) + assert authed_inbox.status_code == 200 + assert authed_inbox.json()["data"]["count"] == 0 + + +def test_token_auth_failures_use_roborock_invalid_credentials_code(tmp_path: Path) -> None: + supervisor, _paths = _build_supervisor(tmp_path) + client = TestClient(supervisor.app) + + bad_token = client.get("/api/v1/getHomeDetail", headers={"Authorization": "wrong-token"}) + assert bad_token.status_code == 401 + assert bad_token.json()["code"] == 2010 + assert bad_token.json()["msg"] == "invalid_credentials" + assert bad_token.json()["data"]["reason"] == "invalid_token" + + wrong_user = client.get( + "/api/v1/getHomeDetail", + headers={ + "Authorization": "local-token-123", + "header_username": "wrong-rruid", + }, + ) + assert wrong_user.status_code == 401 + assert wrong_user.json()["code"] == 2010 + assert wrong_user.json()["data"]["reason"] == "invalid_header_username" + + +def test_local_pin_login_succeeds_without_imported_cloud_snapshot(tmp_path: Path) -> None: + supervisor, _paths = _build_supervisor(tmp_path, with_snapshot=False) + client = TestClient(supervisor.app) + + send_response = client.post( + "/api/v5/email/code/send", + json={"email": "USER@example.com", "baseUrl": "https://api-us.roborock.com"}, + ) + assert send_response.status_code == 200 + assert send_response.json()["data"]["sent"] is True + + login_response = client.post( + "/api/v5/auth/email/login/code", + json={"email": "USER@example.com", "code": "123456", "baseUrl": "https://api-us.roborock.com"}, + ) + assert login_response.status_code == 200 + login_payload = login_response.json()["data"] + assert login_payload["email"] == "user@example.com" + assert login_payload["token"].startswith("rr") + assert login_payload["rriot"]["r"]["a"] == supervisor.context.api_url() + assert login_payload["rriot"]["r"]["m"] == supervisor.context.mqtt_url() + assert login_payload["rriot"]["r"]["l"] == supervisor.context.wood_url() + + home_response = client.get("/api/v1/getHomeDetail", headers=_token_headers(login_payload)) + assert home_response.status_code == 200 + assert home_response.json()["data"]["rrHomeId"] == 12345 + + user_info = client.get("/api/v1/userInfo", headers=_token_headers(login_payload)) + assert user_info.status_code == 200 + assert user_info.json()["data"]["email"] == "user@example.com" + + +def test_protected_routes_skip_protocol_auth_when_disabled(tmp_path: Path) -> None: + supervisor, _paths = _build_supervisor_with_protocol_toggle(tmp_path, protocol_auth_enabled=False) + client = TestClient(supervisor.app) + + home_response = client.get("/api/v1/getHomeDetail") + assert home_response.status_code == 200 + + inbox_response = client.get("/user/inbox/latest") + assert inbox_response.status_code == 200 + + +def test_protocol_code_login_routes_use_local_email_and_pin_without_cloud_manager(tmp_path: Path, monkeypatch) -> None: + supervisor, _paths = _build_supervisor(tmp_path) + client = TestClient(supervisor.app) + snapshot = json.loads(supervisor.paths.cloud_snapshot_path.read_text(encoding="utf-8")) + snapshot.setdefault("meta", {})["username"] = "imported@example.com" + snapshot.setdefault("user_data", {})["email"] = "imported@example.com" + supervisor.paths.cloud_snapshot_path.write_text(json.dumps(snapshot, indent=2) + "\n", encoding="utf-8") + + async def fail_request_code(*, email: str, base_url: str = "") -> dict[str, object]: + _ = email, base_url + raise AssertionError("protocol code send must not call cloud_manager.request_code") + + async def fail_submit_code(*, session_id: str, code: str) -> dict[str, object]: + _ = session_id, code + raise AssertionError("protocol code submit must not call cloud_manager.submit_code") + + monkeypatch.setattr(supervisor.cloud_manager, "request_code", fail_request_code) + monkeypatch.setattr(supervisor.cloud_manager, "submit_code", fail_submit_code) + + send_response = client.post( + "/api/v5/email/code/send", + json={"email": "user@example.com", "baseUrl": supervisor.context.api_url()}, + ) + assert send_response.status_code == 200 + assert send_response.json()["data"]["sent"] is True + + login_response = client.post( + "/api/v5/auth/email/login/code", + json={"email": "user@example.com", "code": "123456", "baseUrl": supervisor.context.api_url(), "sessionId": "ignored"}, + ) + assert login_response.status_code == 200 + login_payload = login_response.json()["data"] + assert login_payload["token"] != "local-token-123" + assert login_payload["rruid"] != "local-rruid-123" + assert login_payload["rriot"]["u"] != "hawk-user-123" + assert login_payload["email"] == "user@example.com" + + +def test_protocol_code_login_rejects_wrong_email_and_wrong_pin(tmp_path: Path) -> None: + supervisor, _paths = _build_supervisor(tmp_path, with_snapshot=False) + client = TestClient(supervisor.app) + + wrong_email = client.post( + "/api/v5/auth/email/login/code", + json={"email": "other@example.com", "code": "123456"}, + ) + assert wrong_email.status_code == 401 + assert wrong_email.json()["code"] == 2010 + assert wrong_email.json()["data"]["reason"] == "invalid_login_email" + + wrong_pin = client.post( + "/api/v5/auth/email/login/code", + json={"email": "user@example.com", "code": "654321"}, + ) + assert wrong_pin.status_code == 401 + assert wrong_pin.json()["code"] == 2010 + assert wrong_pin.json()["data"]["reason"] == "invalid_login_pin" + + +def test_protocol_password_login_is_rejected(tmp_path: Path) -> None: + supervisor, _paths = _build_supervisor(tmp_path) + client = TestClient(supervisor.app) + + response = client.post("/api/v5/auth/email/login/pwd", json={"email": "user@example.com", "password": "secret"}) + assert response.status_code == 400 + assert response.json()["msg"] == "password_login_not_supported" + + +def test_protocol_sync_route_persists_additional_sessions_and_redacts_logs(tmp_path: Path) -> None: + supervisor, paths = _build_supervisor(tmp_path) + client = TestClient(supervisor.app) + synced_user_data = _protocol_user_data( + token="real-cloud-token-999", + rruid="real-cloud-rruid-999", + hawk_id="real-cloud-hawk-user", + hawk_session="real-cloud-hawk-session", + hawk_key="real-cloud-hawk-secret", + mqtt_key="real-cloud-mqtt-key", + ) + + sync_response = client.post( + PROTOCOL_AUTH_SYNC_PATH, + json={"source": "test_sync", "user_data": synced_user_data}, + headers={"X-Local-Sync-Secret": "abcdefghijklmnopqrstuvwxyz123456"}, + ) + assert sync_response.status_code == 200 + assert sync_response.json()["data"]["stored"] is True + assert paths.protocol_auth_sessions_path.exists() + + token_headers = { + "Authorization": "real-cloud-token-999", + "header_username": "real-cloud-rruid-999", + } + authed_home = client.get("/api/v1/getHomeDetail", headers=token_headers) + assert authed_home.status_code == 200 + + auth_store = ProtocolAuthStore( + paths.cloud_snapshot_path, + session_store_path=paths.protocol_auth_sessions_path, + ) + synced_user = next(user for user in auth_store.availability().users if user.token == "real-cloud-token-999") + hawk_headers = { + "Authorization": build_hawk_authorization( + user=synced_user, + path="/user/inbox/latest", + timestamp=int(time.time()), + nonce="nonce-protocol-sync", + ) + } + authed_inbox = client.get("/user/inbox/latest", headers=hawk_headers) + assert authed_inbox.status_code == 200 + + log_entries = [ + json.loads(line) + for line in paths.http_jsonl_path.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + sync_entry = next(entry for entry in log_entries if entry.get("route") == "protocol_auth_sync") + assert sync_entry["body_redacted"] is True + assert "body_text" not in sync_entry + assert sync_entry["headers"]["x-local-sync-secret"] == "" + + +def test_local_issued_and_imported_sessions_coexist(tmp_path: Path) -> None: + supervisor, paths = _build_supervisor(tmp_path) + client = TestClient(supervisor.app) + + local_session = supervisor.protocol_auth.issue_local_session( + load_cloud_user_data(supervisor.context) or {}, + source="test_local_login", + ) + imported_user_data = _protocol_user_data( + token="real-cloud-token-999", + rruid="real-cloud-rruid-999", + hawk_id="real-cloud-hawk-user", + hawk_session="real-cloud-hawk-session", + hawk_key="real-cloud-hawk-secret", + mqtt_key="real-cloud-mqtt-key", + ) + sync_response = client.post( + PROTOCOL_AUTH_SYNC_PATH, + json={"source": "test_sync", "user_data": imported_user_data}, + headers={"X-Local-Sync-Secret": "abcdefghijklmnopqrstuvwxyz123456"}, + ) + assert sync_response.status_code == 200 + + auth_store = ProtocolAuthStore( + paths.cloud_snapshot_path, + session_store_path=paths.protocol_auth_sessions_path, + ) + availability = auth_store.availability() + assert len(availability.users) >= 3 + + local_token_response = client.get( + "/api/v1/getHomeDetail", + headers={ + "Authorization": str(local_session["token"]), + "header_username": str(local_session["rruid"]), + }, + ) + assert local_token_response.status_code == 200 + + imported_token_response = client.get( + "/api/v1/getHomeDetail", + headers={ + "Authorization": "real-cloud-token-999", + "header_username": "real-cloud-rruid-999", + }, + ) + assert imported_token_response.status_code == 200 + + local_hawk_user = next(user for user in availability.users if user.token == local_session["token"]) + imported_hawk_user = next(user for user in availability.users if user.token == "real-cloud-token-999") + + local_hawk_response = client.get( + "/user/inbox/latest", + headers={ + "Authorization": build_hawk_authorization( + user=local_hawk_user, + path="/user/inbox/latest", + timestamp=int(time.time()), + nonce="nonce-local-issued", + ) + }, + ) + assert local_hawk_response.status_code == 200 + + imported_hawk_response = client.get( + "/user/inbox/latest", + headers={ + "Authorization": build_hawk_authorization( + user=imported_hawk_user, + path="/user/inbox/latest", + timestamp=int(time.time()), + nonce="nonce-imported", + ) + }, + ) + assert imported_hawk_response.status_code == 200 diff --git a/tests/test_routine_runner.py b/tests/test_routine_runner.py index f37ebb1..cbe2910 100644 --- a/tests/test_routine_runner.py +++ b/tests/test_routine_runner.py @@ -18,11 +18,13 @@ def _test_context(tmp_path: Path) -> ServerContext: mqtt_host="mqtt.example.com", wood_host="wood.example.com", region="us", + protocol_login_email="user@example.com", localkey="local-key", duid="default-duid", mqtt_usr="mqtt-user", mqtt_passwd="mqtt-pass", mqtt_clientid="mqtt-client", + https_port=443, mqtt_tls_port=8883, http_jsonl=tmp_path / "http.jsonl", mqtt_jsonl=tmp_path / "mqtt.jsonl", @@ -456,5 +458,3 @@ async def exercise() -> None: assert len(client.sent_commands) == 2 asyncio.run(exercise()) - - diff --git a/tests/test_runtime_credentials.py b/tests/test_runtime_credentials.py index 8b49c44..106f19e 100644 --- a/tests/test_runtime_credentials.py +++ b/tests/test_runtime_credentials.py @@ -20,6 +20,7 @@ def test_ensure_device_merges_split_did_and_duid_records(tmp_path: Path) -> None "localkey": "", "local_key_source": "", "device_mqtt_usr": "mqtt-user-a", + "device_mqtt_pass": "", "updated_at": "2026-03-16T00:22:31.225097+00:00", "last_nc_at": "", "last_mqtt_seen_at": "2026-03-16T00:22:31.225063+00:00", @@ -33,6 +34,7 @@ def test_ensure_device_merges_split_did_and_duid_records(tmp_path: Path) -> None "localkey": "local-key-a", "local_key_source": "inventory", "device_mqtt_usr": "", + "device_mqtt_pass": "", "updated_at": "2026-03-16T00:22:20.199941+00:00", "last_nc_at": "", "last_mqtt_seen_at": "", @@ -49,6 +51,7 @@ def test_ensure_device_merges_split_did_and_duid_records(tmp_path: Path) -> None did="1103821560705", duid="cloud-q7-a", device_mqtt_usr="mqtt-user-a", + device_mqtt_pass="mqtt-pass-a", assign_localkey=False, ) @@ -61,4 +64,53 @@ def test_ensure_device_merges_split_did_and_duid_records(tmp_path: Path) -> None assert devices[0]["product_id"] == "product-q7-a" assert devices[0]["localkey"] == "local-key-a" assert devices[0]["device_mqtt_usr"] == "mqtt-user-a" + assert devices[0]["device_mqtt_pass"] == "mqtt-pass-a" assert merged["duid"] == "cloud-q7-a" + + +def test_backfill_device_mqtt_passwords_updates_only_known_usernames(tmp_path: Path) -> None: + credentials_path = tmp_path / "runtime_credentials.json" + credentials_path.write_text( + json.dumps( + { + "schema_version": 2, + "devices": [ + { + "did": "1103821560705", + "duid": "cloud-q7-a", + "name": "Q7 Upstairs", + "model": "roborock.vacuum.sc05", + "product_id": "product-q7-a", + "localkey": "local-key-a", + "local_key_source": "inventory", + "device_mqtt_usr": "c25b14ceac358d2a", + "device_mqtt_pass": "", + "updated_at": "", + "last_nc_at": "", + "last_mqtt_seen_at": "", + } + ], + } + ) + + "\n", + encoding="utf-8", + ) + log_path = tmp_path / "mqtt_server.log" + log_path.write_text( + "\n".join( + [ + "2026-04-17 16:20:32,393 [INFO] [conn 1670 c2b] CONNECT len=82 hex=105000044d51545404c2001e00106130313233393163623566386263393700106332356231346365616333353864326100206666383932326432346139613961663831663138663335646365653961356135", + "2026-04-17 16:20:33,603 [INFO] [conn 1671 c2b] CONNECT len=82 hex=105000044d51545404c2001e00103664343439636537383337643366666600106464323131333035653264343837336200203762336533656138346663383333613937306464363432363032356162313436", + ] + ) + + "\n", + encoding="utf-8", + ) + + store = RuntimeCredentialsStore(credentials_path) + changed = store.backfill_device_mqtt_passwords(log_path) + + assert changed == 1 + device = store.resolve_device(duid="cloud-q7-a") + assert device is not None + assert device["device_mqtt_pass"] == "ff8922d24a9a9af81f18f35dcee9a5a5" diff --git a/tests/test_version_sync.py b/tests/test_version_sync.py new file mode 100644 index 0000000..2c44c2f --- /dev/null +++ b/tests/test_version_sync.py @@ -0,0 +1,16 @@ +import re +import tomllib +from pathlib import Path + +from roborock_local_server import __version__ + + +def test_package_version_matches_pyproject() -> None: + pyproject = tomllib.loads(Path("pyproject.toml").read_text(encoding="utf-8")) + assert pyproject["project"]["version"] == __version__ + + +def test_init_module_exports_single_version_literal() -> None: + init_text = Path("src/roborock_local_server/__init__.py").read_text(encoding="utf-8") + matches = re.findall(r'__version__\s*=\s*"([^"]+)"', init_text) + assert matches == [__version__] diff --git a/uv.lock b/uv.lock index 17f3e4d..fb05c27 100644 --- a/uv.lock +++ b/uv.lock @@ -1266,7 +1266,7 @@ wheels = [ [[package]] name = "roborock-local-server" -version = "0.1.0" +version = "0.0.2" source = { editable = "." } dependencies = [ { name = "aiohttp" },