From 1d01bc21a1892feb5287f89bb21c310700026f0b Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Mon, 23 Feb 2026 22:03:28 +0100 Subject: [PATCH 01/16] adding websocket client --- README.md | 197 +++++++++++++----- pyproject.toml | 1 + src/mistapi/__init__.py | 2 + src/mistapi/websockets/__init__.py | 24 +++ src/mistapi/websockets/__ws_client.py | 182 ++++++++++++++++ src/mistapi/websockets/location/__init__.py | 21 ++ src/mistapi/websockets/location/ble_assets.py | 60 ++++++ .../websockets/location/clients_connected.py | 61 ++++++ .../websockets/location/clients_sdk.py | 61 ++++++ .../location/clients_unconnected.py | 61 ++++++ .../location/discovered_ble_assets.py | 61 ++++++ src/mistapi/websockets/orgs/__init__.py | 15 ++ src/mistapi/websockets/orgs/insights.py | 56 +++++ src/mistapi/websockets/orgs/mxedges_stats.py | 56 +++++ .../websockets/orgs/mxedges_upgrades.py | 56 +++++ src/mistapi/websockets/sites/__init__.py | 19 ++ src/mistapi/websockets/sites/clients_stats.py | 56 +++++ src/mistapi/websockets/sites/devices_cmd.py | 60 ++++++ src/mistapi/websockets/sites/devices_stats.py | 56 +++++ .../websockets/sites/devices_upgrades.py | 56 +++++ src/mistapi/websockets/sites/mxedges_stats.py | 56 +++++ src/mistapi/websockets/sites/pcap.py | 56 +++++ uv.lock | 11 + 23 files changed, 1237 insertions(+), 47 deletions(-) create mode 100644 src/mistapi/websockets/__init__.py create mode 100644 src/mistapi/websockets/__ws_client.py create mode 100644 src/mistapi/websockets/location/__init__.py create mode 100644 src/mistapi/websockets/location/ble_assets.py create mode 100644 src/mistapi/websockets/location/clients_connected.py create mode 100644 src/mistapi/websockets/location/clients_sdk.py create mode 100644 src/mistapi/websockets/location/clients_unconnected.py create mode 100644 src/mistapi/websockets/location/discovered_ble_assets.py create mode 100644 src/mistapi/websockets/orgs/__init__.py create mode 100644 src/mistapi/websockets/orgs/insights.py create mode 100644 src/mistapi/websockets/orgs/mxedges_stats.py create mode 100644 src/mistapi/websockets/orgs/mxedges_upgrades.py create mode 100644 src/mistapi/websockets/sites/__init__.py create mode 100644 src/mistapi/websockets/sites/clients_stats.py create mode 100644 src/mistapi/websockets/sites/devices_cmd.py create mode 100644 src/mistapi/websockets/sites/devices_stats.py create mode 100644 src/mistapi/websockets/sites/devices_upgrades.py create mode 100644 src/mistapi/websockets/sites/mxedges_stats.py create mode 100644 src/mistapi/websockets/sites/pcap.py diff --git a/README.md b/README.md index df7002c..c36ebc3 100644 --- a/README.md +++ b/README.md @@ -14,11 +14,26 @@ A comprehensive Python package to interact with the Mist Cloud APIs, built from - [Installation](#installation) - [Quick Start](#quick-start) - [Configuration](#configuration) + - [Using Environment File](#using-environment-file) + - [Environment Variables](#environment-variables) - [Authentication](#authentication) -- [Usage](#usage) -- [CLI Helper Functions](#cli-helper-functions) -- [Pagination](#pagination-support) -- [Examples](#examples) + - [Interactive Authentication](#interactive-authentication) + - [Environment File Authentication](#environment-file-authentication) + - [HashiCorp Vault Authentication](#hashicorp-vault-authentication) + - [System Keyring Authentication](#system-keyring-authentication) + - [Direct Parameter Authentication](#direct-parameter-authentication) +- [API Requests Usage](#api-requests-usage) + - [Basic API Calls](#basic-api-calls) + - [Error Handling](#error-handling) + - [Log Sanitization](#log-sanitization) + - [Getting Help](#getting-help) + - [CLI Helper Functions](#cli-helper-functions) + - [Pagination](#pagination-support) + - [Examples](#examples) +- [WebSocket Streaming](#websocket-streaming) + - [Available Channels](#available-channels) + - [Callbacks](#callbacks) + - [Usage Patterns](#usage-patterns) - [Development](#development-and-testing) - [Contributing](#contributing) - [License](#license) @@ -150,22 +165,24 @@ MIST_APITOKEN=your_api_token_here # LOGGING_LOG_LEVEL=10 ``` -### All Configuration Options +### Configuration Options + +| Environment Variable | APISession Parameter | Type | Default | Description | +|---|---|---|---|---| +| `MIST_HOST` | `host` | string | None | Mist Cloud API endpoint (e.g., `api.mist.com`) | +| `MIST_APITOKEN` | `apitoken` | string | None | API Token for authentication (recommended) | +| `MIST_USER` | `email` | string | None | Username/email for authentication | +| `MIST_PASSWORD` | `password` | string | None | Password for authentication | +| `MIST_KEYRING_SERVICE` | `keyring_service` | string | None | System keyring service name | +| `MIST_VAULT_URL` | `vault_url` | string | https://127.0.0.1:8200 | HashiCorp Vault URL | +| `MIST_VAULT_PATH` | `vault_path` | string | None | Path to secret in Vault | +| `MIST_VAULT_MOUNT_POINT` | `vault_mount_point` | string | secret | Vault mount point | +| `MIST_VAULT_TOKEN` | `vault_token` | string | None | Vault authentication token | +| `CONSOLE_LOG_LEVEL` | `console_log_level` | int | 20 | Console log level (0-50) | +| `LOGGING_LOG_LEVEL` | `logging_log_level` | int | 10 | File log level (0-50) | +| `HTTPS_PROXY` | `https_proxy` | string | None | HTTP/HTTPS proxy URL | +| | `env_file` | str | None | Path to `.env` file | -| Variable | Type | Default | Description | -|----------|------|---------|-------------| -| `MIST_HOST` | string | None | Mist Cloud API endpoint (e.g., `api.mist.com`) | -| `MIST_APITOKEN` | string | None | API Token for authentication (recommended) | -| `MIST_USER` | string | None | Username if not using API token | -| `MIST_PASSWORD` | string | None | Password if not using API token | -| `MIST_KEYRING_SERVICE` | string | None | Load credentials from system keyring | -| `MIST_VAULT_URL` | string | https://127.0.0.1:8200 | HashiCorp Vault URL | -| `MIST_VAULT_PATH` | string | None | Path to secret in Vault | -| `MIST_VAULT_MOUNT_POINT` | string | secret | Vault mount point | -| `MIST_VAULT_TOKEN` | string | None | Vault authentication token | -| `CONSOLE_LOG_LEVEL` | int | 20 | Console log level (0-50) | -| `LOGGING_LOG_LEVEL` | int | 10 | File log level (0-50) | -| `HTTPS_PROXY` | string | None | HTTP/HTTPS proxy URL | --- @@ -253,27 +270,10 @@ apisession = mistapi.APISession( apisession.login() ``` -### APISession Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `email` | str | None | Email for login/password authentication | -| `password` | str | None | Password for login/password authentication | -| `apitoken` | str | None | API token (recommended method) | -| `host` | str | None | Mist Cloud endpoint (e.g., "api.mist.com") | -| `keyring_service` | str | None | System keyring service name | -| `vault_url` | str | https://127.0.0.1:8200 | HashiCorp Vault URL | -| `vault_path` | str | None | Path to secret in Vault | -| `vault_mount_point` | str | secret | Vault mount point | -| `vault_token` | str | None | Vault authentication token | -| `env_file` | str | None | Path to `.env` file | -| `console_log_level` | int | 20 | Console logging level (0-50) | -| `logging_log_level` | int | 10 | File logging level (0-50) | -| `https_proxy` | str | None | Proxy URL | --- -## Usage +## API Requests Usage ### Basic API Calls @@ -345,11 +345,11 @@ help(mistapi.api.v1.orgs.stats.getOrgStats) --- -## CLI Helper Functions +### CLI Helper Functions Interactive functions for selecting organizations and sites. -### Organization Selection +#### Organization Selection ```python # Select single organization @@ -368,7 +368,7 @@ Available organizations: Select an Org (0 to 1, or q to exit): 0 ``` -### Site Selection +#### Site Selection ```python # Select site within an organization @@ -386,9 +386,9 @@ Select a Site (0 to 1, or q to exit): 0 --- -## Pagination Support +### Pagination Support -### Get Next Page +#### Get Next Page ```python # Get first page @@ -403,7 +403,7 @@ if response.next: print(f"Second page: {len(response_2.data['results'])} results") ``` -### Get All Pages Automatically +#### Get All Pages Automatically ```python # Get all pages with a single call @@ -419,11 +419,11 @@ print(f"Total results across all pages: {len(all_data)}") --- -## Examples +### Examples Comprehensive examples are available in the [Mist Library repository](https://github.com/tmunzer/mist_library). -### Device Management +#### Device Management ```python # List all devices in an organization @@ -441,7 +441,7 @@ result = mistapi.api.v1.orgs.devices.updateOrgDevice( ) ``` -### Site Management +#### Site Management ```python # Create a new site @@ -458,7 +458,7 @@ new_site = mistapi.api.v1.orgs.sites.createOrgSite( site_stats = mistapi.api.v1.sites.stats.getSiteStats(apisession, new_site.id) ``` -### Client Analytics +#### Client Analytics ```python # Search for wireless clients @@ -478,6 +478,109 @@ events = mistapi.api.v1.orgs.clients.searchOrgClientsEvents( --- +## WebSocket Streaming + +The package provides a WebSocket client for real-time event streaming from the Mist API (`wss://{host}/api-ws/v1/stream`). Authentication is handled automatically using the same session credentials (API token or login/password). + +### Available Channels + +#### Organization Channels + +| Class | Channel | Description | +|-------|---------|-------------| +| `mistapi.websockets.orgs.OrgInsightsEvents` | `/orgs/{org_id}/insights/summary` | Real-time insights events for an organization | +| `mistapi.websockets.orgs.OrgMxEdgesStatsEvents` | `/orgs/{org_id}/stats/mxedges` | Real-time MX edges stats for an organization | +| `mistapi.websockets.orgs.OrgMxEdgesUpgradesEvents` | `/orgs/{org_id}/mxedges` | Real-time MX edges upgrades events for an organization | + +#### Site Channels + +| Class | Channel | Description | +|-------|---------|-------------| +| `mistapi.websockets.sites.SiteClientsStatsEvents` | `/sites/{site_id}/stats/clients` | Real-time clients stats for a site | +| `mistapi.websockets.sites.SiteDeviceCmdEvents` | `/sites/{site_id}/devices/{device_id}/cmd` | Real-time device command events for a site | +| `mistapi.websockets.sites.SiteDeviceStatsEvents` | `/sites/{site_id}/stats/devices` | Real-time device stats for a site | +| `mistapi.websockets.sites.SiteDeviceUpgradesEvents` | `/sites/{site_id}/devices` | Real-time device upgrades events for a site | +| `mistapi.websockets.sites.SitePcapEvents` | `/sites/{site_id}/pcap` | Real-time PCAP events for a site | + +#### Location Channels + +| Class | Channel | Description | +|-------|---------|-------------| +| `mistapi.websockets.location.LocationBleAssetsEvents` | `/sites/{site_id}/stats/maps/{map_id}/assets` | Real-time BLE assets location events | +| `mistapi.websockets.location.LocationConnectedClientsEvents` | `/sites/{site_id}/stats/maps/{map_id}/clients` | Real-time connected clients location events | +| `mistapi.websockets.location.LocationSdkClientsEvents` | `/sites/{site_id}/stats/maps/{map_id}/sdkclients` | Real-time SDK clients location events | +| `mistapi.websockets.location.LocationUnconnectedClientsEvents` | `/sites/{site_id}/stats/maps/{map_id}/unconnected_clients` | Real-time unconnected clients location events | +| `mistapi.websockets.location.LocationDiscoveredBleAssetsEvents` | `/sites/{site_id}/stats/maps/{map_id}/discovered_assets` | Real-time discovered BLE assets location events | + +### Callbacks + +| Method | Signature | Description | +|--------|-----------|-------------| +| `ws.on_open(cb)` | `cb()` | Called when the connection is established | +| `ws.on_message(cb)` | `cb(data: dict)` | Called for every incoming message | +| `ws.on_error(cb)` | `cb(error: Exception)` | Called on WebSocket errors | +| `ws.on_close(cb)` | `cb(status_code: int, msg: str)` | Called when the connection closes | + +### Usage Patterns + +#### Callback style (recommended) + +`connect()` returns immediately; messages are delivered to the registered callback in a background thread. + +```python +import mistapi + +apisession = mistapi.APISession(env_file="~/.mist_env") +apisession.login() + +ws = mistapi.websockets.sites.SiteDeviceStatsEvents(apisession, site_id="") +ws.on_message(lambda data: print(data)) +ws.connect() # non-blocking + +input("Press Enter to stop") +ws.disconnect() +``` + +#### Generator style + +Iterate over incoming messages as a blocking generator. Useful when you want to process messages sequentially in a loop. + +```python +ws = mistapi.websockets.sites.SiteDeviceStatsEvents(apisession, site_id="") +ws.connect(run_in_background=True) + +for msg in ws.receive(): # blocks, yields each message as a dict + print(msg) + if some_condition: + ws.disconnect() # stops the generator cleanly +``` + +#### Blocking style + +`connect(run_in_background=False)` blocks the calling thread until the connection closes. Useful for simple scripts. + +```python +ws = mistapi.websockets.sites.SiteDeviceStatsEvents(apisession, site_id="") +ws.on_message(lambda data: print(data)) +ws.connect(run_in_background=False) # blocks until disconnected +``` + +#### Context manager + +`disconnect()` is called automatically on exit, even if an exception is raised. + +```python +import time + +with mistapi.websockets.sites.SiteDeviceStatsEvents(apisession, site_id="") as ws: + ws.on_message(lambda data: print(data)) + ws.connect() + time.sleep(60) +# ws.disconnect() called automatically here +``` + +--- + ## Development and Testing ### Development Setup diff --git a/pyproject.toml b/pyproject.toml index ccf782e..a72fa16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "deprecation>=2.1.0", "hvac>=2.3.0", "keyring>=24.3.0", + "websocket-client>=1.8.0", ] [project.urls] diff --git a/src/mistapi/__init__.py b/src/mistapi/__init__.py index 891e035..9d04e29 100644 --- a/src/mistapi/__init__.py +++ b/src/mistapi/__init__.py @@ -10,9 +10,11 @@ -------------------------------------------------------------------------------- """ +# isort: skip_file from mistapi.__api_session import APISession as APISession from mistapi import api as api from mistapi import cli as cli +from mistapi import websockets as websockets from mistapi.__pagination import get_all as get_all from mistapi.__pagination import get_next as get_next from mistapi.__version import __author__ as __author__ diff --git a/src/mistapi/websockets/__init__.py b/src/mistapi/websockets/__init__.py new file mode 100644 index 0000000..6ca5838 --- /dev/null +++ b/src/mistapi/websockets/__init__.py @@ -0,0 +1,24 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel classes for real-time Mist API streaming. + +Usage example:: + + import mistapi + session = mistapi.APISession(...) + session.login() + + ws = mistapi.websockets.sites.SiteDeviceStatsEvents(session, site_id="") + ws.on_message(lambda data: print(data)) + ws.connect() +""" + +from mistapi.websockets import location, orgs, sites diff --git a/src/mistapi/websockets/__ws_client.py b/src/mistapi/websockets/__ws_client.py new file mode 100644 index 0000000..b0f4268 --- /dev/null +++ b/src/mistapi/websockets/__ws_client.py @@ -0,0 +1,182 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +This module provides the _MistWebsocket base class for WebSocket connections +to the Mist API streaming endpoint (wss://{host}/api-ws/v1/stream). +""" + +import json +import queue +import threading +from collections.abc import Callable, Generator +from typing import TYPE_CHECKING + +import websocket + +if TYPE_CHECKING: + from mistapi import APISession + + +class _MistWebsocket: + """ + Base class for Mist API WebSocket channels. + + Connects to wss://{host}/api-ws/v1/stream and subscribes to a channel + by sending {"subscribe": ""} on open. + + Auth is handled automatically: + - API token sessions use an Authorization header. + - Login/password sessions pass the requests Session cookies. + """ + + def __init__(self, mist_session: "APISession", channel: str) -> None: + self._mist_session = mist_session + self._channel = channel + self._ws: websocket.WebSocketApp | None = None + self._thread: threading.Thread | None = None + self._queue: queue.Queue[dict | None] = queue.Queue() + self._on_message_cb: Callable[[dict], None] | None = None + self._on_error_cb: Callable[[Exception], None] | None = None + self._on_open_cb: Callable[[], None] | None = None + self._on_close_cb: Callable[[int, str], None] | None = None + + # ------------------------------------------------------------------ + # Auth / URL helpers + + def _build_ws_url(self) -> str: + return f"wss://{self._mist_session._cloud_uri.replace('api.', 'api-ws.')}/api-ws/v1/stream" + + def _get_headers(self) -> dict: + if self._mist_session._apitoken: + token = self._mist_session._apitoken[self._mist_session._apitoken_index] + return {"Authorization": f"Token {token}"} + return {} + + def _get_cookie(self) -> str | None: + cookies = self._mist_session._session.cookies + if cookies: + pairs = "; ".join(f"{c.name}={c.value}" for c in cookies) + return pairs if pairs else None + return None + + # ------------------------------------------------------------------ + # Callback registration + + def on_message(self, callback: Callable[[dict], None]) -> None: + """Register a callback invoked for every incoming message.""" + self._on_message_cb = callback + + def on_error(self, callback: Callable[[Exception], None]) -> None: + """Register a callback invoked on WebSocket errors.""" + self._on_error_cb = callback + + def on_open(self, callback: Callable[[], None]) -> None: + """Register a callback invoked when the connection is established.""" + self._on_open_cb = callback + + def on_close(self, callback: Callable[[int, str], None]) -> None: + """Register a callback invoked when the connection closes.""" + self._on_close_cb = callback + + # ------------------------------------------------------------------ + # Internal WebSocketApp handlers + + def _handle_open(self, ws: websocket.WebSocketApp) -> None: + ws.send(json.dumps({"subscribe": self._channel})) + if self._on_open_cb: + self._on_open_cb() + + def _handle_message(self, ws: websocket.WebSocketApp, message: str) -> None: + try: + data = json.loads(message) + except json.JSONDecodeError: + data = {"raw": message} + self._queue.put(data) + if self._on_message_cb: + self._on_message_cb(data) + + def _handle_error(self, ws: websocket.WebSocketApp, error: Exception) -> None: + if self._on_error_cb: + self._on_error_cb(error) + + def _handle_close( + self, + ws: websocket.WebSocketApp, + close_status_code: int, + close_msg: str, + ) -> None: + self._queue.put(None) # Signals receive() generator to stop + if self._on_close_cb: + self._on_close_cb(close_status_code, close_msg) + + # ------------------------------------------------------------------ + # Lifecycle + + def connect(self, run_in_background: bool = True) -> None: + """ + Open the WebSocket connection and subscribe to the channel. + + PARAMS + ----------- + run_in_background : bool, default True + If True, runs the WebSocket loop in a daemon thread (non-blocking). + If False, blocks the calling thread until disconnected. + """ + self._ws = websocket.WebSocketApp( + self._build_ws_url(), + header=self._get_headers(), + cookie=self._get_cookie(), + on_open=self._handle_open, + on_message=self._handle_message, + on_error=self._handle_error, + on_close=self._handle_close, + ) + if run_in_background: + self._thread = threading.Thread(target=self._run_forever_safe, daemon=True) + self._thread.start() + else: + self._run_forever_safe() + + def _run_forever_safe(self) -> None: + if self._ws: + try: + self._ws.run_forever(ping_interval=30, ping_timeout=10) + except Exception as exc: + self._handle_error(self._ws, exc) + self._handle_close(self._ws, -1, str(exc)) + + def disconnect(self) -> None: + """Close the WebSocket connection.""" + if self._ws: + self._ws.close() + + def receive(self) -> Generator[dict, None, None]: + """ + Blocking generator that yields each incoming message as a dict. + + Exits cleanly when the connection closes (disconnect() is called or + the server closes the connection). + + Intended for use after connect(run_in_background=True). + """ + while True: + item = self._queue.get() + if item is None: + break + yield item + + # ------------------------------------------------------------------ + # Context manager + + def __enter__(self) -> "_MistWebsocket": + return self + + def __exit__(self, *args) -> None: + self.disconnect() diff --git a/src/mistapi/websockets/location/__init__.py b/src/mistapi/websockets/location/__init__.py new file mode 100644 index 0000000..fed0c6a --- /dev/null +++ b/src/mistapi/websockets/location/__init__.py @@ -0,0 +1,21 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi.websockets.location.ble_assets import LocationBleAssetsEvents +from mistapi.websockets.location.clients_connected import LocationConnectedClientsEvents +from mistapi.websockets.location.clients_sdk import LocationSdkClientsEvents +from mistapi.websockets.location.clients_unconnected import ( + LocationUnconnectedClientsEvents, +) +from mistapi.websockets.location.discovered_ble_assets import ( + LocationDiscoveredBleAssetsEvents, +) diff --git a/src/mistapi/websockets/location/ble_assets.py b/src/mistapi/websockets/location/ble_assets.py new file mode 100644 index 0000000..fe3a668 --- /dev/null +++ b/src/mistapi/websockets/location/ble_assets.py @@ -0,0 +1,60 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for BLE assets location events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class LocationBleAssetsEvents(_MistWebsocket): + """WebSocket stream for location BLE assets events. + + Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/assets`` channel and delivers + real-time BLE assets events for the given location. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + map_id : str + UUID of the map to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style (background thread):: + + ws = LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: + super().__init__( + mist_session, channel=f"/sites/{site_id}/stats/maps/{map_id}/assets" + ) diff --git a/src/mistapi/websockets/location/clients_connected.py b/src/mistapi/websockets/location/clients_connected.py new file mode 100644 index 0000000..ed17636 --- /dev/null +++ b/src/mistapi/websockets/location/clients_connected.py @@ -0,0 +1,61 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for connected clients location events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class LocationConnectedClientsEvents(_MistWebsocket): + """WebSocket stream for location connected clients events. + + Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/clients`` channel and delivers + real-time connected clients events for the given location. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + map_id : str + UUID of the map to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style (background thread):: + + ws = LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: + super().__init__( + mist_session, + channel=f"/sites/{site_id}/stats/maps/{map_id}/clients", + ) diff --git a/src/mistapi/websockets/location/clients_sdk.py b/src/mistapi/websockets/location/clients_sdk.py new file mode 100644 index 0000000..e490543 --- /dev/null +++ b/src/mistapi/websockets/location/clients_sdk.py @@ -0,0 +1,61 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for SDK Clients location events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class LocationSdkClientsEvents(_MistWebsocket): + """WebSocket stream for location SDK clients events. + + Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/sdkclients`` channel and delivers + real-time SDK clients events for the given location. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + map_id : str + UUID of the map to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style (background thread):: + + ws = LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: + super().__init__( + mist_session, + channel=f"/sites/{site_id}/stats/maps/{map_id}/sdkclients", + ) diff --git a/src/mistapi/websockets/location/clients_unconnected.py b/src/mistapi/websockets/location/clients_unconnected.py new file mode 100644 index 0000000..2c48f35 --- /dev/null +++ b/src/mistapi/websockets/location/clients_unconnected.py @@ -0,0 +1,61 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for unconnected clients location events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class LocationUnconnectedClientsEvents(_MistWebsocket): + """WebSocket stream for location unconnected clients events. + + Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/unconnected_clients`` channel and delivers + real-time unconnected clients events for the given location. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + map_id : str + UUID of the map to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style (background thread):: + + ws = LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: + super().__init__( + mist_session, + channel=f"/sites/{site_id}/stats/maps/{map_id}/unconnected_clients", + ) diff --git a/src/mistapi/websockets/location/discovered_ble_assets.py b/src/mistapi/websockets/location/discovered_ble_assets.py new file mode 100644 index 0000000..97cf510 --- /dev/null +++ b/src/mistapi/websockets/location/discovered_ble_assets.py @@ -0,0 +1,61 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for discovered BLE assets location events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class LocationDiscoveredBleAssetsEvents(_MistWebsocket): + """WebSocket stream for location discovered BLE assets events. + + Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/discovered_assets`` channel and delivers + real-time discovered BLE assets events for the given location. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + map_id : str + UUID of the map to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style (background thread):: + + ws = LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: + super().__init__( + mist_session, + channel=f"/sites/{site_id}/stats/maps/{map_id}/discovered_assets", + ) diff --git a/src/mistapi/websockets/orgs/__init__.py b/src/mistapi/websockets/orgs/__init__.py new file mode 100644 index 0000000..3f3353e --- /dev/null +++ b/src/mistapi/websockets/orgs/__init__.py @@ -0,0 +1,15 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi.websockets.orgs.insights import OrgInsightsEvents +from mistapi.websockets.orgs.mxedges_stats import OrgMxEdgesStatsEvents +from mistapi.websockets.orgs.mxedges_upgrades import OrgMxEdgesUpgradesEvents diff --git a/src/mistapi/websockets/orgs/insights.py b/src/mistapi/websockets/orgs/insights.py new file mode 100644 index 0000000..9153678 --- /dev/null +++ b/src/mistapi/websockets/orgs/insights.py @@ -0,0 +1,56 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for organization insights events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class OrgInsightsEvents(_MistWebsocket): + """WebSocket stream for organization insights events. + + Subscribes to the ``orgs/{org_id}/insights/summary`` channel and delivers + real-time insights events for the given organization. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + org_id : str + UUID of the organization to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = OrgInsightsEvents(session, org_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = OrgInsightsEvents(session, org_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with OrgInsightsEvents(session, org_id="abc123") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, org_id: str) -> None: + super().__init__(mist_session, channel=f"/orgs/{org_id}/insights/summary") diff --git a/src/mistapi/websockets/orgs/mxedges_stats.py b/src/mistapi/websockets/orgs/mxedges_stats.py new file mode 100644 index 0000000..4b20fbf --- /dev/null +++ b/src/mistapi/websockets/orgs/mxedges_stats.py @@ -0,0 +1,56 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for site MX edges stats events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class OrgMxEdgesStatsEvents(_MistWebsocket): + """WebSocket stream for organization MX edges stats events. + + Subscribes to the ``orgs/{org_id}/stats/mxedges`` channel and delivers + real-time MX edges stats events for the given organization. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + org_id : str + UUID of the organization to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = OrgMxEdgesStatsEvents(session, org_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = OrgMxEdgesStatsEvents(session, org_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with OrgMxEdgesStatsEvents(session, org_id="abc123") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, org_id: str) -> None: + super().__init__(mist_session, channel=f"/orgs/{org_id}/stats/mxedges") diff --git a/src/mistapi/websockets/orgs/mxedges_upgrades.py b/src/mistapi/websockets/orgs/mxedges_upgrades.py new file mode 100644 index 0000000..d2882ab --- /dev/null +++ b/src/mistapi/websockets/orgs/mxedges_upgrades.py @@ -0,0 +1,56 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for org MX edges upgrades events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class OrgMxEdgesUpgradesEvents(_MistWebsocket): + """WebSocket stream for org MX edges upgrades events. + + Subscribes to the ``orgs/{org_id}/mxedges`` channel and delivers + real-time MX edges upgrades events for the given org. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + org_id : str + UUID of the org to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = OrgMxEdgesUpgradesEvents(session, org_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = OrgMxEdgesUpgradesEvents(session, org_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with OrgMxEdgesUpgradesEvents(session, org_id="abc123") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, org_id: str) -> None: + super().__init__(mist_session, channel=f"/orgs/{org_id}/mxedges") diff --git a/src/mistapi/websockets/sites/__init__.py b/src/mistapi/websockets/sites/__init__.py new file mode 100644 index 0000000..369357e --- /dev/null +++ b/src/mistapi/websockets/sites/__init__.py @@ -0,0 +1,19 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi.websockets.sites.clients_stats import SiteClientsStatsEvents +from mistapi.websockets.sites.devices_cmd import SiteDeviceCmdEvents +from mistapi.websockets.sites.devices_stats import SiteDeviceStatsEvents +from mistapi.websockets.sites.devices_upgrades import SiteDeviceUpgradesEvents + +# from mistapi.websockets.sites.mxedges_stats import SiteMxEdgesStatsEvents +from mistapi.websockets.sites.pcap import SitePcapEvents diff --git a/src/mistapi/websockets/sites/clients_stats.py b/src/mistapi/websockets/sites/clients_stats.py new file mode 100644 index 0000000..bdfb5f8 --- /dev/null +++ b/src/mistapi/websockets/sites/clients_stats.py @@ -0,0 +1,56 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for site clients stats events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class SiteClientsStatsEvents(_MistWebsocket): + """WebSocket stream for site clients stats events. + + Subscribes to the ``sites/{site_id}/stats/clients`` channel and delivers + real-time clients stats events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = SiteClientsStatsEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = SiteClientsStatsEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with SiteClientsStatsEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str) -> None: + super().__init__(mist_session, channel=f"/sites/{site_id}/stats/clients") diff --git a/src/mistapi/websockets/sites/devices_cmd.py b/src/mistapi/websockets/sites/devices_cmd.py new file mode 100644 index 0000000..8febf4d --- /dev/null +++ b/src/mistapi/websockets/sites/devices_cmd.py @@ -0,0 +1,60 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for site device command events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class SiteDeviceCmdEvents(_MistWebsocket): + """WebSocket stream for site device command events. + + Subscribes to the ``sites/{site_id}/devices/{device_id}/cmd`` channel and delivers + real-time device command events for the given site and device. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + device_id : str + UUID of the device to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str, device_id: str) -> None: + super().__init__( + mist_session, channel=f"/sites/{site_id}/devices/{device_id}/cmd" + ) diff --git a/src/mistapi/websockets/sites/devices_stats.py b/src/mistapi/websockets/sites/devices_stats.py new file mode 100644 index 0000000..a9edca4 --- /dev/null +++ b/src/mistapi/websockets/sites/devices_stats.py @@ -0,0 +1,56 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for site device stats events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class SiteDeviceStatsEvents(_MistWebsocket): + """WebSocket stream for site device stats events. + + Subscribes to the ``sites/{site_id}/stats/devices`` channel and delivers + real-time device stats events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = SiteDeviceStatsEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = SiteDeviceStatsEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with SiteDeviceStatsEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str) -> None: + super().__init__(mist_session, channel=f"/sites/{site_id}/stats/devices") diff --git a/src/mistapi/websockets/sites/devices_upgrades.py b/src/mistapi/websockets/sites/devices_upgrades.py new file mode 100644 index 0000000..29f498c --- /dev/null +++ b/src/mistapi/websockets/sites/devices_upgrades.py @@ -0,0 +1,56 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for site device upgrades events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class SiteDeviceUpgradesEvents(_MistWebsocket): + """WebSocket stream for site device upgrades events. + + Subscribes to the ``sites/{site_id}/devices`` channel and delivers + real-time device upgrades events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = SiteDeviceUpgradesEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = SiteDeviceUpgradesEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with SiteDeviceUpgradesEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str) -> None: + super().__init__(mist_session, channel=f"/sites/{site_id}/devices") diff --git a/src/mistapi/websockets/sites/mxedges_stats.py b/src/mistapi/websockets/sites/mxedges_stats.py new file mode 100644 index 0000000..dce7984 --- /dev/null +++ b/src/mistapi/websockets/sites/mxedges_stats.py @@ -0,0 +1,56 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for site MX edges stats events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class SiteMxEdgesStatsEvents(_MistWebsocket): + """WebSocket stream for site MX edges stats events. + + Subscribes to the ``sites/{site_id}/stats/mxedges`` channel and delivers + real-time MX edges stats events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = SiteMxEdgesStatsEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = SiteMxEdgesStatsEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with SiteMxEdgesStatsEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str) -> None: + super().__init__(mist_session, channel=f"/sites/{site_id}/stats/mxedges") diff --git a/src/mistapi/websockets/sites/pcap.py b/src/mistapi/websockets/sites/pcap.py new file mode 100644 index 0000000..cc60be5 --- /dev/null +++ b/src/mistapi/websockets/sites/pcap.py @@ -0,0 +1,56 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for site PCAP events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class SitePcapEvents(_MistWebsocket): + """WebSocket stream for site PCAP events. + + Subscribes to the ``sites/{site_id}/pcap`` channel and delivers + real-time PCAP events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = SitePcapEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = SitePcapEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with SitePcapEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str) -> None: + super().__init__(mist_session, channel=f"/sites/{site_id}/pcap") diff --git a/uv.lock b/uv.lock index ca7641b..2a738e2 100644 --- a/uv.lock +++ b/uv.lock @@ -546,6 +546,7 @@ dependencies = [ { name = "python-dotenv" }, { name = "requests" }, { name = "tabulate" }, + { name = "websocket-client" }, ] [package.dev-dependencies] @@ -569,6 +570,7 @@ requires-dist = [ { name = "python-dotenv", specifier = ">=1.1.0" }, { name = "requests", specifier = ">=2.32.3" }, { name = "tabulate", specifier = ">=0.9.0" }, + { name = "websocket-client", specifier = ">=1.8.0" }, ] [package.metadata.requires-dev] @@ -1006,6 +1008,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6d/b9/4095b668ea3678bf6a0af005527f39de12fb026516fb3df17495a733b7f8/urllib3-2.6.2-py3-none-any.whl", hash = "sha256:ec21cddfe7724fc7cb4ba4bea7aa8e2ef36f607a4bab81aa6ce42a13dc3f03dd", size = 131182, upload-time = "2025-12-11T15:56:38.584Z" }, ] +[[package]] +name = "websocket-client" +version = "1.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/41/aa4bf9664e4cda14c3b39865b12251e8e7d239f4cd0e3cc1b6c2ccde25c1/websocket_client-1.9.0.tar.gz", hash = "sha256:9e813624b6eb619999a97dc7958469217c3176312b3a16a4bd1bc7e08a46ec98", size = 70576, upload-time = "2025-10-07T21:16:36.495Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/34/db/b10e48aa8fff7407e67470363eac595018441cf32d5e1001567a7aeba5d2/websocket_client-1.9.0-py3-none-any.whl", hash = "sha256:af248a825037ef591efbf6ed20cc5faa03d3b47b9e5a2230a529eeee1c1fc3ef", size = 82616, upload-time = "2025-10-07T21:16:34.951Z" }, +] + [[package]] name = "zipp" version = "3.23.0" From 435115989f73b0e4f903e0d41d095a8585421893 Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Tue, 24 Feb 2026 12:17:13 +0100 Subject: [PATCH 02/16] reorganize websocket files --- CLAUDE.md | 7 + src/mistapi/websockets/location.py | 244 ++++++++++++++++ src/mistapi/websockets/location/__init__.py | 21 -- src/mistapi/websockets/location/ble_assets.py | 60 ---- .../websockets/location/clients_connected.py | 61 ---- .../websockets/location/clients_sdk.py | 61 ---- .../location/clients_unconnected.py | 61 ---- .../location/discovered_ble_assets.py | 61 ---- src/mistapi/websockets/orgs.py | 138 +++++++++ src/mistapi/websockets/orgs/__init__.py | 15 - src/mistapi/websockets/orgs/insights.py | 56 ---- src/mistapi/websockets/orgs/mxedges_stats.py | 56 ---- .../websockets/orgs/mxedges_upgrades.py | 56 ---- src/mistapi/websockets/sites.py | 271 ++++++++++++++++++ src/mistapi/websockets/sites/__init__.py | 19 -- src/mistapi/websockets/sites/clients_stats.py | 56 ---- src/mistapi/websockets/sites/devices_cmd.py | 60 ---- src/mistapi/websockets/sites/devices_stats.py | 56 ---- .../websockets/sites/devices_upgrades.py | 56 ---- src/mistapi/websockets/sites/mxedges_stats.py | 56 ---- src/mistapi/websockets/sites/pcap.py | 56 ---- 21 files changed, 660 insertions(+), 867 deletions(-) create mode 100644 CLAUDE.md create mode 100644 src/mistapi/websockets/location.py delete mode 100644 src/mistapi/websockets/location/__init__.py delete mode 100644 src/mistapi/websockets/location/ble_assets.py delete mode 100644 src/mistapi/websockets/location/clients_connected.py delete mode 100644 src/mistapi/websockets/location/clients_sdk.py delete mode 100644 src/mistapi/websockets/location/clients_unconnected.py delete mode 100644 src/mistapi/websockets/location/discovered_ble_assets.py create mode 100644 src/mistapi/websockets/orgs.py delete mode 100644 src/mistapi/websockets/orgs/__init__.py delete mode 100644 src/mistapi/websockets/orgs/insights.py delete mode 100644 src/mistapi/websockets/orgs/mxedges_stats.py delete mode 100644 src/mistapi/websockets/orgs/mxedges_upgrades.py create mode 100644 src/mistapi/websockets/sites.py delete mode 100644 src/mistapi/websockets/sites/__init__.py delete mode 100644 src/mistapi/websockets/sites/clients_stats.py delete mode 100644 src/mistapi/websockets/sites/devices_cmd.py delete mode 100644 src/mistapi/websockets/sites/devices_stats.py delete mode 100644 src/mistapi/websockets/sites/devices_upgrades.py delete mode 100644 src/mistapi/websockets/sites/mxedges_stats.py delete mode 100644 src/mistapi/websockets/sites/pcap.py diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..94b8344 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,7 @@ +This is the repo for my Mist API Python client, which is a wrapper around the Mist API. It allows you to easily interact with the Mist API and perform various actions such as creating and managing devices, sites, and more. + +The code in src/mistapi/api is automatically generated from the OpenAPI specification provided by Mist. This means that the code is always up to date with the latest version of the API, and you can be confident that it will work correctly with the Mist API. +The code in src/mistapi/api is organized into different modules, each corresponding to a different aspect of the Mist API. For example, there are modules for managing devices, sites, and more. Each module contains functions that correspond to the various endpoints of the Mist API, allowing you to easily perform actions such as creating a new device, retrieving information about a site, and more. + + +The code in src/mistapi/websocket is here to provide a WebSocket client for the Mist API. This allows you to receive real-time updates from the Mist API, such as when a new device is added or when a site is updated. The WebSocket client is built using the popular websocket-client library, and it provides an easy-to-use interface for connecting to the Mist API and receiving updates. \ No newline at end of file diff --git a/src/mistapi/websockets/location.py b/src/mistapi/websockets/location.py new file mode 100644 index 0000000..bc82754 --- /dev/null +++ b/src/mistapi/websockets/location.py @@ -0,0 +1,244 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for Location events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class LocationBleAssetsEvents(_MistWebsocket): + """WebSocket stream for location BLE assets events. + + Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/assets`` channel and delivers + real-time BLE assets events for the given location. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + map_id : str + UUID of the map to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style (background thread):: + + ws = LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: + super().__init__( + mist_session, channel=f"/sites/{site_id}/stats/maps/{map_id}/assets" + ) + + +class LocationConnectedClientsEvents(_MistWebsocket): + """WebSocket stream for location connected clients events. + + Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/clients`` channel and delivers + real-time connected clients events for the given location. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + map_id : str + UUID of the map to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style (background thread):: + + ws = LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: + super().__init__( + mist_session, + channel=f"/sites/{site_id}/stats/maps/{map_id}/clients", + ) + + +class LocationSdkClientsEvents(_MistWebsocket): + """WebSocket stream for location SDK clients events. + + Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/sdkclients`` channel and delivers + real-time SDK clients events for the given location. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + map_id : str + UUID of the map to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style (background thread):: + + ws = LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: + super().__init__( + mist_session, + channel=f"/sites/{site_id}/stats/maps/{map_id}/sdkclients", + ) + + +class LocationUnconnectedClientsEvents(_MistWebsocket): + """WebSocket stream for location unconnected clients events. + + Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/unconnected_clients`` channel and delivers + real-time unconnected clients events for the given location. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + map_id : str + UUID of the map to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style (background thread):: + + ws = LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: + super().__init__( + mist_session, + channel=f"/sites/{site_id}/stats/maps/{map_id}/unconnected_clients", + ) + + +class LocationDiscoveredBleAssetsEvents(_MistWebsocket): + """WebSocket stream for location discovered BLE assets events. + + Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/discovered_assets`` channel and delivers + real-time discovered BLE assets events for the given location. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + map_id : str + UUID of the map to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style (background thread):: + + ws = LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: + super().__init__( + mist_session, + channel=f"/sites/{site_id}/stats/maps/{map_id}/discovered_assets", + ) diff --git a/src/mistapi/websockets/location/__init__.py b/src/mistapi/websockets/location/__init__.py deleted file mode 100644 index fed0c6a..0000000 --- a/src/mistapi/websockets/location/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -""" - -from mistapi.websockets.location.ble_assets import LocationBleAssetsEvents -from mistapi.websockets.location.clients_connected import LocationConnectedClientsEvents -from mistapi.websockets.location.clients_sdk import LocationSdkClientsEvents -from mistapi.websockets.location.clients_unconnected import ( - LocationUnconnectedClientsEvents, -) -from mistapi.websockets.location.discovered_ble_assets import ( - LocationDiscoveredBleAssetsEvents, -) diff --git a/src/mistapi/websockets/location/ble_assets.py b/src/mistapi/websockets/location/ble_assets.py deleted file mode 100644 index fe3a668..0000000 --- a/src/mistapi/websockets/location/ble_assets.py +++ /dev/null @@ -1,60 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -WebSocket channel for BLE assets location events. -""" - -from mistapi import APISession -from mistapi.websockets.__ws_client import _MistWebsocket - - -class LocationBleAssetsEvents(_MistWebsocket): - """WebSocket stream for location BLE assets events. - - Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/assets`` channel and delivers - real-time BLE assets events for the given location. - - PARAMS - ----------- - mist_session : mistapi.APISession - Authenticated API session. - site_id : str - UUID of the site to stream events from. - map_id : str - UUID of the map to stream events from. - - EXAMPLE - ----------- - Callback style (background thread):: - - ws = LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") - ws.on_message(lambda data: print(data)) - ws.connect() - input("Press Enter to stop") - ws.disconnect() - - Generator style (background thread):: - - ws = LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") - ws.connect(run_in_background=True) - for msg in ws.receive(): - process(msg) - - Context manager:: - - with LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") as ws: - ws.on_message(my_handler) - time.sleep(60) - """ - - def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: - super().__init__( - mist_session, channel=f"/sites/{site_id}/stats/maps/{map_id}/assets" - ) diff --git a/src/mistapi/websockets/location/clients_connected.py b/src/mistapi/websockets/location/clients_connected.py deleted file mode 100644 index ed17636..0000000 --- a/src/mistapi/websockets/location/clients_connected.py +++ /dev/null @@ -1,61 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -WebSocket channel for connected clients location events. -""" - -from mistapi import APISession -from mistapi.websockets.__ws_client import _MistWebsocket - - -class LocationConnectedClientsEvents(_MistWebsocket): - """WebSocket stream for location connected clients events. - - Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/clients`` channel and delivers - real-time connected clients events for the given location. - - PARAMS - ----------- - mist_session : mistapi.APISession - Authenticated API session. - site_id : str - UUID of the site to stream events from. - map_id : str - UUID of the map to stream events from. - - EXAMPLE - ----------- - Callback style (background thread):: - - ws = LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") - ws.on_message(lambda data: print(data)) - ws.connect() - input("Press Enter to stop") - ws.disconnect() - - Generator style (background thread):: - - ws = LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") - ws.connect(run_in_background=True) - for msg in ws.receive(): - process(msg) - - Context manager:: - - with LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") as ws: - ws.on_message(my_handler) - time.sleep(60) - """ - - def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: - super().__init__( - mist_session, - channel=f"/sites/{site_id}/stats/maps/{map_id}/clients", - ) diff --git a/src/mistapi/websockets/location/clients_sdk.py b/src/mistapi/websockets/location/clients_sdk.py deleted file mode 100644 index e490543..0000000 --- a/src/mistapi/websockets/location/clients_sdk.py +++ /dev/null @@ -1,61 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -WebSocket channel for SDK Clients location events. -""" - -from mistapi import APISession -from mistapi.websockets.__ws_client import _MistWebsocket - - -class LocationSdkClientsEvents(_MistWebsocket): - """WebSocket stream for location SDK clients events. - - Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/sdkclients`` channel and delivers - real-time SDK clients events for the given location. - - PARAMS - ----------- - mist_session : mistapi.APISession - Authenticated API session. - site_id : str - UUID of the site to stream events from. - map_id : str - UUID of the map to stream events from. - - EXAMPLE - ----------- - Callback style (background thread):: - - ws = LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") - ws.on_message(lambda data: print(data)) - ws.connect() - input("Press Enter to stop") - ws.disconnect() - - Generator style (background thread):: - - ws = LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") - ws.connect(run_in_background=True) - for msg in ws.receive(): - process(msg) - - Context manager:: - - with LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") as ws: - ws.on_message(my_handler) - time.sleep(60) - """ - - def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: - super().__init__( - mist_session, - channel=f"/sites/{site_id}/stats/maps/{map_id}/sdkclients", - ) diff --git a/src/mistapi/websockets/location/clients_unconnected.py b/src/mistapi/websockets/location/clients_unconnected.py deleted file mode 100644 index 2c48f35..0000000 --- a/src/mistapi/websockets/location/clients_unconnected.py +++ /dev/null @@ -1,61 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -WebSocket channel for unconnected clients location events. -""" - -from mistapi import APISession -from mistapi.websockets.__ws_client import _MistWebsocket - - -class LocationUnconnectedClientsEvents(_MistWebsocket): - """WebSocket stream for location unconnected clients events. - - Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/unconnected_clients`` channel and delivers - real-time unconnected clients events for the given location. - - PARAMS - ----------- - mist_session : mistapi.APISession - Authenticated API session. - site_id : str - UUID of the site to stream events from. - map_id : str - UUID of the map to stream events from. - - EXAMPLE - ----------- - Callback style (background thread):: - - ws = LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") - ws.on_message(lambda data: print(data)) - ws.connect() - input("Press Enter to stop") - ws.disconnect() - - Generator style (background thread):: - - ws = LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") - ws.connect(run_in_background=True) - for msg in ws.receive(): - process(msg) - - Context manager:: - - with LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") as ws: - ws.on_message(my_handler) - time.sleep(60) - """ - - def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: - super().__init__( - mist_session, - channel=f"/sites/{site_id}/stats/maps/{map_id}/unconnected_clients", - ) diff --git a/src/mistapi/websockets/location/discovered_ble_assets.py b/src/mistapi/websockets/location/discovered_ble_assets.py deleted file mode 100644 index 97cf510..0000000 --- a/src/mistapi/websockets/location/discovered_ble_assets.py +++ /dev/null @@ -1,61 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -WebSocket channel for discovered BLE assets location events. -""" - -from mistapi import APISession -from mistapi.websockets.__ws_client import _MistWebsocket - - -class LocationDiscoveredBleAssetsEvents(_MistWebsocket): - """WebSocket stream for location discovered BLE assets events. - - Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/discovered_assets`` channel and delivers - real-time discovered BLE assets events for the given location. - - PARAMS - ----------- - mist_session : mistapi.APISession - Authenticated API session. - site_id : str - UUID of the site to stream events from. - map_id : str - UUID of the map to stream events from. - - EXAMPLE - ----------- - Callback style (background thread):: - - ws = LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") - ws.on_message(lambda data: print(data)) - ws.connect() - input("Press Enter to stop") - ws.disconnect() - - Generator style (background thread):: - - ws = LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") - ws.connect(run_in_background=True) - for msg in ws.receive(): - process(msg) - - Context manager:: - - with LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") as ws: - ws.on_message(my_handler) - time.sleep(60) - """ - - def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: - super().__init__( - mist_session, - channel=f"/sites/{site_id}/stats/maps/{map_id}/discovered_assets", - ) diff --git a/src/mistapi/websockets/orgs.py b/src/mistapi/websockets/orgs.py new file mode 100644 index 0000000..cc84691 --- /dev/null +++ b/src/mistapi/websockets/orgs.py @@ -0,0 +1,138 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for Org events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class OrgInsightsEvents(_MistWebsocket): + """WebSocket stream for organization insights events. + + Subscribes to the ``orgs/{org_id}/insights/summary`` channel and delivers + real-time insights events for the given organization. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + org_id : str + UUID of the organization to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = OrgInsightsEvents(session, org_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = OrgInsightsEvents(session, org_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with OrgInsightsEvents(session, org_id="abc123") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, org_id: str) -> None: + super().__init__(mist_session, channel=f"/orgs/{org_id}/insights/summary") + + +class OrgMxEdgesStatsEvents(_MistWebsocket): + """WebSocket stream for organization MX edges stats events. + + Subscribes to the ``orgs/{org_id}/stats/mxedges`` channel and delivers + real-time MX edges stats events for the given organization. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + org_id : str + UUID of the organization to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = OrgMxEdgesStatsEvents(session, org_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = OrgMxEdgesStatsEvents(session, org_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with OrgMxEdgesStatsEvents(session, org_id="abc123") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, org_id: str) -> None: + super().__init__(mist_session, channel=f"/orgs/{org_id}/stats/mxedges") + + +class OrgMxEdgesUpgradesEvents(_MistWebsocket): + """WebSocket stream for org MX edges upgrades events. + + Subscribes to the ``orgs/{org_id}/mxedges`` channel and delivers + real-time MX edges upgrades events for the given org. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + org_id : str + UUID of the org to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = OrgMxEdgesUpgradesEvents(session, org_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = OrgMxEdgesUpgradesEvents(session, org_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with OrgMxEdgesUpgradesEvents(session, org_id="abc123") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, org_id: str) -> None: + super().__init__(mist_session, channel=f"/orgs/{org_id}/mxedges") diff --git a/src/mistapi/websockets/orgs/__init__.py b/src/mistapi/websockets/orgs/__init__.py deleted file mode 100644 index 3f3353e..0000000 --- a/src/mistapi/websockets/orgs/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -""" - -from mistapi.websockets.orgs.insights import OrgInsightsEvents -from mistapi.websockets.orgs.mxedges_stats import OrgMxEdgesStatsEvents -from mistapi.websockets.orgs.mxedges_upgrades import OrgMxEdgesUpgradesEvents diff --git a/src/mistapi/websockets/orgs/insights.py b/src/mistapi/websockets/orgs/insights.py deleted file mode 100644 index 9153678..0000000 --- a/src/mistapi/websockets/orgs/insights.py +++ /dev/null @@ -1,56 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -WebSocket channel for organization insights events. -""" - -from mistapi import APISession -from mistapi.websockets.__ws_client import _MistWebsocket - - -class OrgInsightsEvents(_MistWebsocket): - """WebSocket stream for organization insights events. - - Subscribes to the ``orgs/{org_id}/insights/summary`` channel and delivers - real-time insights events for the given organization. - - PARAMS - ----------- - mist_session : mistapi.APISession - Authenticated API session. - org_id : str - UUID of the organization to stream events from. - - EXAMPLE - ----------- - Callback style (background thread):: - - ws = OrgInsightsEvents(session, org_id="abc123") - ws.on_message(lambda data: print(data)) - ws.connect() - input("Press Enter to stop") - ws.disconnect() - - Generator style:: - - ws = OrgInsightsEvents(session, org_id="abc123") - ws.connect(run_in_background=True) - for msg in ws.receive(): - process(msg) - - Context manager:: - - with OrgInsightsEvents(session, org_id="abc123") as ws: - ws.on_message(my_handler) - time.sleep(60) - """ - - def __init__(self, mist_session: APISession, org_id: str) -> None: - super().__init__(mist_session, channel=f"/orgs/{org_id}/insights/summary") diff --git a/src/mistapi/websockets/orgs/mxedges_stats.py b/src/mistapi/websockets/orgs/mxedges_stats.py deleted file mode 100644 index 4b20fbf..0000000 --- a/src/mistapi/websockets/orgs/mxedges_stats.py +++ /dev/null @@ -1,56 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -WebSocket channel for site MX edges stats events. -""" - -from mistapi import APISession -from mistapi.websockets.__ws_client import _MistWebsocket - - -class OrgMxEdgesStatsEvents(_MistWebsocket): - """WebSocket stream for organization MX edges stats events. - - Subscribes to the ``orgs/{org_id}/stats/mxedges`` channel and delivers - real-time MX edges stats events for the given organization. - - PARAMS - ----------- - mist_session : mistapi.APISession - Authenticated API session. - org_id : str - UUID of the organization to stream events from. - - EXAMPLE - ----------- - Callback style (background thread):: - - ws = OrgMxEdgesStatsEvents(session, org_id="abc123") - ws.on_message(lambda data: print(data)) - ws.connect() - input("Press Enter to stop") - ws.disconnect() - - Generator style:: - - ws = OrgMxEdgesStatsEvents(session, org_id="abc123") - ws.connect(run_in_background=True) - for msg in ws.receive(): - process(msg) - - Context manager:: - - with OrgMxEdgesStatsEvents(session, org_id="abc123") as ws: - ws.on_message(my_handler) - time.sleep(60) - """ - - def __init__(self, mist_session: APISession, org_id: str) -> None: - super().__init__(mist_session, channel=f"/orgs/{org_id}/stats/mxedges") diff --git a/src/mistapi/websockets/orgs/mxedges_upgrades.py b/src/mistapi/websockets/orgs/mxedges_upgrades.py deleted file mode 100644 index d2882ab..0000000 --- a/src/mistapi/websockets/orgs/mxedges_upgrades.py +++ /dev/null @@ -1,56 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -WebSocket channel for org MX edges upgrades events. -""" - -from mistapi import APISession -from mistapi.websockets.__ws_client import _MistWebsocket - - -class OrgMxEdgesUpgradesEvents(_MistWebsocket): - """WebSocket stream for org MX edges upgrades events. - - Subscribes to the ``orgs/{org_id}/mxedges`` channel and delivers - real-time MX edges upgrades events for the given org. - - PARAMS - ----------- - mist_session : mistapi.APISession - Authenticated API session. - org_id : str - UUID of the org to stream events from. - - EXAMPLE - ----------- - Callback style (background thread):: - - ws = OrgMxEdgesUpgradesEvents(session, org_id="abc123") - ws.on_message(lambda data: print(data)) - ws.connect() - input("Press Enter to stop") - ws.disconnect() - - Generator style:: - - ws = OrgMxEdgesUpgradesEvents(session, org_id="abc123") - ws.connect(run_in_background=True) - for msg in ws.receive(): - process(msg) - - Context manager:: - - with OrgMxEdgesUpgradesEvents(session, org_id="abc123") as ws: - ws.on_message(my_handler) - time.sleep(60) - """ - - def __init__(self, mist_session: APISession, org_id: str) -> None: - super().__init__(mist_session, channel=f"/orgs/{org_id}/mxedges") diff --git a/src/mistapi/websockets/sites.py b/src/mistapi/websockets/sites.py new file mode 100644 index 0000000..8138f11 --- /dev/null +++ b/src/mistapi/websockets/sites.py @@ -0,0 +1,271 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for Site events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class SiteClientsStatsEvents(_MistWebsocket): + """WebSocket stream for site clients stats events. + + Subscribes to the ``sites/{site_id}/stats/clients`` channel and delivers + real-time clients stats events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = SiteClientsStatsEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = SiteClientsStatsEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with SiteClientsStatsEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str) -> None: + super().__init__(mist_session, channel=f"/sites/{site_id}/stats/clients") + + +class SiteDeviceCmdEvents(_MistWebsocket): + """WebSocket stream for site device command events. + + Subscribes to the ``sites/{site_id}/devices/{device_id}/cmd`` channel and delivers + real-time device command events for the given site and device. + + Device commands functions: + mistapi.api.v1.sites.devices.arpFromDevice + mistapi.api.v1.sites.devices.bounceDevicePort + mistapi.api.v1.sites.devices.cableTestFromSwitch + mistapi.api.v1.sites.devices.clearSiteDeviceMacTable + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + device_id : str + UUID of the device to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str, device_id: str) -> None: + super().__init__( + mist_session, channel=f"/sites/{site_id}/devices/{device_id}/cmd" + ) + + +class SiteDeviceStatsEvents(_MistWebsocket): + """WebSocket stream for site device stats events. + + Subscribes to the ``sites/{site_id}/stats/devices`` channel and delivers + real-time device stats events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = SiteDeviceStatsEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = SiteDeviceStatsEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with SiteDeviceStatsEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str) -> None: + super().__init__(mist_session, channel=f"/sites/{site_id}/stats/devices") + + +class SiteDeviceUpgradesEvents(_MistWebsocket): + """WebSocket stream for site device upgrades events. + + Subscribes to the ``sites/{site_id}/devices`` channel and delivers + real-time device upgrades events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = SiteDeviceUpgradesEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = SiteDeviceUpgradesEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with SiteDeviceUpgradesEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str) -> None: + super().__init__(mist_session, channel=f"/sites/{site_id}/devices") + + +class SiteMxEdgesStatsEvents(_MistWebsocket): + """WebSocket stream for site MX edges stats events. + + Subscribes to the ``sites/{site_id}/stats/mxedges`` channel and delivers + real-time MX edges stats events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = SiteMxEdgesStatsEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = SiteMxEdgesStatsEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with SiteMxEdgesStatsEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str) -> None: + super().__init__(mist_session, channel=f"/sites/{site_id}/stats/mxedges") + + +class SitePcapEvents(_MistWebsocket): + """WebSocket stream for site PCAP events. + + Subscribes to the ``sites/{site_id}/pcap`` channel and delivers + real-time PCAP events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = SitePcapEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = SitePcapEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with SitePcapEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__(self, mist_session: APISession, site_id: str) -> None: + super().__init__(mist_session, channel=f"/sites/{site_id}/pcap") diff --git a/src/mistapi/websockets/sites/__init__.py b/src/mistapi/websockets/sites/__init__.py deleted file mode 100644 index 369357e..0000000 --- a/src/mistapi/websockets/sites/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -""" - -from mistapi.websockets.sites.clients_stats import SiteClientsStatsEvents -from mistapi.websockets.sites.devices_cmd import SiteDeviceCmdEvents -from mistapi.websockets.sites.devices_stats import SiteDeviceStatsEvents -from mistapi.websockets.sites.devices_upgrades import SiteDeviceUpgradesEvents - -# from mistapi.websockets.sites.mxedges_stats import SiteMxEdgesStatsEvents -from mistapi.websockets.sites.pcap import SitePcapEvents diff --git a/src/mistapi/websockets/sites/clients_stats.py b/src/mistapi/websockets/sites/clients_stats.py deleted file mode 100644 index bdfb5f8..0000000 --- a/src/mistapi/websockets/sites/clients_stats.py +++ /dev/null @@ -1,56 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -WebSocket channel for site clients stats events. -""" - -from mistapi import APISession -from mistapi.websockets.__ws_client import _MistWebsocket - - -class SiteClientsStatsEvents(_MistWebsocket): - """WebSocket stream for site clients stats events. - - Subscribes to the ``sites/{site_id}/stats/clients`` channel and delivers - real-time clients stats events for the given site. - - PARAMS - ----------- - mist_session : mistapi.APISession - Authenticated API session. - site_id : str - UUID of the site to stream events from. - - EXAMPLE - ----------- - Callback style (background thread):: - - ws = SiteClientsStatsEvents(session, site_id="abc123") - ws.on_message(lambda data: print(data)) - ws.connect() - input("Press Enter to stop") - ws.disconnect() - - Generator style:: - - ws = SiteClientsStatsEvents(session, site_id="abc123") - ws.connect(run_in_background=True) - for msg in ws.receive(): - process(msg) - - Context manager:: - - with SiteClientsStatsEvents(session, site_id="abc123") as ws: - ws.on_message(my_handler) - time.sleep(60) - """ - - def __init__(self, mist_session: APISession, site_id: str) -> None: - super().__init__(mist_session, channel=f"/sites/{site_id}/stats/clients") diff --git a/src/mistapi/websockets/sites/devices_cmd.py b/src/mistapi/websockets/sites/devices_cmd.py deleted file mode 100644 index 8febf4d..0000000 --- a/src/mistapi/websockets/sites/devices_cmd.py +++ /dev/null @@ -1,60 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -WebSocket channel for site device command events. -""" - -from mistapi import APISession -from mistapi.websockets.__ws_client import _MistWebsocket - - -class SiteDeviceCmdEvents(_MistWebsocket): - """WebSocket stream for site device command events. - - Subscribes to the ``sites/{site_id}/devices/{device_id}/cmd`` channel and delivers - real-time device command events for the given site and device. - - PARAMS - ----------- - mist_session : mistapi.APISession - Authenticated API session. - site_id : str - UUID of the site to stream events from. - device_id : str - UUID of the device to stream events from. - - EXAMPLE - ----------- - Callback style (background thread):: - - ws = SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") - ws.on_message(lambda data: print(data)) - ws.connect() - input("Press Enter to stop") - ws.disconnect() - - Generator style:: - - ws = SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") - ws.connect(run_in_background=True) - for msg in ws.receive(): - process(msg) - - Context manager:: - - with SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") as ws: - ws.on_message(my_handler) - time.sleep(60) - """ - - def __init__(self, mist_session: APISession, site_id: str, device_id: str) -> None: - super().__init__( - mist_session, channel=f"/sites/{site_id}/devices/{device_id}/cmd" - ) diff --git a/src/mistapi/websockets/sites/devices_stats.py b/src/mistapi/websockets/sites/devices_stats.py deleted file mode 100644 index a9edca4..0000000 --- a/src/mistapi/websockets/sites/devices_stats.py +++ /dev/null @@ -1,56 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -WebSocket channel for site device stats events. -""" - -from mistapi import APISession -from mistapi.websockets.__ws_client import _MistWebsocket - - -class SiteDeviceStatsEvents(_MistWebsocket): - """WebSocket stream for site device stats events. - - Subscribes to the ``sites/{site_id}/stats/devices`` channel and delivers - real-time device stats events for the given site. - - PARAMS - ----------- - mist_session : mistapi.APISession - Authenticated API session. - site_id : str - UUID of the site to stream events from. - - EXAMPLE - ----------- - Callback style (background thread):: - - ws = SiteDeviceStatsEvents(session, site_id="abc123") - ws.on_message(lambda data: print(data)) - ws.connect() - input("Press Enter to stop") - ws.disconnect() - - Generator style:: - - ws = SiteDeviceStatsEvents(session, site_id="abc123") - ws.connect(run_in_background=True) - for msg in ws.receive(): - process(msg) - - Context manager:: - - with SiteDeviceStatsEvents(session, site_id="abc123") as ws: - ws.on_message(my_handler) - time.sleep(60) - """ - - def __init__(self, mist_session: APISession, site_id: str) -> None: - super().__init__(mist_session, channel=f"/sites/{site_id}/stats/devices") diff --git a/src/mistapi/websockets/sites/devices_upgrades.py b/src/mistapi/websockets/sites/devices_upgrades.py deleted file mode 100644 index 29f498c..0000000 --- a/src/mistapi/websockets/sites/devices_upgrades.py +++ /dev/null @@ -1,56 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -WebSocket channel for site device upgrades events. -""" - -from mistapi import APISession -from mistapi.websockets.__ws_client import _MistWebsocket - - -class SiteDeviceUpgradesEvents(_MistWebsocket): - """WebSocket stream for site device upgrades events. - - Subscribes to the ``sites/{site_id}/devices`` channel and delivers - real-time device upgrades events for the given site. - - PARAMS - ----------- - mist_session : mistapi.APISession - Authenticated API session. - site_id : str - UUID of the site to stream events from. - - EXAMPLE - ----------- - Callback style (background thread):: - - ws = SiteDeviceUpgradesEvents(session, site_id="abc123") - ws.on_message(lambda data: print(data)) - ws.connect() - input("Press Enter to stop") - ws.disconnect() - - Generator style:: - - ws = SiteDeviceUpgradesEvents(session, site_id="abc123") - ws.connect(run_in_background=True) - for msg in ws.receive(): - process(msg) - - Context manager:: - - with SiteDeviceUpgradesEvents(session, site_id="abc123") as ws: - ws.on_message(my_handler) - time.sleep(60) - """ - - def __init__(self, mist_session: APISession, site_id: str) -> None: - super().__init__(mist_session, channel=f"/sites/{site_id}/devices") diff --git a/src/mistapi/websockets/sites/mxedges_stats.py b/src/mistapi/websockets/sites/mxedges_stats.py deleted file mode 100644 index dce7984..0000000 --- a/src/mistapi/websockets/sites/mxedges_stats.py +++ /dev/null @@ -1,56 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -WebSocket channel for site MX edges stats events. -""" - -from mistapi import APISession -from mistapi.websockets.__ws_client import _MistWebsocket - - -class SiteMxEdgesStatsEvents(_MistWebsocket): - """WebSocket stream for site MX edges stats events. - - Subscribes to the ``sites/{site_id}/stats/mxedges`` channel and delivers - real-time MX edges stats events for the given site. - - PARAMS - ----------- - mist_session : mistapi.APISession - Authenticated API session. - site_id : str - UUID of the site to stream events from. - - EXAMPLE - ----------- - Callback style (background thread):: - - ws = SiteMxEdgesStatsEvents(session, site_id="abc123") - ws.on_message(lambda data: print(data)) - ws.connect() - input("Press Enter to stop") - ws.disconnect() - - Generator style:: - - ws = SiteMxEdgesStatsEvents(session, site_id="abc123") - ws.connect(run_in_background=True) - for msg in ws.receive(): - process(msg) - - Context manager:: - - with SiteMxEdgesStatsEvents(session, site_id="abc123") as ws: - ws.on_message(my_handler) - time.sleep(60) - """ - - def __init__(self, mist_session: APISession, site_id: str) -> None: - super().__init__(mist_session, channel=f"/sites/{site_id}/stats/mxedges") diff --git a/src/mistapi/websockets/sites/pcap.py b/src/mistapi/websockets/sites/pcap.py deleted file mode 100644 index cc60be5..0000000 --- a/src/mistapi/websockets/sites/pcap.py +++ /dev/null @@ -1,56 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- -WebSocket channel for site PCAP events. -""" - -from mistapi import APISession -from mistapi.websockets.__ws_client import _MistWebsocket - - -class SitePcapEvents(_MistWebsocket): - """WebSocket stream for site PCAP events. - - Subscribes to the ``sites/{site_id}/pcap`` channel and delivers - real-time PCAP events for the given site. - - PARAMS - ----------- - mist_session : mistapi.APISession - Authenticated API session. - site_id : str - UUID of the site to stream events from. - - EXAMPLE - ----------- - Callback style (background thread):: - - ws = SitePcapEvents(session, site_id="abc123") - ws.on_message(lambda data: print(data)) - ws.connect() - input("Press Enter to stop") - ws.disconnect() - - Generator style:: - - ws = SitePcapEvents(session, site_id="abc123") - ws.connect(run_in_background=True) - for msg in ws.receive(): - process(msg) - - Context manager:: - - with SitePcapEvents(session, site_id="abc123") as ws: - ws.on_message(my_handler) - time.sleep(60) - """ - - def __init__(self, mist_session: APISession, site_id: str) -> None: - super().__init__(mist_session, channel=f"/sites/{site_id}/pcap") From be3d06d23282181d8eff8ba450c46c5e4f73e82f Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Tue, 24 Feb 2026 12:23:59 +0100 Subject: [PATCH 03/16] add ping_interval and ping_timeout params --- README.md | 40 ++++++++++++++++++++------- src/mistapi/websockets/__ws_client.py | 12 ++++++-- src/mistapi/websockets/location.py | 16 +++++++---- src/mistapi/websockets/orgs.py | 12 ++++---- src/mistapi/websockets/sites.py | 24 ++++++++-------- 5 files changed, 68 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index c36ebc3..8e894fc 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,9 @@ A comprehensive Python package to interact with the Mist Cloud APIs, built from - [Pagination](#pagination-support) - [Examples](#examples) - [WebSocket Streaming](#websocket-streaming) - - [Available Channels](#available-channels) + - [Connection Parameters](#connection-parameters) - [Callbacks](#callbacks) + - [Available Channels](#available-channels) - [Usage Patterns](#usage-patterns) - [Development](#development-and-testing) - [Contributing](#contributing) @@ -482,6 +483,34 @@ events = mistapi.api.v1.orgs.clients.searchOrgClientsEvents( The package provides a WebSocket client for real-time event streaming from the Mist API (`wss://{host}/api-ws/v1/stream`). Authentication is handled automatically using the same session credentials (API token or login/password). +### Connection Parameters + +All channel classes accept the following optional keyword arguments to control the WebSocket keep-alive behaviour: + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `ping_interval` | `int` | `30` | Seconds between automatic ping frames. Set to `0` to disable pings. | +| `ping_timeout` | `int` | `10` | Seconds to wait for a pong response before treating the connection as dead. | + +```python +ws = mistapi.websockets.sites.SiteDeviceStatsEvents( + apisession, + site_id="", + ping_interval=60, # ping every 60 s + ping_timeout=20, # wait up to 20 s for pong +) +ws.connect() +``` + +### Callbacks + +| Method | Signature | Description | +|--------|-----------|-------------| +| `ws.on_open(cb)` | `cb()` | Called when the connection is established | +| `ws.on_message(cb)` | `cb(data: dict)` | Called for every incoming message | +| `ws.on_error(cb)` | `cb(error: Exception)` | Called on WebSocket errors | +| `ws.on_close(cb)` | `cb(status_code: int, msg: str)` | Called when the connection closes | + ### Available Channels #### Organization Channels @@ -512,15 +541,6 @@ The package provides a WebSocket client for real-time event streaming from the M | `mistapi.websockets.location.LocationUnconnectedClientsEvents` | `/sites/{site_id}/stats/maps/{map_id}/unconnected_clients` | Real-time unconnected clients location events | | `mistapi.websockets.location.LocationDiscoveredBleAssetsEvents` | `/sites/{site_id}/stats/maps/{map_id}/discovered_assets` | Real-time discovered BLE assets location events | -### Callbacks - -| Method | Signature | Description | -|--------|-----------|-------------| -| `ws.on_open(cb)` | `cb()` | Called when the connection is established | -| `ws.on_message(cb)` | `cb(data: dict)` | Called for every incoming message | -| `ws.on_error(cb)` | `cb(error: Exception)` | Called on WebSocket errors | -| `ws.on_close(cb)` | `cb(status_code: int, msg: str)` | Called when the connection closes | - ### Usage Patterns #### Callback style (recommended) diff --git a/src/mistapi/websockets/__ws_client.py b/src/mistapi/websockets/__ws_client.py index b0f4268..1dcd765 100644 --- a/src/mistapi/websockets/__ws_client.py +++ b/src/mistapi/websockets/__ws_client.py @@ -36,9 +36,17 @@ class _MistWebsocket: - Login/password sessions pass the requests Session cookies. """ - def __init__(self, mist_session: "APISession", channel: str) -> None: + def __init__( + self, + mist_session: "APISession", + channel: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: self._mist_session = mist_session self._channel = channel + self._ping_interval = ping_interval + self._ping_timeout = ping_timeout self._ws: websocket.WebSocketApp | None = None self._thread: threading.Thread | None = None self._queue: queue.Queue[dict | None] = queue.Queue() @@ -147,7 +155,7 @@ def connect(self, run_in_background: bool = True) -> None: def _run_forever_safe(self) -> None: if self._ws: try: - self._ws.run_forever(ping_interval=30, ping_timeout=10) + self._ws.run_forever(ping_interval=self._ping_interval, ping_timeout=self._ping_timeout) except Exception as exc: self._handle_error(self._ws, exc) self._handle_close(self._ws, -1, str(exc)) diff --git a/src/mistapi/websockets/location.py b/src/mistapi/websockets/location.py index bc82754..54c928b 100644 --- a/src/mistapi/websockets/location.py +++ b/src/mistapi/websockets/location.py @@ -54,9 +54,9 @@ class LocationBleAssetsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: + def __init__(self, mist_session: APISession, site_id: str, map_id: str, **kwargs) -> None: super().__init__( - mist_session, channel=f"/sites/{site_id}/stats/maps/{map_id}/assets" + mist_session, channel=f"/sites/{site_id}/stats/maps/{map_id}/assets", **kwargs ) @@ -99,10 +99,11 @@ class LocationConnectedClientsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: + def __init__(self, mist_session: APISession, site_id: str, map_id: str, **kwargs) -> None: super().__init__( mist_session, channel=f"/sites/{site_id}/stats/maps/{map_id}/clients", + **kwargs, ) @@ -145,10 +146,11 @@ class LocationSdkClientsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: + def __init__(self, mist_session: APISession, site_id: str, map_id: str, **kwargs) -> None: super().__init__( mist_session, channel=f"/sites/{site_id}/stats/maps/{map_id}/sdkclients", + **kwargs, ) @@ -191,10 +193,11 @@ class LocationUnconnectedClientsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: + def __init__(self, mist_session: APISession, site_id: str, map_id: str, **kwargs) -> None: super().__init__( mist_session, channel=f"/sites/{site_id}/stats/maps/{map_id}/unconnected_clients", + **kwargs, ) @@ -237,8 +240,9 @@ class LocationDiscoveredBleAssetsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, map_id: str) -> None: + def __init__(self, mist_session: APISession, site_id: str, map_id: str, **kwargs) -> None: super().__init__( mist_session, channel=f"/sites/{site_id}/stats/maps/{map_id}/discovered_assets", + **kwargs, ) diff --git a/src/mistapi/websockets/orgs.py b/src/mistapi/websockets/orgs.py index cc84691..7463571 100644 --- a/src/mistapi/websockets/orgs.py +++ b/src/mistapi/websockets/orgs.py @@ -52,8 +52,8 @@ class OrgInsightsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, org_id: str) -> None: - super().__init__(mist_session, channel=f"/orgs/{org_id}/insights/summary") + def __init__(self, mist_session: APISession, org_id: str, **kwargs) -> None: + super().__init__(mist_session, channel=f"/orgs/{org_id}/insights/summary", **kwargs) class OrgMxEdgesStatsEvents(_MistWebsocket): @@ -93,8 +93,8 @@ class OrgMxEdgesStatsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, org_id: str) -> None: - super().__init__(mist_session, channel=f"/orgs/{org_id}/stats/mxedges") + def __init__(self, mist_session: APISession, org_id: str, **kwargs) -> None: + super().__init__(mist_session, channel=f"/orgs/{org_id}/stats/mxedges", **kwargs) class OrgMxEdgesUpgradesEvents(_MistWebsocket): @@ -134,5 +134,5 @@ class OrgMxEdgesUpgradesEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, org_id: str) -> None: - super().__init__(mist_session, channel=f"/orgs/{org_id}/mxedges") + def __init__(self, mist_session: APISession, org_id: str, **kwargs) -> None: + super().__init__(mist_session, channel=f"/orgs/{org_id}/mxedges", **kwargs) diff --git a/src/mistapi/websockets/sites.py b/src/mistapi/websockets/sites.py index 8138f11..9d32711 100644 --- a/src/mistapi/websockets/sites.py +++ b/src/mistapi/websockets/sites.py @@ -52,8 +52,8 @@ class SiteClientsStatsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str) -> None: - super().__init__(mist_session, channel=f"/sites/{site_id}/stats/clients") + def __init__(self, mist_session: APISession, site_id: str, **kwargs) -> None: + super().__init__(mist_session, channel=f"/sites/{site_id}/stats/clients", **kwargs) class SiteDeviceCmdEvents(_MistWebsocket): @@ -101,9 +101,9 @@ class SiteDeviceCmdEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, device_id: str) -> None: + def __init__(self, mist_session: APISession, site_id: str, device_id: str, **kwargs) -> None: super().__init__( - mist_session, channel=f"/sites/{site_id}/devices/{device_id}/cmd" + mist_session, channel=f"/sites/{site_id}/devices/{device_id}/cmd", **kwargs ) @@ -144,8 +144,8 @@ class SiteDeviceStatsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str) -> None: - super().__init__(mist_session, channel=f"/sites/{site_id}/stats/devices") + def __init__(self, mist_session: APISession, site_id: str, **kwargs) -> None: + super().__init__(mist_session, channel=f"/sites/{site_id}/stats/devices", **kwargs) class SiteDeviceUpgradesEvents(_MistWebsocket): @@ -185,8 +185,8 @@ class SiteDeviceUpgradesEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str) -> None: - super().__init__(mist_session, channel=f"/sites/{site_id}/devices") + def __init__(self, mist_session: APISession, site_id: str, **kwargs) -> None: + super().__init__(mist_session, channel=f"/sites/{site_id}/devices", **kwargs) class SiteMxEdgesStatsEvents(_MistWebsocket): @@ -226,8 +226,8 @@ class SiteMxEdgesStatsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str) -> None: - super().__init__(mist_session, channel=f"/sites/{site_id}/stats/mxedges") + def __init__(self, mist_session: APISession, site_id: str, **kwargs) -> None: + super().__init__(mist_session, channel=f"/sites/{site_id}/stats/mxedges", **kwargs) class SitePcapEvents(_MistWebsocket): @@ -267,5 +267,5 @@ class SitePcapEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str) -> None: - super().__init__(mist_session, channel=f"/sites/{site_id}/pcap") + def __init__(self, mist_session: APISession, site_id: str, **kwargs) -> None: + super().__init__(mist_session, channel=f"/sites/{site_id}/pcap", **kwargs) From cef5e9044d6b978397400e0714004e7046815846 Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Tue, 24 Feb 2026 19:01:31 +0100 Subject: [PATCH 04/16] add ws.ready() and explicit parameters --- src/mistapi/websockets/__ws_client.py | 8 +- src/mistapi/websockets/location.py | 83 ++++++++++++++++--- src/mistapi/websockets/orgs.py | 57 +++++++++++-- src/mistapi/websockets/sites.py | 113 +++++++++++++++++++++++--- 4 files changed, 232 insertions(+), 29 deletions(-) diff --git a/src/mistapi/websockets/__ws_client.py b/src/mistapi/websockets/__ws_client.py index 1dcd765..13a33b1 100644 --- a/src/mistapi/websockets/__ws_client.py +++ b/src/mistapi/websockets/__ws_client.py @@ -155,7 +155,9 @@ def connect(self, run_in_background: bool = True) -> None: def _run_forever_safe(self) -> None: if self._ws: try: - self._ws.run_forever(ping_interval=self._ping_interval, ping_timeout=self._ping_timeout) + self._ws.run_forever( + ping_interval=self._ping_interval, ping_timeout=self._ping_timeout + ) except Exception as exc: self._handle_error(self._ws, exc) self._handle_close(self._ws, -1, str(exc)) @@ -188,3 +190,7 @@ def __enter__(self) -> "_MistWebsocket": def __exit__(self, *args) -> None: self.disconnect() + + def ready(self) -> bool | None: + """Returns True if the WebSocket connection is open and ready.""" + return self._ws is not None and self._ws.ready() diff --git a/src/mistapi/websockets/location.py b/src/mistapi/websockets/location.py index 54c928b..d8aecd6 100644 --- a/src/mistapi/websockets/location.py +++ b/src/mistapi/websockets/location.py @@ -29,6 +29,11 @@ class LocationBleAssetsEvents(_MistWebsocket): UUID of the site to stream events from. map_id : str UUID of the map to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + EXAMPLE ----------- @@ -54,9 +59,19 @@ class LocationBleAssetsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, map_id: str, **kwargs) -> None: + def __init__( + self, + mist_session: APISession, + site_id: str, + map_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: super().__init__( - mist_session, channel=f"/sites/{site_id}/stats/maps/{map_id}/assets", **kwargs + mist_session, + channel=f"/sites/{site_id}/stats/maps/{map_id}/assets", + ping_interval=ping_interval, + ping_timeout=ping_timeout, ) @@ -74,6 +89,10 @@ class LocationConnectedClientsEvents(_MistWebsocket): UUID of the site to stream events from. map_id : str UUID of the map to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. EXAMPLE ----------- @@ -99,11 +118,19 @@ class LocationConnectedClientsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, map_id: str, **kwargs) -> None: + def __init__( + self, + mist_session: APISession, + site_id: str, + map_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: super().__init__( mist_session, channel=f"/sites/{site_id}/stats/maps/{map_id}/clients", - **kwargs, + ping_interval=ping_interval, + ping_timeout=ping_timeout, ) @@ -121,6 +148,10 @@ class LocationSdkClientsEvents(_MistWebsocket): UUID of the site to stream events from. map_id : str UUID of the map to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. EXAMPLE ----------- @@ -146,11 +177,19 @@ class LocationSdkClientsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, map_id: str, **kwargs) -> None: + def __init__( + self, + mist_session: APISession, + site_id: str, + map_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: super().__init__( mist_session, channel=f"/sites/{site_id}/stats/maps/{map_id}/sdkclients", - **kwargs, + ping_interval=ping_interval, + ping_timeout=ping_timeout, ) @@ -168,6 +207,10 @@ class LocationUnconnectedClientsEvents(_MistWebsocket): UUID of the site to stream events from. map_id : str UUID of the map to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. EXAMPLE ----------- @@ -193,11 +236,19 @@ class LocationUnconnectedClientsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, map_id: str, **kwargs) -> None: + def __init__( + self, + mist_session: APISession, + site_id: str, + map_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: super().__init__( mist_session, channel=f"/sites/{site_id}/stats/maps/{map_id}/unconnected_clients", - **kwargs, + ping_interval=ping_interval, + ping_timeout=ping_timeout, ) @@ -215,6 +266,10 @@ class LocationDiscoveredBleAssetsEvents(_MistWebsocket): UUID of the site to stream events from. map_id : str UUID of the map to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. EXAMPLE ----------- @@ -240,9 +295,17 @@ class LocationDiscoveredBleAssetsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, map_id: str, **kwargs) -> None: + def __init__( + self, + mist_session: APISession, + site_id: str, + map_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: super().__init__( mist_session, channel=f"/sites/{site_id}/stats/maps/{map_id}/discovered_assets", - **kwargs, + ping_interval=ping_interval, + ping_timeout=ping_timeout, ) diff --git a/src/mistapi/websockets/orgs.py b/src/mistapi/websockets/orgs.py index 7463571..c3111c8 100644 --- a/src/mistapi/websockets/orgs.py +++ b/src/mistapi/websockets/orgs.py @@ -27,6 +27,10 @@ class OrgInsightsEvents(_MistWebsocket): Authenticated API session. org_id : str UUID of the organization to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. EXAMPLE ----------- @@ -52,8 +56,19 @@ class OrgInsightsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, org_id: str, **kwargs) -> None: - super().__init__(mist_session, channel=f"/orgs/{org_id}/insights/summary", **kwargs) + def __init__( + self, + mist_session: APISession, + org_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + super().__init__( + mist_session, + channel=f"/orgs/{org_id}/insights/summary", + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) class OrgMxEdgesStatsEvents(_MistWebsocket): @@ -68,6 +83,10 @@ class OrgMxEdgesStatsEvents(_MistWebsocket): Authenticated API session. org_id : str UUID of the organization to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. EXAMPLE ----------- @@ -93,8 +112,19 @@ class OrgMxEdgesStatsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, org_id: str, **kwargs) -> None: - super().__init__(mist_session, channel=f"/orgs/{org_id}/stats/mxedges", **kwargs) + def __init__( + self, + mist_session: APISession, + org_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + super().__init__( + mist_session, + channel=f"/orgs/{org_id}/stats/mxedges", + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) class OrgMxEdgesUpgradesEvents(_MistWebsocket): @@ -109,6 +139,10 @@ class OrgMxEdgesUpgradesEvents(_MistWebsocket): Authenticated API session. org_id : str UUID of the org to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. EXAMPLE ----------- @@ -134,5 +168,16 @@ class OrgMxEdgesUpgradesEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, org_id: str, **kwargs) -> None: - super().__init__(mist_session, channel=f"/orgs/{org_id}/mxedges", **kwargs) + def __init__( + self, + mist_session: APISession, + org_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + super().__init__( + mist_session, + channel=f"/orgs/{org_id}/mxedges", + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) diff --git a/src/mistapi/websockets/sites.py b/src/mistapi/websockets/sites.py index 9d32711..9ab09de 100644 --- a/src/mistapi/websockets/sites.py +++ b/src/mistapi/websockets/sites.py @@ -27,6 +27,10 @@ class SiteClientsStatsEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. EXAMPLE ----------- @@ -52,8 +56,19 @@ class SiteClientsStatsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, **kwargs) -> None: - super().__init__(mist_session, channel=f"/sites/{site_id}/stats/clients", **kwargs) + def __init__( + self, + mist_session: APISession, + site_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + super().__init__( + mist_session, + channel=f"/sites/{site_id}/stats/clients", + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) class SiteDeviceCmdEvents(_MistWebsocket): @@ -76,6 +91,10 @@ class SiteDeviceCmdEvents(_MistWebsocket): UUID of the site to stream events from. device_id : str UUID of the device to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. EXAMPLE ----------- @@ -101,9 +120,19 @@ class SiteDeviceCmdEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, device_id: str, **kwargs) -> None: + def __init__( + self, + mist_session: APISession, + site_id: str, + device_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: super().__init__( - mist_session, channel=f"/sites/{site_id}/devices/{device_id}/cmd", **kwargs + mist_session, + channel=f"/sites/{site_id}/devices/{device_id}/cmd", + ping_interval=ping_interval, + ping_timeout=ping_timeout, ) @@ -119,6 +148,10 @@ class SiteDeviceStatsEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. EXAMPLE ----------- @@ -144,8 +177,19 @@ class SiteDeviceStatsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, **kwargs) -> None: - super().__init__(mist_session, channel=f"/sites/{site_id}/stats/devices", **kwargs) + def __init__( + self, + mist_session: APISession, + site_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + super().__init__( + mist_session, + channel=f"/sites/{site_id}/stats/devices", + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) class SiteDeviceUpgradesEvents(_MistWebsocket): @@ -160,6 +204,10 @@ class SiteDeviceUpgradesEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. EXAMPLE ----------- @@ -185,8 +233,19 @@ class SiteDeviceUpgradesEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, **kwargs) -> None: - super().__init__(mist_session, channel=f"/sites/{site_id}/devices", **kwargs) + def __init__( + self, + mist_session: APISession, + site_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + super().__init__( + mist_session, + channel=f"/sites/{site_id}/devices", + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) class SiteMxEdgesStatsEvents(_MistWebsocket): @@ -201,6 +260,10 @@ class SiteMxEdgesStatsEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. EXAMPLE ----------- @@ -226,8 +289,19 @@ class SiteMxEdgesStatsEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, **kwargs) -> None: - super().__init__(mist_session, channel=f"/sites/{site_id}/stats/mxedges", **kwargs) + def __init__( + self, + mist_session: APISession, + site_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + super().__init__( + mist_session, + channel=f"/sites/{site_id}/stats/mxedges", + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) class SitePcapEvents(_MistWebsocket): @@ -242,6 +316,10 @@ class SitePcapEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. EXAMPLE ----------- @@ -267,5 +345,16 @@ class SitePcapEvents(_MistWebsocket): time.sleep(60) """ - def __init__(self, mist_session: APISession, site_id: str, **kwargs) -> None: - super().__init__(mist_session, channel=f"/sites/{site_id}/pcap", **kwargs) + def __init__( + self, + mist_session: APISession, + site_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + super().__init__( + mist_session, + channel=f"/sites/{site_id}/pcap", + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) From 84c7615ef871aa7dfb0f6fe7e5416392765c4c87 Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Tue, 3 Mar 2026 21:55:21 +0100 Subject: [PATCH 05/16] WIP --- pyproject.toml | 6 + src/mistapi/api/v1/sites/devices.py | 8 +- src/mistapi/websockets/__init__.py | 2 +- src/mistapi/websockets/__ws_client.py | 7 +- src/mistapi/websockets/location.py | 59 +- src/mistapi/websockets/orgs.py | 12 +- src/mistapi/websockets/session.py | 72 +++ src/mistapi/websockets/sites.py | 68 ++- src/mistapi/websockets/utils/__init__.py | 1 + src/mistapi/websockets/utils/__ws_wrapper.py | 155 +++++ src/mistapi/websockets/utils/common.py | 545 ++++++++++++++++++ src/mistapi/websockets/utils/gateway.py | 249 ++++++++ src/mistapi/websockets/utils/junos.py | 110 ++++ src/mistapi/websockets/utils/switch.py | 570 +++++++++++++++++++ test.py | 46 ++ 15 files changed, 1841 insertions(+), 69 deletions(-) create mode 100644 src/mistapi/websockets/session.py create mode 100644 src/mistapi/websockets/utils/__init__.py create mode 100644 src/mistapi/websockets/utils/__ws_wrapper.py create mode 100644 src/mistapi/websockets/utils/common.py create mode 100644 src/mistapi/websockets/utils/gateway.py create mode 100644 src/mistapi/websockets/utils/junos.py create mode 100644 src/mistapi/websockets/utils/switch.py create mode 100644 test.py diff --git a/pyproject.toml b/pyproject.toml index a72fa16..0430154 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,12 @@ dependencies = [ "Bug Tracker" = "https://github.com/tmunzer/mistapi_python/issues" # UV-specific configuration +[tool.uv.scripts] +test = "pytest" +lint = "ruff check src/" +fmt = "ruff format src/" +build = "python -m build" + [tool.uv] dev-dependencies = [ # Testing dependencies diff --git a/src/mistapi/api/v1/sites/devices.py b/src/mistapi/api/v1/sites/devices.py index 29ed517..34b2294 100644 --- a/src/mistapi/api/v1/sites/devices.py +++ b/src/mistapi/api/v1/sites/devices.py @@ -1734,7 +1734,7 @@ def clearSiteDevicePendingVersion( def clearSiteDevicePolicyHitCount( - mist_session: _APISession, site_id: str, device_id: str + mist_session: _APISession, site_id: str, device_id: str, body: dict ) -> _APIResponse: """ API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/utilities/common/clear-site-device-policy-hit-count @@ -1756,7 +1756,7 @@ def clearSiteDevicePolicyHitCount( """ uri = f"/api/v1/sites/{site_id}/devices/{device_id}/clear_policy_hit_count" - resp = mist_session.mist_post(uri=uri) + resp = mist_session.mist_post(uri=uri, body=body) return resp @@ -2375,7 +2375,7 @@ def getSiteDeviceZtpPassword( def testSiteSsrDnsResolution( - mist_session: _APISession, site_id: str, device_id: str + mist_session: _APISession, site_id: str, device_id: str, body: dict ) -> _APIResponse: """ API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/utilities/wan/test-site-ssr-dns-resolution @@ -2397,7 +2397,7 @@ def testSiteSsrDnsResolution( """ uri = f"/api/v1/sites/{site_id}/devices/{device_id}/resolve_dns" - resp = mist_session.mist_post(uri=uri) + resp = mist_session.mist_post(uri=uri, body=body) return resp diff --git a/src/mistapi/websockets/__init__.py b/src/mistapi/websockets/__init__.py index 6ca5838..0e89fd7 100644 --- a/src/mistapi/websockets/__init__.py +++ b/src/mistapi/websockets/__init__.py @@ -21,4 +21,4 @@ ws.connect() """ -from mistapi.websockets import location, orgs, sites +from mistapi.websockets import location, orgs, session, sites, utils diff --git a/src/mistapi/websockets/__ws_client.py b/src/mistapi/websockets/__ws_client.py index 13a33b1..29d15fa 100644 --- a/src/mistapi/websockets/__ws_client.py +++ b/src/mistapi/websockets/__ws_client.py @@ -39,12 +39,12 @@ class _MistWebsocket: def __init__( self, mist_session: "APISession", - channel: str, + channels: list[str], ping_interval: int = 30, ping_timeout: int = 10, ) -> None: self._mist_session = mist_session - self._channel = channel + self._channels = channels self._ping_interval = ping_interval self._ping_timeout = ping_timeout self._ws: websocket.WebSocketApp | None = None @@ -97,7 +97,8 @@ def on_close(self, callback: Callable[[int, str], None]) -> None: # Internal WebSocketApp handlers def _handle_open(self, ws: websocket.WebSocketApp) -> None: - ws.send(json.dumps({"subscribe": self._channel})) + for channel in self._channels: + ws.send(json.dumps({"subscribe": channel})) if self._on_open_cb: self._on_open_cb() diff --git a/src/mistapi/websockets/location.py b/src/mistapi/websockets/location.py index d8aecd6..f010240 100644 --- a/src/mistapi/websockets/location.py +++ b/src/mistapi/websockets/location.py @@ -15,7 +15,7 @@ from mistapi.websockets.__ws_client import _MistWebsocket -class LocationBleAssetsEvents(_MistWebsocket): +class BleAssetsEvents(_MistWebsocket): """WebSocket stream for location BLE assets events. Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/assets`` channel and delivers @@ -27,8 +27,8 @@ class LocationBleAssetsEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. - map_id : str - UUID of the map to stream events from. + map_id : list[str] + UUIDs of the maps to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 @@ -63,19 +63,20 @@ def __init__( self, mist_session: APISession, site_id: str, - map_id: str, + map_id: list[str], ping_interval: int = 30, ping_timeout: int = 10, ) -> None: + channels = [f"/sites/{site_id}/stats/maps/{mid}/assets" for mid in map_id] super().__init__( mist_session, - channel=f"/sites/{site_id}/stats/maps/{map_id}/assets", + channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, ) -class LocationConnectedClientsEvents(_MistWebsocket): +class ConnectedClientsEvents(_MistWebsocket): """WebSocket stream for location connected clients events. Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/clients`` channel and delivers @@ -87,8 +88,8 @@ class LocationConnectedClientsEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. - map_id : str - UUID of the map to stream events from. + map_id : list[str] + UUIDs of the maps to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 @@ -122,19 +123,20 @@ def __init__( self, mist_session: APISession, site_id: str, - map_id: str, + map_id: list[str], ping_interval: int = 30, ping_timeout: int = 10, ) -> None: + channels = [f"/sites/{site_id}/stats/maps/{mid}/clients" for mid in map_id] super().__init__( mist_session, - channel=f"/sites/{site_id}/stats/maps/{map_id}/clients", + channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, ) -class LocationSdkClientsEvents(_MistWebsocket): +class SdkClientsEvents(_MistWebsocket): """WebSocket stream for location SDK clients events. Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/sdkclients`` channel and delivers @@ -146,8 +148,8 @@ class LocationSdkClientsEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. - map_id : str - UUID of the map to stream events from. + map_id : list[str] + UUIDs of the maps to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 @@ -181,19 +183,20 @@ def __init__( self, mist_session: APISession, site_id: str, - map_id: str, + map_id: list[str], ping_interval: int = 30, ping_timeout: int = 10, ) -> None: + channels = [f"/sites/{site_id}/stats/maps/{mid}/sdkclients" for mid in map_id] super().__init__( mist_session, - channel=f"/sites/{site_id}/stats/maps/{map_id}/sdkclients", + channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, ) -class LocationUnconnectedClientsEvents(_MistWebsocket): +class UnconnectedClientsEvents(_MistWebsocket): """WebSocket stream for location unconnected clients events. Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/unconnected_clients`` channel and delivers @@ -205,8 +208,8 @@ class LocationUnconnectedClientsEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. - map_id : str - UUID of the map to stream events from. + map_id : list[str] + UUIDs of the maps to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 @@ -240,19 +243,22 @@ def __init__( self, mist_session: APISession, site_id: str, - map_id: str, + map_id: list[str], ping_interval: int = 30, ping_timeout: int = 10, ) -> None: + channels = [ + f"/sites/{site_id}/stats/maps/{mid}/unconnected_clients" for mid in map_id + ] super().__init__( mist_session, - channel=f"/sites/{site_id}/stats/maps/{map_id}/unconnected_clients", + channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, ) -class LocationDiscoveredBleAssetsEvents(_MistWebsocket): +class DiscoveredBleAssetsEvents(_MistWebsocket): """WebSocket stream for location discovered BLE assets events. Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/discovered_assets`` channel and delivers @@ -264,8 +270,8 @@ class LocationDiscoveredBleAssetsEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. - map_id : str - UUID of the map to stream events from. + map_id : list[str] + UUIDs of the maps to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 @@ -299,13 +305,16 @@ def __init__( self, mist_session: APISession, site_id: str, - map_id: str, + map_id: list[str], ping_interval: int = 30, ping_timeout: int = 10, ) -> None: + channels = [ + f"/sites/{site_id}/stats/maps/{mid}/discovered_assets" for mid in map_id + ] super().__init__( mist_session, - channel=f"/sites/{site_id}/stats/maps/{map_id}/discovered_assets", + channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, ) diff --git a/src/mistapi/websockets/orgs.py b/src/mistapi/websockets/orgs.py index c3111c8..1f9ce9a 100644 --- a/src/mistapi/websockets/orgs.py +++ b/src/mistapi/websockets/orgs.py @@ -15,7 +15,7 @@ from mistapi.websockets.__ws_client import _MistWebsocket -class OrgInsightsEvents(_MistWebsocket): +class InsightsEvents(_MistWebsocket): """WebSocket stream for organization insights events. Subscribes to the ``orgs/{org_id}/insights/summary`` channel and delivers @@ -65,13 +65,13 @@ def __init__( ) -> None: super().__init__( mist_session, - channel=f"/orgs/{org_id}/insights/summary", + channels=[f"/orgs/{org_id}/insights/summary"], ping_interval=ping_interval, ping_timeout=ping_timeout, ) -class OrgMxEdgesStatsEvents(_MistWebsocket): +class MxEdgesStatsEvents(_MistWebsocket): """WebSocket stream for organization MX edges stats events. Subscribes to the ``orgs/{org_id}/stats/mxedges`` channel and delivers @@ -121,13 +121,13 @@ def __init__( ) -> None: super().__init__( mist_session, - channel=f"/orgs/{org_id}/stats/mxedges", + channels=[f"/orgs/{org_id}/stats/mxedges"], ping_interval=ping_interval, ping_timeout=ping_timeout, ) -class OrgMxEdgesUpgradesEvents(_MistWebsocket): +class MxEdgesUpgradesEvents(_MistWebsocket): """WebSocket stream for org MX edges upgrades events. Subscribes to the ``orgs/{org_id}/mxedges`` channel and delivers @@ -177,7 +177,7 @@ def __init__( ) -> None: super().__init__( mist_session, - channel=f"/orgs/{org_id}/mxedges", + channels=[f"/orgs/{org_id}/mxedges"], ping_interval=ping_interval, ping_timeout=ping_timeout, ) diff --git a/src/mistapi/websockets/session.py b/src/mistapi/websockets/session.py new file mode 100644 index 0000000..c2ef382 --- /dev/null +++ b/src/mistapi/websockets/session.py @@ -0,0 +1,72 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for Remote Commands events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class SessionWithUrl(_MistWebsocket): + """WebSocket stream for remote commands events. + + Open a WebSocket connection to a custom channel URL for remote command events. + This is a base class that can be used to implement specific remote command event streams by providing the + appropriate channel URL. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + url : str + URL of the WebSocket channel to connect to. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = sessionWithUrl(session, url="wss://example.com/channel") + ws.on_message(lambda data: print(data)) + ws.connect() + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = sessionWithUrl(session, url="wss://example.com/channel") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with sessionWithUrl(session, url="wss://example.com/channel") as ws: + ws.on_message(my_handler) + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + url: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + super().__init__( + mist_session, + channels=[url], + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) diff --git a/src/mistapi/websockets/sites.py b/src/mistapi/websockets/sites.py index 9ab09de..27db901 100644 --- a/src/mistapi/websockets/sites.py +++ b/src/mistapi/websockets/sites.py @@ -15,7 +15,7 @@ from mistapi.websockets.__ws_client import _MistWebsocket -class SiteClientsStatsEvents(_MistWebsocket): +class ClientsStatsEvents(_MistWebsocket): """WebSocket stream for site clients stats events. Subscribes to the ``sites/{site_id}/stats/clients`` channel and delivers @@ -25,8 +25,8 @@ class SiteClientsStatsEvents(_MistWebsocket): ----------- mist_session : mistapi.APISession Authenticated API session. - site_id : str - UUID of the site to stream events from. + site_ids : list[str] + UUIDs of the sites to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 @@ -59,19 +59,20 @@ class SiteClientsStatsEvents(_MistWebsocket): def __init__( self, mist_session: APISession, - site_id: str, + site_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, ) -> None: + channels = [f"/sites/{site_id}/stats/clients" for site_id in site_ids] super().__init__( mist_session, - channel=f"/sites/{site_id}/stats/clients", + channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, ) -class SiteDeviceCmdEvents(_MistWebsocket): +class DeviceCmdEvents(_MistWebsocket): """WebSocket stream for site device command events. Subscribes to the ``sites/{site_id}/devices/{device_id}/cmd`` channel and delivers @@ -89,8 +90,8 @@ class SiteDeviceCmdEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. - device_id : str - UUID of the device to stream events from. + device_ids : list[str] + UUIDs of the devices to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 @@ -124,19 +125,22 @@ def __init__( self, mist_session: APISession, site_id: str, - device_id: str, + device_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, ) -> None: + channels = [ + f"/sites/{site_id}/devices/{device_id}/cmd" for device_id in device_ids + ] super().__init__( mist_session, - channel=f"/sites/{site_id}/devices/{device_id}/cmd", + channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, ) -class SiteDeviceStatsEvents(_MistWebsocket): +class DeviceStatsEvents(_MistWebsocket): """WebSocket stream for site device stats events. Subscribes to the ``sites/{site_id}/stats/devices`` channel and delivers @@ -146,8 +150,8 @@ class SiteDeviceStatsEvents(_MistWebsocket): ----------- mist_session : mistapi.APISession Authenticated API session. - site_id : str - UUID of the site to stream events from. + site_ids : list[str] + UUIDs of the sites to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 @@ -180,19 +184,20 @@ class SiteDeviceStatsEvents(_MistWebsocket): def __init__( self, mist_session: APISession, - site_id: str, + site_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, ) -> None: + channels = [f"/sites/{site_id}/stats/devices" for site_id in site_ids] super().__init__( mist_session, - channel=f"/sites/{site_id}/stats/devices", + channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, ) -class SiteDeviceUpgradesEvents(_MistWebsocket): +class DeviceUpgradesEvents(_MistWebsocket): """WebSocket stream for site device upgrades events. Subscribes to the ``sites/{site_id}/devices`` channel and delivers @@ -202,8 +207,8 @@ class SiteDeviceUpgradesEvents(_MistWebsocket): ----------- mist_session : mistapi.APISession Authenticated API session. - site_id : str - UUID of the site to stream events from. + site_ids : list[str] + UUIDs of the sites to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 @@ -236,19 +241,20 @@ class SiteDeviceUpgradesEvents(_MistWebsocket): def __init__( self, mist_session: APISession, - site_id: str, + site_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, ) -> None: + channels = [f"/sites/{site_id}/devices" for site_id in site_ids] super().__init__( mist_session, - channel=f"/sites/{site_id}/devices", + channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, ) -class SiteMxEdgesStatsEvents(_MistWebsocket): +class MxEdgesStatsEvents(_MistWebsocket): """WebSocket stream for site MX edges stats events. Subscribes to the ``sites/{site_id}/stats/mxedges`` channel and delivers @@ -258,8 +264,8 @@ class SiteMxEdgesStatsEvents(_MistWebsocket): ----------- mist_session : mistapi.APISession Authenticated API session. - site_id : str - UUID of the site to stream events from. + site_ids : list[str] + UUIDs of the sites to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 @@ -292,19 +298,20 @@ class SiteMxEdgesStatsEvents(_MistWebsocket): def __init__( self, mist_session: APISession, - site_id: str, + site_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, ) -> None: + channels = [f"/sites/{site_id}/stats/mxedges" for site_id in site_ids] super().__init__( mist_session, - channel=f"/sites/{site_id}/stats/mxedges", + channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, ) -class SitePcapEvents(_MistWebsocket): +class PcapEvents(_MistWebsocket): """WebSocket stream for site PCAP events. Subscribes to the ``sites/{site_id}/pcap`` channel and delivers @@ -314,8 +321,8 @@ class SitePcapEvents(_MistWebsocket): ----------- mist_session : mistapi.APISession Authenticated API session. - site_id : str - UUID of the site to stream events from. + site_ids : list[str] + UUID of the sites to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 @@ -348,13 +355,14 @@ class SitePcapEvents(_MistWebsocket): def __init__( self, mist_session: APISession, - site_id: str, + site_ids: list[str], ping_interval: int = 30, ping_timeout: int = 10, ) -> None: + channels = [f"/sites/{site_id}/pcap" for site_id in site_ids] super().__init__( mist_session, - channel=f"/sites/{site_id}/pcap", + channels=channels, ping_interval=ping_interval, ping_timeout=ping_timeout, ) diff --git a/src/mistapi/websockets/utils/__init__.py b/src/mistapi/websockets/utils/__init__.py new file mode 100644 index 0000000..943352e --- /dev/null +++ b/src/mistapi/websockets/utils/__init__.py @@ -0,0 +1 @@ +from mistapi.websockets.utils import common, gateway, junos, switch diff --git a/src/mistapi/websockets/utils/__ws_wrapper.py b/src/mistapi/websockets/utils/__ws_wrapper.py new file mode 100644 index 0000000..1ba03af --- /dev/null +++ b/src/mistapi/websockets/utils/__ws_wrapper.py @@ -0,0 +1,155 @@ +import json +import threading +import time + +from mistapi import APISession +from mistapi.__api_response import APIResponse as _APIResponse +from mistapi.__logger import logger as LOGGER +from mistapi.websockets.session import SessionWithUrl +from mistapi.websockets.sites import DeviceCmdEvents + + +class UtilResponse: + """ + A simple class to encapsulate the response from utility WebSocket functions. + This class can be extended in the future to include additional metadata or helper methods. + """ + + def __init__( + self, + api_response: _APIResponse, + ) -> None: + self.trigger_api_response = api_response + self.ws_required: bool = False # This can be set to True if the WebSocket connection was successfully initiated + self.ws_data: list[str] = [] + self.ws_raw_events: list[str] = [] + + +class WebSocketWrapper: + """ + A wrapper class for managing WebSocket connections and events. + This class provides a simplified interface for connecting to WebSocket channels, + handling messages, and managing connection timeouts. + """ + + def __init__( + self, + apissession: APISession, + util_response: UtilResponse, + timeout: int = 10, + max_duration: int = 60, + ) -> None: + self.apissession = apissession + self.util_response = util_response + self.timeout_timer = None + self.timeout = timeout + self.max_duration_timer = None + self.max_duration = max_duration + self.received_messages = 0 + self.data = [] + self.raw_events = [] + self.ws = None + + def _on_open(self): + LOGGER.info("WebSocket connection opened") + if self.max_duration_timer and self.ws: + self.max_duration_timer = threading.Timer( + self.max_duration, self.ws.disconnect + ) + self.max_duration_timer.start() + self._reset_timer() # Start the timer when the connection opens + + def _reset_timer(self): + if self.timeout_timer: + self.timeout_timer.cancel() + if self.ws: + self.timeout_timer = threading.Timer(self.timeout, self.ws.disconnect) + self.timeout_timer.start() + + def _extract_raw(self, message): + self.raw_events.append(message) + event = message + if isinstance(event, str): + try: + event = json.loads(message) + if isinstance(event, dict) and "raw" in event: + return event["raw"] + except json.JSONDecodeError: + return + if event.get("event") == "data" and event.get("data"): + return self._extract_raw(event["data"]) + elif event.get("raw"): + self.received_messages += 1 + LOGGER.debug(f"Received raw message: {event['raw']}") + return event["raw"] + return None + + def _handle_message(self, msg): + if isinstance(msg, dict) and msg.get("event") == "channel_subscribed": + LOGGER.debug(msg) + else: + LOGGER.debug(msg) + raw = self._extract_raw(msg) + if raw: + self.data.append(raw) + self._reset_timer() # Reset timeout on each message + + async def startCmdEvents(self, site_id: str, device_id: str) -> UtilResponse: + """ + Start a WebSocket stream for site device command events. + + PARAMS + ----------- + site_id : str + UUID of the site to stream events from. + device_id : str + UUID of the device to stream events from. + """ + self.ws = DeviceCmdEvents( + self.apissession, site_id=site_id, device_ids=[device_id] + ) + self.ws.on_message(self._handle_message) + self.ws.on_error(lambda error: LOGGER.error(f"Error: {error}")) + self.ws.on_close( + lambda code, msg: LOGGER.info(f"WebSocket closed: {code} - {msg}") + ) + self.ws.on_open(self._on_open) + self.ws.connect() # non-blocking + LOGGER.info("WebSocket connection initiated") + time.sleep(1) + while self.ws and self.ws.ready(): + time.sleep(1) + LOGGER.info("WebSocket connection closed, exiting") + self.util_response.ws_required = True + self.util_response.ws_data = self.data + self.util_response.ws_raw_events = self.raw_events + return self.util_response + + async def startSessionUrl(self, url: str) -> UtilResponse: + """ + Start a WebSocket stream using a custom URL. + This should be used when Mist is returning a WebSocket URL from an API call. + + PARAMS + ----------- + url : str + Full WebSocket URL to connect to (e.g., wss://api.mist.com/ws/v1/orgs/{org_id}/sites/{site_id}/devices/{device_id}/cmds). + + """ + self.ws = SessionWithUrl(self.apissession, url=url) + self.ws.on_message(self._handle_message) + self.ws.on_error(lambda error: LOGGER.error(f"Error: {error}")) + self.ws.on_close( + lambda code, msg: LOGGER.info(f"WebSocket closed: {code} - {msg}") + ) + self.ws.on_open(self._on_open) + self.ws.connect() # non-blocking + LOGGER.info("WebSocket connection initiated") + time.sleep(1) + while self.ws and self.ws.ready(): + time.sleep(1) + LOGGER.info("WebSocket connection closed, exiting") + self.util_response.ws_required = True + self.util_response.ws_data = self.data + self.util_response.ws_raw_events = self.raw_events + return self.util_response diff --git a/src/mistapi/websockets/utils/common.py b/src/mistapi/websockets/utils/common.py new file mode 100644 index 0000000..0410d5f --- /dev/null +++ b/src/mistapi/websockets/utils/common.py @@ -0,0 +1,545 @@ +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.websockets.utils.__ws_wrapper import UtilResponse, WebSocketWrapper + + +class Node(Enum): + NODE0 = "node0" + NODE1 = "node1" + + +class RouteProtocol(Enum): + ANY = "any" + BGP = "bgp" + DIRECT = "direct" + EVPN = "evpn" + OSPF = "ospf" + STATIC = "static" + + +async def retrieve_arp_table( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + timeout=5, +) -> UtilResponse: + """ + Retrieves the ARP table from a device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to retrieve the ARP table from. + node : Node, optional + Node information for the ARP table retrieval command. + timeout : int, optional + Timeout for the ARP table retrieval command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + # AP is returnning RAW data + # SWITCH is returning ??? + # GATEWAY is returning JSON + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + trigger = devices.arpFromDevice( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Show ARP command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" + ) # Give the show ARP command a moment to take effect + return util_response + + +async def bounce_ports( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + timeout=5, +) -> UtilResponse: + """ + Initiates a bounce command on the specified ports of a device and streams the results. + + PARAMS + ----------- + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to perform the bounce command on. + port_ids : list[str] + List of port IDs to bounce. + timeout : int, async default 5 + Timeout for the bounce command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if port_ids: + body["ports"] = port_ids + trigger = devices.bounceDevicePort( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info( + f"Bounce command triggered for ports {port_ids} on device {device_id}" + ) + util_response = await WebSocketWrapper( + apissession, util_response, timeout + ).startCmdEvents(site_id=site_id, device_id=device_id) + else: + LOGGER.error( + f"Failed to trigger bounce command: {trigger.status_code} - {trigger.data}" + ) # Give the bounce command a moment to take effect + return util_response + + +async def clear_mac_table( + apissession: _APISession, + site_id: str, + device_id: str, + mac_address: str | None = None, + port_id: str | None = None, + vlan_id: str | None = None, + # timeout=30, +) -> UtilResponse: + """ + Clears the MAC table on a device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to clear the MAC table from. + mac_address : str, optional + MAC address to clear from the MAC table. + port_id : str, optional + Port ID to clear from the MAC table. + vlan_id : str, optional + VLAN ID to clear from the MAC table. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + # AP is returnning RAW data + # SWITCH is returning ??? + # GATEWAY is returning JSON + body: dict[str, str | list | int] = {} + if mac_address: + body["mac_address"] = mac_address + if port_id: + body["port_id"] = port_id + if vlan_id: + body["vlan_id"] = vlan_id + trigger = devices.clearSiteDeviceMacTable( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear MAC Table command triggered for device {device_id}") + # util_response = await WebSocketWrapper( + # apissession, util_response, timeout=timeout + # ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger clear MAC Table command: {trigger.status_code} - {trigger.data}" + ) # Give the clear MAC Table command a moment to take effect + return util_response + + +async def release_dhcp_leases( + apissession: _APISession, + site_id: str, + device_id: str, + macs: list[str] | None = None, + network: str | None = None, + node: Node | None = None, + port_id: str | None = None, + timeout=5, +) -> UtilResponse: + """ + Releases DHCP leases on a device and streams the results. + valid combinations are: + - network + - network + macs + - network + port_id + - port_id + - port_id + macs + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to release DHCP leases on. + macs : list[str], optional + List of MAC addresses to release DHCP leases for. + network : str, optional + Network to release DHCP leases for. + node : Node, optional + Node information for the DHCP lease release command. + port_id : str, optional + Port ID to release DHCP leases for. + timeout : int, optional + Timeout for the release DHCP leases command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if macs: + body["macs"] = macs + if network: + body["network"] = network + if node: + body["node"] = node.value + if port_id: + body["port_id"] = port_id + trigger = devices.releaseSiteDeviceDhcpLease( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Release DHCP leases command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger release DHCP leases command: {trigger.status_code} - {trigger.data}" + ) # Give the release DHCP leases command a moment to take effect + return util_response + + +# TODO +async def retrieve_dhcp_leases( + apissession: _APISession, + site_id: str, + device_id: str, + network: str, + node: Node | None = None, + timeout=15, +) -> UtilResponse: + """ + Retrieves DHCP leases on a device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to retrieve DHCP leases from. + network : str + Network to release DHCP leases for. + node : Node, optional + Node information for the DHCP lease release command. + port_id : str, optional + Port ID to release DHCP leases for. + timeout : int, optional + Timeout for the release DHCP leases command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"network": network} + if node: + body["node"] = node.value + trigger = devices.showSiteDeviceDhcpLeases( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Retrieve DHCP leases command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger retrieve DHCP leases command: {trigger.status_code} - {trigger.data}" + ) # Give the release DHCP leases command a moment to take effect + return util_response + + +####################################################### +## Switch +####################################################### + + +async def switch_clear_bpdu_error( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], +) -> UtilResponse: + """ + Clears BPDU error state on the specified ports of a switch. + + PARAMS + ----------- + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to clear BPDU errors on. + port_ids : list[str] + List of port IDs to clear BPDU errors on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {"ports": port_ids} + trigger = devices.clearBpduErrorsFromPortsOnSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear BPDU error command triggered for device {device_id}") + else: + LOGGER.error( + f"Failed to trigger clear BPDU error command: {trigger.status_code} - {trigger.data}" + ) # Give the clear BPDU error command a moment to take effect + return util_response + + +async def switch_clear_learned_mac( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], +) -> UtilResponse: + """ + Clears learned MAC addresses on the specified ports of a device. + + PARAMS + ----------- + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to clear learned MAC addresses on. + port_ids : list[str] + List of port IDs to clear learned MAC addresses on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"ports": port_ids} + trigger = devices.clearSiteDeviceDot1xSession( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear learned MACs command triggered for device {device_id}") + else: + LOGGER.error( + f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" + ) # Give the clear learned MACs command a moment to take effect + return util_response + + +async def switch_clear_dot1x_sessions( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], +) -> UtilResponse: + """ + Clears dot1x sessions on the specified ports of a switch. + + PARAMS + ----------- + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to clear dot1x sessions on. + port_ids : list[str] + List of port IDs to clear dot1x sessions on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"ports": port_ids} + trigger = devices.clearAllLearnedMacsFromPortOnSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear learned MACs command triggered for device {device_id}") + else: + LOGGER.error( + f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" + ) # Give the clear learned MACs command a moment to take effect + return util_response + + +####################################################### +## Websocket +####################################################### + + +async def ping( + apissession: _APISession, + site_id: str, + device_id: str, + host: str, + count: int | None = None, + node: None | None = None, + size: int | None = None, + vrf: str | None = None, + timeout: int = 5, +) -> UtilResponse: + """ + Initiates a ping command from a device to a specified host and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to initiate the ping from. + host : str + The host to ping. + count : int, optional + Number of ping requests to send. + node : None, optional + Node information for the ping command. + size : int, optional + Size of the ping packet. + vrf : str, optional + VRF to use for the ping command. + timeout : int, optional + Timeout for the ping command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if count: + body["count"] = count + if host: + body["host"] = host + if node: + body["node"] = node.value + if size: + body["size"] = size + if vrf: + body["vrf"] = vrf + trigger = devices.pingFromDevice( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Ping command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger ping command: {trigger.status_code} - {trigger.data}" + ) # Give the ping command a moment to take effect + return util_response + + +# async def gateway_dns_resolution( +# self, +# site_id: str, +# device_id: str, +# timeout=10, +# ) -> list[str]: +# """For SSR Only. Initiates a DNS resolution command on the gateway and streams the results.""" +# self.timeout = timeout +# trigger = testSiteSsrDnsResolution( +# apissession, +# site_id=site_id, +# device_id=device_id, +# ) +# if trigger.status_code == 200: +# print(trigger.data) +# print(f"SSR DNS resolution command triggered for device {device_id}") +# self.startCmdEvents(site_id, device_id) +# else: +# print( +# f"Failed to trigger SSR DNS resolution command: {trigger.status_code} - {trigger.data}" +# ) # Give the SSR DNS resolution command a moment to take effect +# return util_response + +####################################################### +## Websocket Session +####################################################### diff --git a/src/mistapi/websockets/utils/gateway.py b/src/mistapi/websockets/utils/gateway.py new file mode 100644 index 0000000..18f0916 --- /dev/null +++ b/src/mistapi/websockets/utils/gateway.py @@ -0,0 +1,249 @@ +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.websockets.utils.__ws_wrapper import UtilResponse, WebSocketWrapper + + +class Node(Enum): + NODE0 = "node0" + NODE1 = "node1" + + +class RouteProtocol(Enum): + ANY = "any" + BGP = "bgp" + DIRECT = "direct" + EVPN = "evpn" + OSPF = "ospf" + STATIC = "static" + + +async def show_routes( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + prefix: str | None = None, + protocol: RouteProtocol | None = None, + route_type: str | None = None, + vrf: str | None = None, + timeout=5, +) -> UtilResponse: + """ + For SSR and SRX. Initiates a show service path command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the show routes command on. + node : Node, optional + Node information for the show routes command. + prefix : str, optional + Prefix to filter the routes. + protocol : RouteProtocol, optional + Protocol to filter the routes. + route_type : str, optional + Type of the route to filter. + vrf : str, optional + VRF to filter the routes. + timeout : int, optional + Timeout for the command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if prefix: + body["prefix"] = prefix + if protocol: + body["protocol"] = protocol.value + if route_type: + body["route_type"] = route_type + if vrf: + body["vrf"] = vrf + trigger = devices.showSiteSsrAndSrxRoutes( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"SSR service path command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger SSR service path command: {trigger.status_code} - {trigger.data}" + ) # Give the SSR service path command a moment to take effect + return util_response + + +async def test_dns_resolution( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + hostname: str | None = None, + timeout=5, +) -> UtilResponse: + """ + For SSR Only. Initiates a DNS resolution command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the DNS resolution command on. + node : Node, optional + Node information for the DNS resolution command. + hostname : str, optional + Hostname to resolve. + timeout : int, optional + Timeout for the command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if hostname: + body["hostname"] = hostname + trigger = devices.testSiteSsrDnsResolution( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"SSR DNS resolution command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger SSR DNS resolution command: {trigger.status_code} - {trigger.data}" + ) # Give the SSR DNS resolution command a moment to take effect + return util_response + + +async def show_service_path( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + service_name: str | None = None, + timeout=5, +) -> UtilResponse: + """ + For SSR Only. Initiates a show service path command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the show service path command on. + node : Node, optional + Node information for the show service path command. + service_name : str, optional + Name of the service to show the path for. + timeout : int, optional + Timeout for the command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if service_name: + body["service_name"] = service_name + trigger = devices.showSiteSsrServicePath( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"SSR service path command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger SSR service path command: {trigger.status_code} - {trigger.data}" + ) # Give the SSR service path command a moment to take effect + return util_response + + +async def clear_policy_hit_count( + apissession: _APISession, + site_id: str, + device_id: str, + policy_name: str, + # timeout: int = 10, +) -> UtilResponse: + """ + Clears the policy hit count on a device. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to clear the policy hit count on. + policy_name : str + Name of the policy to clear the hit count for. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + trigger = devices.clearSiteDevicePolicyHitCount( + apissession, + site_id=site_id, + device_id=device_id, + body={"policy_name": policy_name}, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Clear policy hit count command triggered for device {device_id}") + # util_response = await WebSocketWrapper( + # apissession, util_response, timeout=timeout + # ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger clear policy hit count command: {trigger.status_code} - {trigger.data}" + ) # Give the clear policy hit count command a moment to take effect + return util_response diff --git a/src/mistapi/websockets/utils/junos.py b/src/mistapi/websockets/utils/junos.py new file mode 100644 index 0000000..2b12f44 --- /dev/null +++ b/src/mistapi/websockets/utils/junos.py @@ -0,0 +1,110 @@ +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.websockets.utils.__ws_wrapper import UtilResponse, WebSocketWrapper + + +# TODO +async def monitor_traffic( + apissession: _APISession, + site_id: str, + device_id: str, + port_id: str | None = None, + timeout=30, +) -> UtilResponse: + """ + For EX and SRX Only. Initiates a monitor traffic command on the device and streams the results. + + * if `port_id` is provided, JUNOS uses cmd "monitor interface" to monitor traffic on particular + * if `port_id` is not provided, JUNOS uses cmd "monitor interface traffic" to monitor traffic on all ports + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to monitor traffic on. + port_id : str, optional + Port ID to filter the traffic. + timeout : int, optional + Timeout for the monitor traffic command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | int] = {"duration": 60} + if port_id: + body["port"] = port_id + trigger = devices.monitorSiteDeviceTraffic( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Monitor traffic command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startSessionUrl(trigger.data.get("url", "")) + else: + LOGGER.error( + f"Failed to trigger monitor traffic command: {trigger.status_code} - {trigger.data}" + ) # Give the monitor traffic command a moment to take effect + return util_response + + +async def clear_policy_hit_count( + apissession: _APISession, + site_id: str, + device_id: str, + policy_name: str, + timeout=30, +) -> UtilResponse: + """ + For EX and SRX Only. Clears the policy hit count on the device. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to clear the policy hit count on. + policy_name : str + Name of the policy to clear the hit count for. + timeout : int, optional + Timeout for the clear policy hit count command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str] = {} + if policy_name: + body["policy_name"] = policy_name + trigger = devices.clearSiteDevicePolicyHitCount( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear policy hit count command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger clear policy hit count command: {trigger.status_code} - {trigger.data}" + ) # Give the clear policy hit count command a moment to take effect + return util_response diff --git a/src/mistapi/websockets/utils/switch.py b/src/mistapi/websockets/utils/switch.py new file mode 100644 index 0000000..8419b19 --- /dev/null +++ b/src/mistapi/websockets/utils/switch.py @@ -0,0 +1,570 @@ +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.websockets.utils.__ws_wrapper import UtilResponse, WebSocketWrapper + + +class Node(Enum): + NODE0 = "node0" + NODE1 = "node1" + + +class RouteProtocol(Enum): + ANY = "any" + BGP = "bgp" + DIRECT = "direct" + EVPN = "evpn" + OSPF = "ospf" + STATIC = "static" + + +async def bounce_ports( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + timeout=5, +) -> UtilResponse: + """ + Initiates a bounce command on the specified ports of a device and streams the results. + + PARAMS + ----------- + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to perform the bounce command on. + port_ids : list[str] + List of port IDs to bounce. + timeout : int, async default 5 + Timeout for the bounce command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if port_ids: + body["ports"] = port_ids + trigger = devices.bounceDevicePort( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info( + f"Bounce command triggered for ports {port_ids} on device {device_id}" + ) + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id=site_id, device_id=device_id) + else: + LOGGER.error( + f"Failed to trigger bounce command: {trigger.status_code} - {trigger.data}" + ) # Give the bounce command a moment to take effect + return util_response + + +async def retrieve_arp_table( + apissession: _APISession, + site_id: str, + device_id: str, + ip: str | None = None, + port_id: str | None = None, + vrf: str | None = None, + timeout=5, +) -> UtilResponse: + """ + Retrieve the ARP table from a device with optional filters for IP, port, and VRF. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to retrieve the ARP table from. + ip : str, optional + IP address to filter the ARP table. + port_id : str, optional + Port ID to filter the ARP table. + vrf : str, optional + VRF to filter the ARP table. + timeout : int, optional + Timeout for the ARP table retrieval command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"duration": 1, "interval": 1} + if ip: + body["ip"] = ip + if vrf: + body["vrf"] = vrf + if port_id: + body["port_id"] = port_id + trigger = devices.showSiteDeviceArpTable( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Show ARP command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" + ) # Give the show ARP command a moment to take effect + return util_response + + +####################################################### +## Switch +####################################################### + + +async def switch_clear_bpdu_error( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], +) -> UtilResponse: + """ + Clears BPDU error state on the specified ports of a switch. + + PARAMS + ----------- + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to clear BPDU errors on. + port_ids : list[str] + List of port IDs to clear BPDU errors on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {"ports": port_ids} + trigger = devices.clearBpduErrorsFromPortsOnSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear BPDU error command triggered for device {device_id}") + else: + LOGGER.error( + f"Failed to trigger clear BPDU error command: {trigger.status_code} - {trigger.data}" + ) # Give the clear BPDU error command a moment to take effect + return util_response + + +async def switch_clear_learned_mac( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], +) -> UtilResponse: + """ + Clears learned MAC addresses on the specified ports of a device. + + PARAMS + ----------- + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to clear learned MAC addresses on. + port_ids : list[str] + List of port IDs to clear learned MAC addresses on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"ports": port_ids} + trigger = devices.clearSiteDeviceDot1xSession( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear learned MACs command triggered for device {device_id}") + else: + LOGGER.error( + f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" + ) # Give the clear learned MACs command a moment to take effect + return util_response + + +async def switch_clear_dot1x_sessions( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], +) -> UtilResponse: + """ + Clears dot1x sessions on the specified ports of a switch. + + PARAMS + ----------- + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to clear dot1x sessions on. + port_ids : list[str] + List of port IDs to clear dot1x sessions on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"ports": port_ids} + trigger = devices.clearAllLearnedMacsFromPortOnSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear learned MACs command triggered for device {device_id}") + else: + LOGGER.error( + f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" + ) # Give the clear learned MACs command a moment to take effect + return util_response + + +####################################################### +## Websocket +####################################################### + + +async def clear_mac_table( + apissession: _APISession, + site_id: str, + device_id: str, + mac_address: str, + port_id: str, + vlan_id: str, + timeout=5, +) -> UtilResponse: + """ + Clears the MAC table on a device for a specific MAC address, port, or VLAN and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to clear the MAC table on. + mac_address : str + MAC address to clear from the MAC table. + port_id : str + Port ID to clear the MAC table on. + vlan_id : str + VLAN ID to clear the MAC table on. + timeout : int, optional + Timeout for the clear MAC table command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if mac_address: + body["mac_address"] = mac_address + if port_id: + body["port_id"] = port_id + if vlan_id: + body["vlan_id"] = vlan_id + trigger = devices.clearSiteDeviceMacTable( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Clear MAC table command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger clear MAC table command: {trigger.status_code} - {trigger.data}" + ) # Give the clear MAC table command a moment to take effect + return util_response + + +async def ping( + apissession: _APISession, + site_id: str, + device_id: str, + host: str, + count: int | None = None, + node: None | None = None, + size: int | None = None, + vrf: str | None = None, + timeout: int = 5, +) -> UtilResponse: + """ + Initiates a ping command from a device to a specified host and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to initiate the ping from. + host : str + The host to ping. + count : int, optional + Number of ping requests to send. + node : None, optional + Node information for the ping command. + size : int, optional + Size of the ping packet. + vrf : str, optional + VRF to use for the ping command. + timeout : int, optional + Timeout for the ping command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if count: + body["count"] = count + if host: + body["host"] = host + if node: + body["node"] = node.value + if size: + body["size"] = size + if vrf: + body["vrf"] = vrf + trigger = devices.pingFromDevice( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Ping command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger ping command: {trigger.status_code} - {trigger.data}" + ) # Give the ping command a moment to take effect + return util_response + + +async def release_dhcp_leases( + apissession: _APISession, + site_id: str, + device_id: str, + macs: list[str] | None = None, + network: str | None = None, + node: Node | None = None, + port_id: str | None = None, + timeout=5, +) -> UtilResponse: + """ + Releases DHCP leases on a device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to release DHCP leases on. + macs : list[str], optional + List of MAC addresses to release DHCP leases for. + network : str, optional + Network to release DHCP leases for. + node : Node, optional + Node information for the DHCP lease release command. + port_id : str, optional + Port ID to release DHCP leases for. + timeout : int, optional + Timeout for the release DHCP leases command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if macs: + body["macs"] = macs + if network: + body["network"] = network + if node: + body["node"] = node.value + if port_id: + body["port_id"] = port_id + trigger = devices.releaseSiteDeviceDhcpLease( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Release DHCP leases command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger release DHCP leases command: {trigger.status_code} - {trigger.data}" + ) # Give the release DHCP leases command a moment to take effect + return util_response + + +async def stream_arp_table( + apissession: _APISession, + site_id: str, + device_id: str, + ip: str | None = None, + port_id: str | None = None, + vrf: str | None = None, + timeout=5, +) -> UtilResponse: + """ + Streams the ARP table from a device with optional filters for IP, port, and VRF. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to retrieve the ARP table from. + ip : str, optional + IP address to filter the ARP table. + port_id : str, optional + Port ID to filter the ARP table. + vrf : str, optional + VRF to filter the ARP table. + timeout : int, optional + Timeout for the ARP table retrieval command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"duration": 1, "interval": 1} + if ip: + body["ip"] = ip + if vrf: + body["vrf"] = vrf + if port_id: + body["port_id"] = port_id + trigger = devices.showSiteDeviceArpTable( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Show ARP command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" + ) # Give the show ARP command a moment to take effect + return util_response + + +async def switch_cable_test( + apissession: _APISession, + site_id: str, + device_id: str, + port_id: str, + timeout=10, +) -> UtilResponse: + """ + Initiates a cable test on a switch port and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to perform the cable test on. + port_id : str + Port ID to perform the cable test on. + timeout : int, optional + Timeout for the cable test command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"port": port_id} + trigger = devices.cableTestFromSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Cable test command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger cable test command: {trigger.status_code} - {trigger.data}" + ) # Give the cable test command a moment to take effect + return util_response diff --git a/test.py b/test.py new file mode 100644 index 0000000..e81f846 --- /dev/null +++ b/test.py @@ -0,0 +1,46 @@ +import asyncio + +import src.mistapi as mistapi + +# APISESSION = mistapi.APISession(env_file="~/.mist_env_ld_ro", show_cli_notif=False) +# ORG_ID = "9777c1a0-6ef6-11e6-8bbf-02e208b2d34f" +# SITE_ID = "a925ea04-8393-4e0f-ab6b-209f11382cee" +# AP_ID = "00000000-0000-0000-1000-04a92439fb75" +# SWITCH_ID = "00000000-0000-0000-1000-2093390b3580" +# GATEWAY_ID = "00000000-0000-0000-1000-409ea4e60b00" + +APISESSION = mistapi.APISession(env_file="~/.mist_env_gc1", show_cli_notif=False) +ORG_ID = "8aa21779-1178-4357-b3e0-42c02b93b870" +SITE_ID = "d6fb4f96-3ba4-4cf5-8af2-a8d7b85087ac" +AP_ID = "00000000-0000-0000-1000-04a92439fb75" +SWITCH_ID = "00000000-0000-0000-1000-2093390b3580" +GATEWAY_ID = "00000000-0000-0000-1000-0200010edbca" + +APISESSION.login() + +# data = asyncio.run( +# mistapi.websockets.utils.common.bounce_ports( +# apissession=APISESSION, +# site_id=SITE_ID, +# device_id=GATEWAY_ID, +# port_ids=["ge-0/0/3"], +# ) +# ) + + +data = asyncio.run( + mistapi.websockets.utils.junos.monitor_traffic( + apissession=APISESSION, + site_id=SITE_ID, + device_id=SWITCH_ID, + ) +) +print(data.trigger_api_response.data) +print("".center(50, "-")) +if data.ws_required: + if isinstance(data.ws_data, list): + print("".join(data.ws_data)) + else: + print(data.ws_data) +else: + print("No WebSocket data available.") From 08dcf7b46c0a5466df4003a26b2d21a0aceb1f1a Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Thu, 12 Mar 2026 21:56:42 +0100 Subject: [PATCH 06/16] websocket tools added --- .gitmodules | 2 +- mist_openapi | 2 +- pyproject.toml | 21 +- src/mistapi/__init__.py | 1 + src/mistapi/__version.py | 2 +- src/mistapi/api/v1/orgs/alarms.py | 69 +- src/mistapi/api/v1/orgs/devices.py | 2 +- src/mistapi/api/v1/orgs/inventory.py | 19 +- src/mistapi/api/v1/orgs/jsi.py | 40 +- src/mistapi/api/v1/orgs/stats.py | 2 +- src/mistapi/api/v1/sites/__init__.py | 2 + src/mistapi/api/v1/sites/alarms.py | 67 +- src/mistapi/api/v1/sites/devices.py | 15 +- src/mistapi/api/v1/sites/insights.py | 8 +- src/mistapi/api/v1/sites/mapstacks.py | 88 +++ src/mistapi/api/v1/sites/sle.py | 4 +- src/mistapi/api/v1/sites/stats.py | 2 +- src/mistapi/utils/__init__.py | 92 +++ src/mistapi/utils/__ws_wrapper.py | 309 ++++++++ src/mistapi/utils/ap.py | 29 + src/mistapi/utils/arp.py | 204 +++++ src/mistapi/utils/bgp.py | 63 ++ src/mistapi/utils/bpdu.py | 61 ++ src/mistapi/utils/dhcp.py | 162 ++++ src/mistapi/utils/dns.py | 84 +++ src/mistapi/utils/dot1x.py | 60 ++ src/mistapi/utils/ex.py | 78 ++ src/mistapi/utils/mac.py | 192 +++++ src/mistapi/utils/ospf.py | 273 +++++++ src/mistapi/utils/policy.py | 62 ++ src/mistapi/utils/port.py | 122 +++ src/mistapi/utils/routes.py | 107 +++ src/mistapi/utils/service_path.py | 84 +++ src/mistapi/utils/sessions.py | 162 ++++ src/mistapi/utils/srx.py | 61 ++ src/mistapi/utils/ssr.py | 65 ++ src/mistapi/utils/tools.py | 740 +++++++++++++++++++ src/mistapi/websockets/__init__.py | 21 +- src/mistapi/websockets/sites.py | 8 +- src/mistapi/websockets/utils/__init__.py | 1 - src/mistapi/websockets/utils/__ws_wrapper.py | 155 ---- src/mistapi/websockets/utils/common.py | 545 -------------- src/mistapi/websockets/utils/gateway.py | 249 ------- src/mistapi/websockets/utils/junos.py | 110 --- src/mistapi/websockets/utils/switch.py | 570 -------------- uv.lock | 2 +- 46 files changed, 3264 insertions(+), 1753 deletions(-) create mode 100644 src/mistapi/api/v1/sites/mapstacks.py create mode 100644 src/mistapi/utils/__init__.py create mode 100644 src/mistapi/utils/__ws_wrapper.py create mode 100644 src/mistapi/utils/ap.py create mode 100644 src/mistapi/utils/arp.py create mode 100644 src/mistapi/utils/bgp.py create mode 100644 src/mistapi/utils/bpdu.py create mode 100644 src/mistapi/utils/dhcp.py create mode 100644 src/mistapi/utils/dns.py create mode 100644 src/mistapi/utils/dot1x.py create mode 100644 src/mistapi/utils/ex.py create mode 100644 src/mistapi/utils/mac.py create mode 100644 src/mistapi/utils/ospf.py create mode 100644 src/mistapi/utils/policy.py create mode 100644 src/mistapi/utils/port.py create mode 100644 src/mistapi/utils/routes.py create mode 100644 src/mistapi/utils/service_path.py create mode 100644 src/mistapi/utils/sessions.py create mode 100644 src/mistapi/utils/srx.py create mode 100644 src/mistapi/utils/ssr.py create mode 100644 src/mistapi/utils/tools.py delete mode 100644 src/mistapi/websockets/utils/__init__.py delete mode 100644 src/mistapi/websockets/utils/__ws_wrapper.py delete mode 100644 src/mistapi/websockets/utils/common.py delete mode 100644 src/mistapi/websockets/utils/gateway.py delete mode 100644 src/mistapi/websockets/utils/junos.py delete mode 100644 src/mistapi/websockets/utils/switch.py diff --git a/.gitmodules b/.gitmodules index 4088015..081acbd 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "mist_openapi"] path = mist_openapi url = https://github.com/mistsys/mist_openapi.git - branch = 2602.1.2 \ No newline at end of file + branch = master \ No newline at end of file diff --git a/mist_openapi b/mist_openapi index b6718f7..c0a88a3 160000 --- a/mist_openapi +++ b/mist_openapi @@ -1 +1 @@ -Subproject commit b6718f784f96c6bf9fea1fae241d1276fe2140ff +Subproject commit c0a88a3c79e42d233ea45a92ffffd13f968a2a6b diff --git a/pyproject.toml b/pyproject.toml index 0430154..d4f79ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "mistapi" -version = "0.60.3" +version = "0.55.15" authors = [{ name = "Thomas Munzer", email = "tmunzer@juniper.net" }] description = "Python package to simplify the Mist System APIs usage" keywords = ["Mist", "Juniper", "API"] @@ -35,14 +35,17 @@ dependencies = [ "Bug Tracker" = "https://github.com/tmunzer/mistapi_python/issues" # UV-specific configuration -[tool.uv.scripts] -test = "pytest" -lint = "ruff check src/" -fmt = "ruff format src/" -build = "python -m build" - -[tool.uv] -dev-dependencies = [ +#[tool.uv] +#preview = true + +#[tool.uv.scripts] +#test = "pytest" +#lint = "ruff check src/" +#fmt = "ruff format src/" +#build = "python -m build" + +[dependency-groups] +dev = [ # Testing dependencies "pytest>=8.4.0", "pytest-cov>=6.1.1", diff --git a/src/mistapi/__init__.py b/src/mistapi/__init__.py index 9d04e29..6a18153 100644 --- a/src/mistapi/__init__.py +++ b/src/mistapi/__init__.py @@ -15,6 +15,7 @@ from mistapi import api as api from mistapi import cli as cli from mistapi import websockets as websockets +from mistapi import utils as utils from mistapi.__pagination import get_all as get_all from mistapi.__pagination import get_next as get_next from mistapi.__version import __author__ as __author__ diff --git a/src/mistapi/__version.py b/src/mistapi/__version.py index 4896894..f9533d2 100644 --- a/src/mistapi/__version.py +++ b/src/mistapi/__version.py @@ -1,2 +1,2 @@ -__version__ = "0.60.3" +__version__ = "0.55.15" __author__ = "Thomas Munzer " diff --git a/src/mistapi/api/v1/orgs/alarms.py b/src/mistapi/api/v1/orgs/alarms.py index f929d73..19b4599 100644 --- a/src/mistapi/api/v1/orgs/alarms.py +++ b/src/mistapi/api/v1/orgs/alarms.py @@ -144,38 +144,43 @@ def searchOrgAlarms( search_after: str | None = None, ) -> _APIResponse: """ - API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/orgs/alarms/search-org-alarms - - PARAMS - ----------- - mistapi.APISession : mist_session - mistapi session including authentication and Mist host information - - PATH PARAMS - ----------- - org_id : str - - QUERY PARAMS - ------------ - site_id : str - group : str{'infrastructure', 'marvis', 'security'} - Alarm group. enum: `infrastructure`, `marvis`, `security` - severity : str{'critical', 'info', 'warn'} - Severity of the alarm. enum: `critical`, `info`, `warn` - type : str - ack_admin_name : str - acked : bool - start : str - end : str - duration : str, default: 1d - limit : int, default: 100 - sort : str, default: timestamp - search_after : str - - RETURN - ----------- - mistapi.APIResponse - response from the API call + API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/orgs/alarms/search-org-alarms + + PARAMS + ----------- + mistapi.APISession : mist_session + mistapi session including authentication and Mist host information + + PATH PARAMS + ----------- + org_id : str + + QUERY PARAMS + ------------ + site_id : str + group : str{'infrastructure', 'marvis', 'security'} + Alarm group. enum: `infrastructure`, `marvis`, `security`. + The `marvis` group is used to retrieve AI-driven network issue detections. + Known Marvis alarm types include: `bad_cable`, `bad_wan_uplink`, `dns_failure`, + `arp_failure`, `auth_failure`, `dhcp_failure`, `missing_vlan`, + `negotiation_mismatch`, `port_flap`. Results include resolution status + (`status`, `resolved_time`) and affected entity details." + severity : str{'critical', 'info', 'warn'} + Severity of the alarm. enum: `critical`, `info`, `warn` + type : str + ack_admin_name : str + acked : bool + start : str + end : str + duration : str, default: 1d + limit : int, default: 100 + sort : str, default: timestamp + search_after : str + + RETURN + ----------- + mistapi.APIResponse + response from the API call """ uri = f"/api/v1/orgs/{org_id}/alarms/search" diff --git a/src/mistapi/api/v1/orgs/devices.py b/src/mistapi/api/v1/orgs/devices.py index ab2e9a0..ac09149 100644 --- a/src/mistapi/api/v1/orgs/devices.py +++ b/src/mistapi/api/v1/orgs/devices.py @@ -260,7 +260,7 @@ def searchOrgDeviceEvents( ------------ mac : str model : str - device_type : str{'all', 'ap', 'gateway', 'switch'}, default: ap + device_type : str, default: ap text : str timestamp : str type : str diff --git a/src/mistapi/api/v1/orgs/inventory.py b/src/mistapi/api/v1/orgs/inventory.py index b79f69a..85103b0 100644 --- a/src/mistapi/api/v1/orgs/inventory.py +++ b/src/mistapi/api/v1/orgs/inventory.py @@ -316,8 +316,8 @@ def searchOrgInventory( org_id: str, type: str | None = None, mac: str | None = None, - vc_mac: str | None = None, - master_mac: str | None = None, + model: str | None = None, + name: str | None = None, site_id: str | None = None, serial: str | None = None, master: str | None = None, @@ -345,14 +345,15 @@ def searchOrgInventory( ------------ type : str{'ap', 'gateway', 'switch'}, default: ap mac : str - vc_mac : str - master_mac : str + model : str + name : str site_id : str serial : str master : str sku : str version : str - status : str + status : str{'connected', 'disconnected'} + Device status. enum: `connected`, `disconnected` text : str limit : int, default: 100 sort : str, default: timestamp @@ -370,10 +371,10 @@ def searchOrgInventory( query_params["type"] = str(type) if mac: query_params["mac"] = str(mac) - if vc_mac: - query_params["vc_mac"] = str(vc_mac) - if master_mac: - query_params["master_mac"] = str(master_mac) + if model: + query_params["model"] = str(model) + if name: + query_params["name"] = str(name) if site_id: query_params["site_id"] = str(site_id) if serial: diff --git a/src/mistapi/api/v1/orgs/jsi.py b/src/mistapi/api/v1/orgs/jsi.py index a950555..dde8c0b 100644 --- a/src/mistapi/api/v1/orgs/jsi.py +++ b/src/mistapi/api/v1/orgs/jsi.py @@ -245,9 +245,15 @@ def searchOrgJsiAssetsAndContracts( sku: str | None = None, status: str | None = None, warranty_type: str | None = None, - eol_duration: str | None = None, - eos_duration: str | None = None, + eol_after: str | None = None, + eol_before: str | None = None, + eos_after: str | None = None, + eos_before: str | None = None, + version_eos_after: str | None = None, + version_eos_before: str | None = None, has_support: bool | None = None, + sirt_id: str | None = None, + pbn_id: str | None = None, text: str | None = None, limit: int | None = None, sort: str | None = None, @@ -275,9 +281,15 @@ def searchOrgJsiAssetsAndContracts( Device status warranty_type : str{'Standard Hardware Warranty', 'Enhanced Hardware Warranty', 'Dead On Arrival Warranty', 'Limited Lifetime Warranty', 'Software Warranty', 'Limited Lifetime Warranty for WLA', 'Warranty-JCPO EOL (DOA Not Included)', 'MIST Enhanced Hardware Warranty', 'MIST Standard Warranty', 'Determine Lifetime warranty'} Device warranty type - eol_duration : str - eos_duration : str + eol_after : str + eol_before : str + eos_after : str + eos_before : str + version_eos_after : str + version_eos_before : str has_support : bool + sirt_id : str + pbn_id : str text : str limit : int, default: 100 sort : str, default: timestamp @@ -303,12 +315,24 @@ def searchOrgJsiAssetsAndContracts( query_params["status"] = str(status) if warranty_type: query_params["warranty_type"] = str(warranty_type) - if eol_duration: - query_params["eol_duration"] = str(eol_duration) - if eos_duration: - query_params["eos_duration"] = str(eos_duration) + if eol_after: + query_params["eol_after"] = str(eol_after) + if eol_before: + query_params["eol_before"] = str(eol_before) + if eos_after: + query_params["eos_after"] = str(eos_after) + if eos_before: + query_params["eos_before"] = str(eos_before) + if version_eos_after: + query_params["version_eos_after"] = str(version_eos_after) + if version_eos_before: + query_params["version_eos_before"] = str(version_eos_before) if has_support: query_params["has_support"] = str(has_support) + if sirt_id: + query_params["sirt_id"] = str(sirt_id) + if pbn_id: + query_params["pbn_id"] = str(pbn_id) if text: query_params["text"] = str(text) if limit: diff --git a/src/mistapi/api/v1/orgs/stats.py b/src/mistapi/api/v1/orgs/stats.py index 223f861..51e8914 100644 --- a/src/mistapi/api/v1/orgs/stats.py +++ b/src/mistapi/api/v1/orgs/stats.py @@ -410,7 +410,7 @@ def listOrgDevicesStats( QUERY PARAMS ------------ - type : str{'all', 'ap', 'gateway', 'switch'}, default: ap + type : str, default: ap status : str{'all', 'connected', 'disconnected'}, default: all site_id : str mac : str diff --git a/src/mistapi/api/v1/sites/__init__.py b/src/mistapi/api/v1/sites/__init__.py index 12dc829..8ebdbb5 100644 --- a/src/mistapi/api/v1/sites/__init__.py +++ b/src/mistapi/api/v1/sites/__init__.py @@ -34,6 +34,7 @@ licenses, location, maps, + mapstacks, mxedges, mxtunnels, nac_clients, @@ -99,6 +100,7 @@ "licenses", "location", "maps", + "mapstacks", "mxedges", "mxtunnels", "nac_clients", diff --git a/src/mistapi/api/v1/sites/alarms.py b/src/mistapi/api/v1/sites/alarms.py index 6c7616c..dcc0ab9 100644 --- a/src/mistapi/api/v1/sites/alarms.py +++ b/src/mistapi/api/v1/sites/alarms.py @@ -164,37 +164,42 @@ def searchSiteAlarms( search_after: str | None = None, ) -> _APIResponse: """ - API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/sites/alarms/search-site-alarms - - PARAMS - ----------- - mistapi.APISession : mist_session - mistapi session including authentication and Mist host information - - PATH PARAMS - ----------- - site_id : str - - QUERY PARAMS - ------------ - group : str{'infrastructure', 'marvis', 'security'} - Alarm group. enum: `infrastructure`, `marvis`, `security` - severity : str{'critical', 'info', 'warn'} - Severity of the alarm. enum: `critical`, `info`, `warn` - type : str - ack_admin_name : str - acked : bool - limit : int, default: 100 - start : str - end : str - duration : str, default: 1d - sort : str, default: timestamp - search_after : str - - RETURN - ----------- - mistapi.APIResponse - response from the API call + API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/sites/alarms/search-site-alarms + + PARAMS + ----------- + mistapi.APISession : mist_session + mistapi session including authentication and Mist host information + + PATH PARAMS + ----------- + site_id : str + + QUERY PARAMS + ------------ + group : str{'infrastructure', 'marvis', 'security'} + Alarm group. enum: `infrastructure`, `marvis`, `security`. + The `marvis` group is used to retrieve AI-driven network issue detections. + Known Marvis alarm types include: `bad_cable`, `bad_wan_uplink`, `dns_failure`, + `arp_failure`, `auth_failure`, `dhcp_failure`, `missing_vlan`, + `negotiation_mismatch`, `port_flap`. Results include resolution status + (`status`, `resolved_time`) and affected entity details." + severity : str{'critical', 'info', 'warn'} + Severity of the alarm. enum: `critical`, `info`, `warn` + type : str + ack_admin_name : str + acked : bool + limit : int, default: 100 + start : str + end : str + duration : str, default: 1d + sort : str, default: timestamp + search_after : str + + RETURN + ----------- + mistapi.APIResponse + response from the API call """ uri = f"/api/v1/sites/{site_id}/alarms/search" diff --git a/src/mistapi/api/v1/sites/devices.py b/src/mistapi/api/v1/sites/devices.py index 34b2294..6fcd1b9 100644 --- a/src/mistapi/api/v1/sites/devices.py +++ b/src/mistapi/api/v1/sites/devices.py @@ -36,7 +36,7 @@ def listSiteDevices( QUERY PARAMS ------------ - type : str{'all', 'ap', 'gateway', 'switch'}, default: ap + type : str, default: ap name : str limit : int, default: 100 page : int, default: 1 @@ -1734,7 +1734,7 @@ def clearSiteDevicePendingVersion( def clearSiteDevicePolicyHitCount( - mist_session: _APISession, site_id: str, device_id: str, body: dict + mist_session: _APISession, site_id: str, device_id: str, body: dict | list ) -> _APIResponse: """ API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/utilities/common/clear-site-device-policy-hit-count @@ -1749,6 +1749,11 @@ def clearSiteDevicePolicyHitCount( site_id : str device_id : str + BODY PARAMS + ----------- + body : dict + JSON object to send to Mist Cloud (see API doc above for more details) + RETURN ----------- mistapi.APIResponse @@ -2375,7 +2380,7 @@ def getSiteDeviceZtpPassword( def testSiteSsrDnsResolution( - mist_session: _APISession, site_id: str, device_id: str, body: dict + mist_session: _APISession, site_id: str, device_id: str ) -> _APIResponse: """ API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/utilities/wan/test-site-ssr-dns-resolution @@ -2397,7 +2402,7 @@ def testSiteSsrDnsResolution( """ uri = f"/api/v1/sites/{site_id}/devices/{device_id}/resolve_dns" - resp = mist_session.mist_post(uri=uri, body=body) + resp = mist_session.mist_post(uri=uri) return resp @@ -2587,7 +2592,7 @@ def showSiteDeviceArpTable( mist_session: _APISession, site_id: str, device_id: str, body: dict | list ) -> _APIResponse: """ - API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/utilities/common/show-site-device-arp-table + API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/utilities/lan/show-site-device-arp-table PARAMS ----------- diff --git a/src/mistapi/api/v1/sites/insights.py b/src/mistapi/api/v1/sites/insights.py index 9741b2c..3033df7 100644 --- a/src/mistapi/api/v1/sites/insights.py +++ b/src/mistapi/api/v1/sites/insights.py @@ -132,7 +132,7 @@ def getSiteInsightMetricsForDevice( return resp -def countOrgClientFingerprints( +def countSiteClientFingerprints( mist_session: _APISession, site_id: str, distinct: str | None = None, @@ -142,7 +142,7 @@ def countOrgClientFingerprints( limit: int | None = None, ) -> _APIResponse: """ - API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/orgs/nac-fingerprints/count-org-client-fingerprints + API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/sites/nac-fingerprints/count-site-client-fingerprints PARAMS ----------- @@ -183,7 +183,7 @@ def countOrgClientFingerprints( return resp -def searchOrgClientFingerprints( +def searchSiteClientFingerprints( mist_session: _APISession, site_id: str, family: str | None = None, @@ -202,7 +202,7 @@ def searchOrgClientFingerprints( search_after: str | None = None, ) -> _APIResponse: """ - API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/orgs/nac-fingerprints/search-org-client-fingerprints + API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/sites/nac-fingerprints/search-site-client-fingerprints PARAMS ----------- diff --git a/src/mistapi/api/v1/sites/mapstacks.py b/src/mistapi/api/v1/sites/mapstacks.py new file mode 100644 index 0000000..33805d4 --- /dev/null +++ b/src/mistapi/api/v1/sites/mapstacks.py @@ -0,0 +1,88 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi import APISession as _APISession +from mistapi.__api_response import APIResponse as _APIResponse + + +def listSiteMapStacks( + mist_session: _APISession, + site_id: str, + limit: int | None = None, + page: int | None = None, + name: str | None = None, +) -> _APIResponse: + """ + API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/sites/map-stacks/list-site-map-stacks + + PARAMS + ----------- + mistapi.APISession : mist_session + mistapi session including authentication and Mist host information + + PATH PARAMS + ----------- + site_id : str + + QUERY PARAMS + ------------ + limit : int, default: 100 + page : int, default: 1 + name : str + + RETURN + ----------- + mistapi.APIResponse + response from the API call + """ + + uri = f"/api/v1/sites/{site_id}/mapstacks" + query_params: dict[str, str] = {} + if limit: + query_params["limit"] = str(limit) + if page: + query_params["page"] = str(page) + if name: + query_params["name"] = str(name) + resp = mist_session.mist_get(uri=uri, query=query_params) + return resp + + +def createSiteMapStack( + mist_session: _APISession, site_id: str, body: dict | list +) -> _APIResponse: + """ + API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/sites/map-stacks/create-site-map-stack + + PARAMS + ----------- + mistapi.APISession : mist_session + mistapi session including authentication and Mist host information + + PATH PARAMS + ----------- + site_id : str + + BODY PARAMS + ----------- + body : dict + JSON object to send to Mist Cloud (see API doc above for more details) + + RETURN + ----------- + mistapi.APIResponse + response from the API call + """ + + uri = f"/api/v1/sites/{site_id}/mapstacks" + resp = mist_session.mist_post(uri=uri, body=body) + return resp diff --git a/src/mistapi/api/v1/sites/sle.py b/src/mistapi/api/v1/sites/sle.py index 89c6810..b5d4c60 100644 --- a/src/mistapi/api/v1/sites/sle.py +++ b/src/mistapi/api/v1/sites/sle.py @@ -18,7 +18,7 @@ @deprecation.deprecated( deprecated_in="0.59.2", removed_in="0.65.0", - current_version="0.60.3", + current_version="0.55.15", details="function replaced with getSiteSleClassifierSummaryTrend", ) def getSiteSleClassifierDetails( @@ -741,7 +741,7 @@ def listSiteSleImpactedWirelessClients( @deprecation.deprecated( deprecated_in="0.59.2", removed_in="0.65.0", - current_version="0.60.3", + current_version="0.55.15", details="function replaced with getSiteSleSummaryTrend", ) def getSiteSleSummary( diff --git a/src/mistapi/api/v1/sites/stats.py b/src/mistapi/api/v1/sites/stats.py index 7dbd2c7..e2098dd 100644 --- a/src/mistapi/api/v1/sites/stats.py +++ b/src/mistapi/api/v1/sites/stats.py @@ -949,7 +949,7 @@ def listSiteDevicesStats( QUERY PARAMS ------------ - type : str{'all', 'ap', 'gateway', 'switch'}, default: ap + type : str, default: ap status : str{'all', 'connected', 'disconnected'}, default: all limit : int, default: 100 page : int, default: 1 diff --git a/src/mistapi/utils/__init__.py b/src/mistapi/utils/__init__.py new file mode 100644 index 0000000..4ff6e43 --- /dev/null +++ b/src/mistapi/utils/__init__.py @@ -0,0 +1,92 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- + +Mist API Utilities +================== + +This package provides utility functions for interacting with Mist devices. + +Device-Specific Modules (Recommended) +-------------------------------------- +Import device-specific modules for a clean, organized API: + + from mistapi.utils import ap, ex, srx, ssr + + # Use device-specific utilities + await ap.ping(session, site_id, device_id, host) + await ex.cable_test(session, site_id, device_id, port_id) + await ssr.show_service_path(session, site_id, device_id) + +Supported Devices: +- ap: Mist Access Points +- ex: Juniper EX Switches +- srx: Juniper SRX Firewalls +- ssr: Juniper Session Smart Routers + +Function-Based Modules (Legacy) +--------------------------------- +Original organization by function type (still available): + + from mistapi.utils import arp, bgp, dhcp, mac, port, routes, tools + +Available modules: arp, bgp, bpdu, dhcp, dns, dot1x, mac, policy, port, routes, + service_path, tools +""" + +# Device-specific modules (recommended) +# Function-based modules (legacy, still supported) +# Internal modules +from mistapi.utils import ( + __ws_wrapper, + ap, + arp, + bgp, + bpdu, + dhcp, + dns, + dot1x, + ex, + mac, + ospf, + policy, + port, + routes, + service_path, + sessions, + srx, + ssr, + tools, +) + +__all__ = [ + # Device-specific modules (recommended) + "ap", + "ex", + "srx", + "ssr", + # Function-based modules (legacy) + "arp", + "bgp", + "bpdu", + "dhcp", + "dns", + "dot1x", + "mac", + "ospf", + "policy", + "port", + "routes", + "service_path", + "sessions", + "tools", + # Internal + "__ws_wrapper", +] diff --git a/src/mistapi/utils/__ws_wrapper.py b/src/mistapi/utils/__ws_wrapper.py new file mode 100644 index 0000000..1e3c331 --- /dev/null +++ b/src/mistapi/utils/__ws_wrapper.py @@ -0,0 +1,309 @@ +import json +import threading +import time +from enum import Enum + +from mistapi import APISession +from mistapi.__api_response import APIResponse as _APIResponse +from mistapi.__logger import logger as LOGGER +from mistapi.websockets.session import SessionWithUrl +from mistapi.websockets.sites import DeviceCmdEvents, PcapEvents + + +class TimerAction(Enum): + START = "start" + STOP = "stop" + RESET = "reset" + + +class Timer(Enum): + TIMEOUT = "timeout" + FIRST_MESSAGE_TIMEOUT = "first_message_timeout" + MAX_DURATION = "max_duration" + + +class UtilResponse: + """ + A simple class to encapsulate the response from utility WebSocket functions. + This class can be extended in the future to include additional metadata or helper methods. + """ + + def __init__( + self, + api_response: _APIResponse, + ) -> None: + self.trigger_api_response = api_response + self.ws_required: bool = False # This can be set to True if the WebSocket connection was successfully initiated + self.ws_data: list[str] = [] + self.ws_raw_events: list[str] = [] + + +class WebSocketWrapper: + """ + A wrapper class for managing WebSocket connections and events. + This class provides a simplified interface for connecting to WebSocket channels, + handling messages, and managing connection timeouts. + """ + + def __init__( + self, + apissession: APISession, + util_response: UtilResponse, + timeout: int = 10, + max_duration: int = 60, + ) -> None: + self.apissession = apissession + self.util_response = util_response + self.timers = { + Timer.TIMEOUT.value: { + "thread": None, + "duration": timeout, + }, + Timer.FIRST_MESSAGE_TIMEOUT.value: { + "thread": None, + "duration": 30, + }, + Timer.MAX_DURATION.value: { + "thread": None, + "duration": max_duration, + }, + } + self.received_messages = 0 + self.data = [] + self.raw_events = [] + self.ws = None + self.session_id: str | None = None + self.capture_id: str | None = None + + LOGGER.debug( + "trigger response: %s", self.util_response.trigger_api_response.data + ) + if self.util_response.trigger_api_response.data and isinstance( + self.util_response.trigger_api_response.data, dict + ): + self.session_id = self.util_response.trigger_api_response.data.get( + "session", None + ) + self.capture_id = self.util_response.trigger_api_response.data.get( + "id", None + ) + LOGGER.debug("Extracted session_id: %s", self.session_id) + LOGGER.debug("Extracted capture_id: %s", self.capture_id) + + def _on_open(self): + LOGGER.info("WebSocket connection opened") + # Start the max duration timer + self._timeout_handler(Timer.MAX_DURATION, TimerAction.START) + # self._reset_timer() # Start the timer when the connection opens + + #################################################################################################################### + ## Helper methods for managing timers + def _timeout_handler(self, timer_type: Timer, action: TimerAction): + duration = self.timers[timer_type.value]["duration"] + if action == TimerAction.STOP or action == TimerAction.RESET: + if self.timers[timer_type.value]["thread"]: + LOGGER.debug("Stopping %s timer", timer_type.value) + self.timers[timer_type.value]["thread"].cancel() + self.timers[timer_type.value]["thread"] = None + elif action == TimerAction.STOP: + # Only warn when explicitly stopping (not resetting) a non-active timer + LOGGER.warning("%s timer is not active to stop", timer_type.value) + if action == TimerAction.START or action == TimerAction.RESET: + if self.ws: + LOGGER.debug( + "Starting %s timer with duration: %s seconds", + timer_type.value, + duration, + ) + self.timers[timer_type.value]["thread"] = threading.Timer( + duration, self.ws.disconnect + ) + self.timers[timer_type.value]["thread"].start() + else: + LOGGER.warning( + "WebSocket is not available to start %s timer", timer_type.value + ) + + #################################################################################################################### + ## WebSocket event handlers + + def _handle_message(self, msg): + if isinstance(msg, dict) and msg.get("event") == "channel_subscribed": + LOGGER.debug("channel_subscribed: %s", msg) + # Start the first message timeout timer when the channel is successfully subscribed + self._timeout_handler(Timer.FIRST_MESSAGE_TIMEOUT, TimerAction.START) + elif self._extract_session_id(msg): + # Stop the first message timeout timer on receiving the first message + self._timeout_handler(Timer.FIRST_MESSAGE_TIMEOUT, TimerAction.STOP) + LOGGER.debug("data: %s", msg) + raw = self._extract_raw(msg) + if raw: + self.data.append(raw) + self._timeout_handler(Timer.TIMEOUT, TimerAction.RESET) + + #################################################################################################################### + ## Message processing and WebSocket connection management + def _extract_session_id(self, message) -> bool: + """ + Extracts the session_id from the message and compares it to the expected session_id. + This method is designed to handle messages that may have the session_id nested at different levels. + If the expected session_id is None, it will accept all messages. + """ + if not self.session_id and not self.capture_id: + LOGGER.debug("No session_id or capture_id provided, accepting all messages") + return True + if isinstance(message, str): + LOGGER.debug("Trying to decode message: %s", message) + try: + message = json.loads(message) + except json.JSONDecodeError: + LOGGER.warning("Failed to decode message as JSON: %s", message) + return False + if isinstance(message, dict): + if message.get("event") == "data" and message.get("data"): + LOGGER.debug( + "Checking nested data for session_id or capture_id: %s", + message["data"], + ) + return self._extract_session_id(message["data"]) + if message.get("session") == self.session_id: + LOGGER.info( + "Message session_id matches expected session_id: %s", + self.session_id, + ) + return True + if message.get("capture_id") == self.capture_id: + LOGGER.info( + "Message capture_id matches expected capture_id: %s", + self.capture_id, + ) + return True + return False + + def _extract_raw(self, message): + """ + Extracts the raw message from the given message. + This method is designed to handle messages that may have the raw message nested at different levels. + Handles both command events (with "raw" field) and pcap events (with "pcap_dict" field). + """ + self.raw_events.append(message) + event = message + if isinstance(event, str): + try: + event = json.loads(message) + if isinstance(event, dict): + # Check for raw field (command events) + if "raw" in event: + LOGGER.debug("Extracted raw message: %s", event["raw"]) + return event["raw"] + # Check for pcap_dict field (pcap events) + if "pcap_dict" in event: + LOGGER.debug("Extracted pcap_dict: %s", event["pcap_dict"]) + return event["pcap_dict"] + except json.JSONDecodeError: + LOGGER.warning("Failed to decode message as JSON: %s", message) + return None + if event.get("event") == "data" and event.get("data"): + return self._extract_raw(event["data"]) + if event.get("raw"): + self.received_messages += 1 + LOGGER.debug("Received raw message: %s", event.get("raw")) + return event["raw"] + if event.get("pcap_dict"): + self.received_messages += 1 + LOGGER.debug("Received pcap data: %s", event["pcap_dict"]) + return event["pcap_dict"] + return None + + #################################################################################################################### + ## WebSocket connection management + async def startCmdEvents(self, site_id: str, device_id: str) -> UtilResponse: + """ + Start a WebSocket stream for site device command events. + + PARAMS + ----------- + site_id : str + UUID of the site to stream events from. + device_id : str + UUID of the device to stream events from. + """ + self.ws = DeviceCmdEvents( + self.apissession, site_id=site_id, device_ids=[device_id] + ) + self.ws.on_message(self._handle_message) + self.ws.on_error(lambda error: LOGGER.error(f"Error: {error}")) + self.ws.on_close( + lambda code, msg: LOGGER.info(f"WebSocket closed: {code} - {msg}") + ) + self.ws.on_open(self._on_open) + self.ws.connect() # non-blocking + LOGGER.info( + "WebSocket connection initiated: site_id=%s, device_id=%s", + site_id, + device_id, + ) + time.sleep(1) + while self.ws and self.ws.ready(): + time.sleep(1) + LOGGER.info("WebSocket connection closed, exiting") + self.util_response.ws_required = True + self.util_response.ws_data = self.data + self.util_response.ws_raw_events = self.raw_events + return self.util_response + + async def startSessionUrl(self, url: str) -> UtilResponse: + """ + Start a WebSocket stream using a custom URL. + This should be used when Mist is returning a WebSocket URL from an API call. + + PARAMS + ----------- + url : str + Full WebSocket URL to connect to (e.g., wss://api-ws.mist.com/ssh?jwt=eyJhbGciOiJI...). + """ + self.ws = SessionWithUrl(self.apissession, url=url) + self.ws.on_message(self._handle_message) + self.ws.on_error(lambda error: LOGGER.error(f"Error: {error}")) + self.ws.on_close( + lambda code, msg: LOGGER.info(f"WebSocket closed: {code} - {msg}") + ) + self.ws.on_open(self._on_open) + self.ws.connect() # non-blocking + LOGGER.info("WebSocket connection initiated: url=%s", url) + time.sleep(1) + while self.ws and self.ws.ready(): + time.sleep(1) + LOGGER.info("WebSocket connection closed, exiting") + self.util_response.ws_required = True + self.util_response.ws_data = self.data + self.util_response.ws_raw_events = self.raw_events + return self.util_response + + async def startRemotePcap(self, site_id: str) -> UtilResponse: + """ + Start a WebSocket stream for remote PCAP events. + This should be used when Mist is returning a WebSocket URL from an API call. + + PARAMS + ----------- + site_id : str + UUID of the site to stream PCAP events from. + """ + self.ws = PcapEvents(self.apissession, site_id=site_id) + self.ws.on_message(self._handle_message) + self.ws.on_error(lambda error: LOGGER.error(f"Error: {error}")) + self.ws.on_close( + lambda code, msg: LOGGER.info(f"WebSocket closed: {code} - {msg}") + ) + self.ws.on_open(self._on_open) + self.ws.connect() # non-blocking + LOGGER.info("WebSocket connection initiated: /sites/%s/pcaps", site_id) + time.sleep(1) + while self.ws and self.ws.ready(): + time.sleep(1) + LOGGER.info("WebSocket connection closed, exiting") + self.util_response.ws_required = True + self.util_response.ws_data = self.data + self.util_response.ws_raw_events = self.raw_events + return self.util_response diff --git a/src/mistapi/utils/ap.py b/src/mistapi/utils/ap.py new file mode 100644 index 0000000..b5320a8 --- /dev/null +++ b/src/mistapi/utils/ap.py @@ -0,0 +1,29 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- + +Utility functions for Mist Access Points (AP). + +This module provides a device-specific namespace for AP utilities. +All functions are imported from their respective functional modules. +""" + +# Re-export shared classes and types +from mistapi.utils.arp import Node +from mistapi.utils.arp import retrieve_ap_arp_table as retrieve_arp_table +from mistapi.utils.tools import TracerouteProtocol, ping, traceroute + +__all__ = [ + "Node", + "ping", + "traceroute", + "TracerouteProtocol", + "retrieve_arp_table", +] diff --git a/src/mistapi/utils/arp.py b/src/mistapi/utils/arp.py new file mode 100644 index 0000000..a15efce --- /dev/null +++ b/src/mistapi/utils/arp.py @@ -0,0 +1,204 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper + + +class Node(Enum): + """Node Enum for specifying node information in ARP commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +async def retrieve_ap_arp_table( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + timeout=1, +) -> UtilResponse: + """ + DEVICES: AP + + Retrieves the ARP table from a Mist Access Point and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to retrieve the ARP table from. + node : Node, optional + Node information for the ARP table retrieval command. + timeout : int, optional + Timeout for the ARP table retrieval command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + # AP is returning RAW data + # SWITCH is returning ??? + # GATEWAY is returning JSON + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + trigger = devices.arpFromDevice( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Show ARP command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" + ) # Give the show ARP command a moment to take effect + return util_response + + +async def retrieve_ssr_arp_table( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + timeout=1, +) -> UtilResponse: + """ + DEVICES: SSR + + Retrieves the ARP table from a SSR Gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to retrieve the ARP table from. + node : Node, optional + Node information for the ARP table retrieval command. + timeout : int, optional + Timeout for the ARP table retrieval command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from + the WebSocket stream. + """ + # AP is returning RAW data + # SWITCH is returning ??? + # GATEWAY is returning JSON + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + trigger = devices.arpFromDevice( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Show ARP command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" + ) # Give the show ARP command a moment to take effect + return util_response + + +async def retrieve_junos_arp_table( + apissession: _APISession, + site_id: str, + device_id: str, + ip: str | None = None, + port_id: str | None = None, + vrf: str | None = None, + timeout=1, +) -> UtilResponse: + """ + DEVICES: EX, SRX + + Retrieve the ARP table from a Junos device (EX / SRX) with optional filters for IP, port, + and VRF. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to retrieve the ARP table from. + ip : str, optional + IP address to filter the ARP table. + port_id : str, optional + Port ID to filter the ARP table. + vrf : str, optional + VRF to filter the ARP table. + timeout : int, optional + Timeout for the ARP table retrieval command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"duration": 1, "interval": 1} + if ip: + body["ip"] = ip + if vrf: + body["vrf"] = vrf + if port_id: + body["port_id"] = port_id + trigger = devices.showSiteDeviceArpTable( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Show ARP command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" + ) # Give the show ARP command a moment to take effect + return util_response diff --git a/src/mistapi/utils/bgp.py b/src/mistapi/utils/bgp.py new file mode 100644 index 0000000..2db700f --- /dev/null +++ b/src/mistapi/utils/bgp.py @@ -0,0 +1,63 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper + + +async def show_summary( + apissession: _APISession, + site_id: str, + device_id: str, + timeout=5, +) -> UtilResponse: + """ + DEVICES: EX, SRX, SSR + + Shows BGP summary on a device (EX/ SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show BGP summary on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"protocol": "bgp"} + trigger = devices.showSiteDeviceBgpSummary( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"BGP summary command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger BGP summary command: {trigger.status_code} - {trigger.data}" + ) # Give the BGP summary command a moment to take effect + return util_response diff --git a/src/mistapi/utils/bpdu.py b/src/mistapi/utils/bpdu.py new file mode 100644 index 0000000..26eccb3 --- /dev/null +++ b/src/mistapi/utils/bpdu.py @@ -0,0 +1,61 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.utils.__ws_wrapper import UtilResponse + + +async def clear_error( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], +) -> UtilResponse: + """ + DEVICES: EX + + Clears BPDU error state on the specified ports of a switch. + + PARAMS + ----------- + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to clear BPDU errors on. + port_ids : list[str] + List of port IDs to clear BPDU errors on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {"ports": port_ids} + trigger = devices.clearBpduErrorsFromPortsOnSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear BPDU error command triggered for device {device_id}") + else: + LOGGER.error( + f"Failed to trigger clear BPDU error command: {trigger.status_code} - {trigger.data}" + ) # Give the clear BPDU error command a moment to take effect + return util_response diff --git a/src/mistapi/utils/dhcp.py b/src/mistapi/utils/dhcp.py new file mode 100644 index 0000000..0738705 --- /dev/null +++ b/src/mistapi/utils/dhcp.py @@ -0,0 +1,162 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper + + +class Node(Enum): + """Node Enum for specifying node information in DHCP commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +async def release_dhcp_leases( + apissession: _APISession, + site_id: str, + device_id: str, + macs: list[str] | None = None, + network: str | None = None, + node: Node | None = None, + port_id: str | None = None, + timeout=5, +) -> UtilResponse: + """ + DEVICES: EX, SRX, SSR + + Releases DHCP leases on a device (EX/ SRX / SSR) and streams the results. + + valid combinations for EX are: + - network + macs + - network + port_id + - port_id + + valid combinations for SRX / SSR are: + - network + - network + macs + - network + port_id + - port_id + - port_id + macs + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to release DHCP leases on. + macs : list[str], optional + List of MAC addresses to release DHCP leases for. + network : str, optional + Network to release DHCP leases for. + node : Node, optional + Node information for the DHCP lease release command. + port_id : str, optional + Port ID to release DHCP leases for. + timeout : int, optional + Timeout for the release DHCP leases command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if macs: + body["macs"] = macs + if network: + body["network"] = network + if node: + body["node"] = node.value + if port_id: + body["port_id"] = port_id + trigger = devices.releaseSiteDeviceDhcpLease( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Release DHCP leases command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger release DHCP leases command: {trigger.status_code} - {trigger.data}" + ) # Give the release DHCP leases command a moment to take effect + return util_response + + +async def retrieve_dhcp_leases( + apissession: _APISession, + site_id: str, + device_id: str, + network: str, + node: Node | None = None, + timeout=15, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Retrieves DHCP leases on a gateway (SRX / SSR) and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to retrieve DHCP leases from. + network : str + Network to release DHCP leases for. + node : Node, optional + Node information for the DHCP lease release command. + port_id : str, optional + Port ID to release DHCP leases for. + timeout : int, optional + Timeout for the release DHCP leases command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"network": network} + if node: + body["node"] = node.value + trigger = devices.showSiteDeviceDhcpLeases( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Retrieve DHCP leases command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger retrieve DHCP leases command: {trigger.status_code} - {trigger.data}" + ) # Give the release DHCP leases command a moment to take effect + return util_response diff --git a/src/mistapi/utils/dns.py b/src/mistapi/utils/dns.py new file mode 100644 index 0000000..75c18f6 --- /dev/null +++ b/src/mistapi/utils/dns.py @@ -0,0 +1,84 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper + + +class Node(Enum): + """Node Enum for specifying node information in DNS commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +async def test_resolution( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + hostname: str | None = None, + timeout=5, +) -> UtilResponse: + """ + DEVICES: SSR + + Initiates a DNS resolution command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the DNS resolution command on. + node : Node, optional + Node information for the DNS resolution command. + hostname : str, optional + Hostname to resolve. + timeout : int, optional + Timeout for the command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if hostname: + body["hostname"] = hostname + trigger = devices.testSiteSsrDnsResolution( + apissession, + site_id=site_id, + device_id=device_id, + # body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"SSR DNS resolution command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger SSR DNS resolution command: {trigger.status_code} - {trigger.data}" + ) # Give the SSR DNS resolution command a moment to take effect + return util_response diff --git a/src/mistapi/utils/dot1x.py b/src/mistapi/utils/dot1x.py new file mode 100644 index 0000000..abece84 --- /dev/null +++ b/src/mistapi/utils/dot1x.py @@ -0,0 +1,60 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.utils.__ws_wrapper import UtilResponse + + +async def clear_sessions( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], +) -> UtilResponse: + """ + DEVICES: EX + + Clears dot1x sessions on the specified ports of a switch (EX). + + PARAMS + ----------- + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to clear dot1x sessions on. + port_ids : list[str] + List of port IDs to clear dot1x sessions on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"ports": port_ids} + trigger = devices.clearAllLearnedMacsFromPortOnSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear learned MACs command triggered for device {device_id}") + else: + LOGGER.error( + f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" + ) # Give the clear learned MACs command a moment to take effect + return util_response diff --git a/src/mistapi/utils/ex.py b/src/mistapi/utils/ex.py new file mode 100644 index 0000000..f9c5455 --- /dev/null +++ b/src/mistapi/utils/ex.py @@ -0,0 +1,78 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- + +Utility functions for Juniper EX Switches. + +This module provides a device-specific namespace for EX switch utilities. +All functions are imported from their respective functional modules. +""" + +# Re-export shared classes and types +from mistapi.utils.arp import Node + +# ARP functions +from mistapi.utils.arp import retrieve_junos_arp_table as retrieve_arp_table + +# BGP functions +from mistapi.utils.bgp import show_summary as show_bgp_summary + +# BPDU functions +from mistapi.utils.bpdu import clear_error as clear_bpdu_error + +# DHCP functions +from mistapi.utils.dhcp import release_dhcp_leases + +# Dot1x functions +from mistapi.utils.dot1x import clear_sessions as clear_dot1x_sessions + +# MAC table functions +from mistapi.utils.mac import ( + clear_learned_mac, + clear_mac_table, + retrieve_mac_table, +) + +# Policy functions +from mistapi.utils.policy import clear_hit_count + +# Port functions +from mistapi.utils.port import bounce as bounce_port +from mistapi.utils.port import cable_test + +# Tools (ping, monitor traffic) +from mistapi.utils.tools import monitor_traffic, ping + +__all__ = [ + # Classes/Enums + "Node", + # ARP + "retrieve_arp_table", + # BGP + "show_bgp_summary", + # BPDU + "clear_bpdu_error", + # DHCP + "release_dhcp_leases", + # Dot1x + "clear_dot1x_sessions", + # MAC + "clear_learned_mac", + "clear_mac_table", + "retrieve_mac_table", + # Port + "bounce_port", + "cable_test", + # Policy + "clear_hit_count", + # Tools + "monitor_traffic", + "ping", +] diff --git a/src/mistapi/utils/mac.py b/src/mistapi/utils/mac.py new file mode 100644 index 0000000..e4c25d5 --- /dev/null +++ b/src/mistapi/utils/mac.py @@ -0,0 +1,192 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper + + +async def clear_mac_table( + apissession: _APISession, + site_id: str, + device_id: str, + mac_address: str | None = None, + port_id: str | None = None, + vlan_id: str | None = None, + # timeout=30, +) -> UtilResponse: + """ + DEVICES: EX + + Clears the MAC table on a switch (EX). + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to clear the MAC table from. + mac_address : str, optional + MAC address to clear from the MAC table. + port_id : str, optional + Port ID to clear from the MAC table. + vlan_id : str, optional + VLAN ID to clear from the MAC table. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + # AP is returning RAW data + # SWITCH is returning ??? + # GATEWAY is returning JSON + body: dict[str, str | list | int] = {} + if mac_address: + body["mac_address"] = mac_address + if port_id: + body["port_id"] = port_id + if vlan_id: + body["vlan_id"] = vlan_id + trigger = devices.clearSiteDeviceMacTable( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear MAC Table command triggered for device {device_id}") + # util_response = await WebSocketWrapper( + # apissession, util_response, timeout=timeout + # ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger clear MAC Table command: {trigger.status_code} - {trigger.data}" + ) # Give the clear MAC Table command a moment to take effect + return util_response + + +async def retrieve_mac_table( + apissession: _APISession, + site_id: str, + device_id: str, + mac_address: str | None = None, + port_id: str | None = None, + vlan_id: str | None = None, + timeout=5, +) -> UtilResponse: + """ + DEVICES: EX + + Retrieves the MAC Table table from a switch (EX) and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to retrieve the ARP table from. + mac_address : str, optional + MAC address to filter the ARP table retrieval. + port_id : str, optional + Port ID to filter the ARP table retrieval. + vlan_id : str, optional + VLAN ID to filter the ARP table retrieval. + timeout : int, optional + Timeout for the ARP table retrieval command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + # AP is returning RAW data + # SWITCH is returning ??? + # GATEWAY is returning JSON + body: dict[str, str | list | int] = {} + if mac_address: + body["mac_address"] = mac_address + if port_id: + body["port_id"] = port_id + if vlan_id: + body["vlan_id"] = vlan_id + trigger = devices.showSiteDeviceMacTable( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Show MAC Table command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger show MAC Table command: {trigger.status_code} - {trigger.data}" + ) # Give the show ARP command a moment to take effect + return util_response + + +async def clear_learned_mac( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], +) -> UtilResponse: + """ + DEVICES: EX + + Clears learned MAC addresses on the specified ports of a switch (EX). + + PARAMS + ----------- + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to clear learned MAC addresses on. + port_ids : list[str] + List of port IDs to clear learned MAC addresses on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"ports": port_ids} + trigger = devices.clearSiteDeviceDot1xSession( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear learned MACs command triggered for device {device_id}") + else: + LOGGER.error( + f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" + ) # Give the clear learned MACs command a moment to take effect + return util_response diff --git a/src/mistapi/utils/ospf.py b/src/mistapi/utils/ospf.py new file mode 100644 index 0000000..36ed711 --- /dev/null +++ b/src/mistapi/utils/ospf.py @@ -0,0 +1,273 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper + + +class Node(Enum): + """Node Enum for specifying node information in OSPF commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +async def show_database( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + self_originate: bool | None = None, + vrf: str | None = None, + timeout=5, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Shows OSPF database on a device (SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show OSPF database on. + node : Node, optional + Node information for the show OSPF database command. + self_originate : bool, optional + Filter for self-originated routes in the OSPF database. + vrf : str, optional + VRF to filter the OSPF database. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if self_originate is not None: + body["self_originate"] = self_originate + if vrf: + body["vrf"] = vrf + trigger = devices.showSiteGatewayOspfDatabase( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"OSPF database command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger OSPF database command: {trigger.status_code} - {trigger.data}" + ) # Give the OSPF database command a moment to take effect + return util_response + + +async def show_interfaces( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + port_id: str | None = None, + vrf: str | None = None, + timeout=5, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Shows OSPF interfaces on a device (SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show OSPF interfaces on. + node : Node, optional + Node information for the show OSPF interfaces command. + port_id : str, optional + Port ID to filter the OSPF interfaces. + vrf : str, optional + VRF to filter the OSPF interfaces. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if port_id: + body["port_id"] = port_id + if vrf: + body["vrf"] = vrf + trigger = devices.showSiteGatewayOspfInterfaces( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"OSPF interfaces command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger OSPF interfaces command: {trigger.status_code} - {trigger.data}" + ) # Give the OSPF interfaces command a moment to take effect + return util_response + + +async def show_neighbors( + apissession: _APISession, + site_id: str, + device_id: str, + neighbor: str | None = None, + node: Node | None = None, + port_id: str | None = None, + vrf: str | None = None, + timeout=5, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Shows OSPF neighbors on a device (SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show OSPF neighbors on. + neighbor : str, optional + Neighbor IP address to filter the OSPF neighbors. + node : Node, optional + Node information for the show OSPF neighbors command. + port_id : str, optional + Port ID to filter the OSPF neighbors. + vrf : str, optional + VRF to filter the OSPF neighbors. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if port_id: + body["port_id"] = port_id + if vrf: + body["vrf"] = vrf + if neighbor: + body["neighbor"] = neighbor + trigger = devices.showSiteGatewayOspfNeighbors( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"OSPF neighbors command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger OSPF neighbors command: {trigger.status_code} - {trigger.data}" + ) # Give the OSPF neighbors command a moment to take effect + return util_response + + +async def show_summary( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + vrf: str | None = None, + timeout=5, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Shows OSPF summary on a device (SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show OSPF summary on. + node : Node, optional + Node information for the show OSPF summary command. + vrf : str, optional + VRF to filter the OSPF summary. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if vrf: + body["vrf"] = vrf + trigger = devices.showSiteGatewayOspfSummary( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"OSPF summary command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger OSPF summary command: {trigger.status_code} - {trigger.data}" + ) # Give the OSPF summary command a moment to take effect + return util_response diff --git a/src/mistapi/utils/policy.py b/src/mistapi/utils/policy.py new file mode 100644 index 0000000..77828d5 --- /dev/null +++ b/src/mistapi/utils/policy.py @@ -0,0 +1,62 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.utils.__ws_wrapper import UtilResponse + + +async def clear_hit_count( + apissession: _APISession, + site_id: str, + device_id: str, + policy_name: str, +) -> UtilResponse: + """ + DEVICE: EX + + Clears the policy hit count on a device. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to clear the policy hit count on. + policy_name : str + Name of the policy to clear the hit count for. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + trigger = devices.clearSiteDevicePolicyHitCount( + apissession, + site_id=site_id, + device_id=device_id, + body={"policy_name": policy_name}, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Clear policy hit count command triggered for device {device_id}") + # util_response = await WebSocketWrapper( + # apissession, util_response, timeout=timeout + # ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger clear policy hit count command: {trigger.status_code} - {trigger.data}" + ) # Give the clear policy hit count command a moment to take effect + return util_response diff --git a/src/mistapi/utils/port.py b/src/mistapi/utils/port.py new file mode 100644 index 0000000..b13040a --- /dev/null +++ b/src/mistapi/utils/port.py @@ -0,0 +1,122 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper + + +async def bounce( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + timeout=60, +) -> UtilResponse: + """ + DEVICE: EX, SRX, SSR + + Initiates a bounce command on the specified ports of a device (EX / SRX / SSR) and streams + the results. + + PARAMS + ----------- + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to perform the bounce command on. + port_ids : list[str] + List of port IDs to bounce. + timeout : int, async default 5 + Timeout for the bounce command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if port_ids: + body["ports"] = port_ids + trigger = devices.bounceDevicePort( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info( + f"Bounce command triggered for ports {port_ids} on device {device_id}" + ) + util_response = await WebSocketWrapper( + apissession, util_response, timeout + ).startCmdEvents(site_id=site_id, device_id=device_id) + else: + LOGGER.error( + f"Failed to trigger bounce command: {trigger.status_code} - {trigger.data}" + ) # Give the bounce command a moment to take effect + return util_response + + +async def cable_test( + apissession: _APISession, + site_id: str, + device_id: str, + port_id: str, + timeout=10, +) -> UtilResponse: + """ + DEVICES: EX + + Initiates a cable test on a switch port and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to perform the cable test on. + port_id : str + Port ID to perform the cable test on. + timeout : int, optional + Timeout for the cable test command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"port": port_id} + trigger = devices.cableTestFromSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Cable test command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger cable test command: {trigger.status_code} - {trigger.data}" + ) # Give the cable test command a moment to take effect + return util_response diff --git a/src/mistapi/utils/routes.py b/src/mistapi/utils/routes.py new file mode 100644 index 0000000..ff0f511 --- /dev/null +++ b/src/mistapi/utils/routes.py @@ -0,0 +1,107 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper + + +class Node(Enum): + NODE0 = "node0" + NODE1 = "node1" + + +class RouteProtocol(Enum): + ANY = "any" + BGP = "bgp" + DIRECT = "direct" + EVPN = "evpn" + OSPF = "ospf" + STATIC = "static" + + +async def show( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + prefix: str | None = None, + protocol: RouteProtocol | None = None, + route_type: str | None = None, + vrf: str | None = None, + timeout=2, +) -> UtilResponse: + """ + DEVICE: SSR, SRX + + Initiates a show routes command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the show routes command on. + node : Node, optional + Node information for the show routes command. + prefix : str, optional + Prefix to filter the routes. + protocol : RouteProtocol, optional + Protocol to filter the routes. + route_type : str, optional + Type of the route to filter. + vrf : str, optional + VRF to filter the routes. + timeout : int, optional + Timeout for the command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if prefix: + body["prefix"] = prefix + if protocol: + body["protocol"] = protocol.value + if route_type: + body["route_type"] = route_type + if vrf: + body["vrf"] = vrf + trigger = devices.showSiteSsrAndSrxRoutes( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Device Routes command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger Device Routes command: {trigger.status_code} - {trigger.data}" + ) # Give the Device Routes command a moment to take effect + return util_response diff --git a/src/mistapi/utils/service_path.py b/src/mistapi/utils/service_path.py new file mode 100644 index 0000000..1302b9c --- /dev/null +++ b/src/mistapi/utils/service_path.py @@ -0,0 +1,84 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper + + +class Node(Enum): + """Node Enum for specifying node information in service path commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +async def show_service_path( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + service_name: str | None = None, + timeout: int = 5, +) -> UtilResponse: + """ + DEVICES: SSR + + Initiates a show service path command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the show service path command on. + node : Node, optional + Node information for the show service path command. + service_name : str, optional + Name of the service to show the path for. + timeout : int, optional + Timeout for the command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if service_name: + body["service_name"] = service_name + trigger = devices.showSiteSsrServicePath( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"SSR service path command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger SSR service path command: {trigger.status_code} - {trigger.data}" + ) # Give the SSR service path command a moment to take effect + return util_response diff --git a/src/mistapi/utils/sessions.py b/src/mistapi/utils/sessions.py new file mode 100644 index 0000000..dde41f2 --- /dev/null +++ b/src/mistapi/utils/sessions.py @@ -0,0 +1,162 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper + + +class Node(Enum): + """Node Enum for specifying node information in session commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +async def clear( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + service_name: str | None = None, + service_ids: list[str] | None = None, + vrf: str | None = None, + timeout=2, +) -> UtilResponse: + """ + DEVICE: SSR, SRX + + Initiates a clear sessions command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the show routes command on. + node : Node, optional + Node information for the show routes command. + prefix : str, optional + Prefix to filter the routes. + protocol : RouteProtocol, optional + Protocol to filter the routes. + route_type : str, optional + Type of the route to filter. + vrf : str, optional + VRF to filter the routes. + timeout : int, optional + Timeout for the command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if service_name: + body["service_name"] = service_name + if service_ids: + body["service_ids"] = service_ids + if vrf: + body["vrf"] = vrf + trigger = devices.clearSiteDeviceSession( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Device Sessions command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger Device Sessions command: {trigger.status_code} - {trigger.data}" + ) # Give the Device Sessions command a moment to take effect + return util_response + + +async def show( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + service_name: str | None = None, + service_ids: list[str] | None = None, + timeout=2, +) -> UtilResponse: + """ + DEVICE: SSR, SRX + + Initiates a show sessions command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the show sessions command on. + node : Node, optional + Node information for the show sessions command. + service_name : str, optional + Name of the service to filter the sessions. + service_ids : list[str], optional + List of service IDs to filter the sessions. + timeout : int, optional + Timeout for the command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if service_name: + body["service_name"] = service_name + if service_ids: + body["service_ids"] = service_ids + trigger = devices.showSiteSsrAndSrxSessions( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Device Sessions command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger Device Sessions command: {trigger.status_code} - {trigger.data}" + ) # Give the Device Sessions command a moment to take effect + return util_response diff --git a/src/mistapi/utils/srx.py b/src/mistapi/utils/srx.py new file mode 100644 index 0000000..6d8148b --- /dev/null +++ b/src/mistapi/utils/srx.py @@ -0,0 +1,61 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- + +Utility functions for Juniper SRX Firewalls. + +This module provides a device-specific namespace for SRX firewall utilities. +All functions are imported from their respective functional modules. +""" + +# Re-export shared classes and types +from mistapi.utils.arp import Node + +# ARP functions +from mistapi.utils.arp import retrieve_junos_arp_table as retrieve_arp_table + +# BGP functions +from mistapi.utils.bgp import show_summary as show_bgp_summary + +# DHCP functions +from mistapi.utils.dhcp import release_dhcp_leases, retrieve_dhcp_leases + +# Policy functions +from mistapi.utils.policy import clear_hit_count + +# Port functions +from mistapi.utils.port import bounce as bounce_port + +# Route functions +from mistapi.utils.routes import show + +# Tools (ping, monitor traffic) +from mistapi.utils.tools import monitor_traffic, ping + +__all__ = [ + # Classes/Enums + "Node", + # ARP + "retrieve_arp_table", + # BGP + "show_bgp_summary", + # DHCP + "release_dhcp_leases", + "retrieve_dhcp_leases", + # Port + "bounce_port", + # Policy + "clear_hit_count", + # Routes + "show", + # Tools + "monitor_traffic", + "ping", +] diff --git a/src/mistapi/utils/ssr.py b/src/mistapi/utils/ssr.py new file mode 100644 index 0000000..9f7afad --- /dev/null +++ b/src/mistapi/utils/ssr.py @@ -0,0 +1,65 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- + +Utility functions for Juniper Session Smart Routers (SSR). + +This module provides a device-specific namespace for SSR router utilities. +All functions are imported from their respective functional modules. +""" + +# Re-export shared classes and types +from mistapi.utils.arp import Node + +# ARP functions +from mistapi.utils.arp import retrieve_ssr_arp_table as retrieve_arp_table + +# BGP functions +from mistapi.utils.bgp import show_summary as show_bgp_summary + +# DHCP functions +from mistapi.utils.dhcp import release_dhcp_leases, retrieve_dhcp_leases + +# DNS functions +from mistapi.utils.dns import test_resolution as test_dns_resolution + +# Policy functions +from mistapi.utils.policy import clear_hit_count + +# Port functions +from mistapi.utils.port import bounce as bounce_port + +# Service Path functions +from mistapi.utils.service_path import show_service_path + +# Tools (ping only - no monitor_traffic for SSR) +from mistapi.utils.tools import ping + +__all__ = [ + # Classes/Enums + "Node", + # ARP + "retrieve_arp_table", + # BGP + "show_bgp_summary", + # DHCP + "release_dhcp_leases", + "retrieve_dhcp_leases", + # DNS + "test_dns_resolution", + # Port + "bounce_port", + # Policy + "clear_hit_count", + # Service Path + "show_service_path", + # Tools + "ping", +] diff --git a/src/mistapi/utils/tools.py b/src/mistapi/utils/tools.py new file mode 100644 index 0000000..aca66c1 --- /dev/null +++ b/src/mistapi/utils/tools.py @@ -0,0 +1,740 @@ +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices, pcaps +from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper + + +class Node(Enum): + """Node Enum for specifying node information in commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +class TracerouteProtocol(Enum): + """Enum for specifying protocol in traceroute command.""" + + ICMP = "icmp" + UDP = "udp" + + +async def ping( + apissession: _APISession, + site_id: str, + device_id: str, + host: str, + count: int | None = None, + node: None | None = None, + size: int | None = None, + vrf: str | None = None, + timeout: int = 3, +) -> UtilResponse: + """ + DEVICES: AP, EX, SRX, SSR + + Initiates a ping command from a device (AP / EX/ SRX / SSR) to a specified host and + streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to initiate the ping from. + host : str + The host to ping. + count : int, optional + Number of ping requests to send. + node : None, optional + Node information for the ping command. + size : int, optional + Size of the ping packet. + vrf : str, optional + VRF to use for the ping command. + timeout : int, optional + Timeout for the ping command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if count: + body["count"] = count + if host: + body["host"] = host + if node: + body["node"] = node.value + if size: + body["size"] = size + if vrf: + body["vrf"] = vrf + trigger = devices.pingFromDevice( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Ping command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger ping command: {trigger.status_code} - {trigger.data}" + ) # Give the ping command a moment to take effect + return util_response + + +## NO DATA +# async def service_ping( +# apissession: _APISession, +# site_id: str, +# device_id: str, +# host: str, +# service: str, +# tenant: str, +# count: int | None = None, +# node: None | None = None, +# size: int | None = None, +# timeout: int = 3, +# ) -> UtilResponse: +# """ +# DEVICES: SSR + +# Initiates a service ping command from a SSR to a specified host and streams the results. + +# PARAMS +# ----------- +# apissession : _APISession +# The API session to use for the request. +# site_id : str +# UUID of the site where the device is located. +# device_id : str +# UUID of the device to initiate the ping from. +# host : str +# The host to ping. +# service : str +# The service to ping. +# tenant : str +# Tenant to use for the ping command. +# count : int, optional +# Number of ping requests to send. +# node : None, optional +# Node information for the ping command. +# size : int, optional +# Size of the ping packet. +# timeout : int, optional +# Timeout for the ping command in seconds. + +# RETURNS +# ----------- +# UtilResponse +# A UtilResponse object containing the API response and a list of raw messages received +# from the WebSocket stream. +# """ +# body: dict[str, str | list | int] = {} +# if count: +# body["count"] = count +# if host: +# body["host"] = host +# if node: +# body["node"] = node.value +# if size: +# body["size"] = size +# if tenant: +# body["tenant"] = tenant +# if service: +# body["service"] = service +# trigger = devices.servicePingFromSsr( +# apissession, +# site_id=site_id, +# device_id=device_id, +# body=body, +# ) +# util_response = UtilResponse(trigger) +# if trigger.status_code == 200: +# LOGGER.info(f"Service Ping command triggered for device {device_id}") +# util_response = await WebSocketWrapper( +# apissession, util_response, timeout +# ).startCmdEvents(site_id, device_id) +# else: +# LOGGER.error( +# f"Failed to trigger Service Ping command: {trigger.status_code} - {trigger.data}" +# ) # Give the ping command a moment to take effect +# return util_response + + +async def traceroute( + apissession: _APISession, + site_id: str, + device_id: str, + host: str, + protocol: TracerouteProtocol = TracerouteProtocol.ICMP, + port: int | None = None, + timeout: int = 10, +) -> UtilResponse: + """ + DEVICES: AP, EX, SRX, SSR + + Initiates a traceroute command from a device (AP / EX/ SRX / SSR) to a specified host and + streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to initiate the traceroute from. + host : str + The host to traceroute. + protocol : TracerouteProtocol, optional + Protocol to use for the traceroute command (icmp or udp). + port : int, optional + Port to use for UDP traceroute. + timeout : int, optional + Timeout for the traceroute command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"host": host} + if protocol: + body["protocol"] = protocol.value + if port: + body["port"] = port + trigger = devices.tracerouteFromDevice( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Traceroute command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout + ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger traceroute command: {trigger.status_code} - {trigger.data}" + ) # Give the traceroute command a moment to take effect + return util_response + + +async def monitor_traffic( + apissession: _APISession, + site_id: str, + device_id: str, + port_id: str | None = None, + timeout=30, +) -> UtilResponse: + """ + DEVICE: EX, SRX + + Initiates a monitor traffic command on the device and streams the results. + + * if `port_id` is provided, JUNOS uses cmd "monitor interface" to monitor traffic on particular + * if `port_id` is not provided, JUNOS uses cmd "monitor interface traffic" to monitor traffic + on all ports + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to monitor traffic on. + port_id : str, optional + Port ID to filter the traffic. + timeout : int, optional + Timeout for the monitor traffic command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | int] = {"duration": 60} + if port_id: + body["port"] = port_id + trigger = devices.monitorSiteDeviceTraffic( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Monitor traffic command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startSessionUrl(trigger.data.get("url", "")) + else: + LOGGER.error( + f"Failed to trigger monitor traffic command: {trigger.status_code} - {trigger.data}" + ) # Give the monitor traffic command a moment to take effect + return util_response + + +async def ap_remote_pcap_wireless( + apissession: _APISession, + site_id: str, + device_id: str, + band: str, + tcpdump_expression: str | None = None, + ssid: str | None = None, + ap_mac: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, +) -> UtilResponse: + """ + DEVICE: AP + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + band : str + Comma-separated list of radio bands (24, 5, or 6). + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "type mgt or type ctl -vvv -tttt -en" + ssid : str, optional + SSID to filter the wireless traffic. + ap_mac : str, optional + AP MAC address to filter the wireless traffic. + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | int] = { + "band": band, + "duration": duration, + "max_pkt_len": max_pkt_len, + "num_packets": num_packets, + "type": "radiotap", + "format": "stream", + } + if ssid: + body["ssid"] = ssid + if ap_mac: + body["ap_mac"] = ap_mac + if tcpdump_expression: + body["tcpdump_expression"] = tcpdump_expression + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startRemotePcap(site_id) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +async def ap_remote_pcap_wired( + apissession: _APISession, + site_id: str, + device_id: str, + tcpdump_expression: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, +) -> UtilResponse: + """ + DEVICE: AP + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "udp port 67 or udp port 68 -vvv -tttt -en" + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | int] = { + "duration": duration, + "max_pkt_len": max_pkt_len, + "num_packets": num_packets, + "type": "wired", + "format": "stream", + } + if tcpdump_expression: + body["tcpdump_expression"] = tcpdump_expression + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startRemotePcap(site_id) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +async def srx_remote_pcap( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + tcpdump_expression: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, +) -> UtilResponse: + """ + DEVICE: SRX + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + port_ids : list[str] + List of port IDs to monitor. + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "udp port 67 or udp port 68 -vvv -tttt -en" + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + gateway_mac = device_id.split("-")[-1] + body: dict[str, str | int | dict] = { + "duration": duration, + "max_pkt_len": max_pkt_len, + "num_packets": num_packets, + "gateways": {gateway_mac: {"ports": {}}}, + "type": "gateway", + "format": "stream", + } + for port_id in port_ids: + gateway_dict = body["gateways"] + assert isinstance(gateway_dict, dict) + mac_dict = gateway_dict[gateway_mac] + assert isinstance(mac_dict, dict) + ports_dict = mac_dict["ports"] + assert isinstance(ports_dict, dict) + ports_dict[port_id] = {"tcpdump_expression": tcpdump_expression} + if tcpdump_expression: + body["tcpdump_expression"] = tcpdump_expression + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startRemotePcap(site_id) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +async def ssr_remote_pcap( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + tcpdump_expression: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, +) -> UtilResponse: + """ + DEVICE: SSR + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + port_ids : list[str] + List of port IDs to monitor. + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "udp port 67 or udp port 68 -vvv -tttt -en" + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + gateway_mac = device_id.split("-")[-1] + body: dict[str, str | int | dict] = { + "duration": duration, + "max_pkt_len": max_pkt_len, + "num_packets": num_packets, + "raw": False, + "gateways": {gateway_mac: {"ports": {}}}, + "type": "gateway", + "format": "stream", + } + for port_id in port_ids: + gateway_dict = body["gateways"] + assert isinstance(gateway_dict, dict) + mac_dict = gateway_dict[gateway_mac] + assert isinstance(mac_dict, dict) + ports_dict = mac_dict["ports"] + assert isinstance(ports_dict, dict) + ports_dict[port_id] = {"tcpdump_expression": tcpdump_expression} + if tcpdump_expression: + body["tcpdump_expression"] = tcpdump_expression + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startRemotePcap(site_id) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +async def ex_remote_pcap( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + tcpdump_expression: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, +) -> UtilResponse: + """ + DEVICE: EX + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + port_ids : list[str] + List of port IDs to monitor. + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "udp port 67 or udp port 68 -vvv -tttt -en" + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + switch_mac = device_id.split("-")[-1] + body: dict[str, str | int | dict] = { + "duration": duration, + "max_pkt_len": max_pkt_len, + "num_packets": num_packets, + "switches": {switch_mac: {"ports": {}}}, + "type": "switch", + "format": "stream", + } + for port_id in port_ids: + switch_dict = body["switches"] + assert isinstance(switch_dict, dict) + mac_dict = switch_dict[switch_mac] + assert isinstance(mac_dict, dict) + ports_dict = mac_dict["ports"] + assert isinstance(ports_dict, dict) + ports_dict[port_id] = {"tcpdump_expression": tcpdump_expression} + if tcpdump_expression: + body["tcpdump_expression"] = tcpdump_expression + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + util_response = await WebSocketWrapper( + apissession, util_response, timeout=timeout + ).startRemotePcap(site_id) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +## NO DATA +# async def srx_top_command( +# apissession: _APISession, +# site_id: str, +# device_id: str, +# timeout=10, +# ) -> UtilResponse: +# """ +# DEVICE: SRX + +# For SRX Only. Initiates a top command on the device and streams the results. + +# PARAMS +# ----------- +# apissession : _APISession +# The API session to use for the request. +# site_id : str +# UUID of the site where the device is located. +# device_id : str +# UUID of the device to run the top command on. +# timeout : int, optional +# Timeout for the top command in seconds. + +# RETURNS +# ----------- +# UtilResponse +# A UtilResponse object containing the API response and a list of raw messages received +# from the WebSocket stream. +# """ +# trigger = devices.runSiteSrxTopCommand( +# apissession, +# site_id=site_id, +# device_id=device_id, +# ) +# util_response = UtilResponse(trigger) +# if trigger.status_code == 200: +# LOGGER.info(trigger.data) +# print(f"Top command triggered for device {device_id}") +# util_response = await WebSocketWrapper( +# apissession, util_response, timeout=timeout +# ).startSessionUrl(site_id) +# else: +# LOGGER.error( +# f"Failed to trigger top command: {trigger.status_code} - {trigger.data}" +# ) # Give the top command a moment to take effect +# return util_response diff --git a/src/mistapi/websockets/__init__.py b/src/mistapi/websockets/__init__.py index 0e89fd7..81203b7 100644 --- a/src/mistapi/websockets/__init__.py +++ b/src/mistapi/websockets/__init__.py @@ -8,17 +8,14 @@ This package is licensed under the MIT License. -------------------------------------------------------------------------------- -WebSocket channel classes for real-time Mist API streaming. - -Usage example:: - - import mistapi - session = mistapi.APISession(...) - session.login() - - ws = mistapi.websockets.sites.SiteDeviceStatsEvents(session, site_id="") - ws.on_message(lambda data: print(data)) - ws.connect() """ -from mistapi.websockets import location, orgs, session, sites, utils +from mistapi.websockets import __ws_client, location, orgs, session, sites + +__all__ = [ + "location", + "orgs", + "session", + "sites", + "__ws_client", +] diff --git a/src/mistapi/websockets/sites.py b/src/mistapi/websockets/sites.py index 27db901..4c24cc4 100644 --- a/src/mistapi/websockets/sites.py +++ b/src/mistapi/websockets/sites.py @@ -321,8 +321,8 @@ class PcapEvents(_MistWebsocket): ----------- mist_session : mistapi.APISession Authenticated API session. - site_ids : list[str] - UUID of the sites to stream events from. + site_id : str + UUID of the site to stream events from. ping_interval : int, default 30 Interval in seconds to send WebSocket ping frames (keep-alive). ping_timeout : int, default 10 @@ -355,11 +355,11 @@ class PcapEvents(_MistWebsocket): def __init__( self, mist_session: APISession, - site_ids: list[str], + site_id: str, ping_interval: int = 30, ping_timeout: int = 10, ) -> None: - channels = [f"/sites/{site_id}/pcap" for site_id in site_ids] + channels = [f"/sites/{site_id}/pcaps"] super().__init__( mist_session, channels=channels, diff --git a/src/mistapi/websockets/utils/__init__.py b/src/mistapi/websockets/utils/__init__.py deleted file mode 100644 index 943352e..0000000 --- a/src/mistapi/websockets/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from mistapi.websockets.utils import common, gateway, junos, switch diff --git a/src/mistapi/websockets/utils/__ws_wrapper.py b/src/mistapi/websockets/utils/__ws_wrapper.py deleted file mode 100644 index 1ba03af..0000000 --- a/src/mistapi/websockets/utils/__ws_wrapper.py +++ /dev/null @@ -1,155 +0,0 @@ -import json -import threading -import time - -from mistapi import APISession -from mistapi.__api_response import APIResponse as _APIResponse -from mistapi.__logger import logger as LOGGER -from mistapi.websockets.session import SessionWithUrl -from mistapi.websockets.sites import DeviceCmdEvents - - -class UtilResponse: - """ - A simple class to encapsulate the response from utility WebSocket functions. - This class can be extended in the future to include additional metadata or helper methods. - """ - - def __init__( - self, - api_response: _APIResponse, - ) -> None: - self.trigger_api_response = api_response - self.ws_required: bool = False # This can be set to True if the WebSocket connection was successfully initiated - self.ws_data: list[str] = [] - self.ws_raw_events: list[str] = [] - - -class WebSocketWrapper: - """ - A wrapper class for managing WebSocket connections and events. - This class provides a simplified interface for connecting to WebSocket channels, - handling messages, and managing connection timeouts. - """ - - def __init__( - self, - apissession: APISession, - util_response: UtilResponse, - timeout: int = 10, - max_duration: int = 60, - ) -> None: - self.apissession = apissession - self.util_response = util_response - self.timeout_timer = None - self.timeout = timeout - self.max_duration_timer = None - self.max_duration = max_duration - self.received_messages = 0 - self.data = [] - self.raw_events = [] - self.ws = None - - def _on_open(self): - LOGGER.info("WebSocket connection opened") - if self.max_duration_timer and self.ws: - self.max_duration_timer = threading.Timer( - self.max_duration, self.ws.disconnect - ) - self.max_duration_timer.start() - self._reset_timer() # Start the timer when the connection opens - - def _reset_timer(self): - if self.timeout_timer: - self.timeout_timer.cancel() - if self.ws: - self.timeout_timer = threading.Timer(self.timeout, self.ws.disconnect) - self.timeout_timer.start() - - def _extract_raw(self, message): - self.raw_events.append(message) - event = message - if isinstance(event, str): - try: - event = json.loads(message) - if isinstance(event, dict) and "raw" in event: - return event["raw"] - except json.JSONDecodeError: - return - if event.get("event") == "data" and event.get("data"): - return self._extract_raw(event["data"]) - elif event.get("raw"): - self.received_messages += 1 - LOGGER.debug(f"Received raw message: {event['raw']}") - return event["raw"] - return None - - def _handle_message(self, msg): - if isinstance(msg, dict) and msg.get("event") == "channel_subscribed": - LOGGER.debug(msg) - else: - LOGGER.debug(msg) - raw = self._extract_raw(msg) - if raw: - self.data.append(raw) - self._reset_timer() # Reset timeout on each message - - async def startCmdEvents(self, site_id: str, device_id: str) -> UtilResponse: - """ - Start a WebSocket stream for site device command events. - - PARAMS - ----------- - site_id : str - UUID of the site to stream events from. - device_id : str - UUID of the device to stream events from. - """ - self.ws = DeviceCmdEvents( - self.apissession, site_id=site_id, device_ids=[device_id] - ) - self.ws.on_message(self._handle_message) - self.ws.on_error(lambda error: LOGGER.error(f"Error: {error}")) - self.ws.on_close( - lambda code, msg: LOGGER.info(f"WebSocket closed: {code} - {msg}") - ) - self.ws.on_open(self._on_open) - self.ws.connect() # non-blocking - LOGGER.info("WebSocket connection initiated") - time.sleep(1) - while self.ws and self.ws.ready(): - time.sleep(1) - LOGGER.info("WebSocket connection closed, exiting") - self.util_response.ws_required = True - self.util_response.ws_data = self.data - self.util_response.ws_raw_events = self.raw_events - return self.util_response - - async def startSessionUrl(self, url: str) -> UtilResponse: - """ - Start a WebSocket stream using a custom URL. - This should be used when Mist is returning a WebSocket URL from an API call. - - PARAMS - ----------- - url : str - Full WebSocket URL to connect to (e.g., wss://api.mist.com/ws/v1/orgs/{org_id}/sites/{site_id}/devices/{device_id}/cmds). - - """ - self.ws = SessionWithUrl(self.apissession, url=url) - self.ws.on_message(self._handle_message) - self.ws.on_error(lambda error: LOGGER.error(f"Error: {error}")) - self.ws.on_close( - lambda code, msg: LOGGER.info(f"WebSocket closed: {code} - {msg}") - ) - self.ws.on_open(self._on_open) - self.ws.connect() # non-blocking - LOGGER.info("WebSocket connection initiated") - time.sleep(1) - while self.ws and self.ws.ready(): - time.sleep(1) - LOGGER.info("WebSocket connection closed, exiting") - self.util_response.ws_required = True - self.util_response.ws_data = self.data - self.util_response.ws_raw_events = self.raw_events - return self.util_response diff --git a/src/mistapi/websockets/utils/common.py b/src/mistapi/websockets/utils/common.py deleted file mode 100644 index 0410d5f..0000000 --- a/src/mistapi/websockets/utils/common.py +++ /dev/null @@ -1,545 +0,0 @@ -from enum import Enum - -from mistapi import APISession as _APISession -from mistapi.__logger import logger as LOGGER -from mistapi.api.v1.sites import devices -from mistapi.websockets.utils.__ws_wrapper import UtilResponse, WebSocketWrapper - - -class Node(Enum): - NODE0 = "node0" - NODE1 = "node1" - - -class RouteProtocol(Enum): - ANY = "any" - BGP = "bgp" - DIRECT = "direct" - EVPN = "evpn" - OSPF = "ospf" - STATIC = "static" - - -async def retrieve_arp_table( - apissession: _APISession, - site_id: str, - device_id: str, - node: Node | None = None, - timeout=5, -) -> UtilResponse: - """ - Retrieves the ARP table from a device and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to retrieve the ARP table from. - node : Node, optional - Node information for the ARP table retrieval command. - timeout : int, optional - Timeout for the ARP table retrieval command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - # AP is returnning RAW data - # SWITCH is returning ??? - # GATEWAY is returning JSON - body: dict[str, str | list | int] = {} - if node: - body["node"] = node.value - trigger = devices.arpFromDevice( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Show ARP command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" - ) # Give the show ARP command a moment to take effect - return util_response - - -async def bounce_ports( - apissession: _APISession, - site_id: str, - device_id: str, - port_ids: list[str], - timeout=5, -) -> UtilResponse: - """ - Initiates a bounce command on the specified ports of a device and streams the results. - - PARAMS - ----------- - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to perform the bounce command on. - port_ids : list[str] - List of port IDs to bounce. - timeout : int, async default 5 - Timeout for the bounce command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if port_ids: - body["ports"] = port_ids - trigger = devices.bounceDevicePort( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info( - f"Bounce command triggered for ports {port_ids} on device {device_id}" - ) - util_response = await WebSocketWrapper( - apissession, util_response, timeout - ).startCmdEvents(site_id=site_id, device_id=device_id) - else: - LOGGER.error( - f"Failed to trigger bounce command: {trigger.status_code} - {trigger.data}" - ) # Give the bounce command a moment to take effect - return util_response - - -async def clear_mac_table( - apissession: _APISession, - site_id: str, - device_id: str, - mac_address: str | None = None, - port_id: str | None = None, - vlan_id: str | None = None, - # timeout=30, -) -> UtilResponse: - """ - Clears the MAC table on a device and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to clear the MAC table from. - mac_address : str, optional - MAC address to clear from the MAC table. - port_id : str, optional - Port ID to clear from the MAC table. - vlan_id : str, optional - VLAN ID to clear from the MAC table. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - # AP is returnning RAW data - # SWITCH is returning ??? - # GATEWAY is returning JSON - body: dict[str, str | list | int] = {} - if mac_address: - body["mac_address"] = mac_address - if port_id: - body["port_id"] = port_id - if vlan_id: - body["vlan_id"] = vlan_id - trigger = devices.clearSiteDeviceMacTable( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Clear MAC Table command triggered for device {device_id}") - # util_response = await WebSocketWrapper( - # apissession, util_response, timeout=timeout - # ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger clear MAC Table command: {trigger.status_code} - {trigger.data}" - ) # Give the clear MAC Table command a moment to take effect - return util_response - - -async def release_dhcp_leases( - apissession: _APISession, - site_id: str, - device_id: str, - macs: list[str] | None = None, - network: str | None = None, - node: Node | None = None, - port_id: str | None = None, - timeout=5, -) -> UtilResponse: - """ - Releases DHCP leases on a device and streams the results. - valid combinations are: - - network - - network + macs - - network + port_id - - port_id - - port_id + macs - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to release DHCP leases on. - macs : list[str], optional - List of MAC addresses to release DHCP leases for. - network : str, optional - Network to release DHCP leases for. - node : Node, optional - Node information for the DHCP lease release command. - port_id : str, optional - Port ID to release DHCP leases for. - timeout : int, optional - Timeout for the release DHCP leases command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if macs: - body["macs"] = macs - if network: - body["network"] = network - if node: - body["node"] = node.value - if port_id: - body["port_id"] = port_id - trigger = devices.releaseSiteDeviceDhcpLease( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Release DHCP leases command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger release DHCP leases command: {trigger.status_code} - {trigger.data}" - ) # Give the release DHCP leases command a moment to take effect - return util_response - - -# TODO -async def retrieve_dhcp_leases( - apissession: _APISession, - site_id: str, - device_id: str, - network: str, - node: Node | None = None, - timeout=15, -) -> UtilResponse: - """ - Retrieves DHCP leases on a device and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to retrieve DHCP leases from. - network : str - Network to release DHCP leases for. - node : Node, optional - Node information for the DHCP lease release command. - port_id : str, optional - Port ID to release DHCP leases for. - timeout : int, optional - Timeout for the release DHCP leases command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {"network": network} - if node: - body["node"] = node.value - trigger = devices.showSiteDeviceDhcpLeases( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Retrieve DHCP leases command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger retrieve DHCP leases command: {trigger.status_code} - {trigger.data}" - ) # Give the release DHCP leases command a moment to take effect - return util_response - - -####################################################### -## Switch -####################################################### - - -async def switch_clear_bpdu_error( - apissession: _APISession, - site_id: str, - device_id: str, - port_ids: list[str], -) -> UtilResponse: - """ - Clears BPDU error state on the specified ports of a switch. - - PARAMS - ----------- - site_id : str - UUID of the site where the switch is located. - device_id : str - UUID of the switch to clear BPDU errors on. - port_ids : list[str] - List of port IDs to clear BPDU errors on. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - - body: dict[str, str | list | int] = {"ports": port_ids} - trigger = devices.clearBpduErrorsFromPortsOnSwitch( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Clear BPDU error command triggered for device {device_id}") - else: - LOGGER.error( - f"Failed to trigger clear BPDU error command: {trigger.status_code} - {trigger.data}" - ) # Give the clear BPDU error command a moment to take effect - return util_response - - -async def switch_clear_learned_mac( - apissession: _APISession, - site_id: str, - device_id: str, - port_ids: list[str], -) -> UtilResponse: - """ - Clears learned MAC addresses on the specified ports of a device. - - PARAMS - ----------- - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to clear learned MAC addresses on. - port_ids : list[str] - List of port IDs to clear learned MAC addresses on. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {"ports": port_ids} - trigger = devices.clearSiteDeviceDot1xSession( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Clear learned MACs command triggered for device {device_id}") - else: - LOGGER.error( - f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" - ) # Give the clear learned MACs command a moment to take effect - return util_response - - -async def switch_clear_dot1x_sessions( - apissession: _APISession, - site_id: str, - device_id: str, - port_ids: list[str], -) -> UtilResponse: - """ - Clears dot1x sessions on the specified ports of a switch. - - PARAMS - ----------- - site_id : str - UUID of the site where the switch is located. - device_id : str - UUID of the switch to clear dot1x sessions on. - port_ids : list[str] - List of port IDs to clear dot1x sessions on. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {"ports": port_ids} - trigger = devices.clearAllLearnedMacsFromPortOnSwitch( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Clear learned MACs command triggered for device {device_id}") - else: - LOGGER.error( - f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" - ) # Give the clear learned MACs command a moment to take effect - return util_response - - -####################################################### -## Websocket -####################################################### - - -async def ping( - apissession: _APISession, - site_id: str, - device_id: str, - host: str, - count: int | None = None, - node: None | None = None, - size: int | None = None, - vrf: str | None = None, - timeout: int = 5, -) -> UtilResponse: - """ - Initiates a ping command from a device to a specified host and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to initiate the ping from. - host : str - The host to ping. - count : int, optional - Number of ping requests to send. - node : None, optional - Node information for the ping command. - size : int, optional - Size of the ping packet. - vrf : str, optional - VRF to use for the ping command. - timeout : int, optional - Timeout for the ping command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if count: - body["count"] = count - if host: - body["host"] = host - if node: - body["node"] = node.value - if size: - body["size"] = size - if vrf: - body["vrf"] = vrf - trigger = devices.pingFromDevice( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Ping command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout - ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger ping command: {trigger.status_code} - {trigger.data}" - ) # Give the ping command a moment to take effect - return util_response - - -# async def gateway_dns_resolution( -# self, -# site_id: str, -# device_id: str, -# timeout=10, -# ) -> list[str]: -# """For SSR Only. Initiates a DNS resolution command on the gateway and streams the results.""" -# self.timeout = timeout -# trigger = testSiteSsrDnsResolution( -# apissession, -# site_id=site_id, -# device_id=device_id, -# ) -# if trigger.status_code == 200: -# print(trigger.data) -# print(f"SSR DNS resolution command triggered for device {device_id}") -# self.startCmdEvents(site_id, device_id) -# else: -# print( -# f"Failed to trigger SSR DNS resolution command: {trigger.status_code} - {trigger.data}" -# ) # Give the SSR DNS resolution command a moment to take effect -# return util_response - -####################################################### -## Websocket Session -####################################################### diff --git a/src/mistapi/websockets/utils/gateway.py b/src/mistapi/websockets/utils/gateway.py deleted file mode 100644 index 18f0916..0000000 --- a/src/mistapi/websockets/utils/gateway.py +++ /dev/null @@ -1,249 +0,0 @@ -from enum import Enum - -from mistapi import APISession as _APISession -from mistapi.__logger import logger as LOGGER -from mistapi.api.v1.sites import devices -from mistapi.websockets.utils.__ws_wrapper import UtilResponse, WebSocketWrapper - - -class Node(Enum): - NODE0 = "node0" - NODE1 = "node1" - - -class RouteProtocol(Enum): - ANY = "any" - BGP = "bgp" - DIRECT = "direct" - EVPN = "evpn" - OSPF = "ospf" - STATIC = "static" - - -async def show_routes( - apissession: _APISession, - site_id: str, - device_id: str, - node: Node | None = None, - prefix: str | None = None, - protocol: RouteProtocol | None = None, - route_type: str | None = None, - vrf: str | None = None, - timeout=5, -) -> UtilResponse: - """ - For SSR and SRX. Initiates a show service path command on the gateway and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the gateway is located. - device_id : str - UUID of the gateway to perform the show routes command on. - node : Node, optional - Node information for the show routes command. - prefix : str, optional - Prefix to filter the routes. - protocol : RouteProtocol, optional - Protocol to filter the routes. - route_type : str, optional - Type of the route to filter. - vrf : str, optional - VRF to filter the routes. - timeout : int, optional - Timeout for the command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - - body: dict[str, str | list | int] = {} - if node: - body["node"] = node.value - if prefix: - body["prefix"] = prefix - if protocol: - body["protocol"] = protocol.value - if route_type: - body["route_type"] = route_type - if vrf: - body["vrf"] = vrf - trigger = devices.showSiteSsrAndSrxRoutes( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"SSR service path command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger SSR service path command: {trigger.status_code} - {trigger.data}" - ) # Give the SSR service path command a moment to take effect - return util_response - - -async def test_dns_resolution( - apissession: _APISession, - site_id: str, - device_id: str, - node: Node | None = None, - hostname: str | None = None, - timeout=5, -) -> UtilResponse: - """ - For SSR Only. Initiates a DNS resolution command on the gateway and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the gateway is located. - device_id : str - UUID of the gateway to perform the DNS resolution command on. - node : Node, optional - Node information for the DNS resolution command. - hostname : str, optional - Hostname to resolve. - timeout : int, optional - Timeout for the command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if node: - body["node"] = node.value - if hostname: - body["hostname"] = hostname - trigger = devices.testSiteSsrDnsResolution( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"SSR DNS resolution command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger SSR DNS resolution command: {trigger.status_code} - {trigger.data}" - ) # Give the SSR DNS resolution command a moment to take effect - return util_response - - -async def show_service_path( - apissession: _APISession, - site_id: str, - device_id: str, - node: Node | None = None, - service_name: str | None = None, - timeout=5, -) -> UtilResponse: - """ - For SSR Only. Initiates a show service path command on the gateway and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the gateway is located. - device_id : str - UUID of the gateway to perform the show service path command on. - node : Node, optional - Node information for the show service path command. - service_name : str, optional - Name of the service to show the path for. - timeout : int, optional - Timeout for the command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if node: - body["node"] = node.value - if service_name: - body["service_name"] = service_name - trigger = devices.showSiteSsrServicePath( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"SSR service path command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger SSR service path command: {trigger.status_code} - {trigger.data}" - ) # Give the SSR service path command a moment to take effect - return util_response - - -async def clear_policy_hit_count( - apissession: _APISession, - site_id: str, - device_id: str, - policy_name: str, - # timeout: int = 10, -) -> UtilResponse: - """ - Clears the policy hit count on a device. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to clear the policy hit count on. - policy_name : str - Name of the policy to clear the hit count for. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - trigger = devices.clearSiteDevicePolicyHitCount( - apissession, - site_id=site_id, - device_id=device_id, - body={"policy_name": policy_name}, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Clear policy hit count command triggered for device {device_id}") - # util_response = await WebSocketWrapper( - # apissession, util_response, timeout=timeout - # ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger clear policy hit count command: {trigger.status_code} - {trigger.data}" - ) # Give the clear policy hit count command a moment to take effect - return util_response diff --git a/src/mistapi/websockets/utils/junos.py b/src/mistapi/websockets/utils/junos.py deleted file mode 100644 index 2b12f44..0000000 --- a/src/mistapi/websockets/utils/junos.py +++ /dev/null @@ -1,110 +0,0 @@ -from mistapi import APISession as _APISession -from mistapi.__logger import logger as LOGGER -from mistapi.api.v1.sites import devices -from mistapi.websockets.utils.__ws_wrapper import UtilResponse, WebSocketWrapper - - -# TODO -async def monitor_traffic( - apissession: _APISession, - site_id: str, - device_id: str, - port_id: str | None = None, - timeout=30, -) -> UtilResponse: - """ - For EX and SRX Only. Initiates a monitor traffic command on the device and streams the results. - - * if `port_id` is provided, JUNOS uses cmd "monitor interface" to monitor traffic on particular - * if `port_id` is not provided, JUNOS uses cmd "monitor interface traffic" to monitor traffic on all ports - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to monitor traffic on. - port_id : str, optional - Port ID to filter the traffic. - timeout : int, optional - Timeout for the monitor traffic command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | int] = {"duration": 60} - if port_id: - body["port"] = port_id - trigger = devices.monitorSiteDeviceTraffic( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Monitor traffic command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startSessionUrl(trigger.data.get("url", "")) - else: - LOGGER.error( - f"Failed to trigger monitor traffic command: {trigger.status_code} - {trigger.data}" - ) # Give the monitor traffic command a moment to take effect - return util_response - - -async def clear_policy_hit_count( - apissession: _APISession, - site_id: str, - device_id: str, - policy_name: str, - timeout=30, -) -> UtilResponse: - """ - For EX and SRX Only. Clears the policy hit count on the device. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to clear the policy hit count on. - policy_name : str - Name of the policy to clear the hit count for. - timeout : int, optional - Timeout for the clear policy hit count command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str] = {} - if policy_name: - body["policy_name"] = policy_name - trigger = devices.clearSiteDevicePolicyHitCount( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Clear policy hit count command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger clear policy hit count command: {trigger.status_code} - {trigger.data}" - ) # Give the clear policy hit count command a moment to take effect - return util_response diff --git a/src/mistapi/websockets/utils/switch.py b/src/mistapi/websockets/utils/switch.py deleted file mode 100644 index 8419b19..0000000 --- a/src/mistapi/websockets/utils/switch.py +++ /dev/null @@ -1,570 +0,0 @@ -from enum import Enum - -from mistapi import APISession as _APISession -from mistapi.__logger import logger as LOGGER -from mistapi.api.v1.sites import devices -from mistapi.websockets.utils.__ws_wrapper import UtilResponse, WebSocketWrapper - - -class Node(Enum): - NODE0 = "node0" - NODE1 = "node1" - - -class RouteProtocol(Enum): - ANY = "any" - BGP = "bgp" - DIRECT = "direct" - EVPN = "evpn" - OSPF = "ospf" - STATIC = "static" - - -async def bounce_ports( - apissession: _APISession, - site_id: str, - device_id: str, - port_ids: list[str], - timeout=5, -) -> UtilResponse: - """ - Initiates a bounce command on the specified ports of a device and streams the results. - - PARAMS - ----------- - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to perform the bounce command on. - port_ids : list[str] - List of port IDs to bounce. - timeout : int, async default 5 - Timeout for the bounce command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if port_ids: - body["ports"] = port_ids - trigger = devices.bounceDevicePort( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info( - f"Bounce command triggered for ports {port_ids} on device {device_id}" - ) - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id=site_id, device_id=device_id) - else: - LOGGER.error( - f"Failed to trigger bounce command: {trigger.status_code} - {trigger.data}" - ) # Give the bounce command a moment to take effect - return util_response - - -async def retrieve_arp_table( - apissession: _APISession, - site_id: str, - device_id: str, - ip: str | None = None, - port_id: str | None = None, - vrf: str | None = None, - timeout=5, -) -> UtilResponse: - """ - Retrieve the ARP table from a device with optional filters for IP, port, and VRF. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to retrieve the ARP table from. - ip : str, optional - IP address to filter the ARP table. - port_id : str, optional - Port ID to filter the ARP table. - vrf : str, optional - VRF to filter the ARP table. - timeout : int, optional - Timeout for the ARP table retrieval command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {"duration": 1, "interval": 1} - if ip: - body["ip"] = ip - if vrf: - body["vrf"] = vrf - if port_id: - body["port_id"] = port_id - trigger = devices.showSiteDeviceArpTable( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Show ARP command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" - ) # Give the show ARP command a moment to take effect - return util_response - - -####################################################### -## Switch -####################################################### - - -async def switch_clear_bpdu_error( - apissession: _APISession, - site_id: str, - device_id: str, - port_ids: list[str], -) -> UtilResponse: - """ - Clears BPDU error state on the specified ports of a switch. - - PARAMS - ----------- - site_id : str - UUID of the site where the switch is located. - device_id : str - UUID of the switch to clear BPDU errors on. - port_ids : list[str] - List of port IDs to clear BPDU errors on. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - - body: dict[str, str | list | int] = {"ports": port_ids} - trigger = devices.clearBpduErrorsFromPortsOnSwitch( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Clear BPDU error command triggered for device {device_id}") - else: - LOGGER.error( - f"Failed to trigger clear BPDU error command: {trigger.status_code} - {trigger.data}" - ) # Give the clear BPDU error command a moment to take effect - return util_response - - -async def switch_clear_learned_mac( - apissession: _APISession, - site_id: str, - device_id: str, - port_ids: list[str], -) -> UtilResponse: - """ - Clears learned MAC addresses on the specified ports of a device. - - PARAMS - ----------- - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to clear learned MAC addresses on. - port_ids : list[str] - List of port IDs to clear learned MAC addresses on. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {"ports": port_ids} - trigger = devices.clearSiteDeviceDot1xSession( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Clear learned MACs command triggered for device {device_id}") - else: - LOGGER.error( - f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" - ) # Give the clear learned MACs command a moment to take effect - return util_response - - -async def switch_clear_dot1x_sessions( - apissession: _APISession, - site_id: str, - device_id: str, - port_ids: list[str], -) -> UtilResponse: - """ - Clears dot1x sessions on the specified ports of a switch. - - PARAMS - ----------- - site_id : str - UUID of the site where the switch is located. - device_id : str - UUID of the switch to clear dot1x sessions on. - port_ids : list[str] - List of port IDs to clear dot1x sessions on. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {"ports": port_ids} - trigger = devices.clearAllLearnedMacsFromPortOnSwitch( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Clear learned MACs command triggered for device {device_id}") - else: - LOGGER.error( - f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" - ) # Give the clear learned MACs command a moment to take effect - return util_response - - -####################################################### -## Websocket -####################################################### - - -async def clear_mac_table( - apissession: _APISession, - site_id: str, - device_id: str, - mac_address: str, - port_id: str, - vlan_id: str, - timeout=5, -) -> UtilResponse: - """ - Clears the MAC table on a device for a specific MAC address, port, or VLAN and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to clear the MAC table on. - mac_address : str - MAC address to clear from the MAC table. - port_id : str - Port ID to clear the MAC table on. - vlan_id : str - VLAN ID to clear the MAC table on. - timeout : int, optional - Timeout for the clear MAC table command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if mac_address: - body["mac_address"] = mac_address - if port_id: - body["port_id"] = port_id - if vlan_id: - body["vlan_id"] = vlan_id - trigger = devices.clearSiteDeviceMacTable( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Clear MAC table command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger clear MAC table command: {trigger.status_code} - {trigger.data}" - ) # Give the clear MAC table command a moment to take effect - return util_response - - -async def ping( - apissession: _APISession, - site_id: str, - device_id: str, - host: str, - count: int | None = None, - node: None | None = None, - size: int | None = None, - vrf: str | None = None, - timeout: int = 5, -) -> UtilResponse: - """ - Initiates a ping command from a device to a specified host and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to initiate the ping from. - host : str - The host to ping. - count : int, optional - Number of ping requests to send. - node : None, optional - Node information for the ping command. - size : int, optional - Size of the ping packet. - vrf : str, optional - VRF to use for the ping command. - timeout : int, optional - Timeout for the ping command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if count: - body["count"] = count - if host: - body["host"] = host - if node: - body["node"] = node.value - if size: - body["size"] = size - if vrf: - body["vrf"] = vrf - trigger = devices.pingFromDevice( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Ping command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger ping command: {trigger.status_code} - {trigger.data}" - ) # Give the ping command a moment to take effect - return util_response - - -async def release_dhcp_leases( - apissession: _APISession, - site_id: str, - device_id: str, - macs: list[str] | None = None, - network: str | None = None, - node: Node | None = None, - port_id: str | None = None, - timeout=5, -) -> UtilResponse: - """ - Releases DHCP leases on a device and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to release DHCP leases on. - macs : list[str], optional - List of MAC addresses to release DHCP leases for. - network : str, optional - Network to release DHCP leases for. - node : Node, optional - Node information for the DHCP lease release command. - port_id : str, optional - Port ID to release DHCP leases for. - timeout : int, optional - Timeout for the release DHCP leases command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {} - if macs: - body["macs"] = macs - if network: - body["network"] = network - if node: - body["node"] = node.value - if port_id: - body["port_id"] = port_id - trigger = devices.releaseSiteDeviceDhcpLease( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(f"Release DHCP leases command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger release DHCP leases command: {trigger.status_code} - {trigger.data}" - ) # Give the release DHCP leases command a moment to take effect - return util_response - - -async def stream_arp_table( - apissession: _APISession, - site_id: str, - device_id: str, - ip: str | None = None, - port_id: str | None = None, - vrf: str | None = None, - timeout=5, -) -> UtilResponse: - """ - Streams the ARP table from a device with optional filters for IP, port, and VRF. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the device is located. - device_id : str - UUID of the device to retrieve the ARP table from. - ip : str, optional - IP address to filter the ARP table. - port_id : str, optional - Port ID to filter the ARP table. - vrf : str, optional - VRF to filter the ARP table. - timeout : int, optional - Timeout for the ARP table retrieval command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {"duration": 1, "interval": 1} - if ip: - body["ip"] = ip - if vrf: - body["vrf"] = vrf - if port_id: - body["port_id"] = port_id - trigger = devices.showSiteDeviceArpTable( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Show ARP command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" - ) # Give the show ARP command a moment to take effect - return util_response - - -async def switch_cable_test( - apissession: _APISession, - site_id: str, - device_id: str, - port_id: str, - timeout=10, -) -> UtilResponse: - """ - Initiates a cable test on a switch port and streams the results. - - PARAMS - ----------- - apissession : _APISession - The API session to use for the request. - site_id : str - UUID of the site where the switch is located. - device_id : str - UUID of the switch to perform the cable test on. - port_id : str - Port ID to perform the cable test on. - timeout : int, optional - Timeout for the cable test command in seconds. - - RETURNS - ----------- - UtilResponse - A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. - """ - body: dict[str, str | list | int] = {"port": port_id} - trigger = devices.cableTestFromSwitch( - apissession, - site_id=site_id, - device_id=device_id, - body=body, - ) - util_response = UtilResponse(trigger) - if trigger.status_code == 200: - LOGGER.info(trigger.data) - print(f"Cable test command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) - else: - LOGGER.error( - f"Failed to trigger cable test command: {trigger.status_code} - {trigger.data}" - ) # Give the cable test command a moment to take effect - return util_response diff --git a/uv.lock b/uv.lock index 2a738e2..7d80793 100644 --- a/uv.lock +++ b/uv.lock @@ -537,7 +537,7 @@ wheels = [ [[package]] name = "mistapi" -version = "0.60.3" +version = "0.55.15" source = { editable = "." } dependencies = [ { name = "deprecation" }, From ff8bf6fbc75f19bc5116bc9fbb5510256a1e3e8f Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Fri, 13 Mar 2026 16:36:04 +0100 Subject: [PATCH 07/16] refactor: consolidate error handling and improve security Major refactoring to reduce code duplication and address security concerns: * feat(__api_request): consolidate retry logic into _request_with_retry - Extract HTTP operations into inner functions (_do_get, _do_post, etc.) - Centralize error handling for all HTTP methods - Reduces code duplication by ~55 lines * security(__api_session): remove SSL verification bypass in Vault client - Remove verify=False from hvac.Client initialization - Make vault attributes private (_vault_url, _vault_path, etc.) - Improve vault credentials cleanup in finally block * feat(__api_session): improve session management - Add _new_session() helper for consistent session initialization - Add validate parameter to set_api_token() for optional token validation - Fix delete_api_token() to return APIResponse instead of Response - Use mist_delete() method instead of raw session.delete() * perf(__init__): implement lazy loading for heavy subpackages - Defer api and cli imports until accessed - Improves initial import performance * fix(__logger): correct logging sanitization - Use getMessage() instead of direct msg access - Clear record.args after sanitization to prevent re-formatting This refactoring improves maintainability, security, and performance without changing the public API surface. --- src/mistapi/__api_request.py | 384 ++++++++++++++--------------------- src/mistapi/__api_session.py | 89 +++++--- src/mistapi/__init__.py | 15 ++ src/mistapi/__logger.py | 3 +- 4 files changed, 220 insertions(+), 271 deletions(-) diff --git a/src/mistapi/__api_request.py b/src/mistapi/__api_request.py index e6f42a1..a5aabc5 100644 --- a/src/mistapi/__api_request.py +++ b/src/mistapi/__api_request.py @@ -17,7 +17,9 @@ import json import os import re -import sys +import time +import urllib.parse +from collections.abc import Callable from typing import Any import requests @@ -33,6 +35,9 @@ class APIRequest: Class handling API Request to the Mist Cloud """ + _MAX_429_RETRIES: int = 3 + _DEFAULT_RETRY_AFTER: int = 5 + def __init__(self) -> None: self._cloud_uri: str = "" self._session = requests.session() @@ -77,8 +82,7 @@ def _log_proxy(self) -> None: re.sub(pwd_regex, ":*********@", self._session.proxies["https"]), ) print( - "apirequest:sending request to proxy server %s", - re.sub(pwd_regex, ":*********@", self._session.proxies["https"]), + f"apirequest:sending request to proxy server {re.sub(pwd_regex, ':*********@', self._session.proxies['https'])}" ) def _next_apitoken(self) -> None: @@ -111,18 +115,16 @@ def _next_apitoken(self) -> None: " For large organization, it is recommended to configure" " multiple API Tokens (comma separated list) to avoid this issue" ) - logger.critical(" Exiting...") - sys.exit(255) + raise RuntimeError( + "API rate limit reached and no other API Token available. " + "For large organizations, configure multiple API Tokens " + "(comma separated list) to avoid this issue." + ) def _gen_query(self, query: dict[str, str] | None) -> str: - logger.debug(f"apirequest:_gen_query:processing query {query}") - html_query = "?" - if query: - for query_param in query: - html_query += f"{query_param}={query[query_param]}&" - logger.debug(f"apirequest:_gen_query:generated query:{html_query}") - html_query = html_query[:-1] - return html_query + if not query: + return "" + return "?" + urllib.parse.urlencode(query) def _remove_auth_from_headers(self, resp: requests.Response): headers = resp.request.headers @@ -157,6 +159,72 @@ def remove_file_from_body(self, resp: requests.Response): request_body += f"\r\n{i}" return request_body + def _handle_rate_limit(self, resp: requests.Response, attempt: int) -> None: + retry_after = resp.headers.get("Retry-After") + if retry_after: + try: + wait = int(retry_after) + except ValueError: + wait = self._DEFAULT_RETRY_AFTER * (2**attempt) + else: + wait = self._DEFAULT_RETRY_AFTER * (2**attempt) + logger.info( + "apirequest:rate_limited:sleeping %ss (attempt %s/%s)", + wait, + attempt + 1, + self._MAX_429_RETRIES, + ) + time.sleep(wait) + + def _request_with_retry( + self, method_name: str, request_fn: Callable, url: str + ) -> APIResponse: + """Shared retry wrapper for all HTTP methods.""" + resp = None + proxy_failed = False + for attempt in range(self._MAX_429_RETRIES + 1): + try: + logger.info(f"apirequest:{method_name}:sending request to {url}") + self._log_proxy() + resp = request_fn() + logger.debug( + f"apirequest:{method_name}:request headers:{self._remove_auth_from_headers(resp)}" + ) + resp.raise_for_status() + break + except requests.exceptions.ProxyError as e: + logger.error(f"apirequest:{method_name}:Proxy Error: {e}") + proxy_failed = True + break + except requests.exceptions.ConnectionError as e: + logger.error(f"apirequest:{method_name}:Connection Error: {e}") + break + except HTTPError as e: + if e.response.status_code == 429 and attempt < self._MAX_429_RETRIES: + logger.warning( + f"apirequest:{method_name}:HTTP 429 (attempt {attempt + 1}/{self._MAX_429_RETRIES})" + ) + try: + self._next_apitoken() + except RuntimeError: + pass # single token — still retry with backoff + self._handle_rate_limit(e.response, attempt) + continue + logger.error(f"apirequest:{method_name}:HTTP error: {e}") + if resp: + logger.error( + f"apirequest:{method_name}:HTTP error description: {resp.json()}" + ) + break + except Exception as e: + logger.error(f"apirequest:{method_name}:error: {e}") + logger.error( + f"apirequest:{method_name}:Exception occurred", exc_info=True + ) + break + self._count += 1 + return APIResponse(url=url, response=resp, proxy_error=proxy_failed) + def mist_get(self, uri: str, query: dict[str, str] | None = None) -> APIResponse: """ GET HTTP Request @@ -173,41 +241,8 @@ def mist_get(self, uri: str, query: dict[str, str] | None = None) -> APIResponse mistapi.APIResponse response from the API call """ - resp = None - proxy_failed = False - try: - url = self._url(uri) + self._gen_query(query) - logger.info(f"apirequest:mist_get:sending request to {url}") - self._log_proxy() - resp = self._session.get(url) - logger.debug( - f"apirequest:mist_get:request headers:{self._remove_auth_from_headers(resp)}" - ) - resp.raise_for_status() - except requests.exceptions.ProxyError as proxy_error: - logger.error(f"apirequest:mist_get:Proxy Error: {proxy_error}") - proxy_failed = True - except requests.exceptions.ConnectionError as connexion_error: - logger.error(f"apirequest:mist_get:Connection Error: {connexion_error}") - except HTTPError as http_err: - if http_err.response.status_code == 429: - logger.warning( - "apirequest:mist_get:" - "got HTTP Error 429 from Mist. Will try with next API Token" - ) - self._next_apitoken() - return self.mist_get(uri, query) - logger.error(f"apirequest:mist_get:HTTP error occurred: {http_err}") - if resp: - logger.error( - f"apirequest:mist_get:HTTP error description: {resp.json()}" - ) - except Exception as err: - logger.error(f"apirequest:mist_get:Other error occurred: {err}") - logger.error("apirequest:mist_get:Exception occurred", exc_info=True) - finally: - self._count += 1 - return APIResponse(url=url, response=resp, proxy_error=proxy_failed) + url = self._url(uri) + self._gen_query(query) + return self._request_with_retry("mist_get", lambda: self._session.get(url), url) def mist_post(self, uri: str, body: dict | list | None = None) -> APIResponse: """ @@ -224,48 +259,14 @@ def mist_post(self, uri: str, body: dict | list | None = None) -> APIResponse: mistapi.APIResponse response from the API call """ - resp = None - proxy_failed = False - try: - url = self._url(uri) - logger.info(f"apirequest:mist_post:sending request to {url}") - headers = {"Content-Type": "application/json"} - logger.debug(f"apirequest:mist_post:Request body:{body}") - if isinstance(body, str): - self._log_proxy() - resp = self._session.post(url, data=body, headers=headers) - else: - self._log_proxy() - resp = self._session.post(url, json=body, headers=headers) - logger.debug( - f"apirequest:mist_post:request headers:{self._remove_auth_from_headers(resp)}" - ) - logger.debug("apirequest:mist_post:request body: %s", resp.request.body) - resp.raise_for_status() - except requests.exceptions.ProxyError as proxy_error: - logger.error(f"apirequest:mist_post:Proxy Error: {proxy_error}") - proxy_failed = True - except requests.exceptions.ConnectionError as connexion_error: - logger.error(f"apirequest:mist_post:Connection Error: {connexion_error}") - except HTTPError as http_err: - if http_err.response.status_code == 429: - logger.warning( - "apirequest:mist_post:" - "got HTTP Error 429 from Mist. Will try with next API Token" - ) - self._next_apitoken() - return self.mist_post(uri, body) - logger.error(f"apirequest:mist_post: HTTP error occurred: {http_err}") - if resp: - logger.error( - f"apirequest:mist_post: HTTP error description: {resp.json()}" - ) - except Exception as err: - logger.error(f"apirequest:mist_post: Other error occurred: {err}") - logger.error("apirequest:mist_post: Exception occurred", exc_info=True) - finally: - self._count += 1 - return APIResponse(url=url, response=resp, proxy_error=proxy_failed) + url = self._url(uri) + headers = {"Content-Type": "application/json"} + logger.debug(f"apirequest:mist_post:Request body:{body}") + if isinstance(body, str): + fn = lambda: self._session.post(url, data=body, headers=headers) + else: + fn = lambda: self._session.post(url, json=body, headers=headers) + return self._request_with_retry("mist_post", fn, url) def mist_put(self, uri: str, body: dict | None = None) -> APIResponse: """ @@ -282,48 +283,14 @@ def mist_put(self, uri: str, body: dict | None = None) -> APIResponse: mistapi.APIResponse response from the API call """ - resp = None - proxy_failed = False - try: - url = self._url(uri) - logger.info(f"apirequest:mist_put:sending request to {url}") - headers = {"Content-Type": "application/json"} - logger.debug(f"apirequest:mist_put:Request body:{body}") - if isinstance(body, str): - self._log_proxy() - resp = self._session.put(url, data=body, headers=headers) - else: - self._log_proxy() - resp = self._session.put(url, json=body, headers=headers) - logger.debug( - f"apirequest:mist_put:request headers:{self._remove_auth_from_headers(resp)}" - ) - logger.debug("apirequest:mist_put:request body:%s", resp.request.body) - resp.raise_for_status() - except requests.exceptions.ProxyError as proxy_error: - logger.error(f"apirequest:mist_put:Proxy Error: {proxy_error}") - proxy_failed = True - except requests.exceptions.ConnectionError as connexion_error: - logger.error(f"apirequest:mist_put:Connection Error: {connexion_error}") - except HTTPError as http_err: - if http_err.response.status_code == 429: - logger.warning( - "apirequest:mist_put:" - "got HTTP Error 429 from Mist. Will try with next API Token" - ) - self._next_apitoken() - return self.mist_put(uri, body) - logger.error(f"apirequest:mist_put: HTTP error occurred: {http_err}") - if resp: - logger.error( - f"apirequest:mist_put: HTTP error description: {resp.json()}" - ) - except Exception as err: - logger.error(f"apirequest:mist_put: Other error occurred: {err}") - logger.error("apirequest:mist_put: Exception occurred", exc_info=True) - finally: - self._count += 1 - return APIResponse(url=url, response=resp, proxy_error=proxy_failed) + url = self._url(uri) + headers = {"Content-Type": "application/json"} + logger.debug(f"apirequest:mist_put:Request body:{body}") + if isinstance(body, str): + fn = lambda: self._session.put(url, data=body, headers=headers) + else: + fn = lambda: self._session.put(url, json=body, headers=headers) + return self._request_with_retry("mist_put", fn, url) def mist_delete(self, uri: str, query: dict | None = None) -> APIResponse: """ @@ -339,37 +306,10 @@ def mist_delete(self, uri: str, query: dict | None = None) -> APIResponse: mistapi.APIResponse response from the API call """ - resp = None - proxy_failed = False - try: - url = self._url(uri) + self._gen_query(query) - logger.info(f"apirequest:mist_delete:sending request to {url}") - self._log_proxy() - resp = self._session.delete(url) - logger.debug( - f"apirequest:mist_delete:request headers:{self._remove_auth_from_headers(resp)}" - ) - resp.raise_for_status() - except requests.exceptions.ProxyError as proxy_error: - logger.error(f"apirequest:mist_delete:Proxy Error: {proxy_error}") - proxy_failed = True - except requests.exceptions.ConnectionError as connexion_error: - logger.error(f"apirequest:mist_delete:Connection Error: {connexion_error}") - except HTTPError as http_err: - if http_err.response.status_code == 429: - logger.warning( - "apirequest:mist_delete:" - "got HTTP Error 429 from Mist. Will try with next API Token" - ) - self._next_apitoken() - return self.mist_delete(uri, query) - logger.error(f"apirequest:mist_delete: HTTP error occurred: {http_err}") - except Exception as err: - logger.error(f"apirequest:mist_delete: Other error occurred: {err}") - logger.error("apirequest:mist_delete: Exception occurred", exc_info=True) - finally: - self._count += 1 - return APIResponse(url=url, response=resp, proxy_error=proxy_failed) + url = self._url(uri) + self._gen_query(query) + return self._request_with_retry( + "mist_delete", lambda: self._session.delete(url), url + ) def mist_post_file( self, uri: str, multipart_form_data: dict | None = None @@ -389,85 +329,55 @@ def mist_post_file( mistapi.APIResponse response from the API call """ - resp = None - proxy_failed = False - try: - if multipart_form_data is None: - multipart_form_data = {} - url = self._url(uri) - logger.info(f"apirequest:mist_post_file:sending request to {url}") + if multipart_form_data is None: + multipart_form_data = {} + url = self._url(uri) + logger.debug( + f"apirequest:mist_post_file:initial multipart_form_data:{multipart_form_data}" + ) + generated_multipart_form_data: dict[str, Any] = {} + for key in multipart_form_data: logger.debug( - f"apirequest:mist_post_file:initial multipart_form_data:{multipart_form_data}" + f"apirequest:mist_post_file:" + f"multipart_form_data:{key} = {multipart_form_data[key]}" ) - generated_multipart_form_data: dict[str, Any] = {} - for key in multipart_form_data: - logger.debug( - f"apirequest:mist_post_file:" - f"multipart_form_data:{key} = {multipart_form_data[key]}" - ) - if multipart_form_data[key]: - try: - if key in ["csv", "file"]: - logger.debug( - f"apirequest:mist_post_file:reading file:{multipart_form_data[key]}" - ) - f = open(multipart_form_data[key], "rb") - generated_multipart_form_data[key] = ( - os.path.basename(multipart_form_data[key]), - f, - "application/octet-stream", - ) - else: - generated_multipart_form_data[key] = ( - None, - json.dumps(multipart_form_data[key]), - ) - except (OSError, json.JSONDecodeError): - logger.error( - f"apirequest:mist_post_file:multipart_form_data:" - f"Unable to parse JSON object {key} " - f"with value {multipart_form_data[key]}" + if multipart_form_data[key]: + try: + if key in ["csv", "file"]: + logger.debug( + f"apirequest:mist_post_file:reading file:{multipart_form_data[key]}" ) - logger.error( - "apirequest:mist_post_file: Exception occurred", - exc_info=True, + f = open(multipart_form_data[key], "rb") + generated_multipart_form_data[key] = ( + os.path.basename(multipart_form_data[key]), + f, + "application/octet-stream", ) - logger.debug( - f"apirequest:mist_post_file:" - f"final multipart_form_data:{generated_multipart_form_data}" - ) - self._log_proxy() + else: + generated_multipart_form_data[key] = ( + None, + json.dumps(multipart_form_data[key]), + ) + except (OSError, json.JSONDecodeError): + logger.error( + f"apirequest:mist_post_file:multipart_form_data:" + f"Unable to parse JSON object {key} " + f"with value {multipart_form_data[key]}" + ) + logger.error( + "apirequest:mist_post_file: Exception occurred", + exc_info=True, + ) + logger.debug( + f"apirequest:mist_post_file:" + f"final multipart_form_data:{generated_multipart_form_data}" + ) + + def _do_post_file(): resp = self._session.post(url, files=generated_multipart_form_data) - logger.debug( - f"apirequest:mist_post_file:request headers:{self._remove_auth_from_headers(resp)}" - ) logger.debug( f"apirequest:mist_post_file:request body:{self.remove_file_from_body(resp)}" ) - resp.raise_for_status() - except requests.exceptions.ProxyError as proxy_error: - logger.error(f"apirequest:mist_post_file:Proxy Error: {proxy_error}") - proxy_failed = True - except requests.exceptions.ConnectionError as connexion_error: - logger.error( - f"apirequest:mist_post_file:Connection Error: {connexion_error}" - ) - except HTTPError as http_err: - if http_err.response.status_code == 429: - logger.warning( - "apirequest:mist_post_file:" - "got HTTP Error 429 from Mist. Will try with next API Token" - ) - self._next_apitoken() - return self.mist_post_file(uri, multipart_form_data) - logger.error(f"apirequest:mist_post_file: HTTP error occurred: {http_err}") - if resp: - logger.error( - f"apirequest:mist_post_file: HTTP error description: {resp.json()}" - ) - except Exception as err: - logger.error(f"apirequest:mist_post_file: Other error occurred: {err}") - logger.error("apirequest:mist_post_file: Exception occurred", exc_info=True) - finally: - self._count += 1 - return APIResponse(url=url, response=resp, proxy_error=proxy_failed) + return resp + + return self._request_with_retry("mist_post_file", _do_post_file, url) diff --git a/src/mistapi/__api_session.py b/src/mistapi/__api_session.py index 154fb52..48e62d0 100644 --- a/src/mistapi/__api_session.py +++ b/src/mistapi/__api_session.py @@ -21,7 +21,7 @@ import keyring import requests from dotenv import load_dotenv -from requests import Response, Session +from requests import Session from mistapi.__api_request import APIRequest from mistapi.__api_response import APIResponse @@ -132,17 +132,26 @@ def __init__( self._logging_log_level = logging_log_level self._show_cli_notif = show_cli_notif self._proxies = {"https": https_proxy} - self.vault_url = vault_url - self.vault_path = vault_path - self.vault_mount_point = vault_mount_point - self.vault_token = vault_token + self._vault_url = vault_url + self._vault_path = vault_path + self._vault_mount_point = vault_mount_point + self._vault_token = vault_token CONSOLE._set_log_level(console_log_level, logging_log_level) self._load_env(env_file) if keyring_service: self._load_keyring(keyring_service) - if self.vault_path: - self._load_vault() + if self._vault_path: + self._load_vault() # finally block deletes _vault_* attrs + else: + for attr in ( + "_vault_url", + "_vault_token", + "_vault_path", + "_vault_mount_point", + ): + if hasattr(self, attr): + delattr(self, attr) # Filter out None values before updating proxies filtered_proxies = {k: v for k, v in self._proxies.items() if v is not None} self._session.proxies.update(filtered_proxies) @@ -165,6 +174,18 @@ def __init__( LOGGER.debug("apisession:__init__: API Session initialized") + def _new_session(self) -> Session: + session = requests.session() + session.headers["Accept"] = "application/json, application/vnd.api+json" + filtered_proxies = {k: v for k, v in self._proxies.items() if v is not None} + if filtered_proxies: + session.proxies.update(filtered_proxies) + if self._apitoken and self._apitoken_index >= 0: + session.headers["Authorization"] = ( + "Token " + self._apitoken[self._apitoken_index] + ) + return session + def __str__(self) -> str: fields = [ "email", @@ -206,7 +227,7 @@ def _load_vault( Load Vault settings from env file """ LOGGER.info("apisession:_load_vault: Loading Vault settings") - client = hvac.Client(url=self.vault_url, token=self.vault_token, verify=False) + client = hvac.Client(url=self._vault_url, token=self._vault_token) if not client.is_authenticated(): LOGGER.error("apisession:_load_vault: Vault authentication failed") CONSOLE.error("Vault authentication failed") @@ -216,7 +237,7 @@ def _load_vault( ) try: read_response = client.secrets.kv.v2.read_secret( - path=self.vault_path, mount_point=self.vault_mount_point + path=self._vault_path, mount_point=self._vault_mount_point ) LOGGER.info("apisession:_load_vault: Secret retrieved successfully") @@ -232,10 +253,10 @@ def _load_vault( LOGGER.error("apisession:_load_vault: Failed to retrieve secret") CONSOLE.error("Failed to retrieve secret") finally: - del self.vault_url - del self.vault_path - del self.vault_mount_point - del self.vault_token + del self._vault_url + del self._vault_path + del self._vault_mount_point + del self._vault_token def _load_keyring(self, keyring_service) -> None: """ @@ -322,17 +343,17 @@ def _load_env(self, env_file=None) -> None: except ValueError: self._logging_log_level = 10 # Default fallback - if os.getenv("MIST_VAULT_URL") and not self.vault_url: - self.vault_url = os.getenv("MIST_VAULT_URL") + if os.getenv("MIST_VAULT_URL") and not self._vault_url: + self._vault_url = os.getenv("MIST_VAULT_URL") - if os.getenv("MIST_VAULT_PATH") and not self.vault_path: - self.vault_path = os.getenv("MIST_VAULT_PATH") + if os.getenv("MIST_VAULT_PATH") and not self._vault_path: + self._vault_path = os.getenv("MIST_VAULT_PATH") - if os.getenv("MIST_VAULT_MOUNT_POINT") and not self.vault_mount_point: - self.vault_mount_point = os.getenv("MIST_VAULT_MOUNT_POINT") + if os.getenv("MIST_VAULT_MOUNT_POINT") and not self._vault_mount_point: + self._vault_mount_point = os.getenv("MIST_VAULT_MOUNT_POINT") - if os.getenv("MIST_VAULT_TOKEN") and not self.vault_token: - self.vault_token = os.getenv("MIST_VAULT_TOKEN") + if os.getenv("MIST_VAULT_TOKEN") and not self._vault_token: + self._vault_token = os.getenv("MIST_VAULT_TOKEN") if os.getenv("MIST_KEYRING_SERVICE"): self.keyring_service = os.getenv("MIST_KEYRING_SERVICE") @@ -465,7 +486,7 @@ def set_password(self, password: str | None = None) -> None: LOGGER.info("apisession:set_password:password configured") CONSOLE.debug("Password configured") - def set_api_token(self, apitoken: str) -> None: + def set_api_token(self, apitoken: str, validate: bool = True) -> None: """ Set Mist API Token @@ -473,6 +494,9 @@ def set_api_token(self, apitoken: str) -> None: ----------- apitoken : str API Token to add in the requests headers for authentication and authorization + validate : bool, default True + If True, validate the API tokens against the Mist Cloud before using them. + If False, accept the tokens directly without validation. """ LOGGER.debug("apisession:set_api_token") apitokens_in = apitoken.split(",") @@ -483,7 +507,10 @@ def set_api_token(self, apitoken: str) -> None: apitokens_out.append(token) LOGGER.info("apisession:set_api_token:found %s API Tokens", len(apitokens_out)) - valid_api_tokens = self._check_api_tokens(apitokens_out) + if validate: + valid_api_tokens = self._check_api_tokens(apitokens_out) + else: + valid_api_tokens = apitokens_out if valid_api_tokens: self._apitoken = valid_api_tokens self._apitoken_index = 0 @@ -666,7 +693,7 @@ def _process_login(self, retry: bool = True) -> str | None: print(" Login/Pwd authentication ".center(80, "-")) print() - self._session = requests.session() + self._session = self._new_session() if not self.email: self.set_email() if not self._password: @@ -776,7 +803,7 @@ def login_with_return( Error message from Mist (if any) """ LOGGER.debug("apisession:login_with_return") - self._session = requests.session() + self._session = self._new_session() if apitoken: self.set_api_token(apitoken) if email: @@ -944,8 +971,7 @@ def get_api_token(self) -> APIResponse: LOGGER.info( 'apisession:get_api_token: Sending GET request to "/api/v1/self/apitokens"' ) - resp = self.mist_get("/api/v1/self/apitokens") - return resp + return self.mist_get("/api/v1/self/apitokens") def create_api_token(self, token_name: str | None = None) -> APIResponse: """ @@ -970,10 +996,9 @@ def create_api_token(self, token_name: str | None = None) -> APIResponse: 'sending POST request to "/api/v1/self/apitokens" with name "%s"', token_name, ) - resp = self.mist_post("/api/v1/self/apitokens", body=body) - return resp + return self.mist_post("/api/v1/self/apitokens", body=body) - def delete_api_token(self, apitoken_id: str) -> Response: + def delete_api_token(self, apitoken_id: str) -> APIResponse: """ Delete an API Token based on its token_id @@ -993,9 +1018,7 @@ def delete_api_token(self, apitoken_id: str) -> Response: 'sending DELETE request to "/api/v1/self/apitokens" with token_id "%s"', apitoken_id, ) - uri = f"https://{self._cloud_uri}/api/v1/self/apitokens/{apitoken_id}" - resp = self._session.delete(uri) - return resp + return self.mist_delete(f"/api/v1/self/apitokens/{apitoken_id}") def _two_factor_authentication(self, two_factor: str) -> bool: """ diff --git a/src/mistapi/__init__.py b/src/mistapi/__init__.py index 6a18153..025bf45 100644 --- a/src/mistapi/__init__.py +++ b/src/mistapi/__init__.py @@ -20,3 +20,18 @@ from mistapi.__pagination import get_next as get_next from mistapi.__version import __author__ as __author__ from mistapi.__version import __version__ as __version__ + +_LAZY_SUBPACKAGES = { + "api": "mistapi.api", + "cli": "mistapi.cli", +} + + +def __getattr__(name: str): + if name in _LAZY_SUBPACKAGES: + import importlib + + module = importlib.import_module(_LAZY_SUBPACKAGES[name]) + globals()[name] = module + return module + raise AttributeError(f"module 'mistapi' has no attribute {name!r}") diff --git a/src/mistapi/__logger.py b/src/mistapi/__logger.py index e870683..51b9347 100644 --- a/src/mistapi/__logger.py +++ b/src/mistapi/__logger.py @@ -238,7 +238,8 @@ def __init__(self): self.console = Console() def filter(self, record): - record.msg = self.console.sanitize(record.msg) + record.msg = self.console.sanitize(record.getMessage()) + record.args = None return True From b5438da4144dae68914afe7cd74016c178eb3d3d Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Fri, 13 Mar 2026 20:47:04 +0100 Subject: [PATCH 08/16] refactor: rename utils to device_utils and adopt camelCase public API refactor: rename utils to device_utils and adopt camelCase public API Reorganize device utility modules from `mistapi.utils` to `mistapi.device_utils` with a clean separation between internal implementations (`__tools/`) and public-facing device modules (ap, ex, srx, ssr). Public function names now use camelCase (e.g. retrieveArpTable, monitorTraffic) for consistency with the auto-generated REST API. Also adds comprehensive unit tests for api_request, api_response, api_session, models, pagination, logger, init, and websocket_client. --- README.md | 175 +++- src/mistapi/__init__.py | 14 +- src/mistapi/api/v1/sites/devices.py | 4 +- .../{utils => device_utils}/__init__.py | 40 +- src/mistapi/device_utils/__tools/__init__.py | 42 + .../__tools}/__ws_wrapper.py | 169 ++-- .../{utils => device_utils/__tools}/arp.py | 40 +- .../{utils => device_utils/__tools}/bgp.py | 17 +- .../{utils => device_utils/__tools}/bpdu.py | 2 +- src/mistapi/device_utils/__tools/dhcp.py | 172 ++++ src/mistapi/device_utils/__tools/dns.py | 84 ++ .../{utils => device_utils/__tools}/dot1x.py | 2 +- .../{utils => device_utils/__tools}/mac.py | 27 +- .../device_utils/__tools/miscellaneous.py | 364 ++++++++ src/mistapi/device_utils/__tools/ospf.py | 291 +++++++ .../{utils => device_utils/__tools}/policy.py | 2 +- src/mistapi/device_utils/__tools/port.py | 133 +++ .../device_utils/__tools/remote_capture.py | 444 ++++++++++ .../{utils => device_utils/__tools}/routes.py | 16 +- .../__tools/service_path.py} | 46 +- .../__tools}/sessions.py | 28 +- src/mistapi/{utils => device_utils}/ap.py | 12 +- src/mistapi/device_utils/bgp.py | 70 ++ src/mistapi/device_utils/bpdu.py | 61 ++ src/mistapi/{utils => device_utils}/dhcp.py | 28 +- src/mistapi/device_utils/dot1x.py | 60 ++ src/mistapi/device_utils/ex.py | 78 ++ src/mistapi/{utils => device_utils}/ospf.py | 52 +- src/mistapi/device_utils/policy.py | 62 ++ src/mistapi/{utils => device_utils}/port.py | 31 +- .../{utils => device_utils}/service_path.py | 16 +- src/mistapi/device_utils/sessions.py | 172 ++++ src/mistapi/device_utils/srx.py | 69 ++ src/mistapi/device_utils/ssr.py | 76 ++ src/mistapi/{utils => device_utils}/tools.py | 251 +++--- src/mistapi/utils/ex.py | 78 -- src/mistapi/utils/srx.py | 61 -- src/mistapi/utils/ssr.py | 65 -- src/mistapi/websockets/__init__.py | 3 +- src/mistapi/websockets/__ws_client.py | 61 +- src/mistapi/websockets/location.py | 15 +- src/mistapi/websockets/orgs.py | 9 +- src/mistapi/websockets/session.py | 3 +- src/mistapi/websockets/sites.py | 18 +- tests/unit/test_api_request.py | 793 ++++++++++++++++++ tests/unit/test_api_response.py | 328 ++++++++ tests/unit/test_api_session.py | 152 ++++ tests/unit/test_init.py | 288 +++++++ tests/unit/test_logger.py | 316 +++++++ tests/unit/test_models.py | 506 +++++++++++ tests/unit/test_pagination.py | 296 +++++++ tests/unit/test_websocket_client.py | 774 +++++++++++++++++ 52 files changed, 6283 insertions(+), 633 deletions(-) rename src/mistapi/{utils => device_utils}/__init__.py (69%) create mode 100644 src/mistapi/device_utils/__tools/__init__.py rename src/mistapi/{utils => device_utils/__tools}/__ws_wrapper.py (60%) rename src/mistapi/{utils => device_utils/__tools}/arp.py (79%) rename src/mistapi/{utils => device_utils/__tools}/bgp.py (75%) rename src/mistapi/{utils => device_utils/__tools}/bpdu.py (96%) create mode 100644 src/mistapi/device_utils/__tools/dhcp.py create mode 100644 src/mistapi/device_utils/__tools/dns.py rename src/mistapi/{utils => device_utils/__tools}/dot1x.py (96%) rename src/mistapi/{utils => device_utils/__tools}/mac.py (87%) create mode 100644 src/mistapi/device_utils/__tools/miscellaneous.py create mode 100644 src/mistapi/device_utils/__tools/ospf.py rename src/mistapi/{utils => device_utils/__tools}/policy.py (96%) create mode 100644 src/mistapi/device_utils/__tools/port.py create mode 100644 src/mistapi/device_utils/__tools/remote_capture.py rename src/mistapi/{utils => device_utils/__tools}/routes.py (83%) rename src/mistapi/{utils/dns.py => device_utils/__tools/service_path.py} (53%) rename src/mistapi/{utils => device_utils/__tools}/sessions.py (82%) rename src/mistapi/{utils => device_utils}/ap.py (76%) create mode 100644 src/mistapi/device_utils/bgp.py create mode 100644 src/mistapi/device_utils/bpdu.py rename src/mistapi/{utils => device_utils}/dhcp.py (82%) create mode 100644 src/mistapi/device_utils/dot1x.py create mode 100644 src/mistapi/device_utils/ex.py rename src/mistapi/{utils => device_utils}/ospf.py (80%) create mode 100644 src/mistapi/device_utils/policy.py rename src/mistapi/{utils => device_utils}/port.py (76%) rename src/mistapi/{utils => device_utils}/service_path.py (80%) create mode 100644 src/mistapi/device_utils/sessions.py create mode 100644 src/mistapi/device_utils/srx.py create mode 100644 src/mistapi/device_utils/ssr.py rename src/mistapi/{utils => device_utils}/tools.py (78%) delete mode 100644 src/mistapi/utils/ex.py delete mode 100644 src/mistapi/utils/srx.py delete mode 100644 src/mistapi/utils/ssr.py create mode 100644 tests/unit/test_init.py create mode 100644 tests/unit/test_logger.py create mode 100644 tests/unit/test_websocket_client.py diff --git a/README.md b/README.md index 8e894fc..09d27e4 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,10 @@ A comprehensive Python package to interact with the Mist Cloud APIs, built from - [Callbacks](#callbacks) - [Available Channels](#available-channels) - [Usage Patterns](#usage-patterns) +- [Device Utilities](#device-utilities) + - [Supported Devices](#supported-devices) + - [Usage](#device-utilities-usage) + - [UtilResponse Object](#utilresponse-object) - [Development](#development-and-testing) - [Contributing](#contributing) - [License](#license) @@ -60,18 +64,12 @@ Support for all Mist cloud instances worldwide: ### Core Features - **Complete API Coverage**: Auto-generated from OpenAPI specs - **Automatic Pagination**: Built-in support for paginated responses +- **WebSocket Streaming**: Real-time event streaming for devices, clients, and location data +- **Device Diagnostics**: High-level utilities for ping, traceroute, ARP, BGP, OSPF, and more - **Error Handling**: Detailed error responses and logging - **Proxy Support**: HTTP/HTTPS proxy configuration - **Log Sanitization**: Automatic redaction of sensitive data in logs -### API Coverage -**Organization Level**: Organizations, Sites, Devices (APs/Switches/Gateways), WLANs, VPNs, Networks, NAC, Users, Admins, Guests, Alarms, Events, Statistics, SLE, Assets, Licenses, Webhooks, Security Policies, MSP management - -**Site Level**: Device management, RF optimization, Location services, Maps, Client analytics, Asset tracking, Synthetic testing, Anomaly detection - -**Constants & Utilities**: Device models, AP channels, Applications, Country codes, Alarm definitions, Event types, Webhook topics - -**Additional Services**: OAuth, Two-factor authentication, Account recovery, Invitations, MDM workflows --- @@ -97,16 +95,31 @@ python3 -m pip install --upgrade mistapi py -m pip install --upgrade mistapi ``` +### Installation with uv + +[uv](https://docs.astral.sh/uv/) is a fast Python package manager: + +```bash +# Install in current project +uv add mistapi + +# Or run directly without installing +uv run --with mistapi python my_script.py +``` + ### Development Installation ```bash -# Install with development dependencies (for contributors) -pip install mistapi[dev] +# With pip +pip install -e ".[dev]" + +# With uv +uv sync ``` ### Requirements - Python 3.10 or higher -- Dependencies: `requests`, `python-dotenv`, `tabulate`, `deprecation`, `hvac`, `keyring` +- Dependencies: `requests`, `python-dotenv`, `tabulate`, `deprecation`, `hvac`, `keyring`, `websocket-client` --- @@ -175,9 +188,9 @@ MIST_APITOKEN=your_api_token_here | `MIST_USER` | `email` | string | None | Username/email for authentication | | `MIST_PASSWORD` | `password` | string | None | Password for authentication | | `MIST_KEYRING_SERVICE` | `keyring_service` | string | None | System keyring service name | -| `MIST_VAULT_URL` | `vault_url` | string | https://127.0.0.1:8200 | HashiCorp Vault URL | +| `MIST_VAULT_URL` | `vault_url` | string | None | HashiCorp Vault URL | | `MIST_VAULT_PATH` | `vault_path` | string | None | Path to secret in Vault | -| `MIST_VAULT_MOUNT_POINT` | `vault_mount_point` | string | secret | Vault mount point | +| `MIST_VAULT_MOUNT_POINT` | `vault_mount_point` | string | None | Vault mount point | | `MIST_VAULT_TOKEN` | `vault_token` | string | None | Vault authentication token | | `CONSOLE_LOG_LEVEL` | `console_log_level` | int | 20 | Console log level (0-50) | | `LOGGING_LOG_LEVEL` | `logging_log_level` | int | 10 | File log level (0-50) | @@ -473,7 +486,7 @@ clients = mistapi.api.v1.orgs.clients.searchOrgWirelessClients( events = mistapi.api.v1.orgs.clients.searchOrgClientsEvents( apisession, org_id, duration="1h", - client_mac="aa:bb:cc:dd:ee:ff" + client_mac="aabbccddeeff" ) ``` @@ -493,9 +506,9 @@ All channel classes accept the following optional keyword arguments to control t | `ping_timeout` | `int` | `10` | Seconds to wait for a pong response before treating the connection as dead. | ```python -ws = mistapi.websockets.sites.SiteDeviceStatsEvents( +ws = mistapi.websockets.sites.DeviceStatsEvents( apisession, - site_id="", + site_ids=[""], ping_interval=60, # ping every 60 s ping_timeout=20, # wait up to 20 s for pong ) @@ -510,6 +523,7 @@ ws.connect() | `ws.on_message(cb)` | `cb(data: dict)` | Called for every incoming message | | `ws.on_error(cb)` | `cb(error: Exception)` | Called on WebSocket errors | | `ws.on_close(cb)` | `cb(status_code: int, msg: str)` | Called when the connection closes | +| `ws.ready()` | `-> bool \| None` | Returns `True` if the connection is open and ready | ### Available Channels @@ -517,35 +531,42 @@ ws.connect() | Class | Channel | Description | |-------|---------|-------------| -| `mistapi.websockets.orgs.OrgInsightsEvents` | `/orgs/{org_id}/insights/summary` | Real-time insights events for an organization | -| `mistapi.websockets.orgs.OrgMxEdgesStatsEvents` | `/orgs/{org_id}/stats/mxedges` | Real-time MX edges stats for an organization | -| `mistapi.websockets.orgs.OrgMxEdgesUpgradesEvents` | `/orgs/{org_id}/mxedges` | Real-time MX edges upgrades events for an organization | +| `mistapi.websockets.orgs.InsightsEvents` | `/orgs/{org_id}/insights/summary` | Real-time insights events for an organization | +| `mistapi.websockets.orgs.MxEdgesStatsEvents` | `/orgs/{org_id}/stats/mxedges` | Real-time MX edges stats for an organization | +| `mistapi.websockets.orgs.MxEdgesUpgradesEvents` | `/orgs/{org_id}/mxedges` | Real-time MX edges upgrades events for an organization | #### Site Channels | Class | Channel | Description | |-------|---------|-------------| -| `mistapi.websockets.sites.SiteClientsStatsEvents` | `/sites/{site_id}/stats/clients` | Real-time clients stats for a site | -| `mistapi.websockets.sites.SiteDeviceCmdEvents` | `/sites/{site_id}/devices/{device_id}/cmd` | Real-time device command events for a site | -| `mistapi.websockets.sites.SiteDeviceStatsEvents` | `/sites/{site_id}/stats/devices` | Real-time device stats for a site | -| `mistapi.websockets.sites.SiteDeviceUpgradesEvents` | `/sites/{site_id}/devices` | Real-time device upgrades events for a site | -| `mistapi.websockets.sites.SitePcapEvents` | `/sites/{site_id}/pcap` | Real-time PCAP events for a site | +| `mistapi.websockets.sites.ClientsStatsEvents` | `/sites/{site_id}/stats/clients` | Real-time clients stats for a site | +| `mistapi.websockets.sites.DeviceCmdEvents` | `/sites/{site_id}/devices/{device_id}/cmd` | Real-time device command events for a site | +| `mistapi.websockets.sites.DeviceStatsEvents` | `/sites/{site_id}/stats/devices` | Real-time device stats for a site | +| `mistapi.websockets.sites.DeviceUpgradesEvents` | `/sites/{site_id}/devices` | Real-time device upgrades events for a site | +| `mistapi.websockets.sites.MxEdgesStatsEvents` | `/sites/{site_id}/stats/mxedges` | Real-time MX edges stats for a site | +| `mistapi.websockets.sites.PcapEvents` | `/sites/{site_id}/pcap` | Real-time PCAP events for a site | #### Location Channels | Class | Channel | Description | |-------|---------|-------------| -| `mistapi.websockets.location.LocationBleAssetsEvents` | `/sites/{site_id}/stats/maps/{map_id}/assets` | Real-time BLE assets location events | -| `mistapi.websockets.location.LocationConnectedClientsEvents` | `/sites/{site_id}/stats/maps/{map_id}/clients` | Real-time connected clients location events | -| `mistapi.websockets.location.LocationSdkClientsEvents` | `/sites/{site_id}/stats/maps/{map_id}/sdkclients` | Real-time SDK clients location events | -| `mistapi.websockets.location.LocationUnconnectedClientsEvents` | `/sites/{site_id}/stats/maps/{map_id}/unconnected_clients` | Real-time unconnected clients location events | -| `mistapi.websockets.location.LocationDiscoveredBleAssetsEvents` | `/sites/{site_id}/stats/maps/{map_id}/discovered_assets` | Real-time discovered BLE assets location events | +| `mistapi.websockets.location.BleAssetsEvents` | `/sites/{site_id}/stats/maps/{map_id}/assets` | Real-time BLE assets location events | +| `mistapi.websockets.location.ConnectedClientsEvents` | `/sites/{site_id}/stats/maps/{map_id}/clients` | Real-time connected clients location events | +| `mistapi.websockets.location.SdkClientsEvents` | `/sites/{site_id}/stats/maps/{map_id}/sdkclients` | Real-time SDK clients location events | +| `mistapi.websockets.location.UnconnectedClientsEvents` | `/sites/{site_id}/stats/maps/{map_id}/unconnected_clients` | Real-time unconnected clients location events | +| `mistapi.websockets.location.DiscoveredBleAssetsEvents` | `/sites/{site_id}/stats/maps/{map_id}/discovered_assets` | Real-time discovered BLE assets location events | + +#### Session Channels + +| Class | Channel | Description | +|-------|---------|-------------| +| `mistapi.websockets.session.SessionWithUrl` | Custom URL | Connect to a custom WebSocket channel URL | ### Usage Patterns #### Callback style (recommended) -`connect()` returns immediately; messages are delivered to the registered callback in a background thread. +`connect()` defaults to `run_in_background=True` and returns immediately. The WebSocket runs in a daemon thread, so your program must stay alive (e.g., with `input()` or an event loop). Messages are delivered to the registered callback in the background thread. ```python import mistapi @@ -553,7 +574,7 @@ import mistapi apisession = mistapi.APISession(env_file="~/.mist_env") apisession.login() -ws = mistapi.websockets.sites.SiteDeviceStatsEvents(apisession, site_id="") +ws = mistapi.websockets.sites.DeviceStatsEvents(apisession, site_ids=[""]) ws.on_message(lambda data: print(data)) ws.connect() # non-blocking @@ -566,7 +587,7 @@ ws.disconnect() Iterate over incoming messages as a blocking generator. Useful when you want to process messages sequentially in a loop. ```python -ws = mistapi.websockets.sites.SiteDeviceStatsEvents(apisession, site_id="") +ws = mistapi.websockets.sites.DeviceStatsEvents(apisession, site_ids=[""]) ws.connect(run_in_background=True) for msg in ws.receive(): # blocks, yields each message as a dict @@ -580,7 +601,7 @@ for msg in ws.receive(): # blocks, yields each message as a dict `connect(run_in_background=False)` blocks the calling thread until the connection closes. Useful for simple scripts. ```python -ws = mistapi.websockets.sites.SiteDeviceStatsEvents(apisession, site_id="") +ws = mistapi.websockets.sites.DeviceStatsEvents(apisession, site_ids=[""]) ws.on_message(lambda data: print(data)) ws.connect(run_in_background=False) # blocks until disconnected ``` @@ -592,7 +613,7 @@ ws.connect(run_in_background=False) # blocks until disconnected ```python import time -with mistapi.websockets.sites.SiteDeviceStatsEvents(apisession, site_id="") as ws: +with mistapi.websockets.sites.DeviceStatsEvents(apisession, site_ids=[""]) as ws: ws.on_message(lambda data: print(data)) ws.connect() time.sleep(60) @@ -601,6 +622,57 @@ with mistapi.websockets.sites.SiteDeviceStatsEvents(apisession, site_id=" _APIResponse: """ API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/utilities/wan/test-site-ssr-dns-resolution @@ -2402,7 +2402,7 @@ def testSiteSsrDnsResolution( """ uri = f"/api/v1/sites/{site_id}/devices/{device_id}/resolve_dns" - resp = mist_session.mist_post(uri=uri) + resp = mist_session.mist_post(uri=uri, body=body) return resp diff --git a/src/mistapi/utils/__init__.py b/src/mistapi/device_utils/__init__.py similarity index 69% rename from src/mistapi/utils/__init__.py rename to src/mistapi/device_utils/__init__.py index 4ff6e43..dea7134 100644 --- a/src/mistapi/utils/__init__.py +++ b/src/mistapi/device_utils/__init__.py @@ -21,9 +21,9 @@ from mistapi.utils import ap, ex, srx, ssr # Use device-specific utilities - await ap.ping(session, site_id, device_id, host) - await ex.cable_test(session, site_id, device_id, port_id) - await ssr.show_service_path(session, site_id, device_id) + ap.ping(session, site_id, device_id, host) + ex.cableTest(session, site_id, device_id, port_id) + ssr.showServicePath(session, site_id, device_id) Supported Devices: - ap: Mist Access Points @@ -44,26 +44,11 @@ # Device-specific modules (recommended) # Function-based modules (legacy, still supported) # Internal modules -from mistapi.utils import ( - __ws_wrapper, +from mistapi.device_utils import ( ap, - arp, - bgp, - bpdu, - dhcp, - dns, - dot1x, ex, - mac, - ospf, - policy, - port, - routes, - service_path, - sessions, srx, ssr, - tools, ) __all__ = [ @@ -72,21 +57,4 @@ "ex", "srx", "ssr", - # Function-based modules (legacy) - "arp", - "bgp", - "bpdu", - "dhcp", - "dns", - "dot1x", - "mac", - "ospf", - "policy", - "port", - "routes", - "service_path", - "sessions", - "tools", - # Internal - "__ws_wrapper", ] diff --git a/src/mistapi/device_utils/__tools/__init__.py b/src/mistapi/device_utils/__tools/__init__.py new file mode 100644 index 0000000..1110f88 --- /dev/null +++ b/src/mistapi/device_utils/__tools/__init__.py @@ -0,0 +1,42 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- + +Mist API Utilities +================== + +This package provides utility functions for interacting with Mist devices. + +Device-Specific Modules (Recommended) +-------------------------------------- +Import device-specific modules for a clean, organized API: + + from mistapi.utils import ap, ex, srx, ssr + + # Use device-specific utilities + ap.ping(session, site_id, device_id, host) + ex.cable_test(session, site_id, device_id, port_id) + ssr.show_service_path(session, site_id, device_id) + +Supported Devices: +- ap: Mist Access Points +- ex: Juniper EX Switches +- srx: Juniper SRX Firewalls +- ssr: Juniper Session Smart Routers + +Function-Based Modules (Legacy) +--------------------------------- +Original organization by function type (still available): + + from mistapi.utils import arp, bgp, dhcp, mac, port, routes, tools + +Available modules: arp, bgp, bpdu, dhcp, dns, dot1x, mac, policy, port, routes, + service_path, tools +""" diff --git a/src/mistapi/utils/__ws_wrapper.py b/src/mistapi/device_utils/__tools/__ws_wrapper.py similarity index 60% rename from src/mistapi/utils/__ws_wrapper.py rename to src/mistapi/device_utils/__tools/__ws_wrapper.py index 1e3c331..a1d8d8a 100644 --- a/src/mistapi/utils/__ws_wrapper.py +++ b/src/mistapi/device_utils/__tools/__ws_wrapper.py @@ -1,22 +1,28 @@ import json import threading -import time +from collections.abc import Callable from enum import Enum from mistapi import APISession from mistapi.__api_response import APIResponse as _APIResponse from mistapi.__logger import logger as LOGGER -from mistapi.websockets.session import SessionWithUrl -from mistapi.websockets.sites import DeviceCmdEvents, PcapEvents class TimerAction(Enum): + """ + TimerAction Enum for managing timer actions in WebSocketWrapper. + """ + START = "start" STOP = "stop" RESET = "reset" class Timer(Enum): + """ + Timer Enum for specifying different timer types in WebSocketWrapper. + """ + TIMEOUT = "timeout" FIRST_MESSAGE_TIMEOUT = "first_message_timeout" MAX_DURATION = "max_duration" @@ -33,7 +39,8 @@ def __init__( api_response: _APIResponse, ) -> None: self.trigger_api_response = api_response - self.ws_required: bool = False # This can be set to True if the WebSocket connection was successfully initiated + # This can be set to True if the WebSocket connection was successfully initiated + self.ws_required: bool = False self.ws_data: list[str] = [] self.ws_raw_events: list[str] = [] @@ -51,6 +58,7 @@ def __init__( util_response: UtilResponse, timeout: int = 10, max_duration: int = 60, + on_message: Callable[[dict], None] | None = None, ) -> None: self.apissession = apissession self.util_response = util_response @@ -74,6 +82,8 @@ def __init__( self.ws = None self.session_id: str | None = None self.capture_id: str | None = None + self._on_message_cb = on_message + self._closed = threading.Event() LOGGER.debug( "trigger response: %s", self.util_response.trigger_api_response.data @@ -94,9 +104,12 @@ def _on_open(self): LOGGER.info("WebSocket connection opened") # Start the max duration timer self._timeout_handler(Timer.MAX_DURATION, TimerAction.START) - # self._reset_timer() # Start the timer when the connection opens - #################################################################################################################### + def _on_close(self, code, msg): + LOGGER.info("WebSocket closed: %s - %s", code, msg) + self._closed.set() + + ########################################################################## ## Helper methods for managing timers def _timeout_handler(self, timer_type: Timer, action: TimerAction): duration = self.timers[timer_type.value]["duration"] @@ -124,7 +137,13 @@ def _timeout_handler(self, timer_type: Timer, action: TimerAction): "WebSocket is not available to start %s timer", timer_type.value ) - #################################################################################################################### + def _stop_all_timers(self): + for timer_info in self.timers.values(): + if timer_info["thread"]: + timer_info["thread"].cancel() + timer_info["thread"] = None + + ########################################################################## ## WebSocket event handlers def _handle_message(self, msg): @@ -139,14 +158,17 @@ def _handle_message(self, msg): raw = self._extract_raw(msg) if raw: self.data.append(raw) + if self._on_message_cb: + self._on_message_cb(raw) self._timeout_handler(Timer.TIMEOUT, TimerAction.RESET) - #################################################################################################################### + ########################################################################## ## Message processing and WebSocket connection management def _extract_session_id(self, message) -> bool: """ Extracts the session_id from the message and compares it to the expected session_id. - This method is designed to handle messages that may have the session_id nested at different levels. + This method is designed to handle messages that may have the session_id nested at + different levels. If the expected session_id is None, it will accept all messages. """ if not self.session_id and not self.capture_id: @@ -183,126 +205,49 @@ def _extract_session_id(self, message) -> bool: def _extract_raw(self, message): """ Extracts the raw message from the given message. - This method is designed to handle messages that may have the raw message nested at different levels. + This method is designed to handle messages that may have the raw message nested at + different levels. Handles both command events (with "raw" field) and pcap events (with "pcap_dict" field). """ self.raw_events.append(message) event = message if isinstance(event, str): try: - event = json.loads(message) - if isinstance(event, dict): - # Check for raw field (command events) - if "raw" in event: - LOGGER.debug("Extracted raw message: %s", event["raw"]) - return event["raw"] - # Check for pcap_dict field (pcap events) - if "pcap_dict" in event: - LOGGER.debug("Extracted pcap_dict: %s", event["pcap_dict"]) - return event["pcap_dict"] + event = json.loads(event) except json.JSONDecodeError: LOGGER.warning("Failed to decode message as JSON: %s", message) return None - if event.get("event") == "data" and event.get("data"): - return self._extract_raw(event["data"]) - if event.get("raw"): - self.received_messages += 1 - LOGGER.debug("Received raw message: %s", event.get("raw")) - return event["raw"] - if event.get("pcap_dict"): - self.received_messages += 1 - LOGGER.debug("Received pcap data: %s", event["pcap_dict"]) - return event["pcap_dict"] + if isinstance(event, dict): + if event.get("event") == "data" and event.get("data"): + return self._extract_raw(event["data"]) + if "raw" in event: + self.received_messages += 1 + LOGGER.debug("Extracted raw message: %s", event["raw"]) + return event["raw"] + if "pcap_dict" in event: + self.received_messages += 1 + LOGGER.debug("Extracted pcap data: %s", event["pcap_dict"]) + return event["pcap_dict"] return None - #################################################################################################################### + ########################################################################## ## WebSocket connection management - async def startCmdEvents(self, site_id: str, device_id: str) -> UtilResponse: + def start(self, ws) -> UtilResponse: """ - Start a WebSocket stream for site device command events. + Start the WS connection, block until closed, return UtilResponse. PARAMS ----------- - site_id : str - UUID of the site to stream events from. - device_id : str - UUID of the device to stream events from. - """ - self.ws = DeviceCmdEvents( - self.apissession, site_id=site_id, device_ids=[device_id] - ) - self.ws.on_message(self._handle_message) - self.ws.on_error(lambda error: LOGGER.error(f"Error: {error}")) - self.ws.on_close( - lambda code, msg: LOGGER.info(f"WebSocket closed: {code} - {msg}") - ) - self.ws.on_open(self._on_open) - self.ws.connect() # non-blocking - LOGGER.info( - "WebSocket connection initiated: site_id=%s, device_id=%s", - site_id, - device_id, - ) - time.sleep(1) - while self.ws and self.ws.ready(): - time.sleep(1) - LOGGER.info("WebSocket connection closed, exiting") - self.util_response.ws_required = True - self.util_response.ws_data = self.data - self.util_response.ws_raw_events = self.raw_events - return self.util_response - - async def startSessionUrl(self, url: str) -> UtilResponse: + ws : _MistWebsocket + An already-constructed WebSocket channel object. """ - Start a WebSocket stream using a custom URL. - This should be used when Mist is returning a WebSocket URL from an API call. - - PARAMS - ----------- - url : str - Full WebSocket URL to connect to (e.g., wss://api-ws.mist.com/ssh?jwt=eyJhbGciOiJI...). - """ - self.ws = SessionWithUrl(self.apissession, url=url) - self.ws.on_message(self._handle_message) - self.ws.on_error(lambda error: LOGGER.error(f"Error: {error}")) - self.ws.on_close( - lambda code, msg: LOGGER.info(f"WebSocket closed: {code} - {msg}") - ) - self.ws.on_open(self._on_open) - self.ws.connect() # non-blocking - LOGGER.info("WebSocket connection initiated: url=%s", url) - time.sleep(1) - while self.ws and self.ws.ready(): - time.sleep(1) - LOGGER.info("WebSocket connection closed, exiting") - self.util_response.ws_required = True - self.util_response.ws_data = self.data - self.util_response.ws_raw_events = self.raw_events - return self.util_response - - async def startRemotePcap(self, site_id: str) -> UtilResponse: - """ - Start a WebSocket stream for remote PCAP events. - This should be used when Mist is returning a WebSocket URL from an API call. - - PARAMS - ----------- - site_id : str - UUID of the site to stream PCAP events from. - """ - self.ws = PcapEvents(self.apissession, site_id=site_id) - self.ws.on_message(self._handle_message) - self.ws.on_error(lambda error: LOGGER.error(f"Error: {error}")) - self.ws.on_close( - lambda code, msg: LOGGER.info(f"WebSocket closed: {code} - {msg}") - ) - self.ws.on_open(self._on_open) - self.ws.connect() # non-blocking - LOGGER.info("WebSocket connection initiated: /sites/%s/pcaps", site_id) - time.sleep(1) - while self.ws and self.ws.ready(): - time.sleep(1) - LOGGER.info("WebSocket connection closed, exiting") + self.ws = ws + ws.on_message(self._handle_message) + ws.on_error(lambda error: LOGGER.error("Error: %s", error)) + ws.on_close(self._on_close) + ws.on_open(self._on_open) + ws.connect(run_in_background=False) # blocks until _on_close fires + self._stop_all_timers() self.util_response.ws_required = True self.util_response.ws_data = self.data self.util_response.ws_raw_events = self.raw_events diff --git a/src/mistapi/utils/arp.py b/src/mistapi/device_utils/__tools/arp.py similarity index 79% rename from src/mistapi/utils/arp.py rename to src/mistapi/device_utils/__tools/arp.py index a15efce..f9b3d6d 100644 --- a/src/mistapi/utils/arp.py +++ b/src/mistapi/device_utils/__tools/arp.py @@ -10,12 +10,14 @@ -------------------------------------------------------------------------------- """ +from collections.abc import Callable from enum import Enum from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices -from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents class Node(Enum): @@ -25,12 +27,13 @@ class Node(Enum): NODE1 = "node1" -async def retrieve_ap_arp_table( +def retrieve_ap_arp_table( apissession: _APISession, site_id: str, device_id: str, node: Node | None = None, timeout=1, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICES: AP @@ -49,6 +52,8 @@ async def retrieve_ap_arp_table( Node information for the ARP table retrieval command. timeout : int, optional Timeout for the ARP table retrieval command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -72,9 +77,10 @@ async def retrieve_ap_arp_table( if trigger.status_code == 200: LOGGER.info(trigger.data) print(f"Show ARP command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" @@ -82,12 +88,13 @@ async def retrieve_ap_arp_table( return util_response -async def retrieve_ssr_arp_table( +def retrieve_ssr_arp_table( apissession: _APISession, site_id: str, device_id: str, node: Node | None = None, timeout=1, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICES: SSR @@ -106,6 +113,8 @@ async def retrieve_ssr_arp_table( Node information for the ARP table retrieval command. timeout : int, optional Timeout for the ARP table retrieval command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -129,9 +138,10 @@ async def retrieve_ssr_arp_table( if trigger.status_code == 200: LOGGER.info(trigger.data) print(f"Show ARP command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" @@ -139,7 +149,7 @@ async def retrieve_ssr_arp_table( return util_response -async def retrieve_junos_arp_table( +def retrieve_junos_arp_table( apissession: _APISession, site_id: str, device_id: str, @@ -147,6 +157,7 @@ async def retrieve_junos_arp_table( port_id: str | None = None, vrf: str | None = None, timeout=1, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICES: EX, SRX @@ -170,6 +181,8 @@ async def retrieve_junos_arp_table( VRF to filter the ARP table. timeout : int, optional Timeout for the ARP table retrieval command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -194,9 +207,10 @@ async def retrieve_junos_arp_table( if trigger.status_code == 200: LOGGER.info(trigger.data) print(f"Show ARP command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" diff --git a/src/mistapi/utils/bgp.py b/src/mistapi/device_utils/__tools/bgp.py similarity index 75% rename from src/mistapi/utils/bgp.py rename to src/mistapi/device_utils/__tools/bgp.py index 2db700f..f545c57 100644 --- a/src/mistapi/utils/bgp.py +++ b/src/mistapi/device_utils/__tools/bgp.py @@ -10,17 +10,21 @@ -------------------------------------------------------------------------------- """ +from collections.abc import Callable + from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices -from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents -async def show_summary( +def summary( apissession: _APISession, site_id: str, device_id: str, timeout=5, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICES: EX, SRX, SSR @@ -36,6 +40,8 @@ async def show_summary( UUID of the site where the device is located. device_id : str UUID of the device to show BGP summary on. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -53,9 +59,10 @@ async def show_summary( util_response = UtilResponse(trigger) if trigger.status_code == 200: LOGGER.info(f"BGP summary command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger BGP summary command: {trigger.status_code} - {trigger.data}" diff --git a/src/mistapi/utils/bpdu.py b/src/mistapi/device_utils/__tools/bpdu.py similarity index 96% rename from src/mistapi/utils/bpdu.py rename to src/mistapi/device_utils/__tools/bpdu.py index 26eccb3..0bdf96b 100644 --- a/src/mistapi/utils/bpdu.py +++ b/src/mistapi/device_utils/__tools/bpdu.py @@ -13,7 +13,7 @@ from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices -from mistapi.utils.__ws_wrapper import UtilResponse +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse async def clear_error( diff --git a/src/mistapi/device_utils/__tools/dhcp.py b/src/mistapi/device_utils/__tools/dhcp.py new file mode 100644 index 0000000..e0ed5f0 --- /dev/null +++ b/src/mistapi/device_utils/__tools/dhcp.py @@ -0,0 +1,172 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +class Node(Enum): + """Node Enum for specifying node information in DHCP commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +def release_dhcp_leases( + apissession: _APISession, + site_id: str, + device_id: str, + macs: list[str] | None = None, + network: str | None = None, + node: Node | None = None, + port_id: str | None = None, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: EX, SRX, SSR + + Releases DHCP leases on a device (EX/ SRX / SSR) and streams the results. + + valid combinations for EX are: + - network + macs + - network + port_id + - port_id + + valid combinations for SRX / SSR are: + - network + - network + macs + - network + port_id + - port_id + - port_id + macs + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to release DHCP leases on. + macs : list[str], optional + List of MAC addresses to release DHCP leases for. + network : str, optional + Network to release DHCP leases for. + node : Node, optional + Node information for the DHCP lease release command. + port_id : str, optional + Port ID to release DHCP leases for. + timeout : int, optional + Timeout for the release DHCP leases command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if macs: + body["macs"] = macs + if network: + body["network"] = network + if node: + body["node"] = node.value + if port_id: + body["port_id"] = port_id + trigger = devices.releaseSiteDeviceDhcpLease( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Release DHCP leases command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger release DHCP leases command: {trigger.status_code} - {trigger.data}" + ) # Give the release DHCP leases command a moment to take effect + return util_response + + +def retrieve_dhcp_leases( + apissession: _APISession, + site_id: str, + device_id: str, + network: str, + node: Node | None = None, + timeout=15, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Retrieves DHCP leases on a gateway (SRX / SSR) and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to retrieve DHCP leases from. + network : str + Network to release DHCP leases for. + node : Node, optional + Node information for the DHCP lease release command. + port_id : str, optional + Port ID to release DHCP leases for. + timeout : int, optional + Timeout for the release DHCP leases command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"network": network} + if node: + body["node"] = node.value + trigger = devices.showSiteDeviceDhcpLeases( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Retrieve DHCP leases command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger retrieve DHCP leases command: {trigger.status_code} - {trigger.data}" + ) # Give the release DHCP leases command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/__tools/dns.py b/src/mistapi/device_utils/__tools/dns.py new file mode 100644 index 0000000..4f5cca2 --- /dev/null +++ b/src/mistapi/device_utils/__tools/dns.py @@ -0,0 +1,84 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from enum import Enum + + +class Node(Enum): + """Node Enum for specifying node information in DNS commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +## NO DATA +# def test_resolution( +# apissession: _APISession, +# site_id: str, +# device_id: str, +# node: Node | None = None, +# hostname: str | None = None, +# timeout=5, +# on_message: Callable[[dict], None] | None = None, +# ) -> UtilResponse: +# """ +# DEVICES: SSR + +# Initiates a DNS resolution command on the gateway and streams the results. + +# PARAMS +# ----------- +# apissession : _APISession +# The API session to use for the request. +# site_id : str +# UUID of the site where the gateway is located. +# device_id : str +# UUID of the gateway to perform the DNS resolution command on. +# node : Node, optional +# Node information for the DNS resolution command. +# hostname : str, optional +# Hostname to resolve. +# timeout : int, optional +# Timeout for the command in seconds. +# on_message : Callable, optional +# Callback invoked with each extracted raw message as it arrives. + +# RETURNS +# ----------- +# UtilResponse +# A UtilResponse object containing the API response and a list of raw messages received +# from the WebSocket stream. +# """ +# body: dict[str, str | list | int] = {} +# if node: +# body["node"] = node.value +# if hostname: +# body["hostname"] = hostname +# trigger = devices.testSiteSsrDnsResolution( +# apissession, +# site_id=site_id, +# device_id=device_id, +# body=body, +# ) +# util_response = UtilResponse(trigger) +# if trigger.status_code == 200: +# LOGGER.info(trigger.data) +# print(f"SSR DNS resolution command triggered for device {device_id}") +# ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) +# util_response = WebSocketWrapper( +# apissession, util_response, timeout=timeout, on_message=on_message +# ).start(ws) +# else: +# LOGGER.error( +# f"Failed to trigger SSR DNS resolution command: {trigger.status_code} - {trigger.data}" +# ) # Give the SSR DNS resolution command a moment to take effect +# return util_response diff --git a/src/mistapi/utils/dot1x.py b/src/mistapi/device_utils/__tools/dot1x.py similarity index 96% rename from src/mistapi/utils/dot1x.py rename to src/mistapi/device_utils/__tools/dot1x.py index abece84..537e65d 100644 --- a/src/mistapi/utils/dot1x.py +++ b/src/mistapi/device_utils/__tools/dot1x.py @@ -13,7 +13,7 @@ from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices -from mistapi.utils.__ws_wrapper import UtilResponse +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse async def clear_sessions( diff --git a/src/mistapi/utils/mac.py b/src/mistapi/device_utils/__tools/mac.py similarity index 87% rename from src/mistapi/utils/mac.py rename to src/mistapi/device_utils/__tools/mac.py index e4c25d5..d68441a 100644 --- a/src/mistapi/utils/mac.py +++ b/src/mistapi/device_utils/__tools/mac.py @@ -10,13 +10,16 @@ -------------------------------------------------------------------------------- """ +from collections.abc import Callable + from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices -from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents -async def clear_mac_table( +def clear_mac_table( apissession: _APISession, site_id: str, device_id: str, @@ -71,9 +74,9 @@ async def clear_mac_table( if trigger.status_code == 200: LOGGER.info(trigger.data) print(f"Clear MAC Table command triggered for device {device_id}") - # util_response = await WebSocketWrapper( - # apissession, util_response, timeout=timeout - # ).startCmdEvents(site_id, device_id) + # util_response = WebSocketWrapper( + # apissession, util_response, timeout=timeout, on_message=on_message + # ).start(ws) else: LOGGER.error( f"Failed to trigger clear MAC Table command: {trigger.status_code} - {trigger.data}" @@ -81,7 +84,7 @@ async def clear_mac_table( return util_response -async def retrieve_mac_table( +def retrieve_mac_table( apissession: _APISession, site_id: str, device_id: str, @@ -89,6 +92,7 @@ async def retrieve_mac_table( port_id: str | None = None, vlan_id: str | None = None, timeout=5, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICES: EX @@ -111,6 +115,8 @@ async def retrieve_mac_table( VLAN ID to filter the ARP table retrieval. timeout : int, optional Timeout for the ARP table retrieval command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -138,9 +144,10 @@ async def retrieve_mac_table( if trigger.status_code == 200: LOGGER.info(trigger.data) print(f"Show MAC Table command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger show MAC Table command: {trigger.status_code} - {trigger.data}" @@ -148,7 +155,7 @@ async def retrieve_mac_table( return util_response -async def clear_learned_mac( +def clear_learned_mac( apissession: _APISession, site_id: str, device_id: str, diff --git a/src/mistapi/device_utils/__tools/miscellaneous.py b/src/mistapi/device_utils/__tools/miscellaneous.py new file mode 100644 index 0000000..ccc1bc4 --- /dev/null +++ b/src/mistapi/device_utils/__tools/miscellaneous.py @@ -0,0 +1,364 @@ +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.session import SessionWithUrl +from mistapi.websockets.sites import DeviceCmdEvents + + +class Node(Enum): + """Node Enum for specifying node information in commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +class TracerouteProtocol(Enum): + """Enum for specifying protocol in traceroute command.""" + + ICMP = "icmp" + UDP = "udp" + + +def ping( + apissession: _APISession, + site_id: str, + device_id: str, + host: str, + count: int | None = None, + node: Node | None = None, + size: int | None = None, + vrf: str | None = None, + timeout: int = 3, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: AP, EX, SRX, SSR + + Initiates a ping command from a device (AP / EX/ SRX / SSR) to a specified host and + streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to initiate the ping from. + host : str + The host to ping. + count : int, optional + Number of ping requests to send. + node : None, optional + Node information for the ping command. + size : int, optional + Size of the ping packet. + vrf : str, optional + VRF to use for the ping command. + timeout : int, optional + Timeout for the ping command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if count: + body["count"] = count + if host: + body["host"] = host + if node: + body["node"] = node.value + if size: + body["size"] = size + if vrf: + body["vrf"] = vrf + trigger = devices.pingFromDevice( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Ping command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger ping command: {trigger.status_code} - {trigger.data}" + ) # Give the ping command a moment to take effect + return util_response + + +## NO DATA +# def service_ping( +# apissession: _APISession, +# site_id: str, +# device_id: str, +# host: str, +# service: str, +# tenant: str, +# count: int | None = None, +# node: None | None = None, +# size: int | None = None, +# timeout: int = 3, +# on_message: Callable[[dict], None] | None = None, +# ) -> UtilResponse: +# """ +# DEVICES: SSR + +# Initiates a service ping command from a SSR to a specified host and streams the results. + +# PARAMS +# ----------- +# apissession : _APISession +# The API session to use for the request. +# site_id : str +# UUID of the site where the device is located. +# device_id : str +# UUID of the device to initiate the ping from. +# host : str +# The host to ping. +# service : str +# The service to ping. +# tenant : str +# Tenant to use for the ping command. +# count : int, optional +# Number of ping requests to send. +# node : None, optional +# Node information for the ping command. +# size : int, optional +# Size of the ping packet. +# timeout : int, optional +# Timeout for the ping command in seconds. +# on_message : Callable, optional +# Callback invoked with each extracted raw message as it arrives. + +# RETURNS +# ----------- +# UtilResponse +# A UtilResponse object containing the API response and a list of raw messages received +# from the WebSocket stream. +# """ +# body: dict[str, str | list | int] = {} +# if count: +# body["count"] = count +# if host: +# body["host"] = host +# if node: +# body["node"] = node.value +# if size: +# body["size"] = size +# if tenant: +# body["tenant"] = tenant +# if service: +# body["service"] = service +# trigger = devices.servicePingFromSsr( +# apissession, +# site_id=site_id, +# device_id=device_id, +# body=body, +# ) +# util_response = UtilResponse(trigger) +# if trigger.status_code == 200: +# LOGGER.info(f"Service Ping command triggered for device {device_id}") +# ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) +# util_response = WebSocketWrapper( +# apissession, util_response, timeout, on_message=on_message +# ).start(ws) +# else: +# LOGGER.error( +# f"Failed to trigger Service Ping command: {trigger.status_code} - {trigger.data}" +# ) # Give the ping command a moment to take effect +# return util_response + + +def traceroute( + apissession: _APISession, + site_id: str, + device_id: str, + host: str, + protocol: TracerouteProtocol = TracerouteProtocol.ICMP, + port: int | None = None, + timeout: int = 10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: AP, EX, SRX, SSR + + Initiates a traceroute command from a device (AP / EX/ SRX / SSR) to a specified host and + streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to initiate the traceroute from. + host : str + The host to traceroute. + protocol : TracerouteProtocol, optional + Protocol to use for the traceroute command (icmp or udp). + port : int, optional + Port to use for UDP traceroute. + timeout : int, optional + Timeout for the traceroute command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"host": host} + if protocol: + body["protocol"] = protocol.value + if port: + body["port"] = port + trigger = devices.tracerouteFromDevice( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Traceroute command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger traceroute command: {trigger.status_code} - {trigger.data}" + ) # Give the traceroute command a moment to take effect + return util_response + + +def monitor_traffic( + apissession: _APISession, + site_id: str, + device_id: str, + port_id: str | None = None, + timeout=30, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: EX, SRX + + Initiates a monitor traffic command on the device and streams the results. + + * if `port_id` is provided, JUNOS uses cmd "monitor interface" to monitor traffic on particular + * if `port_id` is not provided, JUNOS uses cmd "monitor interface traffic" to monitor traffic + on all ports + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to monitor traffic on. + port_id : str, optional + Port ID to filter the traffic. + timeout : int, optional + Timeout for the monitor traffic command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | int] = {"duration": 60} + if port_id: + body["port"] = port_id + trigger = devices.monitorSiteDeviceTraffic( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Monitor traffic command triggered for device {device_id}") + ws = SessionWithUrl(apissession, url=trigger.data.get("url", "")) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger monitor traffic command: {trigger.status_code} - {trigger.data}" + ) # Give the monitor traffic command a moment to take effect + return util_response + + +## NO DATA +# def srx_top_command( +# apissession: _APISession, +# site_id: str, +# device_id: str, +# timeout=10, +# on_message: Callable[[dict], None] | None = None, +# ) -> UtilResponse: +# """ +# DEVICE: SRX + +# For SRX Only. Initiates a top command on the device and streams the results. + +# PARAMS +# ----------- +# apissession : _APISession +# The API session to use for the request. +# site_id : str +# UUID of the site where the device is located. +# device_id : str +# UUID of the device to run the top command on. +# timeout : int, optional +# Timeout for the top command in seconds. +# on_message : Callable, optional +# Callback invoked with each extracted raw message as it arrives. + +# RETURNS +# ----------- +# UtilResponse +# A UtilResponse object containing the API response and a list of raw messages received +# from the WebSocket stream. +# """ +# trigger = devices.runSiteSrxTopCommand( +# apissession, +# site_id=site_id, +# device_id=device_id, +# ) +# util_response = UtilResponse(trigger) +# if trigger.status_code == 200: +# LOGGER.info(trigger.data) +# print(f"Top command triggered for device {device_id}") +# ws = SessionWithUrl(apissession, url=trigger.data.get("url", "")) +# util_response = WebSocketWrapper( +# apissession, util_response, timeout=timeout, on_message=on_message +# ).start(ws) +# else: +# LOGGER.error( +# f"Failed to trigger top command: {trigger.status_code} - {trigger.data}" +# ) # Give the top command a moment to take effect +# return util_response diff --git a/src/mistapi/device_utils/__tools/ospf.py b/src/mistapi/device_utils/__tools/ospf.py new file mode 100644 index 0000000..09eda9e --- /dev/null +++ b/src/mistapi/device_utils/__tools/ospf.py @@ -0,0 +1,291 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +class Node(Enum): + """Node Enum for specifying node information in OSPF commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +def show_database( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + self_originate: bool | None = None, + vrf: str | None = None, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Shows OSPF database on a device (SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show OSPF database on. + node : Node, optional + Node information for the show OSPF database command. + self_originate : bool, optional + Filter for self-originated routes in the OSPF database. + vrf : str, optional + VRF to filter the OSPF database. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if self_originate is not None: + body["self_originate"] = self_originate + if vrf: + body["vrf"] = vrf + trigger = devices.showSiteGatewayOspfDatabase( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"OSPF database command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger OSPF database command: {trigger.status_code} - {trigger.data}" + ) # Give the OSPF database command a moment to take effect + return util_response + + +def show_interfaces( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + port_id: str | None = None, + vrf: str | None = None, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Shows OSPF interfaces on a device (SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show OSPF interfaces on. + node : Node, optional + Node information for the show OSPF interfaces command. + port_id : str, optional + Port ID to filter the OSPF interfaces. + vrf : str, optional + VRF to filter the OSPF interfaces. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if port_id: + body["port_id"] = port_id + if vrf: + body["vrf"] = vrf + trigger = devices.showSiteGatewayOspfInterfaces( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"OSPF interfaces command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger OSPF interfaces command: {trigger.status_code} - {trigger.data}" + ) # Give the OSPF interfaces command a moment to take effect + return util_response + + +def show_neighbors( + apissession: _APISession, + site_id: str, + device_id: str, + neighbor: str | None = None, + node: Node | None = None, + port_id: str | None = None, + vrf: str | None = None, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Shows OSPF neighbors on a device (SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show OSPF neighbors on. + neighbor : str, optional + Neighbor IP address to filter the OSPF neighbors. + node : Node, optional + Node information for the show OSPF neighbors command. + port_id : str, optional + Port ID to filter the OSPF neighbors. + vrf : str, optional + VRF to filter the OSPF neighbors. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if port_id: + body["port_id"] = port_id + if vrf: + body["vrf"] = vrf + if neighbor: + body["neighbor"] = neighbor + trigger = devices.showSiteGatewayOspfNeighbors( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"OSPF neighbors command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger OSPF neighbors command: {trigger.status_code} - {trigger.data}" + ) # Give the OSPF neighbors command a moment to take effect + return util_response + + +def show_summary( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + vrf: str | None = None, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Shows OSPF summary on a device (SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show OSPF summary on. + node : Node, optional + Node information for the show OSPF summary command. + vrf : str, optional + VRF to filter the OSPF summary. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if vrf: + body["vrf"] = vrf + trigger = devices.showSiteGatewayOspfSummary( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"OSPF summary command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger OSPF summary command: {trigger.status_code} - {trigger.data}" + ) # Give the OSPF summary command a moment to take effect + return util_response diff --git a/src/mistapi/utils/policy.py b/src/mistapi/device_utils/__tools/policy.py similarity index 96% rename from src/mistapi/utils/policy.py rename to src/mistapi/device_utils/__tools/policy.py index 77828d5..2d57303 100644 --- a/src/mistapi/utils/policy.py +++ b/src/mistapi/device_utils/__tools/policy.py @@ -13,7 +13,7 @@ from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices -from mistapi.utils.__ws_wrapper import UtilResponse +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse async def clear_hit_count( diff --git a/src/mistapi/device_utils/__tools/port.py b/src/mistapi/device_utils/__tools/port.py new file mode 100644 index 0000000..d0c150e --- /dev/null +++ b/src/mistapi/device_utils/__tools/port.py @@ -0,0 +1,133 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +def bounce( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + timeout=60, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: EX, SRX, SSR + + Initiates a bounce command on the specified ports of a device (EX / SRX / SSR) and streams + the results. + + PARAMS + ----------- + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to perform the bounce command on. + port_ids : list[str] + List of port IDs to bounce. + timeout : int, default 5 + Timeout for the bounce command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if port_ids: + body["ports"] = port_ids + trigger = devices.bounceDevicePort( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info( + f"Bounce command triggered for ports {port_ids} on device {device_id}" + ) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger bounce command: {trigger.status_code} - {trigger.data}" + ) # Give the bounce command a moment to take effect + return util_response + + +def cable_test( + apissession: _APISession, + site_id: str, + device_id: str, + port_id: str, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: EX + + Initiates a cable test on a switch port and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to perform the cable test on. + port_id : str + Port ID to perform the cable test on. + timeout : int, optional + Timeout for the cable test command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"port": port_id} + trigger = devices.cableTestFromSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Cable test command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger cable test command: {trigger.status_code} - {trigger.data}" + ) # Give the cable test command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/__tools/remote_capture.py b/src/mistapi/device_utils/__tools/remote_capture.py new file mode 100644 index 0000000..90a438d --- /dev/null +++ b/src/mistapi/device_utils/__tools/remote_capture.py @@ -0,0 +1,444 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import pcaps +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import PcapEvents + + +def _build_pcap_body( + device_id: str, + port_ids: list[str], + device_key: str, + device_type: str, + tcpdump_expression: str | None, + duration: int, + max_pkt_len: int, + num_packets: int, + raw: bool | None = None, +) -> dict: + """Build the request body for remote pcap commands (SRX, SSR, EX).""" + mac = device_id.split("-")[-1] + body: dict = { + "duration": duration, + "max_pkt_len": max_pkt_len, + "num_packets": num_packets, + device_key: {mac: {"ports": {}}}, + "type": device_type, + "format": "stream", + } + if raw is not None: + body["raw"] = raw + for port_id in port_ids: + port_entry: dict = {} + if tcpdump_expression is not None: + port_entry["tcpdump_expression"] = tcpdump_expression + body[device_key][mac]["ports"][port_id] = port_entry + if tcpdump_expression: + body["tcpdump_expression"] = tcpdump_expression + return body + + +def ap_remote_pcap_wireless( + apissession: _APISession, + site_id: str, + device_id: str, + band: str, + tcpdump_expression: str | None = None, + ssid: str | None = None, + ap_mac: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: AP + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + band : str + Comma-separated list of radio bands (24, 5, or 6). + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "type mgt or type ctl -vvv -tttt -en" + ssid : str, optional + SSID to filter the wireless traffic. + ap_mac : str, optional + AP MAC address to filter the wireless traffic. + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | int] = { + "band": band, + "duration": duration, + "max_pkt_len": max_pkt_len, + "num_packets": num_packets, + "type": "radiotap", + "format": "stream", + } + if ssid: + body["ssid"] = ssid + if ap_mac: + body["ap_mac"] = ap_mac + if tcpdump_expression: + body["tcpdump_expression"] = tcpdump_expression + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +def ap_remote_pcap_wired( + apissession: _APISession, + site_id: str, + device_id: str, + tcpdump_expression: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: AP + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "udp port 67 or udp port 68 -vvv -tttt -en" + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | int] = { + "duration": duration, + "max_pkt_len": max_pkt_len, + "num_packets": num_packets, + "type": "wired", + "format": "stream", + } + if tcpdump_expression: + body["tcpdump_expression"] = tcpdump_expression + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +def srx_remote_pcap( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + tcpdump_expression: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: SRX + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + port_ids : list[str] + List of port IDs to monitor. + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "udp port 67 or udp port 68 -vvv -tttt -en" + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body = _build_pcap_body( + device_id, + port_ids, + "gateways", + "gateway", + tcpdump_expression, + duration, + max_pkt_len, + num_packets, + ) + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +def ssr_remote_pcap( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + tcpdump_expression: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: SSR + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + port_ids : list[str] + List of port IDs to monitor. + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "udp port 67 or udp port 68 -vvv -tttt -en" + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body = _build_pcap_body( + device_id, + port_ids, + "gateways", + "gateway", + tcpdump_expression, + duration, + max_pkt_len, + num_packets, + raw=False, + ) + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +def ex_remote_pcap( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + tcpdump_expression: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: EX + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + port_ids : list[str] + List of port IDs to monitor. + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "udp port 67 or udp port 68 -vvv -tttt -en" + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body = _build_pcap_body( + device_id, + port_ids, + "switches", + "switch", + tcpdump_expression, + duration, + max_pkt_len, + num_packets, + ) + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response diff --git a/src/mistapi/utils/routes.py b/src/mistapi/device_utils/__tools/routes.py similarity index 83% rename from src/mistapi/utils/routes.py rename to src/mistapi/device_utils/__tools/routes.py index ff0f511..6022f02 100644 --- a/src/mistapi/utils/routes.py +++ b/src/mistapi/device_utils/__tools/routes.py @@ -10,12 +10,14 @@ -------------------------------------------------------------------------------- """ +from collections.abc import Callable from enum import Enum from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices -from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents class Node(Enum): @@ -32,7 +34,7 @@ class RouteProtocol(Enum): STATIC = "static" -async def show( +def show( apissession: _APISession, site_id: str, device_id: str, @@ -42,6 +44,7 @@ async def show( route_type: str | None = None, vrf: str | None = None, timeout=2, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICE: SSR, SRX @@ -68,6 +71,8 @@ async def show( VRF to filter the routes. timeout : int, optional Timeout for the command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -97,9 +102,10 @@ async def show( if trigger.status_code == 200: LOGGER.info(trigger.data) print(f"Device Routes command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger Device Routes command: {trigger.status_code} - {trigger.data}" diff --git a/src/mistapi/utils/dns.py b/src/mistapi/device_utils/__tools/service_path.py similarity index 53% rename from src/mistapi/utils/dns.py rename to src/mistapi/device_utils/__tools/service_path.py index 75c18f6..5f53fc0 100644 --- a/src/mistapi/utils/dns.py +++ b/src/mistapi/device_utils/__tools/service_path.py @@ -10,33 +10,36 @@ -------------------------------------------------------------------------------- """ +from collections.abc import Callable from enum import Enum from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices -from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents class Node(Enum): - """Node Enum for specifying node information in DNS commands.""" + """Node Enum for specifying node information in service path commands.""" NODE0 = "node0" NODE1 = "node1" -async def test_resolution( +def show_service_path( apissession: _APISession, site_id: str, device_id: str, node: Node | None = None, - hostname: str | None = None, - timeout=5, + service_name: str | None = None, + timeout: int = 5, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICES: SSR - Initiates a DNS resolution command on the gateway and streams the results. + Initiates a show service path command on the gateway and streams the results. PARAMS ----------- @@ -45,13 +48,15 @@ async def test_resolution( site_id : str UUID of the site where the gateway is located. device_id : str - UUID of the gateway to perform the DNS resolution command on. + UUID of the gateway to perform the show service path command on. node : Node, optional - Node information for the DNS resolution command. - hostname : str, optional - Hostname to resolve. + Node information for the show service path command. + service_name : str, optional + Name of the service to show the path for. timeout : int, optional Timeout for the command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -62,23 +67,24 @@ async def test_resolution( body: dict[str, str | list | int] = {} if node: body["node"] = node.value - if hostname: - body["hostname"] = hostname - trigger = devices.testSiteSsrDnsResolution( + if service_name: + body["service_name"] = service_name + trigger = devices.showSiteSsrServicePath( apissession, site_id=site_id, device_id=device_id, - # body=body, + body=body, ) util_response = UtilResponse(trigger) if trigger.status_code == 200: LOGGER.info(trigger.data) - print(f"SSR DNS resolution command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + print(f"SSR service path command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( - f"Failed to trigger SSR DNS resolution command: {trigger.status_code} - {trigger.data}" - ) # Give the SSR DNS resolution command a moment to take effect + f"Failed to trigger SSR service path command: {trigger.status_code} - {trigger.data}" + ) # Give the SSR service path command a moment to take effect return util_response diff --git a/src/mistapi/utils/sessions.py b/src/mistapi/device_utils/__tools/sessions.py similarity index 82% rename from src/mistapi/utils/sessions.py rename to src/mistapi/device_utils/__tools/sessions.py index dde41f2..c019f10 100644 --- a/src/mistapi/utils/sessions.py +++ b/src/mistapi/device_utils/__tools/sessions.py @@ -10,12 +10,14 @@ -------------------------------------------------------------------------------- """ +from collections.abc import Callable from enum import Enum from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices -from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents class Node(Enum): @@ -25,7 +27,7 @@ class Node(Enum): NODE1 = "node1" -async def clear( +def clear( apissession: _APISession, site_id: str, device_id: str, @@ -34,6 +36,7 @@ async def clear( service_ids: list[str] | None = None, vrf: str | None = None, timeout=2, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICE: SSR, SRX @@ -60,6 +63,8 @@ async def clear( VRF to filter the routes. timeout : int, optional Timeout for the command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -87,9 +92,10 @@ async def clear( if trigger.status_code == 200: LOGGER.info(trigger.data) print(f"Device Sessions command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger Device Sessions command: {trigger.status_code} - {trigger.data}" @@ -97,7 +103,7 @@ async def clear( return util_response -async def show( +def show( apissession: _APISession, site_id: str, device_id: str, @@ -105,6 +111,7 @@ async def show( service_name: str | None = None, service_ids: list[str] | None = None, timeout=2, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICE: SSR, SRX @@ -127,6 +134,8 @@ async def show( List of service IDs to filter the sessions. timeout : int, optional Timeout for the command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -152,9 +161,10 @@ async def show( if trigger.status_code == 200: LOGGER.info(trigger.data) print(f"Device Sessions command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger Device Sessions command: {trigger.status_code} - {trigger.data}" diff --git a/src/mistapi/utils/ap.py b/src/mistapi/device_utils/ap.py similarity index 76% rename from src/mistapi/utils/ap.py rename to src/mistapi/device_utils/ap.py index b5320a8..73e34df 100644 --- a/src/mistapi/utils/ap.py +++ b/src/mistapi/device_utils/ap.py @@ -16,14 +16,16 @@ """ # Re-export shared classes and types -from mistapi.utils.arp import Node -from mistapi.utils.arp import retrieve_ap_arp_table as retrieve_arp_table -from mistapi.utils.tools import TracerouteProtocol, ping, traceroute +from mistapi.device_utils.__tools.arp import retrieve_ap_arp_table as retrieveArpTable +from mistapi.device_utils.__tools.miscellaneous import ( + TracerouteProtocol, + ping, + traceroute, +) __all__ = [ - "Node", "ping", "traceroute", "TracerouteProtocol", - "retrieve_arp_table", + "retrieveArpTable", ] diff --git a/src/mistapi/device_utils/bgp.py b/src/mistapi/device_utils/bgp.py new file mode 100644 index 0000000..f545c57 --- /dev/null +++ b/src/mistapi/device_utils/bgp.py @@ -0,0 +1,70 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +def summary( + apissession: _APISession, + site_id: str, + device_id: str, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: EX, SRX, SSR + + Shows BGP summary on a device (EX/ SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show BGP summary on. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"protocol": "bgp"} + trigger = devices.showSiteDeviceBgpSummary( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"BGP summary command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger BGP summary command: {trigger.status_code} - {trigger.data}" + ) # Give the BGP summary command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/bpdu.py b/src/mistapi/device_utils/bpdu.py new file mode 100644 index 0000000..c565903 --- /dev/null +++ b/src/mistapi/device_utils/bpdu.py @@ -0,0 +1,61 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse + + +async def clearError( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], +) -> UtilResponse: + """ + DEVICES: EX + + Clears BPDU error state on the specified ports of a switch. + + PARAMS + ----------- + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to clear BPDU errors on. + port_ids : list[str] + List of port IDs to clear BPDU errors on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {"ports": port_ids} + trigger = devices.clearBpduErrorsFromPortsOnSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear BPDU error command triggered for device {device_id}") + else: + LOGGER.error( + f"Failed to trigger clear BPDU error command: {trigger.status_code} - {trigger.data}" + ) # Give the clear BPDU error command a moment to take effect + return util_response diff --git a/src/mistapi/utils/dhcp.py b/src/mistapi/device_utils/dhcp.py similarity index 82% rename from src/mistapi/utils/dhcp.py rename to src/mistapi/device_utils/dhcp.py index 0738705..c967c34 100644 --- a/src/mistapi/utils/dhcp.py +++ b/src/mistapi/device_utils/dhcp.py @@ -10,12 +10,14 @@ -------------------------------------------------------------------------------- """ +from collections.abc import Callable from enum import Enum from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices -from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents class Node(Enum): @@ -25,7 +27,7 @@ class Node(Enum): NODE1 = "node1" -async def release_dhcp_leases( +def releaseDhcpLeases( apissession: _APISession, site_id: str, device_id: str, @@ -34,6 +36,7 @@ async def release_dhcp_leases( node: Node | None = None, port_id: str | None = None, timeout=5, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICES: EX, SRX, SSR @@ -70,6 +73,8 @@ async def release_dhcp_leases( Port ID to release DHCP leases for. timeout : int, optional Timeout for the release DHCP leases command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -95,9 +100,10 @@ async def release_dhcp_leases( util_response = UtilResponse(trigger) if trigger.status_code == 200: LOGGER.info(f"Release DHCP leases command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger release DHCP leases command: {trigger.status_code} - {trigger.data}" @@ -105,13 +111,14 @@ async def release_dhcp_leases( return util_response -async def retrieve_dhcp_leases( +def retrieveDhcpLeases( apissession: _APISession, site_id: str, device_id: str, network: str, node: Node | None = None, timeout=15, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICES: SRX, SSR @@ -134,6 +141,8 @@ async def retrieve_dhcp_leases( Port ID to release DHCP leases for. timeout : int, optional Timeout for the release DHCP leases command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -152,9 +161,10 @@ async def retrieve_dhcp_leases( util_response = UtilResponse(trigger) if trigger.status_code == 200: LOGGER.info(f"Retrieve DHCP leases command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger retrieve DHCP leases command: {trigger.status_code} - {trigger.data}" diff --git a/src/mistapi/device_utils/dot1x.py b/src/mistapi/device_utils/dot1x.py new file mode 100644 index 0000000..af5c322 --- /dev/null +++ b/src/mistapi/device_utils/dot1x.py @@ -0,0 +1,60 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse + + +async def clearSessions( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], +) -> UtilResponse: + """ + DEVICES: EX + + Clears dot1x sessions on the specified ports of a switch (EX). + + PARAMS + ----------- + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to clear dot1x sessions on. + port_ids : list[str] + List of port IDs to clear dot1x sessions on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"ports": port_ids} + trigger = devices.clearAllLearnedMacsFromPortOnSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear learned MACs command triggered for device {device_id}") + else: + LOGGER.error( + f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" + ) # Give the clear learned MACs command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/ex.py b/src/mistapi/device_utils/ex.py new file mode 100644 index 0000000..dd3680f --- /dev/null +++ b/src/mistapi/device_utils/ex.py @@ -0,0 +1,78 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- + +Utility functions for Juniper EX Switches. + +This module provides a device-specific namespace for EX switch utilities. +All functions are imported from their respective functional modules. +""" + +# Re-export shared classes and types + +# ARP functions +from mistapi.device_utils.__tools.arp import ( + retrieve_junos_arp_table as retrieveArpTable, +) + +# BGP functions +from mistapi.device_utils.__tools.bgp import summary as retrieveBgpSummary + +# BPDU functions +from mistapi.device_utils.__tools.bpdu import clear_error as clearBpduError + +# DHCP functions +from mistapi.device_utils.__tools.dhcp import release_dhcp_leases as releaseDhcpLeases +from mistapi.device_utils.__tools.dhcp import retrieve_dhcp_leases as retrieveDhcpLeases + +# Dot1x functions +from mistapi.device_utils.__tools.dot1x import clear_sessions as clearDot1xSessions + +# MAC table functions +from mistapi.device_utils.__tools.mac import clear_learned_mac as clearLearnedMac +from mistapi.device_utils.__tools.mac import clear_mac_table as clearMacTable +from mistapi.device_utils.__tools.mac import retrieve_mac_table as retrieveMacTable + +# Tools (ping, monitor traffic) +from mistapi.device_utils.__tools.miscellaneous import monitor_traffic as monitorTraffic +from mistapi.device_utils.__tools.miscellaneous import ping + +# Policy functions +from mistapi.device_utils.__tools.policy import clear_hit_count as clearHitCount + +# Port functions +from mistapi.device_utils.__tools.port import bounce as bouncePort +from mistapi.device_utils.__tools.port import cable_test as cableTest + +__all__ = [ + # ARP + "retrieveArpTable", + # BGP + "retrieveBgpSummary", + # BPDU + "clearBpduError", + # DHCP + "retrieveDhcpLeases", + "releaseDhcpLeases", + # Dot1x + "clearDot1xSessions", + # MAC + "clearLearnedMac", + "clearMacTable", + "retrieveMacTable", + # Policy + "clearHitCount", + # Port + "bouncePort", + "cableTest", + # Tools + "monitorTraffic", + "ping", +] diff --git a/src/mistapi/utils/ospf.py b/src/mistapi/device_utils/ospf.py similarity index 80% rename from src/mistapi/utils/ospf.py rename to src/mistapi/device_utils/ospf.py index 36ed711..4903a52 100644 --- a/src/mistapi/utils/ospf.py +++ b/src/mistapi/device_utils/ospf.py @@ -10,12 +10,14 @@ -------------------------------------------------------------------------------- """ +from collections.abc import Callable from enum import Enum from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices -from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents class Node(Enum): @@ -25,7 +27,7 @@ class Node(Enum): NODE1 = "node1" -async def show_database( +def showDatabase( apissession: _APISession, site_id: str, device_id: str, @@ -33,6 +35,7 @@ async def show_database( self_originate: bool | None = None, vrf: str | None = None, timeout=5, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICES: SRX, SSR @@ -54,6 +57,8 @@ async def show_database( Filter for self-originated routes in the OSPF database. vrf : str, optional VRF to filter the OSPF database. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -77,9 +82,10 @@ async def show_database( util_response = UtilResponse(trigger) if trigger.status_code == 200: LOGGER.info(f"OSPF database command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger OSPF database command: {trigger.status_code} - {trigger.data}" @@ -87,7 +93,7 @@ async def show_database( return util_response -async def show_interfaces( +def showInterfaces( apissession: _APISession, site_id: str, device_id: str, @@ -95,6 +101,7 @@ async def show_interfaces( port_id: str | None = None, vrf: str | None = None, timeout=5, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICES: SRX, SSR @@ -116,6 +123,8 @@ async def show_interfaces( Port ID to filter the OSPF interfaces. vrf : str, optional VRF to filter the OSPF interfaces. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -139,9 +148,10 @@ async def show_interfaces( util_response = UtilResponse(trigger) if trigger.status_code == 200: LOGGER.info(f"OSPF interfaces command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger OSPF interfaces command: {trigger.status_code} - {trigger.data}" @@ -149,7 +159,7 @@ async def show_interfaces( return util_response -async def show_neighbors( +def showNeighbors( apissession: _APISession, site_id: str, device_id: str, @@ -158,6 +168,7 @@ async def show_neighbors( port_id: str | None = None, vrf: str | None = None, timeout=5, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICES: SRX, SSR @@ -181,6 +192,8 @@ async def show_neighbors( Port ID to filter the OSPF neighbors. vrf : str, optional VRF to filter the OSPF neighbors. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -206,9 +219,10 @@ async def show_neighbors( util_response = UtilResponse(trigger) if trigger.status_code == 200: LOGGER.info(f"OSPF neighbors command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger OSPF neighbors command: {trigger.status_code} - {trigger.data}" @@ -216,13 +230,14 @@ async def show_neighbors( return util_response -async def show_summary( +def showSummary( apissession: _APISession, site_id: str, device_id: str, node: Node | None = None, vrf: str | None = None, timeout=5, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICES: SRX, SSR @@ -242,6 +257,8 @@ async def show_summary( Node information for the show OSPF summary command. vrf : str, optional VRF to filter the OSPF summary. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -263,9 +280,10 @@ async def show_summary( util_response = UtilResponse(trigger) if trigger.status_code == 200: LOGGER.info(f"OSPF summary command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger OSPF summary command: {trigger.status_code} - {trigger.data}" diff --git a/src/mistapi/device_utils/policy.py b/src/mistapi/device_utils/policy.py new file mode 100644 index 0000000..ba8d606 --- /dev/null +++ b/src/mistapi/device_utils/policy.py @@ -0,0 +1,62 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse + + +async def clearHitCount( + apissession: _APISession, + site_id: str, + device_id: str, + policy_name: str, +) -> UtilResponse: + """ + DEVICE: EX + + Clears the policy hit count on a device. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to clear the policy hit count on. + policy_name : str + Name of the policy to clear the hit count for. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + trigger = devices.clearSiteDevicePolicyHitCount( + apissession, + site_id=site_id, + device_id=device_id, + body={"policy_name": policy_name}, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Clear policy hit count command triggered for device {device_id}") + # util_response = await WebSocketWrapper( + # apissession, util_response, timeout=timeout + # ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger clear policy hit count command: {trigger.status_code} - {trigger.data}" + ) # Give the clear policy hit count command a moment to take effect + return util_response diff --git a/src/mistapi/utils/port.py b/src/mistapi/device_utils/port.py similarity index 76% rename from src/mistapi/utils/port.py rename to src/mistapi/device_utils/port.py index b13040a..5757c0f 100644 --- a/src/mistapi/utils/port.py +++ b/src/mistapi/device_utils/port.py @@ -10,18 +10,22 @@ -------------------------------------------------------------------------------- """ +from collections.abc import Callable + from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices -from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents -async def bounce( +def bounce( apissession: _APISession, site_id: str, device_id: str, port_ids: list[str], timeout=60, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICE: EX, SRX, SSR @@ -37,8 +41,10 @@ async def bounce( UUID of the device to perform the bounce command on. port_ids : list[str] List of port IDs to bounce. - timeout : int, async default 5 + timeout : int, default 5 Timeout for the bounce command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -60,9 +66,10 @@ async def bounce( LOGGER.info( f"Bounce command triggered for ports {port_ids} on device {device_id}" ) - util_response = await WebSocketWrapper( - apissession, util_response, timeout - ).startCmdEvents(site_id=site_id, device_id=device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger bounce command: {trigger.status_code} - {trigger.data}" @@ -70,12 +77,13 @@ async def bounce( return util_response -async def cable_test( +def cableTest( apissession: _APISession, site_id: str, device_id: str, port_id: str, timeout=10, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICES: EX @@ -94,6 +102,8 @@ async def cable_test( Port ID to perform the cable test on. timeout : int, optional Timeout for the cable test command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -112,9 +122,10 @@ async def cable_test( if trigger.status_code == 200: LOGGER.info(trigger.data) print(f"Cable test command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger cable test command: {trigger.status_code} - {trigger.data}" diff --git a/src/mistapi/utils/service_path.py b/src/mistapi/device_utils/service_path.py similarity index 80% rename from src/mistapi/utils/service_path.py rename to src/mistapi/device_utils/service_path.py index 1302b9c..2973c23 100644 --- a/src/mistapi/utils/service_path.py +++ b/src/mistapi/device_utils/service_path.py @@ -10,12 +10,14 @@ -------------------------------------------------------------------------------- """ +from collections.abc import Callable from enum import Enum from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices -from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents class Node(Enum): @@ -25,13 +27,14 @@ class Node(Enum): NODE1 = "node1" -async def show_service_path( +def showServicePath( apissession: _APISession, site_id: str, device_id: str, node: Node | None = None, service_name: str | None = None, timeout: int = 5, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICES: SSR @@ -52,6 +55,8 @@ async def show_service_path( Name of the service to show the path for. timeout : int, optional Timeout for the command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -74,9 +79,10 @@ async def show_service_path( if trigger.status_code == 200: LOGGER.info(trigger.data) print(f"SSR service path command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger SSR service path command: {trigger.status_code} - {trigger.data}" diff --git a/src/mistapi/device_utils/sessions.py b/src/mistapi/device_utils/sessions.py new file mode 100644 index 0000000..c019f10 --- /dev/null +++ b/src/mistapi/device_utils/sessions.py @@ -0,0 +1,172 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +class Node(Enum): + """Node Enum for specifying node information in session commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +def clear( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + service_name: str | None = None, + service_ids: list[str] | None = None, + vrf: str | None = None, + timeout=2, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: SSR, SRX + + Initiates a clear sessions command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the show routes command on. + node : Node, optional + Node information for the show routes command. + prefix : str, optional + Prefix to filter the routes. + protocol : RouteProtocol, optional + Protocol to filter the routes. + route_type : str, optional + Type of the route to filter. + vrf : str, optional + VRF to filter the routes. + timeout : int, optional + Timeout for the command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if service_name: + body["service_name"] = service_name + if service_ids: + body["service_ids"] = service_ids + if vrf: + body["vrf"] = vrf + trigger = devices.clearSiteDeviceSession( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Device Sessions command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger Device Sessions command: {trigger.status_code} - {trigger.data}" + ) # Give the Device Sessions command a moment to take effect + return util_response + + +def show( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + service_name: str | None = None, + service_ids: list[str] | None = None, + timeout=2, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: SSR, SRX + + Initiates a show sessions command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the show sessions command on. + node : Node, optional + Node information for the show sessions command. + service_name : str, optional + Name of the service to filter the sessions. + service_ids : list[str], optional + List of service IDs to filter the sessions. + timeout : int, optional + Timeout for the command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if service_name: + body["service_name"] = service_name + if service_ids: + body["service_ids"] = service_ids + trigger = devices.showSiteSsrAndSrxSessions( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Device Sessions command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger Device Sessions command: {trigger.status_code} - {trigger.data}" + ) # Give the Device Sessions command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/srx.py b/src/mistapi/device_utils/srx.py new file mode 100644 index 0000000..a93f124 --- /dev/null +++ b/src/mistapi/device_utils/srx.py @@ -0,0 +1,69 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- + +Utility functions for Juniper SRX Firewalls. + +This module provides a device-specific namespace for SRX firewall utilities. +All functions are imported from their respective functional modules. +""" + +# Re-export shared classes and types +from mistapi.device_utils.__tools.arp import Node + +# ARP functions +from mistapi.device_utils.__tools.arp import ( + retrieve_junos_arp_table as retrieveArpTable, +) + +# BGP functions +from mistapi.device_utils.__tools.bgp import summary as retrieveBgpSummary + +# DHCP functions +from mistapi.device_utils.__tools.dhcp import release_dhcp_leases as releaseDhcpLeases +from mistapi.device_utils.__tools.dhcp import retrieve_dhcp_leases as retrieveDhcpLeases + +# Tools (ping, monitor traffic) +from mistapi.device_utils.__tools.miscellaneous import monitor_traffic as monitorTraffic +from mistapi.device_utils.__tools.miscellaneous import ping + +# OSPF functions +from mistapi.device_utils.__tools.ospf import show_database as showDatabase +from mistapi.device_utils.__tools.ospf import show_interfaces as showInterfaces +from mistapi.device_utils.__tools.ospf import show_neighbors as showNeighbors + +# Port functions +from mistapi.device_utils.__tools.port import bounce as bouncePort + +# Route functions +from mistapi.device_utils.__tools.routes import show as retrieveRoutes + +__all__ = [ + # Classes/Enums + "Node", + # ARP + "retrieveArpTable", + # BGP + "retrieveBgpSummary", + # DHCP + "releaseDhcpLeases", + "retrieveDhcpLeases", + # OSPF + "showDatabase", + "showNeighbors", + "showInterfaces", + # Port + "bouncePort", + # Routes + "retrieveRoutes", + # Tools + "monitorTraffic", + "ping", +] diff --git a/src/mistapi/device_utils/ssr.py b/src/mistapi/device_utils/ssr.py new file mode 100644 index 0000000..d68abd5 --- /dev/null +++ b/src/mistapi/device_utils/ssr.py @@ -0,0 +1,76 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- + +Utility functions for Juniper Session Smart Routers (SSR). + +This module provides a device-specific namespace for SSR router utilities. +All functions are imported from their respective functional modules. +""" + +# Re-export shared classes and types +from mistapi.device_utils.__tools.arp import Node + +# ARP functions +from mistapi.device_utils.__tools.arp import ( + retrieve_ssr_arp_table as retrieveArpTable, +) + +# BGP functions +from mistapi.device_utils.__tools.bgp import summary as retrieveBgpSummary + +# DHCP functions +from mistapi.device_utils.__tools.dhcp import release_dhcp_leases as releaseDhcpLeases +from mistapi.device_utils.__tools.dhcp import retrieve_dhcp_leases as retrieveDhcpLeases + +# Tools (ping only - no monitor_traffic for SSR) +from mistapi.device_utils.__tools.miscellaneous import ping + +# DNS functions +# from mistapi.utils.dns import test_resolution as test_dns_resolution +# OSPF functions +from mistapi.device_utils.__tools.ospf import show_database as showDatabase +from mistapi.device_utils.__tools.ospf import show_interfaces as showInterfaces +from mistapi.device_utils.__tools.ospf import show_neighbors as showNeighbors + +# Port functions +from mistapi.device_utils.__tools.port import bounce as bouncePort + +# Route functions +from mistapi.device_utils.__tools.routes import show as retrieveRoutes + +# Service Path functions +from mistapi.device_utils.__tools.service_path import show_service_path as showServicePath + +__all__ = [ + # Classes/Enums + "Node", + # ARP + "retrieveArpTable", + # BGP + "retrieveBgpSummary", + # DHCP + "releaseDhcpLeases", + "retrieveDhcpLeases", + # DNS + # "test_dns_resolution", + # OSPF + "showDatabase", + "showNeighbors", + "showInterfaces", + # Port + "bouncePort", + # Routes + "retrieveRoutes", + # Service Path + "showServicePath", + # Tools + "ping", +] diff --git a/src/mistapi/utils/tools.py b/src/mistapi/device_utils/tools.py similarity index 78% rename from src/mistapi/utils/tools.py rename to src/mistapi/device_utils/tools.py index aca66c1..8a95822 100644 --- a/src/mistapi/utils/tools.py +++ b/src/mistapi/device_utils/tools.py @@ -1,9 +1,12 @@ +from collections.abc import Callable from enum import Enum from mistapi import APISession as _APISession from mistapi.__logger import logger as LOGGER from mistapi.api.v1.sites import devices, pcaps -from mistapi.utils.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.session import SessionWithUrl +from mistapi.websockets.sites import DeviceCmdEvents, PcapEvents class Node(Enum): @@ -20,16 +23,50 @@ class TracerouteProtocol(Enum): UDP = "udp" -async def ping( +def _build_pcap_body( + device_id: str, + port_ids: list[str], + device_key: str, + device_type: str, + tcpdump_expression: str | None, + duration: int, + max_pkt_len: int, + num_packets: int, + raw: bool | None = None, +) -> dict: + """Build the request body for remote pcap commands (SRX, SSR, EX).""" + mac = device_id.split("-")[-1] + body: dict = { + "duration": duration, + "max_pkt_len": max_pkt_len, + "num_packets": num_packets, + device_key: {mac: {"ports": {}}}, + "type": device_type, + "format": "stream", + } + if raw is not None: + body["raw"] = raw + for port_id in port_ids: + port_entry: dict = {} + if tcpdump_expression is not None: + port_entry["tcpdump_expression"] = tcpdump_expression + body[device_key][mac]["ports"][port_id] = port_entry + if tcpdump_expression: + body["tcpdump_expression"] = tcpdump_expression + return body + + +def ping( apissession: _APISession, site_id: str, device_id: str, host: str, count: int | None = None, - node: None | None = None, + node: Node | None = None, size: int | None = None, vrf: str | None = None, timeout: int = 3, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICES: AP, EX, SRX, SSR @@ -57,6 +94,8 @@ async def ping( VRF to use for the ping command. timeout : int, optional Timeout for the ping command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -84,9 +123,10 @@ async def ping( util_response = UtilResponse(trigger) if trigger.status_code == 200: LOGGER.info(f"Ping command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger ping command: {trigger.status_code} - {trigger.data}" @@ -94,8 +134,8 @@ async def ping( return util_response -## NO DATA -# async def service_ping( +## NO DATA +# def service_ping( # apissession: _APISession, # site_id: str, # device_id: str, @@ -106,6 +146,7 @@ async def ping( # node: None | None = None, # size: int | None = None, # timeout: int = 3, +# on_message: Callable[[dict], None] | None = None, # ) -> UtilResponse: # """ # DEVICES: SSR @@ -134,6 +175,8 @@ async def ping( # Size of the ping packet. # timeout : int, optional # Timeout for the ping command in seconds. +# on_message : Callable, optional +# Callback invoked with each extracted raw message as it arrives. # RETURNS # ----------- @@ -163,9 +206,10 @@ async def ping( # util_response = UtilResponse(trigger) # if trigger.status_code == 200: # LOGGER.info(f"Service Ping command triggered for device {device_id}") -# util_response = await WebSocketWrapper( -# apissession, util_response, timeout -# ).startCmdEvents(site_id, device_id) +# ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) +# util_response = WebSocketWrapper( +# apissession, util_response, timeout, on_message=on_message +# ).start(ws) # else: # LOGGER.error( # f"Failed to trigger Service Ping command: {trigger.status_code} - {trigger.data}" @@ -173,7 +217,7 @@ async def ping( # return util_response -async def traceroute( +def traceroute( apissession: _APISession, site_id: str, device_id: str, @@ -181,6 +225,7 @@ async def traceroute( protocol: TracerouteProtocol = TracerouteProtocol.ICMP, port: int | None = None, timeout: int = 10, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICES: AP, EX, SRX, SSR @@ -204,6 +249,8 @@ async def traceroute( Port to use for UDP traceroute. timeout : int, optional Timeout for the traceroute command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -225,9 +272,10 @@ async def traceroute( util_response = UtilResponse(trigger) if trigger.status_code == 200: LOGGER.info(f"Traceroute command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout - ).startCmdEvents(site_id, device_id) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger traceroute command: {trigger.status_code} - {trigger.data}" @@ -235,12 +283,13 @@ async def traceroute( return util_response -async def monitor_traffic( +def monitorTraffic( apissession: _APISession, site_id: str, device_id: str, port_id: str | None = None, timeout=30, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICE: EX, SRX @@ -263,6 +312,8 @@ async def monitor_traffic( Port ID to filter the traffic. timeout : int, optional Timeout for the monitor traffic command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -283,9 +334,10 @@ async def monitor_traffic( if trigger.status_code == 200: LOGGER.info(trigger.data) print(f"Monitor traffic command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startSessionUrl(trigger.data.get("url", "")) + ws = SessionWithUrl(apissession, url=trigger.data.get("url", "")) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger monitor traffic command: {trigger.status_code} - {trigger.data}" @@ -293,7 +345,7 @@ async def monitor_traffic( return util_response -async def ap_remote_pcap_wireless( +def apRemotePcapWireless( apissession: _APISession, site_id: str, device_id: str, @@ -305,6 +357,7 @@ async def ap_remote_pcap_wireless( max_pkt_len: int = 512, num_packets: int = 1024, timeout=10, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICE: AP @@ -336,6 +389,8 @@ async def ap_remote_pcap_wireless( Maximum number of packets to capture (default: 1024). timeout : int, optional Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -366,9 +421,10 @@ async def ap_remote_pcap_wireless( if trigger.status_code == 200: LOGGER.info(trigger.data) print(f"Remote pcap command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startRemotePcap(site_id) + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" @@ -376,7 +432,7 @@ async def ap_remote_pcap_wireless( return util_response -async def ap_remote_pcap_wired( +def apRemotePcapWired( apissession: _APISession, site_id: str, device_id: str, @@ -385,6 +441,7 @@ async def ap_remote_pcap_wired( max_pkt_len: int = 512, num_packets: int = 1024, timeout=10, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICE: AP @@ -410,6 +467,8 @@ async def ap_remote_pcap_wired( Maximum number of packets to capture (default: 1024). timeout : int, optional Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -435,9 +494,10 @@ async def ap_remote_pcap_wired( if trigger.status_code == 200: LOGGER.info(trigger.data) print(f"Remote pcap command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startRemotePcap(site_id) + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" @@ -445,7 +505,7 @@ async def ap_remote_pcap_wired( return util_response -async def srx_remote_pcap( +def srxRemotePcap( apissession: _APISession, site_id: str, device_id: str, @@ -455,6 +515,7 @@ async def srx_remote_pcap( max_pkt_len: int = 512, num_packets: int = 1024, timeout=10, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICE: SRX @@ -482,6 +543,8 @@ async def srx_remote_pcap( Maximum number of packets to capture (default: 1024). timeout : int, optional Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -489,25 +552,16 @@ async def srx_remote_pcap( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ - gateway_mac = device_id.split("-")[-1] - body: dict[str, str | int | dict] = { - "duration": duration, - "max_pkt_len": max_pkt_len, - "num_packets": num_packets, - "gateways": {gateway_mac: {"ports": {}}}, - "type": "gateway", - "format": "stream", - } - for port_id in port_ids: - gateway_dict = body["gateways"] - assert isinstance(gateway_dict, dict) - mac_dict = gateway_dict[gateway_mac] - assert isinstance(mac_dict, dict) - ports_dict = mac_dict["ports"] - assert isinstance(ports_dict, dict) - ports_dict[port_id] = {"tcpdump_expression": tcpdump_expression} - if tcpdump_expression: - body["tcpdump_expression"] = tcpdump_expression + body = _build_pcap_body( + device_id, + port_ids, + "gateways", + "gateway", + tcpdump_expression, + duration, + max_pkt_len, + num_packets, + ) trigger = pcaps.startSitePacketCapture( apissession, site_id=site_id, @@ -517,9 +571,10 @@ async def srx_remote_pcap( if trigger.status_code == 200: LOGGER.info(trigger.data) print(f"Remote pcap command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startRemotePcap(site_id) + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" @@ -527,7 +582,7 @@ async def srx_remote_pcap( return util_response -async def ssr_remote_pcap( +def ssrRemotePcap( apissession: _APISession, site_id: str, device_id: str, @@ -537,6 +592,7 @@ async def ssr_remote_pcap( max_pkt_len: int = 512, num_packets: int = 1024, timeout=10, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICE: SSR @@ -564,6 +620,8 @@ async def ssr_remote_pcap( Maximum number of packets to capture (default: 1024). timeout : int, optional Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -571,26 +629,17 @@ async def ssr_remote_pcap( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ - gateway_mac = device_id.split("-")[-1] - body: dict[str, str | int | dict] = { - "duration": duration, - "max_pkt_len": max_pkt_len, - "num_packets": num_packets, - "raw": False, - "gateways": {gateway_mac: {"ports": {}}}, - "type": "gateway", - "format": "stream", - } - for port_id in port_ids: - gateway_dict = body["gateways"] - assert isinstance(gateway_dict, dict) - mac_dict = gateway_dict[gateway_mac] - assert isinstance(mac_dict, dict) - ports_dict = mac_dict["ports"] - assert isinstance(ports_dict, dict) - ports_dict[port_id] = {"tcpdump_expression": tcpdump_expression} - if tcpdump_expression: - body["tcpdump_expression"] = tcpdump_expression + body = _build_pcap_body( + device_id, + port_ids, + "gateways", + "gateway", + tcpdump_expression, + duration, + max_pkt_len, + num_packets, + raw=False, + ) trigger = pcaps.startSitePacketCapture( apissession, site_id=site_id, @@ -600,9 +649,10 @@ async def ssr_remote_pcap( if trigger.status_code == 200: LOGGER.info(trigger.data) print(f"Remote pcap command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startRemotePcap(site_id) + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" @@ -610,7 +660,7 @@ async def ssr_remote_pcap( return util_response -async def ex_remote_pcap( +def exRemotePcap( apissession: _APISession, site_id: str, device_id: str, @@ -620,6 +670,7 @@ async def ex_remote_pcap( max_pkt_len: int = 512, num_packets: int = 1024, timeout=10, + on_message: Callable[[dict], None] | None = None, ) -> UtilResponse: """ DEVICE: EX @@ -647,6 +698,8 @@ async def ex_remote_pcap( Maximum number of packets to capture (default: 1024). timeout : int, optional Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. RETURNS ----------- @@ -654,25 +707,16 @@ async def ex_remote_pcap( A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. """ - switch_mac = device_id.split("-")[-1] - body: dict[str, str | int | dict] = { - "duration": duration, - "max_pkt_len": max_pkt_len, - "num_packets": num_packets, - "switches": {switch_mac: {"ports": {}}}, - "type": "switch", - "format": "stream", - } - for port_id in port_ids: - switch_dict = body["switches"] - assert isinstance(switch_dict, dict) - mac_dict = switch_dict[switch_mac] - assert isinstance(mac_dict, dict) - ports_dict = mac_dict["ports"] - assert isinstance(ports_dict, dict) - ports_dict[port_id] = {"tcpdump_expression": tcpdump_expression} - if tcpdump_expression: - body["tcpdump_expression"] = tcpdump_expression + body = _build_pcap_body( + device_id, + port_ids, + "switches", + "switch", + tcpdump_expression, + duration, + max_pkt_len, + num_packets, + ) trigger = pcaps.startSitePacketCapture( apissession, site_id=site_id, @@ -682,9 +726,10 @@ async def ex_remote_pcap( if trigger.status_code == 200: LOGGER.info(trigger.data) print(f"Remote pcap command triggered for device {device_id}") - util_response = await WebSocketWrapper( - apissession, util_response, timeout=timeout - ).startRemotePcap(site_id) + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) else: LOGGER.error( f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" @@ -693,11 +738,12 @@ async def ex_remote_pcap( ## NO DATA -# async def srx_top_command( +# def srx_top_command( # apissession: _APISession, # site_id: str, # device_id: str, # timeout=10, +# on_message: Callable[[dict], None] | None = None, # ) -> UtilResponse: # """ # DEVICE: SRX @@ -714,6 +760,8 @@ async def ex_remote_pcap( # UUID of the device to run the top command on. # timeout : int, optional # Timeout for the top command in seconds. +# on_message : Callable, optional +# Callback invoked with each extracted raw message as it arrives. # RETURNS # ----------- @@ -730,9 +778,10 @@ async def ex_remote_pcap( # if trigger.status_code == 200: # LOGGER.info(trigger.data) # print(f"Top command triggered for device {device_id}") -# util_response = await WebSocketWrapper( -# apissession, util_response, timeout=timeout -# ).startSessionUrl(site_id) +# ws = SessionWithUrl(apissession, url=trigger.data.get("url", "")) +# util_response = WebSocketWrapper( +# apissession, util_response, timeout=timeout, on_message=on_message +# ).start(ws) # else: # LOGGER.error( # f"Failed to trigger top command: {trigger.status_code} - {trigger.data}" diff --git a/src/mistapi/utils/ex.py b/src/mistapi/utils/ex.py deleted file mode 100644 index f9c5455..0000000 --- a/src/mistapi/utils/ex.py +++ /dev/null @@ -1,78 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- - -Utility functions for Juniper EX Switches. - -This module provides a device-specific namespace for EX switch utilities. -All functions are imported from their respective functional modules. -""" - -# Re-export shared classes and types -from mistapi.utils.arp import Node - -# ARP functions -from mistapi.utils.arp import retrieve_junos_arp_table as retrieve_arp_table - -# BGP functions -from mistapi.utils.bgp import show_summary as show_bgp_summary - -# BPDU functions -from mistapi.utils.bpdu import clear_error as clear_bpdu_error - -# DHCP functions -from mistapi.utils.dhcp import release_dhcp_leases - -# Dot1x functions -from mistapi.utils.dot1x import clear_sessions as clear_dot1x_sessions - -# MAC table functions -from mistapi.utils.mac import ( - clear_learned_mac, - clear_mac_table, - retrieve_mac_table, -) - -# Policy functions -from mistapi.utils.policy import clear_hit_count - -# Port functions -from mistapi.utils.port import bounce as bounce_port -from mistapi.utils.port import cable_test - -# Tools (ping, monitor traffic) -from mistapi.utils.tools import monitor_traffic, ping - -__all__ = [ - # Classes/Enums - "Node", - # ARP - "retrieve_arp_table", - # BGP - "show_bgp_summary", - # BPDU - "clear_bpdu_error", - # DHCP - "release_dhcp_leases", - # Dot1x - "clear_dot1x_sessions", - # MAC - "clear_learned_mac", - "clear_mac_table", - "retrieve_mac_table", - # Port - "bounce_port", - "cable_test", - # Policy - "clear_hit_count", - # Tools - "monitor_traffic", - "ping", -] diff --git a/src/mistapi/utils/srx.py b/src/mistapi/utils/srx.py deleted file mode 100644 index 6d8148b..0000000 --- a/src/mistapi/utils/srx.py +++ /dev/null @@ -1,61 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- - -Utility functions for Juniper SRX Firewalls. - -This module provides a device-specific namespace for SRX firewall utilities. -All functions are imported from their respective functional modules. -""" - -# Re-export shared classes and types -from mistapi.utils.arp import Node - -# ARP functions -from mistapi.utils.arp import retrieve_junos_arp_table as retrieve_arp_table - -# BGP functions -from mistapi.utils.bgp import show_summary as show_bgp_summary - -# DHCP functions -from mistapi.utils.dhcp import release_dhcp_leases, retrieve_dhcp_leases - -# Policy functions -from mistapi.utils.policy import clear_hit_count - -# Port functions -from mistapi.utils.port import bounce as bounce_port - -# Route functions -from mistapi.utils.routes import show - -# Tools (ping, monitor traffic) -from mistapi.utils.tools import monitor_traffic, ping - -__all__ = [ - # Classes/Enums - "Node", - # ARP - "retrieve_arp_table", - # BGP - "show_bgp_summary", - # DHCP - "release_dhcp_leases", - "retrieve_dhcp_leases", - # Port - "bounce_port", - # Policy - "clear_hit_count", - # Routes - "show", - # Tools - "monitor_traffic", - "ping", -] diff --git a/src/mistapi/utils/ssr.py b/src/mistapi/utils/ssr.py deleted file mode 100644 index 9f7afad..0000000 --- a/src/mistapi/utils/ssr.py +++ /dev/null @@ -1,65 +0,0 @@ -""" --------------------------------------------------------------------------------- -------------------------- Mist API Python CLI Session -------------------------- - - Written by: Thomas Munzer (tmunzer@juniper.net) - Github : https://github.com/tmunzer/mistapi_python - - This package is licensed under the MIT License. - --------------------------------------------------------------------------------- - -Utility functions for Juniper Session Smart Routers (SSR). - -This module provides a device-specific namespace for SSR router utilities. -All functions are imported from their respective functional modules. -""" - -# Re-export shared classes and types -from mistapi.utils.arp import Node - -# ARP functions -from mistapi.utils.arp import retrieve_ssr_arp_table as retrieve_arp_table - -# BGP functions -from mistapi.utils.bgp import show_summary as show_bgp_summary - -# DHCP functions -from mistapi.utils.dhcp import release_dhcp_leases, retrieve_dhcp_leases - -# DNS functions -from mistapi.utils.dns import test_resolution as test_dns_resolution - -# Policy functions -from mistapi.utils.policy import clear_hit_count - -# Port functions -from mistapi.utils.port import bounce as bounce_port - -# Service Path functions -from mistapi.utils.service_path import show_service_path - -# Tools (ping only - no monitor_traffic for SSR) -from mistapi.utils.tools import ping - -__all__ = [ - # Classes/Enums - "Node", - # ARP - "retrieve_arp_table", - # BGP - "show_bgp_summary", - # DHCP - "release_dhcp_leases", - "retrieve_dhcp_leases", - # DNS - "test_dns_resolution", - # Port - "bounce_port", - # Policy - "clear_hit_count", - # Service Path - "show_service_path", - # Tools - "ping", -] diff --git a/src/mistapi/websockets/__init__.py b/src/mistapi/websockets/__init__.py index 81203b7..e269ee6 100644 --- a/src/mistapi/websockets/__init__.py +++ b/src/mistapi/websockets/__init__.py @@ -10,12 +10,11 @@ -------------------------------------------------------------------------------- """ -from mistapi.websockets import __ws_client, location, orgs, session, sites +from mistapi.websockets import location, orgs, session, sites __all__ = [ "location", "orgs", "session", "sites", - "__ws_client", ] diff --git a/src/mistapi/websockets/__ws_client.py b/src/mistapi/websockets/__ws_client.py index 29d15fa..cb4811d 100644 --- a/src/mistapi/websockets/__ws_client.py +++ b/src/mistapi/websockets/__ws_client.py @@ -14,12 +14,15 @@ import json import queue +import ssl import threading from collections.abc import Callable, Generator from typing import TYPE_CHECKING import websocket +from mistapi.__logger import logger + if TYPE_CHECKING: from mistapi import APISession @@ -50,6 +53,9 @@ def __init__( self._ws: websocket.WebSocketApp | None = None self._thread: threading.Thread | None = None self._queue: queue.Queue[dict | None] = queue.Queue() + self._connected = ( + threading.Event() + ) # tracks whether the WebSocket connection is currently open self._on_message_cb: Callable[[dict], None] | None = None self._on_error_cb: Callable[[Exception], None] | None = None self._on_open_cb: Callable[[], None] | None = None @@ -70,10 +76,38 @@ def _get_headers(self) -> dict: def _get_cookie(self) -> str | None: cookies = self._mist_session._session.cookies if cookies: - pairs = "; ".join(f"{c.name}={c.value}" for c in cookies) - return pairs if pairs else None + safe = [] + for c in cookies: + has_crlf = "\r" in c.name or "\n" in c.name or ( + c.value and ("\r" in c.value or "\n" in c.value) + ) + if has_crlf: + logger.warning( + "Skipping cookie %r: contains CRLF characters (possible header injection)", + c.name, + ) + continue + safe.append(f"{c.name}={c.value}") + return "; ".join(safe) if safe else None return None + def _build_sslopt(self) -> dict: + """Build SSL options from the APISession's requests.Session.""" + sslopt: dict = {} + session = self._mist_session._session + if session.verify is False: + sslopt["cert_reqs"] = ssl.CERT_NONE + elif isinstance(session.verify, str): + sslopt["ca_certs"] = session.verify + if session.cert: + if isinstance(session.cert, str): + sslopt["certfile"] = session.cert + elif isinstance(session.cert, tuple): + sslopt["certfile"] = session.cert[0] + if len(session.cert) > 1: + sslopt["keyfile"] = session.cert[1] + return sslopt + # ------------------------------------------------------------------ # Callback registration @@ -99,6 +133,7 @@ def on_close(self, callback: Callable[[int, str], None]) -> None: def _handle_open(self, ws: websocket.WebSocketApp) -> None: for channel in self._channels: ws.send(json.dumps({"subscribe": channel})) + self._connected.set() if self._on_open_cb: self._on_open_cb() @@ -121,6 +156,7 @@ def _handle_close( close_status_code: int, close_msg: str, ) -> None: + self._connected.clear() self._queue.put(None) # Signals receive() generator to stop if self._on_close_cb: self._on_close_cb(close_status_code, close_msg) @@ -138,6 +174,13 @@ def connect(self, run_in_background: bool = True) -> None: If True, runs the WebSocket loop in a daemon thread (non-blocking). If False, blocks the calling thread until disconnected. """ + # Drain stale sentinel from previous connection + while not self._queue.empty(): + try: + self._queue.get_nowait() + except queue.Empty: + break + self._ws = websocket.WebSocketApp( self._build_ws_url(), header=self._get_headers(), @@ -156,8 +199,11 @@ def connect(self, run_in_background: bool = True) -> None: def _run_forever_safe(self) -> None: if self._ws: try: + sslopt = self._build_sslopt() self._ws.run_forever( - ping_interval=self._ping_interval, ping_timeout=self._ping_timeout + ping_interval=self._ping_interval, + ping_timeout=self._ping_timeout, + sslopt=sslopt, ) except Exception as exc: self._handle_error(self._ws, exc) @@ -177,8 +223,15 @@ def receive(self) -> Generator[dict, None, None]: Intended for use after connect(run_in_background=True). """ + if not self._connected.wait(timeout=10): + return while True: - item = self._queue.get() + try: + item = self._queue.get(timeout=1) + except queue.Empty: + if not self._connected.is_set() and self._queue.empty(): + break + continue if item is None: break yield item diff --git a/src/mistapi/websockets/location.py b/src/mistapi/websockets/location.py index f010240..2c40842 100644 --- a/src/mistapi/websockets/location.py +++ b/src/mistapi/websockets/location.py @@ -41,7 +41,7 @@ class BleAssetsEvents(_MistWebsocket): ws = LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") ws.on_message(lambda data: print(data)) - ws.connect() + ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") ws.disconnect() @@ -56,6 +56,7 @@ class BleAssetsEvents(_MistWebsocket): with LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") as ws: ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread time.sleep(60) """ @@ -101,7 +102,7 @@ class ConnectedClientsEvents(_MistWebsocket): ws = LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") ws.on_message(lambda data: print(data)) - ws.connect() + ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") ws.disconnect() @@ -116,6 +117,7 @@ class ConnectedClientsEvents(_MistWebsocket): with LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") as ws: ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread time.sleep(60) """ @@ -161,7 +163,7 @@ class SdkClientsEvents(_MistWebsocket): ws = LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") ws.on_message(lambda data: print(data)) - ws.connect() + ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") ws.disconnect() @@ -176,6 +178,7 @@ class SdkClientsEvents(_MistWebsocket): with LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") as ws: ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread time.sleep(60) """ @@ -221,7 +224,7 @@ class UnconnectedClientsEvents(_MistWebsocket): ws = LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") ws.on_message(lambda data: print(data)) - ws.connect() + ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") ws.disconnect() @@ -236,6 +239,7 @@ class UnconnectedClientsEvents(_MistWebsocket): with LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") as ws: ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread time.sleep(60) """ @@ -283,7 +287,7 @@ class DiscoveredBleAssetsEvents(_MistWebsocket): ws = LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") ws.on_message(lambda data: print(data)) - ws.connect() + ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") ws.disconnect() @@ -298,6 +302,7 @@ class DiscoveredBleAssetsEvents(_MistWebsocket): with LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") as ws: ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread time.sleep(60) """ diff --git a/src/mistapi/websockets/orgs.py b/src/mistapi/websockets/orgs.py index 1f9ce9a..e8a24ff 100644 --- a/src/mistapi/websockets/orgs.py +++ b/src/mistapi/websockets/orgs.py @@ -38,7 +38,7 @@ class InsightsEvents(_MistWebsocket): ws = OrgInsightsEvents(session, org_id="abc123") ws.on_message(lambda data: print(data)) - ws.connect() + ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") ws.disconnect() @@ -53,6 +53,7 @@ class InsightsEvents(_MistWebsocket): with OrgInsightsEvents(session, org_id="abc123") as ws: ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread time.sleep(60) """ @@ -94,7 +95,7 @@ class MxEdgesStatsEvents(_MistWebsocket): ws = OrgMxEdgesStatsEvents(session, org_id="abc123") ws.on_message(lambda data: print(data)) - ws.connect() + ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") ws.disconnect() @@ -109,6 +110,7 @@ class MxEdgesStatsEvents(_MistWebsocket): with OrgMxEdgesStatsEvents(session, org_id="abc123") as ws: ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread time.sleep(60) """ @@ -150,7 +152,7 @@ class MxEdgesUpgradesEvents(_MistWebsocket): ws = OrgMxEdgesUpgradesEvents(session, org_id="abc123") ws.on_message(lambda data: print(data)) - ws.connect() + ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") ws.disconnect() @@ -165,6 +167,7 @@ class MxEdgesUpgradesEvents(_MistWebsocket): with OrgMxEdgesUpgradesEvents(session, org_id="abc123") as ws: ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread time.sleep(60) """ diff --git a/src/mistapi/websockets/session.py b/src/mistapi/websockets/session.py index c2ef382..8b87801 100644 --- a/src/mistapi/websockets/session.py +++ b/src/mistapi/websockets/session.py @@ -39,7 +39,7 @@ class SessionWithUrl(_MistWebsocket): ws = sessionWithUrl(session, url="wss://example.com/channel") ws.on_message(lambda data: print(data)) - ws.connect() + ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") ws.disconnect() @@ -54,6 +54,7 @@ class SessionWithUrl(_MistWebsocket): with sessionWithUrl(session, url="wss://example.com/channel") as ws: ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread time.sleep(60) """ diff --git a/src/mistapi/websockets/sites.py b/src/mistapi/websockets/sites.py index 4c24cc4..291ca14 100644 --- a/src/mistapi/websockets/sites.py +++ b/src/mistapi/websockets/sites.py @@ -38,7 +38,7 @@ class ClientsStatsEvents(_MistWebsocket): ws = SiteClientsStatsEvents(session, site_id="abc123") ws.on_message(lambda data: print(data)) - ws.connect() + ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") ws.disconnect() @@ -53,6 +53,7 @@ class ClientsStatsEvents(_MistWebsocket): with SiteClientsStatsEvents(session, site_id="abc123") as ws: ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread time.sleep(60) """ @@ -103,7 +104,7 @@ class DeviceCmdEvents(_MistWebsocket): ws = SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") ws.on_message(lambda data: print(data)) - ws.connect() + ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") ws.disconnect() @@ -118,6 +119,7 @@ class DeviceCmdEvents(_MistWebsocket): with SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") as ws: ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread time.sleep(60) """ @@ -163,7 +165,7 @@ class DeviceStatsEvents(_MistWebsocket): ws = SiteDeviceStatsEvents(session, site_id="abc123") ws.on_message(lambda data: print(data)) - ws.connect() + ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") ws.disconnect() @@ -178,6 +180,7 @@ class DeviceStatsEvents(_MistWebsocket): with SiteDeviceStatsEvents(session, site_id="abc123") as ws: ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread time.sleep(60) """ @@ -220,7 +223,7 @@ class DeviceUpgradesEvents(_MistWebsocket): ws = SiteDeviceUpgradesEvents(session, site_id="abc123") ws.on_message(lambda data: print(data)) - ws.connect() + ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") ws.disconnect() @@ -235,6 +238,7 @@ class DeviceUpgradesEvents(_MistWebsocket): with SiteDeviceUpgradesEvents(session, site_id="abc123") as ws: ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread time.sleep(60) """ @@ -277,7 +281,7 @@ class MxEdgesStatsEvents(_MistWebsocket): ws = SiteMxEdgesStatsEvents(session, site_id="abc123") ws.on_message(lambda data: print(data)) - ws.connect() + ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") ws.disconnect() @@ -292,6 +296,7 @@ class MxEdgesStatsEvents(_MistWebsocket): with SiteMxEdgesStatsEvents(session, site_id="abc123") as ws: ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread time.sleep(60) """ @@ -334,7 +339,7 @@ class PcapEvents(_MistWebsocket): ws = SitePcapEvents(session, site_id="abc123") ws.on_message(lambda data: print(data)) - ws.connect() + ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") ws.disconnect() @@ -349,6 +354,7 @@ class PcapEvents(_MistWebsocket): with SitePcapEvents(session, site_id="abc123") as ws: ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread time.sleep(60) """ diff --git a/tests/unit/test_api_request.py b/tests/unit/test_api_request.py index e69de29..09fa9b6 100644 --- a/tests/unit/test_api_request.py +++ b/tests/unit/test_api_request.py @@ -0,0 +1,793 @@ +# tests/unit/test_api_request.py +""" +Comprehensive unit tests for mistapi.__api_request.APIRequest. + +Tests cover URL generation, query string encoding, token rotation, +proxy logging, header sanitisation, rate-limit handling, the shared +retry wrapper, and each HTTP-method convenience function. +""" + +import json +from unittest.mock import Mock, patch + +import pytest +import requests +from requests.exceptions import HTTPError + +from mistapi.__api_request import APIRequest +from mistapi.__api_response import APIResponse + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_api_request(cloud_uri="api.mist.com", tokens=None): + """Create an APIRequest with a mocked session for isolated testing.""" + with patch("mistapi.__api_request.requests.session") as mock_session_cls: + mock_session = Mock() + mock_session.headers = {} + mock_session.proxies = {} + mock_session.cookies = {} + mock_session_cls.return_value = mock_session + + req = APIRequest() + req._session = mock_session + req._cloud_uri = cloud_uri + + if tokens: + req._apitoken = list(tokens) + req._apitoken_index = 0 + req._session.headers.update( + {"Authorization": "Token " + tokens[0]} + ) + return req + + +def _mock_response( + status_code=200, + json_data=None, + headers=None, + raise_for_status_effect=None, +): + """Build a mock requests.Response.""" + resp = Mock(spec=requests.Response) + resp.status_code = status_code + resp.headers = headers or {} + resp.json.return_value = json_data if json_data is not None else {} + resp.content = json.dumps(json_data or {}).encode() + + # For _remove_auth_from_headers — attach a mock PreparedRequest + prep = Mock() + prep.headers = {} + resp.request = prep + + if raise_for_status_effect: + resp.raise_for_status.side_effect = raise_for_status_effect + else: + resp.raise_for_status.return_value = None + return resp + + +# =========================================================================== +# Tests +# =========================================================================== + + +class TestUrl: + """APIRequest._url() builds the full URL from cloud_uri + uri.""" + + def test_basic_url(self): + req = _make_api_request("api.mist.com") + assert req._url("/api/v1/self") == "https://api.mist.com/api/v1/self" + + def test_empty_uri(self): + req = _make_api_request("api.mist.com") + assert req._url("") == "https://api.mist.com" + + def test_eu_host(self): + req = _make_api_request("api.eu.mist.com") + assert req._url("/api/v1/orgs") == "https://api.eu.mist.com/api/v1/orgs" + + def test_uri_with_path_segments(self): + req = _make_api_request("api.mist.com") + org_id = "203d3d02-dbc0-4c1b-9f41-76896a3330f4" + uri = f"/api/v1/orgs/{org_id}/sites" + assert req._url(uri) == f"https://api.mist.com{uri}" + + +class TestGenQuery: + """APIRequest._gen_query() builds URL-encoded query strings.""" + + def test_none_returns_empty(self): + req = _make_api_request() + assert req._gen_query(None) == "" + + def test_empty_dict_returns_empty(self): + req = _make_api_request() + assert req._gen_query({}) == "" + + def test_single_param(self): + req = _make_api_request() + assert req._gen_query({"page": "2"}) == "?page=2" + + def test_multiple_params(self): + req = _make_api_request() + result = req._gen_query({"page": "1", "limit": "100"}) + assert result.startswith("?") + assert "page=1" in result + assert "limit=100" in result + + def test_special_chars_encoded(self): + req = _make_api_request() + result = req._gen_query({"filter": "name=hello world&foo"}) + assert "?" in result + # urllib.parse.urlencode encodes spaces as + and & as %26 + assert "hello+world" in result or "hello%20world" in result + assert "%26" in result + + def test_preserves_insertion_order(self): + req = _make_api_request() + # dict preserves insertion order in Python 3.7+ + result = req._gen_query({"a": "1", "b": "2", "c": "3"}) + assert result == "?a=1&b=2&c=3" + + +class TestNextApiToken: + """APIRequest._next_apitoken() rotates through tokens and raises + RuntimeError when only one token is available.""" + + def test_rotates_to_next_token(self): + req = _make_api_request(tokens=["tok_aaa1", "tok_bbb2", "tok_ccc3"]) + req._apitoken_index = 0 + req._next_apitoken() + assert req._apitoken_index == 1 + assert req._session.headers["Authorization"] == "Token tok_bbb2" + + def test_wraps_around_to_first_token(self): + req = _make_api_request(tokens=["tok_aaa1", "tok_bbb2"]) + req._apitoken_index = 1 + req._next_apitoken() + assert req._apitoken_index == 0 + assert req._session.headers["Authorization"] == "Token tok_aaa1" + + def test_single_token_raises_runtime_error(self): + req = _make_api_request(tokens=["tok_only1"]) + req._apitoken_index = 0 + with pytest.raises(RuntimeError, match="API rate limit reached"): + req._next_apitoken() + + def test_rotation_cycle(self): + tokens = ["tok_aaa1", "tok_bbb2", "tok_ccc3"] + req = _make_api_request(tokens=tokens) + req._apitoken_index = 0 + # Rotate through all tokens and back + req._next_apitoken() + assert req._apitoken_index == 1 + req._next_apitoken() + assert req._apitoken_index == 2 + req._next_apitoken() + assert req._apitoken_index == 0 + + +class TestLogProxy: + """APIRequest._log_proxy() prints a masked proxy URL.""" + + def test_prints_masked_proxy(self, capsys): + req = _make_api_request() + req._session.proxies = { + "https": "http://user:secret_password@proxy.example.com:8080" + } + req._log_proxy() + captured = capsys.readouterr() + assert "proxy.example.com:8080" in captured.out + assert "*********" in captured.out + assert "secret_password" not in captured.out + + def test_no_proxy_does_nothing(self, capsys): + req = _make_api_request() + req._session.proxies = {} + req._log_proxy() + captured = capsys.readouterr() + assert captured.out == "" + + def test_proxy_without_password(self, capsys): + req = _make_api_request() + req._session.proxies = {"https": "http://proxy.example.com:8080"} + req._log_proxy() + captured = capsys.readouterr() + assert "proxy.example.com:8080" in captured.out + + +class TestRemoveAuthFromHeaders: + """APIRequest._remove_auth_from_headers() masks sensitive headers.""" + + def test_masks_authorization(self): + req = _make_api_request() + resp = Mock() + resp.request.headers = {"Authorization": "Token secret123"} + headers = req._remove_auth_from_headers(resp) + assert headers["Authorization"] == "***hidden***" + + def test_masks_csrf_token(self): + req = _make_api_request() + resp = Mock() + resp.request.headers = {"X-CSRFToken": "csrf_value"} + headers = req._remove_auth_from_headers(resp) + assert headers["X-CSRFToken"] == "***hidden***" + + def test_masks_cookie(self): + req = _make_api_request() + resp = Mock() + resp.request.headers = {"Cookie": "session=abc123"} + headers = req._remove_auth_from_headers(resp) + assert headers["Cookie"] == "***hidden***" + + def test_masks_all_three(self): + req = _make_api_request() + resp = Mock() + resp.request.headers = { + "Authorization": "Token x", + "X-CSRFToken": "y", + "Cookie": "z", + "Content-Type": "application/json", + } + headers = req._remove_auth_from_headers(resp) + assert headers["Authorization"] == "***hidden***" + assert headers["X-CSRFToken"] == "***hidden***" + assert headers["Cookie"] == "***hidden***" + # Non-sensitive header left untouched + assert headers["Content-Type"] == "application/json" + + def test_leaves_non_sensitive_headers(self): + req = _make_api_request() + resp = Mock() + resp.request.headers = { + "Accept": "application/json", + "User-Agent": "python-requests/2.32", + } + headers = req._remove_auth_from_headers(resp) + assert headers["Accept"] == "application/json" + assert headers["User-Agent"] == "python-requests/2.32" + + +class TestHandleRateLimit: + """APIRequest._handle_rate_limit() sleeps according to Retry-After + header or falls back to exponential backoff.""" + + @patch("mistapi.__api_request.time.sleep") + def test_uses_retry_after_header(self, mock_sleep): + req = _make_api_request() + resp = Mock() + resp.headers = {"Retry-After": "10"} + req._handle_rate_limit(resp, attempt=0) + mock_sleep.assert_called_once_with(10) + + @patch("mistapi.__api_request.time.sleep") + def test_exponential_backoff_when_no_header(self, mock_sleep): + req = _make_api_request() + resp = Mock() + resp.headers = {} + # attempt 0 => 5 * (2**0) = 5 + req._handle_rate_limit(resp, attempt=0) + mock_sleep.assert_called_once_with(5) + + @patch("mistapi.__api_request.time.sleep") + def test_exponential_backoff_attempt_1(self, mock_sleep): + req = _make_api_request() + resp = Mock() + resp.headers = {} + # attempt 1 => 5 * (2**1) = 10 + req._handle_rate_limit(resp, attempt=1) + mock_sleep.assert_called_once_with(10) + + @patch("mistapi.__api_request.time.sleep") + def test_exponential_backoff_attempt_2(self, mock_sleep): + req = _make_api_request() + resp = Mock() + resp.headers = {} + # attempt 2 => 5 * (2**2) = 20 + req._handle_rate_limit(resp, attempt=2) + mock_sleep.assert_called_once_with(20) + + @patch("mistapi.__api_request.time.sleep") + def test_invalid_retry_after_falls_back(self, mock_sleep): + req = _make_api_request() + resp = Mock() + resp.headers = {"Retry-After": "not-a-number"} + # attempt 0 => fallback 5 * (2**0) = 5 + req._handle_rate_limit(resp, attempt=0) + mock_sleep.assert_called_once_with(5) + + +class TestRequestWithRetrySuccess: + """_request_with_retry() on successful responses.""" + + def test_returns_api_response_on_200(self): + req = _make_api_request() + resp = _mock_response(status_code=200, json_data={"ok": True}) + fn = Mock(return_value=resp) + + result = req._request_with_retry("test", fn, "https://api.mist.com/api/v1/self") + + assert isinstance(result, APIResponse) + assert result.status_code == 200 + fn.assert_called_once() + + def test_increments_count(self): + req = _make_api_request() + assert req._count == 0 + resp = _mock_response(status_code=200) + fn = Mock(return_value=resp) + + req._request_with_retry("test", fn, "https://example.com") + assert req._count == 1 + + req._request_with_retry("test", fn, "https://example.com") + assert req._count == 2 + + +class TestRequestWithRetryProxyError: + """_request_with_retry() on ProxyError.""" + + def test_proxy_error_sets_flag(self): + req = _make_api_request() + fn = Mock(side_effect=requests.exceptions.ProxyError("proxy down")) + + result = req._request_with_retry("test", fn, "https://example.com") + + assert result.proxy_error is True + assert result.status_code is None + assert req._count == 1 + + +class TestRequestWithRetryConnectionError: + """_request_with_retry() on ConnectionError.""" + + def test_connection_error_returns_none_response(self): + req = _make_api_request() + fn = Mock(side_effect=requests.exceptions.ConnectionError("no route")) + + result = req._request_with_retry("test", fn, "https://example.com") + + assert result.status_code is None + assert result.proxy_error is False + assert req._count == 1 + + +class TestRequestWithRetryHTTPErrorNon429: + """_request_with_retry() on non-429 HTTPError.""" + + def test_non_429_stops_immediately(self): + req = _make_api_request() + resp = _mock_response(status_code=403, json_data={"error": "forbidden"}) + http_err = HTTPError(response=resp) + resp.raise_for_status.side_effect = http_err + + fn = Mock(return_value=resp) + result = req._request_with_retry("test", fn, "https://example.com") + + # Should only call request_fn once (no retries for non-429) + fn.assert_called_once() + assert result.status_code == 403 + assert req._count == 1 + + def test_500_error(self): + req = _make_api_request() + resp = _mock_response(status_code=500, json_data={"error": "server error"}) + http_err = HTTPError(response=resp) + resp.raise_for_status.side_effect = http_err + + fn = Mock(return_value=resp) + result = req._request_with_retry("test", fn, "https://example.com") + + fn.assert_called_once() + assert result.status_code == 500 + + +class TestRequestWithRetry429: + """_request_with_retry() on 429 rate-limit responses.""" + + @patch("mistapi.__api_request.time.sleep") + def test_retries_on_429(self, mock_sleep): + req = _make_api_request(tokens=["tok_aaa1", "tok_bbb2"]) + + # First call returns 429, second call succeeds + resp_429 = _mock_response(status_code=429, headers={"Retry-After": "1"}) + http_err = HTTPError(response=resp_429) + resp_429.raise_for_status.side_effect = http_err + + resp_ok = _mock_response(status_code=200, json_data={"ok": True}) + + fn = Mock(side_effect=[resp_429, resp_ok]) + result = req._request_with_retry("test", fn, "https://example.com") + + assert fn.call_count == 2 + assert result.status_code == 200 + mock_sleep.assert_called_once_with(1) + + @patch("mistapi.__api_request.time.sleep") + def test_rotates_token_on_429(self, mock_sleep): + req = _make_api_request(tokens=["tok_aaa1", "tok_bbb2"]) + + resp_429 = _mock_response(status_code=429, headers={"Retry-After": "1"}) + http_err = HTTPError(response=resp_429) + resp_429.raise_for_status.side_effect = http_err + + resp_ok = _mock_response(status_code=200) + fn = Mock(side_effect=[resp_429, resp_ok]) + req._request_with_retry("test", fn, "https://example.com") + + # Token should have rotated from index 0 to index 1 + assert req._apitoken_index == 1 + + @patch("mistapi.__api_request.time.sleep") + def test_429_exhausted_after_max_retries(self, mock_sleep): + req = _make_api_request(tokens=["tok_aaa1", "tok_bbb2", "tok_ccc3", "tok_ddd4"]) + + resp_429 = _mock_response(status_code=429, headers={"Retry-After": "1"}) + http_err = HTTPError(response=resp_429) + resp_429.raise_for_status.side_effect = http_err + + # All 4 calls (1 initial + 3 retries) return 429 + fn = Mock(return_value=resp_429) + result = req._request_with_retry("test", fn, "https://example.com") + + # MAX_429_RETRIES = 3, so total calls = 4 (attempt 0,1,2,3) + assert fn.call_count == 4 + assert result.status_code == 429 + assert mock_sleep.call_count == 3 + assert req._count == 1 + + @patch("mistapi.__api_request.time.sleep") + def test_429_single_token_still_retries_with_backoff(self, mock_sleep): + """Even with one token (RuntimeError on rotation), retry with backoff.""" + req = _make_api_request(tokens=["tok_only1"]) + + resp_429 = _mock_response(status_code=429, headers={}) + http_err = HTTPError(response=resp_429) + resp_429.raise_for_status.side_effect = http_err + + resp_ok = _mock_response(status_code=200) + fn = Mock(side_effect=[resp_429, resp_ok]) + result = req._request_with_retry("test", fn, "https://example.com") + + assert fn.call_count == 2 + assert result.status_code == 200 + # Backoff with attempt=0 => 5 * 1 = 5 + mock_sleep.assert_called_once_with(5) + + @patch("mistapi.__api_request.time.sleep") + def test_429_calls_handle_rate_limit(self, mock_sleep): + req = _make_api_request(tokens=["tok_aaa1", "tok_bbb2"]) + + resp_429 = _mock_response(status_code=429, headers={"Retry-After": "7"}) + http_err = HTTPError(response=resp_429) + resp_429.raise_for_status.side_effect = http_err + + resp_ok = _mock_response(status_code=200) + fn = Mock(side_effect=[resp_429, resp_ok]) + + with patch.object(req, "_handle_rate_limit", wraps=req._handle_rate_limit) as wrapped: + req._request_with_retry("test", fn, "https://example.com") + wrapped.assert_called_once_with(resp_429, 0) + + mock_sleep.assert_called_once_with(7) + + +class TestRequestWithRetryGenericException: + """_request_with_retry() on unexpected exceptions.""" + + def test_generic_exception_breaks_loop(self): + req = _make_api_request() + fn = Mock(side_effect=ValueError("something unexpected")) + result = req._request_with_retry("test", fn, "https://example.com") + + fn.assert_called_once() + assert result.status_code is None + assert req._count == 1 + + +class TestMistGet: + """mist_get() delegates to _request_with_retry with correct URL+query.""" + + def test_get_without_query(self): + req = _make_api_request() + resp = _mock_response(status_code=200, json_data={"items": []}) + req._session.get.return_value = resp + + result = req.mist_get("/api/v1/self") + + assert result.status_code == 200 + req._session.get.assert_called_once_with("https://api.mist.com/api/v1/self") + + def test_get_with_query(self): + req = _make_api_request() + resp = _mock_response(status_code=200, json_data={}) + req._session.get.return_value = resp + + result = req.mist_get("/api/v1/orgs", query={"page": "2", "limit": "50"}) + + assert result.status_code == 200 + expected_url = "https://api.mist.com/api/v1/orgs?page=2&limit=50" + req._session.get.assert_called_once_with(expected_url) + + def test_get_increments_count(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.get.return_value = resp + + req.mist_get("/api/v1/self") + assert req.get_request_count() == 1 + + +class TestMistPost: + """mist_post() sends JSON body or raw string data.""" + + def test_post_dict_body(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + body = {"name": "Test Site"} + result = req.mist_post("/api/v1/sites", body=body) + + assert result.status_code == 200 + req._session.post.assert_called_once_with( + "https://api.mist.com/api/v1/sites", + json=body, + headers={"Content-Type": "application/json"}, + ) + + def test_post_string_body(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + body = '{"name": "Test Site"}' + result = req.mist_post("/api/v1/sites", body=body) + + assert result.status_code == 200 + req._session.post.assert_called_once_with( + "https://api.mist.com/api/v1/sites", + data=body, + headers={"Content-Type": "application/json"}, + ) + + def test_post_list_body(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + body = [{"name": "Site A"}, {"name": "Site B"}] + result = req.mist_post("/api/v1/sites", body=body) + + assert result.status_code == 200 + req._session.post.assert_called_once_with( + "https://api.mist.com/api/v1/sites", + json=body, + headers={"Content-Type": "application/json"}, + ) + + def test_post_none_body(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + result = req.mist_post("/api/v1/sites", body=None) + + assert result.status_code == 200 + req._session.post.assert_called_once_with( + "https://api.mist.com/api/v1/sites", + json=None, + headers={"Content-Type": "application/json"}, + ) + + +class TestMistPut: + """mist_put() sends JSON body or raw string data.""" + + def test_put_dict_body(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.put.return_value = resp + + body = {"name": "Updated Site"} + result = req.mist_put("/api/v1/sites/123", body=body) + + assert result.status_code == 200 + req._session.put.assert_called_once_with( + "https://api.mist.com/api/v1/sites/123", + json=body, + headers={"Content-Type": "application/json"}, + ) + + def test_put_string_body(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.put.return_value = resp + + body = '{"name": "Updated Site"}' + result = req.mist_put("/api/v1/sites/123", body=body) + + assert result.status_code == 200 + req._session.put.assert_called_once_with( + "https://api.mist.com/api/v1/sites/123", + data=body, + headers={"Content-Type": "application/json"}, + ) + + def test_put_none_body(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.put.return_value = resp + + result = req.mist_put("/api/v1/sites/123", body=None) + + assert result.status_code == 200 + req._session.put.assert_called_once_with( + "https://api.mist.com/api/v1/sites/123", + json=None, + headers={"Content-Type": "application/json"}, + ) + + +class TestMistDelete: + """mist_delete() delegates with correct URL+query.""" + + def test_delete_without_query(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.delete.return_value = resp + + result = req.mist_delete("/api/v1/sites/123") + + assert result.status_code == 200 + req._session.delete.assert_called_once_with( + "https://api.mist.com/api/v1/sites/123" + ) + + def test_delete_with_query(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.delete.return_value = resp + + result = req.mist_delete("/api/v1/sites/123", query={"force": "true"}) + + assert result.status_code == 200 + req._session.delete.assert_called_once_with( + "https://api.mist.com/api/v1/sites/123?force=true" + ) + + +class TestMistPostFile: + """mist_post_file() builds multipart form data and delegates to retry wrapper.""" + + def test_post_file_with_file_key(self, tmp_path): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + # Create a real temporary file + test_file = tmp_path / "upload.bin" + test_file.write_bytes(b"file content here") + + result = req.mist_post_file( + "/api/v1/sites/123/maps/import", + multipart_form_data={"file": str(test_file)}, + ) + + assert result.status_code == 200 + req._session.post.assert_called_once() + call_kwargs = req._session.post.call_args + files_arg = call_kwargs.kwargs.get("files") or call_kwargs[1].get("files") + assert "file" in files_arg + # Tuple structure: (filename, file_obj, content_type) + assert files_arg["file"][0] == "upload.bin" + assert files_arg["file"][2] == "application/octet-stream" + + def test_post_file_with_csv_key(self, tmp_path): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + csv_file = tmp_path / "data.csv" + csv_file.write_text("col1,col2\na,b\n") + + result = req.mist_post_file( + "/api/v1/orgs/123/inventory", + multipart_form_data={"csv": str(csv_file)}, + ) + + assert result.status_code == 200 + call_kwargs = req._session.post.call_args + files_arg = call_kwargs.kwargs.get("files") or call_kwargs[1].get("files") + assert "csv" in files_arg + assert files_arg["csv"][0] == "data.csv" + + def test_post_file_with_json_field(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + result = req.mist_post_file( + "/api/v1/sites/123/maps", + multipart_form_data={"json": {"name": "Floor 1"}}, + ) + + assert result.status_code == 200 + call_kwargs = req._session.post.call_args + files_arg = call_kwargs.kwargs.get("files") or call_kwargs[1].get("files") + assert "json" in files_arg + # Non-file keys produce (None, json_string) tuples + assert files_arg["json"][0] is None + assert json.loads(files_arg["json"][1]) == {"name": "Floor 1"} + + def test_post_file_none_defaults_to_empty(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + result = req.mist_post_file("/api/v1/sites/123/maps") + + assert result.status_code == 200 + call_kwargs = req._session.post.call_args + files_arg = call_kwargs.kwargs.get("files") or call_kwargs[1].get("files") + assert files_arg == {} + + def test_post_file_skips_falsy_values(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + result = req.mist_post_file( + "/api/v1/sites/123/maps", + multipart_form_data={"json": None, "file": ""}, + ) + + assert result.status_code == 200 + call_kwargs = req._session.post.call_args + files_arg = call_kwargs.kwargs.get("files") or call_kwargs[1].get("files") + assert files_arg == {} + + def test_post_file_missing_file_handled_gracefully(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + # Point to a file that does not exist + result = req.mist_post_file( + "/api/v1/sites/123/maps", + multipart_form_data={"file": "/nonexistent/path/file.bin"}, + ) + + assert result.status_code == 200 + # The OSError is caught silently; file key is not in the generated data + call_kwargs = req._session.post.call_args + files_arg = call_kwargs.kwargs.get("files") or call_kwargs[1].get("files") + assert "file" not in files_arg + + +class TestGetRequestCount: + """get_request_count() returns the cumulative request count.""" + + def test_initial_count_is_zero(self): + req = _make_api_request() + assert req.get_request_count() == 0 + + def test_count_after_requests(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.get.return_value = resp + + req.mist_get("/api/v1/self") + req.mist_get("/api/v1/self") + req.mist_get("/api/v1/self") + + assert req.get_request_count() == 3 + + def test_count_increments_on_error(self): + req = _make_api_request() + fn = Mock(side_effect=requests.exceptions.ConnectionError("fail")) + req._request_with_retry("test", fn, "https://example.com") + assert req.get_request_count() == 1 diff --git a/tests/unit/test_api_response.py b/tests/unit/test_api_response.py index e69de29..850fc0a 100644 --- a/tests/unit/test_api_response.py +++ b/tests/unit/test_api_response.py @@ -0,0 +1,328 @@ +""" +Unit tests for mistapi.__api_response.APIResponse + +Tests cover: +- Construction with None response (default field values) +- Construction with a valid JSON response (data, status_code, headers) +- _check_next() when "next" key is present in response data +- _check_next() pagination via X-Page-Total / X-Page-Limit / X-Page-Page headers +- _check_next() on the last page (next stays None) +- _check_next() when the URL already contains a page= parameter +- Error responses (4xx status codes) +- proxy_error flag propagation +- Non-JSON responses (exception path in __init__) +""" + +import json +from unittest.mock import Mock + +import pytest + +from mistapi.__api_response import APIResponse + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_mock_response(status_code=200, data=None, headers=None, json_raises=False): + """Build a mock requests.Response with the given attributes.""" + mock = Mock() + mock.status_code = status_code + mock.headers = headers or {} + if json_raises: + mock.content = b"not-json" + mock.json.side_effect = ValueError("No JSON") + else: + payload = data if data is not None else {} + mock.content = json.dumps(payload).encode() + mock.json.return_value = payload + return mock + + +# --------------------------------------------------------------------------- +# Tests: construction / default fields +# --------------------------------------------------------------------------- + +class TestAPIResponseConstruction: + """Tests for APIResponse.__init__ with different response inputs.""" + + def test_none_response_defaults(self): + """When response is None every field should keep its default value.""" + resp = APIResponse(response=None, url="https://api.mist.com/api/v1/test") + + assert resp.raw_data == "" + assert resp.data == {} + assert resp.url == "https://api.mist.com/api/v1/test" + assert resp.next is None + assert resp.headers is None + assert resp.status_code is None + assert resp.proxy_error is False + + def test_200_response_with_json(self, api_response_factory): + """A 200 response with JSON body should populate data, status_code, headers.""" + data = {"id": "abc-123", "name": "widget"} + headers = {"Content-Type": "application/json"} + + resp = api_response_factory(status_code=200, data=data, headers=headers) + + assert resp.status_code == 200 + assert resp.data == data + assert resp.headers == headers + assert resp.url == "https://api.mist.com/api/v1/test" + + def test_raw_data_is_string_of_content(self): + """raw_data should be str(response.content).""" + data = {"key": "value"} + mock = _make_mock_response(data=data) + resp = APIResponse(response=mock, url="https://host/api/v1/x") + + assert resp.raw_data == str(json.dumps(data).encode()) + + def test_proxy_error_true(self): + """proxy_error=True should be stored on the instance.""" + resp = APIResponse(response=None, url="https://host/api/v1/x", proxy_error=True) + + assert resp.proxy_error is True + + def test_proxy_error_false_by_default(self): + """proxy_error defaults to False.""" + resp = APIResponse(response=None, url="https://host/api/v1/x") + + assert resp.proxy_error is False + + def test_proxy_error_with_response(self): + """proxy_error flag should propagate even when a valid response is present.""" + mock = _make_mock_response(status_code=502, data={"error": "bad gateway"}) + resp = APIResponse(response=mock, url="https://host/api/v1/x", proxy_error=True) + + assert resp.proxy_error is True + assert resp.status_code == 502 + + +# --------------------------------------------------------------------------- +# Tests: error responses +# --------------------------------------------------------------------------- + +class TestAPIResponseErrors: + """Tests for error HTTP status codes and error payloads.""" + + @pytest.mark.parametrize("status_code", [400, 401, 403, 404, 500, 502, 503]) + def test_error_status_codes_stored(self, api_response_factory, status_code): + """Error status codes should be stored without raising.""" + data = {"error": "something went wrong"} + resp = api_response_factory(status_code=status_code, data=data) + + assert resp.status_code == status_code + assert resp.data == data + + def test_error_key_in_200_response(self, api_response_factory): + """A 200 with an 'error' key in the body should still store data.""" + data = {"error": "unexpected error in body"} + resp = api_response_factory(status_code=200, data=data) + + assert resp.status_code == 200 + assert resp.data == data + + def test_non_json_response_handled_gracefully(self): + """When response.json() raises, the exception path should not propagate.""" + mock = _make_mock_response(json_raises=True) + # Should not raise + resp = APIResponse(response=mock, url="https://host/api/v1/x") + + # data stays at default because json() failed + assert resp.data == {} + assert resp.status_code == 200 + assert resp.next is None + + +# --------------------------------------------------------------------------- +# Tests: _check_next with "next" in data +# --------------------------------------------------------------------------- + +class TestCheckNextFromData: + """Tests for _check_next() when the response body contains a 'next' key.""" + + def test_next_in_data(self, api_response_factory): + """When data contains 'next', self.next should be set from data.""" + data = {"results": [], "next": "/api/v1/test?page=2"} + resp = api_response_factory(data=data) + + assert resp.next == "/api/v1/test?page=2" + + def test_next_in_data_takes_precedence_over_headers(self): + """'next' in data should take precedence over pagination headers.""" + headers = { + "X-Page-Total": "100", + "X-Page-Limit": "10", + "X-Page-Page": "1", + } + data = {"next": "/api/v1/custom-next"} + mock = _make_mock_response(data=data, headers=headers) + resp = APIResponse(response=mock, url="https://host/api/v1/items") + + assert resp.next == "/api/v1/custom-next" + + def test_next_value_none_in_data(self, api_response_factory): + """When data['next'] is None, self.next should be set to None.""" + data = {"results": [], "next": None} + resp = api_response_factory(data=data) + + assert resp.next is None + + +# --------------------------------------------------------------------------- +# Tests: _check_next with pagination headers +# --------------------------------------------------------------------------- + +class TestCheckNextFromHeaders: + """Tests for _check_next() computing the next URL from pagination headers.""" + + def _make_paginated_response(self, total, limit, page, url=None): + """Helper: build an APIResponse with pagination headers.""" + url = url or "https://api.mist.com/api/v1/sites" + headers = { + "X-Page-Total": str(total), + "X-Page-Limit": str(limit), + "X-Page-Page": str(page), + } + mock = _make_mock_response(data={"results": []}, headers=headers) + return APIResponse(response=mock, url=url) + + def test_next_page_computed_from_headers(self): + """When there are more pages, next should be computed from headers.""" + resp = self._make_paginated_response(total=50, limit=10, page=1) + + assert resp.next == "/api/v1/sites?page=2" + + def test_next_page_with_existing_query_string(self): + """When URL already has a query string, page should use '&'.""" + resp = self._make_paginated_response( + total=50, limit=10, page=1, + url="https://api.mist.com/api/v1/sites?limit=10" + ) + + assert resp.next == "/api/v1/sites?limit=10&page=2" + + def test_last_page_next_is_none(self): + """On the last page (limit*page >= total), next should remain None.""" + resp = self._make_paginated_response(total=30, limit=10, page=3) + + assert resp.next is None + + def test_beyond_last_page_next_is_none(self): + """When limit*page > total, next should remain None.""" + resp = self._make_paginated_response(total=25, limit=10, page=3) + + assert resp.next is None + + def test_single_page_result(self): + """When all results fit in one page, next stays None.""" + resp = self._make_paginated_response(total=5, limit=10, page=1) + + assert resp.next is None + + def test_existing_page_param_replaced(self): + """When URL already contains page=N, it should be replaced.""" + resp = self._make_paginated_response( + total=100, limit=10, page=2, + url="https://api.mist.com/api/v1/sites?limit=10&page=2" + ) + + assert resp.next == "/api/v1/sites?limit=10&page=3" + + def test_existing_page_param_first_page(self): + """Replacing page=1 with page=2 when page param already in URL.""" + resp = self._make_paginated_response( + total=100, limit=10, page=1, + url="https://api.mist.com/api/v1/sites?page=1&limit=10" + ) + + assert resp.next == "/api/v1/sites?page=2&limit=10" + + def test_missing_total_header(self): + """When X-Page-Total is missing, next should remain None.""" + headers = { + "X-Page-Limit": "10", + "X-Page-Page": "1", + } + mock = _make_mock_response(data={"results": []}, headers=headers) + resp = APIResponse(response=mock, url="https://host/api/v1/items") + + assert resp.next is None + + def test_missing_limit_header(self): + """When X-Page-Limit is missing, next should remain None.""" + headers = { + "X-Page-Total": "50", + "X-Page-Page": "1", + } + mock = _make_mock_response(data={"results": []}, headers=headers) + resp = APIResponse(response=mock, url="https://host/api/v1/items") + + assert resp.next is None + + def test_missing_page_header(self): + """When X-Page-Page is missing, next should remain None.""" + headers = { + "X-Page-Total": "50", + "X-Page-Limit": "10", + } + mock = _make_mock_response(data={"results": []}, headers=headers) + resp = APIResponse(response=mock, url="https://host/api/v1/items") + + assert resp.next is None + + def test_non_numeric_headers_handled(self): + """Non-numeric pagination header values should not raise.""" + headers = { + "X-Page-Total": "abc", + "X-Page-Limit": "10", + "X-Page-Page": "1", + } + mock = _make_mock_response(data={"results": []}, headers=headers) + resp = APIResponse(response=mock, url="https://host/api/v1/items") + + assert resp.next is None + + def test_no_headers_at_all(self, api_response_factory): + """When there are no pagination headers and no 'next' in data, next is None.""" + resp = api_response_factory(data={"results": []}) + + assert resp.next is None + + def test_pagination_strips_host_prefix(self): + """The computed next URL should be a relative /api/... path, not absolute.""" + resp = self._make_paginated_response( + total=100, limit=10, page=1, + url="https://api.eu.mist.com/api/v1/orgs/abc/devices" + ) + + assert resp.next.startswith("/api/v1/") + assert "api.eu.mist.com" not in resp.next + assert resp.next == "/api/v1/orgs/abc/devices?page=2" + + +# --------------------------------------------------------------------------- +# Tests: data types preserved +# --------------------------------------------------------------------------- + +class TestDataTypes: + """Verify that different JSON response shapes are handled correctly.""" + + def test_list_response(self): + """An API that returns a JSON list should store it in data.""" + data = [{"id": "a"}, {"id": "b"}] + mock = _make_mock_response(data=data) + resp = APIResponse(response=mock, url="https://host/api/v1/x") + + assert resp.data == data + # When data is a list, there is no 'next' key lookup issue + assert resp.next is None + + def test_empty_dict_response(self, api_response_factory): + """An empty dict body should result in data=={} and next==None.""" + resp = api_response_factory(data={}) + + assert resp.data == {} + assert resp.next is None diff --git a/tests/unit/test_api_session.py b/tests/unit/test_api_session.py index f3eed30..50716d2 100644 --- a/tests/unit/test_api_session.py +++ b/tests/unit/test_api_session.py @@ -403,3 +403,155 @@ def test_environment_variable_type_handling(self) -> None: # Assert - Should convert string to int assert session._console_log_level == 30 # int, not '30' assert session._logging_log_level == 20 # int, not '20' + + +class TestNewSession: + """Test _new_session() method""" + + def test_new_session_returns_session_with_headers(self, authenticated_session) -> None: + """_new_session creates a requests.Session with correct Accept header""" + with patch("mistapi.__api_session.requests.session") as mock_session_cls: + mock_sess = Mock() + mock_sess.headers = {} + mock_sess.proxies = {} + mock_session_cls.return_value = mock_sess + + result = authenticated_session._new_session() + + assert result.headers["Accept"] == "application/json, application/vnd.api+json" + + def test_new_session_sets_auth_header(self, authenticated_session) -> None: + """_new_session includes Authorization header when API token is configured""" + with patch("mistapi.__api_session.requests.session") as mock_session_cls: + mock_sess = Mock() + mock_sess.headers = {} + mock_sess.proxies = {} + mock_session_cls.return_value = mock_sess + + result = authenticated_session._new_session() + + expected_token = authenticated_session._apitoken[authenticated_session._apitoken_index] + assert result.headers["Authorization"] == f"Token {expected_token}" + + def test_new_session_sets_proxies(self) -> None: + """_new_session applies proxies when configured""" + with patch.dict(os.environ, {}, clear=True): + with patch("mistapi.__api_session.requests.session") as mock_session_cls: + mock_sess = Mock() + mock_sess.headers = {} + mock_sess.proxies = {} + mock_session_cls.return_value = mock_sess + + session = APISession( + console_log_level=50, https_proxy="http://proxy:8080" + ) + session._apitoken = [] + + # Create new session - use a real dict for proxies so update works + new_mock = Mock() + new_mock.headers = {} + new_mock.proxies = {} + mock_session_cls.return_value = new_mock + + result = session._new_session() + assert result.proxies == {"https": "http://proxy:8080"} + + def test_new_session_no_auth_without_token(self, isolated_session) -> None: + """_new_session omits Authorization when no token configured""" + with patch("mistapi.__api_session.requests.session") as mock_session_cls: + mock_sess = Mock() + mock_sess.headers = {} + mock_sess.proxies = {} + mock_session_cls.return_value = mock_sess + + result = isolated_session._new_session() + + assert "Authorization" not in result.headers + + +class TestSetApiTokenValidation: + """Test set_api_token with validate=False""" + + def test_set_api_token_no_validate(self, isolated_session) -> None: + """set_api_token(validate=False) accepts tokens without calling _check_api_tokens""" + isolated_session.set_cloud("api.mist.com") + with patch.object( + isolated_session, "_check_api_tokens" + ) as mock_check: + isolated_session.set_api_token("token_abc_123", validate=False) + + mock_check.assert_not_called() + assert isolated_session._apitoken == ["token_abc_123"] + assert isolated_session._apitoken_index == 0 + + def test_set_api_token_validate_true_calls_check(self, isolated_session) -> None: + """set_api_token(validate=True) calls _check_api_tokens""" + isolated_session.set_cloud("api.mist.com") + with patch.object( + isolated_session, "_check_api_tokens", return_value=["token_abc"] + ) as mock_check: + isolated_session.set_api_token("token_abc", validate=True) + + mock_check.assert_called_once_with(["token_abc"]) + + +class TestDeleteApiToken: + """Test delete_api_token method""" + + def test_delete_api_token_calls_mist_delete(self, authenticated_session) -> None: + """delete_api_token delegates to mist_delete with correct URI""" + with patch.object(authenticated_session, "mist_delete") as mock_delete: + mock_resp = Mock() + mock_delete.return_value = mock_resp + + result = authenticated_session.delete_api_token("token-id-123") + + mock_delete.assert_called_once_with( + "/api/v1/self/apitokens/token-id-123" + ) + assert result is mock_resp + + +class TestVaultAttrsCleanup: + """Test that vault attributes are cleaned up after init""" + + def test_no_vault_attrs_without_vault(self, isolated_session) -> None: + """Vault attributes are deleted when vault_path is not set""" + assert not hasattr(isolated_session, "_vault_url") + assert not hasattr(isolated_session, "_vault_token") + assert not hasattr(isolated_session, "_vault_path") + assert not hasattr(isolated_session, "_vault_mount_point") + + +class TestLoadEnvVault: + """Test _load_env vault variable loading""" + + def test_load_env_vault_vars(self) -> None: + """_load_env populates _vault_* from env when not already set""" + env_vars = { + "MIST_VAULT_URL": "https://vault.example.com", + "MIST_VAULT_PATH": "secret/data/mist", + "MIST_VAULT_MOUNT_POINT": "kv", + "MIST_VAULT_TOKEN": "vault-token-123", + } + with patch.dict(os.environ, env_vars, clear=True): + with patch("mistapi.__api_session.requests.session") as mock_cls: + mock_sess = Mock() + mock_sess.headers = {} + mock_sess.proxies = {} + mock_cls.return_value = mock_sess + with patch("mistapi.__api_session.hvac") as mock_hvac: + mock_client = Mock() + mock_client.is_authenticated.return_value = True + mock_client.secrets.kv.v2.read_secret.return_value = { + "data": {"data": {}} + } + mock_hvac.Client.return_value = mock_client + + session = APISession(console_log_level=50) + + # Vault was loaded since _vault_path was set from env + mock_hvac.Client.assert_called_once_with( + url="https://vault.example.com", + token="vault-token-123", + ) diff --git a/tests/unit/test_init.py b/tests/unit/test_init.py new file mode 100644 index 0000000..b8d4707 --- /dev/null +++ b/tests/unit/test_init.py @@ -0,0 +1,288 @@ +# tests/unit/test_init.py +""" +Unit tests for mistapi.__init__ module. + +Tests verify: +- Direct imports (APISession, get_all, get_next, __version__, __author__) +- Lazy subpackage loading (api, utils, websockets, cli) +- AttributeError for unknown attributes +""" + +import importlib +import types +from unittest.mock import patch + +import pytest + + +class TestDirectImports: + """Test that directly-imported names are available on the mistapi module.""" + + def test_apisession_is_available(self): + """APISession should be importable from mistapi.""" + from mistapi import APISession + + assert APISession is not None + + def test_apisession_is_the_real_class(self): + """APISession should be the class from mistapi.__api_session.""" + from mistapi import APISession + from mistapi.__api_session import APISession as RealAPISession + + assert APISession is RealAPISession + + def test_get_all_is_available(self): + """get_all should be importable from mistapi.""" + from mistapi import get_all + + assert callable(get_all) + + def test_get_all_is_the_real_function(self): + """get_all should be the function from mistapi.__pagination.""" + from mistapi import get_all + from mistapi.__pagination import get_all as real_get_all + + assert get_all is real_get_all + + def test_get_next_is_available(self): + """get_next should be importable from mistapi.""" + from mistapi import get_next + + assert callable(get_next) + + def test_get_next_is_the_real_function(self): + """get_next should be the function from mistapi.__pagination.""" + from mistapi import get_next + from mistapi.__pagination import get_next as real_get_next + + assert get_next is real_get_next + + def test_version_is_available(self): + """__version__ should be importable from mistapi.""" + from mistapi import __version__ + + assert isinstance(__version__, str) + assert len(__version__) > 0 + + def test_version_matches_version_module(self): + """__version__ should match the value in mistapi.__version.""" + import mistapi + from mistapi.__version import __version__ as real_version + + assert mistapi.__version__ == real_version + + def test_author_is_available(self): + """__author__ should be importable from mistapi.""" + from mistapi import __author__ + + assert isinstance(__author__, str) + assert len(__author__) > 0 + + def test_author_matches_version_module(self): + """__author__ should match the value in mistapi.__version.""" + import mistapi + from mistapi.__version import __author__ as real_author + + assert mistapi.__author__ == real_author + + +class TestLazyImportApi: + """Test lazy loading of the mistapi.api subpackage.""" + + def test_api_loads_on_access(self): + """Accessing mistapi.api should trigger lazy import and return a module.""" + import mistapi + + # Remove cached attribute if present so __getattr__ fires + mistapi.__dict__.pop("api", None) + + with patch("importlib.import_module", wraps=importlib.import_module) as spy: + result = mistapi.api + spy.assert_any_call("mistapi.api") + + assert isinstance(result, types.ModuleType) + assert result.__name__ == "mistapi.api" + + def test_api_is_cached_after_first_access(self): + """After first access, mistapi.api should be cached in globals.""" + import mistapi + + # Force a fresh lazy load + mistapi.__dict__.pop("api", None) + first = mistapi.api + second = mistapi.api + + assert first is second + assert "api" in mistapi.__dict__ + + +class TestLazyImportUtils: + """Test lazy loading of the mistapi.device_utils subpackage.""" + + def test_utils_loads_on_access(self): + """Accessing mistapi.device_utils should trigger lazy import and return a module.""" + import mistapi + + mistapi.__dict__.pop("device_utils", None) + + with patch("importlib.import_module", wraps=importlib.import_module) as spy: + result = mistapi.device_utils + spy.assert_any_call("mistapi.device_utils") + + assert isinstance(result, types.ModuleType) + assert result.__name__ == "mistapi.device_utils" + + def test_utils_is_cached_after_first_access(self): + """After first access, mistapi.device_utils should be cached in globals.""" + import mistapi + + mistapi.__dict__.pop("device_utils", None) + first = mistapi.device_utils + second = mistapi.device_utils + + assert first is second + assert "device_utils" in mistapi.__dict__ + + +class TestLazyImportWebsockets: + """Test lazy loading of the mistapi.websockets subpackage.""" + + def test_websockets_loads_on_access(self): + """Accessing mistapi.websockets should trigger lazy import and return a module.""" + import mistapi + + mistapi.__dict__.pop("websockets", None) + + with patch("importlib.import_module", wraps=importlib.import_module) as spy: + result = mistapi.websockets + spy.assert_any_call("mistapi.websockets") + + assert isinstance(result, types.ModuleType) + assert result.__name__ == "mistapi.websockets" + + def test_websockets_is_cached_after_first_access(self): + """After first access, mistapi.websockets should be cached in globals.""" + import mistapi + + mistapi.__dict__.pop("websockets", None) + first = mistapi.websockets + second = mistapi.websockets + + assert first is second + assert "websockets" in mistapi.__dict__ + + +class TestLazyImportCli: + """Test lazy loading of the mistapi.cli subpackage.""" + + def test_cli_loads_on_access(self): + """Accessing mistapi.cli should trigger lazy import and return a module.""" + import mistapi + + mistapi.__dict__.pop("cli", None) + + with patch("importlib.import_module", wraps=importlib.import_module) as spy: + result = mistapi.cli + spy.assert_any_call("mistapi.cli") + + assert isinstance(result, types.ModuleType) + assert result.__name__ == "mistapi.cli" + + def test_cli_is_cached_after_first_access(self): + """After first access, mistapi.cli should be cached in globals.""" + import mistapi + + mistapi.__dict__.pop("cli", None) + first = mistapi.cli + second = mistapi.cli + + assert first is second + assert "cli" in mistapi.__dict__ + + +class TestLazyImportMechanism: + """Test the __getattr__ mechanism in general.""" + + def test_lazy_subpackages_dict_has_expected_keys(self): + """_LAZY_SUBPACKAGES should contain the expected subpackage mappings.""" + import mistapi + + expected = {"api", "cli", "websockets", "device_utils"} + assert set(mistapi._LAZY_SUBPACKAGES.keys()) == expected + + def test_lazy_subpackages_values_are_dotted_paths(self): + """Each value in _LAZY_SUBPACKAGES should be a fully-qualified module path.""" + import mistapi + + for key, value in mistapi._LAZY_SUBPACKAGES.items(): + assert value == f"mistapi.{key}" + + def test_getattr_delegates_to_importlib(self): + """__getattr__ should call importlib.import_module for known subpackages.""" + import mistapi + + sentinel = types.ModuleType("mistapi.api") + mistapi.__dict__.pop("api", None) + + with patch("importlib.import_module", return_value=sentinel) as mock_import: + result = mistapi.__getattr__("api") + + mock_import.assert_called_once_with("mistapi.api") + assert result is sentinel + + def test_getattr_caches_result_in_globals(self): + """__getattr__ should store the imported module in the package globals.""" + import mistapi + + sentinel = types.ModuleType("mistapi.device_utils") + mistapi.__dict__.pop("device_utils", None) + + with patch("importlib.import_module", return_value=sentinel): + mistapi.__getattr__("device_utils") + + assert mistapi.__dict__["device_utils"] is sentinel + + +class TestInvalidAttribute: + """Test that accessing undefined attributes raises AttributeError.""" + + def test_unknown_attribute_raises_attribute_error(self): + """Accessing a non-existent attribute should raise AttributeError.""" + import mistapi + + with pytest.raises( + AttributeError, match=r"module 'mistapi' has no attribute 'nonexistent'" + ): + mistapi.__getattr__("nonexistent") + + def test_unknown_attribute_via_getattr_builtin(self): + """getattr on an unknown name without default should raise AttributeError.""" + import mistapi + + with pytest.raises(AttributeError): + getattr(mistapi, "totally_made_up_attribute_xyz") + + def test_unknown_attribute_with_default(self): + """getattr with a default should return the default for unknown names.""" + import mistapi + + result = getattr(mistapi, "totally_made_up_attribute_xyz", "fallback") + assert result == "fallback" + + def test_error_message_includes_attribute_name(self): + """The AttributeError message should include the missing attribute name.""" + import mistapi + + with pytest.raises(AttributeError, match="'bogus_name'"): + mistapi.__getattr__("bogus_name") + + @pytest.mark.parametrize( + "name", + ["foo", "bar", "API", "Websockets", "Api", "UTILS"], + ) + def test_various_invalid_names(self, name): + """Various invalid names should all raise AttributeError (case-sensitive).""" + import mistapi + + with pytest.raises(AttributeError): + mistapi.__getattr__(name) diff --git a/tests/unit/test_logger.py b/tests/unit/test_logger.py new file mode 100644 index 0000000..1f5bdef --- /dev/null +++ b/tests/unit/test_logger.py @@ -0,0 +1,316 @@ +# tests/unit/test_logger.py +""" +Unit tests for the Console logger and LogSanitizer. + +These tests cover sensitive-field redaction, log-level gating, +the LogSanitizer filter, and the _set_log_level helper. +""" + +import logging + +import pytest + +from mistapi.__logger import Console, LogSanitizer, SENSITIVE_FIELDS, logger + + +# --------------------------------------------------------------------------- +# Sanitization +# --------------------------------------------------------------------------- +class TestConsoleSanitize: + """Tests for Console.sanitize() redaction logic.""" + + def test_redacts_password_double_quotes(self) -> None: + """A plain 'password' field in double quotes is redacted.""" + c = Console() + raw = '{"password": "s3cret!"}' + result = c.sanitize(raw) + assert "s3cret!" not in result + assert '******' in result + + def test_redacts_password_single_quotes(self) -> None: + """A 'password' field wrapped in single quotes is redacted.""" + c = Console() + raw = "{'password': 'mysecret'}" + result = c.sanitize(raw) + assert "mysecret" not in result + assert '******' in result + + @pytest.mark.parametrize( + "field", + SENSITIVE_FIELDS, + ids=SENSITIVE_FIELDS, + ) + def test_redacts_every_sensitive_field(self, field: str) -> None: + """Every entry in SENSITIVE_FIELDS must be redacted.""" + c = Console() + raw = f'{{"{field}": "topSecret123"}}' + result = c.sanitize(raw) + assert "topSecret123" not in result + assert '******' in result + + def test_redacts_case_insensitively(self) -> None: + """Field matching is case-insensitive.""" + c = Console() + raw = '{"PASSWORD": "abc"}' + result = c.sanitize(raw) + assert "abc" not in result + assert '******' in result + + def test_redacts_multiple_fields_in_one_string(self) -> None: + """Multiple sensitive fields in the same string are all redacted.""" + c = Console() + raw = '{"password": "pw1", "apitoken": "tok1", "key": "k1"}' + result = c.sanitize(raw) + assert "pw1" not in result + assert "tok1" not in result + assert "k1" not in result + + def test_no_sensitive_data_unchanged(self) -> None: + """A string with no sensitive fields is returned unchanged.""" + c = Console() + raw = '{"name": "Alice", "age": "30"}' + result = c.sanitize(raw) + assert result == raw + + def test_non_string_input_dict(self) -> None: + """A dict is JSON-serialised before sanitisation.""" + c = Console() + data = {"password": "hunter2", "user": "admin"} + result = c.sanitize(data) + assert "hunter2" not in result + assert "admin" in result + assert '******' in result + + def test_non_string_input_list(self) -> None: + """A list containing a dict with sensitive data is sanitised.""" + c = Console() + data = [{"apitoken": "secret_tok"}] + result = c.sanitize(data) + assert "secret_tok" not in result + assert '******' in result + + def test_non_string_input_int(self) -> None: + """An integer is serialised and returned as-is (no sensitive data).""" + c = Console() + result = c.sanitize(42) + assert result == "42" + + def test_empty_string(self) -> None: + """An empty string returns an empty string.""" + c = Console() + assert c.sanitize("") == "" + + def test_empty_value_still_redacted(self) -> None: + """A sensitive field whose value is empty is still redacted.""" + c = Console() + raw = '{"password": ""}' + result = c.sanitize(raw) + assert '******' in result + + +# --------------------------------------------------------------------------- +# Log-level methods (print gating) +# --------------------------------------------------------------------------- +class TestConsoleLogLevelMethods: + """Each log method should print only when the level threshold is met.""" + + # Mapping: (method_name, method_threshold) + # critical prints at level <= 50 + # error prints at level <= 40 + # warning prints at level <= 30 + # info prints at level <= 20 + # debug prints at level <= 10 + LOG_METHODS = [ + ("critical", 50), + ("error", 40), + ("warning", 30), + ("info", 20), + ("debug", 10), + ] + + @pytest.mark.parametrize("method_name,threshold", LOG_METHODS, ids=[m for m, _ in LOG_METHODS]) + def test_prints_when_level_equals_threshold(self, capsys, method_name, threshold) -> None: + """Method prints when console level == method threshold.""" + c = Console(level=threshold) + getattr(c, method_name)("hello") + captured = capsys.readouterr() + assert "hello" in captured.out + + @pytest.mark.parametrize("method_name,threshold", LOG_METHODS, ids=[m for m, _ in LOG_METHODS]) + def test_prints_when_level_below_threshold(self, capsys, method_name, threshold) -> None: + """Method prints when console level is below the method threshold.""" + c = Console(level=max(threshold - 10, 1)) + getattr(c, method_name)("below") + captured = capsys.readouterr() + assert "below" in captured.out + + @pytest.mark.parametrize("method_name,threshold", LOG_METHODS, ids=[m for m, _ in LOG_METHODS]) + def test_silent_when_level_above_threshold(self, capsys, method_name, threshold) -> None: + """Method is silent when console level exceeds the method threshold.""" + c = Console(level=threshold + 10) + getattr(c, method_name)("nope") + captured = capsys.readouterr() + assert captured.out == "" + + @pytest.mark.parametrize("method_name,threshold", LOG_METHODS, ids=[m for m, _ in LOG_METHODS]) + def test_silent_when_level_is_zero(self, capsys, method_name, threshold) -> None: + """No output at all when level is 0 (disabled).""" + c = Console(level=0) + getattr(c, method_name)("disabled") + captured = capsys.readouterr() + assert captured.out == "" + + def test_output_sanitised(self, capsys) -> None: + """Print output has sensitive data redacted.""" + c = Console(level=20) + c.info('{"password": "oops"}') + captured = capsys.readouterr() + assert "oops" not in captured.out + assert '******' in captured.out + + def test_output_has_bracket_prefix(self, capsys) -> None: + """All log lines are wrapped with a bracket prefix.""" + c = Console(level=10) + c.debug("test_msg") + captured = capsys.readouterr() + assert captured.out.startswith("[") + assert "test_msg" in captured.out + + +# --------------------------------------------------------------------------- +# Default level +# --------------------------------------------------------------------------- +class TestConsoleDefaults: + """Verify constructor defaults.""" + + def test_default_level_is_20(self) -> None: + c = Console() + assert c.level == 20 + + def test_custom_level(self) -> None: + c = Console(level=40) + assert c.level == 40 + + +# --------------------------------------------------------------------------- +# _set_log_level +# --------------------------------------------------------------------------- +class TestSetLogLevel: + """Tests for _set_log_level() which adjusts both console and logging levels.""" + + @pytest.fixture(autouse=True) + def _restore_logger_level(self): + """Save and restore the module-level logger level between tests.""" + original = logger.level + yield + logger.setLevel(original) + + def test_sets_console_level(self) -> None: + c = Console(level=20) + c._set_log_level(console_log_level=40, logging_log_level=10) + assert c.level == 40 + + def test_sets_logging_level(self) -> None: + c = Console(level=20) + c._set_log_level(console_log_level=20, logging_log_level=30) + assert logger.level == 30 + + def test_default_arguments(self) -> None: + c = Console(level=50) + c._set_log_level() + assert c.level == 20 + assert logger.level == 10 + + def test_set_level_to_zero_disables_console(self, capsys) -> None: + c = Console(level=20) + c._set_log_level(console_log_level=0) + c.critical("should_not_appear") + captured = capsys.readouterr() + assert captured.out == "" + + +# --------------------------------------------------------------------------- +# LogSanitizer filter +# --------------------------------------------------------------------------- +class TestLogSanitizer: + """Tests for the logging.Filter subclass that redacts sensitive data.""" + + def _make_record(self, msg: str, *args) -> logging.LogRecord: + """Create a minimal LogRecord for testing.""" + record = logging.LogRecord( + name="mistapi", + level=logging.INFO, + pathname="", + lineno=0, + msg=msg, + args=args if args else None, + exc_info=None, + ) + return record + + def test_filter_returns_true(self) -> None: + """The filter never drops records; it always returns True.""" + f = LogSanitizer() + record = self._make_record("safe message") + assert f.filter(record) is True + + def test_filter_sanitises_message(self) -> None: + """Sensitive data inside the message is replaced.""" + f = LogSanitizer() + record = self._make_record('{"password": "leak"}') + f.filter(record) + assert "leak" not in record.msg + assert '******' in record.msg + + def test_filter_clears_args(self) -> None: + """record.args is set to None after filtering.""" + f = LogSanitizer() + record = self._make_record("value is %s", "something") + assert record.args is not None + f.filter(record) + assert record.args is None + + def test_filter_handles_format_args(self) -> None: + """getMessage() expands %-formatting; filter sees the expanded string.""" + f = LogSanitizer() + record = self._make_record( + '{"apitoken": "%s"}', "my_secret_token" + ) + # Before filtering, getMessage() should expand the arg + expanded = record.getMessage() + assert "my_secret_token" in expanded + + f.filter(record) + assert "my_secret_token" not in record.msg + assert '******' in record.msg + + def test_filter_safe_message_unchanged(self) -> None: + """A record without sensitive data passes through with its message intact.""" + f = LogSanitizer() + record = self._make_record("just a normal log line") + f.filter(record) + assert record.msg == "just a normal log line" + assert record.args is None + + +# --------------------------------------------------------------------------- +# Module-level singletons +# --------------------------------------------------------------------------- +class TestModuleLevelObjects: + """The module exposes pre-built console and logger objects.""" + + def test_module_console_exists(self) -> None: + from mistapi.__logger import console as mod_console + assert isinstance(mod_console, Console) + + def test_module_logger_name(self) -> None: + assert logger.name == "mistapi" + + def test_module_logger_has_sanitizer_filter(self) -> None: + filters = logger.filters + assert any(isinstance(f, LogSanitizer) for f in filters) + + def test_module_logger_level_is_integer(self) -> None: + """Logger level is always a valid integer (may be mutated by other tests).""" + assert isinstance(logger.level, int) + assert logger.level >= 0 diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index e69de29..5c67695 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -0,0 +1,506 @@ +# tests/unit/test_models.py +""" +Unit tests for privilege model classes. + +These tests cover _Privilege and Privileges from mistapi.__models.privilege, +verifying construction, string representation, iteration, and field access. +""" + +import pytest + +from mistapi.__models.privilege import Privileges, _Privilege + + +# --------------------------------------------------------------------------- +# Constants re-declared locally so tests are self-contained within the file, +# but the sample_privileges fixture from conftest is reused for consistency. +# --------------------------------------------------------------------------- +TEST_ORG_ID = "203d3d02-dbc0-4c1b-9f41-76896a3330f4" +TEST_SITE_ID = "f5fcbee5-fbca-45b3-8bf1-1619ede87879" + + +# =================================================================== +# _Privilege tests +# =================================================================== +class TestPrivilegeCreation: + """Test _Privilege initialisation and field population""" + + def test_all_fields_populated_from_dict(self) -> None: + """All dict keys should become attributes on the _Privilege object""" + # Arrange + data = { + "scope": "org", + "role": "admin", + "org_id": TEST_ORG_ID, + "org_name": "Acme Corp", + "msp_id": "msp-id-1", + "msp_name": "MSP One", + "orggroup_ids": ["grp-1", "grp-2"], + "name": "Test Privilege", + "site_id": TEST_SITE_ID, + "sitegroup_ids": ["sg-1"], + "views": ["monitoring", "location"], + } + + # Act + priv = _Privilege(data) + + # Assert + assert priv.scope == "org" + assert priv.role == "admin" + assert priv.org_id == TEST_ORG_ID + assert priv.org_name == "Acme Corp" + assert priv.msp_id == "msp-id-1" + assert priv.msp_name == "MSP One" + assert priv.orggroup_ids == ["grp-1", "grp-2"] + assert priv.name == "Test Privilege" + assert priv.site_id == TEST_SITE_ID + assert priv.sitegroup_ids == ["sg-1"] + assert priv.views == ["monitoring", "location"] + + def test_partial_dict_sets_defaults_for_missing_fields(self) -> None: + """Fields not present in the dict should retain their defaults""" + # Arrange + data = {"scope": "org", "role": "read"} + + # Act + priv = _Privilege(data) + + # Assert — supplied values + assert priv.scope == "org" + assert priv.role == "read" + + # Assert — default values + assert priv.org_id == "" + assert priv.org_name == "" + assert priv.msp_id == "" + assert priv.msp_name == "" + assert priv.orggroup_ids == [] + assert priv.name == "" + assert priv.site_id == "" + assert priv.sitegroup_ids == [] + assert priv.views == [] + + def test_empty_dict_uses_all_defaults(self) -> None: + """An empty dict should produce a _Privilege with all default values""" + # Act + priv = _Privilege({}) + + # Assert + assert priv.scope == "" + assert priv.role == "" + assert priv.org_id == "" + assert priv.name == "" + + def test_extra_keys_become_attributes(self) -> None: + """Keys not in the predefined set should still be set as attributes""" + # Arrange + data = {"scope": "org", "custom_field": "custom_value"} + + # Act + priv = _Privilege(data) + + # Assert + assert priv.scope == "org" + assert priv.custom_field == "custom_value" + + def test_creation_from_sample_fixture(self, sample_privileges) -> None: + """Verify creation using the sample_privileges fixture data""" + # Act + priv_org = _Privilege(sample_privileges[0]) + priv_site = _Privilege(sample_privileges[1]) + + # Assert + assert priv_org.scope == "org" + assert priv_org.role == "admin" + assert priv_org.org_id == TEST_ORG_ID + assert priv_org.name == "Test Organisation" + + assert priv_site.scope == "site" + assert priv_site.role == "write" + assert priv_site.site_id == TEST_SITE_ID + assert priv_site.org_id == TEST_ORG_ID + assert priv_site.name == "Test Site" + + +class TestPrivilegeStr: + """Test _Privilege.__str__() output""" + + def test_str_includes_non_empty_fields(self) -> None: + """Non-empty string fields should appear in the output""" + # Arrange + data = { + "scope": "org", + "role": "admin", + "org_id": TEST_ORG_ID, + "name": "My Org", + } + + # Act + priv = _Privilege(data) + result = str(priv) + + # Assert + assert "scope: org" in result + assert "role: admin" in result + assert f"org_id: {TEST_ORG_ID}" in result + assert "name: My Org" in result + + def test_str_excludes_empty_string_fields(self) -> None: + """Empty string fields should not appear in the output""" + # Arrange + data = {"scope": "org", "role": "admin"} + + # Act + priv = _Privilege(data) + result = str(priv) + + # Assert — these fields are empty strings and should be absent + assert "org_id:" not in result + assert "org_name:" not in result + assert "msp_id:" not in result + assert "msp_name:" not in result + assert "site_id:" not in result + + def test_str_includes_non_empty_list_fields(self) -> None: + """Non-empty list fields (orggroup_ids, sitegroup_ids) should appear""" + # Arrange + data = { + "scope": "org", + "role": "admin", + "orggroup_ids": ["grp-1"], + "sitegroup_ids": ["sg-1"], + } + + # Act + priv = _Privilege(data) + result = str(priv) + + # Assert + assert "orggroup_ids:" in result + assert "sitegroup_ids:" in result + + def test_str_uses_crlf_separator(self) -> None: + """Each field line should end with ' \\r\\n'""" + # Arrange + data = {"scope": "org", "role": "admin"} + + # Act + priv = _Privilege(data) + result = str(priv) + + # Assert + assert "scope: org \r\n" in result + assert "role: admin \r\n" in result + + def test_str_empty_privilege(self) -> None: + """A privilege with all defaults should only contain list fields""" + # Act + priv = _Privilege({}) + result = str(priv) + + # Assert - empty strings are excluded but empty lists [] are not "" + # so orggroup_ids and sitegroup_ids will still appear + assert "scope:" not in result + assert "role:" not in result + assert "org_id:" not in result + + +class TestPrivilegeGet: + """Test _Privilege.get() method""" + + def test_get_returns_value_when_present(self) -> None: + """get() should return the attribute value when it exists and is truthy""" + # Arrange + data = {"scope": "org", "role": "admin", "org_id": TEST_ORG_ID} + priv = _Privilege(data) + + # Act & Assert + assert priv.get("scope") == "org" + assert priv.get("role") == "admin" + assert priv.get("org_id") == TEST_ORG_ID + + def test_get_returns_default_when_key_missing(self) -> None: + """get() should return the default when the attribute does not exist""" + # Arrange + priv = _Privilege({"scope": "org"}) + + # Act & Assert + assert priv.get("nonexistent") is None + assert priv.get("nonexistent", "fallback") == "fallback" + + def test_get_returns_default_when_value_is_empty_string(self) -> None: + """get() should return the default when the attribute is an empty string (falsy)""" + # Arrange + priv = _Privilege({}) # org_id defaults to "" + + # Act & Assert + assert priv.get("org_id") is None + assert priv.get("org_id", "default_org") == "default_org" + + def test_get_returns_default_when_value_is_empty_list(self) -> None: + """get() should return the default when the attribute is an empty list (falsy)""" + # Arrange + priv = _Privilege({}) # views defaults to [] + + # Act & Assert + assert priv.get("views") is None + assert priv.get("views", ["default_view"]) == ["default_view"] + + def test_get_returns_list_when_non_empty(self) -> None: + """get() should return the list value when it is non-empty""" + # Arrange + data = {"views": ["monitoring", "location"]} + priv = _Privilege(data) + + # Act & Assert + assert priv.get("views") == ["monitoring", "location"] + + def test_get_default_parameter_defaults_to_none(self) -> None: + """get() default parameter should be None when not specified""" + # Arrange + priv = _Privilege({}) + + # Act & Assert + assert priv.get("nonexistent") is None + + +# =================================================================== +# Privileges tests +# =================================================================== +class TestPrivilegesCreation: + """Test Privileges initialisation""" + + def test_creates_privilege_objects_from_list_of_dicts(self, sample_privileges) -> None: + """Privileges should wrap each dict into a _Privilege object""" + # Act + privs = Privileges(sample_privileges) + + # Assert + assert len(privs.privileges) == 2 + assert all(isinstance(p, _Privilege) for p in privs.privileges) + + def test_first_entry_matches_source_data(self, sample_privileges) -> None: + """First _Privilege should carry the values from the first dict""" + # Act + privs = Privileges(sample_privileges) + + # Assert + first = privs.privileges[0] + assert first.scope == "org" + assert first.role == "admin" + assert first.org_id == TEST_ORG_ID + assert first.name == "Test Organisation" + + def test_second_entry_matches_source_data(self, sample_privileges) -> None: + """Second _Privilege should carry the values from the second dict""" + # Act + privs = Privileges(sample_privileges) + + # Assert + second = privs.privileges[1] + assert second.scope == "site" + assert second.role == "write" + assert second.site_id == TEST_SITE_ID + assert second.org_id == TEST_ORG_ID + assert second.name == "Test Site" + + def test_empty_list_produces_empty_privileges(self) -> None: + """An empty list should result in an empty privileges list""" + # Act + privs = Privileges([]) + + # Assert + assert privs.privileges == [] + + def test_single_privilege(self) -> None: + """A single-element list should produce exactly one _Privilege""" + # Arrange + data = [{"scope": "msp", "role": "admin", "msp_id": "msp-1"}] + + # Act + privs = Privileges(data) + + # Assert + assert len(privs.privileges) == 1 + assert privs.privileges[0].scope == "msp" + assert privs.privileges[0].msp_id == "msp-1" + + +class TestPrivilegesIter: + """Test Privileges.__iter__()""" + + def test_iter_yields_all_privileges(self, sample_privileges) -> None: + """Iterating should yield every _Privilege in order""" + # Arrange + privs = Privileges(sample_privileges) + + # Act + result = list(privs) + + # Assert + assert len(result) == 2 + assert result[0].scope == "org" + assert result[1].scope == "site" + + def test_iter_returns_privilege_instances(self, sample_privileges) -> None: + """Each iterated element should be a _Privilege instance""" + # Arrange + privs = Privileges(sample_privileges) + + # Act & Assert + for priv in privs: + assert isinstance(priv, _Privilege) + + def test_iter_empty_privileges_returns_empty_iterator(self) -> None: + """Iterating over empty Privileges should produce no elements""" + # Arrange + privs = Privileges([]) + + # Act + result = list(privs) + + # Assert + assert result == [] + + def test_iter_can_be_used_in_for_loop(self, sample_privileges) -> None: + """Privileges should work naturally in a for loop""" + # Arrange + privs = Privileges(sample_privileges) + scopes = [] + + # Act + for priv in privs: + scopes.append(priv.scope) + + # Assert + assert scopes == ["org", "site"] + + def test_iter_supports_next(self, sample_privileges) -> None: + """The iterator returned by __iter__ should support next()""" + # Arrange + privs = Privileges(sample_privileges) + + # Act + it = iter(privs) + first = next(it) + second = next(it) + + # Assert + assert first.scope == "org" + assert second.scope == "site" + + with pytest.raises(StopIteration): + next(it) + + def test_iter_supports_generator_expression(self, sample_privileges) -> None: + """Privileges should work with generator expressions (e.g. next(...))""" + # Arrange + privs = Privileges(sample_privileges) + + # Act + found = next((p for p in privs if p.org_id == TEST_ORG_ID), None) + + # Assert + assert found is not None + assert found.org_id == TEST_ORG_ID + + +class TestPrivilegesStr: + """Test Privileges.__str__() output""" + + def test_str_produces_tabulate_table(self, sample_privileges) -> None: + """String output should be a tabulate-formatted table""" + # Arrange + privs = Privileges(sample_privileges) + + # Act + result = str(privs) + + # Assert — column headers should be present + assert "scope" in result + assert "role" in result + assert "name" in result + assert "site_id" in result + assert "org_name" in result + assert "org_id" in result + assert "msp_name" in result + assert "msp_id" in result + assert "views" in result + + def test_str_contains_privilege_data(self, sample_privileges) -> None: + """The table should contain actual privilege field values""" + # Arrange + privs = Privileges(sample_privileges) + + # Act + result = str(privs) + + # Assert + assert "org" in result + assert "admin" in result + assert "Test Organisation" in result + assert TEST_ORG_ID in result + assert "site" in result + assert "write" in result + assert "Test Site" in result + assert TEST_SITE_ID in result + + def test_str_empty_privileges(self) -> None: + """An empty Privileges should produce just headers (or an empty table)""" + # Arrange + privs = Privileges([]) + + # Act + result = str(privs) + + # Assert — tabulate with an empty table and headers produces header row(s) + # At minimum it should not raise and should be a string + assert isinstance(result, str) + + def test_str_is_consistent(self, sample_privileges) -> None: + """Calling str() multiple times should produce the same result""" + # Arrange + privs = Privileges(sample_privileges) + + # Act + result1 = str(privs) + result2 = str(privs) + + # Assert + assert result1 == result2 + + +class TestPrivilegesDisplay: + """Test Privileges.display() method""" + + def test_display_returns_same_as_str(self, sample_privileges) -> None: + """display() should return exactly the same string as __str__()""" + # Arrange + privs = Privileges(sample_privileges) + + # Act + display_result = privs.display() + str_result = str(privs) + + # Assert + assert display_result == str_result + + def test_display_returns_string_type(self, sample_privileges) -> None: + """display() should return a str""" + # Arrange + privs = Privileges(sample_privileges) + + # Act + result = privs.display() + + # Assert + assert isinstance(result, str) + + def test_display_empty_privileges(self) -> None: + """display() on empty Privileges should match str() on empty Privileges""" + # Arrange + privs = Privileges([]) + + # Act & Assert + assert privs.display() == str(privs) diff --git a/tests/unit/test_pagination.py b/tests/unit/test_pagination.py index e69de29..1cdfee5 100644 --- a/tests/unit/test_pagination.py +++ b/tests/unit/test_pagination.py @@ -0,0 +1,296 @@ +""" +Unit tests for mistapi.__pagination module. + +Tests the get_next() and get_all() pagination helper functions using +mocked APISession and APIResponse objects. +""" + +from unittest.mock import Mock + +from mistapi.__pagination import get_all, get_next + + +def _make_response(data, next_url=None): + """Create a mock APIResponse with the given data and next link.""" + response = Mock() + response.data = data + response.next = next_url + return response + + +class TestGetNext: + """Tests for get_next().""" + + def test_calls_mist_get_when_next_exists(self): + """get_next() should call mist_session.mist_get with response.next.""" + session = Mock() + next_url = "/api/v1/sites?page=2" + response = _make_response(data=[], next_url=next_url) + expected = _make_response(data=[{"id": "second"}]) + session.mist_get.return_value = expected + + result = get_next(session, response) + + session.mist_get.assert_called_once_with(next_url) + assert result is expected + + def test_returns_none_when_next_is_none(self): + """get_next() should return None when response.next is None.""" + session = Mock() + response = _make_response(data=[], next_url=None) + + result = get_next(session, response) + + assert result is None + session.mist_get.assert_not_called() + + def test_returns_none_when_next_is_empty_string(self): + """get_next() should return None when response.next is an empty string.""" + session = Mock() + response = _make_response(data=[], next_url="") + + result = get_next(session, response) + + assert result is None + session.mist_get.assert_not_called() + + +class TestGetAllList: + """Tests for get_all() when response.data is a list.""" + + def test_single_page_returns_data(self): + """get_all() should return the list data as-is when there is no next page.""" + session = Mock() + items = [{"id": "a"}, {"id": "b"}] + response = _make_response(data=items, next_url=None) + + result = get_all(session, response) + + assert result == items + session.mist_get.assert_not_called() + + def test_single_page_returns_copy(self): + """get_all() should return a new list, not the original reference.""" + session = Mock() + items = [{"id": "a"}] + response = _make_response(data=items, next_url=None) + + result = get_all(session, response) + + assert result == items + assert result is not items + + def test_multi_page_concatenates(self): + """get_all() should follow next links and concatenate all pages.""" + session = Mock() + + page1 = _make_response( + data=[{"id": "1"}, {"id": "2"}], + next_url="/api/v1/items?page=2", + ) + page2 = _make_response( + data=[{"id": "3"}, {"id": "4"}], + next_url="/api/v1/items?page=3", + ) + page3 = _make_response( + data=[{"id": "5"}], + next_url=None, + ) + + session.mist_get.side_effect = [page2, page3] + + result = get_all(session, page1) + + assert result == [ + {"id": "1"}, + {"id": "2"}, + {"id": "3"}, + {"id": "4"}, + {"id": "5"}, + ] + assert session.mist_get.call_count == 2 + + def test_empty_list_returns_empty(self): + """get_all() should return an empty list for an empty first page.""" + session = Mock() + response = _make_response(data=[], next_url=None) + + result = get_all(session, response) + + assert result == [] + + def test_empty_list_with_next_follows_links(self): + """get_all() should still follow next even when the first page is empty.""" + session = Mock() + + page1 = _make_response(data=[], next_url="/api/v1/items?page=2") + page2 = _make_response(data=[{"id": "1"}], next_url=None) + session.mist_get.return_value = page2 + + result = get_all(session, page1) + + assert result == [{"id": "1"}] + + +class TestGetAllDict: + """Tests for get_all() when response.data is a dict with 'results' key.""" + + def test_single_page_extracts_results(self): + """get_all() should extract and return the 'results' list from a dict response.""" + session = Mock() + items = [{"id": "a"}, {"id": "b"}] + response = _make_response( + data={"results": items, "total": 2, "limit": 100}, + next_url=None, + ) + + result = get_all(session, response) + + assert result == items + session.mist_get.assert_not_called() + + def test_single_page_returns_copy(self): + """get_all() should return a copy of results, not the original.""" + session = Mock() + items = [{"id": "a"}] + response = _make_response( + data={"results": items}, + next_url=None, + ) + + result = get_all(session, response) + + assert result == items + assert result is not items + + def test_multi_page_concatenates_results(self): + """get_all() should follow next links and concatenate results from dict responses.""" + session = Mock() + + page1 = _make_response( + data={"results": [{"id": "1"}], "total": 3, "limit": 1}, + next_url="/api/v1/items?page=2", + ) + page2 = _make_response( + data={"results": [{"id": "2"}], "total": 3, "limit": 1}, + next_url="/api/v1/items?page=3", + ) + page3 = _make_response( + data={"results": [{"id": "3"}], "total": 3, "limit": 1}, + next_url=None, + ) + + session.mist_get.side_effect = [page2, page3] + + result = get_all(session, page1) + + assert result == [{"id": "1"}, {"id": "2"}, {"id": "3"}] + assert session.mist_get.call_count == 2 + + def test_empty_results_returns_empty(self): + """get_all() should return an empty list when results is empty.""" + session = Mock() + response = _make_response( + data={"results": [], "total": 0}, + next_url=None, + ) + + result = get_all(session, response) + + assert result == [] + + def test_empty_results_with_next_follows_links(self): + """get_all() should follow next even when the first page results are empty.""" + session = Mock() + + page1 = _make_response( + data={"results": []}, + next_url="/api/v1/items?page=2", + ) + page2 = _make_response( + data={"results": [{"id": "1"}]}, + next_url=None, + ) + session.mist_get.return_value = page2 + + result = get_all(session, page1) + + assert result == [{"id": "1"}] + + +class TestGetAllEdgeCases: + """Tests for get_all() edge cases and unsupported data types.""" + + def test_dict_without_results_key_returns_empty(self): + """get_all() should return empty list when data is a dict without 'results'.""" + session = Mock() + response = _make_response( + data={"items": [1, 2, 3]}, + next_url=None, + ) + + result = get_all(session, response) + + assert result == [] + + def test_non_list_non_dict_returns_empty(self): + """get_all() should return empty list for unsupported data types.""" + session = Mock() + response = _make_response(data="some string", next_url=None) + + result = get_all(session, response) + + assert result == [] + + def test_none_data_returns_empty(self): + """get_all() should return empty list when data is None.""" + session = Mock() + response = _make_response(data=None, next_url=None) + + result = get_all(session, response) + + assert result == [] + + def test_get_next_returns_none_on_last_page(self): + """get_all() should stop when get_next returns None (no more pages).""" + session = Mock() + + # Simulate: page1 has next, page2 does not. + page1 = _make_response( + data=[{"id": "1"}], + next_url="/api/v1/items?page=2", + ) + page2 = _make_response( + data=[{"id": "2"}], + next_url=None, + ) + session.mist_get.return_value = page2 + + result = get_all(session, page1) + + assert result == [{"id": "1"}, {"id": "2"}] + session.mist_get.assert_called_once_with("/api/v1/items?page=2") + + def test_does_not_mutate_original_response(self): + """get_all() should not modify the original response object's data.""" + session = Mock() + original_items = [{"id": "1"}, {"id": "2"}] + response = _make_response(data=original_items, next_url=None) + + get_all(session, response) + + assert response.data == [{"id": "1"}, {"id": "2"}] + + def test_dict_does_not_mutate_original_results(self): + """get_all() should not mutate the original results list in a dict response.""" + session = Mock() + original_results = [{"id": "1"}] + response = _make_response( + data={"results": original_results}, + next_url=None, + ) + + result = get_all(session, response) + + assert original_results == [{"id": "1"}] + assert result is not original_results diff --git a/tests/unit/test_websocket_client.py b/tests/unit/test_websocket_client.py new file mode 100644 index 0000000..47457aa --- /dev/null +++ b/tests/unit/test_websocket_client.py @@ -0,0 +1,774 @@ +# tests/unit/test_websocket_client.py +""" +Unit tests for _MistWebsocket base class and public WebSocket channel classes. + +These tests cover URL building, authentication helpers, SSL options, +callback registration, internal handlers, connect/disconnect lifecycle, +the receive() generator, context-manager support, and the public API +surface of all channel classes (sites, orgs, location, session). +""" + +import json +import queue +import ssl +from unittest.mock import Mock, call, patch + +import pytest + +from mistapi.websockets.__ws_client import _MistWebsocket +from mistapi.websockets.location import ( + BleAssetsEvents, + ConnectedClientsEvents, + DiscoveredBleAssetsEvents, + SdkClientsEvents, + UnconnectedClientsEvents, +) +from mistapi.websockets.orgs import ( + InsightsEvents, + MxEdgesStatsEvents as OrgMxEdgesStatsEvents, + MxEdgesUpgradesEvents, +) +from mistapi.websockets.session import SessionWithUrl +from mistapi.websockets.sites import ( + ClientsStatsEvents, + DeviceCmdEvents, + DeviceStatsEvents, + DeviceUpgradesEvents, + MxEdgesStatsEvents as SiteMxEdgesStatsEvents, + PcapEvents, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_session(): + """ + Lightweight mock of APISession with the internal attributes that + _MistWebsocket accesses directly. + """ + session = Mock() + session._cloud_uri = "api.mist.com" + session._apitoken = ["test_token"] + session._apitoken_index = 0 + + # requests.Session stand-in + requests_session = Mock() + requests_session.cookies = [] + requests_session.verify = True + requests_session.cert = None + session._session = requests_session + + return session + + +@pytest.fixture +def ws_client(mock_session): + """A _MistWebsocket wired to mock_session with two channels.""" + return _MistWebsocket( + mist_session=mock_session, + channels=["/test/channel1", "/test/channel2"], + ) + + +@pytest.fixture +def single_channel_client(mock_session): + """A _MistWebsocket with a single channel and custom ping settings.""" + return _MistWebsocket( + mist_session=mock_session, + channels=["/events"], + ping_interval=15, + ping_timeout=5, + ) + + +# --------------------------------------------------------------------------- +# URL building +# --------------------------------------------------------------------------- + + +class TestBuildWsUrl: + """Tests for _build_ws_url().""" + + def test_replaces_api_with_api_ws(self, ws_client) -> None: + url = ws_client._build_ws_url() + assert url == "wss://api-ws.mist.com/api-ws/v1/stream" + + def test_eu_cloud_url(self, mock_session) -> None: + mock_session._cloud_uri = "api.eu.mist.com" + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._build_ws_url() == "wss://api-ws.eu.mist.com/api-ws/v1/stream" + + def test_gc1_cloud_url(self, mock_session) -> None: + mock_session._cloud_uri = "api.gc1.mist.com" + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._build_ws_url() == "wss://api-ws.gc1.mist.com/api-ws/v1/stream" + + +# --------------------------------------------------------------------------- +# Headers (token auth) +# --------------------------------------------------------------------------- + + +class TestGetHeaders: + """Tests for _get_headers().""" + + def test_returns_authorization_header_with_token(self, ws_client) -> None: + headers = ws_client._get_headers() + assert headers == {"Authorization": "Token test_token"} + + def test_uses_correct_token_index(self, mock_session) -> None: + mock_session._apitoken = ["token_a", "token_b"] + mock_session._apitoken_index = 1 + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_headers() == {"Authorization": "Token token_b"} + + def test_returns_empty_dict_without_token(self, mock_session) -> None: + mock_session._apitoken = [] + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_headers() == {} + + def test_returns_empty_dict_when_token_is_none(self, mock_session) -> None: + mock_session._apitoken = None + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_headers() == {} + + +# --------------------------------------------------------------------------- +# Cookies (session auth) +# --------------------------------------------------------------------------- + + +class TestGetCookie: + """Tests for _get_cookie().""" + + def test_formats_cookies_as_semicolon_pairs(self, mock_session) -> None: + cookie1 = Mock(name="csrftoken", value="abc123") + cookie1.name = "csrftoken" + cookie1.value = "abc123" + cookie2 = Mock(name="sessionid", value="xyz789") + cookie2.name = "sessionid" + cookie2.value = "xyz789" + mock_session._session.cookies = [cookie1, cookie2] + + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_cookie() == "csrftoken=abc123; sessionid=xyz789" + + def test_returns_none_when_no_cookies(self, mock_session) -> None: + mock_session._session.cookies = [] + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_cookie() is None + + def test_filters_cookies_with_cr_in_name(self, mock_session) -> None: + good = Mock() + good.name = "ok" + good.value = "val" + bad = Mock() + bad.name = "bad\rname" + bad.value = "val" + mock_session._session.cookies = [good, bad] + + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_cookie() == "ok=val" + + def test_filters_cookies_with_lf_in_name(self, mock_session) -> None: + bad = Mock() + bad.name = "bad\nname" + bad.value = "val" + mock_session._session.cookies = [bad] + + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_cookie() is None + + def test_filters_cookies_with_cr_in_value(self, mock_session) -> None: + bad = Mock() + bad.name = "name" + bad.value = "bad\rvalue" + mock_session._session.cookies = [bad] + + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_cookie() is None + + def test_filters_cookies_with_lf_in_value(self, mock_session) -> None: + good = Mock() + good.name = "safe" + good.value = "clean" + bad = Mock() + bad.name = "name" + bad.value = "bad\nvalue" + mock_session._session.cookies = [good, bad] + + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_cookie() == "safe=clean" + + def test_returns_none_when_all_cookies_filtered(self, mock_session) -> None: + bad1 = Mock() + bad1.name = "a\r" + bad1.value = "v" + bad2 = Mock() + bad2.name = "b" + bad2.value = "v\n" + mock_session._session.cookies = [bad1, bad2] + + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_cookie() is None + + +# --------------------------------------------------------------------------- +# SSL options +# --------------------------------------------------------------------------- + + +class TestBuildSslopt: + """Tests for _build_sslopt().""" + + def test_defaults_returns_empty_dict(self, ws_client) -> None: + # verify=True, cert=None => empty sslopt + assert ws_client._build_sslopt() == {} + + def test_verify_false(self, mock_session) -> None: + mock_session._session.verify = False + mock_session._session.cert = None + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._build_sslopt() == {"cert_reqs": ssl.CERT_NONE} + + def test_verify_custom_ca_path(self, mock_session) -> None: + mock_session._session.verify = "/etc/ssl/custom-ca.pem" + mock_session._session.cert = None + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._build_sslopt() == {"ca_certs": "/etc/ssl/custom-ca.pem"} + + def test_cert_as_string(self, mock_session) -> None: + mock_session._session.cert = "/path/to/client.pem" + client = _MistWebsocket(mock_session, channels=["/ch"]) + sslopt = client._build_sslopt() + assert sslopt["certfile"] == "/path/to/client.pem" + assert "keyfile" not in sslopt + + def test_cert_as_tuple(self, mock_session) -> None: + mock_session._session.cert = ("/path/cert.pem", "/path/key.pem") + client = _MistWebsocket(mock_session, channels=["/ch"]) + sslopt = client._build_sslopt() + assert sslopt["certfile"] == "/path/cert.pem" + assert sslopt["keyfile"] == "/path/key.pem" + + def test_cert_tuple_single_element(self, mock_session) -> None: + mock_session._session.cert = ("/path/cert.pem",) + client = _MistWebsocket(mock_session, channels=["/ch"]) + sslopt = client._build_sslopt() + assert sslopt["certfile"] == "/path/cert.pem" + assert "keyfile" not in sslopt + + def test_verify_false_with_cert_tuple(self, mock_session) -> None: + mock_session._session.verify = False + mock_session._session.cert = ("/path/cert.pem", "/path/key.pem") + client = _MistWebsocket(mock_session, channels=["/ch"]) + sslopt = client._build_sslopt() + assert sslopt == { + "cert_reqs": ssl.CERT_NONE, + "certfile": "/path/cert.pem", + "keyfile": "/path/key.pem", + } + + +# --------------------------------------------------------------------------- +# Callback registration +# --------------------------------------------------------------------------- + + +class TestCallbackRegistration: + """Tests for on_message / on_error / on_open / on_close setters.""" + + def test_on_message_stores_callback(self, ws_client) -> None: + cb = Mock() + ws_client.on_message(cb) + assert ws_client._on_message_cb is cb + + def test_on_error_stores_callback(self, ws_client) -> None: + cb = Mock() + ws_client.on_error(cb) + assert ws_client._on_error_cb is cb + + def test_on_open_stores_callback(self, ws_client) -> None: + cb = Mock() + ws_client.on_open(cb) + assert ws_client._on_open_cb is cb + + def test_on_close_stores_callback(self, ws_client) -> None: + cb = Mock() + ws_client.on_close(cb) + assert ws_client._on_close_cb is cb + + def test_callbacks_initially_none(self, ws_client) -> None: + assert ws_client._on_message_cb is None + assert ws_client._on_error_cb is None + assert ws_client._on_open_cb is None + assert ws_client._on_close_cb is None + + +# --------------------------------------------------------------------------- +# Internal handlers +# --------------------------------------------------------------------------- + + +class TestHandleOpen: + """Tests for _handle_open().""" + + def test_subscribes_to_each_channel(self, ws_client) -> None: + mock_ws = Mock() + ws_client._handle_open(mock_ws) + expected_calls = [ + call(json.dumps({"subscribe": "/test/channel1"})), + call(json.dumps({"subscribe": "/test/channel2"})), + ] + mock_ws.send.assert_has_calls(expected_calls) + assert mock_ws.send.call_count == 2 + + def test_sets_connected_event(self, ws_client) -> None: + mock_ws = Mock() + assert not ws_client._connected.is_set() + ws_client._handle_open(mock_ws) + assert ws_client._connected.is_set() + + def test_calls_on_open_callback(self, ws_client) -> None: + cb = Mock() + ws_client.on_open(cb) + ws_client._handle_open(Mock()) + cb.assert_called_once_with() + + def test_no_error_without_on_open_callback(self, ws_client) -> None: + ws_client._handle_open(Mock()) # Should not raise + + +class TestHandleMessage: + """Tests for _handle_message().""" + + def test_parses_valid_json_and_enqueues(self, ws_client) -> None: + payload = {"event": "device_update", "id": "abc"} + ws_client._handle_message(Mock(), json.dumps(payload)) + assert ws_client._queue.get_nowait() == payload + + def test_wraps_invalid_json_in_raw_key(self, ws_client) -> None: + ws_client._handle_message(Mock(), "not valid json {{{") + item = ws_client._queue.get_nowait() + assert item == {"raw": "not valid json {{{"} + + def test_calls_on_message_callback_with_parsed_data(self, ws_client) -> None: + cb = Mock() + ws_client.on_message(cb) + payload = {"type": "event"} + ws_client._handle_message(Mock(), json.dumps(payload)) + cb.assert_called_once_with(payload) + + def test_calls_on_message_callback_with_raw_fallback(self, ws_client) -> None: + cb = Mock() + ws_client.on_message(cb) + ws_client._handle_message(Mock(), "plain text") + cb.assert_called_once_with({"raw": "plain text"}) + + def test_no_error_without_on_message_callback(self, ws_client) -> None: + ws_client._handle_message(Mock(), '{"ok": true}') # Should not raise + + +class TestHandleError: + """Tests for _handle_error().""" + + def test_calls_on_error_callback(self, ws_client) -> None: + cb = Mock() + ws_client.on_error(cb) + exc = ConnectionError("lost connection") + ws_client._handle_error(Mock(), exc) + cb.assert_called_once_with(exc) + + def test_no_error_without_callback(self, ws_client) -> None: + ws_client._handle_error(Mock(), RuntimeError("boom")) # Should not raise + + +class TestHandleClose: + """Tests for _handle_close().""" + + def test_clears_connected_event(self, ws_client) -> None: + ws_client._connected.set() + ws_client._handle_close(Mock(), 1000, "normal closure") + assert not ws_client._connected.is_set() + + def test_puts_none_sentinel_on_queue(self, ws_client) -> None: + ws_client._handle_close(Mock(), 1000, "normal closure") + assert ws_client._queue.get_nowait() is None + + def test_calls_on_close_callback(self, ws_client) -> None: + cb = Mock() + ws_client.on_close(cb) + ws_client._handle_close(Mock(), 1001, "going away") + cb.assert_called_once_with(1001, "going away") + + def test_no_error_without_callback(self, ws_client) -> None: + ws_client._handle_close(Mock(), 1000, "") # Should not raise + + +# --------------------------------------------------------------------------- +# Connect / disconnect lifecycle +# --------------------------------------------------------------------------- + + +class TestConnect: + """Tests for connect() and disconnect().""" + + @patch("mistapi.websockets.__ws_client.websocket.WebSocketApp") + def test_connect_creates_websocket_app(self, mock_ws_cls, ws_client) -> None: + mock_ws_instance = Mock() + mock_ws_cls.return_value = mock_ws_instance + + ws_client.connect(run_in_background=False) + + mock_ws_cls.assert_called_once_with( + "wss://api-ws.mist.com/api-ws/v1/stream", + header={"Authorization": "Token test_token"}, + cookie=None, + on_open=ws_client._handle_open, + on_message=ws_client._handle_message, + on_error=ws_client._handle_error, + on_close=ws_client._handle_close, + ) + mock_ws_instance.run_forever.assert_called_once() + + @patch("mistapi.websockets.__ws_client.websocket.WebSocketApp") + def test_connect_drains_stale_queue_items(self, mock_ws_cls, ws_client) -> None: + # Pre-populate queue with stale sentinel and data + ws_client._queue.put(None) + ws_client._queue.put({"old": "data"}) + assert not ws_client._queue.empty() + + mock_ws_cls.return_value = Mock() + ws_client.connect(run_in_background=False) + + # Queue should have been drained before creating the WebSocketApp + assert ws_client._queue.empty() + + @patch("mistapi.websockets.__ws_client.websocket.WebSocketApp") + def test_connect_background_starts_thread(self, mock_ws_cls, ws_client) -> None: + mock_ws_instance = Mock() + mock_ws_cls.return_value = mock_ws_instance + + with patch("mistapi.websockets.__ws_client.threading.Thread") as mock_thread_cls: + mock_thread = Mock() + mock_thread_cls.return_value = mock_thread + + ws_client.connect(run_in_background=True) + + mock_thread_cls.assert_called_once_with( + target=ws_client._run_forever_safe, daemon=True + ) + mock_thread.start.assert_called_once() + + @patch("mistapi.websockets.__ws_client.websocket.WebSocketApp") + def test_disconnect_calls_close(self, mock_ws_cls, ws_client) -> None: + mock_ws_instance = Mock() + mock_ws_cls.return_value = mock_ws_instance + + ws_client.connect(run_in_background=False) + ws_client.disconnect() + + mock_ws_instance.close.assert_called_once() + + def test_disconnect_without_connect_is_noop(self, ws_client) -> None: + ws_client.disconnect() # Should not raise + + +# --------------------------------------------------------------------------- +# _run_forever_safe +# --------------------------------------------------------------------------- + + +class TestRunForeverSafe: + """Tests for _run_forever_safe().""" + + def test_passes_ping_and_ssl_to_run_forever(self, single_channel_client) -> None: + mock_ws = Mock() + single_channel_client._ws = mock_ws + # verify=True, cert=None => empty sslopt dict + single_channel_client._run_forever_safe() + mock_ws.run_forever.assert_called_once_with( + ping_interval=15, + ping_timeout=5, + sslopt={}, + ) + + def test_passes_sslopt_when_verify_false(self, mock_session) -> None: + mock_session._session.verify = False + mock_session._session.cert = None + client = _MistWebsocket(mock_session, channels=["/ch"]) + mock_ws = Mock() + client._ws = mock_ws + client._run_forever_safe() + mock_ws.run_forever.assert_called_once_with( + ping_interval=30, + ping_timeout=10, + sslopt={"cert_reqs": ssl.CERT_NONE}, + ) + + def test_exception_triggers_error_and_close_handlers(self, ws_client) -> None: + mock_ws = Mock() + mock_ws.run_forever.side_effect = RuntimeError("connection failed") + ws_client._ws = mock_ws + + error_cb = Mock() + close_cb = Mock() + ws_client.on_error(error_cb) + ws_client.on_close(close_cb) + + ws_client._run_forever_safe() + + error_cb.assert_called_once() + assert isinstance(error_cb.call_args[0][0], RuntimeError) + close_cb.assert_called_once_with(-1, "connection failed") + + def test_noop_when_ws_is_none(self, ws_client) -> None: + ws_client._ws = None + ws_client._run_forever_safe() # Should not raise + + +# --------------------------------------------------------------------------- +# receive() generator +# --------------------------------------------------------------------------- + + +class TestReceive: + """Tests for the receive() generator.""" + + def test_yields_queued_messages(self, ws_client) -> None: + ws_client._connected.set() + ws_client._queue.put({"event": "a"}) + ws_client._queue.put({"event": "b"}) + ws_client._queue.put(None) # sentinel + + results = list(ws_client.receive()) + assert results == [{"event": "a"}, {"event": "b"}] + + def test_returns_immediately_when_not_connected_within_timeout(self, ws_client) -> None: + # _connected is never set, so wait(timeout=10) returns False. + # Override timeout via monkey-patching for speed. + original_wait = ws_client._connected.wait + ws_client._connected.wait = lambda timeout=None: False + + results = list(ws_client.receive()) + assert results == [] + + ws_client._connected.wait = original_wait + + def test_stops_on_none_sentinel(self, ws_client) -> None: + ws_client._connected.set() + ws_client._queue.put({"first": True}) + ws_client._queue.put(None) + ws_client._queue.put({"should_not_appear": True}) + + results = list(ws_client.receive()) + assert results == [{"first": True}] + + def test_stops_when_disconnected_and_queue_empty(self, ws_client) -> None: + ws_client._connected.set() + ws_client._queue.put({"msg": 1}) + + gen = ws_client.receive() + assert next(gen) == {"msg": 1} + + # Simulate disconnect: clear connected, queue is empty + ws_client._connected.clear() + results = list(gen) + assert results == [] + + +# --------------------------------------------------------------------------- +# Context manager +# --------------------------------------------------------------------------- + + +class TestContextManager: + """Tests for __enter__ / __exit__.""" + + def test_enter_returns_self(self, ws_client) -> None: + assert ws_client.__enter__() is ws_client + + def test_exit_calls_disconnect(self, ws_client) -> None: + mock_ws = Mock() + ws_client._ws = mock_ws + ws_client.__exit__(None, None, None) + mock_ws.close.assert_called_once() + + def test_with_statement(self, mock_session) -> None: + with _MistWebsocket(mock_session, channels=["/ch"]) as client: + assert isinstance(client, _MistWebsocket) + # After exiting, disconnect should have been called (no-op here since _ws is None) + + @patch("mistapi.websockets.__ws_client.websocket.WebSocketApp") + def test_exit_disconnects_active_connection(self, mock_ws_cls, mock_session) -> None: + mock_ws_instance = Mock() + mock_ws_cls.return_value = mock_ws_instance + + with _MistWebsocket(mock_session, channels=["/ch"]) as client: + client.connect(run_in_background=False) + + mock_ws_instance.close.assert_called_once() + + +# --------------------------------------------------------------------------- +# ready() +# --------------------------------------------------------------------------- + + +class TestReady: + """Tests for ready().""" + + def test_returns_false_when_ws_is_none(self, ws_client) -> None: + assert ws_client.ready() is False + + def test_returns_true_when_ws_reports_ready(self, ws_client) -> None: + mock_ws = Mock() + mock_ws.ready.return_value = True + ws_client._ws = mock_ws + assert ws_client.ready() is True + + def test_returns_false_when_ws_not_ready(self, ws_client) -> None: + mock_ws = Mock() + mock_ws.ready.return_value = False + ws_client._ws = mock_ws + assert ws_client.ready() is False + + +# --------------------------------------------------------------------------- +# Initialisation defaults +# --------------------------------------------------------------------------- + + +class TestInit: + """Tests for __init__ defaults.""" + + def test_default_ping_interval_and_timeout(self, ws_client) -> None: + assert ws_client._ping_interval == 30 + assert ws_client._ping_timeout == 10 + + def test_custom_ping_interval_and_timeout(self, single_channel_client) -> None: + assert single_channel_client._ping_interval == 15 + assert single_channel_client._ping_timeout == 5 + + def test_queue_starts_empty(self, ws_client) -> None: + assert ws_client._queue.empty() + + def test_connected_event_starts_unset(self, ws_client) -> None: + assert not ws_client._connected.is_set() + + def test_ws_starts_none(self, ws_client) -> None: + assert ws_client._ws is None + + def test_thread_starts_none(self, ws_client) -> None: + assert ws_client._thread is None + + +# --------------------------------------------------------------------------- +# Public WebSocket channel classes +# --------------------------------------------------------------------------- + + +class TestSiteChannels: + """Tests for public site-level WebSocket channel classes.""" + + def test_clients_stats_events_channels(self, mock_session) -> None: + ws = ClientsStatsEvents(mock_session, site_ids=["s1", "s2"]) + assert ws._channels == ["/sites/s1/stats/clients", "/sites/s2/stats/clients"] + + def test_device_cmd_events_channels(self, mock_session) -> None: + ws = DeviceCmdEvents(mock_session, site_id="s1", device_ids=["d1", "d2"]) + assert ws._channels == ["/sites/s1/devices/d1/cmd", "/sites/s1/devices/d2/cmd"] + + def test_device_stats_events_channels(self, mock_session) -> None: + ws = DeviceStatsEvents(mock_session, site_ids=["s1"]) + assert ws._channels == ["/sites/s1/stats/devices"] + + def test_device_upgrades_events_channels(self, mock_session) -> None: + ws = DeviceUpgradesEvents(mock_session, site_ids=["s1"]) + assert ws._channels == ["/sites/s1/devices"] + + def test_site_mxedges_stats_events_channels(self, mock_session) -> None: + ws = SiteMxEdgesStatsEvents(mock_session, site_ids=["s1"]) + assert ws._channels == ["/sites/s1/stats/mxedges"] + + def test_pcap_events_channels(self, mock_session) -> None: + ws = PcapEvents(mock_session, site_id="s1") + assert ws._channels == ["/sites/s1/pcaps"] + + def test_custom_ping_settings(self, mock_session) -> None: + ws = DeviceStatsEvents( + mock_session, site_ids=["s1"], ping_interval=60, ping_timeout=20 + ) + assert ws._ping_interval == 60 + assert ws._ping_timeout == 20 + + def test_inherits_from_mist_websocket(self, mock_session) -> None: + ws = DeviceCmdEvents(mock_session, site_id="s1", device_ids=["d1"]) + assert isinstance(ws, _MistWebsocket) + + +class TestOrgChannels: + """Tests for public org-level WebSocket channel classes.""" + + def test_insights_events_channels(self, mock_session) -> None: + ws = InsightsEvents(mock_session, org_id="o1") + assert ws._channels == ["/orgs/o1/insights/summary"] + + def test_org_mxedges_stats_events_channels(self, mock_session) -> None: + ws = OrgMxEdgesStatsEvents(mock_session, org_id="o1") + assert ws._channels == ["/orgs/o1/stats/mxedges"] + + def test_mxedges_upgrades_events_channels(self, mock_session) -> None: + ws = MxEdgesUpgradesEvents(mock_session, org_id="o1") + assert ws._channels == ["/orgs/o1/mxedges"] + + def test_inherits_from_mist_websocket(self, mock_session) -> None: + ws = InsightsEvents(mock_session, org_id="o1") + assert isinstance(ws, _MistWebsocket) + + +class TestLocationChannels: + """Tests for public location-level WebSocket channel classes.""" + + def test_ble_assets_events_channels(self, mock_session) -> None: + ws = BleAssetsEvents(mock_session, site_id="s1", map_id=["m1", "m2"]) + assert ws._channels == [ + "/sites/s1/stats/maps/m1/assets", + "/sites/s1/stats/maps/m2/assets", + ] + + def test_connected_clients_events_channels(self, mock_session) -> None: + ws = ConnectedClientsEvents(mock_session, site_id="s1", map_id=["m1"]) + assert ws._channels == ["/sites/s1/stats/maps/m1/clients"] + + def test_sdk_clients_events_channels(self, mock_session) -> None: + ws = SdkClientsEvents(mock_session, site_id="s1", map_id=["m1"]) + assert ws._channels == ["/sites/s1/stats/maps/m1/sdkclients"] + + def test_unconnected_clients_events_channels(self, mock_session) -> None: + ws = UnconnectedClientsEvents(mock_session, site_id="s1", map_id=["m1"]) + assert ws._channels == ["/sites/s1/stats/maps/m1/unconnected_clients"] + + def test_discovered_ble_assets_events_channels(self, mock_session) -> None: + ws = DiscoveredBleAssetsEvents(mock_session, site_id="s1", map_id=["m1"]) + assert ws._channels == ["/sites/s1/stats/maps/m1/discovered_assets"] + + def test_inherits_from_mist_websocket(self, mock_session) -> None: + ws = BleAssetsEvents(mock_session, site_id="s1", map_id=["m1"]) + assert isinstance(ws, _MistWebsocket) + + +class TestSessionChannel: + """Tests for the SessionWithUrl WebSocket channel class.""" + + def test_session_with_url_channels(self, mock_session) -> None: + ws = SessionWithUrl(mock_session, url="wss://example.com/custom") + assert ws._channels == ["wss://example.com/custom"] + + def test_inherits_from_mist_websocket(self, mock_session) -> None: + ws = SessionWithUrl(mock_session, url="wss://example.com/custom") + assert isinstance(ws, _MistWebsocket) From 843e12c4a2d9d714bc094596f1dec2ed0682b53b Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Fri, 13 Mar 2026 20:53:31 +0100 Subject: [PATCH 09/16] refactor: improve logging format in APIRequest class --- src/mistapi/__api_request.py | 55 ++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/src/mistapi/__api_request.py b/src/mistapi/__api_request.py index a5aabc5..c5b1c08 100644 --- a/src/mistapi/__api_request.py +++ b/src/mistapi/__api_request.py @@ -88,9 +88,9 @@ def _log_proxy(self) -> None: def _next_apitoken(self) -> None: logger.info("apirequest:_next_apitoken:rotating API Token") logger.debug( - f"apirequest:_next_apitoken:current API Token is " - f"{self._apitoken[self._apitoken_index][:4]}..." - f"{self._apitoken[self._apitoken_index][-4:]}" + "apirequest:_next_apitoken:current API Token is %s...%s", + self._apitoken[self._apitoken_index][:4], + self._apitoken[self._apitoken_index][-4:], ) new_index = self._apitoken_index + 1 if new_index >= len(self._apitoken): @@ -101,9 +101,9 @@ def _next_apitoken(self) -> None: {"Authorization": "Token " + self._apitoken[self._apitoken_index]} ) logger.debug( - f"apirequest:_next_apitoken:new API Token is " - f"{self._apitoken[self._apitoken_index][:4]}..." - f"{self._apitoken[self._apitoken_index][-4:]}" + "apirequest:_next_apitoken:new API Token is %s...%s", + self._apitoken[self._apitoken_index][:4], + self._apitoken[self._apitoken_index][-4:], ) else: logger.critical(" /!\\ API TOKEN CRITICAL ERROR /!\\") @@ -184,25 +184,26 @@ def _request_with_retry( proxy_failed = False for attempt in range(self._MAX_429_RETRIES + 1): try: - logger.info(f"apirequest:{method_name}:sending request to {url}") + logger.info("apirequest:%s:sending request to %s", method_name, url) self._log_proxy() resp = request_fn() logger.debug( - f"apirequest:{method_name}:request headers:{self._remove_auth_from_headers(resp)}" + "apirequest:%s:request headers:%s", method_name, self._remove_auth_from_headers(resp) ) resp.raise_for_status() break except requests.exceptions.ProxyError as e: - logger.error(f"apirequest:{method_name}:Proxy Error: {e}") + logger.error("apirequest:%s:Proxy Error: %s", method_name, e) proxy_failed = True break except requests.exceptions.ConnectionError as e: - logger.error(f"apirequest:{method_name}:Connection Error: {e}") + logger.error("apirequest:%s:Connection Error: %s", method_name, e) break except HTTPError as e: if e.response.status_code == 429 and attempt < self._MAX_429_RETRIES: logger.warning( - f"apirequest:{method_name}:HTTP 429 (attempt {attempt + 1}/{self._MAX_429_RETRIES})" + "apirequest:%s:HTTP 429 (attempt %s/%s)", + method_name, attempt + 1, self._MAX_429_RETRIES, ) try: self._next_apitoken() @@ -210,16 +211,16 @@ def _request_with_retry( pass # single token — still retry with backoff self._handle_rate_limit(e.response, attempt) continue - logger.error(f"apirequest:{method_name}:HTTP error: {e}") + logger.error("apirequest:%s:HTTP error: %s", method_name, e) if resp: logger.error( - f"apirequest:{method_name}:HTTP error description: {resp.json()}" + "apirequest:%s:HTTP error description: %s", method_name, resp.json() ) break except Exception as e: - logger.error(f"apirequest:{method_name}:error: {e}") + logger.error("apirequest:%s:error: %s", method_name, e) logger.error( - f"apirequest:{method_name}:Exception occurred", exc_info=True + "apirequest:%s:Exception occurred", method_name, exc_info=True ) break self._count += 1 @@ -261,7 +262,7 @@ def mist_post(self, uri: str, body: dict | list | None = None) -> APIResponse: """ url = self._url(uri) headers = {"Content-Type": "application/json"} - logger.debug(f"apirequest:mist_post:Request body:{body}") + logger.debug("apirequest:mist_post:Request body:%s", body) if isinstance(body, str): fn = lambda: self._session.post(url, data=body, headers=headers) else: @@ -285,7 +286,7 @@ def mist_put(self, uri: str, body: dict | None = None) -> APIResponse: """ url = self._url(uri) headers = {"Content-Type": "application/json"} - logger.debug(f"apirequest:mist_put:Request body:{body}") + logger.debug("apirequest:mist_put:Request body:%s", body) if isinstance(body, str): fn = lambda: self._session.put(url, data=body, headers=headers) else: @@ -333,19 +334,19 @@ def mist_post_file( multipart_form_data = {} url = self._url(uri) logger.debug( - f"apirequest:mist_post_file:initial multipart_form_data:{multipart_form_data}" + "apirequest:mist_post_file:initial multipart_form_data:%s", multipart_form_data ) generated_multipart_form_data: dict[str, Any] = {} for key in multipart_form_data: logger.debug( - f"apirequest:mist_post_file:" - f"multipart_form_data:{key} = {multipart_form_data[key]}" + "apirequest:mist_post_file:multipart_form_data:%s = %s", + key, multipart_form_data[key], ) if multipart_form_data[key]: try: if key in ["csv", "file"]: logger.debug( - f"apirequest:mist_post_file:reading file:{multipart_form_data[key]}" + "apirequest:mist_post_file:reading file:%s", multipart_form_data[key] ) f = open(multipart_form_data[key], "rb") generated_multipart_form_data[key] = ( @@ -360,23 +361,23 @@ def mist_post_file( ) except (OSError, json.JSONDecodeError): logger.error( - f"apirequest:mist_post_file:multipart_form_data:" - f"Unable to parse JSON object {key} " - f"with value {multipart_form_data[key]}" + "apirequest:mist_post_file:multipart_form_data:" + "Unable to parse JSON object %s with value %s", + key, multipart_form_data[key], ) logger.error( "apirequest:mist_post_file: Exception occurred", exc_info=True, ) logger.debug( - f"apirequest:mist_post_file:" - f"final multipart_form_data:{generated_multipart_form_data}" + "apirequest:mist_post_file:final multipart_form_data:%s", + generated_multipart_form_data, ) def _do_post_file(): resp = self._session.post(url, files=generated_multipart_form_data) logger.debug( - f"apirequest:mist_post_file:request body:{self.remove_file_from_body(resp)}" + "apirequest:mist_post_file:request body:%s", self.remove_file_from_body(resp) ) return resp From b84184480d10c1d793f203e991d18ab553fee616 Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Fri, 13 Mar 2026 20:58:04 +0100 Subject: [PATCH 10/16] refactor: update logging format to use %-style for consistency --- src/mistapi/__api_session.py | 10 +++--- src/mistapi/__logger.py | 59 +++++++++--------------------------- 2 files changed, 20 insertions(+), 49 deletions(-) diff --git a/src/mistapi/__api_session.py b/src/mistapi/__api_session.py index 48e62d0..c510407 100644 --- a/src/mistapi/__api_session.py +++ b/src/mistapi/__api_session.py @@ -308,7 +308,7 @@ def _load_env(self, env_file=None) -> None: os.path.expanduser("~"), env_file.replace("~/", "") ) env_file = os.path.abspath(env_file) - CONSOLE.debug(f"Loading settings from {env_file}") + CONSOLE.debug("Loading settings from %s", env_file) LOGGER.debug("apisession:_load_env:loading settings from %s", env_file) dotenv_path = Path(env_file) load_dotenv(dotenv_path=dotenv_path, override=True) @@ -389,10 +389,10 @@ def set_cloud(self, cloud_uri: str) -> None: LOGGER.debug( "apisession:set_cloud:Mist Cloud configured to %s", self._cloud_uri ) - CONSOLE.debug(f"Mist Cloud configured to {self._cloud_uri}") + CONSOLE.debug("Mist Cloud configured to %s", self._cloud_uri) else: LOGGER.error("apisession:set_cloud: %s is not valid", cloud_uri) - CONSOLE.error(f"{cloud_uri} is not valid") + CONSOLE.error("%s is not valid", cloud_uri) def get_cloud(self): """ @@ -467,7 +467,7 @@ def set_email(self, email: str | None = None) -> None: else: self.email = input("Login: ") LOGGER.info("apisession:set_email:email configured to %s", self.email) - CONSOLE.debug(f"Email configured to {self.email}") + CONSOLE.debug("Email configured to %s", self.email) def set_password(self, password: str | None = None) -> None: """ @@ -713,7 +713,7 @@ def _process_login(self, retry: bool = True) -> str | None: LOGGER.error( "apisession:_process_login:authentication failed:%s", error ) - CONSOLE.error(f"Authentication failed: {error}\r\n") + CONSOLE.error("Authentication failed: %s\r\n", error) self.email = None self._password = None LOGGER.info( diff --git a/src/mistapi/__logger.py b/src/mistapi/__logger.py index 51b9347..f69c366 100644 --- a/src/mistapi/__logger.py +++ b/src/mistapi/__logger.py @@ -150,60 +150,31 @@ def sanitize(self, data) -> str: sanitized_data = self.sensitive_pattern.sub(r'\1"******"', data_str) return sanitized_data - def critical(self, message) -> None: - """ - Docstring for critical + def _format(self, message, args) -> str: + """Apply %-style formatting if args are provided, then sanitize.""" + if args: + message = str(message) % args + return self.sanitize(message) - :param self: Description - :param message: Description - :type message: str - """ + def critical(self, message, *args) -> None: if self.level <= 50 and self.level > 0: - print(f"[{magenta('CRITICAL ')}] {self.sanitize(message)}") - - def error(self, message) -> None: - """ - Docstring for error + print(f"[{magenta('CRITICAL ')}] {self._format(message, args)}") - :param self: Description - :param message: Description - :type message: str - """ + def error(self, message, *args) -> None: if self.level <= 40 and self.level > 0: - print(f"[{red(' ERROR ')}] {self.sanitize(message)}") - - def warning(self, message) -> None: - """ - Docstring for warning + print(f"[{red(' ERROR ')}] {self._format(message, args)}") - :param self: Description - :param message: Description - :type message: str - """ + def warning(self, message, *args) -> None: if self.level <= 30 and self.level > 0: - print(f"[{yellow(' WARNING ')}] {self.sanitize(message)}") - - def info(self, message) -> None: - """ - Docstring for info + print(f"[{yellow(' WARNING ')}] {self._format(message, args)}") - :param self: Description - :param message: Description - :type message: str - """ + def info(self, message, *args) -> None: if self.level <= 20 and self.level > 0: - print(f"[{green(' INFO ')}] {self.sanitize(message)}") + print(f"[{green(' INFO ')}] {self._format(message, args)}") - def debug(self, message) -> None: - """ - Docstring for debug - - :param self: Description - :param message: Description - :type message: str - """ + def debug(self, message, *args) -> None: if self.level <= 10 and self.level > 0: - print(f"[{white('DEBUG ')}] {self.sanitize(message)}") + print(f"[{white('DEBUG ')}] {self._format(message, args)}") def _set_log_level( self, console_log_level: int = 20, logging_log_level: int = 10 From ed05bf1c1a1d79fed185a937ab9656fbe67d02b2 Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Fri, 13 Mar 2026 21:08:38 +0100 Subject: [PATCH 11/16] refactor: improve logging messages for security --- src/mistapi/__api_session.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/mistapi/__api_session.py b/src/mistapi/__api_session.py index c510407..0f4aad5 100644 --- a/src/mistapi/__api_session.py +++ b/src/mistapi/__api_session.py @@ -285,7 +285,7 @@ def _load_keyring(self, keyring_service) -> None: self.set_api_token(mist_apitoken) mist_user = keyring.get_password(keyring_service, "MIST_USER") if mist_user: - LOGGER.info("apisession:_load_keyring: MIST_USER=%s", mist_user) + LOGGER.info("apisession:_load_keyring: MIST_USER retrieved") self.set_email(mist_user) mist_password = keyring.get_password(keyring_service, "MIST_PASSWORD") if mist_password: @@ -466,8 +466,8 @@ def set_email(self, email: str | None = None) -> None: self.email = email else: self.email = input("Login: ") - LOGGER.info("apisession:set_email:email configured to %s", self.email) - CONSOLE.debug("Email configured to %s", self.email) + LOGGER.info("apisession:set_email:email configured") + CONSOLE.debug("Email configured") def set_password(self, password: str | None = None) -> None: """ @@ -573,7 +573,10 @@ def _get_api_token_data(self, apitoken) -> tuple[str | None, list | None]: data.status_code, ) CONSOLE.critical( - f"Invalid API Token {apitoken[:4]}...{apitoken[-4:]}: status code {data.status_code}\r\n" + "Invalid API Token %s...%s: status code %s\r\n", + apitoken[:4], + apitoken[-4:], + data.status_code, ) raise ValueError( f"Invalid API Token {apitoken[:4]}...{apitoken[-4:]}: status code {data.status_code}" @@ -856,7 +859,7 @@ def login_with_return( LOGGER.info("apisession:login_with_return:access authorized") return {"authenticated": True, "error": ""} else: - LOGGER.error("apisession:login_with_return:access denied: %s", resp.data) + LOGGER.error("apisession:login_with_return:access denied: status code %s", resp.status_code) return {"authenticated": False, "error": resp.data} def logout(self) -> None: @@ -1057,7 +1060,8 @@ def _two_factor_authentication(self, two_factor: str) -> bool: resp.status_code, ) CONSOLE.error( - f"2FA authentication failed with error code: {resp.status_code}\r\n" + "2FA authentication failed with error code: %s\r\n", + resp.status_code, ) return False @@ -1099,9 +1103,7 @@ def _getself(self) -> bool: print(" Authenticated ".center(80, "-")) print(f"\r\nWelcome {self.first_name} {self.last_name}!\r\n") LOGGER.info( - "apisession:_getself:account used: %s %s", - self.first_name, - self.last_name, + "apisession:_getself:account info processed successfully" ) return True elif resp.proxy_error: From 3c1dbcab0e8b79d18dae51dc7eddbf27e0bcd48c Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Fri, 13 Mar 2026 21:14:17 +0100 Subject: [PATCH 12/16] Refactor API parameter handling and improve documentation - Updated `searchOrgNacClients` and `searchSiteNacClients` to correct status options from 'session_ended' to 'session_stopped'. - Rearranged parameters in `searchOrgWanClients`, `searchSiteWanClients`, and `searchSiteWirelessClients` for consistency and clarity. - Added new parameters `cert_expiry_duration` and `psk_name` to `searchSiteNacClients` and `searchSiteWirelessClients` respectively. - Removed unused parameters and cleaned up query parameter handling in various search functions. - Enhanced logging and sanitization in `test_logger.py` to ensure sensitive data is properly redacted. - Deleted obsolete test file `test.py` to streamline the test suite. - Improved test readability and structure across multiple test files. --- mist_openapi | 2 +- src/mistapi/__api_request.py | 27 ++- src/mistapi/__api_session.py | 9 +- src/mistapi/api/v1/orgs/clients.py | 56 +++--- src/mistapi/api/v1/orgs/devices.py | 77 ++++----- src/mistapi/api/v1/orgs/mxedges.py | 12 +- src/mistapi/api/v1/orgs/nac_clients.py | 4 +- src/mistapi/api/v1/orgs/wan_clients.py | 16 +- src/mistapi/api/v1/sites/clients.py | 64 ++++--- src/mistapi/api/v1/sites/devices.py | 218 ++++++++++++------------ src/mistapi/api/v1/sites/nac_clients.py | 28 +-- src/mistapi/api/v1/sites/wan_clients.py | 16 +- src/mistapi/device_utils/ssr.py | 4 +- src/mistapi/websockets/__ws_client.py | 6 +- test.py | 46 ----- tests/unit/test_api_request.py | 9 +- tests/unit/test_api_response.py | 27 ++- tests/unit/test_api_session.py | 20 ++- tests/unit/test_logger.py | 53 +++--- tests/unit/test_models.py | 4 +- tests/unit/test_websocket_client.py | 12 +- 21 files changed, 374 insertions(+), 336 deletions(-) delete mode 100644 test.py diff --git a/mist_openapi b/mist_openapi index c0a88a3..2efa6e1 160000 --- a/mist_openapi +++ b/mist_openapi @@ -1 +1 @@ -Subproject commit c0a88a3c79e42d233ea45a92ffffd13f968a2a6b +Subproject commit 2efa6e1024bd4e3532695a6aeb6be5787b51d9f5 diff --git a/src/mistapi/__api_request.py b/src/mistapi/__api_request.py index c5b1c08..1ceb605 100644 --- a/src/mistapi/__api_request.py +++ b/src/mistapi/__api_request.py @@ -188,7 +188,9 @@ def _request_with_retry( self._log_proxy() resp = request_fn() logger.debug( - "apirequest:%s:request headers:%s", method_name, self._remove_auth_from_headers(resp) + "apirequest:%s:request headers:%s", + method_name, + self._remove_auth_from_headers(resp), ) resp.raise_for_status() break @@ -203,7 +205,9 @@ def _request_with_retry( if e.response.status_code == 429 and attempt < self._MAX_429_RETRIES: logger.warning( "apirequest:%s:HTTP 429 (attempt %s/%s)", - method_name, attempt + 1, self._MAX_429_RETRIES, + method_name, + attempt + 1, + self._MAX_429_RETRIES, ) try: self._next_apitoken() @@ -214,7 +218,9 @@ def _request_with_retry( logger.error("apirequest:%s:HTTP error: %s", method_name, e) if resp: logger.error( - "apirequest:%s:HTTP error description: %s", method_name, resp.json() + "apirequest:%s:HTTP error description: %s", + method_name, + resp.json(), ) break except Exception as e: @@ -334,19 +340,22 @@ def mist_post_file( multipart_form_data = {} url = self._url(uri) logger.debug( - "apirequest:mist_post_file:initial multipart_form_data:%s", multipart_form_data + "apirequest:mist_post_file:initial multipart_form_data:%s", + multipart_form_data, ) generated_multipart_form_data: dict[str, Any] = {} for key in multipart_form_data: logger.debug( "apirequest:mist_post_file:multipart_form_data:%s = %s", - key, multipart_form_data[key], + key, + multipart_form_data[key], ) if multipart_form_data[key]: try: if key in ["csv", "file"]: logger.debug( - "apirequest:mist_post_file:reading file:%s", multipart_form_data[key] + "apirequest:mist_post_file:reading file:%s", + multipart_form_data[key], ) f = open(multipart_form_data[key], "rb") generated_multipart_form_data[key] = ( @@ -363,7 +372,8 @@ def mist_post_file( logger.error( "apirequest:mist_post_file:multipart_form_data:" "Unable to parse JSON object %s with value %s", - key, multipart_form_data[key], + key, + multipart_form_data[key], ) logger.error( "apirequest:mist_post_file: Exception occurred", @@ -377,7 +387,8 @@ def mist_post_file( def _do_post_file(): resp = self._session.post(url, files=generated_multipart_form_data) logger.debug( - "apirequest:mist_post_file:request body:%s", self.remove_file_from_body(resp) + "apirequest:mist_post_file:request body:%s", + self.remove_file_from_body(resp), ) return resp diff --git a/src/mistapi/__api_session.py b/src/mistapi/__api_session.py index 0f4aad5..4f3fa82 100644 --- a/src/mistapi/__api_session.py +++ b/src/mistapi/__api_session.py @@ -859,7 +859,10 @@ def login_with_return( LOGGER.info("apisession:login_with_return:access authorized") return {"authenticated": True, "error": ""} else: - LOGGER.error("apisession:login_with_return:access denied: status code %s", resp.status_code) + LOGGER.error( + "apisession:login_with_return:access denied: status code %s", + resp.status_code, + ) return {"authenticated": False, "error": resp.data} def logout(self) -> None: @@ -1102,9 +1105,7 @@ def _getself(self) -> bool: print() print(" Authenticated ".center(80, "-")) print(f"\r\nWelcome {self.first_name} {self.last_name}!\r\n") - LOGGER.info( - "apisession:_getself:account info processed successfully" - ) + LOGGER.info("apisession:_getself:account info processed successfully") return True elif resp.proxy_error: LOGGER.critical("apisession:_getself:proxy not valid...") diff --git a/src/mistapi/api/v1/orgs/clients.py b/src/mistapi/api/v1/orgs/clients.py index ff81659..9e765f5 100644 --- a/src/mistapi/api/v1/orgs/clients.py +++ b/src/mistapi/api/v1/orgs/clients.py @@ -284,20 +284,20 @@ def searchOrgWirelessClients( mist_session: _APISession, org_id: str, site_id: str | None = None, - mac: str | None = None, - ip: str | None = None, - hostname: str | None = None, + ap: str | None = None, band: str | None = None, device: str | None = None, - os: str | None = None, + hostname: str | None = None, + ip: str | None = None, + mac: str | None = None, model: str | None = None, - ap: str | None = None, + os: str | None = None, psk_id: str | None = None, psk_name: str | None = None, - username: str | None = None, - vlan: str | None = None, ssid: str | None = None, text: str | None = None, + username: str | None = None, + vlan: str | None = None, limit: int | None = None, start: str | None = None, end: str | None = None, @@ -320,20 +320,20 @@ def searchOrgWirelessClients( QUERY PARAMS ------------ site_id : str - mac : str - ip : str - hostname : str + ap : str band : str device : str - os : str + hostname : str + ip : str + mac : str model : str - ap : str + os : str psk_id : str psk_name : str - username : str - vlan : str ssid : str text : str + username : str + vlan : str limit : int, default: 100 start : str end : str @@ -351,34 +351,34 @@ def searchOrgWirelessClients( query_params: dict[str, str] = {} if site_id: query_params["site_id"] = str(site_id) - if mac: - query_params["mac"] = str(mac) - if ip: - query_params["ip"] = str(ip) - if hostname: - query_params["hostname"] = str(hostname) + if ap: + query_params["ap"] = str(ap) if band: query_params["band"] = str(band) if device: query_params["device"] = str(device) - if os: - query_params["os"] = str(os) + if hostname: + query_params["hostname"] = str(hostname) + if ip: + query_params["ip"] = str(ip) + if mac: + query_params["mac"] = str(mac) if model: query_params["model"] = str(model) - if ap: - query_params["ap"] = str(ap) + if os: + query_params["os"] = str(os) if psk_id: query_params["psk_id"] = str(psk_id) if psk_name: query_params["psk_name"] = str(psk_name) - if username: - query_params["username"] = str(username) - if vlan: - query_params["vlan"] = str(vlan) if ssid: query_params["ssid"] = str(ssid) if text: query_params["text"] = str(text) + if username: + query_params["username"] = str(username) + if vlan: + query_params["vlan"] = str(vlan) if limit: query_params["limit"] = str(limit) if start: diff --git a/src/mistapi/api/v1/orgs/devices.py b/src/mistapi/api/v1/orgs/devices.py index ac09149..bc97744 100644 --- a/src/mistapi/api/v1/orgs/devices.py +++ b/src/mistapi/api/v1/orgs/devices.py @@ -485,17 +485,16 @@ def listOrgApsMacs( def searchOrgDevices( mist_session: _APISession, org_id: str, - band_24_bandwidth: int | None = None, band_24_channel: int | None = None, - band_24_power: int | None = None, - band_5_bandwidth: int | None = None, band_5_channel: int | None = None, - band_5_power: int | None = None, - band_6_bandwidth: int | None = None, band_6_channel: int | None = None, + band_24_bandwidth: int | None = None, + band_5_bandwidth: int | None = None, + band_6_bandwidth: int | None = None, + band_24_power: int | None = None, + band_5_power: int | None = None, band_6_power: int | None = None, - cpu: str | None = None, - clustered: str | None = None, + clustered: bool | None = None, eth0_port_speed: int | None = None, evpntopo_id: str | None = None, ext_ip: str | None = None, @@ -505,8 +504,6 @@ def searchOrgDevices( last_hostname: str | None = None, lldp_mgmt_addr: str | None = None, lldp_port_id: str | None = None, - lldp_power_allocated: int | None = None, - lldp_power_draw: int | None = None, lldp_system_desc: str | None = None, lldp_system_name: str | None = None, mac: str | None = None, @@ -518,10 +515,12 @@ def searchOrgDevices( node0_mac: str | None = None, node1_mac: str | None = None, power_constrained: bool | None = None, + radius_stats: str | None = None, site_id: str | None = None, + stats: bool | None = None, t128agent_version: str | None = None, - version: str | None = None, type: str | None = None, + version: str | None = None, limit: int | None = None, start: str | None = None, end: str | None = None, @@ -543,17 +542,16 @@ def searchOrgDevices( QUERY PARAMS ------------ - band_24_bandwidth : int band_24_channel : int - band_24_power : int - band_5_bandwidth : int band_5_channel : int - band_5_power : int - band_6_bandwidth : int band_6_channel : int + band_24_bandwidth : int + band_5_bandwidth : int + band_6_bandwidth : int + band_24_power : int + band_5_power : int band_6_power : int - cpu : str - clustered : str + clustered : bool eth0_port_speed : int evpntopo_id : str ext_ip : str @@ -563,8 +561,6 @@ def searchOrgDevices( last_hostname : str lldp_mgmt_addr : str lldp_port_id : str - lldp_power_allocated : int - lldp_power_draw : int lldp_system_desc : str lldp_system_name : str mac : str @@ -572,16 +568,19 @@ def searchOrgDevices( mxedge_id : str mxedge_ids : str mxtunnel_status : str{'down', 'up'} - If `type`==`ap`, MxTunnel status, up / down - node : str + When `type`==`ap`, MxTunnel status, up / down. + node : str{'node0', 'node1'} + When `type`==`gateway`. enum: `node0`, `node1` node0_mac : str node1_mac : str power_constrained : bool + radius_stats : str site_id : str + stats : bool t128agent_version : str - version : str type : str{'ap', 'gateway', 'switch'}, default: ap Type of device. enum: `ap`, `gateway`, `switch` + version : str limit : int, default: 100 start : str end : str @@ -597,26 +596,24 @@ def searchOrgDevices( uri = f"/api/v1/orgs/{org_id}/devices/search" query_params: dict[str, str] = {} - if band_24_bandwidth: - query_params["band_24_bandwidth"] = str(band_24_bandwidth) if band_24_channel: query_params["band_24_channel"] = str(band_24_channel) - if band_24_power: - query_params["band_24_power"] = str(band_24_power) - if band_5_bandwidth: - query_params["band_5_bandwidth"] = str(band_5_bandwidth) if band_5_channel: query_params["band_5_channel"] = str(band_5_channel) - if band_5_power: - query_params["band_5_power"] = str(band_5_power) - if band_6_bandwidth: - query_params["band_6_bandwidth"] = str(band_6_bandwidth) if band_6_channel: query_params["band_6_channel"] = str(band_6_channel) + if band_24_bandwidth: + query_params["band_24_bandwidth"] = str(band_24_bandwidth) + if band_5_bandwidth: + query_params["band_5_bandwidth"] = str(band_5_bandwidth) + if band_6_bandwidth: + query_params["band_6_bandwidth"] = str(band_6_bandwidth) + if band_24_power: + query_params["band_24_power"] = str(band_24_power) + if band_5_power: + query_params["band_5_power"] = str(band_5_power) if band_6_power: query_params["band_6_power"] = str(band_6_power) - if cpu: - query_params["cpu"] = str(cpu) if clustered: query_params["clustered"] = str(clustered) if eth0_port_speed: @@ -637,10 +634,6 @@ def searchOrgDevices( query_params["lldp_mgmt_addr"] = str(lldp_mgmt_addr) if lldp_port_id: query_params["lldp_port_id"] = str(lldp_port_id) - if lldp_power_allocated: - query_params["lldp_power_allocated"] = str(lldp_power_allocated) - if lldp_power_draw: - query_params["lldp_power_draw"] = str(lldp_power_draw) if lldp_system_desc: query_params["lldp_system_desc"] = str(lldp_system_desc) if lldp_system_name: @@ -663,14 +656,18 @@ def searchOrgDevices( query_params["node1_mac"] = str(node1_mac) if power_constrained: query_params["power_constrained"] = str(power_constrained) + if radius_stats: + query_params["radius_stats"] = str(radius_stats) if site_id: query_params["site_id"] = str(site_id) + if stats: + query_params["stats"] = str(stats) if t128agent_version: query_params["t128agent_version"] = str(t128agent_version) - if version: - query_params["version"] = str(version) if type: query_params["type"] = str(type) + if version: + query_params["version"] = str(version) if limit: query_params["limit"] = str(limit) if start: diff --git a/src/mistapi/api/v1/orgs/mxedges.py b/src/mistapi/api/v1/orgs/mxedges.py index 81986a8..4fc974f 100644 --- a/src/mistapi/api/v1/orgs/mxedges.py +++ b/src/mistapi/api/v1/orgs/mxedges.py @@ -379,12 +379,13 @@ def searchOrgMistEdgeEvents( def searchOrgMxEdges( mist_session: _APISession, org_id: str, + hostname: str | None = None, mxedge_id: str | None = None, - site_id: str | None = None, mxcluster_id: str | None = None, model: str | None = None, distro: str | None = None, tunterm_version: str | None = None, + site_id: str | None = None, stats: bool | None = None, limit: int | None = None, start: str | None = None, @@ -407,12 +408,13 @@ def searchOrgMxEdges( QUERY PARAMS ------------ + hostname : str mxedge_id : str - site_id : str mxcluster_id : str model : str distro : str tunterm_version : str + site_id : str stats : bool limit : int, default: 100 start : str @@ -429,10 +431,10 @@ def searchOrgMxEdges( uri = f"/api/v1/orgs/{org_id}/mxedges/search" query_params: dict[str, str] = {} + if hostname: + query_params["hostname"] = str(hostname) if mxedge_id: query_params["mxedge_id"] = str(mxedge_id) - if site_id: - query_params["site_id"] = str(site_id) if mxcluster_id: query_params["mxcluster_id"] = str(mxcluster_id) if model: @@ -441,6 +443,8 @@ def searchOrgMxEdges( query_params["distro"] = str(distro) if tunterm_version: query_params["tunterm_version"] = str(tunterm_version) + if site_id: + query_params["site_id"] = str(site_id) if stats: query_params["stats"] = str(stats) if limit: diff --git a/src/mistapi/api/v1/orgs/nac_clients.py b/src/mistapi/api/v1/orgs/nac_clients.py index 708ebe4..24e6515 100644 --- a/src/mistapi/api/v1/orgs/nac_clients.py +++ b/src/mistapi/api/v1/orgs/nac_clients.py @@ -417,8 +417,8 @@ def searchOrgNacClients( ingress_vlan : str os : str ssid : str - status : str{'permitted', 'session_started', 'session_ended', 'denied'} - Connection status of client i.e "permitted", "denied, "session_stared", "session_ended" + status : str{'permitted', 'session_started', 'session_stopped', 'denied'} + Connection status of client i.e "permitted", "denied, "session_started", "session_stopped" text : str timestamp : float type : str diff --git a/src/mistapi/api/v1/orgs/wan_clients.py b/src/mistapi/api/v1/orgs/wan_clients.py index 4af3281..c191c8b 100644 --- a/src/mistapi/api/v1/orgs/wan_clients.py +++ b/src/mistapi/api/v1/orgs/wan_clients.py @@ -148,12 +148,12 @@ def searchOrgWanClients( mist_session: _APISession, org_id: str, site_id: str | None = None, - mac: str | None = None, hostname: str | None = None, ip: str | None = None, - network: str | None = None, ip_src: str | None = None, + mac: str | None = None, mfg: str | None = None, + network: str | None = None, limit: int | None = None, start: str | None = None, end: str | None = None, @@ -176,12 +176,12 @@ def searchOrgWanClients( QUERY PARAMS ------------ site_id : str - mac : str hostname : str ip : str - network : str ip_src : str + mac : str mfg : str + network : str limit : int, default: 100 start : str end : str @@ -199,18 +199,18 @@ def searchOrgWanClients( query_params: dict[str, str] = {} if site_id: query_params["site_id"] = str(site_id) - if mac: - query_params["mac"] = str(mac) if hostname: query_params["hostname"] = str(hostname) if ip: query_params["ip"] = str(ip) - if network: - query_params["network"] = str(network) if ip_src: query_params["ip_src"] = str(ip_src) + if mac: + query_params["mac"] = str(mac) if mfg: query_params["mfg"] = str(mfg) + if network: + query_params["network"] = str(network) if limit: query_params["limit"] = str(limit) if start: diff --git a/src/mistapi/api/v1/sites/clients.py b/src/mistapi/api/v1/sites/clients.py index 5ec1955..8a52fcc 100644 --- a/src/mistapi/api/v1/sites/clients.py +++ b/src/mistapi/api/v1/sites/clients.py @@ -301,16 +301,20 @@ def searchSiteWirelessClientEvents( def searchSiteWirelessClients( mist_session: _APISession, site_id: str, - mac: str | None = None, - ip: str | None = None, - hostname: str | None = None, + ap: str | None = None, + band: str | None = None, device: str | None = None, - os: str | None = None, + hostname: str | None = None, + ip: str | None = None, + mac: str | None = None, model: str | None = None, - ap: str | None = None, + os: str | None = None, + psk_id: str | None = None, + psk_name: str | None = None, ssid: str | None = None, text: str | None = None, - nacrule_id: str | None = None, + username: str | None = None, + vlan: str | None = None, limit: int | None = None, start: str | None = None, end: str | None = None, @@ -332,16 +336,20 @@ def searchSiteWirelessClients( QUERY PARAMS ------------ - mac : str - ip : str - hostname : str + ap : str + band : str device : str - os : str + hostname : str + ip : str + mac : str model : str - ap : str + os : str + psk_id : str + psk_name : str ssid : str text : str - nacrule_id : str + username : str + vlan : str limit : int, default: 100 start : str end : str @@ -357,26 +365,34 @@ def searchSiteWirelessClients( uri = f"/api/v1/sites/{site_id}/clients/search" query_params: dict[str, str] = {} - if mac: - query_params["mac"] = str(mac) - if ip: - query_params["ip"] = str(ip) - if hostname: - query_params["hostname"] = str(hostname) + if ap: + query_params["ap"] = str(ap) + if band: + query_params["band"] = str(band) if device: query_params["device"] = str(device) - if os: - query_params["os"] = str(os) + if hostname: + query_params["hostname"] = str(hostname) + if ip: + query_params["ip"] = str(ip) + if mac: + query_params["mac"] = str(mac) if model: query_params["model"] = str(model) - if ap: - query_params["ap"] = str(ap) + if os: + query_params["os"] = str(os) + if psk_id: + query_params["psk_id"] = str(psk_id) + if psk_name: + query_params["psk_name"] = str(psk_name) if ssid: query_params["ssid"] = str(ssid) if text: query_params["text"] = str(text) - if nacrule_id: - query_params["nacrule_id"] = str(nacrule_id) + if username: + query_params["username"] = str(username) + if vlan: + query_params["vlan"] = str(vlan) if limit: query_params["limit"] = str(limit) if start: diff --git a/src/mistapi/api/v1/sites/devices.py b/src/mistapi/api/v1/sites/devices.py index 76d90dd..026a4eb 100644 --- a/src/mistapi/api/v1/sites/devices.py +++ b/src/mistapi/api/v1/sites/devices.py @@ -851,39 +851,41 @@ def restoreSiteMultipleDeviceBackupVersion( def searchSiteDevices( mist_session: _APISession, site_id: str, - hostname: str | None = None, - type: str | None = None, - model: str | None = None, - mac: str | None = None, - ext_ip: str | None = None, - version: str | None = None, - power_constrained: bool | None = None, - ip: str | None = None, - mxtunnel_status: str | None = None, - mxedge_id: str | None = None, - mxedge_ids: list | None = None, - last_hostname: str | None = None, - last_config_status: str | None = None, - radius_stats: str | None = None, - cpu: str | None = None, - node0_mac: str | None = None, - clustered: bool | None = None, - t128agent_version: str | None = None, - node1_mac: str | None = None, - node: str | None = None, - evpntopo_id: str | None = None, - lldp_system_name: str | None = None, - lldp_system_desc: str | None = None, - lldp_port_id: str | None = None, - lldp_mgmt_addr: str | None = None, band_24_channel: int | None = None, band_5_channel: int | None = None, band_6_channel: int | None = None, band_24_bandwidth: int | None = None, band_5_bandwidth: int | None = None, band_6_bandwidth: int | None = None, + band_24_power: int | None = None, + band_5_power: int | None = None, + band_6_power: int | None = None, + clustered: bool | None = None, eth0_port_speed: int | None = None, + evpntopo_id: str | None = None, + ext_ip: str | None = None, + hostname: str | None = None, + ip: str | None = None, + last_config_status: str | None = None, + last_hostname: str | None = None, + lldp_mgmt_addr: str | None = None, + lldp_port_id: str | None = None, + lldp_system_desc: str | None = None, + lldp_system_name: str | None = None, + mac: str | None = None, + model: str | None = None, + mxedge_id: str | None = None, + mxedge_ids: str | None = None, + mxtunnel_status: str | None = None, + node: str | None = None, + node0_mac: str | None = None, + node1_mac: str | None = None, + power_constrained: bool | None = None, + radius_stats: str | None = None, stats: bool | None = None, + t128agent_version: str | None = None, + type: str | None = None, + version: str | None = None, limit: int | None = None, start: str | None = None, end: str | None = None, @@ -906,42 +908,44 @@ def searchSiteDevices( QUERY PARAMS ------------ - hostname : str - type : str{'ap', 'gateway', 'switch'}, default: ap - model : str - mac : str - ext_ip : str - version : str - power_constrained : bool - ip : str - mxtunnel_status : str{'down', 'up'} - For APs only, MxTunnel status, up / down. - mxedge_id : str - mxedge_ids : list - For APs only, list of Mist Edge id, if AP is connecting to a Mist Edge - last_hostname : str - last_config_status : str - radius_stats : str - cpu : str - node0_mac : str - clustered : bool - t128agent_version : str - node1_mac : str - node : str{'node0', 'node1'} - For Gateways only. enum: `node0`, `node1` - evpntopo_id : str - lldp_system_name : str - lldp_system_desc : str - lldp_port_id : str - lldp_mgmt_addr : str band_24_channel : int band_5_channel : int band_6_channel : int band_24_bandwidth : int band_5_bandwidth : int band_6_bandwidth : int + band_24_power : int + band_5_power : int + band_6_power : int + clustered : bool eth0_port_speed : int + evpntopo_id : str + ext_ip : str + hostname : str + ip : str + last_config_status : str + last_hostname : str + lldp_mgmt_addr : str + lldp_port_id : str + lldp_system_desc : str + lldp_system_name : str + mac : str + model : str + mxedge_id : str + mxedge_ids : str + mxtunnel_status : str{'down', 'up'} + When `type`==`ap`, MxTunnel status, up / down. + node : str{'node0', 'node1'} + When `type`==`gateway`. enum: `node0`, `node1` + node0_mac : str + node1_mac : str + power_constrained : bool + radius_stats : str stats : bool + t128agent_version : str + type : str{'ap', 'gateway', 'switch'}, default: ap + Type of device. enum: `ap`, `gateway`, `switch` + version : str limit : int, default: 100 start : str end : str @@ -960,56 +964,6 @@ def searchSiteDevices( uri = f"/api/v1/sites/{site_id}/devices/search" query_params: dict[str, str] = {} - if hostname: - query_params["hostname"] = str(hostname) - if type: - query_params["type"] = str(type) - if model: - query_params["model"] = str(model) - if mac: - query_params["mac"] = str(mac) - if ext_ip: - query_params["ext_ip"] = str(ext_ip) - if version: - query_params["version"] = str(version) - if power_constrained: - query_params["power_constrained"] = str(power_constrained) - if ip: - query_params["ip"] = str(ip) - if mxtunnel_status: - query_params["mxtunnel_status"] = str(mxtunnel_status) - if mxedge_id: - query_params["mxedge_id"] = str(mxedge_id) - if mxedge_ids: - query_params["mxedge_ids"] = str(mxedge_ids) - if last_hostname: - query_params["last_hostname"] = str(last_hostname) - if last_config_status: - query_params["last_config_status"] = str(last_config_status) - if radius_stats: - query_params["radius_stats"] = str(radius_stats) - if cpu: - query_params["cpu"] = str(cpu) - if node0_mac: - query_params["node0_mac"] = str(node0_mac) - if clustered: - query_params["clustered"] = str(clustered) - if t128agent_version: - query_params["t128agent_version"] = str(t128agent_version) - if node1_mac: - query_params["node1_mac"] = str(node1_mac) - if node: - query_params["node"] = str(node) - if evpntopo_id: - query_params["evpntopo_id"] = str(evpntopo_id) - if lldp_system_name: - query_params["lldp_system_name"] = str(lldp_system_name) - if lldp_system_desc: - query_params["lldp_system_desc"] = str(lldp_system_desc) - if lldp_port_id: - query_params["lldp_port_id"] = str(lldp_port_id) - if lldp_mgmt_addr: - query_params["lldp_mgmt_addr"] = str(lldp_mgmt_addr) if band_24_channel: query_params["band_24_channel"] = str(band_24_channel) if band_5_channel: @@ -1022,10 +976,64 @@ def searchSiteDevices( query_params["band_5_bandwidth"] = str(band_5_bandwidth) if band_6_bandwidth: query_params["band_6_bandwidth"] = str(band_6_bandwidth) + if band_24_power: + query_params["band_24_power"] = str(band_24_power) + if band_5_power: + query_params["band_5_power"] = str(band_5_power) + if band_6_power: + query_params["band_6_power"] = str(band_6_power) + if clustered: + query_params["clustered"] = str(clustered) if eth0_port_speed: query_params["eth0_port_speed"] = str(eth0_port_speed) + if evpntopo_id: + query_params["evpntopo_id"] = str(evpntopo_id) + if ext_ip: + query_params["ext_ip"] = str(ext_ip) + if hostname: + query_params["hostname"] = str(hostname) + if ip: + query_params["ip"] = str(ip) + if last_config_status: + query_params["last_config_status"] = str(last_config_status) + if last_hostname: + query_params["last_hostname"] = str(last_hostname) + if lldp_mgmt_addr: + query_params["lldp_mgmt_addr"] = str(lldp_mgmt_addr) + if lldp_port_id: + query_params["lldp_port_id"] = str(lldp_port_id) + if lldp_system_desc: + query_params["lldp_system_desc"] = str(lldp_system_desc) + if lldp_system_name: + query_params["lldp_system_name"] = str(lldp_system_name) + if mac: + query_params["mac"] = str(mac) + if model: + query_params["model"] = str(model) + if mxedge_id: + query_params["mxedge_id"] = str(mxedge_id) + if mxedge_ids: + query_params["mxedge_ids"] = str(mxedge_ids) + if mxtunnel_status: + query_params["mxtunnel_status"] = str(mxtunnel_status) + if node: + query_params["node"] = str(node) + if node0_mac: + query_params["node0_mac"] = str(node0_mac) + if node1_mac: + query_params["node1_mac"] = str(node1_mac) + if power_constrained: + query_params["power_constrained"] = str(power_constrained) + if radius_stats: + query_params["radius_stats"] = str(radius_stats) if stats: query_params["stats"] = str(stats) + if t128agent_version: + query_params["t128agent_version"] = str(t128agent_version) + if type: + query_params["type"] = str(type) + if version: + query_params["version"] = str(version) if limit: query_params["limit"] = str(limit) if start: @@ -2380,7 +2388,7 @@ def getSiteDeviceZtpPassword( def testSiteSsrDnsResolution( - mist_session: _APISession, site_id: str, device_id: str, body: dict | list + mist_session: _APISession, site_id: str, device_id: str ) -> _APIResponse: """ API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/utilities/wan/test-site-ssr-dns-resolution @@ -2402,7 +2410,7 @@ def testSiteSsrDnsResolution( """ uri = f"/api/v1/sites/{site_id}/devices/{device_id}/resolve_dns" - resp = mist_session.mist_post(uri=uri, body=body) + resp = mist_session.mist_post(uri=uri) return resp diff --git a/src/mistapi/api/v1/sites/nac_clients.py b/src/mistapi/api/v1/sites/nac_clients.py index 92d7bff..e5a0e81 100644 --- a/src/mistapi/api/v1/sites/nac_clients.py +++ b/src/mistapi/api/v1/sites/nac_clients.py @@ -334,6 +334,7 @@ def searchSiteNacClients( site_id: str, ap: str | None = None, auth_type: str | None = None, + cert_expiry_duration: str | None = None, edr_managed: bool | None = None, edr_provider: str | None = None, edr_status: str | None = None, @@ -341,15 +342,14 @@ def searchSiteNacClients( hostname: str | None = None, idp_id: str | None = None, mac: str | None = None, - mdm_managed: bool | None = None, mdm_compliance: str | None = None, mdm_provider: str | None = None, + mdm_managed: bool | None = None, mfg: str | None = None, model: str | None = None, - mxedge_id: str | None = None, + nacrule_name: str | None = None, nacrule_id: str | None = None, nacrule_matched: bool | None = None, - nacrule_name: str | None = None, nas_vendor: str | None = None, nas_ip: str | None = None, ingress_vlan: str | None = None, @@ -385,6 +385,7 @@ def searchSiteNacClients( ------------ ap : str auth_type : str + cert_expiry_duration : str edr_managed : bool edr_provider : str{'crowdstrike', 'sentinelone'} EDR provider of client's organization @@ -394,22 +395,21 @@ def searchSiteNacClients( hostname : str idp_id : str mac : str - mdm_managed : bool mdm_compliance : str mdm_provider : str + mdm_managed : bool mfg : str model : str - mxedge_id : str + nacrule_name : str nacrule_id : str nacrule_matched : bool - nacrule_name : str nas_vendor : str nas_ip : str ingress_vlan : str os : str ssid : str - status : str{'permitted', 'session_started', 'session_ended', 'denied'} - Connection status of client i.e "permitted", "denied, "session_ended" + status : str{'permitted', 'session_started', 'session_stopped', 'denied'} + Connection status of client i.e "permitted", "denied, "session_started", "session_stopped" text : str timestamp : float type : str @@ -436,6 +436,8 @@ def searchSiteNacClients( query_params["ap"] = str(ap) if auth_type: query_params["auth_type"] = str(auth_type) + if cert_expiry_duration: + query_params["cert_expiry_duration"] = str(cert_expiry_duration) if edr_managed: query_params["edr_managed"] = str(edr_managed) if edr_provider: @@ -450,24 +452,22 @@ def searchSiteNacClients( query_params["idp_id"] = str(idp_id) if mac: query_params["mac"] = str(mac) - if mdm_managed: - query_params["mdm_managed"] = str(mdm_managed) if mdm_compliance: query_params["mdm_compliance"] = str(mdm_compliance) if mdm_provider: query_params["mdm_provider"] = str(mdm_provider) + if mdm_managed: + query_params["mdm_managed"] = str(mdm_managed) if mfg: query_params["mfg"] = str(mfg) if model: query_params["model"] = str(model) - if mxedge_id: - query_params["mxedge_id"] = str(mxedge_id) + if nacrule_name: + query_params["nacrule_name"] = str(nacrule_name) if nacrule_id: query_params["nacrule_id"] = str(nacrule_id) if nacrule_matched: query_params["nacrule_matched"] = str(nacrule_matched) - if nacrule_name: - query_params["nacrule_name"] = str(nacrule_name) if nas_vendor: query_params["nas_vendor"] = str(nas_vendor) if nas_ip: diff --git a/src/mistapi/api/v1/sites/wan_clients.py b/src/mistapi/api/v1/sites/wan_clients.py index b3dc7f8..0b30029 100644 --- a/src/mistapi/api/v1/sites/wan_clients.py +++ b/src/mistapi/api/v1/sites/wan_clients.py @@ -147,10 +147,12 @@ def searchSiteWanClientEvents( def searchSiteWanClients( mist_session: _APISession, site_id: str, - mac: str | None = None, hostname: str | None = None, ip: str | None = None, + ip_src: str | None = None, + mac: str | None = None, mfg: str | None = None, + network: str | None = None, limit: int | None = None, start: str | None = None, end: str | None = None, @@ -172,10 +174,12 @@ def searchSiteWanClients( QUERY PARAMS ------------ - mac : str hostname : str ip : str + ip_src : str + mac : str mfg : str + network : str limit : int, default: 100 start : str end : str @@ -191,14 +195,18 @@ def searchSiteWanClients( uri = f"/api/v1/sites/{site_id}/wan_clients/search" query_params: dict[str, str] = {} - if mac: - query_params["mac"] = str(mac) if hostname: query_params["hostname"] = str(hostname) if ip: query_params["ip"] = str(ip) + if ip_src: + query_params["ip_src"] = str(ip_src) + if mac: + query_params["mac"] = str(mac) if mfg: query_params["mfg"] = str(mfg) + if network: + query_params["network"] = str(network) if limit: query_params["limit"] = str(limit) if start: diff --git a/src/mistapi/device_utils/ssr.py b/src/mistapi/device_utils/ssr.py index d68abd5..0af8df1 100644 --- a/src/mistapi/device_utils/ssr.py +++ b/src/mistapi/device_utils/ssr.py @@ -47,7 +47,9 @@ from mistapi.device_utils.__tools.routes import show as retrieveRoutes # Service Path functions -from mistapi.device_utils.__tools.service_path import show_service_path as showServicePath +from mistapi.device_utils.__tools.service_path import ( + show_service_path as showServicePath, +) __all__ = [ # Classes/Enums diff --git a/src/mistapi/websockets/__ws_client.py b/src/mistapi/websockets/__ws_client.py index cb4811d..237b056 100644 --- a/src/mistapi/websockets/__ws_client.py +++ b/src/mistapi/websockets/__ws_client.py @@ -78,8 +78,10 @@ def _get_cookie(self) -> str | None: if cookies: safe = [] for c in cookies: - has_crlf = "\r" in c.name or "\n" in c.name or ( - c.value and ("\r" in c.value or "\n" in c.value) + has_crlf = ( + "\r" in c.name + or "\n" in c.name + or (c.value and ("\r" in c.value or "\n" in c.value)) ) if has_crlf: logger.warning( diff --git a/test.py b/test.py deleted file mode 100644 index e81f846..0000000 --- a/test.py +++ /dev/null @@ -1,46 +0,0 @@ -import asyncio - -import src.mistapi as mistapi - -# APISESSION = mistapi.APISession(env_file="~/.mist_env_ld_ro", show_cli_notif=False) -# ORG_ID = "9777c1a0-6ef6-11e6-8bbf-02e208b2d34f" -# SITE_ID = "a925ea04-8393-4e0f-ab6b-209f11382cee" -# AP_ID = "00000000-0000-0000-1000-04a92439fb75" -# SWITCH_ID = "00000000-0000-0000-1000-2093390b3580" -# GATEWAY_ID = "00000000-0000-0000-1000-409ea4e60b00" - -APISESSION = mistapi.APISession(env_file="~/.mist_env_gc1", show_cli_notif=False) -ORG_ID = "8aa21779-1178-4357-b3e0-42c02b93b870" -SITE_ID = "d6fb4f96-3ba4-4cf5-8af2-a8d7b85087ac" -AP_ID = "00000000-0000-0000-1000-04a92439fb75" -SWITCH_ID = "00000000-0000-0000-1000-2093390b3580" -GATEWAY_ID = "00000000-0000-0000-1000-0200010edbca" - -APISESSION.login() - -# data = asyncio.run( -# mistapi.websockets.utils.common.bounce_ports( -# apissession=APISESSION, -# site_id=SITE_ID, -# device_id=GATEWAY_ID, -# port_ids=["ge-0/0/3"], -# ) -# ) - - -data = asyncio.run( - mistapi.websockets.utils.junos.monitor_traffic( - apissession=APISESSION, - site_id=SITE_ID, - device_id=SWITCH_ID, - ) -) -print(data.trigger_api_response.data) -print("".center(50, "-")) -if data.ws_required: - if isinstance(data.ws_data, list): - print("".join(data.ws_data)) - else: - print(data.ws_data) -else: - print("No WebSocket data available.") diff --git a/tests/unit/test_api_request.py b/tests/unit/test_api_request.py index 09fa9b6..2e93fef 100644 --- a/tests/unit/test_api_request.py +++ b/tests/unit/test_api_request.py @@ -22,6 +22,7 @@ # Helpers # --------------------------------------------------------------------------- + def _make_api_request(cloud_uri="api.mist.com", tokens=None): """Create an APIRequest with a mocked session for isolated testing.""" with patch("mistapi.__api_request.requests.session") as mock_session_cls: @@ -38,9 +39,7 @@ def _make_api_request(cloud_uri="api.mist.com", tokens=None): if tokens: req._apitoken = list(tokens) req._apitoken_index = 0 - req._session.headers.update( - {"Authorization": "Token " + tokens[0]} - ) + req._session.headers.update({"Authorization": "Token " + tokens[0]}) return req @@ -468,7 +467,9 @@ def test_429_calls_handle_rate_limit(self, mock_sleep): resp_ok = _mock_response(status_code=200) fn = Mock(side_effect=[resp_429, resp_ok]) - with patch.object(req, "_handle_rate_limit", wraps=req._handle_rate_limit) as wrapped: + with patch.object( + req, "_handle_rate_limit", wraps=req._handle_rate_limit + ) as wrapped: req._request_with_retry("test", fn, "https://example.com") wrapped.assert_called_once_with(resp_429, 0) diff --git a/tests/unit/test_api_response.py b/tests/unit/test_api_response.py index 850fc0a..1c283fa 100644 --- a/tests/unit/test_api_response.py +++ b/tests/unit/test_api_response.py @@ -25,6 +25,7 @@ # Helpers # --------------------------------------------------------------------------- + def _make_mock_response(status_code=200, data=None, headers=None, json_raises=False): """Build a mock requests.Response with the given attributes.""" mock = Mock() @@ -44,6 +45,7 @@ def _make_mock_response(status_code=200, data=None, headers=None, json_raises=Fa # Tests: construction / default fields # --------------------------------------------------------------------------- + class TestAPIResponseConstruction: """Tests for APIResponse.__init__ with different response inputs.""" @@ -104,6 +106,7 @@ def test_proxy_error_with_response(self): # Tests: error responses # --------------------------------------------------------------------------- + class TestAPIResponseErrors: """Tests for error HTTP status codes and error payloads.""" @@ -140,6 +143,7 @@ def test_non_json_response_handled_gracefully(self): # Tests: _check_next with "next" in data # --------------------------------------------------------------------------- + class TestCheckNextFromData: """Tests for _check_next() when the response body contains a 'next' key.""" @@ -175,6 +179,7 @@ def test_next_value_none_in_data(self, api_response_factory): # Tests: _check_next with pagination headers # --------------------------------------------------------------------------- + class TestCheckNextFromHeaders: """Tests for _check_next() computing the next URL from pagination headers.""" @@ -198,8 +203,7 @@ def test_next_page_computed_from_headers(self): def test_next_page_with_existing_query_string(self): """When URL already has a query string, page should use '&'.""" resp = self._make_paginated_response( - total=50, limit=10, page=1, - url="https://api.mist.com/api/v1/sites?limit=10" + total=50, limit=10, page=1, url="https://api.mist.com/api/v1/sites?limit=10" ) assert resp.next == "/api/v1/sites?limit=10&page=2" @@ -225,8 +229,10 @@ def test_single_page_result(self): def test_existing_page_param_replaced(self): """When URL already contains page=N, it should be replaced.""" resp = self._make_paginated_response( - total=100, limit=10, page=2, - url="https://api.mist.com/api/v1/sites?limit=10&page=2" + total=100, + limit=10, + page=2, + url="https://api.mist.com/api/v1/sites?limit=10&page=2", ) assert resp.next == "/api/v1/sites?limit=10&page=3" @@ -234,8 +240,10 @@ def test_existing_page_param_replaced(self): def test_existing_page_param_first_page(self): """Replacing page=1 with page=2 when page param already in URL.""" resp = self._make_paginated_response( - total=100, limit=10, page=1, - url="https://api.mist.com/api/v1/sites?page=1&limit=10" + total=100, + limit=10, + page=1, + url="https://api.mist.com/api/v1/sites?page=1&limit=10", ) assert resp.next == "/api/v1/sites?page=2&limit=10" @@ -294,8 +302,10 @@ def test_no_headers_at_all(self, api_response_factory): def test_pagination_strips_host_prefix(self): """The computed next URL should be a relative /api/... path, not absolute.""" resp = self._make_paginated_response( - total=100, limit=10, page=1, - url="https://api.eu.mist.com/api/v1/orgs/abc/devices" + total=100, + limit=10, + page=1, + url="https://api.eu.mist.com/api/v1/orgs/abc/devices", ) assert resp.next.startswith("/api/v1/") @@ -307,6 +317,7 @@ def test_pagination_strips_host_prefix(self): # Tests: data types preserved # --------------------------------------------------------------------------- + class TestDataTypes: """Verify that different JSON response shapes are handled correctly.""" diff --git a/tests/unit/test_api_session.py b/tests/unit/test_api_session.py index 50716d2..e28076c 100644 --- a/tests/unit/test_api_session.py +++ b/tests/unit/test_api_session.py @@ -408,7 +408,9 @@ def test_environment_variable_type_handling(self) -> None: class TestNewSession: """Test _new_session() method""" - def test_new_session_returns_session_with_headers(self, authenticated_session) -> None: + def test_new_session_returns_session_with_headers( + self, authenticated_session + ) -> None: """_new_session creates a requests.Session with correct Accept header""" with patch("mistapi.__api_session.requests.session") as mock_session_cls: mock_sess = Mock() @@ -418,7 +420,9 @@ def test_new_session_returns_session_with_headers(self, authenticated_session) - result = authenticated_session._new_session() - assert result.headers["Accept"] == "application/json, application/vnd.api+json" + assert ( + result.headers["Accept"] == "application/json, application/vnd.api+json" + ) def test_new_session_sets_auth_header(self, authenticated_session) -> None: """_new_session includes Authorization header when API token is configured""" @@ -430,7 +434,9 @@ def test_new_session_sets_auth_header(self, authenticated_session) -> None: result = authenticated_session._new_session() - expected_token = authenticated_session._apitoken[authenticated_session._apitoken_index] + expected_token = authenticated_session._apitoken[ + authenticated_session._apitoken_index + ] assert result.headers["Authorization"] == f"Token {expected_token}" def test_new_session_sets_proxies(self) -> None: @@ -475,9 +481,7 @@ class TestSetApiTokenValidation: def test_set_api_token_no_validate(self, isolated_session) -> None: """set_api_token(validate=False) accepts tokens without calling _check_api_tokens""" isolated_session.set_cloud("api.mist.com") - with patch.object( - isolated_session, "_check_api_tokens" - ) as mock_check: + with patch.object(isolated_session, "_check_api_tokens") as mock_check: isolated_session.set_api_token("token_abc_123", validate=False) mock_check.assert_not_called() @@ -506,9 +510,7 @@ def test_delete_api_token_calls_mist_delete(self, authenticated_session) -> None result = authenticated_session.delete_api_token("token-id-123") - mock_delete.assert_called_once_with( - "/api/v1/self/apitokens/token-id-123" - ) + mock_delete.assert_called_once_with("/api/v1/self/apitokens/token-id-123") assert result is mock_resp diff --git a/tests/unit/test_logger.py b/tests/unit/test_logger.py index 1f5bdef..0c543c9 100644 --- a/tests/unit/test_logger.py +++ b/tests/unit/test_logger.py @@ -25,7 +25,7 @@ def test_redacts_password_double_quotes(self) -> None: raw = '{"password": "s3cret!"}' result = c.sanitize(raw) assert "s3cret!" not in result - assert '******' in result + assert "******" in result def test_redacts_password_single_quotes(self) -> None: """A 'password' field wrapped in single quotes is redacted.""" @@ -33,7 +33,7 @@ def test_redacts_password_single_quotes(self) -> None: raw = "{'password': 'mysecret'}" result = c.sanitize(raw) assert "mysecret" not in result - assert '******' in result + assert "******" in result @pytest.mark.parametrize( "field", @@ -46,7 +46,7 @@ def test_redacts_every_sensitive_field(self, field: str) -> None: raw = f'{{"{field}": "topSecret123"}}' result = c.sanitize(raw) assert "topSecret123" not in result - assert '******' in result + assert "******" in result def test_redacts_case_insensitively(self) -> None: """Field matching is case-insensitive.""" @@ -54,7 +54,7 @@ def test_redacts_case_insensitively(self) -> None: raw = '{"PASSWORD": "abc"}' result = c.sanitize(raw) assert "abc" not in result - assert '******' in result + assert "******" in result def test_redacts_multiple_fields_in_one_string(self) -> None: """Multiple sensitive fields in the same string are all redacted.""" @@ -79,7 +79,7 @@ def test_non_string_input_dict(self) -> None: result = c.sanitize(data) assert "hunter2" not in result assert "admin" in result - assert '******' in result + assert "******" in result def test_non_string_input_list(self) -> None: """A list containing a dict with sensitive data is sanitised.""" @@ -87,7 +87,7 @@ def test_non_string_input_list(self) -> None: data = [{"apitoken": "secret_tok"}] result = c.sanitize(data) assert "secret_tok" not in result - assert '******' in result + assert "******" in result def test_non_string_input_int(self) -> None: """An integer is serialised and returned as-is (no sensitive data).""" @@ -105,7 +105,7 @@ def test_empty_value_still_redacted(self) -> None: c = Console() raw = '{"password": ""}' result = c.sanitize(raw) - assert '******' in result + assert "******" in result # --------------------------------------------------------------------------- @@ -128,31 +128,45 @@ class TestConsoleLogLevelMethods: ("debug", 10), ] - @pytest.mark.parametrize("method_name,threshold", LOG_METHODS, ids=[m for m, _ in LOG_METHODS]) - def test_prints_when_level_equals_threshold(self, capsys, method_name, threshold) -> None: + @pytest.mark.parametrize( + "method_name,threshold", LOG_METHODS, ids=[m for m, _ in LOG_METHODS] + ) + def test_prints_when_level_equals_threshold( + self, capsys, method_name, threshold + ) -> None: """Method prints when console level == method threshold.""" c = Console(level=threshold) getattr(c, method_name)("hello") captured = capsys.readouterr() assert "hello" in captured.out - @pytest.mark.parametrize("method_name,threshold", LOG_METHODS, ids=[m for m, _ in LOG_METHODS]) - def test_prints_when_level_below_threshold(self, capsys, method_name, threshold) -> None: + @pytest.mark.parametrize( + "method_name,threshold", LOG_METHODS, ids=[m for m, _ in LOG_METHODS] + ) + def test_prints_when_level_below_threshold( + self, capsys, method_name, threshold + ) -> None: """Method prints when console level is below the method threshold.""" c = Console(level=max(threshold - 10, 1)) getattr(c, method_name)("below") captured = capsys.readouterr() assert "below" in captured.out - @pytest.mark.parametrize("method_name,threshold", LOG_METHODS, ids=[m for m, _ in LOG_METHODS]) - def test_silent_when_level_above_threshold(self, capsys, method_name, threshold) -> None: + @pytest.mark.parametrize( + "method_name,threshold", LOG_METHODS, ids=[m for m, _ in LOG_METHODS] + ) + def test_silent_when_level_above_threshold( + self, capsys, method_name, threshold + ) -> None: """Method is silent when console level exceeds the method threshold.""" c = Console(level=threshold + 10) getattr(c, method_name)("nope") captured = capsys.readouterr() assert captured.out == "" - @pytest.mark.parametrize("method_name,threshold", LOG_METHODS, ids=[m for m, _ in LOG_METHODS]) + @pytest.mark.parametrize( + "method_name,threshold", LOG_METHODS, ids=[m for m, _ in LOG_METHODS] + ) def test_silent_when_level_is_zero(self, capsys, method_name, threshold) -> None: """No output at all when level is 0 (disabled).""" c = Console(level=0) @@ -166,7 +180,7 @@ def test_output_sanitised(self, capsys) -> None: c.info('{"password": "oops"}') captured = capsys.readouterr() assert "oops" not in captured.out - assert '******' in captured.out + assert "******" in captured.out def test_output_has_bracket_prefix(self, capsys) -> None: """All log lines are wrapped with a bracket prefix.""" @@ -260,7 +274,7 @@ def test_filter_sanitises_message(self) -> None: record = self._make_record('{"password": "leak"}') f.filter(record) assert "leak" not in record.msg - assert '******' in record.msg + assert "******" in record.msg def test_filter_clears_args(self) -> None: """record.args is set to None after filtering.""" @@ -273,16 +287,14 @@ def test_filter_clears_args(self) -> None: def test_filter_handles_format_args(self) -> None: """getMessage() expands %-formatting; filter sees the expanded string.""" f = LogSanitizer() - record = self._make_record( - '{"apitoken": "%s"}', "my_secret_token" - ) + record = self._make_record('{"apitoken": "%s"}', "my_secret_token") # Before filtering, getMessage() should expand the arg expanded = record.getMessage() assert "my_secret_token" in expanded f.filter(record) assert "my_secret_token" not in record.msg - assert '******' in record.msg + assert "******" in record.msg def test_filter_safe_message_unchanged(self) -> None: """A record without sensitive data passes through with its message intact.""" @@ -301,6 +313,7 @@ class TestModuleLevelObjects: def test_module_console_exists(self) -> None: from mistapi.__logger import console as mod_console + assert isinstance(mod_console, Console) def test_module_logger_name(self) -> None: diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index 5c67695..504b51f 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -271,7 +271,9 @@ def test_get_default_parameter_defaults_to_none(self) -> None: class TestPrivilegesCreation: """Test Privileges initialisation""" - def test_creates_privilege_objects_from_list_of_dicts(self, sample_privileges) -> None: + def test_creates_privilege_objects_from_list_of_dicts( + self, sample_privileges + ) -> None: """Privileges should wrap each dict into a _Privilege object""" # Act privs = Privileges(sample_privileges) diff --git a/tests/unit/test_websocket_client.py b/tests/unit/test_websocket_client.py index 47457aa..2c444b4 100644 --- a/tests/unit/test_websocket_client.py +++ b/tests/unit/test_websocket_client.py @@ -453,7 +453,9 @@ def test_connect_background_starts_thread(self, mock_ws_cls, ws_client) -> None: mock_ws_instance = Mock() mock_ws_cls.return_value = mock_ws_instance - with patch("mistapi.websockets.__ws_client.threading.Thread") as mock_thread_cls: + with patch( + "mistapi.websockets.__ws_client.threading.Thread" + ) as mock_thread_cls: mock_thread = Mock() mock_thread_cls.return_value = mock_thread @@ -548,7 +550,9 @@ def test_yields_queued_messages(self, ws_client) -> None: results = list(ws_client.receive()) assert results == [{"event": "a"}, {"event": "b"}] - def test_returns_immediately_when_not_connected_within_timeout(self, ws_client) -> None: + def test_returns_immediately_when_not_connected_within_timeout( + self, ws_client + ) -> None: # _connected is never set, so wait(timeout=10) returns False. # Override timeout via monkey-patching for speed. original_wait = ws_client._connected.wait @@ -604,7 +608,9 @@ def test_with_statement(self, mock_session) -> None: # After exiting, disconnect should have been called (no-op here since _ws is None) @patch("mistapi.websockets.__ws_client.websocket.WebSocketApp") - def test_exit_disconnects_active_connection(self, mock_ws_cls, mock_session) -> None: + def test_exit_disconnects_active_connection( + self, mock_ws_cls, mock_session + ) -> None: mock_ws_instance = Mock() mock_ws_cls.return_value = mock_ws_instance From d988c376e59edf45a7aef566e7c7d7c14f091a9d Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Fri, 13 Mar 2026 21:28:15 +0100 Subject: [PATCH 13/16] chore: remove CLAUDE.md from branch Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 94b8344..0000000 --- a/CLAUDE.md +++ /dev/null @@ -1,7 +0,0 @@ -This is the repo for my Mist API Python client, which is a wrapper around the Mist API. It allows you to easily interact with the Mist API and perform various actions such as creating and managing devices, sites, and more. - -The code in src/mistapi/api is automatically generated from the OpenAPI specification provided by Mist. This means that the code is always up to date with the latest version of the API, and you can be confident that it will work correctly with the Mist API. -The code in src/mistapi/api is organized into different modules, each corresponding to a different aspect of the Mist API. For example, there are modules for managing devices, sites, and more. Each module contains functions that correspond to the various endpoints of the Mist API, allowing you to easily perform actions such as creating a new device, retrieving information about a site, and more. - - -The code in src/mistapi/websocket is here to provide a WebSocket client for the Mist API. This allows you to receive real-time updates from the Mist API, such as when a new device is added or when a site is updated. The WebSocket client is built using the popular websocket-client library, and it provides an easy-to-use interface for connecting to the Mist API and receiving updates. \ No newline at end of file From fb13128aba7d7d37ec45582aa08ba5f2ccfccf7b Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Fri, 13 Mar 2026 21:51:47 +0100 Subject: [PATCH 14/16] update version and changelog --- CHANGELOG.md | 154 ++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- src/mistapi/__version.py | 2 +- src/mistapi/api/v1/sites/sle.py | 107 +--------------------- src/mistapi/websockets/orgs.py | 12 +-- src/mistapi/websockets/sites.py | 82 ++++++++++++++--- uv.lock | 2 +- 7 files changed, 237 insertions(+), 124 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e240c6..fc985b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,158 @@ # CHANGELOG +## Version 0.61.0 (March 2026) + +**Released**: March 13, 2026 + +**MAJOR RELEASE** with extensive new features, code quality improvements, security enhancements, and performance optimizations. This release adds real-time WebSocket streaming, comprehensive device diagnostic utilities, extensive test coverage, and significant API improvements. + +--- + +### 1. NEW FEATURES + +#### **1.1 WebSocket Streaming Module** (`mistapi.websockets`) +Complete real-time event streaming support with flexible consumption patterns: + +**Available Channels:** +* Organization Channels + +| Class | Description | +|-------|-------------| +| `mistapi.websockets.orgs.InsightsEvents` | Real-time insights events for an organization | +| `mistapi.websockets.orgs.MxEdgesStatsEvents` | Real-time MX edges stats for an organization | +| `mistapi.websockets.orgs.MxEdgesUpgradesEvents` | Real-time MX edges upgrades events for an organization | + +* Site Channels + +| Class | Description | +|-------|-------------| +| `mistapi.websockets.sites.ClientsStatsEvents` | Real-time clients stats for a site | +| `mistapi.websockets.sites.DeviceCmdEvents` | Real-time device command events for a site | +| `mistapi.websockets.sites.DeviceStatsEvents` | Real-time device stats for a site | +| `mistapi.websockets.sites.DeviceUpgradesEvents` | Real-time device upgrades events for a site | +| `mistapi.websockets.sites.MxEdgesStatsEvents` | Real-time MX edges stats for a site | +| `mistapi.websockets.sites.PcapEvents` | Real-time PCAP events for a site | + +* Location Channels + +| Class | Description | +|-------|-------------| +| `mistapi.websockets.location.BleAssetsEvents` | Real-time BLE assets location events | +| `mistapi.websockets.location.ConnectedClientsEvents` | Real-time connected clients location events | +| `mistapi.websockets.location.SdkClientsEvents` | Real-time SDK clients location events | +| `mistapi.websockets.location.UnconnectedClientsEvents` | Real-time unconnected clients location events | +| `mistapi.websockets.location.DiscoveredBleAssetsEvents` | Real-time discovered BLE assets location events | + + +**Features:** +- Callback-based message handling +- Generator-style iteration +- Context manager support +- Automatic reconnection with configurable ping intervals +- Non-blocking background threads +- Type-safe API with full parameter validation + +**Example Usage:** +```python +ws = mistapi.websockets.sites.DeviceStatsEvents(apisession, site_ids=[""]) +ws.connect(run_in_background=True) + +for msg in ws.receive(): # blocks, yields each message as a dict + print(msg) + if some_condition: + ws.disconnect() # stops the generator cleanly +``` + +#### **1.2 Device Utilities Module** (`mistapi.device_utils`) +`mistapi.device_utils` provides high-level utilities for running diagnostic commands on Mist-managed devices. Each function triggers a REST API call and streams the results back via WebSocket. The library handles the connection plumbing — you just call the function and get back a `UtilResponse` object. + +**Device-Specific Modules** (Recommended): +| Module | Device Type | Functions | +|--------|-------------|-----------| +| `device_utils.ap` | Mist Access Points | `ping`, `traceroute`, `retrieveArpTable` | +| `device_utils.ex` | Juniper EX Switches | `ping`, `monitorTraffic`, `retrieveArpTable`, `retrieveBgpSummary`, `retrieveDhcpLeases`, `releaseDhcpLeases`, `retrieveMacTable`, `clearMacTable`, `clearLearnedMac`, `clearBpduError`, `clearDot1xSessions`, `clearHitCount`, `bouncePort`, `cableTest` | +| `device_utils.srx` | Juniper SRX Firewalls | `ping`, `monitorTraffic`, `retrieveArpTable`, `retrieveBgpSummary`, `retrieveDhcpLeases`, `releaseDhcpLeases`, `showDatabase`, `showNeighbors`, `showInterfaces`, `bouncePort`, `retrieveRoutes` | +| `device_utils.ssr` | Juniper SSR Routers | `ping`, `retrieveArpTable`, `retrieveBgpSummary`, `retrieveDhcpLeases`, `releaseDhcpLeases`, `showDatabase`, `showNeighbors`, `showInterfaces`, `bouncePort`, `retrieveRoutes`, `showServicePath` | + +**Example Usage:** +```python +from mistapi.device_utils import ap, ex + +# Ping from an AP +result = ap.ping(apisession, site_id, device_id, host="8.8.8.8") +print(result.ws_data) + +# Retrieve ARP table from a switch +result = ex.retrieveArpTable(apisession, site_id, device_id) +print(result.ws_data) + +# With real-time callback +def handle(msg): + print("got:", msg) + +result = ex.cableTest(apisession, site_id, device_id, port="ge-0/0/0", on_message=handle) +``` + +#### **1.3 New API Endpoints** + +**MapStacks API** (`mistapi.api.v1.sites.mapstacks`): +- `listSiteMapStacks()`: List map stacks with filtering +- `createSiteMapStack()`: Create new map stack + +**Enhanced Query Parameters**: +- Additional filtering options across alarms, clients, and devices endpoints +- Improved parameter handling in JSI, NAC clients, and WAN clients APIs + +--- + +### 2. SECURITY IMPROVEMENTS + +##### **HashiCorp Vault SSL Verification** +- Now properly verifies SSL certificates when connecting to Vault +- Made vault configuration attributes private (`_vault_url`, `_vault_path`, etc.) +- Improved cleanup of vault credentials after loading + +--- + +### 3. PERFORMANCE IMPROVEMENTS + +##### **Lazy Module Loading** +- Implemented lazy loading for `api` and `cli` subpackages +- Reduces initial import time by deferring heavy module imports until accessed +- Uses `__getattr__` for transparent lazy loading + +--- + +### 4. CODE QUALITY IMPROVEMENTS + +##### **HTTP Request Error Handling** +- Consolidated duplicate error handling logic into `_request_with_retry()` method +- Extracts HTTP operations into inner functions for cleaner code +- Reduces code duplication by ~55 lines across GET/POST/PUT/DELETE/POST_FILE methods +- Centralizes 429 rate limit handling and retry logic + +##### **Session Management** +- Added `_new_session()` helper method for consistent session initialization +- Improves code reusability when creating new HTTP sessions + +##### **API Token Management** +- Added `validate` parameter to `set_api_token()` method +- Allows skipping token validation when needed (default: `True`) +- Useful for faster initialization when tokens are known to be valid + +##### **Logging Improvements** +- Fixed logging sanitization to use `getMessage()` instead of direct `msg` access +- Clear `record.args` after sanitization to prevent re-formatting issues +- Improved logging format consistency using %-style formatting + +--- + +### 6. DEPENDENCIES + +##### **New Dependencies** +- Added `websocket-client>=1.8.0` for WebSocket streaming support + +--- + ## Version 0.60.3 (February 2026) **Released**: February 21, 2026 diff --git a/pyproject.toml b/pyproject.toml index d4f79ee..844d985 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "mistapi" -version = "0.55.15" +version = "0.61.0" authors = [{ name = "Thomas Munzer", email = "tmunzer@juniper.net" }] description = "Python package to simplify the Mist System APIs usage" keywords = ["Mist", "Juniper", "API"] diff --git a/src/mistapi/__version.py b/src/mistapi/__version.py index f9533d2..4e16e40 100644 --- a/src/mistapi/__version.py +++ b/src/mistapi/__version.py @@ -1,2 +1,2 @@ -__version__ = "0.55.15" +__version__ = "0.61.0" __author__ = "Thomas Munzer " diff --git a/src/mistapi/api/v1/sites/sle.py b/src/mistapi/api/v1/sites/sle.py index b5d4c60..71c2676 100644 --- a/src/mistapi/api/v1/sites/sle.py +++ b/src/mistapi/api/v1/sites/sle.py @@ -10,15 +10,16 @@ -------------------------------------------------------------------------------- """ +import deprecation + from mistapi import APISession as _APISession from mistapi.__api_response import APIResponse as _APIResponse -import deprecation @deprecation.deprecated( deprecated_in="0.59.2", removed_in="0.65.0", - current_version="0.55.15", + current_version="0.61.0", details="function replaced with getSiteSleClassifierSummaryTrend", ) def getSiteSleClassifierDetails( @@ -72,57 +73,6 @@ def getSiteSleClassifierDetails( return resp -def getSiteSleClassifierDetails( - mist_session: _APISession, - site_id: str, - scope: str, - scope_id: str, - metric: str, - classifier: str, - start: str | None = None, - end: str | None = None, - duration: str | None = None, -) -> _APIResponse: - """ - API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/sites/sles/get-site-sle-classifier-details - - PARAMS - ----------- - mistapi.APISession : mist_session - mistapi session including authentication and Mist host information - - PATH PARAMS - ----------- - site_id : str - scope : str{'ap', 'client', 'gateway', 'site', 'switch'} - scope_id : str - metric : str - classifier : str - - QUERY PARAMS - ------------ - start : str - end : str - duration : str, default: 1d - - RETURN - ----------- - mistapi.APIResponse - response from the API call - """ - - uri = f"/api/v1/sites/{site_id}/sle/{scope}/{scope_id}/metric/{metric}/classifier/{classifier}/summary" - query_params: dict[str, str] = {} - if start: - query_params["start"] = str(start) - if end: - query_params["end"] = str(end) - if duration: - query_params["duration"] = str(duration) - resp = mist_session.mist_get(uri=uri, query=query_params) - return resp - - def getSiteSleClassifierSummaryTrend( mist_session: _APISession, site_id: str, @@ -741,7 +691,7 @@ def listSiteSleImpactedWirelessClients( @deprecation.deprecated( deprecated_in="0.59.2", removed_in="0.65.0", - current_version="0.55.15", + current_version="0.61.0", details="function replaced with getSiteSleSummaryTrend", ) def getSiteSleSummary( @@ -793,55 +743,6 @@ def getSiteSleSummary( return resp -def getSiteSleSummary( - mist_session: _APISession, - site_id: str, - scope: str, - scope_id: str, - metric: str, - start: str | None = None, - end: str | None = None, - duration: str | None = None, -) -> _APIResponse: - """ - API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/sites/sles/get-site-sle-summary - - PARAMS - ----------- - mistapi.APISession : mist_session - mistapi session including authentication and Mist host information - - PATH PARAMS - ----------- - site_id : str - scope : str{'ap', 'client', 'gateway', 'site', 'switch'} - scope_id : str - metric : str - - QUERY PARAMS - ------------ - start : str - end : str - duration : str, default: 1d - - RETURN - ----------- - mistapi.APIResponse - response from the API call - """ - - uri = f"/api/v1/sites/{site_id}/sle/{scope}/{scope_id}/metric/{metric}/summary" - query_params: dict[str, str] = {} - if start: - query_params["start"] = str(start) - if end: - query_params["end"] = str(end) - if duration: - query_params["duration"] = str(duration) - resp = mist_session.mist_get(uri=uri, query=query_params) - return resp - - def getSiteSleSummaryTrend( mist_session: _APISession, site_id: str, diff --git a/src/mistapi/websockets/orgs.py b/src/mistapi/websockets/orgs.py index e8a24ff..d8e5520 100644 --- a/src/mistapi/websockets/orgs.py +++ b/src/mistapi/websockets/orgs.py @@ -129,11 +129,11 @@ def __init__( ) -class MxEdgesUpgradesEvents(_MistWebsocket): - """WebSocket stream for org MX edges upgrades events. +class MxEdgesEvents(_MistWebsocket): + """WebSocket stream for org MX edges events. Subscribes to the ``orgs/{org_id}/mxedges`` channel and delivers - real-time MX edges upgrades events for the given org. + real-time MX edges events for the given org. PARAMS ----------- @@ -150,7 +150,7 @@ class MxEdgesUpgradesEvents(_MistWebsocket): ----------- Callback style (background thread):: - ws = OrgMxEdgesUpgradesEvents(session, org_id="abc123") + ws = MxEdgesEvents(session, org_id="abc123") ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -158,14 +158,14 @@ class MxEdgesUpgradesEvents(_MistWebsocket): Generator style:: - ws = OrgMxEdgesUpgradesEvents(session, org_id="abc123") + ws = MxEdgesEvents(session, org_id="abc123") ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with OrgMxEdgesUpgradesEvents(session, org_id="abc123") as ws: + with MxEdgesEvents(session, org_id="abc123") as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) diff --git a/src/mistapi/websockets/sites.py b/src/mistapi/websockets/sites.py index 291ca14..c63910f 100644 --- a/src/mistapi/websockets/sites.py +++ b/src/mistapi/websockets/sites.py @@ -200,11 +200,11 @@ def __init__( ) -class DeviceUpgradesEvents(_MistWebsocket): - """WebSocket stream for site device upgrades events. +class DeviceEvents(_MistWebsocket): + """WebSocket stream for site device events. Subscribes to the ``sites/{site_id}/devices`` channel and delivers - real-time device upgrades events for the given site. + real-time device events for the given site. PARAMS ----------- @@ -221,7 +221,7 @@ class DeviceUpgradesEvents(_MistWebsocket): ----------- Callback style (background thread):: - ws = SiteDeviceUpgradesEvents(session, site_id="abc123") + ws = DeviceEvents(session, site_id="abc123") ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -229,14 +229,14 @@ class DeviceUpgradesEvents(_MistWebsocket): Generator style:: - ws = SiteDeviceUpgradesEvents(session, site_id="abc123") + ws = DeviceEvents(session, site_id="abc123") ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with SiteDeviceUpgradesEvents(session, site_id="abc123") as ws: + with DeviceEvents(session, site_id="abc123") as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) @@ -279,7 +279,7 @@ class MxEdgesStatsEvents(_MistWebsocket): ----------- Callback style (background thread):: - ws = SiteMxEdgesStatsEvents(session, site_id="abc123") + ws = MxEdgesStatsEvents(session, site_id="abc123") ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -287,14 +287,14 @@ class MxEdgesStatsEvents(_MistWebsocket): Generator style:: - ws = SiteMxEdgesStatsEvents(session, site_id="abc123") + ws = MxEdgesStatsEvents(session, site_id="abc123") ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with SiteMxEdgesStatsEvents(session, site_id="abc123") as ws: + with MxEdgesStatsEvents(session, site_id="abc123") as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) @@ -316,6 +316,64 @@ def __init__( ) +class MxEdgesEvents(_MistWebsocket): + """WebSocket stream for site MX edges events. + + Subscribes to the ``sites/{site_id}/mxedges`` channel and delivers + real-time MX edges events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_ids : list[str] + UUIDs of the sites to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = MxEdgesEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = MxEdgesEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with MxEdgesEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + site_ids: list[str], + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + channels = [f"/sites/{site_id}/mxedges" for site_id in site_ids] + super().__init__( + mist_session, + channels=channels, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) + + class PcapEvents(_MistWebsocket): """WebSocket stream for site PCAP events. @@ -337,7 +395,7 @@ class PcapEvents(_MistWebsocket): ----------- Callback style (background thread):: - ws = SitePcapEvents(session, site_id="abc123") + ws = PcapEvents(session, site_id="abc123") ws.on_message(lambda data: print(data)) ws.connect() # non-blocking, runs in background thread input("Press Enter to stop") @@ -345,14 +403,14 @@ class PcapEvents(_MistWebsocket): Generator style:: - ws = SitePcapEvents(session, site_id="abc123") + ws = PcapEvents(session, site_id="abc123") ws.connect(run_in_background=True) for msg in ws.receive(): process(msg) Context manager:: - with SitePcapEvents(session, site_id="abc123") as ws: + with PcapEvents(session, site_id="abc123") as ws: ws.on_message(my_handler) ws.connect() # non-blocking, runs in background thread time.sleep(60) diff --git a/uv.lock b/uv.lock index 7d80793..11c66d0 100644 --- a/uv.lock +++ b/uv.lock @@ -537,7 +537,7 @@ wheels = [ [[package]] name = "mistapi" -version = "0.55.15" +version = "0.61.0" source = { editable = "." } dependencies = [ { name = "deprecation" }, From b8ad936c053c0d0560f20dc0d8604ce7f5cc9109 Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Fri, 13 Mar 2026 21:55:07 +0100 Subject: [PATCH 15/16] code scanning fix --- src/mistapi/__api_session.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/mistapi/__api_session.py b/src/mistapi/__api_session.py index 4f3fa82..db3b2ff 100644 --- a/src/mistapi/__api_session.py +++ b/src/mistapi/__api_session.py @@ -626,31 +626,32 @@ def _check_api_tokens(self, apitokens) -> list[str]: primary_token_type: str | None = "" primary_token_value: str = "" for token in apitokens: - token_value = f"{token[:4]}...{token[-4:]}" + not_sensitive_data = f"{token[:4]}...{token[-4:]}" if token in valid_api_tokens: LOGGER.info( "apisession:_check_api_tokens:API Token %s is already valid", - token_value, + not_sensitive_data, ) continue (token_type, token_privileges) = self._get_api_token_data(token) if token_type is None or token_privileges is None: LOGGER.error( "apisession:_check_api_tokens:API Token %s is not valid", - token_value, + not_sensitive_data, ) LOGGER.error( - "API Token %s is not valid and will not be used", token_value + "API Token %s is not valid and will not be used", + not_sensitive_data, ) elif len(primary_token_privileges) == 0 and token_privileges: primary_token_privileges = token_privileges primary_token_type = token_type - primary_token_value = token_value + primary_token_value = not_sensitive_data valid_api_tokens.append(token) LOGGER.info( "apisession:_check_api_tokens:" "API Token %s set as primary for comparison", - token_value, + not_sensitive_data, ) elif primary_token_privileges == token_privileges: valid_api_tokens.append(token) @@ -659,7 +660,7 @@ def _check_api_tokens(self, apitokens) -> list[str]: "%s API Token %s has same privileges as " "the %s API Token %s", token_type, - token_value, + not_sensitive_data, primary_token_type, primary_token_value, ) @@ -669,13 +670,13 @@ def _check_api_tokens(self, apitokens) -> list[str]: "%s API Token %s has different privileges " "than the %s API Token %s", token_type, - token_value, + not_sensitive_data, primary_token_type, primary_token_value, ) LOGGER.error( "API Token %s has different privileges and will not be used", - token_value, + not_sensitive_data, ) return valid_api_tokens From 0f7114efaf54d93ee7575b97d0b5d152b2b492ee Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Fri, 13 Mar 2026 22:13:56 +0100 Subject: [PATCH 16/16] fix remaining minor issues --- src/mistapi/__api_session.py | 54 +++++++++++++++++++---------- tests/unit/test_websocket_client.py | 16 ++++----- 2 files changed, 43 insertions(+), 27 deletions(-) diff --git a/src/mistapi/__api_session.py b/src/mistapi/__api_session.py index db3b2ff..e84a191 100644 --- a/src/mistapi/__api_session.py +++ b/src/mistapi/__api_session.py @@ -856,7 +856,11 @@ def login_with_return( LOGGER.error("apisession:login_with_return:credentials are missing") return {"authenticated": False, "error": "credentials are missing"} - if resp.status_code == 200 and not resp.data.get("two_factor_required", False): + if ( + resp.status_code == 200 + and isinstance(resp.data, dict) + and not resp.data.get("two_factor_required", False) + ): LOGGER.info("apisession:login_with_return:access authorized") return {"authenticated": True, "error": ""} else: @@ -884,7 +888,8 @@ def logout(self) -> None: self._set_authenticated(False) else: try: - CONSOLE.error(resp.data["detail"]) + if isinstance(resp.data, dict) and "detail" in resp.data: + CONSOLE.error(resp.data["detail"]) except (KeyError, TypeError, AttributeError): if isinstance(resp.raw_data, bytes): CONSOLE.error(resp.raw_data.decode("utf-8", errors="replace")) @@ -1077,7 +1082,7 @@ def _getself(self) -> bool: uri = "/api/v1/self" LOGGER.info('apisession:_getself: sending GET request to "%s"', uri) resp = self.mist_get(uri) - if resp.status_code == 200 and resp.data: + if resp.status_code == 200 and resp.data and isinstance(resp.data, dict): # Deal with 2FA if needed if ( resp.data.get("two_factor_required") is True @@ -1094,20 +1099,27 @@ def _getself(self) -> bool: LOGGER.info( "apisession:_getself:authentication Ok. Processing account privileges" ) - for key, val in resp.data.items(): - if key == "privileges": - self.privileges = Privileges(resp.data["privileges"]) - if key == "tags": - for tag in resp.data["tags"]: - self.tags.append(tag) - else: - setattr(self, key, val) - if self._show_cli_notif: - print() - print(" Authenticated ".center(80, "-")) - print(f"\r\nWelcome {self.first_name} {self.last_name}!\r\n") - LOGGER.info("apisession:_getself:account info processed successfully") - return True + if isinstance(resp.data, dict): + for key, val in resp.data.items(): + if key == "privileges": + self.privileges = Privileges(resp.data["privileges"]) + if key == "tags": + for tag in resp.data["tags"]: + self.tags.append(tag) + else: + setattr(self, key, val) + if self._show_cli_notif: + print() + print(" Authenticated ".center(80, "-")) + print(f"\r\nWelcome {self.first_name} {self.last_name}!\r\n") + LOGGER.info( + "apisession:_getself:account info processed successfully" + ) + return True + else: + raise ValueError( + "Unexpected format for privileges in the response data" + ) elif resp.proxy_error: LOGGER.critical("apisession:_getself:proxy not valid...") CONSOLE.critical("Proxy not valid...\r\n") @@ -1171,7 +1183,11 @@ def get_privilege_by_org_id(self, org_id: str): msp_id = None try: resp = self.mist_get(uri) - if resp.data and resp.data.get("msp_id"): + if ( + resp.data + and isinstance(resp.data, dict) + and resp.data.get("msp_id") + ): LOGGER.info( "apisession:get_privilege_by_org_id:org %s belong to msp_id %s", {org_id}, @@ -1205,7 +1221,7 @@ def get_privilege_by_org_id(self, org_id: str): "unable of find msp %s privileges in user data", msp_id, ) - else: + elif isinstance(resp.data, dict): return { "scope": "org", "org_id": org_id, diff --git a/tests/unit/test_websocket_client.py b/tests/unit/test_websocket_client.py index 2c444b4..5069567 100644 --- a/tests/unit/test_websocket_client.py +++ b/tests/unit/test_websocket_client.py @@ -9,7 +9,6 @@ """ import json -import queue import ssl from unittest.mock import Mock, call, patch @@ -25,19 +24,20 @@ ) from mistapi.websockets.orgs import ( InsightsEvents, - MxEdgesStatsEvents as OrgMxEdgesStatsEvents, - MxEdgesUpgradesEvents, + MxEdgesEvents, ) +from mistapi.websockets.orgs import MxEdgesStatsEvents as OrgMxEdgesStatsEvents from mistapi.websockets.session import SessionWithUrl from mistapi.websockets.sites import ( ClientsStatsEvents, DeviceCmdEvents, + DeviceEvents, DeviceStatsEvents, - DeviceUpgradesEvents, - MxEdgesStatsEvents as SiteMxEdgesStatsEvents, PcapEvents, ) - +from mistapi.websockets.sites import ( + MxEdgesStatsEvents as SiteMxEdgesStatsEvents, +) # --------------------------------------------------------------------------- # Fixtures @@ -694,7 +694,7 @@ def test_device_stats_events_channels(self, mock_session) -> None: assert ws._channels == ["/sites/s1/stats/devices"] def test_device_upgrades_events_channels(self, mock_session) -> None: - ws = DeviceUpgradesEvents(mock_session, site_ids=["s1"]) + ws = DeviceEvents(mock_session, site_ids=["s1"]) assert ws._channels == ["/sites/s1/devices"] def test_site_mxedges_stats_events_channels(self, mock_session) -> None: @@ -729,7 +729,7 @@ def test_org_mxedges_stats_events_channels(self, mock_session) -> None: assert ws._channels == ["/orgs/o1/stats/mxedges"] def test_mxedges_upgrades_events_channels(self, mock_session) -> None: - ws = MxEdgesUpgradesEvents(mock_session, org_id="o1") + ws = MxEdgesEvents(mock_session, org_id="o1") assert ws._channels == ["/orgs/o1/mxedges"] def test_inherits_from_mist_websocket(self, mock_session) -> None: