diff --git a/.gitmodules b/.gitmodules index 081acbd..8386297 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "mist_openapi"] path = mist_openapi url = https://github.com/mistsys/mist_openapi.git - branch = master \ No newline at end of file + branch = 2602.1.7 \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index df1e042..6164ae1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,58 @@ # CHANGELOG +## Version 0.61.2 (March 2026) + +**Released**: March 17, 2026 + +This release adds automatic reconnection support for WebSocket streams, updates the OpenAPI specification, and includes minor bug fixes. + +--- + +### 1. NEW FEATURES + +#### **WebSocket Auto-Reconnect** +`_MistWebsocket` now supports automatic reconnection with configurable parameters: +- `auto_reconnect` — Enable/disable auto-reconnect (default: `False`) +- `max_reconnect_attempts` — Maximum reconnect attempts before giving up (default: `5`) +- `reconnect_backoff` — Base backoff delay in seconds, with exponential increase (default: `2.0`) + +When enabled, the WebSocket automatically reconnects on transient failures using exponential backoff. User-initiated `disconnect()` calls are respected during reconnection attempts. + +```python +ws = mistapi.websockets.sites.DeviceStatsEvents( + apisession, + site_ids=[""], + auto_reconnect=True, + max_reconnect_attempts=5, + reconnect_backoff=2.0 +) +ws.connect(run_in_background=True) +``` + +--- + +### 2. API CHANGES (OpenAPI 2602.1.7) + +Updated to mist_openapi spec version 2602.1.7. + +#### **Insights API** +- **`getSiteInsightMetrics()`** — Now uses `metrics` as a query parameter instead of a path parameter +- **`getSiteInsightMetricsForAP()`** — New function to retrieve insight metrics for a specific AP +- **`getSiteInsightMetricsForClient()`** — Changed `metric` path parameter to `metrics` query parameter +- **`getSiteInsightMetricsForGateway()`** — Changed `metric` path parameter to `metrics` query parameter + +#### **Stats API** +- **`getOrgStats()`** — Removed `start`, `end`, `duration`, `limit`, `page` query parameters +- **`listOrgSiteStats()`** — Removed `start`, `end`, `duration` query parameters + +--- + +### 3. BUG FIXES +- Fixed `ShellSession.recv()` to gracefully handle socket timeout reset when the connection is already closed +- Fixed thread-safety (TOCTOU) race conditions in `ShellSession` by capturing WebSocket reference in local variables across `disconnect()`, `connected`, `send()`, `recv()`, and `resize()` methods +- Fixed thread-safety race condition in `_MistWebsocket.disconnect()` with local variable capture + +--- + ## Version 0.61.1 (March 2026) **Released**: March 15, 2026 diff --git a/README.md b/README.md index b5c343a..688d105 100644 --- a/README.md +++ b/README.md @@ -579,19 +579,23 @@ The package provides a WebSocket client for real-time event streaming from the M ### Connection Parameters -All channel classes accept the following optional keyword arguments to control the WebSocket keep-alive behaviour: +All channel classes accept the following optional keyword arguments: | Parameter | Type | Default | Description | |-----------|------|---------|-------------| | `ping_interval` | `int` | `30` | Seconds between automatic ping frames. Set to `0` to disable pings. | | `ping_timeout` | `int` | `10` | Seconds to wait for a pong response before treating the connection as dead. | +| `auto_reconnect` | `bool` | `False` | Automatically reconnect on transient failures using exponential backoff. | +| `max_reconnect_attempts` | `int` | `5` | Maximum number of reconnect attempts before giving up. | +| `reconnect_backoff` | `float` | `2.0` | Base backoff delay in seconds. Doubles after each failed attempt (2s, 4s, 8s, ...). Resets on successful reconnection. | ```python ws = mistapi.websockets.sites.DeviceStatsEvents( apisession, site_ids=[""], - ping_interval=60, # ping every 60 s - ping_timeout=20, # wait up to 20 s for pong + ping_interval=60, # ping every 60 s + ping_timeout=20, # wait up to 20 s for pong + auto_reconnect=True, # reconnect on transient failures ) ws.connect() ``` diff --git a/pyproject.toml b/pyproject.toml index 2025baa..8436a61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "mistapi" -version = "0.61.1" +version = "0.61.2" authors = [{ name = "Thomas Munzer", email = "tmunzer@juniper.net" }] description = "Python package to simplify the Mist System APIs usage" keywords = ["Mist", "Juniper", "API"] diff --git a/src/mistapi/__api_request.py b/src/mistapi/__api_request.py index 3b777fc..8453180 100644 --- a/src/mistapi/__api_request.py +++ b/src/mistapi/__api_request.py @@ -253,7 +253,8 @@ def _request_with_retry( "apirequest:%s:Exception occurred", method_name, exc_info=True ) break - self._count += 1 + with self._token_lock: + self._count += 1 return APIResponse(url=url, response=resp, proxy_error=proxy_failed) def mist_get(self, uri: str, query: dict[str, str] | None = None) -> APIResponse: diff --git a/src/mistapi/__api_response.py b/src/mistapi/__api_response.py index 7547965..2d878d7 100644 --- a/src/mistapi/__api_response.py +++ b/src/mistapi/__api_response.py @@ -11,6 +11,8 @@ This module manages API responses """ +import re + from requests import Response from requests.structures import CaseInsensitiveDict @@ -85,7 +87,11 @@ def _check_next(self) -> None: separator = "&" if "?" in uri else "?" self.next = f"{uri}{separator}page={page + 1}" else: - self.next = uri.replace(f"page={page}", f"page={page + 1}") + self.next = re.sub( + rf"(?<=[?&])page={page}(?=&|$)", + f"page={page + 1}", + uri, + ) logger.debug(f"apiresponse:_check_next:set next to {self.next}") except ValueError: logger.error( diff --git a/src/mistapi/__api_session.py b/src/mistapi/__api_session.py index f561a31..d984200 100644 --- a/src/mistapi/__api_session.py +++ b/src/mistapi/__api_session.py @@ -580,7 +580,7 @@ def _get_api_token_data(self, apitoken) -> tuple[str | None, list | None]: if data_json.get("email"): token_type = "user" # nosec bandit B105 - for priv in data_json.get("privileges"): + for priv in data_json.get("privileges", []): tmp = { "scope": priv.get("scope"), "role": priv.get("role"), @@ -715,7 +715,7 @@ def _process_login(self, retry: bool = True) -> str | None: "email/password cleaned up. Restarting authentication function" ) if retry: - return self._process_login(retry) + return self._process_login(retry=False) except requests.exceptions.ProxyError as proxy_error: LOGGER.critical("apisession:_process_login:proxy not valid...") CONSOLE.critical("Proxy not valid...\r\n") @@ -935,9 +935,15 @@ def _set_authenticated(self, authentication_status: bool) -> None: LOGGER.error( "apirequest:mist_post_file: Exception occurred", exc_info=True ) - self._csrftoken = self._session.cookies["csrftoken" + cookies_ext] - self._session.headers.update({"X-CSRFToken": self._csrftoken}) - LOGGER.info("apisession:_set_authenticated:CSRF Token stored") + csrf_cookie = self._session.cookies.get("csrftoken" + cookies_ext) + if csrf_cookie: + self._csrftoken = csrf_cookie + self._session.headers.update({"X-CSRFToken": self._csrftoken}) + LOGGER.info("apisession:_set_authenticated:CSRF Token stored") + else: + LOGGER.error( + "apisession:_set_authenticated:CSRF Token cookie not found" + ) elif authentication_status is False: self._authenticated = False LOGGER.info( @@ -1093,7 +1099,7 @@ def _getself(self) -> bool: for key, val in resp.data.items(): if key == "privileges": self.privileges = Privileges(resp.data["privileges"]) - if key == "tags": + elif key == "tags": for tag in resp.data["tags"]: self.tags.append(tag) else: diff --git a/src/mistapi/__version.py b/src/mistapi/__version.py index adc4c80..ccc3927 100644 --- a/src/mistapi/__version.py +++ b/src/mistapi/__version.py @@ -1,2 +1,2 @@ -__version__ = "0.61.1" +__version__ = "0.61.2" __author__ = "Thomas Munzer " diff --git a/src/mistapi/api/v1/orgs/stats.py b/src/mistapi/api/v1/orgs/stats.py index 51e8914..1a1df64 100644 --- a/src/mistapi/api/v1/orgs/stats.py +++ b/src/mistapi/api/v1/orgs/stats.py @@ -14,15 +14,7 @@ from mistapi.__api_response import APIResponse as _APIResponse -def getOrgStats( - mist_session: _APISession, - org_id: str, - start: str | None = None, - end: str | None = None, - duration: str | None = None, - limit: int | None = None, - page: int | None = None, -) -> _APIResponse: +def getOrgStats(mist_session: _APISession, org_id: str) -> _APIResponse: """ API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/orgs/stats/get-org-stats @@ -35,14 +27,6 @@ def getOrgStats( ----------- org_id : str - QUERY PARAMS - ------------ - start : str - end : str - duration : str, default: 1d - limit : int, default: 100 - page : int, default: 1 - RETURN ----------- mistapi.APIResponse @@ -51,16 +35,6 @@ def getOrgStats( uri = f"/api/v1/orgs/{org_id}/stats" query_params: dict[str, str] = {} - if start: - query_params["start"] = str(start) - if end: - query_params["end"] = str(end) - if duration: - query_params["duration"] = str(duration) - if limit: - query_params["limit"] = str(limit) - if page: - query_params["page"] = str(page) resp = mist_session.mist_get(uri=uri, query=query_params) return resp @@ -1017,9 +991,6 @@ def searchOrgSwOrGwPorts( def listOrgSiteStats( mist_session: _APISession, org_id: str, - start: str | None = None, - end: str | None = None, - duration: str | None = None, limit: int | None = None, page: int | None = None, ) -> _APIResponse: @@ -1037,9 +1008,6 @@ def listOrgSiteStats( QUERY PARAMS ------------ - start : str - end : str - duration : str, default: 1d limit : int, default: 100 page : int, default: 1 @@ -1051,12 +1019,6 @@ def listOrgSiteStats( uri = f"/api/v1/orgs/{org_id}/stats/sites" query_params: dict[str, str] = {} - if start: - query_params["start"] = str(start) - if end: - query_params["end"] = str(end) - if duration: - query_params["duration"] = str(duration) if limit: query_params["limit"] = str(limit) if page: diff --git a/src/mistapi/api/v1/sites/insights.py b/src/mistapi/api/v1/sites/insights.py index 3033df7..d0ad09c 100644 --- a/src/mistapi/api/v1/sites/insights.py +++ b/src/mistapi/api/v1/sites/insights.py @@ -14,11 +14,131 @@ from mistapi.__api_response import APIResponse as _APIResponse +def getSiteInsightMetrics( + mist_session: _APISession, + site_id: str, + metrics: str, + start: str | None = None, + end: str | None = None, + duration: str | None = None, + interval: str | None = None, + limit: int | None = None, + page: int | None = None, +) -> _APIResponse: + """ + API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/sites/insights/get-site-insight-metrics + + PARAMS + ----------- + mistapi.APISession : mist_session + mistapi session including authentication and Mist host information + + PATH PARAMS + ----------- + site_id : str + + QUERY PARAMS + ------------ + metrics : str + start : str + end : str + duration : str, default: 1d + interval : str + limit : int, default: 100 + page : int, default: 1 + + RETURN + ----------- + mistapi.APIResponse + response from the API call + """ + + uri = f"/api/v1/sites/{site_id}/insights" + query_params: dict[str, str] = {} + if metrics: + query_params["metrics"] = str(metrics) + if start: + query_params["start"] = str(start) + if end: + query_params["end"] = str(end) + if duration: + query_params["duration"] = str(duration) + if interval: + query_params["interval"] = str(interval) + if limit: + query_params["limit"] = str(limit) + if page: + query_params["page"] = str(page) + resp = mist_session.mist_get(uri=uri, query=query_params) + return resp + + +def getSiteInsightMetricsForAP( + mist_session: _APISession, + site_id: str, + device_id: str, + metrics: str, + start: str | None = None, + end: str | None = None, + duration: str | None = None, + interval: str | None = None, + limit: int | None = None, + page: int | None = None, +) -> _APIResponse: + """ + API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/sites/insights/get-site-insight-metrics-for-a-p + + PARAMS + ----------- + mistapi.APISession : mist_session + mistapi session including authentication and Mist host information + + PATH PARAMS + ----------- + site_id : str + device_id : str + + QUERY PARAMS + ------------ + metrics : str + start : str + end : str + duration : str, default: 1d + interval : str + limit : int, default: 100 + page : int, default: 1 + + RETURN + ----------- + mistapi.APIResponse + response from the API call + """ + + uri = f"/api/v1/sites/{site_id}/insights/ap/{device_id}/stats" + query_params: dict[str, str] = {} + if metrics: + query_params["metrics"] = str(metrics) + if start: + query_params["start"] = str(start) + if end: + query_params["end"] = str(end) + if duration: + query_params["duration"] = str(duration) + if interval: + query_params["interval"] = str(interval) + if limit: + query_params["limit"] = str(limit) + if page: + query_params["page"] = str(page) + resp = mist_session.mist_get(uri=uri, query=query_params) + return resp + + def getSiteInsightMetricsForClient( mist_session: _APISession, site_id: str, client_mac: str, - metric: str, + metrics: str, start: str | None = None, end: str | None = None, duration: str | None = None, @@ -38,10 +158,10 @@ def getSiteInsightMetricsForClient( ----------- site_id : str client_mac : str - metric : str QUERY PARAMS ------------ + metrics : str start : str end : str duration : str, default: 1d @@ -55,8 +175,10 @@ def getSiteInsightMetricsForClient( response from the API call """ - uri = f"/api/v1/sites/{site_id}/insights/client/{client_mac}/{metric}" + uri = f"/api/v1/sites/{site_id}/insights/client/{client_mac}" query_params: dict[str, str] = {} + if metrics: + query_params["metrics"] = str(metrics) if start: query_params["start"] = str(start) if end: @@ -274,8 +396,8 @@ def searchSiteClientFingerprints( def getSiteInsightMetricsForGateway( mist_session: _APISession, site_id: str, - metric: str, device_id: str, + metrics: str, start: str | None = None, end: str | None = None, duration: str | None = None, @@ -294,11 +416,11 @@ def getSiteInsightMetricsForGateway( PATH PARAMS ----------- site_id : str - metric : str device_id : str QUERY PARAMS ------------ + metrics : str start : str end : str duration : str, default: 1d @@ -312,8 +434,10 @@ def getSiteInsightMetricsForGateway( response from the API call """ - uri = f"/api/v1/sites/{site_id}/insights/gateway/{device_id}/stats/{metric}" + uri = f"/api/v1/sites/{site_id}/insights/gateway/{device_id}/stats" query_params: dict[str, str] = {} + if metrics: + query_params["metrics"] = str(metrics) if start: query_params["start"] = str(start) if end: @@ -552,60 +676,3 @@ def getSiteInsightMetricsForSwitch( query_params["page"] = str(page) resp = mist_session.mist_get(uri=uri, query=query_params) return resp - - -def getSiteInsightMetrics( - mist_session: _APISession, - site_id: str, - metric: str, - start: str | None = None, - end: str | None = None, - duration: str | None = None, - interval: str | None = None, - limit: int | None = None, - page: int | None = None, -) -> _APIResponse: - """ - API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/sites/insights/get-site-insight-metrics - - PARAMS - ----------- - mistapi.APISession : mist_session - mistapi session including authentication and Mist host information - - PATH PARAMS - ----------- - site_id : str - metric : str - - QUERY PARAMS - ------------ - start : str - end : str - duration : str, default: 1d - interval : str - limit : int, default: 100 - page : int, default: 1 - - RETURN - ----------- - mistapi.APIResponse - response from the API call - """ - - uri = f"/api/v1/sites/{site_id}/insights/{metric}" - query_params: dict[str, str] = {} - if start: - query_params["start"] = str(start) - if end: - query_params["end"] = str(end) - if duration: - query_params["duration"] = str(duration) - if interval: - query_params["interval"] = str(interval) - if limit: - query_params["limit"] = str(limit) - if page: - query_params["page"] = str(page) - resp = mist_session.mist_get(uri=uri, query=query_params) - return resp diff --git a/src/mistapi/api/v1/sites/sle.py b/src/mistapi/api/v1/sites/sle.py index 353b59f..99280aa 100644 --- a/src/mistapi/api/v1/sites/sle.py +++ b/src/mistapi/api/v1/sites/sle.py @@ -19,7 +19,7 @@ @deprecation.deprecated( deprecated_in="0.59.2", removed_in="0.65.0", - current_version="0.61.1", + current_version="0.61.2", details="function replaced with getSiteSleClassifierSummaryTrend", ) def getSiteSleClassifierDetails( @@ -691,7 +691,7 @@ def listSiteSleImpactedWirelessClients( @deprecation.deprecated( deprecated_in="0.59.2", removed_in="0.65.0", - current_version="0.61.1", + current_version="0.61.2", details="function replaced with getSiteSleSummaryTrend", ) def getSiteSleSummary( diff --git a/src/mistapi/device_utils/__tools/bgp.py b/src/mistapi/device_utils/__tools/bgp.py index 74f8a5c..bdedf5f 100644 --- a/src/mistapi/device_utils/__tools/bgp.py +++ b/src/mistapi/device_utils/__tools/bgp.py @@ -40,6 +40,8 @@ def summary( UUID of the site where the device is located. device_id : str UUID of the device to show BGP summary on. + timeout : int, default 5 + Time in seconds to wait for data before closing the connection. on_message : Callable, optional Callback invoked with each extracted raw message as it arrives. diff --git a/src/mistapi/device_utils/__tools/shell.py b/src/mistapi/device_utils/__tools/shell.py index f81b783..86c923a 100644 --- a/src/mistapi/device_utils/__tools/shell.py +++ b/src/mistapi/device_utils/__tools/shell.py @@ -127,6 +127,7 @@ def _build_sslopt(self) -> dict: session = self._mist_session._session if session.verify is False: sslopt["cert_reqs"] = ssl.CERT_NONE + sslopt["check_hostname"] = False elif isinstance(session.verify, str): sslopt["ca_certs"] = session.verify if session.cert: @@ -143,6 +144,8 @@ def _build_sslopt(self) -> dict: def connect(self) -> None: """Open the WebSocket connection.""" + if self._ws is not None and self._ws.connected: + raise RuntimeError("Already connected; call disconnect() first") LOGGER.info("Connecting to shell WebSocket: %s", self._ws_url) self._ws = websocket.create_connection( self._ws_url, @@ -150,8 +153,12 @@ def connect(self) -> None: cookie=self._get_cookie(), sslopt=self._build_sslopt(), ) - self._ws.settimeout(0.1) - self.resize(self._rows, self._cols) + try: + self._ws.settimeout(0.1) + self.resize(self._rows, self._cols) + except Exception: + self.disconnect() + raise LOGGER.info("Shell WebSocket connected") def disconnect(self) -> None: @@ -208,7 +215,18 @@ def recv(self, timeout: float = 0.1) -> bytes | None: ): return None finally: - ws.settimeout(old_timeout) + try: + ws.settimeout(old_timeout) + except ( + websocket.WebSocketConnectionClosedException, + ConnectionError, + OSError, + ) as exc: + LOGGER.debug( + "ShellSession.recv: failed to restore websocket timeout " + "(socket may be closed): %s", + exc, + ) def resize(self, rows: int, cols: int) -> None: """Send a terminal resize message to the device.""" diff --git a/src/mistapi/device_utils/ex.py b/src/mistapi/device_utils/ex.py index 816e32a..f7078aa 100644 --- a/src/mistapi/device_utils/ex.py +++ b/src/mistapi/device_utils/ex.py @@ -42,13 +42,17 @@ # Shell (interactive SSH) from mistapi.device_utils.__tools.shell import ShellSession -from mistapi.device_utils.__tools.shell import create_shell_session as createShellSession +from mistapi.device_utils.__tools.shell import ( + create_shell_session as createShellSession, +) from mistapi.device_utils.__tools.shell import interactive_shell as interactiveShell -# Tools (ping, monitor traffic) +# Tools (ping, traceroute, monitor traffic) +from mistapi.device_utils.__tools.miscellaneous import TracerouteProtocol from mistapi.device_utils.__tools.miscellaneous import monitor_traffic as monitorTraffic from mistapi.device_utils.__tools.miscellaneous import ping from mistapi.device_utils.__tools.miscellaneous import top_command as topCommand +from mistapi.device_utils.__tools.miscellaneous import traceroute # Policy functions from mistapi.device_utils.__tools.policy import clear_hit_count as clearHitCount @@ -86,4 +90,6 @@ "monitorTraffic", "ping", "topCommand", + "traceroute", + "TracerouteProtocol", ] diff --git a/src/mistapi/device_utils/srx.py b/src/mistapi/device_utils/srx.py index bdd1f5f..2dfbd4c 100644 --- a/src/mistapi/device_utils/srx.py +++ b/src/mistapi/device_utils/srx.py @@ -32,13 +32,17 @@ # Shell (interactive SSH) from mistapi.device_utils.__tools.shell import ShellSession -from mistapi.device_utils.__tools.shell import create_shell_session as createShellSession +from mistapi.device_utils.__tools.shell import ( + create_shell_session as createShellSession, +) from mistapi.device_utils.__tools.shell import interactive_shell as interactiveShell -# Tools (ping, monitor traffic) +# Tools (ping, traceroute, monitor traffic) +from mistapi.device_utils.__tools.miscellaneous import TracerouteProtocol from mistapi.device_utils.__tools.miscellaneous import monitor_traffic as monitorTraffic from mistapi.device_utils.__tools.miscellaneous import ping from mistapi.device_utils.__tools.miscellaneous import top_command as topCommand +from mistapi.device_utils.__tools.miscellaneous import traceroute # OSPF functions from mistapi.device_utils.__tools.ospf import show_database as retrieveOspfDatabase @@ -50,6 +54,7 @@ from mistapi.device_utils.__tools.port import bounce as bouncePort # Route functions +from mistapi.device_utils.__tools.routes import RouteProtocol from mistapi.device_utils.__tools.routes import show as retrieveRoutes # Sessions functions @@ -59,6 +64,8 @@ __all__ = [ # Classes/Enums "Node", + "RouteProtocol", + "TracerouteProtocol", # ARP "retrieveArpTable", # BGP @@ -86,4 +93,5 @@ "monitorTraffic", "ping", "topCommand", + "traceroute", ] diff --git a/src/mistapi/device_utils/ssr.py b/src/mistapi/device_utils/ssr.py index 6c1f64b..00d4f9a 100644 --- a/src/mistapi/device_utils/ssr.py +++ b/src/mistapi/device_utils/ssr.py @@ -30,8 +30,10 @@ from mistapi.device_utils.__tools.dhcp import release_dhcp_leases as releaseDhcpLeases from mistapi.device_utils.__tools.dhcp import retrieve_dhcp_leases as retrieveDhcpLeases -# Tools (ping only - no monitor_traffic for SSR) +# Tools (ping, traceroute - no monitor_traffic for SSR) +from mistapi.device_utils.__tools.miscellaneous import TracerouteProtocol from mistapi.device_utils.__tools.miscellaneous import ping +from mistapi.device_utils.__tools.miscellaneous import traceroute # DNS functions # from mistapi.utils.dns import test_resolution as test_dns_resolution @@ -45,6 +47,7 @@ from mistapi.device_utils.__tools.port import bounce as bouncePort # Route functions +from mistapi.device_utils.__tools.routes import RouteProtocol from mistapi.device_utils.__tools.routes import show as retrieveRoutes # Service Path functions @@ -59,6 +62,8 @@ __all__ = [ # Classes/Enums "Node", + "RouteProtocol", + "TracerouteProtocol", # ARP "retrieveArpTable", # BGP @@ -84,4 +89,5 @@ "clearSessions", # Tools "ping", + "traceroute", ] diff --git a/src/mistapi/websockets/__ws_client.py b/src/mistapi/websockets/__ws_client.py index a60d736..fb2e04b 100644 --- a/src/mistapi/websockets/__ws_client.py +++ b/src/mistapi/websockets/__ws_client.py @@ -45,27 +45,42 @@ def __init__( channels: list[str], ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: + if max_reconnect_attempts < 0: + raise ValueError("max_reconnect_attempts must be >= 0") + if reconnect_backoff <= 0: + raise ValueError("reconnect_backoff must be > 0") + self._mist_session = mist_session self._channels = channels self._ping_interval = ping_interval self._ping_timeout = ping_timeout + self._auto_reconnect = auto_reconnect + self._max_reconnect_attempts = max_reconnect_attempts + self._reconnect_backoff = reconnect_backoff self._ws: websocket.WebSocketApp | None = None self._thread: threading.Thread | None = None self._queue: queue.Queue[dict | None] = queue.Queue() self._connected = ( threading.Event() ) # tracks whether the WebSocket connection is currently open + self._user_disconnect = threading.Event() + self._reconnect_attempts = 0 + self._last_close_code: int | None = None + self._last_close_msg: str | None = None self._on_message_cb: Callable[[dict], None] | None = None self._on_error_cb: Callable[[Exception], None] | None = None self._on_open_cb: Callable[[], None] | None = None - self._on_close_cb: Callable[[int, str], None] | None = None + self._on_close_cb: Callable[[int | None, str | None], None] | None = None # ------------------------------------------------------------------ # Auth / URL helpers def _build_ws_url(self) -> str: - return f"wss://{self._mist_session._cloud_uri.replace('api.', 'api-ws.')}/api-ws/v1/stream" + return f"wss://{self._mist_session._cloud_uri.replace('api.', 'api-ws.', 1)}/api-ws/v1/stream" def _get_headers(self) -> dict: if self._mist_session._apitoken: @@ -99,6 +114,7 @@ def _build_sslopt(self) -> dict: session = self._mist_session._session if session.verify is False: sslopt["cert_reqs"] = ssl.CERT_NONE + sslopt["check_hostname"] = False elif isinstance(session.verify, str): sslopt["ca_certs"] = session.verify if session.cert: @@ -125,7 +141,7 @@ def on_open(self, callback: Callable[[], None]) -> None: """Register a callback invoked when the connection is established.""" self._on_open_cb = callback - def on_close(self, callback: Callable[[int, str], None]) -> None: + def on_close(self, callback: Callable[[int | None, str | None], None]) -> None: """Register a callback invoked when the connection closes.""" self._on_close_cb = callback @@ -135,6 +151,9 @@ def on_close(self, callback: Callable[[int, str], None]) -> None: def _handle_open(self, ws: websocket.WebSocketApp) -> None: for channel in self._channels: ws.send(json.dumps({"subscribe": channel})) + self._reconnect_attempts = 0 + self._last_close_code = None + self._last_close_msg = None self._connected.set() if self._on_open_cb: self._on_open_cb() @@ -157,17 +176,28 @@ def _handle_error(self, ws: websocket.WebSocketApp, error: Exception) -> None: def _handle_close( self, ws: websocket.WebSocketApp, - close_status_code: int, - close_msg: str, + close_status_code: int | None, + close_msg: str | None, ) -> None: self._connected.clear() - self._queue.put(None) # Signals receive() generator to stop - if self._on_close_cb: - self._on_close_cb(close_status_code, close_msg) + self._last_close_code = close_status_code + self._last_close_msg = close_msg # ------------------------------------------------------------------ # Lifecycle + def _create_ws_app(self) -> websocket.WebSocketApp: + """Create a new WebSocketApp instance with current auth/URL.""" + return websocket.WebSocketApp( + self._build_ws_url(), + header=self._get_headers(), + cookie=self._get_cookie(), + on_open=self._handle_open, + on_message=self._handle_message, + on_error=self._handle_error, + on_close=self._handle_close, + ) + def connect(self, run_in_background: bool = True) -> None: """ Open the WebSocket connection and subscribe to the channel. @@ -178,6 +208,10 @@ def connect(self, run_in_background: bool = True) -> None: If True, runs the WebSocket loop in a daemon thread (non-blocking). If False, blocks the calling thread until disconnected. """ + if self._connected.is_set() or (self._thread is not None and self._thread.is_alive()): + raise RuntimeError("Already connected; call disconnect() first") + self._user_disconnect.clear() + self._reconnect_attempts = 0 # Drain stale sentinel from previous connection while not self._queue.empty(): try: @@ -185,15 +219,7 @@ def connect(self, run_in_background: bool = True) -> None: except queue.Empty: break - self._ws = websocket.WebSocketApp( - self._build_ws_url(), - header=self._get_headers(), - cookie=self._get_cookie(), - on_open=self._handle_open, - on_message=self._handle_message, - on_error=self._handle_error, - on_close=self._handle_close, - ) + self._ws = self._create_ws_app() if run_in_background: self._thread = threading.Thread(target=self._run_forever_safe, daemon=True) self._thread.start() @@ -201,7 +227,7 @@ def connect(self, run_in_background: bool = True) -> None: self._run_forever_safe() def _run_forever_safe(self) -> None: - if self._ws: + while True: try: sslopt = self._build_sslopt() self._ws.run_forever( @@ -213,10 +239,45 @@ def _run_forever_safe(self) -> None: self._handle_error(self._ws, exc) self._handle_close(self._ws, -1, str(exc)) + if self._user_disconnect.is_set() or not self._auto_reconnect: + break + + self._reconnect_attempts += 1 + if self._reconnect_attempts > self._max_reconnect_attempts: + logger.warning( + "Max reconnect attempts (%d) reached, giving up", + self._max_reconnect_attempts, + ) + break + + delay = self._reconnect_backoff * (2 ** (self._reconnect_attempts - 1)) + logger.info( + "Reconnecting in %.1fs (attempt %d/%d)", + delay, + self._reconnect_attempts, + self._max_reconnect_attempts, + ) + if self._user_disconnect.wait(timeout=delay): + break # disconnect() called during backoff + + # Guard against a disconnect that happens immediately after the + # backoff wait returns but before creating a new WebSocketApp. + if self._user_disconnect.is_set(): + break + + self._ws = self._create_ws_app() + + # Final close: put sentinel and call callback + self._queue.put(None) + if self._on_close_cb: + self._on_close_cb(self._last_close_code, self._last_close_msg) + def disconnect(self) -> None: """Close the WebSocket connection.""" - if self._ws: - self._ws.close() + self._user_disconnect.set() + ws = self._ws + if ws: + ws.close() def receive(self) -> Generator[dict, None, None]: """ @@ -227,13 +288,20 @@ def receive(self) -> Generator[dict, None, None]: Intended for use after connect(run_in_background=True). """ - if not self._connected.wait(timeout=10): + if self._auto_reconnect: + while not self._connected.is_set() and not self._user_disconnect.is_set(): + self._connected.wait(timeout=1) + if self._user_disconnect.is_set() and not self._connected.is_set(): + return + elif not self._connected.wait(timeout=10): return while True: try: item = self._queue.get(timeout=1) except queue.Empty: if not self._connected.is_set() and self._queue.empty(): + if self._auto_reconnect and not self._user_disconnect.is_set(): + continue # reconnect in progress, keep waiting break continue if item is None: diff --git a/src/mistapi/websockets/location.py b/src/mistapi/websockets/location.py index 2c40842..8f5dcd9 100644 --- a/src/mistapi/websockets/location.py +++ b/src/mistapi/websockets/location.py @@ -27,19 +27,25 @@ class BleAssetsEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. - map_id : list[str] + map_ids : list[str] UUIDs of the maps to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 - Time in seconds to wait for a ping response before considering the connection dead. + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + auto_reconnect : bool, default False + Automatically reconnect on unexpected disconnections using exponential backoff. + max_reconnect_attempts : int, default 5 + Maximum number of reconnect attempts before giving up. + reconnect_backoff : float, default 2.0 + Base backoff delay in seconds. Doubles after each failed attempt. EXAMPLE ----------- Callback style (background thread):: - ws = LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") + ws = BleAssetsEvents(session, site_id="abc123", map_ids=["def456"]) ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -47,14 +53,14 @@ class BleAssetsEvents(_MistWebsocket): Generator style (background thread):: - ws = LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") + ws = BleAssetsEvents(session, site_id="abc123", map_ids=["def456"]) ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") as ws: + with BleAssetsEvents(session, site_id="abc123", map_ids=["def456"]) as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) @@ -64,16 +70,22 @@ def __init__( self, mist_session: APISession, site_id: str, - map_id: list[str], + map_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: - channels = [f"/sites/{site_id}/stats/maps/{mid}/assets" for mid in map_id] + channels = [f"/sites/{site_id}/stats/maps/{mid}/assets" for mid in map_ids] super().__init__( mist_session, channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_backoff=reconnect_backoff, ) @@ -89,18 +101,24 @@ class ConnectedClientsEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. - map_id : list[str] + map_ids : list[str] UUIDs of the maps to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 Time in seconds to wait for a ping response before considering the connection dead. + auto_reconnect : bool, default False + Automatically reconnect on unexpected disconnections using exponential backoff. + max_reconnect_attempts : int, default 5 + Maximum number of reconnect attempts before giving up. + reconnect_backoff : float, default 2.0 + Base backoff delay in seconds. Doubles after each failed attempt. EXAMPLE ----------- Callback style (background thread):: - ws = LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") + ws = ConnectedClientsEvents(session, site_id="abc123", map_ids=["def456"]) ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -108,14 +126,14 @@ class ConnectedClientsEvents(_MistWebsocket): Generator style (background thread):: - ws = LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") + ws = ConnectedClientsEvents(session, site_id="abc123", map_ids=["def456"]) ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") as ws: + with ConnectedClientsEvents(session, site_id="abc123", map_ids=["def456"]) as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) @@ -125,16 +143,22 @@ def __init__( self, mist_session: APISession, site_id: str, - map_id: list[str], + map_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: - channels = [f"/sites/{site_id}/stats/maps/{mid}/clients" for mid in map_id] + channels = [f"/sites/{site_id}/stats/maps/{mid}/clients" for mid in map_ids] super().__init__( mist_session, channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_backoff=reconnect_backoff, ) @@ -150,18 +174,24 @@ class SdkClientsEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. - map_id : list[str] + map_ids : list[str] UUIDs of the maps to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 Time in seconds to wait for a ping response before considering the connection dead. + auto_reconnect : bool, default False + Automatically reconnect on unexpected disconnections using exponential backoff. + max_reconnect_attempts : int, default 5 + Maximum number of reconnect attempts before giving up. + reconnect_backoff : float, default 2.0 + Base backoff delay in seconds. Doubles after each failed attempt. EXAMPLE ----------- Callback style (background thread):: - ws = LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") + ws = SdkClientsEvents(session, site_id="abc123", map_ids=["def456"]) ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -169,14 +199,14 @@ class SdkClientsEvents(_MistWebsocket): Generator style (background thread):: - ws = LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") + ws = SdkClientsEvents(session, site_id="abc123", map_ids=["def456"]) ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") as ws: + with SdkClientsEvents(session, site_id="abc123", map_ids=["def456"]) as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) @@ -186,16 +216,22 @@ def __init__( self, mist_session: APISession, site_id: str, - map_id: list[str], + map_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: - channels = [f"/sites/{site_id}/stats/maps/{mid}/sdkclients" for mid in map_id] + channels = [f"/sites/{site_id}/stats/maps/{mid}/sdkclients" for mid in map_ids] super().__init__( mist_session, channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_backoff=reconnect_backoff, ) @@ -211,18 +247,24 @@ class UnconnectedClientsEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. - map_id : list[str] + map_ids : list[str] UUIDs of the maps to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 Time in seconds to wait for a ping response before considering the connection dead. + auto_reconnect : bool, default False + Automatically reconnect on unexpected disconnections using exponential backoff. + max_reconnect_attempts : int, default 5 + Maximum number of reconnect attempts before giving up. + reconnect_backoff : float, default 2.0 + Base backoff delay in seconds. Doubles after each failed attempt. EXAMPLE ----------- Callback style (background thread):: - ws = LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") + ws = UnconnectedClientsEvents(session, site_id="abc123", map_ids=["def456"]) ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -230,14 +272,14 @@ class UnconnectedClientsEvents(_MistWebsocket): Generator style (background thread):: - ws = LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") + ws = UnconnectedClientsEvents(session, site_id="abc123", map_ids=["def456"]) ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") as ws: + with UnconnectedClientsEvents(session, site_id="abc123", map_ids=["def456"]) as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) @@ -247,18 +289,24 @@ def __init__( self, mist_session: APISession, site_id: str, - map_id: list[str], + map_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: channels = [ - f"/sites/{site_id}/stats/maps/{mid}/unconnected_clients" for mid in map_id + f"/sites/{site_id}/stats/maps/{mid}/unconnected_clients" for mid in map_ids ] super().__init__( mist_session, channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_backoff=reconnect_backoff, ) @@ -274,18 +322,24 @@ class DiscoveredBleAssetsEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. - map_id : list[str] + map_ids : list[str] UUIDs of the maps to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 Time in seconds to wait for a ping response before considering the connection dead. + auto_reconnect : bool, default False + Automatically reconnect on unexpected disconnections using exponential backoff. + max_reconnect_attempts : int, default 5 + Maximum number of reconnect attempts before giving up. + reconnect_backoff : float, default 2.0 + Base backoff delay in seconds. Doubles after each failed attempt. EXAMPLE ----------- Callback style (background thread):: - ws = LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") + ws = DiscoveredBleAssetsEvents(session, site_id="abc123", map_ids=["def456"]) ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -293,14 +347,14 @@ class DiscoveredBleAssetsEvents(_MistWebsocket): Generator style (background thread):: - ws = LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") + ws = DiscoveredBleAssetsEvents(session, site_id="abc123", map_ids=["def456"]) ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") as ws: + with DiscoveredBleAssetsEvents(session, site_id="abc123", map_ids=["def456"]) as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) @@ -310,16 +364,22 @@ def __init__( self, mist_session: APISession, site_id: str, - map_id: list[str], + map_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: channels = [ - f"/sites/{site_id}/stats/maps/{mid}/discovered_assets" for mid in map_id + f"/sites/{site_id}/stats/maps/{mid}/discovered_assets" for mid in map_ids ] super().__init__( mist_session, channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_backoff=reconnect_backoff, ) diff --git a/src/mistapi/websockets/orgs.py b/src/mistapi/websockets/orgs.py index d8e5520..b506192 100644 --- a/src/mistapi/websockets/orgs.py +++ b/src/mistapi/websockets/orgs.py @@ -31,12 +31,18 @@ class InsightsEvents(_MistWebsocket): Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 Time in seconds to wait for a ping response before considering the connection dead. + auto_reconnect : bool, default False + Automatically reconnect on unexpected disconnections using exponential backoff. + max_reconnect_attempts : int, default 5 + Maximum number of reconnect attempts before giving up. + reconnect_backoff : float, default 2.0 + Base backoff delay in seconds. Doubles after each failed attempt. EXAMPLE ----------- Callback style (background thread):: - ws = OrgInsightsEvents(session, org_id="abc123") + ws = InsightsEvents(session, org_id="abc123") ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -44,14 +50,14 @@ class InsightsEvents(_MistWebsocket): Generator style:: - ws = OrgInsightsEvents(session, org_id="abc123") + ws = InsightsEvents(session, org_id="abc123") ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with OrgInsightsEvents(session, org_id="abc123") as ws: + with InsightsEvents(session, org_id="abc123") as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) @@ -63,12 +69,18 @@ def __init__( org_id: str, ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: super().__init__( mist_session, channels=[f"/orgs/{org_id}/insights/summary"], ping_interval=ping_interval, ping_timeout=ping_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_backoff=reconnect_backoff, ) @@ -88,12 +100,18 @@ class MxEdgesStatsEvents(_MistWebsocket): Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 Time in seconds to wait for a ping response before considering the connection dead. + auto_reconnect : bool, default False + Automatically reconnect on unexpected disconnections using exponential backoff. + max_reconnect_attempts : int, default 5 + Maximum number of reconnect attempts before giving up. + reconnect_backoff : float, default 2.0 + Base backoff delay in seconds. Doubles after each failed attempt. EXAMPLE ----------- Callback style (background thread):: - ws = OrgMxEdgesStatsEvents(session, org_id="abc123") + ws = MxEdgesStatsEvents(session, org_id="abc123") ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -101,14 +119,14 @@ class MxEdgesStatsEvents(_MistWebsocket): Generator style:: - ws = OrgMxEdgesStatsEvents(session, org_id="abc123") + ws = MxEdgesStatsEvents(session, org_id="abc123") ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with OrgMxEdgesStatsEvents(session, org_id="abc123") as ws: + with MxEdgesStatsEvents(session, org_id="abc123") as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) @@ -120,12 +138,18 @@ def __init__( org_id: str, ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: super().__init__( mist_session, channels=[f"/orgs/{org_id}/stats/mxedges"], ping_interval=ping_interval, ping_timeout=ping_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_backoff=reconnect_backoff, ) @@ -145,6 +169,12 @@ class MxEdgesEvents(_MistWebsocket): Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 Time in seconds to wait for a ping response before considering the connection dead. + auto_reconnect : bool, default False + Automatically reconnect on unexpected disconnections using exponential backoff. + max_reconnect_attempts : int, default 5 + Maximum number of reconnect attempts before giving up. + reconnect_backoff : float, default 2.0 + Base backoff delay in seconds. Doubles after each failed attempt. EXAMPLE ----------- @@ -177,10 +207,16 @@ def __init__( org_id: str, ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: super().__init__( mist_session, channels=[f"/orgs/{org_id}/mxedges"], ping_interval=ping_interval, ping_timeout=ping_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_backoff=reconnect_backoff, ) diff --git a/src/mistapi/websockets/session.py b/src/mistapi/websockets/session.py index 4cff41f..8a98a14 100644 --- a/src/mistapi/websockets/session.py +++ b/src/mistapi/websockets/session.py @@ -32,12 +32,18 @@ class SessionWithUrl(_MistWebsocket): Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 Time in seconds to wait for a ping response before considering the connection dead. + auto_reconnect : bool, default False + Automatically reconnect on unexpected disconnections using exponential backoff. + max_reconnect_attempts : int, default 5 + Maximum number of reconnect attempts before giving up. + reconnect_backoff : float, default 2.0 + Base backoff delay in seconds. Doubles after each failed attempt. EXAMPLE ----------- Callback style (background thread):: - ws = sessionWithUrl(session, url="wss://example.com/channel") + ws = SessionWithUrl(session, url="wss://example.com/channel") ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -45,14 +51,14 @@ class SessionWithUrl(_MistWebsocket): Generator style:: - ws = sessionWithUrl(session, url="wss://example.com/channel") + ws = SessionWithUrl(session, url="wss://example.com/channel") ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with sessionWithUrl(session, url="wss://example.com/channel") as ws: + with SessionWithUrl(session, url="wss://example.com/channel") as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) @@ -64,6 +70,9 @@ def __init__( url: str, ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: self._url = url super().__init__( @@ -71,6 +80,9 @@ def __init__( channels=[], ping_interval=ping_interval, ping_timeout=ping_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_backoff=reconnect_backoff, ) def _build_ws_url(self) -> str: diff --git a/src/mistapi/websockets/sites.py b/src/mistapi/websockets/sites.py index c63910f..08f5138 100644 --- a/src/mistapi/websockets/sites.py +++ b/src/mistapi/websockets/sites.py @@ -31,12 +31,18 @@ class ClientsStatsEvents(_MistWebsocket): Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 Time in seconds to wait for a ping response before considering the connection dead. + auto_reconnect : bool, default False + Automatically reconnect on unexpected disconnections using exponential backoff. + max_reconnect_attempts : int, default 5 + Maximum number of reconnect attempts before giving up. + reconnect_backoff : float, default 2.0 + Base backoff delay in seconds. Doubles after each failed attempt. EXAMPLE ----------- Callback style (background thread):: - ws = SiteClientsStatsEvents(session, site_id="abc123") + ws = ClientsStatsEvents(session, site_ids=["abc123"]) ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -44,14 +50,14 @@ class ClientsStatsEvents(_MistWebsocket): Generator style:: - ws = SiteClientsStatsEvents(session, site_id="abc123") + ws = ClientsStatsEvents(session, site_ids=["abc123"]) ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with SiteClientsStatsEvents(session, site_id="abc123") as ws: + with ClientsStatsEvents(session, site_ids=["abc123"]) as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) @@ -63,6 +69,9 @@ def __init__( site_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: channels = [f"/sites/{site_id}/stats/clients" for site_id in site_ids] super().__init__( @@ -70,6 +79,9 @@ def __init__( channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_backoff=reconnect_backoff, ) @@ -97,12 +109,18 @@ class DeviceCmdEvents(_MistWebsocket): Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 Time in seconds to wait for a ping response before considering the connection dead. + auto_reconnect : bool, default False + Automatically reconnect on unexpected disconnections using exponential backoff. + max_reconnect_attempts : int, default 5 + Maximum number of reconnect attempts before giving up. + reconnect_backoff : float, default 2.0 + Base backoff delay in seconds. Doubles after each failed attempt. EXAMPLE ----------- Callback style (background thread):: - ws = SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") + ws = DeviceCmdEvents(session, site_id="abc123", device_ids=["def456"]) ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -110,14 +128,14 @@ class DeviceCmdEvents(_MistWebsocket): Generator style:: - ws = SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") + ws = DeviceCmdEvents(session, site_id="abc123", device_ids=["def456"]) ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") as ws: + with DeviceCmdEvents(session, site_id="abc123", device_ids=["def456"]) as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) @@ -130,6 +148,9 @@ def __init__( device_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: channels = [ f"/sites/{site_id}/devices/{device_id}/cmd" for device_id in device_ids @@ -139,6 +160,9 @@ def __init__( channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_backoff=reconnect_backoff, ) @@ -158,12 +182,18 @@ class DeviceStatsEvents(_MistWebsocket): Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 Time in seconds to wait for a ping response before considering the connection dead. + auto_reconnect : bool, default False + Automatically reconnect on unexpected disconnections using exponential backoff. + max_reconnect_attempts : int, default 5 + Maximum number of reconnect attempts before giving up. + reconnect_backoff : float, default 2.0 + Base backoff delay in seconds. Doubles after each failed attempt. EXAMPLE ----------- Callback style (background thread):: - ws = SiteDeviceStatsEvents(session, site_id="abc123") + ws = DeviceStatsEvents(session, site_ids=["abc123"]) ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -171,14 +201,14 @@ class DeviceStatsEvents(_MistWebsocket): Generator style:: - ws = SiteDeviceStatsEvents(session, site_id="abc123") + ws = DeviceStatsEvents(session, site_ids=["abc123"]) ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with SiteDeviceStatsEvents(session, site_id="abc123") as ws: + with DeviceStatsEvents(session, site_ids=["abc123"]) as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) @@ -190,6 +220,9 @@ def __init__( site_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: channels = [f"/sites/{site_id}/stats/devices" for site_id in site_ids] super().__init__( @@ -197,6 +230,9 @@ def __init__( channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_backoff=reconnect_backoff, ) @@ -216,12 +252,18 @@ class DeviceEvents(_MistWebsocket): Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 Time in seconds to wait for a ping response before considering the connection dead. + auto_reconnect : bool, default False + Automatically reconnect on unexpected disconnections using exponential backoff. + max_reconnect_attempts : int, default 5 + Maximum number of reconnect attempts before giving up. + reconnect_backoff : float, default 2.0 + Base backoff delay in seconds. Doubles after each failed attempt. EXAMPLE ----------- Callback style (background thread):: - ws = DeviceEvents(session, site_id="abc123") + ws = DeviceEvents(session, site_ids=["abc123"]) ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -229,14 +271,14 @@ class DeviceEvents(_MistWebsocket): Generator style:: - ws = DeviceEvents(session, site_id="abc123") + ws = DeviceEvents(session, site_ids=["abc123"]) ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with DeviceEvents(session, site_id="abc123") as ws: + with DeviceEvents(session, site_ids=["abc123"]) as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) @@ -248,6 +290,9 @@ def __init__( site_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: channels = [f"/sites/{site_id}/devices" for site_id in site_ids] super().__init__( @@ -255,6 +300,9 @@ def __init__( channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_backoff=reconnect_backoff, ) @@ -274,12 +322,18 @@ class MxEdgesStatsEvents(_MistWebsocket): Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 Time in seconds to wait for a ping response before considering the connection dead. + auto_reconnect : bool, default False + Automatically reconnect on unexpected disconnections using exponential backoff. + max_reconnect_attempts : int, default 5 + Maximum number of reconnect attempts before giving up. + reconnect_backoff : float, default 2.0 + Base backoff delay in seconds. Doubles after each failed attempt. EXAMPLE ----------- Callback style (background thread):: - ws = MxEdgesStatsEvents(session, site_id="abc123") + ws = MxEdgesStatsEvents(session, site_ids=["abc123"]) ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -287,14 +341,14 @@ class MxEdgesStatsEvents(_MistWebsocket): Generator style:: - ws = MxEdgesStatsEvents(session, site_id="abc123") + ws = MxEdgesStatsEvents(session, site_ids=["abc123"]) ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with MxEdgesStatsEvents(session, site_id="abc123") as ws: + with MxEdgesStatsEvents(session, site_ids=["abc123"]) as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) @@ -306,6 +360,9 @@ def __init__( site_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: channels = [f"/sites/{site_id}/stats/mxedges" for site_id in site_ids] super().__init__( @@ -313,6 +370,9 @@ def __init__( channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_backoff=reconnect_backoff, ) @@ -332,12 +392,18 @@ class MxEdgesEvents(_MistWebsocket): Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 Time in seconds to wait for a ping response before considering the connection dead. + auto_reconnect : bool, default False + Automatically reconnect on unexpected disconnections using exponential backoff. + max_reconnect_attempts : int, default 5 + Maximum number of reconnect attempts before giving up. + reconnect_backoff : float, default 2.0 + Base backoff delay in seconds. Doubles after each failed attempt. EXAMPLE ----------- Callback style (background thread):: - ws = MxEdgesEvents(session, site_id="abc123") + ws = MxEdgesEvents(session, site_ids=["abc123"]) ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -345,14 +411,14 @@ class MxEdgesEvents(_MistWebsocket): Generator style:: - ws = MxEdgesEvents(session, site_id="abc123") + ws = MxEdgesEvents(session, site_ids=["abc123"]) ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with MxEdgesEvents(session, site_id="abc123") as ws: + with MxEdgesEvents(session, site_ids=["abc123"]) as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) @@ -364,6 +430,9 @@ def __init__( site_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: channels = [f"/sites/{site_id}/mxedges" for site_id in site_ids] super().__init__( @@ -371,13 +440,16 @@ def __init__( channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_backoff=reconnect_backoff, ) class PcapEvents(_MistWebsocket): """WebSocket stream for site PCAP events. - Subscribes to the ``sites/{site_id}/pcap`` channel and delivers + Subscribes to the ``sites/{site_id}/pcaps`` channel and delivers real-time PCAP events for the given site. PARAMS @@ -390,6 +462,12 @@ class PcapEvents(_MistWebsocket): Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 Time in seconds to wait for a ping response before considering the connection dead. + auto_reconnect : bool, default False + Automatically reconnect on unexpected disconnections using exponential backoff. + max_reconnect_attempts : int, default 5 + Maximum number of reconnect attempts before giving up. + reconnect_backoff : float, default 2.0 + Base backoff delay in seconds. Doubles after each failed attempt. EXAMPLE ----------- @@ -422,6 +500,9 @@ def __init__( site_id: str, ping_interval: int = 30, ping_timeout: int = 10, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 5, + reconnect_backoff: float = 2.0, ) -> None: channels = [f"/sites/{site_id}/pcaps"] super().__init__( @@ -429,4 +510,7 @@ def __init__( channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_backoff=reconnect_backoff, ) diff --git a/tests/conftest.py b/tests/conftest.py index f77cb21..20110f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -302,7 +302,12 @@ def isolate_env_vars(): original_env = os.environ.copy() # Clear mistapi-related environment variables - mist_vars = [var for var in os.environ.keys() if var.startswith("MIST_")] + mist_vars = [ + var + for var in os.environ.keys() + if var.startswith("MIST_") + or var in ("CONSOLE_LOG_LEVEL", "LOGGING_LOG_LEVEL", "HTTPS_PROXY") + ] for var in mist_vars: os.environ.pop(var, None) diff --git a/tests/unit/test_shell.py b/tests/unit/test_shell.py index 73c9ef6..33f5f40 100644 --- a/tests/unit/test_shell.py +++ b/tests/unit/test_shell.py @@ -111,9 +111,7 @@ def test_build_sslopt_client_cert_tuple(self, mock_apisession) -> None: class TestLifecycle: """Tests for connect/disconnect/connected.""" - def test_connect_calls_create_connection( - self, shell_session, mock_ws - ) -> None: + def test_connect_calls_create_connection(self, shell_session, mock_ws) -> None: with patch.object( shell_module.websocket, "create_connection", @@ -295,15 +293,18 @@ def test_happy_path(self, mock_apisession, mock_ws) -> None: mock_response.status_code = 200 mock_response.data = {"url": "wss://example.com/shell/abc"} - with patch.object( - shell_module.websocket, - "create_connection", - return_value=mock_ws, - ), patch.object( - devices_module, - "createSiteDeviceShellSession", - return_value=mock_response, - ) as mock_shell_api: + with ( + patch.object( + shell_module.websocket, + "create_connection", + return_value=mock_ws, + ), + patch.object( + devices_module, + "createSiteDeviceShellSession", + return_value=mock_response, + ) as mock_shell_api, + ): session = create_shell_session(mock_apisession, "site-1", "device-1") assert isinstance(session, ShellSession) diff --git a/tests/unit/test_websocket_client.py b/tests/unit/test_websocket_client.py index 79bb829..11caf3a 100644 --- a/tests/unit/test_websocket_client.py +++ b/tests/unit/test_websocket_client.py @@ -10,6 +10,7 @@ import json import ssl +import threading from unittest.mock import Mock, call, patch import pytest @@ -233,7 +234,10 @@ def test_verify_false(self, mock_session) -> None: mock_session._session.verify = False mock_session._session.cert = None client = _MistWebsocket(mock_session, channels=["/ch"]) - assert client._build_sslopt() == {"cert_reqs": ssl.CERT_NONE} + assert client._build_sslopt() == { + "cert_reqs": ssl.CERT_NONE, + "check_hostname": False, + } def test_verify_custom_ca_path(self, mock_session) -> None: mock_session._session.verify = "/etc/ssl/custom-ca.pem" @@ -269,6 +273,7 @@ def test_verify_false_with_cert_tuple(self, mock_session) -> None: sslopt = client._build_sslopt() assert sslopt == { "cert_reqs": ssl.CERT_NONE, + "check_hostname": False, "certfile": "/path/cert.pem", "keyfile": "/path/key.pem", } @@ -408,22 +413,26 @@ def test_no_error_without_callback(self, ws_client) -> None: class TestHandleClose: - """Tests for _handle_close().""" + """Tests for _handle_close(). + + Note: _handle_close only clears _connected and stores the close + code/msg. The sentinel and on_close callback are fired by + _run_forever_safe after the reconnect loop exits. + """ def test_clears_connected_event(self, ws_client) -> None: ws_client._connected.set() ws_client._handle_close(Mock(), 1000, "normal closure") assert not ws_client._connected.is_set() - def test_puts_none_sentinel_on_queue(self, ws_client) -> None: - ws_client._handle_close(Mock(), 1000, "normal closure") - assert ws_client._queue.get_nowait() is None - - def test_calls_on_close_callback(self, ws_client) -> None: - cb = Mock() - ws_client.on_close(cb) + def test_stores_close_code_and_msg(self, ws_client) -> None: ws_client._handle_close(Mock(), 1001, "going away") - cb.assert_called_once_with(1001, "going away") + assert ws_client._last_close_code == 1001 + assert ws_client._last_close_msg == "going away" + + def test_does_not_put_sentinel_directly(self, ws_client) -> None: + ws_client._handle_close(Mock(), 1000, "normal closure") + assert ws_client._queue.empty() def test_no_error_without_callback(self, ws_client) -> None: ws_client._handle_close(Mock(), 1000, "") # Should not raise @@ -465,7 +474,9 @@ def test_connect_drains_stale_queue_items(self, mock_ws_cls, ws_client) -> None: mock_ws_cls.return_value = Mock() ws_client.connect(run_in_background=False) - # Queue should have been drained before creating the WebSocketApp + # Stale items should have been drained; only the final sentinel + # from _run_forever_safe remains. + assert ws_client._queue.get_nowait() is None assert ws_client._queue.empty() @patch("mistapi.websockets.__ws_client.websocket.WebSocketApp") @@ -529,7 +540,7 @@ def test_passes_sslopt_when_verify_false(self, mock_session) -> None: mock_ws.run_forever.assert_called_once_with( ping_interval=30, ping_timeout=10, - sslopt={"cert_reqs": ssl.CERT_NONE}, + sslopt={"cert_reqs": ssl.CERT_NONE, "check_hostname": False}, ) def test_exception_triggers_error_and_close_handlers(self, ws_client) -> None: @@ -546,11 +557,14 @@ def test_exception_triggers_error_and_close_handlers(self, ws_client) -> None: error_cb.assert_called_once() assert isinstance(error_cb.call_args[0][0], RuntimeError) + # _handle_close stores (-1, str(exc)), _run_forever_safe forwards it close_cb.assert_called_once_with(-1, "connection failed") - def test_noop_when_ws_is_none(self, ws_client) -> None: - ws_client._ws = None - ws_client._run_forever_safe() # Should not raise + def test_run_forever_safe_puts_sentinel_on_exit(self, ws_client) -> None: + mock_ws = Mock() + ws_client._ws = mock_ws + ws_client._run_forever_safe() + assert ws_client._queue.get_nowait() is None # --------------------------------------------------------------------------- @@ -692,6 +706,18 @@ def test_ws_starts_none(self, ws_client) -> None: def test_thread_starts_none(self, ws_client) -> None: assert ws_client._thread is None + def test_negative_max_reconnect_attempts_raises(self, mock_session) -> None: + with pytest.raises(ValueError, match="max_reconnect_attempts must be >= 0"): + _MistWebsocket(mock_session, channels=["/ch"], max_reconnect_attempts=-1) + + def test_zero_reconnect_backoff_raises(self, mock_session) -> None: + with pytest.raises(ValueError, match="reconnect_backoff must be > 0"): + _MistWebsocket(mock_session, channels=["/ch"], reconnect_backoff=0) + + def test_negative_reconnect_backoff_raises(self, mock_session) -> None: + with pytest.raises(ValueError, match="reconnect_backoff must be > 0"): + _MistWebsocket(mock_session, channels=["/ch"], reconnect_backoff=-1.0) + # --------------------------------------------------------------------------- # Public WebSocket channel classes @@ -761,30 +787,30 @@ class TestLocationChannels: """Tests for public location-level WebSocket channel classes.""" def test_ble_assets_events_channels(self, mock_session) -> None: - ws = BleAssetsEvents(mock_session, site_id="s1", map_id=["m1", "m2"]) + ws = BleAssetsEvents(mock_session, site_id="s1", map_ids=["m1", "m2"]) assert ws._channels == [ "/sites/s1/stats/maps/m1/assets", "/sites/s1/stats/maps/m2/assets", ] def test_connected_clients_events_channels(self, mock_session) -> None: - ws = ConnectedClientsEvents(mock_session, site_id="s1", map_id=["m1"]) + ws = ConnectedClientsEvents(mock_session, site_id="s1", map_ids=["m1"]) assert ws._channels == ["/sites/s1/stats/maps/m1/clients"] def test_sdk_clients_events_channels(self, mock_session) -> None: - ws = SdkClientsEvents(mock_session, site_id="s1", map_id=["m1"]) + ws = SdkClientsEvents(mock_session, site_id="s1", map_ids=["m1"]) assert ws._channels == ["/sites/s1/stats/maps/m1/sdkclients"] def test_unconnected_clients_events_channels(self, mock_session) -> None: - ws = UnconnectedClientsEvents(mock_session, site_id="s1", map_id=["m1"]) + ws = UnconnectedClientsEvents(mock_session, site_id="s1", map_ids=["m1"]) assert ws._channels == ["/sites/s1/stats/maps/m1/unconnected_clients"] def test_discovered_ble_assets_events_channels(self, mock_session) -> None: - ws = DiscoveredBleAssetsEvents(mock_session, site_id="s1", map_id=["m1"]) + ws = DiscoveredBleAssetsEvents(mock_session, site_id="s1", map_ids=["m1"]) assert ws._channels == ["/sites/s1/stats/maps/m1/discovered_assets"] def test_inherits_from_mist_websocket(self, mock_session) -> None: - ws = BleAssetsEvents(mock_session, site_id="s1", map_id=["m1"]) + ws = BleAssetsEvents(mock_session, site_id="s1", map_ids=["m1"]) assert isinstance(ws, _MistWebsocket) @@ -799,3 +825,151 @@ def test_session_with_url_channels(self, mock_session) -> None: def test_inherits_from_mist_websocket(self, mock_session) -> None: ws = SessionWithUrl(mock_session, url="wss://example.com/custom") assert isinstance(ws, _MistWebsocket) + + +# --------------------------------------------------------------------------- +# Auto-reconnect +# --------------------------------------------------------------------------- + + +class TestAutoReconnect: + """Tests for the auto_reconnect feature.""" + + def _make_client(self, mock_session, **kwargs): + defaults = dict( + mist_session=mock_session, + channels=["/ch"], + auto_reconnect=True, + max_reconnect_attempts=3, + reconnect_backoff=0.01, # fast for tests + ) + defaults.update(kwargs) + return _MistWebsocket(**defaults) + + def test_retries_on_transient_failure(self, mock_session) -> None: + client = self._make_client(mock_session, max_reconnect_attempts=2) + call_count = 0 + + def fake_run_forever(**kwargs): + nonlocal call_count + call_count += 1 + # Simulate connection drop + client._handle_close(client._ws, 1006, "abnormal closure") + + mock_ws = Mock() + mock_ws.run_forever.side_effect = fake_run_forever + with patch.object(client, "_create_ws_app", return_value=mock_ws): + client._ws = mock_ws + client._run_forever_safe() + + # 1 initial + 2 retries = 3 calls + assert call_count == 3 + + def test_gives_up_after_max_attempts(self, mock_session) -> None: + client = self._make_client(mock_session, max_reconnect_attempts=2) + close_cb = Mock() + client.on_close(close_cb) + + mock_ws = Mock() + mock_ws.run_forever.side_effect = lambda **kw: client._handle_close( + mock_ws, 1006, "drop" + ) + with patch.object(client, "_create_ws_app", return_value=mock_ws): + client._ws = mock_ws + client._run_forever_safe() + + # Callback fires exactly once (on final close) + close_cb.assert_called_once_with(1006, "drop") + # Sentinel put exactly once + assert client._queue.get_nowait() is None + assert client._queue.empty() + + def test_disconnect_during_backoff_exits_immediately(self, mock_session) -> None: + client = self._make_client( + mock_session, max_reconnect_attempts=5, reconnect_backoff=10.0 + ) + call_count = 0 + entered_backoff = threading.Event() + + original_wait = client._user_disconnect.wait + + def wait_and_signal(timeout=None): + """Signal that the backoff wait has started, then delegate.""" + entered_backoff.set() + return original_wait(timeout=timeout) + + def fake_run_forever(**kwargs): + nonlocal call_count + call_count += 1 + client._handle_close(client._ws, 1006, "drop") + + mock_ws = Mock() + mock_ws.run_forever.side_effect = fake_run_forever + + def disconnect_when_ready(): + entered_backoff.wait() # deterministic: wait until backoff starts + client.disconnect() + + with ( + patch.object(client, "_create_ws_app", return_value=mock_ws), + patch.object(client._user_disconnect, "wait", side_effect=wait_and_signal), + ): + client._ws = mock_ws + t = threading.Thread(target=disconnect_when_ready) + t.start() + client._run_forever_safe() + t.join(timeout=2) + + # Should have run once, then been interrupted during first backoff + assert call_count == 1 + assert client._queue.get_nowait() is None + + def test_handle_open_resets_reconnect_attempts(self, mock_session) -> None: + client = self._make_client(mock_session, max_reconnect_attempts=3) + client._reconnect_attempts = 5 + client._handle_open(Mock()) + assert client._reconnect_attempts == 0 + + def test_successful_reconnect_resets_counter(self, mock_session) -> None: + client = self._make_client(mock_session, max_reconnect_attempts=2) + call_count = 0 + + def fake_run_forever(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + # First connection drops + client._handle_close(client._ws, 1006, "drop") + elif call_count == 2: + # Reconnect succeeds, then simulate open + later drop + client._handle_open(client._ws) + client._handle_close(client._ws, 1006, "drop again") + elif call_count == 3: + # Another reconnect succeeds, then clean exit + client._handle_open(client._ws) + client._handle_close(client._ws, 1006, "drop again") + elif call_count == 4: + # Final reconnect succeeds then user disconnects + client._user_disconnect.set() + + mock_ws = Mock() + mock_ws.run_forever.side_effect = fake_run_forever + with patch.object(client, "_create_ws_app", return_value=mock_ws): + client._ws = mock_ws + client._run_forever_safe() + + # Counter was reset by _handle_open, so we got more than max_attempts+1 total calls + assert call_count == 4 + + def test_no_reconnect_when_disabled(self, mock_session) -> None: + client = _MistWebsocket( + mist_session=mock_session, + channels=["/ch"], + auto_reconnect=False, + ) + mock_ws = Mock() + client._ws = mock_ws + client._run_forever_safe() + + # run_forever called exactly once, no retry + mock_ws.run_forever.assert_called_once() diff --git a/uv.lock b/uv.lock index 4c81864..a26ce92 100644 --- a/uv.lock +++ b/uv.lock @@ -537,7 +537,7 @@ wheels = [ [[package]] name = "mistapi" -version = "0.61.1" +version = "0.61.2" source = { editable = "." } dependencies = [ { name = "deprecation" },