diff --git a/.gitmodules b/.gitmodules index 415e265..081acbd 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "mist_openapi"] path = mist_openapi url = https://github.com/mistsys/mist_openapi.git - branch = 2602.1.5 \ No newline at end of file + branch = master \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bb5a14..7c00853 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,201 @@ # CHANGELOG +## Version 0.61.0 (March 2026) + +**Released**: March 13, 2026 + +**MAJOR RELEASE** with extensive new features, code quality improvements, security enhancements, and performance optimizations. This release adds real-time WebSocket streaming, comprehensive device diagnostic utilities, extensive test coverage, and significant API improvements. + +--- + +### 1. NEW FEATURES + +#### **1.1 WebSocket Streaming Module** (`mistapi.websockets`) +Complete real-time event streaming support with flexible consumption patterns: + +**Available Channels:** +* Organization Channels + +| Class | Description | +|-------|-------------| +| `mistapi.websockets.orgs.InsightsEvents` | Real-time insights events for an organization | +| `mistapi.websockets.orgs.MxEdgesStatsEvents` | Real-time MX edges stats for an organization | +| `mistapi.websockets.orgs.MxEdgesUpgradesEvents` | Real-time MX edges upgrades events for an organization | + +* Site Channels + +| Class | Description | +|-------|-------------| +| `mistapi.websockets.sites.ClientsStatsEvents` | Real-time clients stats for a site | +| `mistapi.websockets.sites.DeviceCmdEvents` | Real-time device command events for a site | +| `mistapi.websockets.sites.DeviceStatsEvents` | Real-time device stats for a site | +| `mistapi.websockets.sites.DeviceUpgradesEvents` | Real-time device upgrades events for a site | +| `mistapi.websockets.sites.MxEdgesStatsEvents` | Real-time MX edges stats for a site | +| `mistapi.websockets.sites.PcapEvents` | Real-time PCAP events for a site | + +* Location Channels + +| Class | Description | +|-------|-------------| +| `mistapi.websockets.location.BleAssetsEvents` | Real-time BLE assets location events | +| `mistapi.websockets.location.ConnectedClientsEvents` | Real-time connected clients location events | +| `mistapi.websockets.location.SdkClientsEvents` | Real-time SDK clients location events | +| `mistapi.websockets.location.UnconnectedClientsEvents` | Real-time unconnected clients location events | +| `mistapi.websockets.location.DiscoveredBleAssetsEvents` | Real-time discovered BLE assets location events | + + +**Features:** +- Callback-based message handling +- Generator-style iteration +- Context manager support +- Automatic reconnection with configurable ping intervals +- Non-blocking background threads +- Type-safe API with full parameter validation + +**Example Usage:** +```python +ws = mistapi.websockets.sites.DeviceStatsEvents(apisession, site_ids=[""]) +ws.connect(run_in_background=True) + +for msg in ws.receive(): # blocks, yields each message as a dict + print(msg) + if some_condition: + ws.disconnect() # stops the generator cleanly +``` + +#### **1.2 Device Utilities Module** (`mistapi.device_utils`) +`mistapi.device_utils` provides high-level utilities for running diagnostic commands on Mist-managed devices. Each function triggers a REST API call and streams the results back via WebSocket. The library handles the connection plumbing — you just call the function and get back a `UtilResponse` object. + +**Device-Specific Modules** (Recommended): +| Module | Device Type | Functions | +|--------|-------------|-----------| +| `device_utils.ap` | Mist Access Points | `ping`, `traceroute`, `retrieveArpTable` | +| `device_utils.ex` | Juniper EX Switches | `ping`, `monitorTraffic`, `retrieveArpTable`, `retrieveBgpSummary`, `retrieveDhcpLeases`, `releaseDhcpLeases`, `retrieveMacTable`, `clearMacTable`, `clearLearnedMac`, `clearBpduError`, `clearDot1xSessions`, `clearHitCount`, `bouncePort`, `cableTest` | +| `device_utils.srx` | Juniper SRX Firewalls | `ping`, `monitorTraffic`, `retrieveArpTable`, `retrieveBgpSummary`, `retrieveDhcpLeases`, `releaseDhcpLeases`, `showDatabase`, `showNeighbors`, `showInterfaces`, `bouncePort`, `retrieveRoutes` | +| `device_utils.ssr` | Juniper SSR Routers | `ping`, `retrieveArpTable`, `retrieveBgpSummary`, `retrieveDhcpLeases`, `releaseDhcpLeases`, `showDatabase`, `showNeighbors`, `showInterfaces`, `bouncePort`, `retrieveRoutes`, `showServicePath` | + +**Example Usage:** +```python +from mistapi.device_utils import ap, ex + +# Ping from an AP +result = ap.ping(apisession, site_id, device_id, host="8.8.8.8") +print(result.ws_data) + +# Retrieve ARP table from a switch +result = ex.retrieveArpTable(apisession, site_id, device_id) +print(result.ws_data) + +# With real-time callback +def handle(msg): + print("got:", msg) + +result = ex.cableTest(apisession, site_id, device_id, port="ge-0/0/0", on_message=handle) +``` + +#### **1.3 New API Endpoints** + +**MapStacks API** (`mistapi.api.v1.sites.mapstacks`): +- `listSiteMapStacks()`: List map stacks with filtering +- `createSiteMapStack()`: Create new map stack + +**Enhanced Query Parameters**: +- Additional filtering options across alarms, clients, and devices endpoints +- Improved parameter handling in JSI, NAC clients, and WAN clients APIs + +--- + +### 2. SECURITY IMPROVEMENTS + +##### **HashiCorp Vault SSL Verification** +- Now properly verifies SSL certificates when connecting to Vault +- Made vault configuration attributes private (`_vault_url`, `_vault_path`, etc.) +- Improved cleanup of vault credentials after loading + +--- + +### 3. PERFORMANCE IMPROVEMENTS + +##### **Lazy Module Loading** +- Implemented lazy loading for `api` and `cli` subpackages +- Reduces initial import time by deferring heavy module imports until accessed +- Uses `__getattr__` for transparent lazy loading + +--- + +### 4. CODE QUALITY IMPROVEMENTS + +##### **HTTP Request Error Handling** +- Consolidated duplicate error handling logic into `_request_with_retry()` method +- Extracts HTTP operations into inner functions for cleaner code +- Reduces code duplication by ~55 lines across GET/POST/PUT/DELETE/POST_FILE methods +- Centralizes 429 rate limit handling and retry logic + +##### **Session Management** +- Added `_new_session()` helper method for consistent session initialization +- Improves code reusability when creating new HTTP sessions + +##### **API Token Management** +- Added `validate` parameter to `set_api_token()` method +- Allows skipping token validation when needed (default: `True`) +- Useful for faster initialization when tokens are known to be valid + +##### **Logging Improvements** +- Fixed logging sanitization to use `getMessage()` instead of direct `msg` access +- Clear `record.args` after sanitization to prevent re-formatting issues +- Improved logging format consistency using %-style formatting + +--- + +### 6. DEPENDENCIES + +##### **New Dependencies** +- Added `websocket-client>=1.8.0` for WebSocket streaming support + +--- + +## Version 0.60.3 (February 2026) + +**Released**: February 21, 2026 + +This release add a missing query parameter to the `searchOrgWanClients()` function. + +--- + +### 1. CHANGES + +##### **API Function Updates** +- Updated `searchOrgWanClients()` and related functions in `orgs/wan_clients.py`. + +--- + +## Version 0.60.1 (February 2026) + +**Released**: February 21, 2026 + +This release includes function updates and bug fixes in the self/logs.py and sites/sle.py modules. + +--- + +### 1. CHANGES + +##### **API Function Updates** +- Updated `listSelfAuditLogs()` and related functions in `self/logs.py`. +- Updated deprecated and new SLE classifier functions in `sites/sle.py`. + +--- + +### 2. BUG FIXES + +- Minor bug fixes and improvements in API modules. + +--- + +### Breaking Changes + +No breaking changes in this release. + +--- + ## Version 0.60.4 (March 2026) **Released**: March 3, 2026 diff --git a/README.md b/README.md index df7002c..09d27e4 100644 --- a/README.md +++ b/README.md @@ -14,11 +14,31 @@ A comprehensive Python package to interact with the Mist Cloud APIs, built from - [Installation](#installation) - [Quick Start](#quick-start) - [Configuration](#configuration) + - [Using Environment File](#using-environment-file) + - [Environment Variables](#environment-variables) - [Authentication](#authentication) -- [Usage](#usage) -- [CLI Helper Functions](#cli-helper-functions) -- [Pagination](#pagination-support) -- [Examples](#examples) + - [Interactive Authentication](#interactive-authentication) + - [Environment File Authentication](#environment-file-authentication) + - [HashiCorp Vault Authentication](#hashicorp-vault-authentication) + - [System Keyring Authentication](#system-keyring-authentication) + - [Direct Parameter Authentication](#direct-parameter-authentication) +- [API Requests Usage](#api-requests-usage) + - [Basic API Calls](#basic-api-calls) + - [Error Handling](#error-handling) + - [Log Sanitization](#log-sanitization) + - [Getting Help](#getting-help) + - [CLI Helper Functions](#cli-helper-functions) + - [Pagination](#pagination-support) + - [Examples](#examples) +- [WebSocket Streaming](#websocket-streaming) + - [Connection Parameters](#connection-parameters) + - [Callbacks](#callbacks) + - [Available Channels](#available-channels) + - [Usage Patterns](#usage-patterns) +- [Device Utilities](#device-utilities) + - [Supported Devices](#supported-devices) + - [Usage](#device-utilities-usage) + - [UtilResponse Object](#utilresponse-object) - [Development](#development-and-testing) - [Contributing](#contributing) - [License](#license) @@ -44,18 +64,12 @@ Support for all Mist cloud instances worldwide: ### Core Features - **Complete API Coverage**: Auto-generated from OpenAPI specs - **Automatic Pagination**: Built-in support for paginated responses +- **WebSocket Streaming**: Real-time event streaming for devices, clients, and location data +- **Device Diagnostics**: High-level utilities for ping, traceroute, ARP, BGP, OSPF, and more - **Error Handling**: Detailed error responses and logging - **Proxy Support**: HTTP/HTTPS proxy configuration - **Log Sanitization**: Automatic redaction of sensitive data in logs -### API Coverage -**Organization Level**: Organizations, Sites, Devices (APs/Switches/Gateways), WLANs, VPNs, Networks, NAC, Users, Admins, Guests, Alarms, Events, Statistics, SLE, Assets, Licenses, Webhooks, Security Policies, MSP management - -**Site Level**: Device management, RF optimization, Location services, Maps, Client analytics, Asset tracking, Synthetic testing, Anomaly detection - -**Constants & Utilities**: Device models, AP channels, Applications, Country codes, Alarm definitions, Event types, Webhook topics - -**Additional Services**: OAuth, Two-factor authentication, Account recovery, Invitations, MDM workflows --- @@ -81,16 +95,31 @@ python3 -m pip install --upgrade mistapi py -m pip install --upgrade mistapi ``` +### Installation with uv + +[uv](https://docs.astral.sh/uv/) is a fast Python package manager: + +```bash +# Install in current project +uv add mistapi + +# Or run directly without installing +uv run --with mistapi python my_script.py +``` + ### Development Installation ```bash -# Install with development dependencies (for contributors) -pip install mistapi[dev] +# With pip +pip install -e ".[dev]" + +# With uv +uv sync ``` ### Requirements - Python 3.10 or higher -- Dependencies: `requests`, `python-dotenv`, `tabulate`, `deprecation`, `hvac`, `keyring` +- Dependencies: `requests`, `python-dotenv`, `tabulate`, `deprecation`, `hvac`, `keyring`, `websocket-client` --- @@ -150,22 +179,24 @@ MIST_APITOKEN=your_api_token_here # LOGGING_LOG_LEVEL=10 ``` -### All Configuration Options +### Configuration Options + +| Environment Variable | APISession Parameter | Type | Default | Description | +|---|---|---|---|---| +| `MIST_HOST` | `host` | string | None | Mist Cloud API endpoint (e.g., `api.mist.com`) | +| `MIST_APITOKEN` | `apitoken` | string | None | API Token for authentication (recommended) | +| `MIST_USER` | `email` | string | None | Username/email for authentication | +| `MIST_PASSWORD` | `password` | string | None | Password for authentication | +| `MIST_KEYRING_SERVICE` | `keyring_service` | string | None | System keyring service name | +| `MIST_VAULT_URL` | `vault_url` | string | None | HashiCorp Vault URL | +| `MIST_VAULT_PATH` | `vault_path` | string | None | Path to secret in Vault | +| `MIST_VAULT_MOUNT_POINT` | `vault_mount_point` | string | None | Vault mount point | +| `MIST_VAULT_TOKEN` | `vault_token` | string | None | Vault authentication token | +| `CONSOLE_LOG_LEVEL` | `console_log_level` | int | 20 | Console log level (0-50) | +| `LOGGING_LOG_LEVEL` | `logging_log_level` | int | 10 | File log level (0-50) | +| `HTTPS_PROXY` | `https_proxy` | string | None | HTTP/HTTPS proxy URL | +| | `env_file` | str | None | Path to `.env` file | -| Variable | Type | Default | Description | -|----------|------|---------|-------------| -| `MIST_HOST` | string | None | Mist Cloud API endpoint (e.g., `api.mist.com`) | -| `MIST_APITOKEN` | string | None | API Token for authentication (recommended) | -| `MIST_USER` | string | None | Username if not using API token | -| `MIST_PASSWORD` | string | None | Password if not using API token | -| `MIST_KEYRING_SERVICE` | string | None | Load credentials from system keyring | -| `MIST_VAULT_URL` | string | https://127.0.0.1:8200 | HashiCorp Vault URL | -| `MIST_VAULT_PATH` | string | None | Path to secret in Vault | -| `MIST_VAULT_MOUNT_POINT` | string | secret | Vault mount point | -| `MIST_VAULT_TOKEN` | string | None | Vault authentication token | -| `CONSOLE_LOG_LEVEL` | int | 20 | Console log level (0-50) | -| `LOGGING_LOG_LEVEL` | int | 10 | File log level (0-50) | -| `HTTPS_PROXY` | string | None | HTTP/HTTPS proxy URL | --- @@ -253,27 +284,10 @@ apisession = mistapi.APISession( apisession.login() ``` -### APISession Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `email` | str | None | Email for login/password authentication | -| `password` | str | None | Password for login/password authentication | -| `apitoken` | str | None | API token (recommended method) | -| `host` | str | None | Mist Cloud endpoint (e.g., "api.mist.com") | -| `keyring_service` | str | None | System keyring service name | -| `vault_url` | str | https://127.0.0.1:8200 | HashiCorp Vault URL | -| `vault_path` | str | None | Path to secret in Vault | -| `vault_mount_point` | str | secret | Vault mount point | -| `vault_token` | str | None | Vault authentication token | -| `env_file` | str | None | Path to `.env` file | -| `console_log_level` | int | 20 | Console logging level (0-50) | -| `logging_log_level` | int | 10 | File logging level (0-50) | -| `https_proxy` | str | None | Proxy URL | --- -## Usage +## API Requests Usage ### Basic API Calls @@ -345,11 +359,11 @@ help(mistapi.api.v1.orgs.stats.getOrgStats) --- -## CLI Helper Functions +### CLI Helper Functions Interactive functions for selecting organizations and sites. -### Organization Selection +#### Organization Selection ```python # Select single organization @@ -368,7 +382,7 @@ Available organizations: Select an Org (0 to 1, or q to exit): 0 ``` -### Site Selection +#### Site Selection ```python # Select site within an organization @@ -386,9 +400,9 @@ Select a Site (0 to 1, or q to exit): 0 --- -## Pagination Support +### Pagination Support -### Get Next Page +#### Get Next Page ```python # Get first page @@ -403,7 +417,7 @@ if response.next: print(f"Second page: {len(response_2.data['results'])} results") ``` -### Get All Pages Automatically +#### Get All Pages Automatically ```python # Get all pages with a single call @@ -419,11 +433,11 @@ print(f"Total results across all pages: {len(all_data)}") --- -## Examples +### Examples Comprehensive examples are available in the [Mist Library repository](https://github.com/tmunzer/mist_library). -### Device Management +#### Device Management ```python # List all devices in an organization @@ -441,7 +455,7 @@ result = mistapi.api.v1.orgs.devices.updateOrgDevice( ) ``` -### Site Management +#### Site Management ```python # Create a new site @@ -458,7 +472,7 @@ new_site = mistapi.api.v1.orgs.sites.createOrgSite( site_stats = mistapi.api.v1.sites.stats.getSiteStats(apisession, new_site.id) ``` -### Client Analytics +#### Client Analytics ```python # Search for wireless clients @@ -472,10 +486,191 @@ clients = mistapi.api.v1.orgs.clients.searchOrgWirelessClients( events = mistapi.api.v1.orgs.clients.searchOrgClientsEvents( apisession, org_id, duration="1h", - client_mac="aa:bb:cc:dd:ee:ff" + client_mac="aabbccddeeff" +) +``` + +--- + +## WebSocket Streaming + +The package provides a WebSocket client for real-time event streaming from the Mist API (`wss://{host}/api-ws/v1/stream`). Authentication is handled automatically using the same session credentials (API token or login/password). + +### Connection Parameters + +All channel classes accept the following optional keyword arguments to control the WebSocket keep-alive behaviour: + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `ping_interval` | `int` | `30` | Seconds between automatic ping frames. Set to `0` to disable pings. | +| `ping_timeout` | `int` | `10` | Seconds to wait for a pong response before treating the connection as dead. | + +```python +ws = mistapi.websockets.sites.DeviceStatsEvents( + apisession, + site_ids=[""], + ping_interval=60, # ping every 60 s + ping_timeout=20, # wait up to 20 s for pong ) +ws.connect() +``` + +### Callbacks + +| Method | Signature | Description | +|--------|-----------|-------------| +| `ws.on_open(cb)` | `cb()` | Called when the connection is established | +| `ws.on_message(cb)` | `cb(data: dict)` | Called for every incoming message | +| `ws.on_error(cb)` | `cb(error: Exception)` | Called on WebSocket errors | +| `ws.on_close(cb)` | `cb(status_code: int, msg: str)` | Called when the connection closes | +| `ws.ready()` | `-> bool \| None` | Returns `True` if the connection is open and ready | + +### Available Channels + +#### Organization Channels + +| Class | Channel | Description | +|-------|---------|-------------| +| `mistapi.websockets.orgs.InsightsEvents` | `/orgs/{org_id}/insights/summary` | Real-time insights events for an organization | +| `mistapi.websockets.orgs.MxEdgesStatsEvents` | `/orgs/{org_id}/stats/mxedges` | Real-time MX edges stats for an organization | +| `mistapi.websockets.orgs.MxEdgesUpgradesEvents` | `/orgs/{org_id}/mxedges` | Real-time MX edges upgrades events for an organization | + +#### Site Channels + +| Class | Channel | Description | +|-------|---------|-------------| +| `mistapi.websockets.sites.ClientsStatsEvents` | `/sites/{site_id}/stats/clients` | Real-time clients stats for a site | +| `mistapi.websockets.sites.DeviceCmdEvents` | `/sites/{site_id}/devices/{device_id}/cmd` | Real-time device command events for a site | +| `mistapi.websockets.sites.DeviceStatsEvents` | `/sites/{site_id}/stats/devices` | Real-time device stats for a site | +| `mistapi.websockets.sites.DeviceUpgradesEvents` | `/sites/{site_id}/devices` | Real-time device upgrades events for a site | +| `mistapi.websockets.sites.MxEdgesStatsEvents` | `/sites/{site_id}/stats/mxedges` | Real-time MX edges stats for a site | +| `mistapi.websockets.sites.PcapEvents` | `/sites/{site_id}/pcap` | Real-time PCAP events for a site | + +#### Location Channels + +| Class | Channel | Description | +|-------|---------|-------------| +| `mistapi.websockets.location.BleAssetsEvents` | `/sites/{site_id}/stats/maps/{map_id}/assets` | Real-time BLE assets location events | +| `mistapi.websockets.location.ConnectedClientsEvents` | `/sites/{site_id}/stats/maps/{map_id}/clients` | Real-time connected clients location events | +| `mistapi.websockets.location.SdkClientsEvents` | `/sites/{site_id}/stats/maps/{map_id}/sdkclients` | Real-time SDK clients location events | +| `mistapi.websockets.location.UnconnectedClientsEvents` | `/sites/{site_id}/stats/maps/{map_id}/unconnected_clients` | Real-time unconnected clients location events | +| `mistapi.websockets.location.DiscoveredBleAssetsEvents` | `/sites/{site_id}/stats/maps/{map_id}/discovered_assets` | Real-time discovered BLE assets location events | + +#### Session Channels + +| Class | Channel | Description | +|-------|---------|-------------| +| `mistapi.websockets.session.SessionWithUrl` | Custom URL | Connect to a custom WebSocket channel URL | + +### Usage Patterns + +#### Callback style (recommended) + +`connect()` defaults to `run_in_background=True` and returns immediately. The WebSocket runs in a daemon thread, so your program must stay alive (e.g., with `input()` or an event loop). Messages are delivered to the registered callback in the background thread. + +```python +import mistapi + +apisession = mistapi.APISession(env_file="~/.mist_env") +apisession.login() + +ws = mistapi.websockets.sites.DeviceStatsEvents(apisession, site_ids=[""]) +ws.on_message(lambda data: print(data)) +ws.connect() # non-blocking + +input("Press Enter to stop") +ws.disconnect() +``` + +#### Generator style + +Iterate over incoming messages as a blocking generator. Useful when you want to process messages sequentially in a loop. + +```python +ws = mistapi.websockets.sites.DeviceStatsEvents(apisession, site_ids=[""]) +ws.connect(run_in_background=True) + +for msg in ws.receive(): # blocks, yields each message as a dict + print(msg) + if some_condition: + ws.disconnect() # stops the generator cleanly +``` + +#### Blocking style + +`connect(run_in_background=False)` blocks the calling thread until the connection closes. Useful for simple scripts. + +```python +ws = mistapi.websockets.sites.DeviceStatsEvents(apisession, site_ids=[""]) +ws.on_message(lambda data: print(data)) +ws.connect(run_in_background=False) # blocks until disconnected +``` + +#### Context manager + +`disconnect()` is called automatically on exit, even if an exception is raised. + +```python +import time + +with mistapi.websockets.sites.DeviceStatsEvents(apisession, site_ids=[""]) as ws: + ws.on_message(lambda data: print(data)) + ws.connect() + time.sleep(60) +# ws.disconnect() called automatically here +``` + +--- + +## Device Utilities + +`mistapi.device_utils` provides high-level utilities for running diagnostic commands on Mist-managed devices. Each function triggers a REST API call and streams the results back via WebSocket. The library handles the connection plumbing — you just call the function and get back a `UtilResponse` object. + +### Supported Devices + +| Module | Device Type | Functions | +|--------|-------------|-----------| +| `device_utils.ap` | Mist Access Points | `ping`, `traceroute`, `retrieveArpTable` | +| `device_utils.ex` | Juniper EX Switches | `ping`, `monitorTraffic`, `retrieveArpTable`, `retrieveBgpSummary`, `retrieveDhcpLeases`, `releaseDhcpLeases`, `retrieveMacTable`, `clearMacTable`, `clearLearnedMac`, `clearBpduError`, `clearDot1xSessions`, `clearHitCount`, `bouncePort`, `cableTest` | +| `device_utils.srx` | Juniper SRX Firewalls | `ping`, `monitorTraffic`, `retrieveArpTable`, `retrieveBgpSummary`, `retrieveDhcpLeases`, `releaseDhcpLeases`, `showDatabase`, `showNeighbors`, `showInterfaces`, `bouncePort`, `retrieveRoutes` | +| `device_utils.ssr` | Juniper SSR Routers | `ping`, `retrieveArpTable`, `retrieveBgpSummary`, `retrieveDhcpLeases`, `releaseDhcpLeases`, `showDatabase`, `showNeighbors`, `showInterfaces`, `bouncePort`, `retrieveRoutes`, `showServicePath` | + +### Device Utilities Usage + +```python +from mistapi.device_utils import ap, ex + +# Ping from an AP +result = ap.ping(apisession, site_id, device_id, host="8.8.8.8") +print(result.ws_data) + +# Retrieve ARP table from a switch +result = ex.retrieveArpTable(apisession, site_id, device_id) +print(result.ws_data) + +# With real-time callback +def handle(msg): + print("got:", msg) + +result = ex.cableTest(apisession, site_id, device_id, port="ge-0/0/0", on_message=handle) ``` +### UtilResponse Object + +All device utility functions return a `UtilResponse` object: + +| Attribute | Type | Description | +|-----------|------|-------------| +| `trigger_api_response` | `APIResponse` | The initial REST API response that triggered the device command. Contains `status_code`, `data`, and `headers` from the trigger request. | +| `ws_required` | `bool` | `True` if the command required a WebSocket connection to stream results (most diagnostic commands do). `False` if the REST response alone was sufficient. | +| `ws_data` | `list[str]` | Parsed result data extracted from the WebSocket stream. Each entry is a processed output line from the device (e.g., a line of ping output or an ARP table row). | +| `ws_raw_events` | `list[str]` | Raw, unprocessed WebSocket event payloads as received from the Mist API. Useful for debugging or custom parsing. | + +### Enums + +- `ap.TracerouteProtocol` — `ICMP`, `UDP` (for `ap.traceroute()`) +- `srx.Node` / `ssr.Node` — `NODE0`, `NODE1` (for dual-node devices) + --- ## Development and Testing @@ -487,8 +682,11 @@ events = mistapi.api.v1.orgs.clients.searchOrgClientsEvents( git clone https://github.com/tmunzer/mistapi_python.git cd mistapi_python -# Install with development dependencies +# With pip pip install -e ".[dev]" + +# With uv +uv sync ``` ### Running Tests @@ -496,6 +694,8 @@ pip install -e ".[dev]" ```bash # Run all tests pytest +# or with uv +uv run pytest # Run with coverage report pytest --cov=src/mistapi --cov-report=html @@ -505,13 +705,15 @@ pytest tests/unit/test_api_session.py # Run linting ruff check src/ +# or with uv +uv run ruff check src/ ``` ### Package Structure ``` src/mistapi/ -├── __init__.py # Main package exports +├── __init__.py # Main package exports (lazy-loads api, cli, utils, websockets) ├── __api_session.py # Session management and authentication ├── __api_request.py # HTTP request handling ├── __api_response.py # Response parsing and pagination @@ -521,12 +723,24 @@ src/mistapi/ ├── __models/ # Data models │ ├── __init__.py │ └── privilege.py -└── api/v1/ # Auto-generated API endpoints - ├── const/ # Constants and enums - ├── orgs/ # Organization-level APIs - ├── sites/ # Site-level APIs - ├── login/ # Authentication APIs - └── utils/ # Utility functions +├── api/v1/ # Auto-generated API endpoints +│ ├── const/ # Constants and enums +│ ├── orgs/ # Organization-level APIs +│ ├── sites/ # Site-level APIs +│ ├── login/ # Authentication APIs +│ └── utils/ # Utility functions +├── device_utils/ # Device utility implementations +│ ├── ap.py # Access Point utilities +│ ├── ex.py # EX Switch utilities +│ ├── srx.py # SRX Firewall utilities +│ ├── ssr.py # Session Smart Router utilities +│ └── ... # Function-based modules (arp, bgp, dhcp, etc.) +└── websockets/ # Real-time WebSocket streaming + ├── __ws_client.py # Base WebSocket client + ├── orgs.py # Organization-level channels + ├── sites.py # Site-level channels + ├── location.py # Location/map channels + └── session.py # Custom URL session channel ``` --- diff --git a/mist_openapi b/mist_openapi index f013c82..2efa6e1 160000 --- a/mist_openapi +++ b/mist_openapi @@ -1 +1 @@ -Subproject commit f013c825a0d9f9baa7bd2f6c953b3d648f0c1ea8 +Subproject commit 2efa6e1024bd4e3532695a6aeb6be5787b51d9f5 diff --git a/pyproject.toml b/pyproject.toml index 1a36784..844d985 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "mistapi" -version = "0.60.4" +version = "0.61.0" authors = [{ name = "Thomas Munzer", email = "tmunzer@juniper.net" }] description = "Python package to simplify the Mist System APIs usage" keywords = ["Mist", "Juniper", "API"] @@ -27,6 +27,7 @@ dependencies = [ "deprecation>=2.1.0", "hvac>=2.3.0", "keyring>=24.3.0", + "websocket-client>=1.8.0", ] [project.urls] @@ -34,8 +35,17 @@ dependencies = [ "Bug Tracker" = "https://github.com/tmunzer/mistapi_python/issues" # UV-specific configuration -[tool.uv] -dev-dependencies = [ +#[tool.uv] +#preview = true + +#[tool.uv.scripts] +#test = "pytest" +#lint = "ruff check src/" +#fmt = "ruff format src/" +#build = "python -m build" + +[dependency-groups] +dev = [ # Testing dependencies "pytest>=8.4.0", "pytest-cov>=6.1.1", diff --git a/src/mistapi/__api_request.py b/src/mistapi/__api_request.py index e6f42a1..1ceb605 100644 --- a/src/mistapi/__api_request.py +++ b/src/mistapi/__api_request.py @@ -17,7 +17,9 @@ import json import os import re -import sys +import time +import urllib.parse +from collections.abc import Callable from typing import Any import requests @@ -33,6 +35,9 @@ class APIRequest: Class handling API Request to the Mist Cloud """ + _MAX_429_RETRIES: int = 3 + _DEFAULT_RETRY_AFTER: int = 5 + def __init__(self) -> None: self._cloud_uri: str = "" self._session = requests.session() @@ -77,16 +82,15 @@ def _log_proxy(self) -> None: re.sub(pwd_regex, ":*********@", self._session.proxies["https"]), ) print( - "apirequest:sending request to proxy server %s", - re.sub(pwd_regex, ":*********@", self._session.proxies["https"]), + f"apirequest:sending request to proxy server {re.sub(pwd_regex, ':*********@', self._session.proxies['https'])}" ) def _next_apitoken(self) -> None: logger.info("apirequest:_next_apitoken:rotating API Token") logger.debug( - f"apirequest:_next_apitoken:current API Token is " - f"{self._apitoken[self._apitoken_index][:4]}..." - f"{self._apitoken[self._apitoken_index][-4:]}" + "apirequest:_next_apitoken:current API Token is %s...%s", + self._apitoken[self._apitoken_index][:4], + self._apitoken[self._apitoken_index][-4:], ) new_index = self._apitoken_index + 1 if new_index >= len(self._apitoken): @@ -97,9 +101,9 @@ def _next_apitoken(self) -> None: {"Authorization": "Token " + self._apitoken[self._apitoken_index]} ) logger.debug( - f"apirequest:_next_apitoken:new API Token is " - f"{self._apitoken[self._apitoken_index][:4]}..." - f"{self._apitoken[self._apitoken_index][-4:]}" + "apirequest:_next_apitoken:new API Token is %s...%s", + self._apitoken[self._apitoken_index][:4], + self._apitoken[self._apitoken_index][-4:], ) else: logger.critical(" /!\\ API TOKEN CRITICAL ERROR /!\\") @@ -111,18 +115,16 @@ def _next_apitoken(self) -> None: " For large organization, it is recommended to configure" " multiple API Tokens (comma separated list) to avoid this issue" ) - logger.critical(" Exiting...") - sys.exit(255) + raise RuntimeError( + "API rate limit reached and no other API Token available. " + "For large organizations, configure multiple API Tokens " + "(comma separated list) to avoid this issue." + ) def _gen_query(self, query: dict[str, str] | None) -> str: - logger.debug(f"apirequest:_gen_query:processing query {query}") - html_query = "?" - if query: - for query_param in query: - html_query += f"{query_param}={query[query_param]}&" - logger.debug(f"apirequest:_gen_query:generated query:{html_query}") - html_query = html_query[:-1] - return html_query + if not query: + return "" + return "?" + urllib.parse.urlencode(query) def _remove_auth_from_headers(self, resp: requests.Response): headers = resp.request.headers @@ -157,6 +159,79 @@ def remove_file_from_body(self, resp: requests.Response): request_body += f"\r\n{i}" return request_body + def _handle_rate_limit(self, resp: requests.Response, attempt: int) -> None: + retry_after = resp.headers.get("Retry-After") + if retry_after: + try: + wait = int(retry_after) + except ValueError: + wait = self._DEFAULT_RETRY_AFTER * (2**attempt) + else: + wait = self._DEFAULT_RETRY_AFTER * (2**attempt) + logger.info( + "apirequest:rate_limited:sleeping %ss (attempt %s/%s)", + wait, + attempt + 1, + self._MAX_429_RETRIES, + ) + time.sleep(wait) + + def _request_with_retry( + self, method_name: str, request_fn: Callable, url: str + ) -> APIResponse: + """Shared retry wrapper for all HTTP methods.""" + resp = None + proxy_failed = False + for attempt in range(self._MAX_429_RETRIES + 1): + try: + logger.info("apirequest:%s:sending request to %s", method_name, url) + self._log_proxy() + resp = request_fn() + logger.debug( + "apirequest:%s:request headers:%s", + method_name, + self._remove_auth_from_headers(resp), + ) + resp.raise_for_status() + break + except requests.exceptions.ProxyError as e: + logger.error("apirequest:%s:Proxy Error: %s", method_name, e) + proxy_failed = True + break + except requests.exceptions.ConnectionError as e: + logger.error("apirequest:%s:Connection Error: %s", method_name, e) + break + except HTTPError as e: + if e.response.status_code == 429 and attempt < self._MAX_429_RETRIES: + logger.warning( + "apirequest:%s:HTTP 429 (attempt %s/%s)", + method_name, + attempt + 1, + self._MAX_429_RETRIES, + ) + try: + self._next_apitoken() + except RuntimeError: + pass # single token — still retry with backoff + self._handle_rate_limit(e.response, attempt) + continue + logger.error("apirequest:%s:HTTP error: %s", method_name, e) + if resp: + logger.error( + "apirequest:%s:HTTP error description: %s", + method_name, + resp.json(), + ) + break + except Exception as e: + logger.error("apirequest:%s:error: %s", method_name, e) + logger.error( + "apirequest:%s:Exception occurred", method_name, exc_info=True + ) + break + self._count += 1 + return APIResponse(url=url, response=resp, proxy_error=proxy_failed) + def mist_get(self, uri: str, query: dict[str, str] | None = None) -> APIResponse: """ GET HTTP Request @@ -173,41 +248,8 @@ def mist_get(self, uri: str, query: dict[str, str] | None = None) -> APIResponse mistapi.APIResponse response from the API call """ - resp = None - proxy_failed = False - try: - url = self._url(uri) + self._gen_query(query) - logger.info(f"apirequest:mist_get:sending request to {url}") - self._log_proxy() - resp = self._session.get(url) - logger.debug( - f"apirequest:mist_get:request headers:{self._remove_auth_from_headers(resp)}" - ) - resp.raise_for_status() - except requests.exceptions.ProxyError as proxy_error: - logger.error(f"apirequest:mist_get:Proxy Error: {proxy_error}") - proxy_failed = True - except requests.exceptions.ConnectionError as connexion_error: - logger.error(f"apirequest:mist_get:Connection Error: {connexion_error}") - except HTTPError as http_err: - if http_err.response.status_code == 429: - logger.warning( - "apirequest:mist_get:" - "got HTTP Error 429 from Mist. Will try with next API Token" - ) - self._next_apitoken() - return self.mist_get(uri, query) - logger.error(f"apirequest:mist_get:HTTP error occurred: {http_err}") - if resp: - logger.error( - f"apirequest:mist_get:HTTP error description: {resp.json()}" - ) - except Exception as err: - logger.error(f"apirequest:mist_get:Other error occurred: {err}") - logger.error("apirequest:mist_get:Exception occurred", exc_info=True) - finally: - self._count += 1 - return APIResponse(url=url, response=resp, proxy_error=proxy_failed) + url = self._url(uri) + self._gen_query(query) + return self._request_with_retry("mist_get", lambda: self._session.get(url), url) def mist_post(self, uri: str, body: dict | list | None = None) -> APIResponse: """ @@ -224,48 +266,14 @@ def mist_post(self, uri: str, body: dict | list | None = None) -> APIResponse: mistapi.APIResponse response from the API call """ - resp = None - proxy_failed = False - try: - url = self._url(uri) - logger.info(f"apirequest:mist_post:sending request to {url}") - headers = {"Content-Type": "application/json"} - logger.debug(f"apirequest:mist_post:Request body:{body}") - if isinstance(body, str): - self._log_proxy() - resp = self._session.post(url, data=body, headers=headers) - else: - self._log_proxy() - resp = self._session.post(url, json=body, headers=headers) - logger.debug( - f"apirequest:mist_post:request headers:{self._remove_auth_from_headers(resp)}" - ) - logger.debug("apirequest:mist_post:request body: %s", resp.request.body) - resp.raise_for_status() - except requests.exceptions.ProxyError as proxy_error: - logger.error(f"apirequest:mist_post:Proxy Error: {proxy_error}") - proxy_failed = True - except requests.exceptions.ConnectionError as connexion_error: - logger.error(f"apirequest:mist_post:Connection Error: {connexion_error}") - except HTTPError as http_err: - if http_err.response.status_code == 429: - logger.warning( - "apirequest:mist_post:" - "got HTTP Error 429 from Mist. Will try with next API Token" - ) - self._next_apitoken() - return self.mist_post(uri, body) - logger.error(f"apirequest:mist_post: HTTP error occurred: {http_err}") - if resp: - logger.error( - f"apirequest:mist_post: HTTP error description: {resp.json()}" - ) - except Exception as err: - logger.error(f"apirequest:mist_post: Other error occurred: {err}") - logger.error("apirequest:mist_post: Exception occurred", exc_info=True) - finally: - self._count += 1 - return APIResponse(url=url, response=resp, proxy_error=proxy_failed) + url = self._url(uri) + headers = {"Content-Type": "application/json"} + logger.debug("apirequest:mist_post:Request body:%s", body) + if isinstance(body, str): + fn = lambda: self._session.post(url, data=body, headers=headers) + else: + fn = lambda: self._session.post(url, json=body, headers=headers) + return self._request_with_retry("mist_post", fn, url) def mist_put(self, uri: str, body: dict | None = None) -> APIResponse: """ @@ -282,48 +290,14 @@ def mist_put(self, uri: str, body: dict | None = None) -> APIResponse: mistapi.APIResponse response from the API call """ - resp = None - proxy_failed = False - try: - url = self._url(uri) - logger.info(f"apirequest:mist_put:sending request to {url}") - headers = {"Content-Type": "application/json"} - logger.debug(f"apirequest:mist_put:Request body:{body}") - if isinstance(body, str): - self._log_proxy() - resp = self._session.put(url, data=body, headers=headers) - else: - self._log_proxy() - resp = self._session.put(url, json=body, headers=headers) - logger.debug( - f"apirequest:mist_put:request headers:{self._remove_auth_from_headers(resp)}" - ) - logger.debug("apirequest:mist_put:request body:%s", resp.request.body) - resp.raise_for_status() - except requests.exceptions.ProxyError as proxy_error: - logger.error(f"apirequest:mist_put:Proxy Error: {proxy_error}") - proxy_failed = True - except requests.exceptions.ConnectionError as connexion_error: - logger.error(f"apirequest:mist_put:Connection Error: {connexion_error}") - except HTTPError as http_err: - if http_err.response.status_code == 429: - logger.warning( - "apirequest:mist_put:" - "got HTTP Error 429 from Mist. Will try with next API Token" - ) - self._next_apitoken() - return self.mist_put(uri, body) - logger.error(f"apirequest:mist_put: HTTP error occurred: {http_err}") - if resp: - logger.error( - f"apirequest:mist_put: HTTP error description: {resp.json()}" - ) - except Exception as err: - logger.error(f"apirequest:mist_put: Other error occurred: {err}") - logger.error("apirequest:mist_put: Exception occurred", exc_info=True) - finally: - self._count += 1 - return APIResponse(url=url, response=resp, proxy_error=proxy_failed) + url = self._url(uri) + headers = {"Content-Type": "application/json"} + logger.debug("apirequest:mist_put:Request body:%s", body) + if isinstance(body, str): + fn = lambda: self._session.put(url, data=body, headers=headers) + else: + fn = lambda: self._session.put(url, json=body, headers=headers) + return self._request_with_retry("mist_put", fn, url) def mist_delete(self, uri: str, query: dict | None = None) -> APIResponse: """ @@ -339,37 +313,10 @@ def mist_delete(self, uri: str, query: dict | None = None) -> APIResponse: mistapi.APIResponse response from the API call """ - resp = None - proxy_failed = False - try: - url = self._url(uri) + self._gen_query(query) - logger.info(f"apirequest:mist_delete:sending request to {url}") - self._log_proxy() - resp = self._session.delete(url) - logger.debug( - f"apirequest:mist_delete:request headers:{self._remove_auth_from_headers(resp)}" - ) - resp.raise_for_status() - except requests.exceptions.ProxyError as proxy_error: - logger.error(f"apirequest:mist_delete:Proxy Error: {proxy_error}") - proxy_failed = True - except requests.exceptions.ConnectionError as connexion_error: - logger.error(f"apirequest:mist_delete:Connection Error: {connexion_error}") - except HTTPError as http_err: - if http_err.response.status_code == 429: - logger.warning( - "apirequest:mist_delete:" - "got HTTP Error 429 from Mist. Will try with next API Token" - ) - self._next_apitoken() - return self.mist_delete(uri, query) - logger.error(f"apirequest:mist_delete: HTTP error occurred: {http_err}") - except Exception as err: - logger.error(f"apirequest:mist_delete: Other error occurred: {err}") - logger.error("apirequest:mist_delete: Exception occurred", exc_info=True) - finally: - self._count += 1 - return APIResponse(url=url, response=resp, proxy_error=proxy_failed) + url = self._url(uri) + self._gen_query(query) + return self._request_with_retry( + "mist_delete", lambda: self._session.delete(url), url + ) def mist_post_file( self, uri: str, multipart_form_data: dict | None = None @@ -389,85 +336,60 @@ def mist_post_file( mistapi.APIResponse response from the API call """ - resp = None - proxy_failed = False - try: - if multipart_form_data is None: - multipart_form_data = {} - url = self._url(uri) - logger.info(f"apirequest:mist_post_file:sending request to {url}") + if multipart_form_data is None: + multipart_form_data = {} + url = self._url(uri) + logger.debug( + "apirequest:mist_post_file:initial multipart_form_data:%s", + multipart_form_data, + ) + generated_multipart_form_data: dict[str, Any] = {} + for key in multipart_form_data: logger.debug( - f"apirequest:mist_post_file:initial multipart_form_data:{multipart_form_data}" + "apirequest:mist_post_file:multipart_form_data:%s = %s", + key, + multipart_form_data[key], ) - generated_multipart_form_data: dict[str, Any] = {} - for key in multipart_form_data: - logger.debug( - f"apirequest:mist_post_file:" - f"multipart_form_data:{key} = {multipart_form_data[key]}" - ) - if multipart_form_data[key]: - try: - if key in ["csv", "file"]: - logger.debug( - f"apirequest:mist_post_file:reading file:{multipart_form_data[key]}" - ) - f = open(multipart_form_data[key], "rb") - generated_multipart_form_data[key] = ( - os.path.basename(multipart_form_data[key]), - f, - "application/octet-stream", - ) - else: - generated_multipart_form_data[key] = ( - None, - json.dumps(multipart_form_data[key]), - ) - except (OSError, json.JSONDecodeError): - logger.error( - f"apirequest:mist_post_file:multipart_form_data:" - f"Unable to parse JSON object {key} " - f"with value {multipart_form_data[key]}" + if multipart_form_data[key]: + try: + if key in ["csv", "file"]: + logger.debug( + "apirequest:mist_post_file:reading file:%s", + multipart_form_data[key], ) - logger.error( - "apirequest:mist_post_file: Exception occurred", - exc_info=True, + f = open(multipart_form_data[key], "rb") + generated_multipart_form_data[key] = ( + os.path.basename(multipart_form_data[key]), + f, + "application/octet-stream", ) - logger.debug( - f"apirequest:mist_post_file:" - f"final multipart_form_data:{generated_multipart_form_data}" - ) - self._log_proxy() + else: + generated_multipart_form_data[key] = ( + None, + json.dumps(multipart_form_data[key]), + ) + except (OSError, json.JSONDecodeError): + logger.error( + "apirequest:mist_post_file:multipart_form_data:" + "Unable to parse JSON object %s with value %s", + key, + multipart_form_data[key], + ) + logger.error( + "apirequest:mist_post_file: Exception occurred", + exc_info=True, + ) + logger.debug( + "apirequest:mist_post_file:final multipart_form_data:%s", + generated_multipart_form_data, + ) + + def _do_post_file(): resp = self._session.post(url, files=generated_multipart_form_data) logger.debug( - f"apirequest:mist_post_file:request headers:{self._remove_auth_from_headers(resp)}" - ) - logger.debug( - f"apirequest:mist_post_file:request body:{self.remove_file_from_body(resp)}" + "apirequest:mist_post_file:request body:%s", + self.remove_file_from_body(resp), ) - resp.raise_for_status() - except requests.exceptions.ProxyError as proxy_error: - logger.error(f"apirequest:mist_post_file:Proxy Error: {proxy_error}") - proxy_failed = True - except requests.exceptions.ConnectionError as connexion_error: - logger.error( - f"apirequest:mist_post_file:Connection Error: {connexion_error}" - ) - except HTTPError as http_err: - if http_err.response.status_code == 429: - logger.warning( - "apirequest:mist_post_file:" - "got HTTP Error 429 from Mist. Will try with next API Token" - ) - self._next_apitoken() - return self.mist_post_file(uri, multipart_form_data) - logger.error(f"apirequest:mist_post_file: HTTP error occurred: {http_err}") - if resp: - logger.error( - f"apirequest:mist_post_file: HTTP error description: {resp.json()}" - ) - except Exception as err: - logger.error(f"apirequest:mist_post_file: Other error occurred: {err}") - logger.error("apirequest:mist_post_file: Exception occurred", exc_info=True) - finally: - self._count += 1 - return APIResponse(url=url, response=resp, proxy_error=proxy_failed) + return resp + + return self._request_with_retry("mist_post_file", _do_post_file, url) diff --git a/src/mistapi/__api_session.py b/src/mistapi/__api_session.py index 154fb52..e84a191 100644 --- a/src/mistapi/__api_session.py +++ b/src/mistapi/__api_session.py @@ -21,7 +21,7 @@ import keyring import requests from dotenv import load_dotenv -from requests import Response, Session +from requests import Session from mistapi.__api_request import APIRequest from mistapi.__api_response import APIResponse @@ -132,17 +132,26 @@ def __init__( self._logging_log_level = logging_log_level self._show_cli_notif = show_cli_notif self._proxies = {"https": https_proxy} - self.vault_url = vault_url - self.vault_path = vault_path - self.vault_mount_point = vault_mount_point - self.vault_token = vault_token + self._vault_url = vault_url + self._vault_path = vault_path + self._vault_mount_point = vault_mount_point + self._vault_token = vault_token CONSOLE._set_log_level(console_log_level, logging_log_level) self._load_env(env_file) if keyring_service: self._load_keyring(keyring_service) - if self.vault_path: - self._load_vault() + if self._vault_path: + self._load_vault() # finally block deletes _vault_* attrs + else: + for attr in ( + "_vault_url", + "_vault_token", + "_vault_path", + "_vault_mount_point", + ): + if hasattr(self, attr): + delattr(self, attr) # Filter out None values before updating proxies filtered_proxies = {k: v for k, v in self._proxies.items() if v is not None} self._session.proxies.update(filtered_proxies) @@ -165,6 +174,18 @@ def __init__( LOGGER.debug("apisession:__init__: API Session initialized") + def _new_session(self) -> Session: + session = requests.session() + session.headers["Accept"] = "application/json, application/vnd.api+json" + filtered_proxies = {k: v for k, v in self._proxies.items() if v is not None} + if filtered_proxies: + session.proxies.update(filtered_proxies) + if self._apitoken and self._apitoken_index >= 0: + session.headers["Authorization"] = ( + "Token " + self._apitoken[self._apitoken_index] + ) + return session + def __str__(self) -> str: fields = [ "email", @@ -206,7 +227,7 @@ def _load_vault( Load Vault settings from env file """ LOGGER.info("apisession:_load_vault: Loading Vault settings") - client = hvac.Client(url=self.vault_url, token=self.vault_token, verify=False) + client = hvac.Client(url=self._vault_url, token=self._vault_token) if not client.is_authenticated(): LOGGER.error("apisession:_load_vault: Vault authentication failed") CONSOLE.error("Vault authentication failed") @@ -216,7 +237,7 @@ def _load_vault( ) try: read_response = client.secrets.kv.v2.read_secret( - path=self.vault_path, mount_point=self.vault_mount_point + path=self._vault_path, mount_point=self._vault_mount_point ) LOGGER.info("apisession:_load_vault: Secret retrieved successfully") @@ -232,10 +253,10 @@ def _load_vault( LOGGER.error("apisession:_load_vault: Failed to retrieve secret") CONSOLE.error("Failed to retrieve secret") finally: - del self.vault_url - del self.vault_path - del self.vault_mount_point - del self.vault_token + del self._vault_url + del self._vault_path + del self._vault_mount_point + del self._vault_token def _load_keyring(self, keyring_service) -> None: """ @@ -264,7 +285,7 @@ def _load_keyring(self, keyring_service) -> None: self.set_api_token(mist_apitoken) mist_user = keyring.get_password(keyring_service, "MIST_USER") if mist_user: - LOGGER.info("apisession:_load_keyring: MIST_USER=%s", mist_user) + LOGGER.info("apisession:_load_keyring: MIST_USER retrieved") self.set_email(mist_user) mist_password = keyring.get_password(keyring_service, "MIST_PASSWORD") if mist_password: @@ -287,7 +308,7 @@ def _load_env(self, env_file=None) -> None: os.path.expanduser("~"), env_file.replace("~/", "") ) env_file = os.path.abspath(env_file) - CONSOLE.debug(f"Loading settings from {env_file}") + CONSOLE.debug("Loading settings from %s", env_file) LOGGER.debug("apisession:_load_env:loading settings from %s", env_file) dotenv_path = Path(env_file) load_dotenv(dotenv_path=dotenv_path, override=True) @@ -322,17 +343,17 @@ def _load_env(self, env_file=None) -> None: except ValueError: self._logging_log_level = 10 # Default fallback - if os.getenv("MIST_VAULT_URL") and not self.vault_url: - self.vault_url = os.getenv("MIST_VAULT_URL") + if os.getenv("MIST_VAULT_URL") and not self._vault_url: + self._vault_url = os.getenv("MIST_VAULT_URL") - if os.getenv("MIST_VAULT_PATH") and not self.vault_path: - self.vault_path = os.getenv("MIST_VAULT_PATH") + if os.getenv("MIST_VAULT_PATH") and not self._vault_path: + self._vault_path = os.getenv("MIST_VAULT_PATH") - if os.getenv("MIST_VAULT_MOUNT_POINT") and not self.vault_mount_point: - self.vault_mount_point = os.getenv("MIST_VAULT_MOUNT_POINT") + if os.getenv("MIST_VAULT_MOUNT_POINT") and not self._vault_mount_point: + self._vault_mount_point = os.getenv("MIST_VAULT_MOUNT_POINT") - if os.getenv("MIST_VAULT_TOKEN") and not self.vault_token: - self.vault_token = os.getenv("MIST_VAULT_TOKEN") + if os.getenv("MIST_VAULT_TOKEN") and not self._vault_token: + self._vault_token = os.getenv("MIST_VAULT_TOKEN") if os.getenv("MIST_KEYRING_SERVICE"): self.keyring_service = os.getenv("MIST_KEYRING_SERVICE") @@ -368,10 +389,10 @@ def set_cloud(self, cloud_uri: str) -> None: LOGGER.debug( "apisession:set_cloud:Mist Cloud configured to %s", self._cloud_uri ) - CONSOLE.debug(f"Mist Cloud configured to {self._cloud_uri}") + CONSOLE.debug("Mist Cloud configured to %s", self._cloud_uri) else: LOGGER.error("apisession:set_cloud: %s is not valid", cloud_uri) - CONSOLE.error(f"{cloud_uri} is not valid") + CONSOLE.error("%s is not valid", cloud_uri) def get_cloud(self): """ @@ -445,8 +466,8 @@ def set_email(self, email: str | None = None) -> None: self.email = email else: self.email = input("Login: ") - LOGGER.info("apisession:set_email:email configured to %s", self.email) - CONSOLE.debug(f"Email configured to {self.email}") + LOGGER.info("apisession:set_email:email configured") + CONSOLE.debug("Email configured") def set_password(self, password: str | None = None) -> None: """ @@ -465,7 +486,7 @@ def set_password(self, password: str | None = None) -> None: LOGGER.info("apisession:set_password:password configured") CONSOLE.debug("Password configured") - def set_api_token(self, apitoken: str) -> None: + def set_api_token(self, apitoken: str, validate: bool = True) -> None: """ Set Mist API Token @@ -473,6 +494,9 @@ def set_api_token(self, apitoken: str) -> None: ----------- apitoken : str API Token to add in the requests headers for authentication and authorization + validate : bool, default True + If True, validate the API tokens against the Mist Cloud before using them. + If False, accept the tokens directly without validation. """ LOGGER.debug("apisession:set_api_token") apitokens_in = apitoken.split(",") @@ -483,7 +507,10 @@ def set_api_token(self, apitoken: str) -> None: apitokens_out.append(token) LOGGER.info("apisession:set_api_token:found %s API Tokens", len(apitokens_out)) - valid_api_tokens = self._check_api_tokens(apitokens_out) + if validate: + valid_api_tokens = self._check_api_tokens(apitokens_out) + else: + valid_api_tokens = apitokens_out if valid_api_tokens: self._apitoken = valid_api_tokens self._apitoken_index = 0 @@ -546,7 +573,10 @@ def _get_api_token_data(self, apitoken) -> tuple[str | None, list | None]: data.status_code, ) CONSOLE.critical( - f"Invalid API Token {apitoken[:4]}...{apitoken[-4:]}: status code {data.status_code}\r\n" + "Invalid API Token %s...%s: status code %s\r\n", + apitoken[:4], + apitoken[-4:], + data.status_code, ) raise ValueError( f"Invalid API Token {apitoken[:4]}...{apitoken[-4:]}: status code {data.status_code}" @@ -596,31 +626,32 @@ def _check_api_tokens(self, apitokens) -> list[str]: primary_token_type: str | None = "" primary_token_value: str = "" for token in apitokens: - token_value = f"{token[:4]}...{token[-4:]}" + not_sensitive_data = f"{token[:4]}...{token[-4:]}" if token in valid_api_tokens: LOGGER.info( "apisession:_check_api_tokens:API Token %s is already valid", - token_value, + not_sensitive_data, ) continue (token_type, token_privileges) = self._get_api_token_data(token) if token_type is None or token_privileges is None: LOGGER.error( "apisession:_check_api_tokens:API Token %s is not valid", - token_value, + not_sensitive_data, ) LOGGER.error( - "API Token %s is not valid and will not be used", token_value + "API Token %s is not valid and will not be used", + not_sensitive_data, ) elif len(primary_token_privileges) == 0 and token_privileges: primary_token_privileges = token_privileges primary_token_type = token_type - primary_token_value = token_value + primary_token_value = not_sensitive_data valid_api_tokens.append(token) LOGGER.info( "apisession:_check_api_tokens:" "API Token %s set as primary for comparison", - token_value, + not_sensitive_data, ) elif primary_token_privileges == token_privileges: valid_api_tokens.append(token) @@ -629,7 +660,7 @@ def _check_api_tokens(self, apitokens) -> list[str]: "%s API Token %s has same privileges as " "the %s API Token %s", token_type, - token_value, + not_sensitive_data, primary_token_type, primary_token_value, ) @@ -639,13 +670,13 @@ def _check_api_tokens(self, apitokens) -> list[str]: "%s API Token %s has different privileges " "than the %s API Token %s", token_type, - token_value, + not_sensitive_data, primary_token_type, primary_token_value, ) LOGGER.error( "API Token %s has different privileges and will not be used", - token_value, + not_sensitive_data, ) return valid_api_tokens @@ -666,7 +697,7 @@ def _process_login(self, retry: bool = True) -> str | None: print(" Login/Pwd authentication ".center(80, "-")) print() - self._session = requests.session() + self._session = self._new_session() if not self.email: self.set_email() if not self._password: @@ -686,7 +717,7 @@ def _process_login(self, retry: bool = True) -> str | None: LOGGER.error( "apisession:_process_login:authentication failed:%s", error ) - CONSOLE.error(f"Authentication failed: {error}\r\n") + CONSOLE.error("Authentication failed: %s\r\n", error) self.email = None self._password = None LOGGER.info( @@ -776,7 +807,7 @@ def login_with_return( Error message from Mist (if any) """ LOGGER.debug("apisession:login_with_return") - self._session = requests.session() + self._session = self._new_session() if apitoken: self.set_api_token(apitoken) if email: @@ -825,11 +856,18 @@ def login_with_return( LOGGER.error("apisession:login_with_return:credentials are missing") return {"authenticated": False, "error": "credentials are missing"} - if resp.status_code == 200 and not resp.data.get("two_factor_required", False): + if ( + resp.status_code == 200 + and isinstance(resp.data, dict) + and not resp.data.get("two_factor_required", False) + ): LOGGER.info("apisession:login_with_return:access authorized") return {"authenticated": True, "error": ""} else: - LOGGER.error("apisession:login_with_return:access denied: %s", resp.data) + LOGGER.error( + "apisession:login_with_return:access denied: status code %s", + resp.status_code, + ) return {"authenticated": False, "error": resp.data} def logout(self) -> None: @@ -850,7 +888,8 @@ def logout(self) -> None: self._set_authenticated(False) else: try: - CONSOLE.error(resp.data["detail"]) + if isinstance(resp.data, dict) and "detail" in resp.data: + CONSOLE.error(resp.data["detail"]) except (KeyError, TypeError, AttributeError): if isinstance(resp.raw_data, bytes): CONSOLE.error(resp.raw_data.decode("utf-8", errors="replace")) @@ -944,8 +983,7 @@ def get_api_token(self) -> APIResponse: LOGGER.info( 'apisession:get_api_token: Sending GET request to "/api/v1/self/apitokens"' ) - resp = self.mist_get("/api/v1/self/apitokens") - return resp + return self.mist_get("/api/v1/self/apitokens") def create_api_token(self, token_name: str | None = None) -> APIResponse: """ @@ -970,10 +1008,9 @@ def create_api_token(self, token_name: str | None = None) -> APIResponse: 'sending POST request to "/api/v1/self/apitokens" with name "%s"', token_name, ) - resp = self.mist_post("/api/v1/self/apitokens", body=body) - return resp + return self.mist_post("/api/v1/self/apitokens", body=body) - def delete_api_token(self, apitoken_id: str) -> Response: + def delete_api_token(self, apitoken_id: str) -> APIResponse: """ Delete an API Token based on its token_id @@ -993,9 +1030,7 @@ def delete_api_token(self, apitoken_id: str) -> Response: 'sending DELETE request to "/api/v1/self/apitokens" with token_id "%s"', apitoken_id, ) - uri = f"https://{self._cloud_uri}/api/v1/self/apitokens/{apitoken_id}" - resp = self._session.delete(uri) - return resp + return self.mist_delete(f"/api/v1/self/apitokens/{apitoken_id}") def _two_factor_authentication(self, two_factor: str) -> bool: """ @@ -1034,7 +1069,8 @@ def _two_factor_authentication(self, two_factor: str) -> bool: resp.status_code, ) CONSOLE.error( - f"2FA authentication failed with error code: {resp.status_code}\r\n" + "2FA authentication failed with error code: %s\r\n", + resp.status_code, ) return False @@ -1046,7 +1082,7 @@ def _getself(self) -> bool: uri = "/api/v1/self" LOGGER.info('apisession:_getself: sending GET request to "%s"', uri) resp = self.mist_get(uri) - if resp.status_code == 200 and resp.data: + if resp.status_code == 200 and resp.data and isinstance(resp.data, dict): # Deal with 2FA if needed if ( resp.data.get("two_factor_required") is True @@ -1063,24 +1099,27 @@ def _getself(self) -> bool: LOGGER.info( "apisession:_getself:authentication Ok. Processing account privileges" ) - for key, val in resp.data.items(): - if key == "privileges": - self.privileges = Privileges(resp.data["privileges"]) - if key == "tags": - for tag in resp.data["tags"]: - self.tags.append(tag) - else: - setattr(self, key, val) - if self._show_cli_notif: - print() - print(" Authenticated ".center(80, "-")) - print(f"\r\nWelcome {self.first_name} {self.last_name}!\r\n") - LOGGER.info( - "apisession:_getself:account used: %s %s", - self.first_name, - self.last_name, - ) - return True + if isinstance(resp.data, dict): + for key, val in resp.data.items(): + if key == "privileges": + self.privileges = Privileges(resp.data["privileges"]) + if key == "tags": + for tag in resp.data["tags"]: + self.tags.append(tag) + else: + setattr(self, key, val) + if self._show_cli_notif: + print() + print(" Authenticated ".center(80, "-")) + print(f"\r\nWelcome {self.first_name} {self.last_name}!\r\n") + LOGGER.info( + "apisession:_getself:account info processed successfully" + ) + return True + else: + raise ValueError( + "Unexpected format for privileges in the response data" + ) elif resp.proxy_error: LOGGER.critical("apisession:_getself:proxy not valid...") CONSOLE.critical("Proxy not valid...\r\n") @@ -1144,7 +1183,11 @@ def get_privilege_by_org_id(self, org_id: str): msp_id = None try: resp = self.mist_get(uri) - if resp.data and resp.data.get("msp_id"): + if ( + resp.data + and isinstance(resp.data, dict) + and resp.data.get("msp_id") + ): LOGGER.info( "apisession:get_privilege_by_org_id:org %s belong to msp_id %s", {org_id}, @@ -1178,7 +1221,7 @@ def get_privilege_by_org_id(self, org_id: str): "unable of find msp %s privileges in user data", msp_id, ) - else: + elif isinstance(resp.data, dict): return { "scope": "org", "org_id": org_id, diff --git a/src/mistapi/__init__.py b/src/mistapi/__init__.py index 891e035..d211453 100644 --- a/src/mistapi/__init__.py +++ b/src/mistapi/__init__.py @@ -10,10 +10,34 @@ -------------------------------------------------------------------------------- """ +# isort: skip_file from mistapi.__api_session import APISession as APISession -from mistapi import api as api -from mistapi import cli as cli from mistapi.__pagination import get_all as get_all from mistapi.__pagination import get_next as get_next from mistapi.__version import __author__ as __author__ from mistapi.__version import __version__ as __version__ + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from mistapi import api as api + from mistapi import cli as cli + from mistapi import device_utils as device_utils + from mistapi import websockets as websockets + +_LAZY_SUBPACKAGES = { + "api": "mistapi.api", + "cli": "mistapi.cli", + "websockets": "mistapi.websockets", + "device_utils": "mistapi.device_utils", +} + + +def __getattr__(name: str): + if name in _LAZY_SUBPACKAGES: + import importlib + + module = importlib.import_module(_LAZY_SUBPACKAGES[name]) + globals()[name] = module + return module + raise AttributeError(f"module 'mistapi' has no attribute {name!r}") diff --git a/src/mistapi/__logger.py b/src/mistapi/__logger.py index e870683..f69c366 100644 --- a/src/mistapi/__logger.py +++ b/src/mistapi/__logger.py @@ -150,60 +150,31 @@ def sanitize(self, data) -> str: sanitized_data = self.sensitive_pattern.sub(r'\1"******"', data_str) return sanitized_data - def critical(self, message) -> None: - """ - Docstring for critical + def _format(self, message, args) -> str: + """Apply %-style formatting if args are provided, then sanitize.""" + if args: + message = str(message) % args + return self.sanitize(message) - :param self: Description - :param message: Description - :type message: str - """ + def critical(self, message, *args) -> None: if self.level <= 50 and self.level > 0: - print(f"[{magenta('CRITICAL ')}] {self.sanitize(message)}") - - def error(self, message) -> None: - """ - Docstring for error + print(f"[{magenta('CRITICAL ')}] {self._format(message, args)}") - :param self: Description - :param message: Description - :type message: str - """ + def error(self, message, *args) -> None: if self.level <= 40 and self.level > 0: - print(f"[{red(' ERROR ')}] {self.sanitize(message)}") - - def warning(self, message) -> None: - """ - Docstring for warning + print(f"[{red(' ERROR ')}] {self._format(message, args)}") - :param self: Description - :param message: Description - :type message: str - """ + def warning(self, message, *args) -> None: if self.level <= 30 and self.level > 0: - print(f"[{yellow(' WARNING ')}] {self.sanitize(message)}") - - def info(self, message) -> None: - """ - Docstring for info + print(f"[{yellow(' WARNING ')}] {self._format(message, args)}") - :param self: Description - :param message: Description - :type message: str - """ + def info(self, message, *args) -> None: if self.level <= 20 and self.level > 0: - print(f"[{green(' INFO ')}] {self.sanitize(message)}") + print(f"[{green(' INFO ')}] {self._format(message, args)}") - def debug(self, message) -> None: - """ - Docstring for debug - - :param self: Description - :param message: Description - :type message: str - """ + def debug(self, message, *args) -> None: if self.level <= 10 and self.level > 0: - print(f"[{white('DEBUG ')}] {self.sanitize(message)}") + print(f"[{white('DEBUG ')}] {self._format(message, args)}") def _set_log_level( self, console_log_level: int = 20, logging_log_level: int = 10 @@ -238,7 +209,8 @@ def __init__(self): self.console = Console() def filter(self, record): - record.msg = self.console.sanitize(record.msg) + record.msg = self.console.sanitize(record.getMessage()) + record.args = None return True diff --git a/src/mistapi/__version.py b/src/mistapi/__version.py index 957e2ce..4e16e40 100644 --- a/src/mistapi/__version.py +++ b/src/mistapi/__version.py @@ -1,2 +1,2 @@ -__version__ = "0.60.4" +__version__ = "0.61.0" __author__ = "Thomas Munzer " diff --git a/src/mistapi/api/v1/orgs/clients.py b/src/mistapi/api/v1/orgs/clients.py index ff81659..9e765f5 100644 --- a/src/mistapi/api/v1/orgs/clients.py +++ b/src/mistapi/api/v1/orgs/clients.py @@ -284,20 +284,20 @@ def searchOrgWirelessClients( mist_session: _APISession, org_id: str, site_id: str | None = None, - mac: str | None = None, - ip: str | None = None, - hostname: str | None = None, + ap: str | None = None, band: str | None = None, device: str | None = None, - os: str | None = None, + hostname: str | None = None, + ip: str | None = None, + mac: str | None = None, model: str | None = None, - ap: str | None = None, + os: str | None = None, psk_id: str | None = None, psk_name: str | None = None, - username: str | None = None, - vlan: str | None = None, ssid: str | None = None, text: str | None = None, + username: str | None = None, + vlan: str | None = None, limit: int | None = None, start: str | None = None, end: str | None = None, @@ -320,20 +320,20 @@ def searchOrgWirelessClients( QUERY PARAMS ------------ site_id : str - mac : str - ip : str - hostname : str + ap : str band : str device : str - os : str + hostname : str + ip : str + mac : str model : str - ap : str + os : str psk_id : str psk_name : str - username : str - vlan : str ssid : str text : str + username : str + vlan : str limit : int, default: 100 start : str end : str @@ -351,34 +351,34 @@ def searchOrgWirelessClients( query_params: dict[str, str] = {} if site_id: query_params["site_id"] = str(site_id) - if mac: - query_params["mac"] = str(mac) - if ip: - query_params["ip"] = str(ip) - if hostname: - query_params["hostname"] = str(hostname) + if ap: + query_params["ap"] = str(ap) if band: query_params["band"] = str(band) if device: query_params["device"] = str(device) - if os: - query_params["os"] = str(os) + if hostname: + query_params["hostname"] = str(hostname) + if ip: + query_params["ip"] = str(ip) + if mac: + query_params["mac"] = str(mac) if model: query_params["model"] = str(model) - if ap: - query_params["ap"] = str(ap) + if os: + query_params["os"] = str(os) if psk_id: query_params["psk_id"] = str(psk_id) if psk_name: query_params["psk_name"] = str(psk_name) - if username: - query_params["username"] = str(username) - if vlan: - query_params["vlan"] = str(vlan) if ssid: query_params["ssid"] = str(ssid) if text: query_params["text"] = str(text) + if username: + query_params["username"] = str(username) + if vlan: + query_params["vlan"] = str(vlan) if limit: query_params["limit"] = str(limit) if start: diff --git a/src/mistapi/api/v1/orgs/devices.py b/src/mistapi/api/v1/orgs/devices.py index ac09149..bc97744 100644 --- a/src/mistapi/api/v1/orgs/devices.py +++ b/src/mistapi/api/v1/orgs/devices.py @@ -485,17 +485,16 @@ def listOrgApsMacs( def searchOrgDevices( mist_session: _APISession, org_id: str, - band_24_bandwidth: int | None = None, band_24_channel: int | None = None, - band_24_power: int | None = None, - band_5_bandwidth: int | None = None, band_5_channel: int | None = None, - band_5_power: int | None = None, - band_6_bandwidth: int | None = None, band_6_channel: int | None = None, + band_24_bandwidth: int | None = None, + band_5_bandwidth: int | None = None, + band_6_bandwidth: int | None = None, + band_24_power: int | None = None, + band_5_power: int | None = None, band_6_power: int | None = None, - cpu: str | None = None, - clustered: str | None = None, + clustered: bool | None = None, eth0_port_speed: int | None = None, evpntopo_id: str | None = None, ext_ip: str | None = None, @@ -505,8 +504,6 @@ def searchOrgDevices( last_hostname: str | None = None, lldp_mgmt_addr: str | None = None, lldp_port_id: str | None = None, - lldp_power_allocated: int | None = None, - lldp_power_draw: int | None = None, lldp_system_desc: str | None = None, lldp_system_name: str | None = None, mac: str | None = None, @@ -518,10 +515,12 @@ def searchOrgDevices( node0_mac: str | None = None, node1_mac: str | None = None, power_constrained: bool | None = None, + radius_stats: str | None = None, site_id: str | None = None, + stats: bool | None = None, t128agent_version: str | None = None, - version: str | None = None, type: str | None = None, + version: str | None = None, limit: int | None = None, start: str | None = None, end: str | None = None, @@ -543,17 +542,16 @@ def searchOrgDevices( QUERY PARAMS ------------ - band_24_bandwidth : int band_24_channel : int - band_24_power : int - band_5_bandwidth : int band_5_channel : int - band_5_power : int - band_6_bandwidth : int band_6_channel : int + band_24_bandwidth : int + band_5_bandwidth : int + band_6_bandwidth : int + band_24_power : int + band_5_power : int band_6_power : int - cpu : str - clustered : str + clustered : bool eth0_port_speed : int evpntopo_id : str ext_ip : str @@ -563,8 +561,6 @@ def searchOrgDevices( last_hostname : str lldp_mgmt_addr : str lldp_port_id : str - lldp_power_allocated : int - lldp_power_draw : int lldp_system_desc : str lldp_system_name : str mac : str @@ -572,16 +568,19 @@ def searchOrgDevices( mxedge_id : str mxedge_ids : str mxtunnel_status : str{'down', 'up'} - If `type`==`ap`, MxTunnel status, up / down - node : str + When `type`==`ap`, MxTunnel status, up / down. + node : str{'node0', 'node1'} + When `type`==`gateway`. enum: `node0`, `node1` node0_mac : str node1_mac : str power_constrained : bool + radius_stats : str site_id : str + stats : bool t128agent_version : str - version : str type : str{'ap', 'gateway', 'switch'}, default: ap Type of device. enum: `ap`, `gateway`, `switch` + version : str limit : int, default: 100 start : str end : str @@ -597,26 +596,24 @@ def searchOrgDevices( uri = f"/api/v1/orgs/{org_id}/devices/search" query_params: dict[str, str] = {} - if band_24_bandwidth: - query_params["band_24_bandwidth"] = str(band_24_bandwidth) if band_24_channel: query_params["band_24_channel"] = str(band_24_channel) - if band_24_power: - query_params["band_24_power"] = str(band_24_power) - if band_5_bandwidth: - query_params["band_5_bandwidth"] = str(band_5_bandwidth) if band_5_channel: query_params["band_5_channel"] = str(band_5_channel) - if band_5_power: - query_params["band_5_power"] = str(band_5_power) - if band_6_bandwidth: - query_params["band_6_bandwidth"] = str(band_6_bandwidth) if band_6_channel: query_params["band_6_channel"] = str(band_6_channel) + if band_24_bandwidth: + query_params["band_24_bandwidth"] = str(band_24_bandwidth) + if band_5_bandwidth: + query_params["band_5_bandwidth"] = str(band_5_bandwidth) + if band_6_bandwidth: + query_params["band_6_bandwidth"] = str(band_6_bandwidth) + if band_24_power: + query_params["band_24_power"] = str(band_24_power) + if band_5_power: + query_params["band_5_power"] = str(band_5_power) if band_6_power: query_params["band_6_power"] = str(band_6_power) - if cpu: - query_params["cpu"] = str(cpu) if clustered: query_params["clustered"] = str(clustered) if eth0_port_speed: @@ -637,10 +634,6 @@ def searchOrgDevices( query_params["lldp_mgmt_addr"] = str(lldp_mgmt_addr) if lldp_port_id: query_params["lldp_port_id"] = str(lldp_port_id) - if lldp_power_allocated: - query_params["lldp_power_allocated"] = str(lldp_power_allocated) - if lldp_power_draw: - query_params["lldp_power_draw"] = str(lldp_power_draw) if lldp_system_desc: query_params["lldp_system_desc"] = str(lldp_system_desc) if lldp_system_name: @@ -663,14 +656,18 @@ def searchOrgDevices( query_params["node1_mac"] = str(node1_mac) if power_constrained: query_params["power_constrained"] = str(power_constrained) + if radius_stats: + query_params["radius_stats"] = str(radius_stats) if site_id: query_params["site_id"] = str(site_id) + if stats: + query_params["stats"] = str(stats) if t128agent_version: query_params["t128agent_version"] = str(t128agent_version) - if version: - query_params["version"] = str(version) if type: query_params["type"] = str(type) + if version: + query_params["version"] = str(version) if limit: query_params["limit"] = str(limit) if start: diff --git a/src/mistapi/api/v1/orgs/inventory.py b/src/mistapi/api/v1/orgs/inventory.py index 2be3bc5..85103b0 100644 --- a/src/mistapi/api/v1/orgs/inventory.py +++ b/src/mistapi/api/v1/orgs/inventory.py @@ -317,8 +317,7 @@ def searchOrgInventory( type: str | None = None, mac: str | None = None, model: str | None = None, - vc_mac: str | None = None, - master_mac: str | None = None, + name: str | None = None, site_id: str | None = None, serial: str | None = None, master: str | None = None, @@ -347,14 +346,14 @@ def searchOrgInventory( type : str{'ap', 'gateway', 'switch'}, default: ap mac : str model : str - vc_mac : str - master_mac : str + name : str site_id : str serial : str master : str sku : str version : str - status : str + status : str{'connected', 'disconnected'} + Device status. enum: `connected`, `disconnected` text : str limit : int, default: 100 sort : str, default: timestamp @@ -374,10 +373,8 @@ def searchOrgInventory( query_params["mac"] = str(mac) if model: query_params["model"] = str(model) - if vc_mac: - query_params["vc_mac"] = str(vc_mac) - if master_mac: - query_params["master_mac"] = str(master_mac) + if name: + query_params["name"] = str(name) if site_id: query_params["site_id"] = str(site_id) if serial: diff --git a/src/mistapi/api/v1/orgs/mxedges.py b/src/mistapi/api/v1/orgs/mxedges.py index 81986a8..4fc974f 100644 --- a/src/mistapi/api/v1/orgs/mxedges.py +++ b/src/mistapi/api/v1/orgs/mxedges.py @@ -379,12 +379,13 @@ def searchOrgMistEdgeEvents( def searchOrgMxEdges( mist_session: _APISession, org_id: str, + hostname: str | None = None, mxedge_id: str | None = None, - site_id: str | None = None, mxcluster_id: str | None = None, model: str | None = None, distro: str | None = None, tunterm_version: str | None = None, + site_id: str | None = None, stats: bool | None = None, limit: int | None = None, start: str | None = None, @@ -407,12 +408,13 @@ def searchOrgMxEdges( QUERY PARAMS ------------ + hostname : str mxedge_id : str - site_id : str mxcluster_id : str model : str distro : str tunterm_version : str + site_id : str stats : bool limit : int, default: 100 start : str @@ -429,10 +431,10 @@ def searchOrgMxEdges( uri = f"/api/v1/orgs/{org_id}/mxedges/search" query_params: dict[str, str] = {} + if hostname: + query_params["hostname"] = str(hostname) if mxedge_id: query_params["mxedge_id"] = str(mxedge_id) - if site_id: - query_params["site_id"] = str(site_id) if mxcluster_id: query_params["mxcluster_id"] = str(mxcluster_id) if model: @@ -441,6 +443,8 @@ def searchOrgMxEdges( query_params["distro"] = str(distro) if tunterm_version: query_params["tunterm_version"] = str(tunterm_version) + if site_id: + query_params["site_id"] = str(site_id) if stats: query_params["stats"] = str(stats) if limit: diff --git a/src/mistapi/api/v1/orgs/nac_clients.py b/src/mistapi/api/v1/orgs/nac_clients.py index 708ebe4..24e6515 100644 --- a/src/mistapi/api/v1/orgs/nac_clients.py +++ b/src/mistapi/api/v1/orgs/nac_clients.py @@ -417,8 +417,8 @@ def searchOrgNacClients( ingress_vlan : str os : str ssid : str - status : str{'permitted', 'session_started', 'session_ended', 'denied'} - Connection status of client i.e "permitted", "denied, "session_stared", "session_ended" + status : str{'permitted', 'session_started', 'session_stopped', 'denied'} + Connection status of client i.e "permitted", "denied, "session_started", "session_stopped" text : str timestamp : float type : str diff --git a/src/mistapi/api/v1/orgs/wan_clients.py b/src/mistapi/api/v1/orgs/wan_clients.py index 4af3281..c191c8b 100644 --- a/src/mistapi/api/v1/orgs/wan_clients.py +++ b/src/mistapi/api/v1/orgs/wan_clients.py @@ -148,12 +148,12 @@ def searchOrgWanClients( mist_session: _APISession, org_id: str, site_id: str | None = None, - mac: str | None = None, hostname: str | None = None, ip: str | None = None, - network: str | None = None, ip_src: str | None = None, + mac: str | None = None, mfg: str | None = None, + network: str | None = None, limit: int | None = None, start: str | None = None, end: str | None = None, @@ -176,12 +176,12 @@ def searchOrgWanClients( QUERY PARAMS ------------ site_id : str - mac : str hostname : str ip : str - network : str ip_src : str + mac : str mfg : str + network : str limit : int, default: 100 start : str end : str @@ -199,18 +199,18 @@ def searchOrgWanClients( query_params: dict[str, str] = {} if site_id: query_params["site_id"] = str(site_id) - if mac: - query_params["mac"] = str(mac) if hostname: query_params["hostname"] = str(hostname) if ip: query_params["ip"] = str(ip) - if network: - query_params["network"] = str(network) if ip_src: query_params["ip_src"] = str(ip_src) + if mac: + query_params["mac"] = str(mac) if mfg: query_params["mfg"] = str(mfg) + if network: + query_params["network"] = str(network) if limit: query_params["limit"] = str(limit) if start: diff --git a/src/mistapi/api/v1/sites/clients.py b/src/mistapi/api/v1/sites/clients.py index 5ec1955..8a52fcc 100644 --- a/src/mistapi/api/v1/sites/clients.py +++ b/src/mistapi/api/v1/sites/clients.py @@ -301,16 +301,20 @@ def searchSiteWirelessClientEvents( def searchSiteWirelessClients( mist_session: _APISession, site_id: str, - mac: str | None = None, - ip: str | None = None, - hostname: str | None = None, + ap: str | None = None, + band: str | None = None, device: str | None = None, - os: str | None = None, + hostname: str | None = None, + ip: str | None = None, + mac: str | None = None, model: str | None = None, - ap: str | None = None, + os: str | None = None, + psk_id: str | None = None, + psk_name: str | None = None, ssid: str | None = None, text: str | None = None, - nacrule_id: str | None = None, + username: str | None = None, + vlan: str | None = None, limit: int | None = None, start: str | None = None, end: str | None = None, @@ -332,16 +336,20 @@ def searchSiteWirelessClients( QUERY PARAMS ------------ - mac : str - ip : str - hostname : str + ap : str + band : str device : str - os : str + hostname : str + ip : str + mac : str model : str - ap : str + os : str + psk_id : str + psk_name : str ssid : str text : str - nacrule_id : str + username : str + vlan : str limit : int, default: 100 start : str end : str @@ -357,26 +365,34 @@ def searchSiteWirelessClients( uri = f"/api/v1/sites/{site_id}/clients/search" query_params: dict[str, str] = {} - if mac: - query_params["mac"] = str(mac) - if ip: - query_params["ip"] = str(ip) - if hostname: - query_params["hostname"] = str(hostname) + if ap: + query_params["ap"] = str(ap) + if band: + query_params["band"] = str(band) if device: query_params["device"] = str(device) - if os: - query_params["os"] = str(os) + if hostname: + query_params["hostname"] = str(hostname) + if ip: + query_params["ip"] = str(ip) + if mac: + query_params["mac"] = str(mac) if model: query_params["model"] = str(model) - if ap: - query_params["ap"] = str(ap) + if os: + query_params["os"] = str(os) + if psk_id: + query_params["psk_id"] = str(psk_id) + if psk_name: + query_params["psk_name"] = str(psk_name) if ssid: query_params["ssid"] = str(ssid) if text: query_params["text"] = str(text) - if nacrule_id: - query_params["nacrule_id"] = str(nacrule_id) + if username: + query_params["username"] = str(username) + if vlan: + query_params["vlan"] = str(vlan) if limit: query_params["limit"] = str(limit) if start: diff --git a/src/mistapi/api/v1/sites/devices.py b/src/mistapi/api/v1/sites/devices.py index 6fcd1b9..026a4eb 100644 --- a/src/mistapi/api/v1/sites/devices.py +++ b/src/mistapi/api/v1/sites/devices.py @@ -851,39 +851,41 @@ def restoreSiteMultipleDeviceBackupVersion( def searchSiteDevices( mist_session: _APISession, site_id: str, - hostname: str | None = None, - type: str | None = None, - model: str | None = None, - mac: str | None = None, - ext_ip: str | None = None, - version: str | None = None, - power_constrained: bool | None = None, - ip: str | None = None, - mxtunnel_status: str | None = None, - mxedge_id: str | None = None, - mxedge_ids: list | None = None, - last_hostname: str | None = None, - last_config_status: str | None = None, - radius_stats: str | None = None, - cpu: str | None = None, - node0_mac: str | None = None, - clustered: bool | None = None, - t128agent_version: str | None = None, - node1_mac: str | None = None, - node: str | None = None, - evpntopo_id: str | None = None, - lldp_system_name: str | None = None, - lldp_system_desc: str | None = None, - lldp_port_id: str | None = None, - lldp_mgmt_addr: str | None = None, band_24_channel: int | None = None, band_5_channel: int | None = None, band_6_channel: int | None = None, band_24_bandwidth: int | None = None, band_5_bandwidth: int | None = None, band_6_bandwidth: int | None = None, + band_24_power: int | None = None, + band_5_power: int | None = None, + band_6_power: int | None = None, + clustered: bool | None = None, eth0_port_speed: int | None = None, + evpntopo_id: str | None = None, + ext_ip: str | None = None, + hostname: str | None = None, + ip: str | None = None, + last_config_status: str | None = None, + last_hostname: str | None = None, + lldp_mgmt_addr: str | None = None, + lldp_port_id: str | None = None, + lldp_system_desc: str | None = None, + lldp_system_name: str | None = None, + mac: str | None = None, + model: str | None = None, + mxedge_id: str | None = None, + mxedge_ids: str | None = None, + mxtunnel_status: str | None = None, + node: str | None = None, + node0_mac: str | None = None, + node1_mac: str | None = None, + power_constrained: bool | None = None, + radius_stats: str | None = None, stats: bool | None = None, + t128agent_version: str | None = None, + type: str | None = None, + version: str | None = None, limit: int | None = None, start: str | None = None, end: str | None = None, @@ -906,42 +908,44 @@ def searchSiteDevices( QUERY PARAMS ------------ - hostname : str - type : str{'ap', 'gateway', 'switch'}, default: ap - model : str - mac : str - ext_ip : str - version : str - power_constrained : bool - ip : str - mxtunnel_status : str{'down', 'up'} - For APs only, MxTunnel status, up / down. - mxedge_id : str - mxedge_ids : list - For APs only, list of Mist Edge id, if AP is connecting to a Mist Edge - last_hostname : str - last_config_status : str - radius_stats : str - cpu : str - node0_mac : str - clustered : bool - t128agent_version : str - node1_mac : str - node : str{'node0', 'node1'} - For Gateways only. enum: `node0`, `node1` - evpntopo_id : str - lldp_system_name : str - lldp_system_desc : str - lldp_port_id : str - lldp_mgmt_addr : str band_24_channel : int band_5_channel : int band_6_channel : int band_24_bandwidth : int band_5_bandwidth : int band_6_bandwidth : int + band_24_power : int + band_5_power : int + band_6_power : int + clustered : bool eth0_port_speed : int + evpntopo_id : str + ext_ip : str + hostname : str + ip : str + last_config_status : str + last_hostname : str + lldp_mgmt_addr : str + lldp_port_id : str + lldp_system_desc : str + lldp_system_name : str + mac : str + model : str + mxedge_id : str + mxedge_ids : str + mxtunnel_status : str{'down', 'up'} + When `type`==`ap`, MxTunnel status, up / down. + node : str{'node0', 'node1'} + When `type`==`gateway`. enum: `node0`, `node1` + node0_mac : str + node1_mac : str + power_constrained : bool + radius_stats : str stats : bool + t128agent_version : str + type : str{'ap', 'gateway', 'switch'}, default: ap + Type of device. enum: `ap`, `gateway`, `switch` + version : str limit : int, default: 100 start : str end : str @@ -960,56 +964,6 @@ def searchSiteDevices( uri = f"/api/v1/sites/{site_id}/devices/search" query_params: dict[str, str] = {} - if hostname: - query_params["hostname"] = str(hostname) - if type: - query_params["type"] = str(type) - if model: - query_params["model"] = str(model) - if mac: - query_params["mac"] = str(mac) - if ext_ip: - query_params["ext_ip"] = str(ext_ip) - if version: - query_params["version"] = str(version) - if power_constrained: - query_params["power_constrained"] = str(power_constrained) - if ip: - query_params["ip"] = str(ip) - if mxtunnel_status: - query_params["mxtunnel_status"] = str(mxtunnel_status) - if mxedge_id: - query_params["mxedge_id"] = str(mxedge_id) - if mxedge_ids: - query_params["mxedge_ids"] = str(mxedge_ids) - if last_hostname: - query_params["last_hostname"] = str(last_hostname) - if last_config_status: - query_params["last_config_status"] = str(last_config_status) - if radius_stats: - query_params["radius_stats"] = str(radius_stats) - if cpu: - query_params["cpu"] = str(cpu) - if node0_mac: - query_params["node0_mac"] = str(node0_mac) - if clustered: - query_params["clustered"] = str(clustered) - if t128agent_version: - query_params["t128agent_version"] = str(t128agent_version) - if node1_mac: - query_params["node1_mac"] = str(node1_mac) - if node: - query_params["node"] = str(node) - if evpntopo_id: - query_params["evpntopo_id"] = str(evpntopo_id) - if lldp_system_name: - query_params["lldp_system_name"] = str(lldp_system_name) - if lldp_system_desc: - query_params["lldp_system_desc"] = str(lldp_system_desc) - if lldp_port_id: - query_params["lldp_port_id"] = str(lldp_port_id) - if lldp_mgmt_addr: - query_params["lldp_mgmt_addr"] = str(lldp_mgmt_addr) if band_24_channel: query_params["band_24_channel"] = str(band_24_channel) if band_5_channel: @@ -1022,10 +976,64 @@ def searchSiteDevices( query_params["band_5_bandwidth"] = str(band_5_bandwidth) if band_6_bandwidth: query_params["band_6_bandwidth"] = str(band_6_bandwidth) + if band_24_power: + query_params["band_24_power"] = str(band_24_power) + if band_5_power: + query_params["band_5_power"] = str(band_5_power) + if band_6_power: + query_params["band_6_power"] = str(band_6_power) + if clustered: + query_params["clustered"] = str(clustered) if eth0_port_speed: query_params["eth0_port_speed"] = str(eth0_port_speed) + if evpntopo_id: + query_params["evpntopo_id"] = str(evpntopo_id) + if ext_ip: + query_params["ext_ip"] = str(ext_ip) + if hostname: + query_params["hostname"] = str(hostname) + if ip: + query_params["ip"] = str(ip) + if last_config_status: + query_params["last_config_status"] = str(last_config_status) + if last_hostname: + query_params["last_hostname"] = str(last_hostname) + if lldp_mgmt_addr: + query_params["lldp_mgmt_addr"] = str(lldp_mgmt_addr) + if lldp_port_id: + query_params["lldp_port_id"] = str(lldp_port_id) + if lldp_system_desc: + query_params["lldp_system_desc"] = str(lldp_system_desc) + if lldp_system_name: + query_params["lldp_system_name"] = str(lldp_system_name) + if mac: + query_params["mac"] = str(mac) + if model: + query_params["model"] = str(model) + if mxedge_id: + query_params["mxedge_id"] = str(mxedge_id) + if mxedge_ids: + query_params["mxedge_ids"] = str(mxedge_ids) + if mxtunnel_status: + query_params["mxtunnel_status"] = str(mxtunnel_status) + if node: + query_params["node"] = str(node) + if node0_mac: + query_params["node0_mac"] = str(node0_mac) + if node1_mac: + query_params["node1_mac"] = str(node1_mac) + if power_constrained: + query_params["power_constrained"] = str(power_constrained) + if radius_stats: + query_params["radius_stats"] = str(radius_stats) if stats: query_params["stats"] = str(stats) + if t128agent_version: + query_params["t128agent_version"] = str(t128agent_version) + if type: + query_params["type"] = str(type) + if version: + query_params["version"] = str(version) if limit: query_params["limit"] = str(limit) if start: diff --git a/src/mistapi/api/v1/sites/insights.py b/src/mistapi/api/v1/sites/insights.py index 9741b2c..3033df7 100644 --- a/src/mistapi/api/v1/sites/insights.py +++ b/src/mistapi/api/v1/sites/insights.py @@ -132,7 +132,7 @@ def getSiteInsightMetricsForDevice( return resp -def countOrgClientFingerprints( +def countSiteClientFingerprints( mist_session: _APISession, site_id: str, distinct: str | None = None, @@ -142,7 +142,7 @@ def countOrgClientFingerprints( limit: int | None = None, ) -> _APIResponse: """ - API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/orgs/nac-fingerprints/count-org-client-fingerprints + API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/sites/nac-fingerprints/count-site-client-fingerprints PARAMS ----------- @@ -183,7 +183,7 @@ def countOrgClientFingerprints( return resp -def searchOrgClientFingerprints( +def searchSiteClientFingerprints( mist_session: _APISession, site_id: str, family: str | None = None, @@ -202,7 +202,7 @@ def searchOrgClientFingerprints( search_after: str | None = None, ) -> _APIResponse: """ - API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/orgs/nac-fingerprints/search-org-client-fingerprints + API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/sites/nac-fingerprints/search-site-client-fingerprints PARAMS ----------- diff --git a/src/mistapi/api/v1/sites/nac_clients.py b/src/mistapi/api/v1/sites/nac_clients.py index 92d7bff..e5a0e81 100644 --- a/src/mistapi/api/v1/sites/nac_clients.py +++ b/src/mistapi/api/v1/sites/nac_clients.py @@ -334,6 +334,7 @@ def searchSiteNacClients( site_id: str, ap: str | None = None, auth_type: str | None = None, + cert_expiry_duration: str | None = None, edr_managed: bool | None = None, edr_provider: str | None = None, edr_status: str | None = None, @@ -341,15 +342,14 @@ def searchSiteNacClients( hostname: str | None = None, idp_id: str | None = None, mac: str | None = None, - mdm_managed: bool | None = None, mdm_compliance: str | None = None, mdm_provider: str | None = None, + mdm_managed: bool | None = None, mfg: str | None = None, model: str | None = None, - mxedge_id: str | None = None, + nacrule_name: str | None = None, nacrule_id: str | None = None, nacrule_matched: bool | None = None, - nacrule_name: str | None = None, nas_vendor: str | None = None, nas_ip: str | None = None, ingress_vlan: str | None = None, @@ -385,6 +385,7 @@ def searchSiteNacClients( ------------ ap : str auth_type : str + cert_expiry_duration : str edr_managed : bool edr_provider : str{'crowdstrike', 'sentinelone'} EDR provider of client's organization @@ -394,22 +395,21 @@ def searchSiteNacClients( hostname : str idp_id : str mac : str - mdm_managed : bool mdm_compliance : str mdm_provider : str + mdm_managed : bool mfg : str model : str - mxedge_id : str + nacrule_name : str nacrule_id : str nacrule_matched : bool - nacrule_name : str nas_vendor : str nas_ip : str ingress_vlan : str os : str ssid : str - status : str{'permitted', 'session_started', 'session_ended', 'denied'} - Connection status of client i.e "permitted", "denied, "session_ended" + status : str{'permitted', 'session_started', 'session_stopped', 'denied'} + Connection status of client i.e "permitted", "denied, "session_started", "session_stopped" text : str timestamp : float type : str @@ -436,6 +436,8 @@ def searchSiteNacClients( query_params["ap"] = str(ap) if auth_type: query_params["auth_type"] = str(auth_type) + if cert_expiry_duration: + query_params["cert_expiry_duration"] = str(cert_expiry_duration) if edr_managed: query_params["edr_managed"] = str(edr_managed) if edr_provider: @@ -450,24 +452,22 @@ def searchSiteNacClients( query_params["idp_id"] = str(idp_id) if mac: query_params["mac"] = str(mac) - if mdm_managed: - query_params["mdm_managed"] = str(mdm_managed) if mdm_compliance: query_params["mdm_compliance"] = str(mdm_compliance) if mdm_provider: query_params["mdm_provider"] = str(mdm_provider) + if mdm_managed: + query_params["mdm_managed"] = str(mdm_managed) if mfg: query_params["mfg"] = str(mfg) if model: query_params["model"] = str(model) - if mxedge_id: - query_params["mxedge_id"] = str(mxedge_id) + if nacrule_name: + query_params["nacrule_name"] = str(nacrule_name) if nacrule_id: query_params["nacrule_id"] = str(nacrule_id) if nacrule_matched: query_params["nacrule_matched"] = str(nacrule_matched) - if nacrule_name: - query_params["nacrule_name"] = str(nacrule_name) if nas_vendor: query_params["nas_vendor"] = str(nas_vendor) if nas_ip: diff --git a/src/mistapi/api/v1/sites/sle.py b/src/mistapi/api/v1/sites/sle.py index c7b5428..71c2676 100644 --- a/src/mistapi/api/v1/sites/sle.py +++ b/src/mistapi/api/v1/sites/sle.py @@ -10,15 +10,16 @@ -------------------------------------------------------------------------------- """ +import deprecation + from mistapi import APISession as _APISession from mistapi.__api_response import APIResponse as _APIResponse -import deprecation @deprecation.deprecated( deprecated_in="0.59.2", removed_in="0.65.0", - current_version="0.60.4", + current_version="0.61.0", details="function replaced with getSiteSleClassifierSummaryTrend", ) def getSiteSleClassifierDetails( @@ -72,57 +73,6 @@ def getSiteSleClassifierDetails( return resp -def getSiteSleClassifierDetails( - mist_session: _APISession, - site_id: str, - scope: str, - scope_id: str, - metric: str, - classifier: str, - start: str | None = None, - end: str | None = None, - duration: str | None = None, -) -> _APIResponse: - """ - API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/sites/sles/get-site-sle-classifier-details - - PARAMS - ----------- - mistapi.APISession : mist_session - mistapi session including authentication and Mist host information - - PATH PARAMS - ----------- - site_id : str - scope : str{'ap', 'client', 'gateway', 'site', 'switch'} - scope_id : str - metric : str - classifier : str - - QUERY PARAMS - ------------ - start : str - end : str - duration : str, default: 1d - - RETURN - ----------- - mistapi.APIResponse - response from the API call - """ - - uri = f"/api/v1/sites/{site_id}/sle/{scope}/{scope_id}/metric/{metric}/classifier/{classifier}/summary" - query_params: dict[str, str] = {} - if start: - query_params["start"] = str(start) - if end: - query_params["end"] = str(end) - if duration: - query_params["duration"] = str(duration) - resp = mist_session.mist_get(uri=uri, query=query_params) - return resp - - def getSiteSleClassifierSummaryTrend( mist_session: _APISession, site_id: str, @@ -741,7 +691,7 @@ def listSiteSleImpactedWirelessClients( @deprecation.deprecated( deprecated_in="0.59.2", removed_in="0.65.0", - current_version="0.60.4", + current_version="0.61.0", details="function replaced with getSiteSleSummaryTrend", ) def getSiteSleSummary( @@ -793,55 +743,6 @@ def getSiteSleSummary( return resp -def getSiteSleSummary( - mist_session: _APISession, - site_id: str, - scope: str, - scope_id: str, - metric: str, - start: str | None = None, - end: str | None = None, - duration: str | None = None, -) -> _APIResponse: - """ - API doc: https://www.juniper.net/documentation/us/en/software/mist/api/http/api/sites/sles/get-site-sle-summary - - PARAMS - ----------- - mistapi.APISession : mist_session - mistapi session including authentication and Mist host information - - PATH PARAMS - ----------- - site_id : str - scope : str{'ap', 'client', 'gateway', 'site', 'switch'} - scope_id : str - metric : str - - QUERY PARAMS - ------------ - start : str - end : str - duration : str, default: 1d - - RETURN - ----------- - mistapi.APIResponse - response from the API call - """ - - uri = f"/api/v1/sites/{site_id}/sle/{scope}/{scope_id}/metric/{metric}/summary" - query_params: dict[str, str] = {} - if start: - query_params["start"] = str(start) - if end: - query_params["end"] = str(end) - if duration: - query_params["duration"] = str(duration) - resp = mist_session.mist_get(uri=uri, query=query_params) - return resp - - def getSiteSleSummaryTrend( mist_session: _APISession, site_id: str, diff --git a/src/mistapi/api/v1/sites/wan_clients.py b/src/mistapi/api/v1/sites/wan_clients.py index b3dc7f8..0b30029 100644 --- a/src/mistapi/api/v1/sites/wan_clients.py +++ b/src/mistapi/api/v1/sites/wan_clients.py @@ -147,10 +147,12 @@ def searchSiteWanClientEvents( def searchSiteWanClients( mist_session: _APISession, site_id: str, - mac: str | None = None, hostname: str | None = None, ip: str | None = None, + ip_src: str | None = None, + mac: str | None = None, mfg: str | None = None, + network: str | None = None, limit: int | None = None, start: str | None = None, end: str | None = None, @@ -172,10 +174,12 @@ def searchSiteWanClients( QUERY PARAMS ------------ - mac : str hostname : str ip : str + ip_src : str + mac : str mfg : str + network : str limit : int, default: 100 start : str end : str @@ -191,14 +195,18 @@ def searchSiteWanClients( uri = f"/api/v1/sites/{site_id}/wan_clients/search" query_params: dict[str, str] = {} - if mac: - query_params["mac"] = str(mac) if hostname: query_params["hostname"] = str(hostname) if ip: query_params["ip"] = str(ip) + if ip_src: + query_params["ip_src"] = str(ip_src) + if mac: + query_params["mac"] = str(mac) if mfg: query_params["mfg"] = str(mfg) + if network: + query_params["network"] = str(network) if limit: query_params["limit"] = str(limit) if start: diff --git a/src/mistapi/device_utils/__init__.py b/src/mistapi/device_utils/__init__.py new file mode 100644 index 0000000..dea7134 --- /dev/null +++ b/src/mistapi/device_utils/__init__.py @@ -0,0 +1,60 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- + +Mist API Utilities +================== + +This package provides utility functions for interacting with Mist devices. + +Device-Specific Modules (Recommended) +-------------------------------------- +Import device-specific modules for a clean, organized API: + + from mistapi.utils import ap, ex, srx, ssr + + # Use device-specific utilities + ap.ping(session, site_id, device_id, host) + ex.cableTest(session, site_id, device_id, port_id) + ssr.showServicePath(session, site_id, device_id) + +Supported Devices: +- ap: Mist Access Points +- ex: Juniper EX Switches +- srx: Juniper SRX Firewalls +- ssr: Juniper Session Smart Routers + +Function-Based Modules (Legacy) +--------------------------------- +Original organization by function type (still available): + + from mistapi.utils import arp, bgp, dhcp, mac, port, routes, tools + +Available modules: arp, bgp, bpdu, dhcp, dns, dot1x, mac, policy, port, routes, + service_path, tools +""" + +# Device-specific modules (recommended) +# Function-based modules (legacy, still supported) +# Internal modules +from mistapi.device_utils import ( + ap, + ex, + srx, + ssr, +) + +__all__ = [ + # Device-specific modules (recommended) + "ap", + "ex", + "srx", + "ssr", +] diff --git a/src/mistapi/device_utils/__tools/__init__.py b/src/mistapi/device_utils/__tools/__init__.py new file mode 100644 index 0000000..1110f88 --- /dev/null +++ b/src/mistapi/device_utils/__tools/__init__.py @@ -0,0 +1,42 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- + +Mist API Utilities +================== + +This package provides utility functions for interacting with Mist devices. + +Device-Specific Modules (Recommended) +-------------------------------------- +Import device-specific modules for a clean, organized API: + + from mistapi.utils import ap, ex, srx, ssr + + # Use device-specific utilities + ap.ping(session, site_id, device_id, host) + ex.cable_test(session, site_id, device_id, port_id) + ssr.show_service_path(session, site_id, device_id) + +Supported Devices: +- ap: Mist Access Points +- ex: Juniper EX Switches +- srx: Juniper SRX Firewalls +- ssr: Juniper Session Smart Routers + +Function-Based Modules (Legacy) +--------------------------------- +Original organization by function type (still available): + + from mistapi.utils import arp, bgp, dhcp, mac, port, routes, tools + +Available modules: arp, bgp, bpdu, dhcp, dns, dot1x, mac, policy, port, routes, + service_path, tools +""" diff --git a/src/mistapi/device_utils/__tools/__ws_wrapper.py b/src/mistapi/device_utils/__tools/__ws_wrapper.py new file mode 100644 index 0000000..a1d8d8a --- /dev/null +++ b/src/mistapi/device_utils/__tools/__ws_wrapper.py @@ -0,0 +1,254 @@ +import json +import threading +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession +from mistapi.__api_response import APIResponse as _APIResponse +from mistapi.__logger import logger as LOGGER + + +class TimerAction(Enum): + """ + TimerAction Enum for managing timer actions in WebSocketWrapper. + """ + + START = "start" + STOP = "stop" + RESET = "reset" + + +class Timer(Enum): + """ + Timer Enum for specifying different timer types in WebSocketWrapper. + """ + + TIMEOUT = "timeout" + FIRST_MESSAGE_TIMEOUT = "first_message_timeout" + MAX_DURATION = "max_duration" + + +class UtilResponse: + """ + A simple class to encapsulate the response from utility WebSocket functions. + This class can be extended in the future to include additional metadata or helper methods. + """ + + def __init__( + self, + api_response: _APIResponse, + ) -> None: + self.trigger_api_response = api_response + # This can be set to True if the WebSocket connection was successfully initiated + self.ws_required: bool = False + self.ws_data: list[str] = [] + self.ws_raw_events: list[str] = [] + + +class WebSocketWrapper: + """ + A wrapper class for managing WebSocket connections and events. + This class provides a simplified interface for connecting to WebSocket channels, + handling messages, and managing connection timeouts. + """ + + def __init__( + self, + apissession: APISession, + util_response: UtilResponse, + timeout: int = 10, + max_duration: int = 60, + on_message: Callable[[dict], None] | None = None, + ) -> None: + self.apissession = apissession + self.util_response = util_response + self.timers = { + Timer.TIMEOUT.value: { + "thread": None, + "duration": timeout, + }, + Timer.FIRST_MESSAGE_TIMEOUT.value: { + "thread": None, + "duration": 30, + }, + Timer.MAX_DURATION.value: { + "thread": None, + "duration": max_duration, + }, + } + self.received_messages = 0 + self.data = [] + self.raw_events = [] + self.ws = None + self.session_id: str | None = None + self.capture_id: str | None = None + self._on_message_cb = on_message + self._closed = threading.Event() + + LOGGER.debug( + "trigger response: %s", self.util_response.trigger_api_response.data + ) + if self.util_response.trigger_api_response.data and isinstance( + self.util_response.trigger_api_response.data, dict + ): + self.session_id = self.util_response.trigger_api_response.data.get( + "session", None + ) + self.capture_id = self.util_response.trigger_api_response.data.get( + "id", None + ) + LOGGER.debug("Extracted session_id: %s", self.session_id) + LOGGER.debug("Extracted capture_id: %s", self.capture_id) + + def _on_open(self): + LOGGER.info("WebSocket connection opened") + # Start the max duration timer + self._timeout_handler(Timer.MAX_DURATION, TimerAction.START) + + def _on_close(self, code, msg): + LOGGER.info("WebSocket closed: %s - %s", code, msg) + self._closed.set() + + ########################################################################## + ## Helper methods for managing timers + def _timeout_handler(self, timer_type: Timer, action: TimerAction): + duration = self.timers[timer_type.value]["duration"] + if action == TimerAction.STOP or action == TimerAction.RESET: + if self.timers[timer_type.value]["thread"]: + LOGGER.debug("Stopping %s timer", timer_type.value) + self.timers[timer_type.value]["thread"].cancel() + self.timers[timer_type.value]["thread"] = None + elif action == TimerAction.STOP: + # Only warn when explicitly stopping (not resetting) a non-active timer + LOGGER.warning("%s timer is not active to stop", timer_type.value) + if action == TimerAction.START or action == TimerAction.RESET: + if self.ws: + LOGGER.debug( + "Starting %s timer with duration: %s seconds", + timer_type.value, + duration, + ) + self.timers[timer_type.value]["thread"] = threading.Timer( + duration, self.ws.disconnect + ) + self.timers[timer_type.value]["thread"].start() + else: + LOGGER.warning( + "WebSocket is not available to start %s timer", timer_type.value + ) + + def _stop_all_timers(self): + for timer_info in self.timers.values(): + if timer_info["thread"]: + timer_info["thread"].cancel() + timer_info["thread"] = None + + ########################################################################## + ## WebSocket event handlers + + def _handle_message(self, msg): + if isinstance(msg, dict) and msg.get("event") == "channel_subscribed": + LOGGER.debug("channel_subscribed: %s", msg) + # Start the first message timeout timer when the channel is successfully subscribed + self._timeout_handler(Timer.FIRST_MESSAGE_TIMEOUT, TimerAction.START) + elif self._extract_session_id(msg): + # Stop the first message timeout timer on receiving the first message + self._timeout_handler(Timer.FIRST_MESSAGE_TIMEOUT, TimerAction.STOP) + LOGGER.debug("data: %s", msg) + raw = self._extract_raw(msg) + if raw: + self.data.append(raw) + if self._on_message_cb: + self._on_message_cb(raw) + self._timeout_handler(Timer.TIMEOUT, TimerAction.RESET) + + ########################################################################## + ## Message processing and WebSocket connection management + def _extract_session_id(self, message) -> bool: + """ + Extracts the session_id from the message and compares it to the expected session_id. + This method is designed to handle messages that may have the session_id nested at + different levels. + If the expected session_id is None, it will accept all messages. + """ + if not self.session_id and not self.capture_id: + LOGGER.debug("No session_id or capture_id provided, accepting all messages") + return True + if isinstance(message, str): + LOGGER.debug("Trying to decode message: %s", message) + try: + message = json.loads(message) + except json.JSONDecodeError: + LOGGER.warning("Failed to decode message as JSON: %s", message) + return False + if isinstance(message, dict): + if message.get("event") == "data" and message.get("data"): + LOGGER.debug( + "Checking nested data for session_id or capture_id: %s", + message["data"], + ) + return self._extract_session_id(message["data"]) + if message.get("session") == self.session_id: + LOGGER.info( + "Message session_id matches expected session_id: %s", + self.session_id, + ) + return True + if message.get("capture_id") == self.capture_id: + LOGGER.info( + "Message capture_id matches expected capture_id: %s", + self.capture_id, + ) + return True + return False + + def _extract_raw(self, message): + """ + Extracts the raw message from the given message. + This method is designed to handle messages that may have the raw message nested at + different levels. + Handles both command events (with "raw" field) and pcap events (with "pcap_dict" field). + """ + self.raw_events.append(message) + event = message + if isinstance(event, str): + try: + event = json.loads(event) + except json.JSONDecodeError: + LOGGER.warning("Failed to decode message as JSON: %s", message) + return None + if isinstance(event, dict): + if event.get("event") == "data" and event.get("data"): + return self._extract_raw(event["data"]) + if "raw" in event: + self.received_messages += 1 + LOGGER.debug("Extracted raw message: %s", event["raw"]) + return event["raw"] + if "pcap_dict" in event: + self.received_messages += 1 + LOGGER.debug("Extracted pcap data: %s", event["pcap_dict"]) + return event["pcap_dict"] + return None + + ########################################################################## + ## WebSocket connection management + def start(self, ws) -> UtilResponse: + """ + Start the WS connection, block until closed, return UtilResponse. + + PARAMS + ----------- + ws : _MistWebsocket + An already-constructed WebSocket channel object. + """ + self.ws = ws + ws.on_message(self._handle_message) + ws.on_error(lambda error: LOGGER.error("Error: %s", error)) + ws.on_close(self._on_close) + ws.on_open(self._on_open) + ws.connect(run_in_background=False) # blocks until _on_close fires + self._stop_all_timers() + self.util_response.ws_required = True + self.util_response.ws_data = self.data + self.util_response.ws_raw_events = self.raw_events + return self.util_response diff --git a/src/mistapi/device_utils/__tools/arp.py b/src/mistapi/device_utils/__tools/arp.py new file mode 100644 index 0000000..f9b3d6d --- /dev/null +++ b/src/mistapi/device_utils/__tools/arp.py @@ -0,0 +1,218 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +class Node(Enum): + """Node Enum for specifying node information in ARP commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +def retrieve_ap_arp_table( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + timeout=1, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: AP + + Retrieves the ARP table from a Mist Access Point and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to retrieve the ARP table from. + node : Node, optional + Node information for the ARP table retrieval command. + timeout : int, optional + Timeout for the ARP table retrieval command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + # AP is returning RAW data + # SWITCH is returning ??? + # GATEWAY is returning JSON + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + trigger = devices.arpFromDevice( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Show ARP command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" + ) # Give the show ARP command a moment to take effect + return util_response + + +def retrieve_ssr_arp_table( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + timeout=1, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SSR + + Retrieves the ARP table from a SSR Gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to retrieve the ARP table from. + node : Node, optional + Node information for the ARP table retrieval command. + timeout : int, optional + Timeout for the ARP table retrieval command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from + the WebSocket stream. + """ + # AP is returning RAW data + # SWITCH is returning ??? + # GATEWAY is returning JSON + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + trigger = devices.arpFromDevice( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Show ARP command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" + ) # Give the show ARP command a moment to take effect + return util_response + + +def retrieve_junos_arp_table( + apissession: _APISession, + site_id: str, + device_id: str, + ip: str | None = None, + port_id: str | None = None, + vrf: str | None = None, + timeout=1, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: EX, SRX + + Retrieve the ARP table from a Junos device (EX / SRX) with optional filters for IP, port, + and VRF. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to retrieve the ARP table from. + ip : str, optional + IP address to filter the ARP table. + port_id : str, optional + Port ID to filter the ARP table. + vrf : str, optional + VRF to filter the ARP table. + timeout : int, optional + Timeout for the ARP table retrieval command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"duration": 1, "interval": 1} + if ip: + body["ip"] = ip + if vrf: + body["vrf"] = vrf + if port_id: + body["port_id"] = port_id + trigger = devices.showSiteDeviceArpTable( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Show ARP command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger show ARP command: {trigger.status_code} - {trigger.data}" + ) # Give the show ARP command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/__tools/bgp.py b/src/mistapi/device_utils/__tools/bgp.py new file mode 100644 index 0000000..f545c57 --- /dev/null +++ b/src/mistapi/device_utils/__tools/bgp.py @@ -0,0 +1,70 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +def summary( + apissession: _APISession, + site_id: str, + device_id: str, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: EX, SRX, SSR + + Shows BGP summary on a device (EX/ SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show BGP summary on. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"protocol": "bgp"} + trigger = devices.showSiteDeviceBgpSummary( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"BGP summary command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger BGP summary command: {trigger.status_code} - {trigger.data}" + ) # Give the BGP summary command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/__tools/bpdu.py b/src/mistapi/device_utils/__tools/bpdu.py new file mode 100644 index 0000000..0bdf96b --- /dev/null +++ b/src/mistapi/device_utils/__tools/bpdu.py @@ -0,0 +1,61 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse + + +async def clear_error( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], +) -> UtilResponse: + """ + DEVICES: EX + + Clears BPDU error state on the specified ports of a switch. + + PARAMS + ----------- + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to clear BPDU errors on. + port_ids : list[str] + List of port IDs to clear BPDU errors on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {"ports": port_ids} + trigger = devices.clearBpduErrorsFromPortsOnSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear BPDU error command triggered for device {device_id}") + else: + LOGGER.error( + f"Failed to trigger clear BPDU error command: {trigger.status_code} - {trigger.data}" + ) # Give the clear BPDU error command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/__tools/dhcp.py b/src/mistapi/device_utils/__tools/dhcp.py new file mode 100644 index 0000000..e0ed5f0 --- /dev/null +++ b/src/mistapi/device_utils/__tools/dhcp.py @@ -0,0 +1,172 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +class Node(Enum): + """Node Enum for specifying node information in DHCP commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +def release_dhcp_leases( + apissession: _APISession, + site_id: str, + device_id: str, + macs: list[str] | None = None, + network: str | None = None, + node: Node | None = None, + port_id: str | None = None, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: EX, SRX, SSR + + Releases DHCP leases on a device (EX/ SRX / SSR) and streams the results. + + valid combinations for EX are: + - network + macs + - network + port_id + - port_id + + valid combinations for SRX / SSR are: + - network + - network + macs + - network + port_id + - port_id + - port_id + macs + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to release DHCP leases on. + macs : list[str], optional + List of MAC addresses to release DHCP leases for. + network : str, optional + Network to release DHCP leases for. + node : Node, optional + Node information for the DHCP lease release command. + port_id : str, optional + Port ID to release DHCP leases for. + timeout : int, optional + Timeout for the release DHCP leases command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if macs: + body["macs"] = macs + if network: + body["network"] = network + if node: + body["node"] = node.value + if port_id: + body["port_id"] = port_id + trigger = devices.releaseSiteDeviceDhcpLease( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Release DHCP leases command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger release DHCP leases command: {trigger.status_code} - {trigger.data}" + ) # Give the release DHCP leases command a moment to take effect + return util_response + + +def retrieve_dhcp_leases( + apissession: _APISession, + site_id: str, + device_id: str, + network: str, + node: Node | None = None, + timeout=15, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Retrieves DHCP leases on a gateway (SRX / SSR) and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to retrieve DHCP leases from. + network : str + Network to release DHCP leases for. + node : Node, optional + Node information for the DHCP lease release command. + port_id : str, optional + Port ID to release DHCP leases for. + timeout : int, optional + Timeout for the release DHCP leases command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"network": network} + if node: + body["node"] = node.value + trigger = devices.showSiteDeviceDhcpLeases( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Retrieve DHCP leases command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger retrieve DHCP leases command: {trigger.status_code} - {trigger.data}" + ) # Give the release DHCP leases command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/__tools/dns.py b/src/mistapi/device_utils/__tools/dns.py new file mode 100644 index 0000000..4f5cca2 --- /dev/null +++ b/src/mistapi/device_utils/__tools/dns.py @@ -0,0 +1,84 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from enum import Enum + + +class Node(Enum): + """Node Enum for specifying node information in DNS commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +## NO DATA +# def test_resolution( +# apissession: _APISession, +# site_id: str, +# device_id: str, +# node: Node | None = None, +# hostname: str | None = None, +# timeout=5, +# on_message: Callable[[dict], None] | None = None, +# ) -> UtilResponse: +# """ +# DEVICES: SSR + +# Initiates a DNS resolution command on the gateway and streams the results. + +# PARAMS +# ----------- +# apissession : _APISession +# The API session to use for the request. +# site_id : str +# UUID of the site where the gateway is located. +# device_id : str +# UUID of the gateway to perform the DNS resolution command on. +# node : Node, optional +# Node information for the DNS resolution command. +# hostname : str, optional +# Hostname to resolve. +# timeout : int, optional +# Timeout for the command in seconds. +# on_message : Callable, optional +# Callback invoked with each extracted raw message as it arrives. + +# RETURNS +# ----------- +# UtilResponse +# A UtilResponse object containing the API response and a list of raw messages received +# from the WebSocket stream. +# """ +# body: dict[str, str | list | int] = {} +# if node: +# body["node"] = node.value +# if hostname: +# body["hostname"] = hostname +# trigger = devices.testSiteSsrDnsResolution( +# apissession, +# site_id=site_id, +# device_id=device_id, +# body=body, +# ) +# util_response = UtilResponse(trigger) +# if trigger.status_code == 200: +# LOGGER.info(trigger.data) +# print(f"SSR DNS resolution command triggered for device {device_id}") +# ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) +# util_response = WebSocketWrapper( +# apissession, util_response, timeout=timeout, on_message=on_message +# ).start(ws) +# else: +# LOGGER.error( +# f"Failed to trigger SSR DNS resolution command: {trigger.status_code} - {trigger.data}" +# ) # Give the SSR DNS resolution command a moment to take effect +# return util_response diff --git a/src/mistapi/device_utils/__tools/dot1x.py b/src/mistapi/device_utils/__tools/dot1x.py new file mode 100644 index 0000000..537e65d --- /dev/null +++ b/src/mistapi/device_utils/__tools/dot1x.py @@ -0,0 +1,60 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse + + +async def clear_sessions( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], +) -> UtilResponse: + """ + DEVICES: EX + + Clears dot1x sessions on the specified ports of a switch (EX). + + PARAMS + ----------- + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to clear dot1x sessions on. + port_ids : list[str] + List of port IDs to clear dot1x sessions on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"ports": port_ids} + trigger = devices.clearAllLearnedMacsFromPortOnSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear learned MACs command triggered for device {device_id}") + else: + LOGGER.error( + f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" + ) # Give the clear learned MACs command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/__tools/mac.py b/src/mistapi/device_utils/__tools/mac.py new file mode 100644 index 0000000..d68441a --- /dev/null +++ b/src/mistapi/device_utils/__tools/mac.py @@ -0,0 +1,199 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +def clear_mac_table( + apissession: _APISession, + site_id: str, + device_id: str, + mac_address: str | None = None, + port_id: str | None = None, + vlan_id: str | None = None, + # timeout=30, +) -> UtilResponse: + """ + DEVICES: EX + + Clears the MAC table on a switch (EX). + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to clear the MAC table from. + mac_address : str, optional + MAC address to clear from the MAC table. + port_id : str, optional + Port ID to clear from the MAC table. + vlan_id : str, optional + VLAN ID to clear from the MAC table. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + # AP is returning RAW data + # SWITCH is returning ??? + # GATEWAY is returning JSON + body: dict[str, str | list | int] = {} + if mac_address: + body["mac_address"] = mac_address + if port_id: + body["port_id"] = port_id + if vlan_id: + body["vlan_id"] = vlan_id + trigger = devices.clearSiteDeviceMacTable( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear MAC Table command triggered for device {device_id}") + # util_response = WebSocketWrapper( + # apissession, util_response, timeout=timeout, on_message=on_message + # ).start(ws) + else: + LOGGER.error( + f"Failed to trigger clear MAC Table command: {trigger.status_code} - {trigger.data}" + ) # Give the clear MAC Table command a moment to take effect + return util_response + + +def retrieve_mac_table( + apissession: _APISession, + site_id: str, + device_id: str, + mac_address: str | None = None, + port_id: str | None = None, + vlan_id: str | None = None, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: EX + + Retrieves the MAC Table table from a switch (EX) and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to retrieve the ARP table from. + mac_address : str, optional + MAC address to filter the ARP table retrieval. + port_id : str, optional + Port ID to filter the ARP table retrieval. + vlan_id : str, optional + VLAN ID to filter the ARP table retrieval. + timeout : int, optional + Timeout for the ARP table retrieval command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + # AP is returning RAW data + # SWITCH is returning ??? + # GATEWAY is returning JSON + body: dict[str, str | list | int] = {} + if mac_address: + body["mac_address"] = mac_address + if port_id: + body["port_id"] = port_id + if vlan_id: + body["vlan_id"] = vlan_id + trigger = devices.showSiteDeviceMacTable( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Show MAC Table command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger show MAC Table command: {trigger.status_code} - {trigger.data}" + ) # Give the show ARP command a moment to take effect + return util_response + + +def clear_learned_mac( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], +) -> UtilResponse: + """ + DEVICES: EX + + Clears learned MAC addresses on the specified ports of a switch (EX). + + PARAMS + ----------- + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to clear learned MAC addresses on. + port_ids : list[str] + List of port IDs to clear learned MAC addresses on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"ports": port_ids} + trigger = devices.clearSiteDeviceDot1xSession( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear learned MACs command triggered for device {device_id}") + else: + LOGGER.error( + f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" + ) # Give the clear learned MACs command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/__tools/miscellaneous.py b/src/mistapi/device_utils/__tools/miscellaneous.py new file mode 100644 index 0000000..ccc1bc4 --- /dev/null +++ b/src/mistapi/device_utils/__tools/miscellaneous.py @@ -0,0 +1,364 @@ +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.session import SessionWithUrl +from mistapi.websockets.sites import DeviceCmdEvents + + +class Node(Enum): + """Node Enum for specifying node information in commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +class TracerouteProtocol(Enum): + """Enum for specifying protocol in traceroute command.""" + + ICMP = "icmp" + UDP = "udp" + + +def ping( + apissession: _APISession, + site_id: str, + device_id: str, + host: str, + count: int | None = None, + node: Node | None = None, + size: int | None = None, + vrf: str | None = None, + timeout: int = 3, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: AP, EX, SRX, SSR + + Initiates a ping command from a device (AP / EX/ SRX / SSR) to a specified host and + streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to initiate the ping from. + host : str + The host to ping. + count : int, optional + Number of ping requests to send. + node : None, optional + Node information for the ping command. + size : int, optional + Size of the ping packet. + vrf : str, optional + VRF to use for the ping command. + timeout : int, optional + Timeout for the ping command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if count: + body["count"] = count + if host: + body["host"] = host + if node: + body["node"] = node.value + if size: + body["size"] = size + if vrf: + body["vrf"] = vrf + trigger = devices.pingFromDevice( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Ping command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger ping command: {trigger.status_code} - {trigger.data}" + ) # Give the ping command a moment to take effect + return util_response + + +## NO DATA +# def service_ping( +# apissession: _APISession, +# site_id: str, +# device_id: str, +# host: str, +# service: str, +# tenant: str, +# count: int | None = None, +# node: None | None = None, +# size: int | None = None, +# timeout: int = 3, +# on_message: Callable[[dict], None] | None = None, +# ) -> UtilResponse: +# """ +# DEVICES: SSR + +# Initiates a service ping command from a SSR to a specified host and streams the results. + +# PARAMS +# ----------- +# apissession : _APISession +# The API session to use for the request. +# site_id : str +# UUID of the site where the device is located. +# device_id : str +# UUID of the device to initiate the ping from. +# host : str +# The host to ping. +# service : str +# The service to ping. +# tenant : str +# Tenant to use for the ping command. +# count : int, optional +# Number of ping requests to send. +# node : None, optional +# Node information for the ping command. +# size : int, optional +# Size of the ping packet. +# timeout : int, optional +# Timeout for the ping command in seconds. +# on_message : Callable, optional +# Callback invoked with each extracted raw message as it arrives. + +# RETURNS +# ----------- +# UtilResponse +# A UtilResponse object containing the API response and a list of raw messages received +# from the WebSocket stream. +# """ +# body: dict[str, str | list | int] = {} +# if count: +# body["count"] = count +# if host: +# body["host"] = host +# if node: +# body["node"] = node.value +# if size: +# body["size"] = size +# if tenant: +# body["tenant"] = tenant +# if service: +# body["service"] = service +# trigger = devices.servicePingFromSsr( +# apissession, +# site_id=site_id, +# device_id=device_id, +# body=body, +# ) +# util_response = UtilResponse(trigger) +# if trigger.status_code == 200: +# LOGGER.info(f"Service Ping command triggered for device {device_id}") +# ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) +# util_response = WebSocketWrapper( +# apissession, util_response, timeout, on_message=on_message +# ).start(ws) +# else: +# LOGGER.error( +# f"Failed to trigger Service Ping command: {trigger.status_code} - {trigger.data}" +# ) # Give the ping command a moment to take effect +# return util_response + + +def traceroute( + apissession: _APISession, + site_id: str, + device_id: str, + host: str, + protocol: TracerouteProtocol = TracerouteProtocol.ICMP, + port: int | None = None, + timeout: int = 10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: AP, EX, SRX, SSR + + Initiates a traceroute command from a device (AP / EX/ SRX / SSR) to a specified host and + streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to initiate the traceroute from. + host : str + The host to traceroute. + protocol : TracerouteProtocol, optional + Protocol to use for the traceroute command (icmp or udp). + port : int, optional + Port to use for UDP traceroute. + timeout : int, optional + Timeout for the traceroute command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"host": host} + if protocol: + body["protocol"] = protocol.value + if port: + body["port"] = port + trigger = devices.tracerouteFromDevice( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Traceroute command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger traceroute command: {trigger.status_code} - {trigger.data}" + ) # Give the traceroute command a moment to take effect + return util_response + + +def monitor_traffic( + apissession: _APISession, + site_id: str, + device_id: str, + port_id: str | None = None, + timeout=30, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: EX, SRX + + Initiates a monitor traffic command on the device and streams the results. + + * if `port_id` is provided, JUNOS uses cmd "monitor interface" to monitor traffic on particular + * if `port_id` is not provided, JUNOS uses cmd "monitor interface traffic" to monitor traffic + on all ports + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to monitor traffic on. + port_id : str, optional + Port ID to filter the traffic. + timeout : int, optional + Timeout for the monitor traffic command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | int] = {"duration": 60} + if port_id: + body["port"] = port_id + trigger = devices.monitorSiteDeviceTraffic( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Monitor traffic command triggered for device {device_id}") + ws = SessionWithUrl(apissession, url=trigger.data.get("url", "")) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger monitor traffic command: {trigger.status_code} - {trigger.data}" + ) # Give the monitor traffic command a moment to take effect + return util_response + + +## NO DATA +# def srx_top_command( +# apissession: _APISession, +# site_id: str, +# device_id: str, +# timeout=10, +# on_message: Callable[[dict], None] | None = None, +# ) -> UtilResponse: +# """ +# DEVICE: SRX + +# For SRX Only. Initiates a top command on the device and streams the results. + +# PARAMS +# ----------- +# apissession : _APISession +# The API session to use for the request. +# site_id : str +# UUID of the site where the device is located. +# device_id : str +# UUID of the device to run the top command on. +# timeout : int, optional +# Timeout for the top command in seconds. +# on_message : Callable, optional +# Callback invoked with each extracted raw message as it arrives. + +# RETURNS +# ----------- +# UtilResponse +# A UtilResponse object containing the API response and a list of raw messages received +# from the WebSocket stream. +# """ +# trigger = devices.runSiteSrxTopCommand( +# apissession, +# site_id=site_id, +# device_id=device_id, +# ) +# util_response = UtilResponse(trigger) +# if trigger.status_code == 200: +# LOGGER.info(trigger.data) +# print(f"Top command triggered for device {device_id}") +# ws = SessionWithUrl(apissession, url=trigger.data.get("url", "")) +# util_response = WebSocketWrapper( +# apissession, util_response, timeout=timeout, on_message=on_message +# ).start(ws) +# else: +# LOGGER.error( +# f"Failed to trigger top command: {trigger.status_code} - {trigger.data}" +# ) # Give the top command a moment to take effect +# return util_response diff --git a/src/mistapi/device_utils/__tools/ospf.py b/src/mistapi/device_utils/__tools/ospf.py new file mode 100644 index 0000000..09eda9e --- /dev/null +++ b/src/mistapi/device_utils/__tools/ospf.py @@ -0,0 +1,291 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +class Node(Enum): + """Node Enum for specifying node information in OSPF commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +def show_database( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + self_originate: bool | None = None, + vrf: str | None = None, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Shows OSPF database on a device (SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show OSPF database on. + node : Node, optional + Node information for the show OSPF database command. + self_originate : bool, optional + Filter for self-originated routes in the OSPF database. + vrf : str, optional + VRF to filter the OSPF database. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if self_originate is not None: + body["self_originate"] = self_originate + if vrf: + body["vrf"] = vrf + trigger = devices.showSiteGatewayOspfDatabase( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"OSPF database command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger OSPF database command: {trigger.status_code} - {trigger.data}" + ) # Give the OSPF database command a moment to take effect + return util_response + + +def show_interfaces( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + port_id: str | None = None, + vrf: str | None = None, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Shows OSPF interfaces on a device (SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show OSPF interfaces on. + node : Node, optional + Node information for the show OSPF interfaces command. + port_id : str, optional + Port ID to filter the OSPF interfaces. + vrf : str, optional + VRF to filter the OSPF interfaces. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if port_id: + body["port_id"] = port_id + if vrf: + body["vrf"] = vrf + trigger = devices.showSiteGatewayOspfInterfaces( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"OSPF interfaces command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger OSPF interfaces command: {trigger.status_code} - {trigger.data}" + ) # Give the OSPF interfaces command a moment to take effect + return util_response + + +def show_neighbors( + apissession: _APISession, + site_id: str, + device_id: str, + neighbor: str | None = None, + node: Node | None = None, + port_id: str | None = None, + vrf: str | None = None, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Shows OSPF neighbors on a device (SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show OSPF neighbors on. + neighbor : str, optional + Neighbor IP address to filter the OSPF neighbors. + node : Node, optional + Node information for the show OSPF neighbors command. + port_id : str, optional + Port ID to filter the OSPF neighbors. + vrf : str, optional + VRF to filter the OSPF neighbors. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if port_id: + body["port_id"] = port_id + if vrf: + body["vrf"] = vrf + if neighbor: + body["neighbor"] = neighbor + trigger = devices.showSiteGatewayOspfNeighbors( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"OSPF neighbors command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger OSPF neighbors command: {trigger.status_code} - {trigger.data}" + ) # Give the OSPF neighbors command a moment to take effect + return util_response + + +def show_summary( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + vrf: str | None = None, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Shows OSPF summary on a device (SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show OSPF summary on. + node : Node, optional + Node information for the show OSPF summary command. + vrf : str, optional + VRF to filter the OSPF summary. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if vrf: + body["vrf"] = vrf + trigger = devices.showSiteGatewayOspfSummary( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"OSPF summary command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger OSPF summary command: {trigger.status_code} - {trigger.data}" + ) # Give the OSPF summary command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/__tools/policy.py b/src/mistapi/device_utils/__tools/policy.py new file mode 100644 index 0000000..2d57303 --- /dev/null +++ b/src/mistapi/device_utils/__tools/policy.py @@ -0,0 +1,62 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse + + +async def clear_hit_count( + apissession: _APISession, + site_id: str, + device_id: str, + policy_name: str, +) -> UtilResponse: + """ + DEVICE: EX + + Clears the policy hit count on a device. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to clear the policy hit count on. + policy_name : str + Name of the policy to clear the hit count for. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + trigger = devices.clearSiteDevicePolicyHitCount( + apissession, + site_id=site_id, + device_id=device_id, + body={"policy_name": policy_name}, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Clear policy hit count command triggered for device {device_id}") + # util_response = await WebSocketWrapper( + # apissession, util_response, timeout=timeout + # ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger clear policy hit count command: {trigger.status_code} - {trigger.data}" + ) # Give the clear policy hit count command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/__tools/port.py b/src/mistapi/device_utils/__tools/port.py new file mode 100644 index 0000000..d0c150e --- /dev/null +++ b/src/mistapi/device_utils/__tools/port.py @@ -0,0 +1,133 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +def bounce( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + timeout=60, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: EX, SRX, SSR + + Initiates a bounce command on the specified ports of a device (EX / SRX / SSR) and streams + the results. + + PARAMS + ----------- + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to perform the bounce command on. + port_ids : list[str] + List of port IDs to bounce. + timeout : int, default 5 + Timeout for the bounce command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if port_ids: + body["ports"] = port_ids + trigger = devices.bounceDevicePort( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info( + f"Bounce command triggered for ports {port_ids} on device {device_id}" + ) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger bounce command: {trigger.status_code} - {trigger.data}" + ) # Give the bounce command a moment to take effect + return util_response + + +def cable_test( + apissession: _APISession, + site_id: str, + device_id: str, + port_id: str, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: EX + + Initiates a cable test on a switch port and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to perform the cable test on. + port_id : str + Port ID to perform the cable test on. + timeout : int, optional + Timeout for the cable test command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"port": port_id} + trigger = devices.cableTestFromSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Cable test command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger cable test command: {trigger.status_code} - {trigger.data}" + ) # Give the cable test command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/__tools/remote_capture.py b/src/mistapi/device_utils/__tools/remote_capture.py new file mode 100644 index 0000000..90a438d --- /dev/null +++ b/src/mistapi/device_utils/__tools/remote_capture.py @@ -0,0 +1,444 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import pcaps +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import PcapEvents + + +def _build_pcap_body( + device_id: str, + port_ids: list[str], + device_key: str, + device_type: str, + tcpdump_expression: str | None, + duration: int, + max_pkt_len: int, + num_packets: int, + raw: bool | None = None, +) -> dict: + """Build the request body for remote pcap commands (SRX, SSR, EX).""" + mac = device_id.split("-")[-1] + body: dict = { + "duration": duration, + "max_pkt_len": max_pkt_len, + "num_packets": num_packets, + device_key: {mac: {"ports": {}}}, + "type": device_type, + "format": "stream", + } + if raw is not None: + body["raw"] = raw + for port_id in port_ids: + port_entry: dict = {} + if tcpdump_expression is not None: + port_entry["tcpdump_expression"] = tcpdump_expression + body[device_key][mac]["ports"][port_id] = port_entry + if tcpdump_expression: + body["tcpdump_expression"] = tcpdump_expression + return body + + +def ap_remote_pcap_wireless( + apissession: _APISession, + site_id: str, + device_id: str, + band: str, + tcpdump_expression: str | None = None, + ssid: str | None = None, + ap_mac: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: AP + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + band : str + Comma-separated list of radio bands (24, 5, or 6). + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "type mgt or type ctl -vvv -tttt -en" + ssid : str, optional + SSID to filter the wireless traffic. + ap_mac : str, optional + AP MAC address to filter the wireless traffic. + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | int] = { + "band": band, + "duration": duration, + "max_pkt_len": max_pkt_len, + "num_packets": num_packets, + "type": "radiotap", + "format": "stream", + } + if ssid: + body["ssid"] = ssid + if ap_mac: + body["ap_mac"] = ap_mac + if tcpdump_expression: + body["tcpdump_expression"] = tcpdump_expression + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +def ap_remote_pcap_wired( + apissession: _APISession, + site_id: str, + device_id: str, + tcpdump_expression: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: AP + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "udp port 67 or udp port 68 -vvv -tttt -en" + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | int] = { + "duration": duration, + "max_pkt_len": max_pkt_len, + "num_packets": num_packets, + "type": "wired", + "format": "stream", + } + if tcpdump_expression: + body["tcpdump_expression"] = tcpdump_expression + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +def srx_remote_pcap( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + tcpdump_expression: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: SRX + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + port_ids : list[str] + List of port IDs to monitor. + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "udp port 67 or udp port 68 -vvv -tttt -en" + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body = _build_pcap_body( + device_id, + port_ids, + "gateways", + "gateway", + tcpdump_expression, + duration, + max_pkt_len, + num_packets, + ) + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +def ssr_remote_pcap( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + tcpdump_expression: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: SSR + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + port_ids : list[str] + List of port IDs to monitor. + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "udp port 67 or udp port 68 -vvv -tttt -en" + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body = _build_pcap_body( + device_id, + port_ids, + "gateways", + "gateway", + tcpdump_expression, + duration, + max_pkt_len, + num_packets, + raw=False, + ) + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +def ex_remote_pcap( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + tcpdump_expression: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: EX + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + port_ids : list[str] + List of port IDs to monitor. + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "udp port 67 or udp port 68 -vvv -tttt -en" + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body = _build_pcap_body( + device_id, + port_ids, + "switches", + "switch", + tcpdump_expression, + duration, + max_pkt_len, + num_packets, + ) + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/__tools/routes.py b/src/mistapi/device_utils/__tools/routes.py new file mode 100644 index 0000000..6022f02 --- /dev/null +++ b/src/mistapi/device_utils/__tools/routes.py @@ -0,0 +1,113 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +class Node(Enum): + NODE0 = "node0" + NODE1 = "node1" + + +class RouteProtocol(Enum): + ANY = "any" + BGP = "bgp" + DIRECT = "direct" + EVPN = "evpn" + OSPF = "ospf" + STATIC = "static" + + +def show( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + prefix: str | None = None, + protocol: RouteProtocol | None = None, + route_type: str | None = None, + vrf: str | None = None, + timeout=2, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: SSR, SRX + + Initiates a show routes command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the show routes command on. + node : Node, optional + Node information for the show routes command. + prefix : str, optional + Prefix to filter the routes. + protocol : RouteProtocol, optional + Protocol to filter the routes. + route_type : str, optional + Type of the route to filter. + vrf : str, optional + VRF to filter the routes. + timeout : int, optional + Timeout for the command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if prefix: + body["prefix"] = prefix + if protocol: + body["protocol"] = protocol.value + if route_type: + body["route_type"] = route_type + if vrf: + body["vrf"] = vrf + trigger = devices.showSiteSsrAndSrxRoutes( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Device Routes command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger Device Routes command: {trigger.status_code} - {trigger.data}" + ) # Give the Device Routes command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/__tools/service_path.py b/src/mistapi/device_utils/__tools/service_path.py new file mode 100644 index 0000000..5f53fc0 --- /dev/null +++ b/src/mistapi/device_utils/__tools/service_path.py @@ -0,0 +1,90 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +class Node(Enum): + """Node Enum for specifying node information in service path commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +def show_service_path( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + service_name: str | None = None, + timeout: int = 5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SSR + + Initiates a show service path command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the show service path command on. + node : Node, optional + Node information for the show service path command. + service_name : str, optional + Name of the service to show the path for. + timeout : int, optional + Timeout for the command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if service_name: + body["service_name"] = service_name + trigger = devices.showSiteSsrServicePath( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"SSR service path command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger SSR service path command: {trigger.status_code} - {trigger.data}" + ) # Give the SSR service path command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/__tools/sessions.py b/src/mistapi/device_utils/__tools/sessions.py new file mode 100644 index 0000000..c019f10 --- /dev/null +++ b/src/mistapi/device_utils/__tools/sessions.py @@ -0,0 +1,172 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +class Node(Enum): + """Node Enum for specifying node information in session commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +def clear( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + service_name: str | None = None, + service_ids: list[str] | None = None, + vrf: str | None = None, + timeout=2, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: SSR, SRX + + Initiates a clear sessions command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the show routes command on. + node : Node, optional + Node information for the show routes command. + prefix : str, optional + Prefix to filter the routes. + protocol : RouteProtocol, optional + Protocol to filter the routes. + route_type : str, optional + Type of the route to filter. + vrf : str, optional + VRF to filter the routes. + timeout : int, optional + Timeout for the command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if service_name: + body["service_name"] = service_name + if service_ids: + body["service_ids"] = service_ids + if vrf: + body["vrf"] = vrf + trigger = devices.clearSiteDeviceSession( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Device Sessions command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger Device Sessions command: {trigger.status_code} - {trigger.data}" + ) # Give the Device Sessions command a moment to take effect + return util_response + + +def show( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + service_name: str | None = None, + service_ids: list[str] | None = None, + timeout=2, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: SSR, SRX + + Initiates a show sessions command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the show sessions command on. + node : Node, optional + Node information for the show sessions command. + service_name : str, optional + Name of the service to filter the sessions. + service_ids : list[str], optional + List of service IDs to filter the sessions. + timeout : int, optional + Timeout for the command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if service_name: + body["service_name"] = service_name + if service_ids: + body["service_ids"] = service_ids + trigger = devices.showSiteSsrAndSrxSessions( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Device Sessions command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger Device Sessions command: {trigger.status_code} - {trigger.data}" + ) # Give the Device Sessions command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/ap.py b/src/mistapi/device_utils/ap.py new file mode 100644 index 0000000..73e34df --- /dev/null +++ b/src/mistapi/device_utils/ap.py @@ -0,0 +1,31 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- + +Utility functions for Mist Access Points (AP). + +This module provides a device-specific namespace for AP utilities. +All functions are imported from their respective functional modules. +""" + +# Re-export shared classes and types +from mistapi.device_utils.__tools.arp import retrieve_ap_arp_table as retrieveArpTable +from mistapi.device_utils.__tools.miscellaneous import ( + TracerouteProtocol, + ping, + traceroute, +) + +__all__ = [ + "ping", + "traceroute", + "TracerouteProtocol", + "retrieveArpTable", +] diff --git a/src/mistapi/device_utils/bgp.py b/src/mistapi/device_utils/bgp.py new file mode 100644 index 0000000..f545c57 --- /dev/null +++ b/src/mistapi/device_utils/bgp.py @@ -0,0 +1,70 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +def summary( + apissession: _APISession, + site_id: str, + device_id: str, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: EX, SRX, SSR + + Shows BGP summary on a device (EX/ SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show BGP summary on. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"protocol": "bgp"} + trigger = devices.showSiteDeviceBgpSummary( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"BGP summary command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger BGP summary command: {trigger.status_code} - {trigger.data}" + ) # Give the BGP summary command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/bpdu.py b/src/mistapi/device_utils/bpdu.py new file mode 100644 index 0000000..c565903 --- /dev/null +++ b/src/mistapi/device_utils/bpdu.py @@ -0,0 +1,61 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse + + +async def clearError( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], +) -> UtilResponse: + """ + DEVICES: EX + + Clears BPDU error state on the specified ports of a switch. + + PARAMS + ----------- + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to clear BPDU errors on. + port_ids : list[str] + List of port IDs to clear BPDU errors on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {"ports": port_ids} + trigger = devices.clearBpduErrorsFromPortsOnSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear BPDU error command triggered for device {device_id}") + else: + LOGGER.error( + f"Failed to trigger clear BPDU error command: {trigger.status_code} - {trigger.data}" + ) # Give the clear BPDU error command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/dhcp.py b/src/mistapi/device_utils/dhcp.py new file mode 100644 index 0000000..c967c34 --- /dev/null +++ b/src/mistapi/device_utils/dhcp.py @@ -0,0 +1,172 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +class Node(Enum): + """Node Enum for specifying node information in DHCP commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +def releaseDhcpLeases( + apissession: _APISession, + site_id: str, + device_id: str, + macs: list[str] | None = None, + network: str | None = None, + node: Node | None = None, + port_id: str | None = None, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: EX, SRX, SSR + + Releases DHCP leases on a device (EX/ SRX / SSR) and streams the results. + + valid combinations for EX are: + - network + macs + - network + port_id + - port_id + + valid combinations for SRX / SSR are: + - network + - network + macs + - network + port_id + - port_id + - port_id + macs + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to release DHCP leases on. + macs : list[str], optional + List of MAC addresses to release DHCP leases for. + network : str, optional + Network to release DHCP leases for. + node : Node, optional + Node information for the DHCP lease release command. + port_id : str, optional + Port ID to release DHCP leases for. + timeout : int, optional + Timeout for the release DHCP leases command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if macs: + body["macs"] = macs + if network: + body["network"] = network + if node: + body["node"] = node.value + if port_id: + body["port_id"] = port_id + trigger = devices.releaseSiteDeviceDhcpLease( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Release DHCP leases command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger release DHCP leases command: {trigger.status_code} - {trigger.data}" + ) # Give the release DHCP leases command a moment to take effect + return util_response + + +def retrieveDhcpLeases( + apissession: _APISession, + site_id: str, + device_id: str, + network: str, + node: Node | None = None, + timeout=15, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Retrieves DHCP leases on a gateway (SRX / SSR) and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to retrieve DHCP leases from. + network : str + Network to release DHCP leases for. + node : Node, optional + Node information for the DHCP lease release command. + port_id : str, optional + Port ID to release DHCP leases for. + timeout : int, optional + Timeout for the release DHCP leases command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"network": network} + if node: + body["node"] = node.value + trigger = devices.showSiteDeviceDhcpLeases( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Retrieve DHCP leases command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger retrieve DHCP leases command: {trigger.status_code} - {trigger.data}" + ) # Give the release DHCP leases command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/dot1x.py b/src/mistapi/device_utils/dot1x.py new file mode 100644 index 0000000..af5c322 --- /dev/null +++ b/src/mistapi/device_utils/dot1x.py @@ -0,0 +1,60 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse + + +async def clearSessions( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], +) -> UtilResponse: + """ + DEVICES: EX + + Clears dot1x sessions on the specified ports of a switch (EX). + + PARAMS + ----------- + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to clear dot1x sessions on. + port_ids : list[str] + List of port IDs to clear dot1x sessions on. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"ports": port_ids} + trigger = devices.clearAllLearnedMacsFromPortOnSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Clear learned MACs command triggered for device {device_id}") + else: + LOGGER.error( + f"Failed to trigger clear learned MACs command: {trigger.status_code} - {trigger.data}" + ) # Give the clear learned MACs command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/ex.py b/src/mistapi/device_utils/ex.py new file mode 100644 index 0000000..dd3680f --- /dev/null +++ b/src/mistapi/device_utils/ex.py @@ -0,0 +1,78 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- + +Utility functions for Juniper EX Switches. + +This module provides a device-specific namespace for EX switch utilities. +All functions are imported from their respective functional modules. +""" + +# Re-export shared classes and types + +# ARP functions +from mistapi.device_utils.__tools.arp import ( + retrieve_junos_arp_table as retrieveArpTable, +) + +# BGP functions +from mistapi.device_utils.__tools.bgp import summary as retrieveBgpSummary + +# BPDU functions +from mistapi.device_utils.__tools.bpdu import clear_error as clearBpduError + +# DHCP functions +from mistapi.device_utils.__tools.dhcp import release_dhcp_leases as releaseDhcpLeases +from mistapi.device_utils.__tools.dhcp import retrieve_dhcp_leases as retrieveDhcpLeases + +# Dot1x functions +from mistapi.device_utils.__tools.dot1x import clear_sessions as clearDot1xSessions + +# MAC table functions +from mistapi.device_utils.__tools.mac import clear_learned_mac as clearLearnedMac +from mistapi.device_utils.__tools.mac import clear_mac_table as clearMacTable +from mistapi.device_utils.__tools.mac import retrieve_mac_table as retrieveMacTable + +# Tools (ping, monitor traffic) +from mistapi.device_utils.__tools.miscellaneous import monitor_traffic as monitorTraffic +from mistapi.device_utils.__tools.miscellaneous import ping + +# Policy functions +from mistapi.device_utils.__tools.policy import clear_hit_count as clearHitCount + +# Port functions +from mistapi.device_utils.__tools.port import bounce as bouncePort +from mistapi.device_utils.__tools.port import cable_test as cableTest + +__all__ = [ + # ARP + "retrieveArpTable", + # BGP + "retrieveBgpSummary", + # BPDU + "clearBpduError", + # DHCP + "retrieveDhcpLeases", + "releaseDhcpLeases", + # Dot1x + "clearDot1xSessions", + # MAC + "clearLearnedMac", + "clearMacTable", + "retrieveMacTable", + # Policy + "clearHitCount", + # Port + "bouncePort", + "cableTest", + # Tools + "monitorTraffic", + "ping", +] diff --git a/src/mistapi/device_utils/ospf.py b/src/mistapi/device_utils/ospf.py new file mode 100644 index 0000000..4903a52 --- /dev/null +++ b/src/mistapi/device_utils/ospf.py @@ -0,0 +1,291 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +class Node(Enum): + """Node Enum for specifying node information in OSPF commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +def showDatabase( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + self_originate: bool | None = None, + vrf: str | None = None, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Shows OSPF database on a device (SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show OSPF database on. + node : Node, optional + Node information for the show OSPF database command. + self_originate : bool, optional + Filter for self-originated routes in the OSPF database. + vrf : str, optional + VRF to filter the OSPF database. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if self_originate is not None: + body["self_originate"] = self_originate + if vrf: + body["vrf"] = vrf + trigger = devices.showSiteGatewayOspfDatabase( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"OSPF database command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger OSPF database command: {trigger.status_code} - {trigger.data}" + ) # Give the OSPF database command a moment to take effect + return util_response + + +def showInterfaces( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + port_id: str | None = None, + vrf: str | None = None, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Shows OSPF interfaces on a device (SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show OSPF interfaces on. + node : Node, optional + Node information for the show OSPF interfaces command. + port_id : str, optional + Port ID to filter the OSPF interfaces. + vrf : str, optional + VRF to filter the OSPF interfaces. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if port_id: + body["port_id"] = port_id + if vrf: + body["vrf"] = vrf + trigger = devices.showSiteGatewayOspfInterfaces( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"OSPF interfaces command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger OSPF interfaces command: {trigger.status_code} - {trigger.data}" + ) # Give the OSPF interfaces command a moment to take effect + return util_response + + +def showNeighbors( + apissession: _APISession, + site_id: str, + device_id: str, + neighbor: str | None = None, + node: Node | None = None, + port_id: str | None = None, + vrf: str | None = None, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Shows OSPF neighbors on a device (SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show OSPF neighbors on. + neighbor : str, optional + Neighbor IP address to filter the OSPF neighbors. + node : Node, optional + Node information for the show OSPF neighbors command. + port_id : str, optional + Port ID to filter the OSPF neighbors. + vrf : str, optional + VRF to filter the OSPF neighbors. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if port_id: + body["port_id"] = port_id + if vrf: + body["vrf"] = vrf + if neighbor: + body["neighbor"] = neighbor + trigger = devices.showSiteGatewayOspfNeighbors( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"OSPF neighbors command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger OSPF neighbors command: {trigger.status_code} - {trigger.data}" + ) # Give the OSPF neighbors command a moment to take effect + return util_response + + +def showSummary( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + vrf: str | None = None, + timeout=5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SRX, SSR + + Shows OSPF summary on a device (SRX / SSR) and streams the results. + + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to show OSPF summary on. + node : Node, optional + Node information for the show OSPF summary command. + vrf : str, optional + VRF to filter the OSPF summary. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if vrf: + body["vrf"] = vrf + trigger = devices.showSiteGatewayOspfSummary( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"OSPF summary command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger OSPF summary command: {trigger.status_code} - {trigger.data}" + ) # Give the OSPF summary command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/policy.py b/src/mistapi/device_utils/policy.py new file mode 100644 index 0000000..ba8d606 --- /dev/null +++ b/src/mistapi/device_utils/policy.py @@ -0,0 +1,62 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse + + +async def clearHitCount( + apissession: _APISession, + site_id: str, + device_id: str, + policy_name: str, +) -> UtilResponse: + """ + DEVICE: EX + + Clears the policy hit count on a device. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to clear the policy hit count on. + policy_name : str + Name of the policy to clear the hit count for. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received from the WebSocket stream. + """ + trigger = devices.clearSiteDevicePolicyHitCount( + apissession, + site_id=site_id, + device_id=device_id, + body={"policy_name": policy_name}, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Clear policy hit count command triggered for device {device_id}") + # util_response = await WebSocketWrapper( + # apissession, util_response, timeout=timeout + # ).startCmdEvents(site_id, device_id) + else: + LOGGER.error( + f"Failed to trigger clear policy hit count command: {trigger.status_code} - {trigger.data}" + ) # Give the clear policy hit count command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/port.py b/src/mistapi/device_utils/port.py new file mode 100644 index 0000000..5757c0f --- /dev/null +++ b/src/mistapi/device_utils/port.py @@ -0,0 +1,133 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +def bounce( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + timeout=60, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: EX, SRX, SSR + + Initiates a bounce command on the specified ports of a device (EX / SRX / SSR) and streams + the results. + + PARAMS + ----------- + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to perform the bounce command on. + port_ids : list[str] + List of port IDs to bounce. + timeout : int, default 5 + Timeout for the bounce command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if port_ids: + body["ports"] = port_ids + trigger = devices.bounceDevicePort( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info( + f"Bounce command triggered for ports {port_ids} on device {device_id}" + ) + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger bounce command: {trigger.status_code} - {trigger.data}" + ) # Give the bounce command a moment to take effect + return util_response + + +def cableTest( + apissession: _APISession, + site_id: str, + device_id: str, + port_id: str, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: EX + + Initiates a cable test on a switch port and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the switch is located. + device_id : str + UUID of the switch to perform the cable test on. + port_id : str + Port ID to perform the cable test on. + timeout : int, optional + Timeout for the cable test command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"port": port_id} + trigger = devices.cableTestFromSwitch( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Cable test command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger cable test command: {trigger.status_code} - {trigger.data}" + ) # Give the cable test command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/service_path.py b/src/mistapi/device_utils/service_path.py new file mode 100644 index 0000000..2973c23 --- /dev/null +++ b/src/mistapi/device_utils/service_path.py @@ -0,0 +1,90 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +class Node(Enum): + """Node Enum for specifying node information in service path commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +def showServicePath( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + service_name: str | None = None, + timeout: int = 5, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: SSR + + Initiates a show service path command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the show service path command on. + node : Node, optional + Node information for the show service path command. + service_name : str, optional + Name of the service to show the path for. + timeout : int, optional + Timeout for the command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if service_name: + body["service_name"] = service_name + trigger = devices.showSiteSsrServicePath( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"SSR service path command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger SSR service path command: {trigger.status_code} - {trigger.data}" + ) # Give the SSR service path command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/sessions.py b/src/mistapi/device_utils/sessions.py new file mode 100644 index 0000000..c019f10 --- /dev/null +++ b/src/mistapi/device_utils/sessions.py @@ -0,0 +1,172 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.sites import DeviceCmdEvents + + +class Node(Enum): + """Node Enum for specifying node information in session commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +def clear( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + service_name: str | None = None, + service_ids: list[str] | None = None, + vrf: str | None = None, + timeout=2, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: SSR, SRX + + Initiates a clear sessions command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the show routes command on. + node : Node, optional + Node information for the show routes command. + prefix : str, optional + Prefix to filter the routes. + protocol : RouteProtocol, optional + Protocol to filter the routes. + route_type : str, optional + Type of the route to filter. + vrf : str, optional + VRF to filter the routes. + timeout : int, optional + Timeout for the command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if service_name: + body["service_name"] = service_name + if service_ids: + body["service_ids"] = service_ids + if vrf: + body["vrf"] = vrf + trigger = devices.clearSiteDeviceSession( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Device Sessions command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger Device Sessions command: {trigger.status_code} - {trigger.data}" + ) # Give the Device Sessions command a moment to take effect + return util_response + + +def show( + apissession: _APISession, + site_id: str, + device_id: str, + node: Node | None = None, + service_name: str | None = None, + service_ids: list[str] | None = None, + timeout=2, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: SSR, SRX + + Initiates a show sessions command on the gateway and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the gateway is located. + device_id : str + UUID of the gateway to perform the show sessions command on. + node : Node, optional + Node information for the show sessions command. + service_name : str, optional + Name of the service to filter the sessions. + service_ids : list[str], optional + List of service IDs to filter the sessions. + timeout : int, optional + Timeout for the command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + + body: dict[str, str | list | int] = {} + if node: + body["node"] = node.value + if service_name: + body["service_name"] = service_name + if service_ids: + body["service_ids"] = service_ids + trigger = devices.showSiteSsrAndSrxSessions( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Device Sessions command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger Device Sessions command: {trigger.status_code} - {trigger.data}" + ) # Give the Device Sessions command a moment to take effect + return util_response diff --git a/src/mistapi/device_utils/srx.py b/src/mistapi/device_utils/srx.py new file mode 100644 index 0000000..a93f124 --- /dev/null +++ b/src/mistapi/device_utils/srx.py @@ -0,0 +1,69 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- + +Utility functions for Juniper SRX Firewalls. + +This module provides a device-specific namespace for SRX firewall utilities. +All functions are imported from their respective functional modules. +""" + +# Re-export shared classes and types +from mistapi.device_utils.__tools.arp import Node + +# ARP functions +from mistapi.device_utils.__tools.arp import ( + retrieve_junos_arp_table as retrieveArpTable, +) + +# BGP functions +from mistapi.device_utils.__tools.bgp import summary as retrieveBgpSummary + +# DHCP functions +from mistapi.device_utils.__tools.dhcp import release_dhcp_leases as releaseDhcpLeases +from mistapi.device_utils.__tools.dhcp import retrieve_dhcp_leases as retrieveDhcpLeases + +# Tools (ping, monitor traffic) +from mistapi.device_utils.__tools.miscellaneous import monitor_traffic as monitorTraffic +from mistapi.device_utils.__tools.miscellaneous import ping + +# OSPF functions +from mistapi.device_utils.__tools.ospf import show_database as showDatabase +from mistapi.device_utils.__tools.ospf import show_interfaces as showInterfaces +from mistapi.device_utils.__tools.ospf import show_neighbors as showNeighbors + +# Port functions +from mistapi.device_utils.__tools.port import bounce as bouncePort + +# Route functions +from mistapi.device_utils.__tools.routes import show as retrieveRoutes + +__all__ = [ + # Classes/Enums + "Node", + # ARP + "retrieveArpTable", + # BGP + "retrieveBgpSummary", + # DHCP + "releaseDhcpLeases", + "retrieveDhcpLeases", + # OSPF + "showDatabase", + "showNeighbors", + "showInterfaces", + # Port + "bouncePort", + # Routes + "retrieveRoutes", + # Tools + "monitorTraffic", + "ping", +] diff --git a/src/mistapi/device_utils/ssr.py b/src/mistapi/device_utils/ssr.py new file mode 100644 index 0000000..0af8df1 --- /dev/null +++ b/src/mistapi/device_utils/ssr.py @@ -0,0 +1,78 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- + +Utility functions for Juniper Session Smart Routers (SSR). + +This module provides a device-specific namespace for SSR router utilities. +All functions are imported from their respective functional modules. +""" + +# Re-export shared classes and types +from mistapi.device_utils.__tools.arp import Node + +# ARP functions +from mistapi.device_utils.__tools.arp import ( + retrieve_ssr_arp_table as retrieveArpTable, +) + +# BGP functions +from mistapi.device_utils.__tools.bgp import summary as retrieveBgpSummary + +# DHCP functions +from mistapi.device_utils.__tools.dhcp import release_dhcp_leases as releaseDhcpLeases +from mistapi.device_utils.__tools.dhcp import retrieve_dhcp_leases as retrieveDhcpLeases + +# Tools (ping only - no monitor_traffic for SSR) +from mistapi.device_utils.__tools.miscellaneous import ping + +# DNS functions +# from mistapi.utils.dns import test_resolution as test_dns_resolution +# OSPF functions +from mistapi.device_utils.__tools.ospf import show_database as showDatabase +from mistapi.device_utils.__tools.ospf import show_interfaces as showInterfaces +from mistapi.device_utils.__tools.ospf import show_neighbors as showNeighbors + +# Port functions +from mistapi.device_utils.__tools.port import bounce as bouncePort + +# Route functions +from mistapi.device_utils.__tools.routes import show as retrieveRoutes + +# Service Path functions +from mistapi.device_utils.__tools.service_path import ( + show_service_path as showServicePath, +) + +__all__ = [ + # Classes/Enums + "Node", + # ARP + "retrieveArpTable", + # BGP + "retrieveBgpSummary", + # DHCP + "releaseDhcpLeases", + "retrieveDhcpLeases", + # DNS + # "test_dns_resolution", + # OSPF + "showDatabase", + "showNeighbors", + "showInterfaces", + # Port + "bouncePort", + # Routes + "retrieveRoutes", + # Service Path + "showServicePath", + # Tools + "ping", +] diff --git a/src/mistapi/device_utils/tools.py b/src/mistapi/device_utils/tools.py new file mode 100644 index 0000000..8a95822 --- /dev/null +++ b/src/mistapi/device_utils/tools.py @@ -0,0 +1,789 @@ +from collections.abc import Callable +from enum import Enum + +from mistapi import APISession as _APISession +from mistapi.__logger import logger as LOGGER +from mistapi.api.v1.sites import devices, pcaps +from mistapi.device_utils.__tools.__ws_wrapper import UtilResponse, WebSocketWrapper +from mistapi.websockets.session import SessionWithUrl +from mistapi.websockets.sites import DeviceCmdEvents, PcapEvents + + +class Node(Enum): + """Node Enum for specifying node information in commands.""" + + NODE0 = "node0" + NODE1 = "node1" + + +class TracerouteProtocol(Enum): + """Enum for specifying protocol in traceroute command.""" + + ICMP = "icmp" + UDP = "udp" + + +def _build_pcap_body( + device_id: str, + port_ids: list[str], + device_key: str, + device_type: str, + tcpdump_expression: str | None, + duration: int, + max_pkt_len: int, + num_packets: int, + raw: bool | None = None, +) -> dict: + """Build the request body for remote pcap commands (SRX, SSR, EX).""" + mac = device_id.split("-")[-1] + body: dict = { + "duration": duration, + "max_pkt_len": max_pkt_len, + "num_packets": num_packets, + device_key: {mac: {"ports": {}}}, + "type": device_type, + "format": "stream", + } + if raw is not None: + body["raw"] = raw + for port_id in port_ids: + port_entry: dict = {} + if tcpdump_expression is not None: + port_entry["tcpdump_expression"] = tcpdump_expression + body[device_key][mac]["ports"][port_id] = port_entry + if tcpdump_expression: + body["tcpdump_expression"] = tcpdump_expression + return body + + +def ping( + apissession: _APISession, + site_id: str, + device_id: str, + host: str, + count: int | None = None, + node: Node | None = None, + size: int | None = None, + vrf: str | None = None, + timeout: int = 3, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: AP, EX, SRX, SSR + + Initiates a ping command from a device (AP / EX/ SRX / SSR) to a specified host and + streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to initiate the ping from. + host : str + The host to ping. + count : int, optional + Number of ping requests to send. + node : None, optional + Node information for the ping command. + size : int, optional + Size of the ping packet. + vrf : str, optional + VRF to use for the ping command. + timeout : int, optional + Timeout for the ping command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {} + if count: + body["count"] = count + if host: + body["host"] = host + if node: + body["node"] = node.value + if size: + body["size"] = size + if vrf: + body["vrf"] = vrf + trigger = devices.pingFromDevice( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Ping command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger ping command: {trigger.status_code} - {trigger.data}" + ) # Give the ping command a moment to take effect + return util_response + + +## NO DATA +# def service_ping( +# apissession: _APISession, +# site_id: str, +# device_id: str, +# host: str, +# service: str, +# tenant: str, +# count: int | None = None, +# node: None | None = None, +# size: int | None = None, +# timeout: int = 3, +# on_message: Callable[[dict], None] | None = None, +# ) -> UtilResponse: +# """ +# DEVICES: SSR + +# Initiates a service ping command from a SSR to a specified host and streams the results. + +# PARAMS +# ----------- +# apissession : _APISession +# The API session to use for the request. +# site_id : str +# UUID of the site where the device is located. +# device_id : str +# UUID of the device to initiate the ping from. +# host : str +# The host to ping. +# service : str +# The service to ping. +# tenant : str +# Tenant to use for the ping command. +# count : int, optional +# Number of ping requests to send. +# node : None, optional +# Node information for the ping command. +# size : int, optional +# Size of the ping packet. +# timeout : int, optional +# Timeout for the ping command in seconds. +# on_message : Callable, optional +# Callback invoked with each extracted raw message as it arrives. + +# RETURNS +# ----------- +# UtilResponse +# A UtilResponse object containing the API response and a list of raw messages received +# from the WebSocket stream. +# """ +# body: dict[str, str | list | int] = {} +# if count: +# body["count"] = count +# if host: +# body["host"] = host +# if node: +# body["node"] = node.value +# if size: +# body["size"] = size +# if tenant: +# body["tenant"] = tenant +# if service: +# body["service"] = service +# trigger = devices.servicePingFromSsr( +# apissession, +# site_id=site_id, +# device_id=device_id, +# body=body, +# ) +# util_response = UtilResponse(trigger) +# if trigger.status_code == 200: +# LOGGER.info(f"Service Ping command triggered for device {device_id}") +# ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) +# util_response = WebSocketWrapper( +# apissession, util_response, timeout, on_message=on_message +# ).start(ws) +# else: +# LOGGER.error( +# f"Failed to trigger Service Ping command: {trigger.status_code} - {trigger.data}" +# ) # Give the ping command a moment to take effect +# return util_response + + +def traceroute( + apissession: _APISession, + site_id: str, + device_id: str, + host: str, + protocol: TracerouteProtocol = TracerouteProtocol.ICMP, + port: int | None = None, + timeout: int = 10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICES: AP, EX, SRX, SSR + + Initiates a traceroute command from a device (AP / EX/ SRX / SSR) to a specified host and + streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to initiate the traceroute from. + host : str + The host to traceroute. + protocol : TracerouteProtocol, optional + Protocol to use for the traceroute command (icmp or udp). + port : int, optional + Port to use for UDP traceroute. + timeout : int, optional + Timeout for the traceroute command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | list | int] = {"host": host} + if protocol: + body["protocol"] = protocol.value + if port: + body["port"] = port + trigger = devices.tracerouteFromDevice( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(f"Traceroute command triggered for device {device_id}") + ws = DeviceCmdEvents(apissession, site_id=site_id, device_ids=[device_id]) + util_response = WebSocketWrapper( + apissession, util_response, timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger traceroute command: {trigger.status_code} - {trigger.data}" + ) # Give the traceroute command a moment to take effect + return util_response + + +def monitorTraffic( + apissession: _APISession, + site_id: str, + device_id: str, + port_id: str | None = None, + timeout=30, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: EX, SRX + + Initiates a monitor traffic command on the device and streams the results. + + * if `port_id` is provided, JUNOS uses cmd "monitor interface" to monitor traffic on particular + * if `port_id` is not provided, JUNOS uses cmd "monitor interface traffic" to monitor traffic + on all ports + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to monitor traffic on. + port_id : str, optional + Port ID to filter the traffic. + timeout : int, optional + Timeout for the monitor traffic command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | int] = {"duration": 60} + if port_id: + body["port"] = port_id + trigger = devices.monitorSiteDeviceTraffic( + apissession, + site_id=site_id, + device_id=device_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Monitor traffic command triggered for device {device_id}") + ws = SessionWithUrl(apissession, url=trigger.data.get("url", "")) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger monitor traffic command: {trigger.status_code} - {trigger.data}" + ) # Give the monitor traffic command a moment to take effect + return util_response + + +def apRemotePcapWireless( + apissession: _APISession, + site_id: str, + device_id: str, + band: str, + tcpdump_expression: str | None = None, + ssid: str | None = None, + ap_mac: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: AP + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + band : str + Comma-separated list of radio bands (24, 5, or 6). + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "type mgt or type ctl -vvv -tttt -en" + ssid : str, optional + SSID to filter the wireless traffic. + ap_mac : str, optional + AP MAC address to filter the wireless traffic. + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | int] = { + "band": band, + "duration": duration, + "max_pkt_len": max_pkt_len, + "num_packets": num_packets, + "type": "radiotap", + "format": "stream", + } + if ssid: + body["ssid"] = ssid + if ap_mac: + body["ap_mac"] = ap_mac + if tcpdump_expression: + body["tcpdump_expression"] = tcpdump_expression + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +def apRemotePcapWired( + apissession: _APISession, + site_id: str, + device_id: str, + tcpdump_expression: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: AP + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "udp port 67 or udp port 68 -vvv -tttt -en" + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body: dict[str, str | int] = { + "duration": duration, + "max_pkt_len": max_pkt_len, + "num_packets": num_packets, + "type": "wired", + "format": "stream", + } + if tcpdump_expression: + body["tcpdump_expression"] = tcpdump_expression + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +def srxRemotePcap( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + tcpdump_expression: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: SRX + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + port_ids : list[str] + List of port IDs to monitor. + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "udp port 67 or udp port 68 -vvv -tttt -en" + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body = _build_pcap_body( + device_id, + port_ids, + "gateways", + "gateway", + tcpdump_expression, + duration, + max_pkt_len, + num_packets, + ) + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +def ssrRemotePcap( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + tcpdump_expression: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: SSR + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + port_ids : list[str] + List of port IDs to monitor. + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "udp port 67 or udp port 68 -vvv -tttt -en" + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body = _build_pcap_body( + device_id, + port_ids, + "gateways", + "gateway", + tcpdump_expression, + duration, + max_pkt_len, + num_packets, + raw=False, + ) + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +def exRemotePcap( + apissession: _APISession, + site_id: str, + device_id: str, + port_ids: list[str], + tcpdump_expression: str | None = None, + duration: int = 600, + max_pkt_len: int = 512, + num_packets: int = 1024, + timeout=10, + on_message: Callable[[dict], None] | None = None, +) -> UtilResponse: + """ + DEVICE: EX + + Initiates a remote pcap command on the device and streams the results. + + PARAMS + ----------- + apissession : _APISession + The API session to use for the request. + site_id : str + UUID of the site where the device is located. + device_id : str + UUID of the device to run remote pcap on. + port_ids : list[str] + List of port IDs to monitor. + tcpdump_expression : str, optional + Tcpdump expression to filter the captured traffic. + e.g. "udp port 67 or udp port 68 -vvv -tttt -en" + duration : int, optional + Duration of the remote pcap in seconds (default: 600). + max_pkt_len : int, optional + Maximum packet length to capture (default: 512). + num_packets : int, optional + Maximum number of packets to capture (default: 1024). + timeout : int, optional + Timeout for the remote pcap command in seconds. + on_message : Callable, optional + Callback invoked with each extracted raw message as it arrives. + + RETURNS + ----------- + UtilResponse + A UtilResponse object containing the API response and a list of raw messages received + from the WebSocket stream. + """ + body = _build_pcap_body( + device_id, + port_ids, + "switches", + "switch", + tcpdump_expression, + duration, + max_pkt_len, + num_packets, + ) + trigger = pcaps.startSitePacketCapture( + apissession, + site_id=site_id, + body=body, + ) + util_response = UtilResponse(trigger) + if trigger.status_code == 200: + LOGGER.info(trigger.data) + print(f"Remote pcap command triggered for device {device_id}") + ws = PcapEvents(apissession, site_id=site_id) + util_response = WebSocketWrapper( + apissession, util_response, timeout=timeout, on_message=on_message + ).start(ws) + else: + LOGGER.error( + f"Failed to trigger remote pcap command: {trigger.status_code} - {trigger.data}" + ) # Give the remote pcap command a moment to take effect + return util_response + + +## NO DATA +# def srx_top_command( +# apissession: _APISession, +# site_id: str, +# device_id: str, +# timeout=10, +# on_message: Callable[[dict], None] | None = None, +# ) -> UtilResponse: +# """ +# DEVICE: SRX + +# For SRX Only. Initiates a top command on the device and streams the results. + +# PARAMS +# ----------- +# apissession : _APISession +# The API session to use for the request. +# site_id : str +# UUID of the site where the device is located. +# device_id : str +# UUID of the device to run the top command on. +# timeout : int, optional +# Timeout for the top command in seconds. +# on_message : Callable, optional +# Callback invoked with each extracted raw message as it arrives. + +# RETURNS +# ----------- +# UtilResponse +# A UtilResponse object containing the API response and a list of raw messages received +# from the WebSocket stream. +# """ +# trigger = devices.runSiteSrxTopCommand( +# apissession, +# site_id=site_id, +# device_id=device_id, +# ) +# util_response = UtilResponse(trigger) +# if trigger.status_code == 200: +# LOGGER.info(trigger.data) +# print(f"Top command triggered for device {device_id}") +# ws = SessionWithUrl(apissession, url=trigger.data.get("url", "")) +# util_response = WebSocketWrapper( +# apissession, util_response, timeout=timeout, on_message=on_message +# ).start(ws) +# else: +# LOGGER.error( +# f"Failed to trigger top command: {trigger.status_code} - {trigger.data}" +# ) # Give the top command a moment to take effect +# return util_response diff --git a/src/mistapi/websockets/__init__.py b/src/mistapi/websockets/__init__.py new file mode 100644 index 0000000..e269ee6 --- /dev/null +++ b/src/mistapi/websockets/__init__.py @@ -0,0 +1,20 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +""" + +from mistapi.websockets import location, orgs, session, sites + +__all__ = [ + "location", + "orgs", + "session", + "sites", +] diff --git a/src/mistapi/websockets/__ws_client.py b/src/mistapi/websockets/__ws_client.py new file mode 100644 index 0000000..237b056 --- /dev/null +++ b/src/mistapi/websockets/__ws_client.py @@ -0,0 +1,252 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +This module provides the _MistWebsocket base class for WebSocket connections +to the Mist API streaming endpoint (wss://{host}/api-ws/v1/stream). +""" + +import json +import queue +import ssl +import threading +from collections.abc import Callable, Generator +from typing import TYPE_CHECKING + +import websocket + +from mistapi.__logger import logger + +if TYPE_CHECKING: + from mistapi import APISession + + +class _MistWebsocket: + """ + Base class for Mist API WebSocket channels. + + Connects to wss://{host}/api-ws/v1/stream and subscribes to a channel + by sending {"subscribe": ""} on open. + + Auth is handled automatically: + - API token sessions use an Authorization header. + - Login/password sessions pass the requests Session cookies. + """ + + def __init__( + self, + mist_session: "APISession", + channels: list[str], + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + self._mist_session = mist_session + self._channels = channels + self._ping_interval = ping_interval + self._ping_timeout = ping_timeout + self._ws: websocket.WebSocketApp | None = None + self._thread: threading.Thread | None = None + self._queue: queue.Queue[dict | None] = queue.Queue() + self._connected = ( + threading.Event() + ) # tracks whether the WebSocket connection is currently open + self._on_message_cb: Callable[[dict], None] | None = None + self._on_error_cb: Callable[[Exception], None] | None = None + self._on_open_cb: Callable[[], None] | None = None + self._on_close_cb: Callable[[int, str], None] | None = None + + # ------------------------------------------------------------------ + # Auth / URL helpers + + def _build_ws_url(self) -> str: + return f"wss://{self._mist_session._cloud_uri.replace('api.', 'api-ws.')}/api-ws/v1/stream" + + def _get_headers(self) -> dict: + if self._mist_session._apitoken: + token = self._mist_session._apitoken[self._mist_session._apitoken_index] + return {"Authorization": f"Token {token}"} + return {} + + def _get_cookie(self) -> str | None: + cookies = self._mist_session._session.cookies + if cookies: + safe = [] + for c in cookies: + has_crlf = ( + "\r" in c.name + or "\n" in c.name + or (c.value and ("\r" in c.value or "\n" in c.value)) + ) + if has_crlf: + logger.warning( + "Skipping cookie %r: contains CRLF characters (possible header injection)", + c.name, + ) + continue + safe.append(f"{c.name}={c.value}") + return "; ".join(safe) if safe else None + return None + + def _build_sslopt(self) -> dict: + """Build SSL options from the APISession's requests.Session.""" + sslopt: dict = {} + session = self._mist_session._session + if session.verify is False: + sslopt["cert_reqs"] = ssl.CERT_NONE + elif isinstance(session.verify, str): + sslopt["ca_certs"] = session.verify + if session.cert: + if isinstance(session.cert, str): + sslopt["certfile"] = session.cert + elif isinstance(session.cert, tuple): + sslopt["certfile"] = session.cert[0] + if len(session.cert) > 1: + sslopt["keyfile"] = session.cert[1] + return sslopt + + # ------------------------------------------------------------------ + # Callback registration + + def on_message(self, callback: Callable[[dict], None]) -> None: + """Register a callback invoked for every incoming message.""" + self._on_message_cb = callback + + def on_error(self, callback: Callable[[Exception], None]) -> None: + """Register a callback invoked on WebSocket errors.""" + self._on_error_cb = callback + + def on_open(self, callback: Callable[[], None]) -> None: + """Register a callback invoked when the connection is established.""" + self._on_open_cb = callback + + def on_close(self, callback: Callable[[int, str], None]) -> None: + """Register a callback invoked when the connection closes.""" + self._on_close_cb = callback + + # ------------------------------------------------------------------ + # Internal WebSocketApp handlers + + def _handle_open(self, ws: websocket.WebSocketApp) -> None: + for channel in self._channels: + ws.send(json.dumps({"subscribe": channel})) + self._connected.set() + if self._on_open_cb: + self._on_open_cb() + + def _handle_message(self, ws: websocket.WebSocketApp, message: str) -> None: + try: + data = json.loads(message) + except json.JSONDecodeError: + data = {"raw": message} + self._queue.put(data) + if self._on_message_cb: + self._on_message_cb(data) + + def _handle_error(self, ws: websocket.WebSocketApp, error: Exception) -> None: + if self._on_error_cb: + self._on_error_cb(error) + + def _handle_close( + self, + ws: websocket.WebSocketApp, + close_status_code: int, + close_msg: str, + ) -> None: + self._connected.clear() + self._queue.put(None) # Signals receive() generator to stop + if self._on_close_cb: + self._on_close_cb(close_status_code, close_msg) + + # ------------------------------------------------------------------ + # Lifecycle + + def connect(self, run_in_background: bool = True) -> None: + """ + Open the WebSocket connection and subscribe to the channel. + + PARAMS + ----------- + run_in_background : bool, default True + If True, runs the WebSocket loop in a daemon thread (non-blocking). + If False, blocks the calling thread until disconnected. + """ + # Drain stale sentinel from previous connection + while not self._queue.empty(): + try: + self._queue.get_nowait() + except queue.Empty: + break + + self._ws = websocket.WebSocketApp( + self._build_ws_url(), + header=self._get_headers(), + cookie=self._get_cookie(), + on_open=self._handle_open, + on_message=self._handle_message, + on_error=self._handle_error, + on_close=self._handle_close, + ) + if run_in_background: + self._thread = threading.Thread(target=self._run_forever_safe, daemon=True) + self._thread.start() + else: + self._run_forever_safe() + + def _run_forever_safe(self) -> None: + if self._ws: + try: + sslopt = self._build_sslopt() + self._ws.run_forever( + ping_interval=self._ping_interval, + ping_timeout=self._ping_timeout, + sslopt=sslopt, + ) + except Exception as exc: + self._handle_error(self._ws, exc) + self._handle_close(self._ws, -1, str(exc)) + + def disconnect(self) -> None: + """Close the WebSocket connection.""" + if self._ws: + self._ws.close() + + def receive(self) -> Generator[dict, None, None]: + """ + Blocking generator that yields each incoming message as a dict. + + Exits cleanly when the connection closes (disconnect() is called or + the server closes the connection). + + Intended for use after connect(run_in_background=True). + """ + if not self._connected.wait(timeout=10): + return + while True: + try: + item = self._queue.get(timeout=1) + except queue.Empty: + if not self._connected.is_set() and self._queue.empty(): + break + continue + if item is None: + break + yield item + + # ------------------------------------------------------------------ + # Context manager + + def __enter__(self) -> "_MistWebsocket": + return self + + def __exit__(self, *args) -> None: + self.disconnect() + + def ready(self) -> bool | None: + """Returns True if the WebSocket connection is open and ready.""" + return self._ws is not None and self._ws.ready() diff --git a/src/mistapi/websockets/location.py b/src/mistapi/websockets/location.py new file mode 100644 index 0000000..2c40842 --- /dev/null +++ b/src/mistapi/websockets/location.py @@ -0,0 +1,325 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for Location events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class BleAssetsEvents(_MistWebsocket): + """WebSocket stream for location BLE assets events. + + Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/assets`` channel and delivers + real-time BLE assets events for the given location. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + map_id : list[str] + UUIDs of the maps to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style (background thread):: + + ws = LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with LocationBleAssetsEvents(session, site_id="abc123", map_id="def456") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + site_id: str, + map_id: list[str], + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + channels = [f"/sites/{site_id}/stats/maps/{mid}/assets" for mid in map_id] + super().__init__( + mist_session, + channels=channels, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) + + +class ConnectedClientsEvents(_MistWebsocket): + """WebSocket stream for location connected clients events. + + Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/clients`` channel and delivers + real-time connected clients events for the given location. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + map_id : list[str] + UUIDs of the maps to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style (background thread):: + + ws = LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with LocationConnectedClientsEvents(session, site_id="abc123", map_id="def456") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + site_id: str, + map_id: list[str], + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + channels = [f"/sites/{site_id}/stats/maps/{mid}/clients" for mid in map_id] + super().__init__( + mist_session, + channels=channels, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) + + +class SdkClientsEvents(_MistWebsocket): + """WebSocket stream for location SDK clients events. + + Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/sdkclients`` channel and delivers + real-time SDK clients events for the given location. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + map_id : list[str] + UUIDs of the maps to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style (background thread):: + + ws = LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with LocationSdkClientsEvents(session, site_id="abc123", map_id="def456") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + site_id: str, + map_id: list[str], + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + channels = [f"/sites/{site_id}/stats/maps/{mid}/sdkclients" for mid in map_id] + super().__init__( + mist_session, + channels=channels, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) + + +class UnconnectedClientsEvents(_MistWebsocket): + """WebSocket stream for location unconnected clients events. + + Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/unconnected_clients`` channel and delivers + real-time unconnected clients events for the given location. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + map_id : list[str] + UUIDs of the maps to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style (background thread):: + + ws = LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with LocationUnconnectedClientsEvents(session, site_id="abc123", map_id="def456") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + site_id: str, + map_id: list[str], + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + channels = [ + f"/sites/{site_id}/stats/maps/{mid}/unconnected_clients" for mid in map_id + ] + super().__init__( + mist_session, + channels=channels, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) + + +class DiscoveredBleAssetsEvents(_MistWebsocket): + """WebSocket stream for location discovered BLE assets events. + + Subscribes to the ``/sites/{site_id}/stats/maps/{map_id}/discovered_assets`` channel and delivers + real-time discovered BLE assets events for the given location. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + map_id : list[str] + UUIDs of the maps to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style (background thread):: + + ws = LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with LocationDiscoveredBleAssetsEvents(session, site_id="abc123", map_id="def456") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + site_id: str, + map_id: list[str], + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + channels = [ + f"/sites/{site_id}/stats/maps/{mid}/discovered_assets" for mid in map_id + ] + super().__init__( + mist_session, + channels=channels, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) diff --git a/src/mistapi/websockets/orgs.py b/src/mistapi/websockets/orgs.py new file mode 100644 index 0000000..d8e5520 --- /dev/null +++ b/src/mistapi/websockets/orgs.py @@ -0,0 +1,186 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for Org events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class InsightsEvents(_MistWebsocket): + """WebSocket stream for organization insights events. + + Subscribes to the ``orgs/{org_id}/insights/summary`` channel and delivers + real-time insights events for the given organization. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + org_id : str + UUID of the organization to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = OrgInsightsEvents(session, org_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = OrgInsightsEvents(session, org_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with OrgInsightsEvents(session, org_id="abc123") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + org_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + super().__init__( + mist_session, + channels=[f"/orgs/{org_id}/insights/summary"], + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) + + +class MxEdgesStatsEvents(_MistWebsocket): + """WebSocket stream for organization MX edges stats events. + + Subscribes to the ``orgs/{org_id}/stats/mxedges`` channel and delivers + real-time MX edges stats events for the given organization. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + org_id : str + UUID of the organization to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = OrgMxEdgesStatsEvents(session, org_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = OrgMxEdgesStatsEvents(session, org_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with OrgMxEdgesStatsEvents(session, org_id="abc123") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + org_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + super().__init__( + mist_session, + channels=[f"/orgs/{org_id}/stats/mxedges"], + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) + + +class MxEdgesEvents(_MistWebsocket): + """WebSocket stream for org MX edges events. + + Subscribes to the ``orgs/{org_id}/mxedges`` channel and delivers + real-time MX edges events for the given org. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + org_id : str + UUID of the org to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = MxEdgesEvents(session, org_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = MxEdgesEvents(session, org_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with MxEdgesEvents(session, org_id="abc123") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + org_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + super().__init__( + mist_session, + channels=[f"/orgs/{org_id}/mxedges"], + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) diff --git a/src/mistapi/websockets/session.py b/src/mistapi/websockets/session.py new file mode 100644 index 0000000..8b87801 --- /dev/null +++ b/src/mistapi/websockets/session.py @@ -0,0 +1,73 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for Remote Commands events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class SessionWithUrl(_MistWebsocket): + """WebSocket stream for remote commands events. + + Open a WebSocket connection to a custom channel URL for remote command events. + This is a base class that can be used to implement specific remote command event streams by providing the + appropriate channel URL. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + url : str + URL of the WebSocket channel to connect to. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = sessionWithUrl(session, url="wss://example.com/channel") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = sessionWithUrl(session, url="wss://example.com/channel") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with sessionWithUrl(session, url="wss://example.com/channel") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + url: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + super().__init__( + mist_session, + channels=[url], + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) diff --git a/src/mistapi/websockets/sites.py b/src/mistapi/websockets/sites.py new file mode 100644 index 0000000..c63910f --- /dev/null +++ b/src/mistapi/websockets/sites.py @@ -0,0 +1,432 @@ +""" +-------------------------------------------------------------------------------- +------------------------- Mist API Python CLI Session -------------------------- + + Written by: Thomas Munzer (tmunzer@juniper.net) + Github : https://github.com/tmunzer/mistapi_python + + This package is licensed under the MIT License. + +-------------------------------------------------------------------------------- +WebSocket channel for Site events. +""" + +from mistapi import APISession +from mistapi.websockets.__ws_client import _MistWebsocket + + +class ClientsStatsEvents(_MistWebsocket): + """WebSocket stream for site clients stats events. + + Subscribes to the ``sites/{site_id}/stats/clients`` channel and delivers + real-time clients stats events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_ids : list[str] + UUIDs of the sites to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = SiteClientsStatsEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = SiteClientsStatsEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with SiteClientsStatsEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + site_ids: list[str], + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + channels = [f"/sites/{site_id}/stats/clients" for site_id in site_ids] + super().__init__( + mist_session, + channels=channels, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) + + +class DeviceCmdEvents(_MistWebsocket): + """WebSocket stream for site device command events. + + Subscribes to the ``sites/{site_id}/devices/{device_id}/cmd`` channel and delivers + real-time device command events for the given site and device. + + Device commands functions: + mistapi.api.v1.sites.devices.arpFromDevice + mistapi.api.v1.sites.devices.bounceDevicePort + mistapi.api.v1.sites.devices.cableTestFromSwitch + mistapi.api.v1.sites.devices.clearSiteDeviceMacTable + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + device_ids : list[str] + UUIDs of the devices to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with SiteDeviceCmdEvents(session, site_id="abc123", device_id="def456") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + site_id: str, + device_ids: list[str], + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + channels = [ + f"/sites/{site_id}/devices/{device_id}/cmd" for device_id in device_ids + ] + super().__init__( + mist_session, + channels=channels, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) + + +class DeviceStatsEvents(_MistWebsocket): + """WebSocket stream for site device stats events. + + Subscribes to the ``sites/{site_id}/stats/devices`` channel and delivers + real-time device stats events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_ids : list[str] + UUIDs of the sites to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = SiteDeviceStatsEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = SiteDeviceStatsEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with SiteDeviceStatsEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + site_ids: list[str], + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + channels = [f"/sites/{site_id}/stats/devices" for site_id in site_ids] + super().__init__( + mist_session, + channels=channels, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) + + +class DeviceEvents(_MistWebsocket): + """WebSocket stream for site device events. + + Subscribes to the ``sites/{site_id}/devices`` channel and delivers + real-time device events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_ids : list[str] + UUIDs of the sites to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = DeviceEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = DeviceEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with DeviceEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + site_ids: list[str], + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + channels = [f"/sites/{site_id}/devices" for site_id in site_ids] + super().__init__( + mist_session, + channels=channels, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) + + +class MxEdgesStatsEvents(_MistWebsocket): + """WebSocket stream for site MX edges stats events. + + Subscribes to the ``sites/{site_id}/stats/mxedges`` channel and delivers + real-time MX edges stats events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_ids : list[str] + UUIDs of the sites to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = MxEdgesStatsEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = MxEdgesStatsEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with MxEdgesStatsEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + site_ids: list[str], + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + channels = [f"/sites/{site_id}/stats/mxedges" for site_id in site_ids] + super().__init__( + mist_session, + channels=channels, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) + + +class MxEdgesEvents(_MistWebsocket): + """WebSocket stream for site MX edges events. + + Subscribes to the ``sites/{site_id}/mxedges`` channel and delivers + real-time MX edges events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_ids : list[str] + UUIDs of the sites to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = MxEdgesEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = MxEdgesEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with MxEdgesEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + site_ids: list[str], + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + channels = [f"/sites/{site_id}/mxedges" for site_id in site_ids] + super().__init__( + mist_session, + channels=channels, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) + + +class PcapEvents(_MistWebsocket): + """WebSocket stream for site PCAP events. + + Subscribes to the ``sites/{site_id}/pcap`` channel and delivers + real-time PCAP events for the given site. + + PARAMS + ----------- + mist_session : mistapi.APISession + Authenticated API session. + site_id : str + UUID of the site to stream events from. + ping_interval : int, default 30 + Interval in seconds to send WebSocket ping frames (keep-alive). + ping_timeout : int, default 10 + Time in seconds to wait for a ping response before considering the connection dead. + + EXAMPLE + ----------- + Callback style (background thread):: + + ws = PcapEvents(session, site_id="abc123") + ws.on_message(lambda data: print(data)) + ws.connect() # non-blocking, runs in background thread + input("Press Enter to stop") + ws.disconnect() + + Generator style:: + + ws = PcapEvents(session, site_id="abc123") + ws.connect(run_in_background=True) + for msg in ws.receive(): + process(msg) + + Context manager:: + + with PcapEvents(session, site_id="abc123") as ws: + ws.on_message(my_handler) + ws.connect() # non-blocking, runs in background thread + time.sleep(60) + """ + + def __init__( + self, + mist_session: APISession, + site_id: str, + ping_interval: int = 30, + ping_timeout: int = 10, + ) -> None: + channels = [f"/sites/{site_id}/pcaps"] + super().__init__( + mist_session, + channels=channels, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) diff --git a/tests/unit/test_api_request.py b/tests/unit/test_api_request.py index e69de29..2e93fef 100644 --- a/tests/unit/test_api_request.py +++ b/tests/unit/test_api_request.py @@ -0,0 +1,794 @@ +# tests/unit/test_api_request.py +""" +Comprehensive unit tests for mistapi.__api_request.APIRequest. + +Tests cover URL generation, query string encoding, token rotation, +proxy logging, header sanitisation, rate-limit handling, the shared +retry wrapper, and each HTTP-method convenience function. +""" + +import json +from unittest.mock import Mock, patch + +import pytest +import requests +from requests.exceptions import HTTPError + +from mistapi.__api_request import APIRequest +from mistapi.__api_response import APIResponse + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_api_request(cloud_uri="api.mist.com", tokens=None): + """Create an APIRequest with a mocked session for isolated testing.""" + with patch("mistapi.__api_request.requests.session") as mock_session_cls: + mock_session = Mock() + mock_session.headers = {} + mock_session.proxies = {} + mock_session.cookies = {} + mock_session_cls.return_value = mock_session + + req = APIRequest() + req._session = mock_session + req._cloud_uri = cloud_uri + + if tokens: + req._apitoken = list(tokens) + req._apitoken_index = 0 + req._session.headers.update({"Authorization": "Token " + tokens[0]}) + return req + + +def _mock_response( + status_code=200, + json_data=None, + headers=None, + raise_for_status_effect=None, +): + """Build a mock requests.Response.""" + resp = Mock(spec=requests.Response) + resp.status_code = status_code + resp.headers = headers or {} + resp.json.return_value = json_data if json_data is not None else {} + resp.content = json.dumps(json_data or {}).encode() + + # For _remove_auth_from_headers — attach a mock PreparedRequest + prep = Mock() + prep.headers = {} + resp.request = prep + + if raise_for_status_effect: + resp.raise_for_status.side_effect = raise_for_status_effect + else: + resp.raise_for_status.return_value = None + return resp + + +# =========================================================================== +# Tests +# =========================================================================== + + +class TestUrl: + """APIRequest._url() builds the full URL from cloud_uri + uri.""" + + def test_basic_url(self): + req = _make_api_request("api.mist.com") + assert req._url("/api/v1/self") == "https://api.mist.com/api/v1/self" + + def test_empty_uri(self): + req = _make_api_request("api.mist.com") + assert req._url("") == "https://api.mist.com" + + def test_eu_host(self): + req = _make_api_request("api.eu.mist.com") + assert req._url("/api/v1/orgs") == "https://api.eu.mist.com/api/v1/orgs" + + def test_uri_with_path_segments(self): + req = _make_api_request("api.mist.com") + org_id = "203d3d02-dbc0-4c1b-9f41-76896a3330f4" + uri = f"/api/v1/orgs/{org_id}/sites" + assert req._url(uri) == f"https://api.mist.com{uri}" + + +class TestGenQuery: + """APIRequest._gen_query() builds URL-encoded query strings.""" + + def test_none_returns_empty(self): + req = _make_api_request() + assert req._gen_query(None) == "" + + def test_empty_dict_returns_empty(self): + req = _make_api_request() + assert req._gen_query({}) == "" + + def test_single_param(self): + req = _make_api_request() + assert req._gen_query({"page": "2"}) == "?page=2" + + def test_multiple_params(self): + req = _make_api_request() + result = req._gen_query({"page": "1", "limit": "100"}) + assert result.startswith("?") + assert "page=1" in result + assert "limit=100" in result + + def test_special_chars_encoded(self): + req = _make_api_request() + result = req._gen_query({"filter": "name=hello world&foo"}) + assert "?" in result + # urllib.parse.urlencode encodes spaces as + and & as %26 + assert "hello+world" in result or "hello%20world" in result + assert "%26" in result + + def test_preserves_insertion_order(self): + req = _make_api_request() + # dict preserves insertion order in Python 3.7+ + result = req._gen_query({"a": "1", "b": "2", "c": "3"}) + assert result == "?a=1&b=2&c=3" + + +class TestNextApiToken: + """APIRequest._next_apitoken() rotates through tokens and raises + RuntimeError when only one token is available.""" + + def test_rotates_to_next_token(self): + req = _make_api_request(tokens=["tok_aaa1", "tok_bbb2", "tok_ccc3"]) + req._apitoken_index = 0 + req._next_apitoken() + assert req._apitoken_index == 1 + assert req._session.headers["Authorization"] == "Token tok_bbb2" + + def test_wraps_around_to_first_token(self): + req = _make_api_request(tokens=["tok_aaa1", "tok_bbb2"]) + req._apitoken_index = 1 + req._next_apitoken() + assert req._apitoken_index == 0 + assert req._session.headers["Authorization"] == "Token tok_aaa1" + + def test_single_token_raises_runtime_error(self): + req = _make_api_request(tokens=["tok_only1"]) + req._apitoken_index = 0 + with pytest.raises(RuntimeError, match="API rate limit reached"): + req._next_apitoken() + + def test_rotation_cycle(self): + tokens = ["tok_aaa1", "tok_bbb2", "tok_ccc3"] + req = _make_api_request(tokens=tokens) + req._apitoken_index = 0 + # Rotate through all tokens and back + req._next_apitoken() + assert req._apitoken_index == 1 + req._next_apitoken() + assert req._apitoken_index == 2 + req._next_apitoken() + assert req._apitoken_index == 0 + + +class TestLogProxy: + """APIRequest._log_proxy() prints a masked proxy URL.""" + + def test_prints_masked_proxy(self, capsys): + req = _make_api_request() + req._session.proxies = { + "https": "http://user:secret_password@proxy.example.com:8080" + } + req._log_proxy() + captured = capsys.readouterr() + assert "proxy.example.com:8080" in captured.out + assert "*********" in captured.out + assert "secret_password" not in captured.out + + def test_no_proxy_does_nothing(self, capsys): + req = _make_api_request() + req._session.proxies = {} + req._log_proxy() + captured = capsys.readouterr() + assert captured.out == "" + + def test_proxy_without_password(self, capsys): + req = _make_api_request() + req._session.proxies = {"https": "http://proxy.example.com:8080"} + req._log_proxy() + captured = capsys.readouterr() + assert "proxy.example.com:8080" in captured.out + + +class TestRemoveAuthFromHeaders: + """APIRequest._remove_auth_from_headers() masks sensitive headers.""" + + def test_masks_authorization(self): + req = _make_api_request() + resp = Mock() + resp.request.headers = {"Authorization": "Token secret123"} + headers = req._remove_auth_from_headers(resp) + assert headers["Authorization"] == "***hidden***" + + def test_masks_csrf_token(self): + req = _make_api_request() + resp = Mock() + resp.request.headers = {"X-CSRFToken": "csrf_value"} + headers = req._remove_auth_from_headers(resp) + assert headers["X-CSRFToken"] == "***hidden***" + + def test_masks_cookie(self): + req = _make_api_request() + resp = Mock() + resp.request.headers = {"Cookie": "session=abc123"} + headers = req._remove_auth_from_headers(resp) + assert headers["Cookie"] == "***hidden***" + + def test_masks_all_three(self): + req = _make_api_request() + resp = Mock() + resp.request.headers = { + "Authorization": "Token x", + "X-CSRFToken": "y", + "Cookie": "z", + "Content-Type": "application/json", + } + headers = req._remove_auth_from_headers(resp) + assert headers["Authorization"] == "***hidden***" + assert headers["X-CSRFToken"] == "***hidden***" + assert headers["Cookie"] == "***hidden***" + # Non-sensitive header left untouched + assert headers["Content-Type"] == "application/json" + + def test_leaves_non_sensitive_headers(self): + req = _make_api_request() + resp = Mock() + resp.request.headers = { + "Accept": "application/json", + "User-Agent": "python-requests/2.32", + } + headers = req._remove_auth_from_headers(resp) + assert headers["Accept"] == "application/json" + assert headers["User-Agent"] == "python-requests/2.32" + + +class TestHandleRateLimit: + """APIRequest._handle_rate_limit() sleeps according to Retry-After + header or falls back to exponential backoff.""" + + @patch("mistapi.__api_request.time.sleep") + def test_uses_retry_after_header(self, mock_sleep): + req = _make_api_request() + resp = Mock() + resp.headers = {"Retry-After": "10"} + req._handle_rate_limit(resp, attempt=0) + mock_sleep.assert_called_once_with(10) + + @patch("mistapi.__api_request.time.sleep") + def test_exponential_backoff_when_no_header(self, mock_sleep): + req = _make_api_request() + resp = Mock() + resp.headers = {} + # attempt 0 => 5 * (2**0) = 5 + req._handle_rate_limit(resp, attempt=0) + mock_sleep.assert_called_once_with(5) + + @patch("mistapi.__api_request.time.sleep") + def test_exponential_backoff_attempt_1(self, mock_sleep): + req = _make_api_request() + resp = Mock() + resp.headers = {} + # attempt 1 => 5 * (2**1) = 10 + req._handle_rate_limit(resp, attempt=1) + mock_sleep.assert_called_once_with(10) + + @patch("mistapi.__api_request.time.sleep") + def test_exponential_backoff_attempt_2(self, mock_sleep): + req = _make_api_request() + resp = Mock() + resp.headers = {} + # attempt 2 => 5 * (2**2) = 20 + req._handle_rate_limit(resp, attempt=2) + mock_sleep.assert_called_once_with(20) + + @patch("mistapi.__api_request.time.sleep") + def test_invalid_retry_after_falls_back(self, mock_sleep): + req = _make_api_request() + resp = Mock() + resp.headers = {"Retry-After": "not-a-number"} + # attempt 0 => fallback 5 * (2**0) = 5 + req._handle_rate_limit(resp, attempt=0) + mock_sleep.assert_called_once_with(5) + + +class TestRequestWithRetrySuccess: + """_request_with_retry() on successful responses.""" + + def test_returns_api_response_on_200(self): + req = _make_api_request() + resp = _mock_response(status_code=200, json_data={"ok": True}) + fn = Mock(return_value=resp) + + result = req._request_with_retry("test", fn, "https://api.mist.com/api/v1/self") + + assert isinstance(result, APIResponse) + assert result.status_code == 200 + fn.assert_called_once() + + def test_increments_count(self): + req = _make_api_request() + assert req._count == 0 + resp = _mock_response(status_code=200) + fn = Mock(return_value=resp) + + req._request_with_retry("test", fn, "https://example.com") + assert req._count == 1 + + req._request_with_retry("test", fn, "https://example.com") + assert req._count == 2 + + +class TestRequestWithRetryProxyError: + """_request_with_retry() on ProxyError.""" + + def test_proxy_error_sets_flag(self): + req = _make_api_request() + fn = Mock(side_effect=requests.exceptions.ProxyError("proxy down")) + + result = req._request_with_retry("test", fn, "https://example.com") + + assert result.proxy_error is True + assert result.status_code is None + assert req._count == 1 + + +class TestRequestWithRetryConnectionError: + """_request_with_retry() on ConnectionError.""" + + def test_connection_error_returns_none_response(self): + req = _make_api_request() + fn = Mock(side_effect=requests.exceptions.ConnectionError("no route")) + + result = req._request_with_retry("test", fn, "https://example.com") + + assert result.status_code is None + assert result.proxy_error is False + assert req._count == 1 + + +class TestRequestWithRetryHTTPErrorNon429: + """_request_with_retry() on non-429 HTTPError.""" + + def test_non_429_stops_immediately(self): + req = _make_api_request() + resp = _mock_response(status_code=403, json_data={"error": "forbidden"}) + http_err = HTTPError(response=resp) + resp.raise_for_status.side_effect = http_err + + fn = Mock(return_value=resp) + result = req._request_with_retry("test", fn, "https://example.com") + + # Should only call request_fn once (no retries for non-429) + fn.assert_called_once() + assert result.status_code == 403 + assert req._count == 1 + + def test_500_error(self): + req = _make_api_request() + resp = _mock_response(status_code=500, json_data={"error": "server error"}) + http_err = HTTPError(response=resp) + resp.raise_for_status.side_effect = http_err + + fn = Mock(return_value=resp) + result = req._request_with_retry("test", fn, "https://example.com") + + fn.assert_called_once() + assert result.status_code == 500 + + +class TestRequestWithRetry429: + """_request_with_retry() on 429 rate-limit responses.""" + + @patch("mistapi.__api_request.time.sleep") + def test_retries_on_429(self, mock_sleep): + req = _make_api_request(tokens=["tok_aaa1", "tok_bbb2"]) + + # First call returns 429, second call succeeds + resp_429 = _mock_response(status_code=429, headers={"Retry-After": "1"}) + http_err = HTTPError(response=resp_429) + resp_429.raise_for_status.side_effect = http_err + + resp_ok = _mock_response(status_code=200, json_data={"ok": True}) + + fn = Mock(side_effect=[resp_429, resp_ok]) + result = req._request_with_retry("test", fn, "https://example.com") + + assert fn.call_count == 2 + assert result.status_code == 200 + mock_sleep.assert_called_once_with(1) + + @patch("mistapi.__api_request.time.sleep") + def test_rotates_token_on_429(self, mock_sleep): + req = _make_api_request(tokens=["tok_aaa1", "tok_bbb2"]) + + resp_429 = _mock_response(status_code=429, headers={"Retry-After": "1"}) + http_err = HTTPError(response=resp_429) + resp_429.raise_for_status.side_effect = http_err + + resp_ok = _mock_response(status_code=200) + fn = Mock(side_effect=[resp_429, resp_ok]) + req._request_with_retry("test", fn, "https://example.com") + + # Token should have rotated from index 0 to index 1 + assert req._apitoken_index == 1 + + @patch("mistapi.__api_request.time.sleep") + def test_429_exhausted_after_max_retries(self, mock_sleep): + req = _make_api_request(tokens=["tok_aaa1", "tok_bbb2", "tok_ccc3", "tok_ddd4"]) + + resp_429 = _mock_response(status_code=429, headers={"Retry-After": "1"}) + http_err = HTTPError(response=resp_429) + resp_429.raise_for_status.side_effect = http_err + + # All 4 calls (1 initial + 3 retries) return 429 + fn = Mock(return_value=resp_429) + result = req._request_with_retry("test", fn, "https://example.com") + + # MAX_429_RETRIES = 3, so total calls = 4 (attempt 0,1,2,3) + assert fn.call_count == 4 + assert result.status_code == 429 + assert mock_sleep.call_count == 3 + assert req._count == 1 + + @patch("mistapi.__api_request.time.sleep") + def test_429_single_token_still_retries_with_backoff(self, mock_sleep): + """Even with one token (RuntimeError on rotation), retry with backoff.""" + req = _make_api_request(tokens=["tok_only1"]) + + resp_429 = _mock_response(status_code=429, headers={}) + http_err = HTTPError(response=resp_429) + resp_429.raise_for_status.side_effect = http_err + + resp_ok = _mock_response(status_code=200) + fn = Mock(side_effect=[resp_429, resp_ok]) + result = req._request_with_retry("test", fn, "https://example.com") + + assert fn.call_count == 2 + assert result.status_code == 200 + # Backoff with attempt=0 => 5 * 1 = 5 + mock_sleep.assert_called_once_with(5) + + @patch("mistapi.__api_request.time.sleep") + def test_429_calls_handle_rate_limit(self, mock_sleep): + req = _make_api_request(tokens=["tok_aaa1", "tok_bbb2"]) + + resp_429 = _mock_response(status_code=429, headers={"Retry-After": "7"}) + http_err = HTTPError(response=resp_429) + resp_429.raise_for_status.side_effect = http_err + + resp_ok = _mock_response(status_code=200) + fn = Mock(side_effect=[resp_429, resp_ok]) + + with patch.object( + req, "_handle_rate_limit", wraps=req._handle_rate_limit + ) as wrapped: + req._request_with_retry("test", fn, "https://example.com") + wrapped.assert_called_once_with(resp_429, 0) + + mock_sleep.assert_called_once_with(7) + + +class TestRequestWithRetryGenericException: + """_request_with_retry() on unexpected exceptions.""" + + def test_generic_exception_breaks_loop(self): + req = _make_api_request() + fn = Mock(side_effect=ValueError("something unexpected")) + result = req._request_with_retry("test", fn, "https://example.com") + + fn.assert_called_once() + assert result.status_code is None + assert req._count == 1 + + +class TestMistGet: + """mist_get() delegates to _request_with_retry with correct URL+query.""" + + def test_get_without_query(self): + req = _make_api_request() + resp = _mock_response(status_code=200, json_data={"items": []}) + req._session.get.return_value = resp + + result = req.mist_get("/api/v1/self") + + assert result.status_code == 200 + req._session.get.assert_called_once_with("https://api.mist.com/api/v1/self") + + def test_get_with_query(self): + req = _make_api_request() + resp = _mock_response(status_code=200, json_data={}) + req._session.get.return_value = resp + + result = req.mist_get("/api/v1/orgs", query={"page": "2", "limit": "50"}) + + assert result.status_code == 200 + expected_url = "https://api.mist.com/api/v1/orgs?page=2&limit=50" + req._session.get.assert_called_once_with(expected_url) + + def test_get_increments_count(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.get.return_value = resp + + req.mist_get("/api/v1/self") + assert req.get_request_count() == 1 + + +class TestMistPost: + """mist_post() sends JSON body or raw string data.""" + + def test_post_dict_body(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + body = {"name": "Test Site"} + result = req.mist_post("/api/v1/sites", body=body) + + assert result.status_code == 200 + req._session.post.assert_called_once_with( + "https://api.mist.com/api/v1/sites", + json=body, + headers={"Content-Type": "application/json"}, + ) + + def test_post_string_body(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + body = '{"name": "Test Site"}' + result = req.mist_post("/api/v1/sites", body=body) + + assert result.status_code == 200 + req._session.post.assert_called_once_with( + "https://api.mist.com/api/v1/sites", + data=body, + headers={"Content-Type": "application/json"}, + ) + + def test_post_list_body(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + body = [{"name": "Site A"}, {"name": "Site B"}] + result = req.mist_post("/api/v1/sites", body=body) + + assert result.status_code == 200 + req._session.post.assert_called_once_with( + "https://api.mist.com/api/v1/sites", + json=body, + headers={"Content-Type": "application/json"}, + ) + + def test_post_none_body(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + result = req.mist_post("/api/v1/sites", body=None) + + assert result.status_code == 200 + req._session.post.assert_called_once_with( + "https://api.mist.com/api/v1/sites", + json=None, + headers={"Content-Type": "application/json"}, + ) + + +class TestMistPut: + """mist_put() sends JSON body or raw string data.""" + + def test_put_dict_body(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.put.return_value = resp + + body = {"name": "Updated Site"} + result = req.mist_put("/api/v1/sites/123", body=body) + + assert result.status_code == 200 + req._session.put.assert_called_once_with( + "https://api.mist.com/api/v1/sites/123", + json=body, + headers={"Content-Type": "application/json"}, + ) + + def test_put_string_body(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.put.return_value = resp + + body = '{"name": "Updated Site"}' + result = req.mist_put("/api/v1/sites/123", body=body) + + assert result.status_code == 200 + req._session.put.assert_called_once_with( + "https://api.mist.com/api/v1/sites/123", + data=body, + headers={"Content-Type": "application/json"}, + ) + + def test_put_none_body(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.put.return_value = resp + + result = req.mist_put("/api/v1/sites/123", body=None) + + assert result.status_code == 200 + req._session.put.assert_called_once_with( + "https://api.mist.com/api/v1/sites/123", + json=None, + headers={"Content-Type": "application/json"}, + ) + + +class TestMistDelete: + """mist_delete() delegates with correct URL+query.""" + + def test_delete_without_query(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.delete.return_value = resp + + result = req.mist_delete("/api/v1/sites/123") + + assert result.status_code == 200 + req._session.delete.assert_called_once_with( + "https://api.mist.com/api/v1/sites/123" + ) + + def test_delete_with_query(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.delete.return_value = resp + + result = req.mist_delete("/api/v1/sites/123", query={"force": "true"}) + + assert result.status_code == 200 + req._session.delete.assert_called_once_with( + "https://api.mist.com/api/v1/sites/123?force=true" + ) + + +class TestMistPostFile: + """mist_post_file() builds multipart form data and delegates to retry wrapper.""" + + def test_post_file_with_file_key(self, tmp_path): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + # Create a real temporary file + test_file = tmp_path / "upload.bin" + test_file.write_bytes(b"file content here") + + result = req.mist_post_file( + "/api/v1/sites/123/maps/import", + multipart_form_data={"file": str(test_file)}, + ) + + assert result.status_code == 200 + req._session.post.assert_called_once() + call_kwargs = req._session.post.call_args + files_arg = call_kwargs.kwargs.get("files") or call_kwargs[1].get("files") + assert "file" in files_arg + # Tuple structure: (filename, file_obj, content_type) + assert files_arg["file"][0] == "upload.bin" + assert files_arg["file"][2] == "application/octet-stream" + + def test_post_file_with_csv_key(self, tmp_path): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + csv_file = tmp_path / "data.csv" + csv_file.write_text("col1,col2\na,b\n") + + result = req.mist_post_file( + "/api/v1/orgs/123/inventory", + multipart_form_data={"csv": str(csv_file)}, + ) + + assert result.status_code == 200 + call_kwargs = req._session.post.call_args + files_arg = call_kwargs.kwargs.get("files") or call_kwargs[1].get("files") + assert "csv" in files_arg + assert files_arg["csv"][0] == "data.csv" + + def test_post_file_with_json_field(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + result = req.mist_post_file( + "/api/v1/sites/123/maps", + multipart_form_data={"json": {"name": "Floor 1"}}, + ) + + assert result.status_code == 200 + call_kwargs = req._session.post.call_args + files_arg = call_kwargs.kwargs.get("files") or call_kwargs[1].get("files") + assert "json" in files_arg + # Non-file keys produce (None, json_string) tuples + assert files_arg["json"][0] is None + assert json.loads(files_arg["json"][1]) == {"name": "Floor 1"} + + def test_post_file_none_defaults_to_empty(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + result = req.mist_post_file("/api/v1/sites/123/maps") + + assert result.status_code == 200 + call_kwargs = req._session.post.call_args + files_arg = call_kwargs.kwargs.get("files") or call_kwargs[1].get("files") + assert files_arg == {} + + def test_post_file_skips_falsy_values(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + result = req.mist_post_file( + "/api/v1/sites/123/maps", + multipart_form_data={"json": None, "file": ""}, + ) + + assert result.status_code == 200 + call_kwargs = req._session.post.call_args + files_arg = call_kwargs.kwargs.get("files") or call_kwargs[1].get("files") + assert files_arg == {} + + def test_post_file_missing_file_handled_gracefully(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.post.return_value = resp + + # Point to a file that does not exist + result = req.mist_post_file( + "/api/v1/sites/123/maps", + multipart_form_data={"file": "/nonexistent/path/file.bin"}, + ) + + assert result.status_code == 200 + # The OSError is caught silently; file key is not in the generated data + call_kwargs = req._session.post.call_args + files_arg = call_kwargs.kwargs.get("files") or call_kwargs[1].get("files") + assert "file" not in files_arg + + +class TestGetRequestCount: + """get_request_count() returns the cumulative request count.""" + + def test_initial_count_is_zero(self): + req = _make_api_request() + assert req.get_request_count() == 0 + + def test_count_after_requests(self): + req = _make_api_request() + resp = _mock_response(status_code=200) + req._session.get.return_value = resp + + req.mist_get("/api/v1/self") + req.mist_get("/api/v1/self") + req.mist_get("/api/v1/self") + + assert req.get_request_count() == 3 + + def test_count_increments_on_error(self): + req = _make_api_request() + fn = Mock(side_effect=requests.exceptions.ConnectionError("fail")) + req._request_with_retry("test", fn, "https://example.com") + assert req.get_request_count() == 1 diff --git a/tests/unit/test_api_response.py b/tests/unit/test_api_response.py index e69de29..1c283fa 100644 --- a/tests/unit/test_api_response.py +++ b/tests/unit/test_api_response.py @@ -0,0 +1,339 @@ +""" +Unit tests for mistapi.__api_response.APIResponse + +Tests cover: +- Construction with None response (default field values) +- Construction with a valid JSON response (data, status_code, headers) +- _check_next() when "next" key is present in response data +- _check_next() pagination via X-Page-Total / X-Page-Limit / X-Page-Page headers +- _check_next() on the last page (next stays None) +- _check_next() when the URL already contains a page= parameter +- Error responses (4xx status codes) +- proxy_error flag propagation +- Non-JSON responses (exception path in __init__) +""" + +import json +from unittest.mock import Mock + +import pytest + +from mistapi.__api_response import APIResponse + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_mock_response(status_code=200, data=None, headers=None, json_raises=False): + """Build a mock requests.Response with the given attributes.""" + mock = Mock() + mock.status_code = status_code + mock.headers = headers or {} + if json_raises: + mock.content = b"not-json" + mock.json.side_effect = ValueError("No JSON") + else: + payload = data if data is not None else {} + mock.content = json.dumps(payload).encode() + mock.json.return_value = payload + return mock + + +# --------------------------------------------------------------------------- +# Tests: construction / default fields +# --------------------------------------------------------------------------- + + +class TestAPIResponseConstruction: + """Tests for APIResponse.__init__ with different response inputs.""" + + def test_none_response_defaults(self): + """When response is None every field should keep its default value.""" + resp = APIResponse(response=None, url="https://api.mist.com/api/v1/test") + + assert resp.raw_data == "" + assert resp.data == {} + assert resp.url == "https://api.mist.com/api/v1/test" + assert resp.next is None + assert resp.headers is None + assert resp.status_code is None + assert resp.proxy_error is False + + def test_200_response_with_json(self, api_response_factory): + """A 200 response with JSON body should populate data, status_code, headers.""" + data = {"id": "abc-123", "name": "widget"} + headers = {"Content-Type": "application/json"} + + resp = api_response_factory(status_code=200, data=data, headers=headers) + + assert resp.status_code == 200 + assert resp.data == data + assert resp.headers == headers + assert resp.url == "https://api.mist.com/api/v1/test" + + def test_raw_data_is_string_of_content(self): + """raw_data should be str(response.content).""" + data = {"key": "value"} + mock = _make_mock_response(data=data) + resp = APIResponse(response=mock, url="https://host/api/v1/x") + + assert resp.raw_data == str(json.dumps(data).encode()) + + def test_proxy_error_true(self): + """proxy_error=True should be stored on the instance.""" + resp = APIResponse(response=None, url="https://host/api/v1/x", proxy_error=True) + + assert resp.proxy_error is True + + def test_proxy_error_false_by_default(self): + """proxy_error defaults to False.""" + resp = APIResponse(response=None, url="https://host/api/v1/x") + + assert resp.proxy_error is False + + def test_proxy_error_with_response(self): + """proxy_error flag should propagate even when a valid response is present.""" + mock = _make_mock_response(status_code=502, data={"error": "bad gateway"}) + resp = APIResponse(response=mock, url="https://host/api/v1/x", proxy_error=True) + + assert resp.proxy_error is True + assert resp.status_code == 502 + + +# --------------------------------------------------------------------------- +# Tests: error responses +# --------------------------------------------------------------------------- + + +class TestAPIResponseErrors: + """Tests for error HTTP status codes and error payloads.""" + + @pytest.mark.parametrize("status_code", [400, 401, 403, 404, 500, 502, 503]) + def test_error_status_codes_stored(self, api_response_factory, status_code): + """Error status codes should be stored without raising.""" + data = {"error": "something went wrong"} + resp = api_response_factory(status_code=status_code, data=data) + + assert resp.status_code == status_code + assert resp.data == data + + def test_error_key_in_200_response(self, api_response_factory): + """A 200 with an 'error' key in the body should still store data.""" + data = {"error": "unexpected error in body"} + resp = api_response_factory(status_code=200, data=data) + + assert resp.status_code == 200 + assert resp.data == data + + def test_non_json_response_handled_gracefully(self): + """When response.json() raises, the exception path should not propagate.""" + mock = _make_mock_response(json_raises=True) + # Should not raise + resp = APIResponse(response=mock, url="https://host/api/v1/x") + + # data stays at default because json() failed + assert resp.data == {} + assert resp.status_code == 200 + assert resp.next is None + + +# --------------------------------------------------------------------------- +# Tests: _check_next with "next" in data +# --------------------------------------------------------------------------- + + +class TestCheckNextFromData: + """Tests for _check_next() when the response body contains a 'next' key.""" + + def test_next_in_data(self, api_response_factory): + """When data contains 'next', self.next should be set from data.""" + data = {"results": [], "next": "/api/v1/test?page=2"} + resp = api_response_factory(data=data) + + assert resp.next == "/api/v1/test?page=2" + + def test_next_in_data_takes_precedence_over_headers(self): + """'next' in data should take precedence over pagination headers.""" + headers = { + "X-Page-Total": "100", + "X-Page-Limit": "10", + "X-Page-Page": "1", + } + data = {"next": "/api/v1/custom-next"} + mock = _make_mock_response(data=data, headers=headers) + resp = APIResponse(response=mock, url="https://host/api/v1/items") + + assert resp.next == "/api/v1/custom-next" + + def test_next_value_none_in_data(self, api_response_factory): + """When data['next'] is None, self.next should be set to None.""" + data = {"results": [], "next": None} + resp = api_response_factory(data=data) + + assert resp.next is None + + +# --------------------------------------------------------------------------- +# Tests: _check_next with pagination headers +# --------------------------------------------------------------------------- + + +class TestCheckNextFromHeaders: + """Tests for _check_next() computing the next URL from pagination headers.""" + + def _make_paginated_response(self, total, limit, page, url=None): + """Helper: build an APIResponse with pagination headers.""" + url = url or "https://api.mist.com/api/v1/sites" + headers = { + "X-Page-Total": str(total), + "X-Page-Limit": str(limit), + "X-Page-Page": str(page), + } + mock = _make_mock_response(data={"results": []}, headers=headers) + return APIResponse(response=mock, url=url) + + def test_next_page_computed_from_headers(self): + """When there are more pages, next should be computed from headers.""" + resp = self._make_paginated_response(total=50, limit=10, page=1) + + assert resp.next == "/api/v1/sites?page=2" + + def test_next_page_with_existing_query_string(self): + """When URL already has a query string, page should use '&'.""" + resp = self._make_paginated_response( + total=50, limit=10, page=1, url="https://api.mist.com/api/v1/sites?limit=10" + ) + + assert resp.next == "/api/v1/sites?limit=10&page=2" + + def test_last_page_next_is_none(self): + """On the last page (limit*page >= total), next should remain None.""" + resp = self._make_paginated_response(total=30, limit=10, page=3) + + assert resp.next is None + + def test_beyond_last_page_next_is_none(self): + """When limit*page > total, next should remain None.""" + resp = self._make_paginated_response(total=25, limit=10, page=3) + + assert resp.next is None + + def test_single_page_result(self): + """When all results fit in one page, next stays None.""" + resp = self._make_paginated_response(total=5, limit=10, page=1) + + assert resp.next is None + + def test_existing_page_param_replaced(self): + """When URL already contains page=N, it should be replaced.""" + resp = self._make_paginated_response( + total=100, + limit=10, + page=2, + url="https://api.mist.com/api/v1/sites?limit=10&page=2", + ) + + assert resp.next == "/api/v1/sites?limit=10&page=3" + + def test_existing_page_param_first_page(self): + """Replacing page=1 with page=2 when page param already in URL.""" + resp = self._make_paginated_response( + total=100, + limit=10, + page=1, + url="https://api.mist.com/api/v1/sites?page=1&limit=10", + ) + + assert resp.next == "/api/v1/sites?page=2&limit=10" + + def test_missing_total_header(self): + """When X-Page-Total is missing, next should remain None.""" + headers = { + "X-Page-Limit": "10", + "X-Page-Page": "1", + } + mock = _make_mock_response(data={"results": []}, headers=headers) + resp = APIResponse(response=mock, url="https://host/api/v1/items") + + assert resp.next is None + + def test_missing_limit_header(self): + """When X-Page-Limit is missing, next should remain None.""" + headers = { + "X-Page-Total": "50", + "X-Page-Page": "1", + } + mock = _make_mock_response(data={"results": []}, headers=headers) + resp = APIResponse(response=mock, url="https://host/api/v1/items") + + assert resp.next is None + + def test_missing_page_header(self): + """When X-Page-Page is missing, next should remain None.""" + headers = { + "X-Page-Total": "50", + "X-Page-Limit": "10", + } + mock = _make_mock_response(data={"results": []}, headers=headers) + resp = APIResponse(response=mock, url="https://host/api/v1/items") + + assert resp.next is None + + def test_non_numeric_headers_handled(self): + """Non-numeric pagination header values should not raise.""" + headers = { + "X-Page-Total": "abc", + "X-Page-Limit": "10", + "X-Page-Page": "1", + } + mock = _make_mock_response(data={"results": []}, headers=headers) + resp = APIResponse(response=mock, url="https://host/api/v1/items") + + assert resp.next is None + + def test_no_headers_at_all(self, api_response_factory): + """When there are no pagination headers and no 'next' in data, next is None.""" + resp = api_response_factory(data={"results": []}) + + assert resp.next is None + + def test_pagination_strips_host_prefix(self): + """The computed next URL should be a relative /api/... path, not absolute.""" + resp = self._make_paginated_response( + total=100, + limit=10, + page=1, + url="https://api.eu.mist.com/api/v1/orgs/abc/devices", + ) + + assert resp.next.startswith("/api/v1/") + assert "api.eu.mist.com" not in resp.next + assert resp.next == "/api/v1/orgs/abc/devices?page=2" + + +# --------------------------------------------------------------------------- +# Tests: data types preserved +# --------------------------------------------------------------------------- + + +class TestDataTypes: + """Verify that different JSON response shapes are handled correctly.""" + + def test_list_response(self): + """An API that returns a JSON list should store it in data.""" + data = [{"id": "a"}, {"id": "b"}] + mock = _make_mock_response(data=data) + resp = APIResponse(response=mock, url="https://host/api/v1/x") + + assert resp.data == data + # When data is a list, there is no 'next' key lookup issue + assert resp.next is None + + def test_empty_dict_response(self, api_response_factory): + """An empty dict body should result in data=={} and next==None.""" + resp = api_response_factory(data={}) + + assert resp.data == {} + assert resp.next is None diff --git a/tests/unit/test_api_session.py b/tests/unit/test_api_session.py index f3eed30..e28076c 100644 --- a/tests/unit/test_api_session.py +++ b/tests/unit/test_api_session.py @@ -403,3 +403,157 @@ def test_environment_variable_type_handling(self) -> None: # Assert - Should convert string to int assert session._console_log_level == 30 # int, not '30' assert session._logging_log_level == 20 # int, not '20' + + +class TestNewSession: + """Test _new_session() method""" + + def test_new_session_returns_session_with_headers( + self, authenticated_session + ) -> None: + """_new_session creates a requests.Session with correct Accept header""" + with patch("mistapi.__api_session.requests.session") as mock_session_cls: + mock_sess = Mock() + mock_sess.headers = {} + mock_sess.proxies = {} + mock_session_cls.return_value = mock_sess + + result = authenticated_session._new_session() + + assert ( + result.headers["Accept"] == "application/json, application/vnd.api+json" + ) + + def test_new_session_sets_auth_header(self, authenticated_session) -> None: + """_new_session includes Authorization header when API token is configured""" + with patch("mistapi.__api_session.requests.session") as mock_session_cls: + mock_sess = Mock() + mock_sess.headers = {} + mock_sess.proxies = {} + mock_session_cls.return_value = mock_sess + + result = authenticated_session._new_session() + + expected_token = authenticated_session._apitoken[ + authenticated_session._apitoken_index + ] + assert result.headers["Authorization"] == f"Token {expected_token}" + + def test_new_session_sets_proxies(self) -> None: + """_new_session applies proxies when configured""" + with patch.dict(os.environ, {}, clear=True): + with patch("mistapi.__api_session.requests.session") as mock_session_cls: + mock_sess = Mock() + mock_sess.headers = {} + mock_sess.proxies = {} + mock_session_cls.return_value = mock_sess + + session = APISession( + console_log_level=50, https_proxy="http://proxy:8080" + ) + session._apitoken = [] + + # Create new session - use a real dict for proxies so update works + new_mock = Mock() + new_mock.headers = {} + new_mock.proxies = {} + mock_session_cls.return_value = new_mock + + result = session._new_session() + assert result.proxies == {"https": "http://proxy:8080"} + + def test_new_session_no_auth_without_token(self, isolated_session) -> None: + """_new_session omits Authorization when no token configured""" + with patch("mistapi.__api_session.requests.session") as mock_session_cls: + mock_sess = Mock() + mock_sess.headers = {} + mock_sess.proxies = {} + mock_session_cls.return_value = mock_sess + + result = isolated_session._new_session() + + assert "Authorization" not in result.headers + + +class TestSetApiTokenValidation: + """Test set_api_token with validate=False""" + + def test_set_api_token_no_validate(self, isolated_session) -> None: + """set_api_token(validate=False) accepts tokens without calling _check_api_tokens""" + isolated_session.set_cloud("api.mist.com") + with patch.object(isolated_session, "_check_api_tokens") as mock_check: + isolated_session.set_api_token("token_abc_123", validate=False) + + mock_check.assert_not_called() + assert isolated_session._apitoken == ["token_abc_123"] + assert isolated_session._apitoken_index == 0 + + def test_set_api_token_validate_true_calls_check(self, isolated_session) -> None: + """set_api_token(validate=True) calls _check_api_tokens""" + isolated_session.set_cloud("api.mist.com") + with patch.object( + isolated_session, "_check_api_tokens", return_value=["token_abc"] + ) as mock_check: + isolated_session.set_api_token("token_abc", validate=True) + + mock_check.assert_called_once_with(["token_abc"]) + + +class TestDeleteApiToken: + """Test delete_api_token method""" + + def test_delete_api_token_calls_mist_delete(self, authenticated_session) -> None: + """delete_api_token delegates to mist_delete with correct URI""" + with patch.object(authenticated_session, "mist_delete") as mock_delete: + mock_resp = Mock() + mock_delete.return_value = mock_resp + + result = authenticated_session.delete_api_token("token-id-123") + + mock_delete.assert_called_once_with("/api/v1/self/apitokens/token-id-123") + assert result is mock_resp + + +class TestVaultAttrsCleanup: + """Test that vault attributes are cleaned up after init""" + + def test_no_vault_attrs_without_vault(self, isolated_session) -> None: + """Vault attributes are deleted when vault_path is not set""" + assert not hasattr(isolated_session, "_vault_url") + assert not hasattr(isolated_session, "_vault_token") + assert not hasattr(isolated_session, "_vault_path") + assert not hasattr(isolated_session, "_vault_mount_point") + + +class TestLoadEnvVault: + """Test _load_env vault variable loading""" + + def test_load_env_vault_vars(self) -> None: + """_load_env populates _vault_* from env when not already set""" + env_vars = { + "MIST_VAULT_URL": "https://vault.example.com", + "MIST_VAULT_PATH": "secret/data/mist", + "MIST_VAULT_MOUNT_POINT": "kv", + "MIST_VAULT_TOKEN": "vault-token-123", + } + with patch.dict(os.environ, env_vars, clear=True): + with patch("mistapi.__api_session.requests.session") as mock_cls: + mock_sess = Mock() + mock_sess.headers = {} + mock_sess.proxies = {} + mock_cls.return_value = mock_sess + with patch("mistapi.__api_session.hvac") as mock_hvac: + mock_client = Mock() + mock_client.is_authenticated.return_value = True + mock_client.secrets.kv.v2.read_secret.return_value = { + "data": {"data": {}} + } + mock_hvac.Client.return_value = mock_client + + session = APISession(console_log_level=50) + + # Vault was loaded since _vault_path was set from env + mock_hvac.Client.assert_called_once_with( + url="https://vault.example.com", + token="vault-token-123", + ) diff --git a/tests/unit/test_init.py b/tests/unit/test_init.py new file mode 100644 index 0000000..b8d4707 --- /dev/null +++ b/tests/unit/test_init.py @@ -0,0 +1,288 @@ +# tests/unit/test_init.py +""" +Unit tests for mistapi.__init__ module. + +Tests verify: +- Direct imports (APISession, get_all, get_next, __version__, __author__) +- Lazy subpackage loading (api, utils, websockets, cli) +- AttributeError for unknown attributes +""" + +import importlib +import types +from unittest.mock import patch + +import pytest + + +class TestDirectImports: + """Test that directly-imported names are available on the mistapi module.""" + + def test_apisession_is_available(self): + """APISession should be importable from mistapi.""" + from mistapi import APISession + + assert APISession is not None + + def test_apisession_is_the_real_class(self): + """APISession should be the class from mistapi.__api_session.""" + from mistapi import APISession + from mistapi.__api_session import APISession as RealAPISession + + assert APISession is RealAPISession + + def test_get_all_is_available(self): + """get_all should be importable from mistapi.""" + from mistapi import get_all + + assert callable(get_all) + + def test_get_all_is_the_real_function(self): + """get_all should be the function from mistapi.__pagination.""" + from mistapi import get_all + from mistapi.__pagination import get_all as real_get_all + + assert get_all is real_get_all + + def test_get_next_is_available(self): + """get_next should be importable from mistapi.""" + from mistapi import get_next + + assert callable(get_next) + + def test_get_next_is_the_real_function(self): + """get_next should be the function from mistapi.__pagination.""" + from mistapi import get_next + from mistapi.__pagination import get_next as real_get_next + + assert get_next is real_get_next + + def test_version_is_available(self): + """__version__ should be importable from mistapi.""" + from mistapi import __version__ + + assert isinstance(__version__, str) + assert len(__version__) > 0 + + def test_version_matches_version_module(self): + """__version__ should match the value in mistapi.__version.""" + import mistapi + from mistapi.__version import __version__ as real_version + + assert mistapi.__version__ == real_version + + def test_author_is_available(self): + """__author__ should be importable from mistapi.""" + from mistapi import __author__ + + assert isinstance(__author__, str) + assert len(__author__) > 0 + + def test_author_matches_version_module(self): + """__author__ should match the value in mistapi.__version.""" + import mistapi + from mistapi.__version import __author__ as real_author + + assert mistapi.__author__ == real_author + + +class TestLazyImportApi: + """Test lazy loading of the mistapi.api subpackage.""" + + def test_api_loads_on_access(self): + """Accessing mistapi.api should trigger lazy import and return a module.""" + import mistapi + + # Remove cached attribute if present so __getattr__ fires + mistapi.__dict__.pop("api", None) + + with patch("importlib.import_module", wraps=importlib.import_module) as spy: + result = mistapi.api + spy.assert_any_call("mistapi.api") + + assert isinstance(result, types.ModuleType) + assert result.__name__ == "mistapi.api" + + def test_api_is_cached_after_first_access(self): + """After first access, mistapi.api should be cached in globals.""" + import mistapi + + # Force a fresh lazy load + mistapi.__dict__.pop("api", None) + first = mistapi.api + second = mistapi.api + + assert first is second + assert "api" in mistapi.__dict__ + + +class TestLazyImportUtils: + """Test lazy loading of the mistapi.device_utils subpackage.""" + + def test_utils_loads_on_access(self): + """Accessing mistapi.device_utils should trigger lazy import and return a module.""" + import mistapi + + mistapi.__dict__.pop("device_utils", None) + + with patch("importlib.import_module", wraps=importlib.import_module) as spy: + result = mistapi.device_utils + spy.assert_any_call("mistapi.device_utils") + + assert isinstance(result, types.ModuleType) + assert result.__name__ == "mistapi.device_utils" + + def test_utils_is_cached_after_first_access(self): + """After first access, mistapi.device_utils should be cached in globals.""" + import mistapi + + mistapi.__dict__.pop("device_utils", None) + first = mistapi.device_utils + second = mistapi.device_utils + + assert first is second + assert "device_utils" in mistapi.__dict__ + + +class TestLazyImportWebsockets: + """Test lazy loading of the mistapi.websockets subpackage.""" + + def test_websockets_loads_on_access(self): + """Accessing mistapi.websockets should trigger lazy import and return a module.""" + import mistapi + + mistapi.__dict__.pop("websockets", None) + + with patch("importlib.import_module", wraps=importlib.import_module) as spy: + result = mistapi.websockets + spy.assert_any_call("mistapi.websockets") + + assert isinstance(result, types.ModuleType) + assert result.__name__ == "mistapi.websockets" + + def test_websockets_is_cached_after_first_access(self): + """After first access, mistapi.websockets should be cached in globals.""" + import mistapi + + mistapi.__dict__.pop("websockets", None) + first = mistapi.websockets + second = mistapi.websockets + + assert first is second + assert "websockets" in mistapi.__dict__ + + +class TestLazyImportCli: + """Test lazy loading of the mistapi.cli subpackage.""" + + def test_cli_loads_on_access(self): + """Accessing mistapi.cli should trigger lazy import and return a module.""" + import mistapi + + mistapi.__dict__.pop("cli", None) + + with patch("importlib.import_module", wraps=importlib.import_module) as spy: + result = mistapi.cli + spy.assert_any_call("mistapi.cli") + + assert isinstance(result, types.ModuleType) + assert result.__name__ == "mistapi.cli" + + def test_cli_is_cached_after_first_access(self): + """After first access, mistapi.cli should be cached in globals.""" + import mistapi + + mistapi.__dict__.pop("cli", None) + first = mistapi.cli + second = mistapi.cli + + assert first is second + assert "cli" in mistapi.__dict__ + + +class TestLazyImportMechanism: + """Test the __getattr__ mechanism in general.""" + + def test_lazy_subpackages_dict_has_expected_keys(self): + """_LAZY_SUBPACKAGES should contain the expected subpackage mappings.""" + import mistapi + + expected = {"api", "cli", "websockets", "device_utils"} + assert set(mistapi._LAZY_SUBPACKAGES.keys()) == expected + + def test_lazy_subpackages_values_are_dotted_paths(self): + """Each value in _LAZY_SUBPACKAGES should be a fully-qualified module path.""" + import mistapi + + for key, value in mistapi._LAZY_SUBPACKAGES.items(): + assert value == f"mistapi.{key}" + + def test_getattr_delegates_to_importlib(self): + """__getattr__ should call importlib.import_module for known subpackages.""" + import mistapi + + sentinel = types.ModuleType("mistapi.api") + mistapi.__dict__.pop("api", None) + + with patch("importlib.import_module", return_value=sentinel) as mock_import: + result = mistapi.__getattr__("api") + + mock_import.assert_called_once_with("mistapi.api") + assert result is sentinel + + def test_getattr_caches_result_in_globals(self): + """__getattr__ should store the imported module in the package globals.""" + import mistapi + + sentinel = types.ModuleType("mistapi.device_utils") + mistapi.__dict__.pop("device_utils", None) + + with patch("importlib.import_module", return_value=sentinel): + mistapi.__getattr__("device_utils") + + assert mistapi.__dict__["device_utils"] is sentinel + + +class TestInvalidAttribute: + """Test that accessing undefined attributes raises AttributeError.""" + + def test_unknown_attribute_raises_attribute_error(self): + """Accessing a non-existent attribute should raise AttributeError.""" + import mistapi + + with pytest.raises( + AttributeError, match=r"module 'mistapi' has no attribute 'nonexistent'" + ): + mistapi.__getattr__("nonexistent") + + def test_unknown_attribute_via_getattr_builtin(self): + """getattr on an unknown name without default should raise AttributeError.""" + import mistapi + + with pytest.raises(AttributeError): + getattr(mistapi, "totally_made_up_attribute_xyz") + + def test_unknown_attribute_with_default(self): + """getattr with a default should return the default for unknown names.""" + import mistapi + + result = getattr(mistapi, "totally_made_up_attribute_xyz", "fallback") + assert result == "fallback" + + def test_error_message_includes_attribute_name(self): + """The AttributeError message should include the missing attribute name.""" + import mistapi + + with pytest.raises(AttributeError, match="'bogus_name'"): + mistapi.__getattr__("bogus_name") + + @pytest.mark.parametrize( + "name", + ["foo", "bar", "API", "Websockets", "Api", "UTILS"], + ) + def test_various_invalid_names(self, name): + """Various invalid names should all raise AttributeError (case-sensitive).""" + import mistapi + + with pytest.raises(AttributeError): + mistapi.__getattr__(name) diff --git a/tests/unit/test_logger.py b/tests/unit/test_logger.py new file mode 100644 index 0000000..0c543c9 --- /dev/null +++ b/tests/unit/test_logger.py @@ -0,0 +1,329 @@ +# tests/unit/test_logger.py +""" +Unit tests for the Console logger and LogSanitizer. + +These tests cover sensitive-field redaction, log-level gating, +the LogSanitizer filter, and the _set_log_level helper. +""" + +import logging + +import pytest + +from mistapi.__logger import Console, LogSanitizer, SENSITIVE_FIELDS, logger + + +# --------------------------------------------------------------------------- +# Sanitization +# --------------------------------------------------------------------------- +class TestConsoleSanitize: + """Tests for Console.sanitize() redaction logic.""" + + def test_redacts_password_double_quotes(self) -> None: + """A plain 'password' field in double quotes is redacted.""" + c = Console() + raw = '{"password": "s3cret!"}' + result = c.sanitize(raw) + assert "s3cret!" not in result + assert "******" in result + + def test_redacts_password_single_quotes(self) -> None: + """A 'password' field wrapped in single quotes is redacted.""" + c = Console() + raw = "{'password': 'mysecret'}" + result = c.sanitize(raw) + assert "mysecret" not in result + assert "******" in result + + @pytest.mark.parametrize( + "field", + SENSITIVE_FIELDS, + ids=SENSITIVE_FIELDS, + ) + def test_redacts_every_sensitive_field(self, field: str) -> None: + """Every entry in SENSITIVE_FIELDS must be redacted.""" + c = Console() + raw = f'{{"{field}": "topSecret123"}}' + result = c.sanitize(raw) + assert "topSecret123" not in result + assert "******" in result + + def test_redacts_case_insensitively(self) -> None: + """Field matching is case-insensitive.""" + c = Console() + raw = '{"PASSWORD": "abc"}' + result = c.sanitize(raw) + assert "abc" not in result + assert "******" in result + + def test_redacts_multiple_fields_in_one_string(self) -> None: + """Multiple sensitive fields in the same string are all redacted.""" + c = Console() + raw = '{"password": "pw1", "apitoken": "tok1", "key": "k1"}' + result = c.sanitize(raw) + assert "pw1" not in result + assert "tok1" not in result + assert "k1" not in result + + def test_no_sensitive_data_unchanged(self) -> None: + """A string with no sensitive fields is returned unchanged.""" + c = Console() + raw = '{"name": "Alice", "age": "30"}' + result = c.sanitize(raw) + assert result == raw + + def test_non_string_input_dict(self) -> None: + """A dict is JSON-serialised before sanitisation.""" + c = Console() + data = {"password": "hunter2", "user": "admin"} + result = c.sanitize(data) + assert "hunter2" not in result + assert "admin" in result + assert "******" in result + + def test_non_string_input_list(self) -> None: + """A list containing a dict with sensitive data is sanitised.""" + c = Console() + data = [{"apitoken": "secret_tok"}] + result = c.sanitize(data) + assert "secret_tok" not in result + assert "******" in result + + def test_non_string_input_int(self) -> None: + """An integer is serialised and returned as-is (no sensitive data).""" + c = Console() + result = c.sanitize(42) + assert result == "42" + + def test_empty_string(self) -> None: + """An empty string returns an empty string.""" + c = Console() + assert c.sanitize("") == "" + + def test_empty_value_still_redacted(self) -> None: + """A sensitive field whose value is empty is still redacted.""" + c = Console() + raw = '{"password": ""}' + result = c.sanitize(raw) + assert "******" in result + + +# --------------------------------------------------------------------------- +# Log-level methods (print gating) +# --------------------------------------------------------------------------- +class TestConsoleLogLevelMethods: + """Each log method should print only when the level threshold is met.""" + + # Mapping: (method_name, method_threshold) + # critical prints at level <= 50 + # error prints at level <= 40 + # warning prints at level <= 30 + # info prints at level <= 20 + # debug prints at level <= 10 + LOG_METHODS = [ + ("critical", 50), + ("error", 40), + ("warning", 30), + ("info", 20), + ("debug", 10), + ] + + @pytest.mark.parametrize( + "method_name,threshold", LOG_METHODS, ids=[m for m, _ in LOG_METHODS] + ) + def test_prints_when_level_equals_threshold( + self, capsys, method_name, threshold + ) -> None: + """Method prints when console level == method threshold.""" + c = Console(level=threshold) + getattr(c, method_name)("hello") + captured = capsys.readouterr() + assert "hello" in captured.out + + @pytest.mark.parametrize( + "method_name,threshold", LOG_METHODS, ids=[m for m, _ in LOG_METHODS] + ) + def test_prints_when_level_below_threshold( + self, capsys, method_name, threshold + ) -> None: + """Method prints when console level is below the method threshold.""" + c = Console(level=max(threshold - 10, 1)) + getattr(c, method_name)("below") + captured = capsys.readouterr() + assert "below" in captured.out + + @pytest.mark.parametrize( + "method_name,threshold", LOG_METHODS, ids=[m for m, _ in LOG_METHODS] + ) + def test_silent_when_level_above_threshold( + self, capsys, method_name, threshold + ) -> None: + """Method is silent when console level exceeds the method threshold.""" + c = Console(level=threshold + 10) + getattr(c, method_name)("nope") + captured = capsys.readouterr() + assert captured.out == "" + + @pytest.mark.parametrize( + "method_name,threshold", LOG_METHODS, ids=[m for m, _ in LOG_METHODS] + ) + def test_silent_when_level_is_zero(self, capsys, method_name, threshold) -> None: + """No output at all when level is 0 (disabled).""" + c = Console(level=0) + getattr(c, method_name)("disabled") + captured = capsys.readouterr() + assert captured.out == "" + + def test_output_sanitised(self, capsys) -> None: + """Print output has sensitive data redacted.""" + c = Console(level=20) + c.info('{"password": "oops"}') + captured = capsys.readouterr() + assert "oops" not in captured.out + assert "******" in captured.out + + def test_output_has_bracket_prefix(self, capsys) -> None: + """All log lines are wrapped with a bracket prefix.""" + c = Console(level=10) + c.debug("test_msg") + captured = capsys.readouterr() + assert captured.out.startswith("[") + assert "test_msg" in captured.out + + +# --------------------------------------------------------------------------- +# Default level +# --------------------------------------------------------------------------- +class TestConsoleDefaults: + """Verify constructor defaults.""" + + def test_default_level_is_20(self) -> None: + c = Console() + assert c.level == 20 + + def test_custom_level(self) -> None: + c = Console(level=40) + assert c.level == 40 + + +# --------------------------------------------------------------------------- +# _set_log_level +# --------------------------------------------------------------------------- +class TestSetLogLevel: + """Tests for _set_log_level() which adjusts both console and logging levels.""" + + @pytest.fixture(autouse=True) + def _restore_logger_level(self): + """Save and restore the module-level logger level between tests.""" + original = logger.level + yield + logger.setLevel(original) + + def test_sets_console_level(self) -> None: + c = Console(level=20) + c._set_log_level(console_log_level=40, logging_log_level=10) + assert c.level == 40 + + def test_sets_logging_level(self) -> None: + c = Console(level=20) + c._set_log_level(console_log_level=20, logging_log_level=30) + assert logger.level == 30 + + def test_default_arguments(self) -> None: + c = Console(level=50) + c._set_log_level() + assert c.level == 20 + assert logger.level == 10 + + def test_set_level_to_zero_disables_console(self, capsys) -> None: + c = Console(level=20) + c._set_log_level(console_log_level=0) + c.critical("should_not_appear") + captured = capsys.readouterr() + assert captured.out == "" + + +# --------------------------------------------------------------------------- +# LogSanitizer filter +# --------------------------------------------------------------------------- +class TestLogSanitizer: + """Tests for the logging.Filter subclass that redacts sensitive data.""" + + def _make_record(self, msg: str, *args) -> logging.LogRecord: + """Create a minimal LogRecord for testing.""" + record = logging.LogRecord( + name="mistapi", + level=logging.INFO, + pathname="", + lineno=0, + msg=msg, + args=args if args else None, + exc_info=None, + ) + return record + + def test_filter_returns_true(self) -> None: + """The filter never drops records; it always returns True.""" + f = LogSanitizer() + record = self._make_record("safe message") + assert f.filter(record) is True + + def test_filter_sanitises_message(self) -> None: + """Sensitive data inside the message is replaced.""" + f = LogSanitizer() + record = self._make_record('{"password": "leak"}') + f.filter(record) + assert "leak" not in record.msg + assert "******" in record.msg + + def test_filter_clears_args(self) -> None: + """record.args is set to None after filtering.""" + f = LogSanitizer() + record = self._make_record("value is %s", "something") + assert record.args is not None + f.filter(record) + assert record.args is None + + def test_filter_handles_format_args(self) -> None: + """getMessage() expands %-formatting; filter sees the expanded string.""" + f = LogSanitizer() + record = self._make_record('{"apitoken": "%s"}', "my_secret_token") + # Before filtering, getMessage() should expand the arg + expanded = record.getMessage() + assert "my_secret_token" in expanded + + f.filter(record) + assert "my_secret_token" not in record.msg + assert "******" in record.msg + + def test_filter_safe_message_unchanged(self) -> None: + """A record without sensitive data passes through with its message intact.""" + f = LogSanitizer() + record = self._make_record("just a normal log line") + f.filter(record) + assert record.msg == "just a normal log line" + assert record.args is None + + +# --------------------------------------------------------------------------- +# Module-level singletons +# --------------------------------------------------------------------------- +class TestModuleLevelObjects: + """The module exposes pre-built console and logger objects.""" + + def test_module_console_exists(self) -> None: + from mistapi.__logger import console as mod_console + + assert isinstance(mod_console, Console) + + def test_module_logger_name(self) -> None: + assert logger.name == "mistapi" + + def test_module_logger_has_sanitizer_filter(self) -> None: + filters = logger.filters + assert any(isinstance(f, LogSanitizer) for f in filters) + + def test_module_logger_level_is_integer(self) -> None: + """Logger level is always a valid integer (may be mutated by other tests).""" + assert isinstance(logger.level, int) + assert logger.level >= 0 diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index e69de29..504b51f 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -0,0 +1,508 @@ +# tests/unit/test_models.py +""" +Unit tests for privilege model classes. + +These tests cover _Privilege and Privileges from mistapi.__models.privilege, +verifying construction, string representation, iteration, and field access. +""" + +import pytest + +from mistapi.__models.privilege import Privileges, _Privilege + + +# --------------------------------------------------------------------------- +# Constants re-declared locally so tests are self-contained within the file, +# but the sample_privileges fixture from conftest is reused for consistency. +# --------------------------------------------------------------------------- +TEST_ORG_ID = "203d3d02-dbc0-4c1b-9f41-76896a3330f4" +TEST_SITE_ID = "f5fcbee5-fbca-45b3-8bf1-1619ede87879" + + +# =================================================================== +# _Privilege tests +# =================================================================== +class TestPrivilegeCreation: + """Test _Privilege initialisation and field population""" + + def test_all_fields_populated_from_dict(self) -> None: + """All dict keys should become attributes on the _Privilege object""" + # Arrange + data = { + "scope": "org", + "role": "admin", + "org_id": TEST_ORG_ID, + "org_name": "Acme Corp", + "msp_id": "msp-id-1", + "msp_name": "MSP One", + "orggroup_ids": ["grp-1", "grp-2"], + "name": "Test Privilege", + "site_id": TEST_SITE_ID, + "sitegroup_ids": ["sg-1"], + "views": ["monitoring", "location"], + } + + # Act + priv = _Privilege(data) + + # Assert + assert priv.scope == "org" + assert priv.role == "admin" + assert priv.org_id == TEST_ORG_ID + assert priv.org_name == "Acme Corp" + assert priv.msp_id == "msp-id-1" + assert priv.msp_name == "MSP One" + assert priv.orggroup_ids == ["grp-1", "grp-2"] + assert priv.name == "Test Privilege" + assert priv.site_id == TEST_SITE_ID + assert priv.sitegroup_ids == ["sg-1"] + assert priv.views == ["monitoring", "location"] + + def test_partial_dict_sets_defaults_for_missing_fields(self) -> None: + """Fields not present in the dict should retain their defaults""" + # Arrange + data = {"scope": "org", "role": "read"} + + # Act + priv = _Privilege(data) + + # Assert — supplied values + assert priv.scope == "org" + assert priv.role == "read" + + # Assert — default values + assert priv.org_id == "" + assert priv.org_name == "" + assert priv.msp_id == "" + assert priv.msp_name == "" + assert priv.orggroup_ids == [] + assert priv.name == "" + assert priv.site_id == "" + assert priv.sitegroup_ids == [] + assert priv.views == [] + + def test_empty_dict_uses_all_defaults(self) -> None: + """An empty dict should produce a _Privilege with all default values""" + # Act + priv = _Privilege({}) + + # Assert + assert priv.scope == "" + assert priv.role == "" + assert priv.org_id == "" + assert priv.name == "" + + def test_extra_keys_become_attributes(self) -> None: + """Keys not in the predefined set should still be set as attributes""" + # Arrange + data = {"scope": "org", "custom_field": "custom_value"} + + # Act + priv = _Privilege(data) + + # Assert + assert priv.scope == "org" + assert priv.custom_field == "custom_value" + + def test_creation_from_sample_fixture(self, sample_privileges) -> None: + """Verify creation using the sample_privileges fixture data""" + # Act + priv_org = _Privilege(sample_privileges[0]) + priv_site = _Privilege(sample_privileges[1]) + + # Assert + assert priv_org.scope == "org" + assert priv_org.role == "admin" + assert priv_org.org_id == TEST_ORG_ID + assert priv_org.name == "Test Organisation" + + assert priv_site.scope == "site" + assert priv_site.role == "write" + assert priv_site.site_id == TEST_SITE_ID + assert priv_site.org_id == TEST_ORG_ID + assert priv_site.name == "Test Site" + + +class TestPrivilegeStr: + """Test _Privilege.__str__() output""" + + def test_str_includes_non_empty_fields(self) -> None: + """Non-empty string fields should appear in the output""" + # Arrange + data = { + "scope": "org", + "role": "admin", + "org_id": TEST_ORG_ID, + "name": "My Org", + } + + # Act + priv = _Privilege(data) + result = str(priv) + + # Assert + assert "scope: org" in result + assert "role: admin" in result + assert f"org_id: {TEST_ORG_ID}" in result + assert "name: My Org" in result + + def test_str_excludes_empty_string_fields(self) -> None: + """Empty string fields should not appear in the output""" + # Arrange + data = {"scope": "org", "role": "admin"} + + # Act + priv = _Privilege(data) + result = str(priv) + + # Assert — these fields are empty strings and should be absent + assert "org_id:" not in result + assert "org_name:" not in result + assert "msp_id:" not in result + assert "msp_name:" not in result + assert "site_id:" not in result + + def test_str_includes_non_empty_list_fields(self) -> None: + """Non-empty list fields (orggroup_ids, sitegroup_ids) should appear""" + # Arrange + data = { + "scope": "org", + "role": "admin", + "orggroup_ids": ["grp-1"], + "sitegroup_ids": ["sg-1"], + } + + # Act + priv = _Privilege(data) + result = str(priv) + + # Assert + assert "orggroup_ids:" in result + assert "sitegroup_ids:" in result + + def test_str_uses_crlf_separator(self) -> None: + """Each field line should end with ' \\r\\n'""" + # Arrange + data = {"scope": "org", "role": "admin"} + + # Act + priv = _Privilege(data) + result = str(priv) + + # Assert + assert "scope: org \r\n" in result + assert "role: admin \r\n" in result + + def test_str_empty_privilege(self) -> None: + """A privilege with all defaults should only contain list fields""" + # Act + priv = _Privilege({}) + result = str(priv) + + # Assert - empty strings are excluded but empty lists [] are not "" + # so orggroup_ids and sitegroup_ids will still appear + assert "scope:" not in result + assert "role:" not in result + assert "org_id:" not in result + + +class TestPrivilegeGet: + """Test _Privilege.get() method""" + + def test_get_returns_value_when_present(self) -> None: + """get() should return the attribute value when it exists and is truthy""" + # Arrange + data = {"scope": "org", "role": "admin", "org_id": TEST_ORG_ID} + priv = _Privilege(data) + + # Act & Assert + assert priv.get("scope") == "org" + assert priv.get("role") == "admin" + assert priv.get("org_id") == TEST_ORG_ID + + def test_get_returns_default_when_key_missing(self) -> None: + """get() should return the default when the attribute does not exist""" + # Arrange + priv = _Privilege({"scope": "org"}) + + # Act & Assert + assert priv.get("nonexistent") is None + assert priv.get("nonexistent", "fallback") == "fallback" + + def test_get_returns_default_when_value_is_empty_string(self) -> None: + """get() should return the default when the attribute is an empty string (falsy)""" + # Arrange + priv = _Privilege({}) # org_id defaults to "" + + # Act & Assert + assert priv.get("org_id") is None + assert priv.get("org_id", "default_org") == "default_org" + + def test_get_returns_default_when_value_is_empty_list(self) -> None: + """get() should return the default when the attribute is an empty list (falsy)""" + # Arrange + priv = _Privilege({}) # views defaults to [] + + # Act & Assert + assert priv.get("views") is None + assert priv.get("views", ["default_view"]) == ["default_view"] + + def test_get_returns_list_when_non_empty(self) -> None: + """get() should return the list value when it is non-empty""" + # Arrange + data = {"views": ["monitoring", "location"]} + priv = _Privilege(data) + + # Act & Assert + assert priv.get("views") == ["monitoring", "location"] + + def test_get_default_parameter_defaults_to_none(self) -> None: + """get() default parameter should be None when not specified""" + # Arrange + priv = _Privilege({}) + + # Act & Assert + assert priv.get("nonexistent") is None + + +# =================================================================== +# Privileges tests +# =================================================================== +class TestPrivilegesCreation: + """Test Privileges initialisation""" + + def test_creates_privilege_objects_from_list_of_dicts( + self, sample_privileges + ) -> None: + """Privileges should wrap each dict into a _Privilege object""" + # Act + privs = Privileges(sample_privileges) + + # Assert + assert len(privs.privileges) == 2 + assert all(isinstance(p, _Privilege) for p in privs.privileges) + + def test_first_entry_matches_source_data(self, sample_privileges) -> None: + """First _Privilege should carry the values from the first dict""" + # Act + privs = Privileges(sample_privileges) + + # Assert + first = privs.privileges[0] + assert first.scope == "org" + assert first.role == "admin" + assert first.org_id == TEST_ORG_ID + assert first.name == "Test Organisation" + + def test_second_entry_matches_source_data(self, sample_privileges) -> None: + """Second _Privilege should carry the values from the second dict""" + # Act + privs = Privileges(sample_privileges) + + # Assert + second = privs.privileges[1] + assert second.scope == "site" + assert second.role == "write" + assert second.site_id == TEST_SITE_ID + assert second.org_id == TEST_ORG_ID + assert second.name == "Test Site" + + def test_empty_list_produces_empty_privileges(self) -> None: + """An empty list should result in an empty privileges list""" + # Act + privs = Privileges([]) + + # Assert + assert privs.privileges == [] + + def test_single_privilege(self) -> None: + """A single-element list should produce exactly one _Privilege""" + # Arrange + data = [{"scope": "msp", "role": "admin", "msp_id": "msp-1"}] + + # Act + privs = Privileges(data) + + # Assert + assert len(privs.privileges) == 1 + assert privs.privileges[0].scope == "msp" + assert privs.privileges[0].msp_id == "msp-1" + + +class TestPrivilegesIter: + """Test Privileges.__iter__()""" + + def test_iter_yields_all_privileges(self, sample_privileges) -> None: + """Iterating should yield every _Privilege in order""" + # Arrange + privs = Privileges(sample_privileges) + + # Act + result = list(privs) + + # Assert + assert len(result) == 2 + assert result[0].scope == "org" + assert result[1].scope == "site" + + def test_iter_returns_privilege_instances(self, sample_privileges) -> None: + """Each iterated element should be a _Privilege instance""" + # Arrange + privs = Privileges(sample_privileges) + + # Act & Assert + for priv in privs: + assert isinstance(priv, _Privilege) + + def test_iter_empty_privileges_returns_empty_iterator(self) -> None: + """Iterating over empty Privileges should produce no elements""" + # Arrange + privs = Privileges([]) + + # Act + result = list(privs) + + # Assert + assert result == [] + + def test_iter_can_be_used_in_for_loop(self, sample_privileges) -> None: + """Privileges should work naturally in a for loop""" + # Arrange + privs = Privileges(sample_privileges) + scopes = [] + + # Act + for priv in privs: + scopes.append(priv.scope) + + # Assert + assert scopes == ["org", "site"] + + def test_iter_supports_next(self, sample_privileges) -> None: + """The iterator returned by __iter__ should support next()""" + # Arrange + privs = Privileges(sample_privileges) + + # Act + it = iter(privs) + first = next(it) + second = next(it) + + # Assert + assert first.scope == "org" + assert second.scope == "site" + + with pytest.raises(StopIteration): + next(it) + + def test_iter_supports_generator_expression(self, sample_privileges) -> None: + """Privileges should work with generator expressions (e.g. next(...))""" + # Arrange + privs = Privileges(sample_privileges) + + # Act + found = next((p for p in privs if p.org_id == TEST_ORG_ID), None) + + # Assert + assert found is not None + assert found.org_id == TEST_ORG_ID + + +class TestPrivilegesStr: + """Test Privileges.__str__() output""" + + def test_str_produces_tabulate_table(self, sample_privileges) -> None: + """String output should be a tabulate-formatted table""" + # Arrange + privs = Privileges(sample_privileges) + + # Act + result = str(privs) + + # Assert — column headers should be present + assert "scope" in result + assert "role" in result + assert "name" in result + assert "site_id" in result + assert "org_name" in result + assert "org_id" in result + assert "msp_name" in result + assert "msp_id" in result + assert "views" in result + + def test_str_contains_privilege_data(self, sample_privileges) -> None: + """The table should contain actual privilege field values""" + # Arrange + privs = Privileges(sample_privileges) + + # Act + result = str(privs) + + # Assert + assert "org" in result + assert "admin" in result + assert "Test Organisation" in result + assert TEST_ORG_ID in result + assert "site" in result + assert "write" in result + assert "Test Site" in result + assert TEST_SITE_ID in result + + def test_str_empty_privileges(self) -> None: + """An empty Privileges should produce just headers (or an empty table)""" + # Arrange + privs = Privileges([]) + + # Act + result = str(privs) + + # Assert — tabulate with an empty table and headers produces header row(s) + # At minimum it should not raise and should be a string + assert isinstance(result, str) + + def test_str_is_consistent(self, sample_privileges) -> None: + """Calling str() multiple times should produce the same result""" + # Arrange + privs = Privileges(sample_privileges) + + # Act + result1 = str(privs) + result2 = str(privs) + + # Assert + assert result1 == result2 + + +class TestPrivilegesDisplay: + """Test Privileges.display() method""" + + def test_display_returns_same_as_str(self, sample_privileges) -> None: + """display() should return exactly the same string as __str__()""" + # Arrange + privs = Privileges(sample_privileges) + + # Act + display_result = privs.display() + str_result = str(privs) + + # Assert + assert display_result == str_result + + def test_display_returns_string_type(self, sample_privileges) -> None: + """display() should return a str""" + # Arrange + privs = Privileges(sample_privileges) + + # Act + result = privs.display() + + # Assert + assert isinstance(result, str) + + def test_display_empty_privileges(self) -> None: + """display() on empty Privileges should match str() on empty Privileges""" + # Arrange + privs = Privileges([]) + + # Act & Assert + assert privs.display() == str(privs) diff --git a/tests/unit/test_pagination.py b/tests/unit/test_pagination.py index e69de29..1cdfee5 100644 --- a/tests/unit/test_pagination.py +++ b/tests/unit/test_pagination.py @@ -0,0 +1,296 @@ +""" +Unit tests for mistapi.__pagination module. + +Tests the get_next() and get_all() pagination helper functions using +mocked APISession and APIResponse objects. +""" + +from unittest.mock import Mock + +from mistapi.__pagination import get_all, get_next + + +def _make_response(data, next_url=None): + """Create a mock APIResponse with the given data and next link.""" + response = Mock() + response.data = data + response.next = next_url + return response + + +class TestGetNext: + """Tests for get_next().""" + + def test_calls_mist_get_when_next_exists(self): + """get_next() should call mist_session.mist_get with response.next.""" + session = Mock() + next_url = "/api/v1/sites?page=2" + response = _make_response(data=[], next_url=next_url) + expected = _make_response(data=[{"id": "second"}]) + session.mist_get.return_value = expected + + result = get_next(session, response) + + session.mist_get.assert_called_once_with(next_url) + assert result is expected + + def test_returns_none_when_next_is_none(self): + """get_next() should return None when response.next is None.""" + session = Mock() + response = _make_response(data=[], next_url=None) + + result = get_next(session, response) + + assert result is None + session.mist_get.assert_not_called() + + def test_returns_none_when_next_is_empty_string(self): + """get_next() should return None when response.next is an empty string.""" + session = Mock() + response = _make_response(data=[], next_url="") + + result = get_next(session, response) + + assert result is None + session.mist_get.assert_not_called() + + +class TestGetAllList: + """Tests for get_all() when response.data is a list.""" + + def test_single_page_returns_data(self): + """get_all() should return the list data as-is when there is no next page.""" + session = Mock() + items = [{"id": "a"}, {"id": "b"}] + response = _make_response(data=items, next_url=None) + + result = get_all(session, response) + + assert result == items + session.mist_get.assert_not_called() + + def test_single_page_returns_copy(self): + """get_all() should return a new list, not the original reference.""" + session = Mock() + items = [{"id": "a"}] + response = _make_response(data=items, next_url=None) + + result = get_all(session, response) + + assert result == items + assert result is not items + + def test_multi_page_concatenates(self): + """get_all() should follow next links and concatenate all pages.""" + session = Mock() + + page1 = _make_response( + data=[{"id": "1"}, {"id": "2"}], + next_url="/api/v1/items?page=2", + ) + page2 = _make_response( + data=[{"id": "3"}, {"id": "4"}], + next_url="/api/v1/items?page=3", + ) + page3 = _make_response( + data=[{"id": "5"}], + next_url=None, + ) + + session.mist_get.side_effect = [page2, page3] + + result = get_all(session, page1) + + assert result == [ + {"id": "1"}, + {"id": "2"}, + {"id": "3"}, + {"id": "4"}, + {"id": "5"}, + ] + assert session.mist_get.call_count == 2 + + def test_empty_list_returns_empty(self): + """get_all() should return an empty list for an empty first page.""" + session = Mock() + response = _make_response(data=[], next_url=None) + + result = get_all(session, response) + + assert result == [] + + def test_empty_list_with_next_follows_links(self): + """get_all() should still follow next even when the first page is empty.""" + session = Mock() + + page1 = _make_response(data=[], next_url="/api/v1/items?page=2") + page2 = _make_response(data=[{"id": "1"}], next_url=None) + session.mist_get.return_value = page2 + + result = get_all(session, page1) + + assert result == [{"id": "1"}] + + +class TestGetAllDict: + """Tests for get_all() when response.data is a dict with 'results' key.""" + + def test_single_page_extracts_results(self): + """get_all() should extract and return the 'results' list from a dict response.""" + session = Mock() + items = [{"id": "a"}, {"id": "b"}] + response = _make_response( + data={"results": items, "total": 2, "limit": 100}, + next_url=None, + ) + + result = get_all(session, response) + + assert result == items + session.mist_get.assert_not_called() + + def test_single_page_returns_copy(self): + """get_all() should return a copy of results, not the original.""" + session = Mock() + items = [{"id": "a"}] + response = _make_response( + data={"results": items}, + next_url=None, + ) + + result = get_all(session, response) + + assert result == items + assert result is not items + + def test_multi_page_concatenates_results(self): + """get_all() should follow next links and concatenate results from dict responses.""" + session = Mock() + + page1 = _make_response( + data={"results": [{"id": "1"}], "total": 3, "limit": 1}, + next_url="/api/v1/items?page=2", + ) + page2 = _make_response( + data={"results": [{"id": "2"}], "total": 3, "limit": 1}, + next_url="/api/v1/items?page=3", + ) + page3 = _make_response( + data={"results": [{"id": "3"}], "total": 3, "limit": 1}, + next_url=None, + ) + + session.mist_get.side_effect = [page2, page3] + + result = get_all(session, page1) + + assert result == [{"id": "1"}, {"id": "2"}, {"id": "3"}] + assert session.mist_get.call_count == 2 + + def test_empty_results_returns_empty(self): + """get_all() should return an empty list when results is empty.""" + session = Mock() + response = _make_response( + data={"results": [], "total": 0}, + next_url=None, + ) + + result = get_all(session, response) + + assert result == [] + + def test_empty_results_with_next_follows_links(self): + """get_all() should follow next even when the first page results are empty.""" + session = Mock() + + page1 = _make_response( + data={"results": []}, + next_url="/api/v1/items?page=2", + ) + page2 = _make_response( + data={"results": [{"id": "1"}]}, + next_url=None, + ) + session.mist_get.return_value = page2 + + result = get_all(session, page1) + + assert result == [{"id": "1"}] + + +class TestGetAllEdgeCases: + """Tests for get_all() edge cases and unsupported data types.""" + + def test_dict_without_results_key_returns_empty(self): + """get_all() should return empty list when data is a dict without 'results'.""" + session = Mock() + response = _make_response( + data={"items": [1, 2, 3]}, + next_url=None, + ) + + result = get_all(session, response) + + assert result == [] + + def test_non_list_non_dict_returns_empty(self): + """get_all() should return empty list for unsupported data types.""" + session = Mock() + response = _make_response(data="some string", next_url=None) + + result = get_all(session, response) + + assert result == [] + + def test_none_data_returns_empty(self): + """get_all() should return empty list when data is None.""" + session = Mock() + response = _make_response(data=None, next_url=None) + + result = get_all(session, response) + + assert result == [] + + def test_get_next_returns_none_on_last_page(self): + """get_all() should stop when get_next returns None (no more pages).""" + session = Mock() + + # Simulate: page1 has next, page2 does not. + page1 = _make_response( + data=[{"id": "1"}], + next_url="/api/v1/items?page=2", + ) + page2 = _make_response( + data=[{"id": "2"}], + next_url=None, + ) + session.mist_get.return_value = page2 + + result = get_all(session, page1) + + assert result == [{"id": "1"}, {"id": "2"}] + session.mist_get.assert_called_once_with("/api/v1/items?page=2") + + def test_does_not_mutate_original_response(self): + """get_all() should not modify the original response object's data.""" + session = Mock() + original_items = [{"id": "1"}, {"id": "2"}] + response = _make_response(data=original_items, next_url=None) + + get_all(session, response) + + assert response.data == [{"id": "1"}, {"id": "2"}] + + def test_dict_does_not_mutate_original_results(self): + """get_all() should not mutate the original results list in a dict response.""" + session = Mock() + original_results = [{"id": "1"}] + response = _make_response( + data={"results": original_results}, + next_url=None, + ) + + result = get_all(session, response) + + assert original_results == [{"id": "1"}] + assert result is not original_results diff --git a/tests/unit/test_websocket_client.py b/tests/unit/test_websocket_client.py new file mode 100644 index 0000000..5069567 --- /dev/null +++ b/tests/unit/test_websocket_client.py @@ -0,0 +1,780 @@ +# tests/unit/test_websocket_client.py +""" +Unit tests for _MistWebsocket base class and public WebSocket channel classes. + +These tests cover URL building, authentication helpers, SSL options, +callback registration, internal handlers, connect/disconnect lifecycle, +the receive() generator, context-manager support, and the public API +surface of all channel classes (sites, orgs, location, session). +""" + +import json +import ssl +from unittest.mock import Mock, call, patch + +import pytest + +from mistapi.websockets.__ws_client import _MistWebsocket +from mistapi.websockets.location import ( + BleAssetsEvents, + ConnectedClientsEvents, + DiscoveredBleAssetsEvents, + SdkClientsEvents, + UnconnectedClientsEvents, +) +from mistapi.websockets.orgs import ( + InsightsEvents, + MxEdgesEvents, +) +from mistapi.websockets.orgs import MxEdgesStatsEvents as OrgMxEdgesStatsEvents +from mistapi.websockets.session import SessionWithUrl +from mistapi.websockets.sites import ( + ClientsStatsEvents, + DeviceCmdEvents, + DeviceEvents, + DeviceStatsEvents, + PcapEvents, +) +from mistapi.websockets.sites import ( + MxEdgesStatsEvents as SiteMxEdgesStatsEvents, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_session(): + """ + Lightweight mock of APISession with the internal attributes that + _MistWebsocket accesses directly. + """ + session = Mock() + session._cloud_uri = "api.mist.com" + session._apitoken = ["test_token"] + session._apitoken_index = 0 + + # requests.Session stand-in + requests_session = Mock() + requests_session.cookies = [] + requests_session.verify = True + requests_session.cert = None + session._session = requests_session + + return session + + +@pytest.fixture +def ws_client(mock_session): + """A _MistWebsocket wired to mock_session with two channels.""" + return _MistWebsocket( + mist_session=mock_session, + channels=["/test/channel1", "/test/channel2"], + ) + + +@pytest.fixture +def single_channel_client(mock_session): + """A _MistWebsocket with a single channel and custom ping settings.""" + return _MistWebsocket( + mist_session=mock_session, + channels=["/events"], + ping_interval=15, + ping_timeout=5, + ) + + +# --------------------------------------------------------------------------- +# URL building +# --------------------------------------------------------------------------- + + +class TestBuildWsUrl: + """Tests for _build_ws_url().""" + + def test_replaces_api_with_api_ws(self, ws_client) -> None: + url = ws_client._build_ws_url() + assert url == "wss://api-ws.mist.com/api-ws/v1/stream" + + def test_eu_cloud_url(self, mock_session) -> None: + mock_session._cloud_uri = "api.eu.mist.com" + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._build_ws_url() == "wss://api-ws.eu.mist.com/api-ws/v1/stream" + + def test_gc1_cloud_url(self, mock_session) -> None: + mock_session._cloud_uri = "api.gc1.mist.com" + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._build_ws_url() == "wss://api-ws.gc1.mist.com/api-ws/v1/stream" + + +# --------------------------------------------------------------------------- +# Headers (token auth) +# --------------------------------------------------------------------------- + + +class TestGetHeaders: + """Tests for _get_headers().""" + + def test_returns_authorization_header_with_token(self, ws_client) -> None: + headers = ws_client._get_headers() + assert headers == {"Authorization": "Token test_token"} + + def test_uses_correct_token_index(self, mock_session) -> None: + mock_session._apitoken = ["token_a", "token_b"] + mock_session._apitoken_index = 1 + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_headers() == {"Authorization": "Token token_b"} + + def test_returns_empty_dict_without_token(self, mock_session) -> None: + mock_session._apitoken = [] + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_headers() == {} + + def test_returns_empty_dict_when_token_is_none(self, mock_session) -> None: + mock_session._apitoken = None + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_headers() == {} + + +# --------------------------------------------------------------------------- +# Cookies (session auth) +# --------------------------------------------------------------------------- + + +class TestGetCookie: + """Tests for _get_cookie().""" + + def test_formats_cookies_as_semicolon_pairs(self, mock_session) -> None: + cookie1 = Mock(name="csrftoken", value="abc123") + cookie1.name = "csrftoken" + cookie1.value = "abc123" + cookie2 = Mock(name="sessionid", value="xyz789") + cookie2.name = "sessionid" + cookie2.value = "xyz789" + mock_session._session.cookies = [cookie1, cookie2] + + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_cookie() == "csrftoken=abc123; sessionid=xyz789" + + def test_returns_none_when_no_cookies(self, mock_session) -> None: + mock_session._session.cookies = [] + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_cookie() is None + + def test_filters_cookies_with_cr_in_name(self, mock_session) -> None: + good = Mock() + good.name = "ok" + good.value = "val" + bad = Mock() + bad.name = "bad\rname" + bad.value = "val" + mock_session._session.cookies = [good, bad] + + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_cookie() == "ok=val" + + def test_filters_cookies_with_lf_in_name(self, mock_session) -> None: + bad = Mock() + bad.name = "bad\nname" + bad.value = "val" + mock_session._session.cookies = [bad] + + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_cookie() is None + + def test_filters_cookies_with_cr_in_value(self, mock_session) -> None: + bad = Mock() + bad.name = "name" + bad.value = "bad\rvalue" + mock_session._session.cookies = [bad] + + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_cookie() is None + + def test_filters_cookies_with_lf_in_value(self, mock_session) -> None: + good = Mock() + good.name = "safe" + good.value = "clean" + bad = Mock() + bad.name = "name" + bad.value = "bad\nvalue" + mock_session._session.cookies = [good, bad] + + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_cookie() == "safe=clean" + + def test_returns_none_when_all_cookies_filtered(self, mock_session) -> None: + bad1 = Mock() + bad1.name = "a\r" + bad1.value = "v" + bad2 = Mock() + bad2.name = "b" + bad2.value = "v\n" + mock_session._session.cookies = [bad1, bad2] + + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._get_cookie() is None + + +# --------------------------------------------------------------------------- +# SSL options +# --------------------------------------------------------------------------- + + +class TestBuildSslopt: + """Tests for _build_sslopt().""" + + def test_defaults_returns_empty_dict(self, ws_client) -> None: + # verify=True, cert=None => empty sslopt + assert ws_client._build_sslopt() == {} + + def test_verify_false(self, mock_session) -> None: + mock_session._session.verify = False + mock_session._session.cert = None + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._build_sslopt() == {"cert_reqs": ssl.CERT_NONE} + + def test_verify_custom_ca_path(self, mock_session) -> None: + mock_session._session.verify = "/etc/ssl/custom-ca.pem" + mock_session._session.cert = None + client = _MistWebsocket(mock_session, channels=["/ch"]) + assert client._build_sslopt() == {"ca_certs": "/etc/ssl/custom-ca.pem"} + + def test_cert_as_string(self, mock_session) -> None: + mock_session._session.cert = "/path/to/client.pem" + client = _MistWebsocket(mock_session, channels=["/ch"]) + sslopt = client._build_sslopt() + assert sslopt["certfile"] == "/path/to/client.pem" + assert "keyfile" not in sslopt + + def test_cert_as_tuple(self, mock_session) -> None: + mock_session._session.cert = ("/path/cert.pem", "/path/key.pem") + client = _MistWebsocket(mock_session, channels=["/ch"]) + sslopt = client._build_sslopt() + assert sslopt["certfile"] == "/path/cert.pem" + assert sslopt["keyfile"] == "/path/key.pem" + + def test_cert_tuple_single_element(self, mock_session) -> None: + mock_session._session.cert = ("/path/cert.pem",) + client = _MistWebsocket(mock_session, channels=["/ch"]) + sslopt = client._build_sslopt() + assert sslopt["certfile"] == "/path/cert.pem" + assert "keyfile" not in sslopt + + def test_verify_false_with_cert_tuple(self, mock_session) -> None: + mock_session._session.verify = False + mock_session._session.cert = ("/path/cert.pem", "/path/key.pem") + client = _MistWebsocket(mock_session, channels=["/ch"]) + sslopt = client._build_sslopt() + assert sslopt == { + "cert_reqs": ssl.CERT_NONE, + "certfile": "/path/cert.pem", + "keyfile": "/path/key.pem", + } + + +# --------------------------------------------------------------------------- +# Callback registration +# --------------------------------------------------------------------------- + + +class TestCallbackRegistration: + """Tests for on_message / on_error / on_open / on_close setters.""" + + def test_on_message_stores_callback(self, ws_client) -> None: + cb = Mock() + ws_client.on_message(cb) + assert ws_client._on_message_cb is cb + + def test_on_error_stores_callback(self, ws_client) -> None: + cb = Mock() + ws_client.on_error(cb) + assert ws_client._on_error_cb is cb + + def test_on_open_stores_callback(self, ws_client) -> None: + cb = Mock() + ws_client.on_open(cb) + assert ws_client._on_open_cb is cb + + def test_on_close_stores_callback(self, ws_client) -> None: + cb = Mock() + ws_client.on_close(cb) + assert ws_client._on_close_cb is cb + + def test_callbacks_initially_none(self, ws_client) -> None: + assert ws_client._on_message_cb is None + assert ws_client._on_error_cb is None + assert ws_client._on_open_cb is None + assert ws_client._on_close_cb is None + + +# --------------------------------------------------------------------------- +# Internal handlers +# --------------------------------------------------------------------------- + + +class TestHandleOpen: + """Tests for _handle_open().""" + + def test_subscribes_to_each_channel(self, ws_client) -> None: + mock_ws = Mock() + ws_client._handle_open(mock_ws) + expected_calls = [ + call(json.dumps({"subscribe": "/test/channel1"})), + call(json.dumps({"subscribe": "/test/channel2"})), + ] + mock_ws.send.assert_has_calls(expected_calls) + assert mock_ws.send.call_count == 2 + + def test_sets_connected_event(self, ws_client) -> None: + mock_ws = Mock() + assert not ws_client._connected.is_set() + ws_client._handle_open(mock_ws) + assert ws_client._connected.is_set() + + def test_calls_on_open_callback(self, ws_client) -> None: + cb = Mock() + ws_client.on_open(cb) + ws_client._handle_open(Mock()) + cb.assert_called_once_with() + + def test_no_error_without_on_open_callback(self, ws_client) -> None: + ws_client._handle_open(Mock()) # Should not raise + + +class TestHandleMessage: + """Tests for _handle_message().""" + + def test_parses_valid_json_and_enqueues(self, ws_client) -> None: + payload = {"event": "device_update", "id": "abc"} + ws_client._handle_message(Mock(), json.dumps(payload)) + assert ws_client._queue.get_nowait() == payload + + def test_wraps_invalid_json_in_raw_key(self, ws_client) -> None: + ws_client._handle_message(Mock(), "not valid json {{{") + item = ws_client._queue.get_nowait() + assert item == {"raw": "not valid json {{{"} + + def test_calls_on_message_callback_with_parsed_data(self, ws_client) -> None: + cb = Mock() + ws_client.on_message(cb) + payload = {"type": "event"} + ws_client._handle_message(Mock(), json.dumps(payload)) + cb.assert_called_once_with(payload) + + def test_calls_on_message_callback_with_raw_fallback(self, ws_client) -> None: + cb = Mock() + ws_client.on_message(cb) + ws_client._handle_message(Mock(), "plain text") + cb.assert_called_once_with({"raw": "plain text"}) + + def test_no_error_without_on_message_callback(self, ws_client) -> None: + ws_client._handle_message(Mock(), '{"ok": true}') # Should not raise + + +class TestHandleError: + """Tests for _handle_error().""" + + def test_calls_on_error_callback(self, ws_client) -> None: + cb = Mock() + ws_client.on_error(cb) + exc = ConnectionError("lost connection") + ws_client._handle_error(Mock(), exc) + cb.assert_called_once_with(exc) + + def test_no_error_without_callback(self, ws_client) -> None: + ws_client._handle_error(Mock(), RuntimeError("boom")) # Should not raise + + +class TestHandleClose: + """Tests for _handle_close().""" + + def test_clears_connected_event(self, ws_client) -> None: + ws_client._connected.set() + ws_client._handle_close(Mock(), 1000, "normal closure") + assert not ws_client._connected.is_set() + + def test_puts_none_sentinel_on_queue(self, ws_client) -> None: + ws_client._handle_close(Mock(), 1000, "normal closure") + assert ws_client._queue.get_nowait() is None + + def test_calls_on_close_callback(self, ws_client) -> None: + cb = Mock() + ws_client.on_close(cb) + ws_client._handle_close(Mock(), 1001, "going away") + cb.assert_called_once_with(1001, "going away") + + def test_no_error_without_callback(self, ws_client) -> None: + ws_client._handle_close(Mock(), 1000, "") # Should not raise + + +# --------------------------------------------------------------------------- +# Connect / disconnect lifecycle +# --------------------------------------------------------------------------- + + +class TestConnect: + """Tests for connect() and disconnect().""" + + @patch("mistapi.websockets.__ws_client.websocket.WebSocketApp") + def test_connect_creates_websocket_app(self, mock_ws_cls, ws_client) -> None: + mock_ws_instance = Mock() + mock_ws_cls.return_value = mock_ws_instance + + ws_client.connect(run_in_background=False) + + mock_ws_cls.assert_called_once_with( + "wss://api-ws.mist.com/api-ws/v1/stream", + header={"Authorization": "Token test_token"}, + cookie=None, + on_open=ws_client._handle_open, + on_message=ws_client._handle_message, + on_error=ws_client._handle_error, + on_close=ws_client._handle_close, + ) + mock_ws_instance.run_forever.assert_called_once() + + @patch("mistapi.websockets.__ws_client.websocket.WebSocketApp") + def test_connect_drains_stale_queue_items(self, mock_ws_cls, ws_client) -> None: + # Pre-populate queue with stale sentinel and data + ws_client._queue.put(None) + ws_client._queue.put({"old": "data"}) + assert not ws_client._queue.empty() + + mock_ws_cls.return_value = Mock() + ws_client.connect(run_in_background=False) + + # Queue should have been drained before creating the WebSocketApp + assert ws_client._queue.empty() + + @patch("mistapi.websockets.__ws_client.websocket.WebSocketApp") + def test_connect_background_starts_thread(self, mock_ws_cls, ws_client) -> None: + mock_ws_instance = Mock() + mock_ws_cls.return_value = mock_ws_instance + + with patch( + "mistapi.websockets.__ws_client.threading.Thread" + ) as mock_thread_cls: + mock_thread = Mock() + mock_thread_cls.return_value = mock_thread + + ws_client.connect(run_in_background=True) + + mock_thread_cls.assert_called_once_with( + target=ws_client._run_forever_safe, daemon=True + ) + mock_thread.start.assert_called_once() + + @patch("mistapi.websockets.__ws_client.websocket.WebSocketApp") + def test_disconnect_calls_close(self, mock_ws_cls, ws_client) -> None: + mock_ws_instance = Mock() + mock_ws_cls.return_value = mock_ws_instance + + ws_client.connect(run_in_background=False) + ws_client.disconnect() + + mock_ws_instance.close.assert_called_once() + + def test_disconnect_without_connect_is_noop(self, ws_client) -> None: + ws_client.disconnect() # Should not raise + + +# --------------------------------------------------------------------------- +# _run_forever_safe +# --------------------------------------------------------------------------- + + +class TestRunForeverSafe: + """Tests for _run_forever_safe().""" + + def test_passes_ping_and_ssl_to_run_forever(self, single_channel_client) -> None: + mock_ws = Mock() + single_channel_client._ws = mock_ws + # verify=True, cert=None => empty sslopt dict + single_channel_client._run_forever_safe() + mock_ws.run_forever.assert_called_once_with( + ping_interval=15, + ping_timeout=5, + sslopt={}, + ) + + def test_passes_sslopt_when_verify_false(self, mock_session) -> None: + mock_session._session.verify = False + mock_session._session.cert = None + client = _MistWebsocket(mock_session, channels=["/ch"]) + mock_ws = Mock() + client._ws = mock_ws + client._run_forever_safe() + mock_ws.run_forever.assert_called_once_with( + ping_interval=30, + ping_timeout=10, + sslopt={"cert_reqs": ssl.CERT_NONE}, + ) + + def test_exception_triggers_error_and_close_handlers(self, ws_client) -> None: + mock_ws = Mock() + mock_ws.run_forever.side_effect = RuntimeError("connection failed") + ws_client._ws = mock_ws + + error_cb = Mock() + close_cb = Mock() + ws_client.on_error(error_cb) + ws_client.on_close(close_cb) + + ws_client._run_forever_safe() + + error_cb.assert_called_once() + assert isinstance(error_cb.call_args[0][0], RuntimeError) + close_cb.assert_called_once_with(-1, "connection failed") + + def test_noop_when_ws_is_none(self, ws_client) -> None: + ws_client._ws = None + ws_client._run_forever_safe() # Should not raise + + +# --------------------------------------------------------------------------- +# receive() generator +# --------------------------------------------------------------------------- + + +class TestReceive: + """Tests for the receive() generator.""" + + def test_yields_queued_messages(self, ws_client) -> None: + ws_client._connected.set() + ws_client._queue.put({"event": "a"}) + ws_client._queue.put({"event": "b"}) + ws_client._queue.put(None) # sentinel + + results = list(ws_client.receive()) + assert results == [{"event": "a"}, {"event": "b"}] + + def test_returns_immediately_when_not_connected_within_timeout( + self, ws_client + ) -> None: + # _connected is never set, so wait(timeout=10) returns False. + # Override timeout via monkey-patching for speed. + original_wait = ws_client._connected.wait + ws_client._connected.wait = lambda timeout=None: False + + results = list(ws_client.receive()) + assert results == [] + + ws_client._connected.wait = original_wait + + def test_stops_on_none_sentinel(self, ws_client) -> None: + ws_client._connected.set() + ws_client._queue.put({"first": True}) + ws_client._queue.put(None) + ws_client._queue.put({"should_not_appear": True}) + + results = list(ws_client.receive()) + assert results == [{"first": True}] + + def test_stops_when_disconnected_and_queue_empty(self, ws_client) -> None: + ws_client._connected.set() + ws_client._queue.put({"msg": 1}) + + gen = ws_client.receive() + assert next(gen) == {"msg": 1} + + # Simulate disconnect: clear connected, queue is empty + ws_client._connected.clear() + results = list(gen) + assert results == [] + + +# --------------------------------------------------------------------------- +# Context manager +# --------------------------------------------------------------------------- + + +class TestContextManager: + """Tests for __enter__ / __exit__.""" + + def test_enter_returns_self(self, ws_client) -> None: + assert ws_client.__enter__() is ws_client + + def test_exit_calls_disconnect(self, ws_client) -> None: + mock_ws = Mock() + ws_client._ws = mock_ws + ws_client.__exit__(None, None, None) + mock_ws.close.assert_called_once() + + def test_with_statement(self, mock_session) -> None: + with _MistWebsocket(mock_session, channels=["/ch"]) as client: + assert isinstance(client, _MistWebsocket) + # After exiting, disconnect should have been called (no-op here since _ws is None) + + @patch("mistapi.websockets.__ws_client.websocket.WebSocketApp") + def test_exit_disconnects_active_connection( + self, mock_ws_cls, mock_session + ) -> None: + mock_ws_instance = Mock() + mock_ws_cls.return_value = mock_ws_instance + + with _MistWebsocket(mock_session, channels=["/ch"]) as client: + client.connect(run_in_background=False) + + mock_ws_instance.close.assert_called_once() + + +# --------------------------------------------------------------------------- +# ready() +# --------------------------------------------------------------------------- + + +class TestReady: + """Tests for ready().""" + + def test_returns_false_when_ws_is_none(self, ws_client) -> None: + assert ws_client.ready() is False + + def test_returns_true_when_ws_reports_ready(self, ws_client) -> None: + mock_ws = Mock() + mock_ws.ready.return_value = True + ws_client._ws = mock_ws + assert ws_client.ready() is True + + def test_returns_false_when_ws_not_ready(self, ws_client) -> None: + mock_ws = Mock() + mock_ws.ready.return_value = False + ws_client._ws = mock_ws + assert ws_client.ready() is False + + +# --------------------------------------------------------------------------- +# Initialisation defaults +# --------------------------------------------------------------------------- + + +class TestInit: + """Tests for __init__ defaults.""" + + def test_default_ping_interval_and_timeout(self, ws_client) -> None: + assert ws_client._ping_interval == 30 + assert ws_client._ping_timeout == 10 + + def test_custom_ping_interval_and_timeout(self, single_channel_client) -> None: + assert single_channel_client._ping_interval == 15 + assert single_channel_client._ping_timeout == 5 + + def test_queue_starts_empty(self, ws_client) -> None: + assert ws_client._queue.empty() + + def test_connected_event_starts_unset(self, ws_client) -> None: + assert not ws_client._connected.is_set() + + def test_ws_starts_none(self, ws_client) -> None: + assert ws_client._ws is None + + def test_thread_starts_none(self, ws_client) -> None: + assert ws_client._thread is None + + +# --------------------------------------------------------------------------- +# Public WebSocket channel classes +# --------------------------------------------------------------------------- + + +class TestSiteChannels: + """Tests for public site-level WebSocket channel classes.""" + + def test_clients_stats_events_channels(self, mock_session) -> None: + ws = ClientsStatsEvents(mock_session, site_ids=["s1", "s2"]) + assert ws._channels == ["/sites/s1/stats/clients", "/sites/s2/stats/clients"] + + def test_device_cmd_events_channels(self, mock_session) -> None: + ws = DeviceCmdEvents(mock_session, site_id="s1", device_ids=["d1", "d2"]) + assert ws._channels == ["/sites/s1/devices/d1/cmd", "/sites/s1/devices/d2/cmd"] + + def test_device_stats_events_channels(self, mock_session) -> None: + ws = DeviceStatsEvents(mock_session, site_ids=["s1"]) + assert ws._channels == ["/sites/s1/stats/devices"] + + def test_device_upgrades_events_channels(self, mock_session) -> None: + ws = DeviceEvents(mock_session, site_ids=["s1"]) + assert ws._channels == ["/sites/s1/devices"] + + def test_site_mxedges_stats_events_channels(self, mock_session) -> None: + ws = SiteMxEdgesStatsEvents(mock_session, site_ids=["s1"]) + assert ws._channels == ["/sites/s1/stats/mxedges"] + + def test_pcap_events_channels(self, mock_session) -> None: + ws = PcapEvents(mock_session, site_id="s1") + assert ws._channels == ["/sites/s1/pcaps"] + + def test_custom_ping_settings(self, mock_session) -> None: + ws = DeviceStatsEvents( + mock_session, site_ids=["s1"], ping_interval=60, ping_timeout=20 + ) + assert ws._ping_interval == 60 + assert ws._ping_timeout == 20 + + def test_inherits_from_mist_websocket(self, mock_session) -> None: + ws = DeviceCmdEvents(mock_session, site_id="s1", device_ids=["d1"]) + assert isinstance(ws, _MistWebsocket) + + +class TestOrgChannels: + """Tests for public org-level WebSocket channel classes.""" + + def test_insights_events_channels(self, mock_session) -> None: + ws = InsightsEvents(mock_session, org_id="o1") + assert ws._channels == ["/orgs/o1/insights/summary"] + + def test_org_mxedges_stats_events_channels(self, mock_session) -> None: + ws = OrgMxEdgesStatsEvents(mock_session, org_id="o1") + assert ws._channels == ["/orgs/o1/stats/mxedges"] + + def test_mxedges_upgrades_events_channels(self, mock_session) -> None: + ws = MxEdgesEvents(mock_session, org_id="o1") + assert ws._channels == ["/orgs/o1/mxedges"] + + def test_inherits_from_mist_websocket(self, mock_session) -> None: + ws = InsightsEvents(mock_session, org_id="o1") + assert isinstance(ws, _MistWebsocket) + + +class TestLocationChannels: + """Tests for public location-level WebSocket channel classes.""" + + def test_ble_assets_events_channels(self, mock_session) -> None: + ws = BleAssetsEvents(mock_session, site_id="s1", map_id=["m1", "m2"]) + assert ws._channels == [ + "/sites/s1/stats/maps/m1/assets", + "/sites/s1/stats/maps/m2/assets", + ] + + def test_connected_clients_events_channels(self, mock_session) -> None: + ws = ConnectedClientsEvents(mock_session, site_id="s1", map_id=["m1"]) + assert ws._channels == ["/sites/s1/stats/maps/m1/clients"] + + def test_sdk_clients_events_channels(self, mock_session) -> None: + ws = SdkClientsEvents(mock_session, site_id="s1", map_id=["m1"]) + assert ws._channels == ["/sites/s1/stats/maps/m1/sdkclients"] + + def test_unconnected_clients_events_channels(self, mock_session) -> None: + ws = UnconnectedClientsEvents(mock_session, site_id="s1", map_id=["m1"]) + assert ws._channels == ["/sites/s1/stats/maps/m1/unconnected_clients"] + + def test_discovered_ble_assets_events_channels(self, mock_session) -> None: + ws = DiscoveredBleAssetsEvents(mock_session, site_id="s1", map_id=["m1"]) + assert ws._channels == ["/sites/s1/stats/maps/m1/discovered_assets"] + + def test_inherits_from_mist_websocket(self, mock_session) -> None: + ws = BleAssetsEvents(mock_session, site_id="s1", map_id=["m1"]) + assert isinstance(ws, _MistWebsocket) + + +class TestSessionChannel: + """Tests for the SessionWithUrl WebSocket channel class.""" + + def test_session_with_url_channels(self, mock_session) -> None: + ws = SessionWithUrl(mock_session, url="wss://example.com/custom") + assert ws._channels == ["wss://example.com/custom"] + + def test_inherits_from_mist_websocket(self, mock_session) -> None: + ws = SessionWithUrl(mock_session, url="wss://example.com/custom") + assert isinstance(ws, _MistWebsocket) diff --git a/uv.lock b/uv.lock index 2c27940..11c66d0 100644 --- a/uv.lock +++ b/uv.lock @@ -537,7 +537,7 @@ wheels = [ [[package]] name = "mistapi" -version = "0.60.4" +version = "0.61.0" source = { editable = "." } dependencies = [ { name = "deprecation" }, @@ -546,6 +546,7 @@ dependencies = [ { name = "python-dotenv" }, { name = "requests" }, { name = "tabulate" }, + { name = "websocket-client" }, ] [package.dev-dependencies] @@ -569,6 +570,7 @@ requires-dist = [ { name = "python-dotenv", specifier = ">=1.1.0" }, { name = "requests", specifier = ">=2.32.3" }, { name = "tabulate", specifier = ">=0.9.0" }, + { name = "websocket-client", specifier = ">=1.8.0" }, ] [package.metadata.requires-dev] @@ -1006,6 +1008,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6d/b9/4095b668ea3678bf6a0af005527f39de12fb026516fb3df17495a733b7f8/urllib3-2.6.2-py3-none-any.whl", hash = "sha256:ec21cddfe7724fc7cb4ba4bea7aa8e2ef36f607a4bab81aa6ce42a13dc3f03dd", size = 131182, upload-time = "2025-12-11T15:56:38.584Z" }, ] +[[package]] +name = "websocket-client" +version = "1.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/41/aa4bf9664e4cda14c3b39865b12251e8e7d239f4cd0e3cc1b6c2ccde25c1/websocket_client-1.9.0.tar.gz", hash = "sha256:9e813624b6eb619999a97dc7958469217c3176312b3a16a4bd1bc7e08a46ec98", size = 70576, upload-time = "2025-10-07T21:16:36.495Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/34/db/b10e48aa8fff7407e67470363eac595018441cf32d5e1001567a7aeba5d2/websocket_client-1.9.0-py3-none-any.whl", hash = "sha256:af248a825037ef591efbf6ed20cc5faa03d3b47b9e5a2230a529eeee1c1fc3ef", size = 82616, upload-time = "2025-10-07T21:16:34.951Z" }, +] + [[package]] name = "zipp" version = "3.23.0"