diff --git a/.env.example b/.env.example new file mode 100644 index 00000000..1f8b76fc --- /dev/null +++ b/.env.example @@ -0,0 +1,40 @@ +# ============================================================================ +# Autohive Integrations — environment variables for integration tests +# ============================================================================ +# +# Copy this file to .env and fill in values for the integrations you need +# to test against live APIs. The .env file is gitignored. +# +# Unit tests (pytest -m unit) never need these — they use mocks. +# Integration tests (pytest -m integration) will skip if the required +# variable is missing. +# +# Format: VARIABLE_NAME=value (no quotes needed) +# ============================================================================ + +# -- Bitly -- +# BITLY_ACCESS_TOKEN= + +# -- NZBN -- +# NZBN_CLIENT_ID= +# NZBN_CLIENT_SECRET= +# NZBN_SUBSCRIPTION_KEY= + +# -- Notion -- +# NOTION_ACCESS_TOKEN= + +# -- Perplexity -- +# PERPLEXITY_API_KEY= + +# -- Shopify Customer -- +# SHOPIFY_CUSTOMER_ACCESS_TOKEN= +# SHOPIFY_CUSTOMER_SHOP_URL= + +# -- Stripe -- +# STRIPE_TEST_API_KEY= + +# -- Zoom -- +# ZOOM_ACCESS_TOKEN= + +# -- Xero -- +# (uses platform OAuth — tokens are short-lived, typically not set here) diff --git a/.gitignore b/.gitignore index 5892dc16..c0de07bc 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,7 @@ .idea/ */.env +.env +.venv/ +.coverage /.agents diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b1ec5866..63b9182c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -83,6 +83,202 @@ ruff check --fix my-integration ruff format my-integration ``` +## Running Tests + +Unit tests and integration tests are **run separately** — they use different file naming, markers, and discovery rules so they never interfere with each other. + +| | Unit tests | Integration tests | +|---|---|---| +| **File naming** | `test_*_unit.py` | `test_*_integration.py` | +| **Marker** | `@pytest.mark.unit` | `@pytest.mark.integration` | +| **Auto-discovered** | Yes (via `python_files` in `pyproject.toml`) | No — must pass the file path explicitly | +| **Runs in CI** | Yes | No | +| **Needs credentials** | No (fully mocked) | Yes (real API calls) | +| **Default `pytest`** | ✅ Selected by `-m unit` in `addopts` | ❌ Excluded by `-m unit` in `addopts` | + +Tests use [pytest](https://docs.pytest.org/) and run from the repo root. They use the same Python environment as the tooling (see [Local Validation](#local-validation) above). + +### Prerequisites + +Python 3.13+ is required (the SDK depends on it). Create a venv and install test dependencies: + +```bash +cd autohive-integrations +uv venv --python 3.13 .venv +source .venv/bin/activate +uv pip install -r requirements-test.txt +``` + +Each integration pins its own SDK version in its `requirements.txt`. Install the dependencies for the integration(s) you want to test: + +```bash +uv pip install -r hackernews/requirements.txt +``` + +If you don't have [uv](https://docs.astral.sh/uv/), you can use any Python 3.13+ interpreter directly: + +```bash +python3.13 -m venv .venv +source .venv/bin/activate +pip install -r requirements-test.txt +pip install -r hackernews/requirements.txt +``` + +### Running unit tests + +Unit tests are mocked — no API credentials or network access needed. They are auto-discovered by pytest from `test_*_unit.py` files. + +```bash +# Run unit tests for a single integration +pytest hackernews/ + +# Run a specific test file +pytest hackernews/tests/test_hackernews_unit.py + +# Run all unit tests (only if all integrations share the same SDK version) +pytest + +# Verbose output +pytest hackernews/ -v +``` + +If integrations pin different SDK versions, run them separately to ensure each uses its own pinned version: + +```bash +uv pip install -r bitly/requirements.txt +pytest bitly/ + +uv pip install -r notion/requirements.txt +pytest notion/ +``` + +The default `pytest` command only runs tests marked `unit` (configured in `pyproject.toml`). + +### Running integration tests + +Integration tests call real APIs and require credentials. They are **not** auto-discovered — you must pass the test file path explicitly and override the marker filter. + +Set up a `.env` file in the repo root (see `.env.example` for the template): + +```bash +cp .env.example .env +# Edit .env and add your test credentials +``` + +Then run by passing the file path directly with `-m integration`: + +```bash +# Run integration tests for one integration +pytest perplexity/tests/test_perplexity_integration.py -m integration +``` + +> **Why the explicit file path?** `pyproject.toml` restricts `python_files` to `test_*_unit.py`, so `pytest -m integration perplexity/` will **not** discover `test_*_integration.py` files. You must name the file directly. + +To run both unit and integration tests together: + +```bash +pytest perplexity/tests/test_perplexity_unit.py perplexity/tests/test_perplexity_integration.py -m "unit or integration" +``` + +Integration tests will `pytest.skip()` if the required environment variables are missing. + +### Coverage + +```bash +# Coverage for a single integration (unit tests only) +pytest --cov=hackernews hackernews/ + +# Coverage for multiple integrations +pytest --cov=hackernews --cov=bitly hackernews/ bitly/ + +# All tested integrations with line-level detail +pytest --cov=hackernews --cov=bitly --cov=nzbn --cov=notion --cov=shopify-customer +``` + +Coverage is configured in `pyproject.toml` to exclude test files — only integration source code is measured. + +### Writing tests for a new integration + +#### Unit tests + +Unit test files go in `/tests/test__unit.py`. This naming is required for auto-discovery. + +```python +import os +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../dependencies"))) + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from my_integration.my_integration import my_integration + +pytestmark = pytest.mark.unit + + +@pytest.fixture +def mock_context(): + ctx = MagicMock(name="ExecutionContext") + ctx.fetch = AsyncMock(name="fetch") + ctx.auth = {"credentials": {"api_key": "test_key"}} + return ctx + + +class TestMyAction: + async def test_success(self, mock_context): + mock_context.fetch.return_value = {"data": "value"} + result = await my_integration.execute_action("my_action", {"input": "x"}, mock_context) + assert result.result.data["data"] == "value" +``` + +#### Integration tests (optional) + +Integration test files go in `/tests/test__integration.py`. They are excluded from CI by both naming and marker. + +```python +import pytest + +pytestmark = pytest.mark.integration + + +@pytest.fixture +def live_context(): + api_key = os.environ.get("MY_API_KEY", "") + if not api_key: + pytest.skip("MY_API_KEY not set") + # ... set up context with real credentials ... + + +class TestMyAction: + async def test_real_api_call(self, live_context): + result = await my_integration.execute_action("my_action", {"input": "x"}, live_context) + assert "data" in result.result.data +``` + +#### Shared conventions + +- Use `pytestmark = pytest.mark.unit` at module level for mocked tests +- Use `pytestmark = pytest.mark.integration` for tests that hit real APIs +- Mock `context.fetch` return values to simulate API responses +- Test both success and error paths +- Also add a `conftest.py` in the integration's `tests/` dir: + +```python +import sys +import os + +sys.path.insert(0, os.path.dirname(__file__)) +``` + +### Test markers + +| Marker | Purpose | Needs credentials? | Runs in CI? | Auto-discovered? | +|--------|---------|-------------------|-------------|-----------------| +| `unit` | Mocked tests, no network | No | Yes | Yes (`test_*_unit.py`) | +| `integration` | Real API calls | Yes (via `.env`) | No | No (explicit file path) | + ## Integration Structure See the SDK's [Integration Structure Reference](https://github.com/autohive-ai/integrations-sdk/blob/master/docs/manual/integration_structure.md) for directory layouts, required files, and the full `config.json` schema. The [Building Your First Integration](https://github.com/autohive-ai/integrations-sdk/blob/master/docs/manual/building_your_first_integration.md) tutorial covers the development workflow end-to-end, and [samples/template/](https://github.com/autohive-ai/integrations-sdk/tree/master/samples/template) provides a ready-to-copy starter. diff --git a/README.md b/README.md index dce372c5..91b6198b 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # Integrations by Autohive + This repository hosts Autohive integrations made and maintained by the Autohive team. ## Getting Started @@ -120,6 +121,11 @@ Supports basic HTTP authentication and Bearer token authentication via the SDK. ### Pipedrive [pipedrive](pipedrive): Pipedrive CRM integration for managing deals, contacts, organizations, activities, and sales pipelines. Supports API token authentication. Supports full CRUD operations for deals (sales opportunities) with values, currencies, and expected close dates. Features complete contact management (persons) with email and phone information, organization management with company details, activity tracking for tasks, calls, meetings, deadlines, emails, and lunch meetings with scheduling and duration tracking. Includes note management with HTML formatting support, pipeline and stage discovery for sales workflow configuration, and universal search across all items (deals, persons, organizations, products) with exact match and custom field search capabilities. Features API token authentication, pagination support for large datasets (up to 500 items per request), and comprehensive filtering options by user, status, and custom filters. Comprises 30 actions covering deals, persons, organizations, activities, notes, pipelines, stages, and search. Ideal for sales pipeline management, customer relationship tracking, activity scheduling, and CRM automation workflows. + +### Perplexity + +[perplexity](perplexity): Web search integration powered by Perplexity's AI search API. Search the web and get ranked, structured results with titles, URLs, snippets, and dates. Supports content depth control (quick, default, detailed extraction), geographic filtering by country, multi-query search (up to 5 queries per request), and configurable result limits (1-20). Requires a `PERPLEXITY_API_KEY` environment variable. Includes 1 action for web search. Ideal for real-time web research, competitive intelligence, content curation, and market research automation. + ### Facebook Pages [facebook](facebook): Comprehensive Facebook Pages integration for managing social media presence through the Graph API v21.0. Supports page discovery, full post lifecycle (create, retrieve, schedule, delete) with text, photo, video, and link content types, comment management (read, reply, hide/unhide, like/unlike, delete), and page/post-level analytics. Features scheduled posting (10 min to 75 days ahead) with ISO 8601 and Unix timestamp support. Uses a multi-file structure pattern for maintainability with separate action modules. Includes OAuth2 authentication with comprehensive page permissions. Tested. @@ -331,3 +337,8 @@ Supports basic HTTP authentication and Bearer token authentication via the SDK. ## Template Use the [starter template](https://github.com/autohive-ai/integrations-sdk/tree/master/samples/template) in the SDK repo as the starting point for new integrations. + +## Testing + +See [CONTRIBUTING.md](CONTRIBUTING.md) for instructions on running and writing unit tests. + diff --git a/TEST_STOCKTAKE.md b/TEST_STOCKTAKE.md new file mode 100644 index 00000000..67e5eec8 --- /dev/null +++ b/TEST_STOCKTAKE.md @@ -0,0 +1,215 @@ +# Test Stocktake Report + +**Date:** 2026-04-14 +**Scope:** All 86 integrations in `autohive-integrations` + +--- + +## Executive Summary + +Every integration has a `tests/` directory with the required scaffolding (`__init__.py`, `context.py`, `test_*.py`). However, **the vast majority of tests are manual integration-test scripts** that require real API credentials to run. Only a small minority use mocks, and almost none can run in CI. There is **no automated test execution in CI** — the existing CI only validates structure, linting, and config correctness. + +--- + +## Numbers at a Glance + +| Metric | Count | +|---|---| +| Total integrations | 86 | +| Integrations with `tests/` directory | 86 (100%) | +| Integrations with `context.py` | 86 (100%) | +| Integrations with `__init__.py` | 89* | +| Total `test_*.py` files | 115 | +| Total lines of test code | ~54,400 | +| Tests using **mocks** (runnable without credentials) | 14 (16%) | +| Tests using **pytest** (`@pytest.mark`, fixtures) | 9 (10%) | +| Tests that are **manual runners** (`asyncio.run(main())`) | ~72 (84%) | +| Tests with a **custom test runner** (no framework) | 2 | +| Tests requiring **real API keys/tokens** | ~72 (84%) | +| CI workflows that **execute tests** | **0** | + +\* Some integrations have extra `__init__.py` files from nested test directories. + +--- + +## Test Style Breakdown + +### 1. Manual Integration Test Runners (~84% of tests) + +The dominant pattern. These are `async def main()` scripts run via `python test_*.py` with real credentials passed as args or env vars. + +**Example:** `stripe/tests/test_stripe.py` (727 lines) +``` +python test_stripe.py sk_test_xxx # Full CRUD suite +python test_stripe.py sk_test_xxx --quick # Read-only tests +``` + +**Characteristics:** +- Require real API credentials (env vars or CLI args) +- Use `ExecutionContext(auth=...)` directly +- Print results to stdout; no structured pass/fail +- Cannot run in CI without secrets +- No assertions in many cases — just print output and catch exceptions +- Follow the SDK template pattern exactly + +**Integrations using this pattern:** hackernews, stripe, box, google-analytics, google-calendar, asana, clickup, notion (partially), zoom, xero (partially), and ~60 more. + +### 2. Pytest-Based Mocked Unit Tests (~10% of tests) + +Proper unit tests using `pytest`, `@pytest.mark.asyncio`, `unittest.mock.AsyncMock`, and `MagicMock`. These can run without credentials. + +**Example:** `shopify-customer/tests/test_unit.py` (305 lines) +```python +@pytest.fixture +def mock_context(): + context = MagicMock() + context.auth = {"credentials": {"access_token": "test_token_123"}} + context.fetch = AsyncMock() + return context + +@pytest.mark.asyncio +async def test_get_profile_success(self, mock_context): + mock_context.fetch.return_value = {"data": {"customer": {...}}} + result = await shopify_customer.execute_action("customer_get_profile", {}, mock_context) + assert result.result.data["success"] is True +``` + +**Integrations with pytest mocked tests:** +- `shopify-customer` (test_unit.py) +- `xero` (rate limiter + purchase order tests) +- `uber` +- `linkedin`, `linkedin-ads` +- `microsoft-word`, `microsoft-powerpoint` +- `productboard` +- `instagram`, `tiktok` +- `facebook`, `humanitix` +- `spreadsheet-tools` + +### 3. Hybrid Tests (pytest + manual runner) — 2 integrations + +Files that contain both `@pytest.mark.asyncio` test functions and a `main()` with `asyncio.run()`. The pytest tests use mocks; the `main()` uses real credentials. + +**Examples:** `xero/tests/test_xero.py`, `spreadsheet-tools/tests/test_spreadsheet_tools.py` + +### 4. Custom Test Runner — 2 integrations + +Home-grown test harnesses without pytest or unittest. + +**Examples:** +- `slider/tests/test_unit_all.py` — custom `TestRunner` class with `[PASS]/[FAIL]` output +- `doc-maker/tests/test_unit.py` — custom `TestResult` class with `assert_equal`/`assert_true` + +### 5. Mocked Tests Without pytest Framework — 3 integrations + +Use `unittest.mock` but run via `asyncio.run(main())` instead of pytest. + +**Examples:** `notion/tests/test_notion_integration.py`, `powerbi/tests/test_powerbi_integration.py`, `monday-com/tests/test_monday_com.py` + +--- + +## What Tests Actually Validate + +| What's tested | How many integrations | Notes | +|---|---|---| +| Action execution against **live API** | ~72 | Requires real credentials | +| Action execution against **mocked API** | ~14 | Can run offline | +| Handler logic / helper functions | ~5 | slider, doc-maker, shopify-customer | +| Config schema correctness | ~3 | notion, gong — verify config.json matches handlers | +| Error handling / edge cases | ~8 | xero rate limiting, notion errors, shopify validation | +| Connected account handler | ~5 | zoom, uber — test `get_connected_account` | +| **No meaningful test logic** | 0 | All 86 have at least basic smoke tests | + +--- + +## CI/CD Situation + +### What CI does today (`validate-integration.yml`) +- ✅ Folder structure validation (required files present) +- ✅ `config.json` schema validation +- ✅ Python syntax + import resolution +- ✅ ruff lint + format +- ✅ bandit security scan +- ✅ pip-audit dependency check +- ✅ Config-code sync check + +### What CI does NOT do +- ❌ **No test execution** — `pytest` is never run +- ❌ No coverage reporting +- ❌ No mock test discovery/execution +- ❌ No validation that tests actually pass + +--- + +## Credential Exposure Concerns + +14 test files reference API keys/tokens via environment variables or placeholder strings. These use patterns like: +```python +API_KEY = os.environ.get("STRIPE_TEST_API_KEY", "sk_test_your_key_here") +ACCESS_TOKEN = os.environ.get("ZOOM_ACCESS_TOKEN", "") +``` + +No hardcoded real secrets were found — all use env var lookups with placeholder defaults. This is the correct pattern but means none of these tests are runnable without manual setup. + +--- + +## SDK Template Compliance + +The SDK template (`samples/template/`) prescribes: +- `tests/__init__.py` — ✅ present in all 86 +- `tests/context.py` — ✅ present in all 86 +- `tests/test_*.py` — ✅ present in all 86 +- Manual `asyncio.run()` runner style — ✅ used by ~84% + +The SDK does **not** mandate pytest or any test framework. The template uses plain async functions with `asyncio.run()`. Most integrations follow this exactly. + +--- + +## Key Observations + +1. **No tests run in CI.** The biggest gap. 86 integrations, 115 test files, 54k lines of test code — and none of it executes automatically. + +2. **~84% of tests are unrunnable without credentials.** They're useful for local developer validation but provide zero automated safety net. + +3. **Only ~14 integrations have mocked tests** that could run in CI today without any secrets. These are the low-hanging fruit for CI integration. + +4. **Two competing test "philosophies" coexist:** + - SDK template approach: manual script, print output, eyeball it + - Modern approach: pytest + mocks + assertions + CI-runnable + +5. **No shared test utilities.** Each integration reinvents mocking patterns. There's no shared `conftest.py`, mock factory, or test helper library. + +6. **Custom test runners** (slider, doc-maker) should be migrated to pytest for consistency. + +7. **The largest test files are manual runners** — microsoft365 (2,349 lines), linkedin (1,413 lines), coda (1,213 lines) — all require real credentials. + +--- + +## Recommendations (Not Actioned) + +1. **Add a CI step to run pytest on mocked tests** — the 14 integrations with mocks could be validated today with zero infrastructure changes. + +2. **Create a shared testing pattern** — a repo-level `conftest.py` or test utilities module with reusable mock context factories. + +3. **Gradually add mocked unit tests** to integrations, especially for: + - Error handling paths + - Input validation + - Data transformation logic + - Config↔handler sync verification + +4. **Standardize on pytest** — the SDK template doesn't mandate it, but pytest is already used by the most mature tests and is the Python community standard. + +5. **Consider a "test tier" system:** + - **Tier 1 (CI):** Mocked unit tests — run on every PR + - **Tier 2 (Scheduled):** Integration tests with test-environment credentials — run nightly + - **Tier 3 (Manual):** Full API tests with production-like credentials — run before releases + +6. **Fix `Integration.load()` across all integrations** — all 86 integrations call `Integration.load()` with no arguments. The SDK resolves `config.json` relative to its own package location, which only works when the SDK is vendored into `dependencies/`. When the SDK is installed as a site-package (the test setup), this breaks. The current workaround is a monkeypatch in the root `conftest.py` that uses frame inspection to find the caller's directory. The proper fix: update all integration source files to pass an explicit path: + ```python + # Before (fragile) + my_integration = Integration.load() + + # After (robust) + import os + my_integration = Integration.load(os.path.join(os.path.dirname(__file__), "config.json")) + ``` + This is a bulk change across 86 files but is mechanical and safe. Once done, the monkeypatch in `conftest.py` can be removed. Consider also proposing a fix upstream in the SDK to use caller frame inspection as the default. diff --git a/bitly/bitly.py b/bitly/bitly.py index 58dcad0a..1a57f0d1 100644 --- a/bitly/bitly.py +++ b/bitly/bitly.py @@ -136,7 +136,9 @@ async def execute(self, inputs: Dict[str, Any], context): body["archived"] = inputs["archived"] response = await context.fetch( - f"{BITLY_API_BASE_URL}/bitlinks/{encoded_bitlink}", method="PATCH", json=body + f"{BITLY_API_BASE_URL}/bitlinks/{encoded_bitlink}", + method="PATCH", + json=body, ) return ActionResult(data={"bitlink": response, "result": True}, cost_usd=0.0) @@ -153,9 +155,16 @@ async def execute(self, inputs: Dict[str, Any], context): try: bitlink = normalize_bitlink(inputs["bitlink"]) - response = await context.fetch(f"{BITLY_API_BASE_URL}/expand", method="POST", json={"bitlink_id": bitlink}) + response = await context.fetch( + f"{BITLY_API_BASE_URL}/expand", + method="POST", + json={"bitlink_id": bitlink}, + ) - return ActionResult(data={"long_url": response.get("long_url", ""), "result": True}, cost_usd=0.0) + return ActionResult( + data={"long_url": response.get("long_url", ""), "result": True}, + cost_usd=0.0, + ) except Exception as e: return ActionResult(data={"long_url": "", "result": False, "error": str(e)}, cost_usd=0.0) @@ -179,7 +188,9 @@ async def execute(self, inputs: Dict[str, Any], context): } response = await context.fetch( - f"{BITLY_API_BASE_URL}/bitlinks/{encoded_bitlink}/clicks", method="GET", params=params + f"{BITLY_API_BASE_URL}/bitlinks/{encoded_bitlink}/clicks", + method="GET", + params=params, ) clicks = response.get("link_clicks", []) @@ -205,7 +216,9 @@ async def execute(self, inputs: Dict[str, Any], context): } response = await context.fetch( - f"{BITLY_API_BASE_URL}/bitlinks/{encoded_bitlink}/clicks/summary", method="GET", params=params + f"{BITLY_API_BASE_URL}/bitlinks/{encoded_bitlink}/clicks/summary", + method="GET", + params=params, ) return ActionResult( @@ -220,7 +233,14 @@ async def execute(self, inputs: Dict[str, Any], context): except Exception as e: return ActionResult( - data={"total_clicks": 0, "unit": "", "units": 0, "result": False, "error": str(e)}, cost_usd=0.0 + data={ + "total_clicks": 0, + "unit": "", + "units": 0, + "result": False, + "error": str(e), + }, + cost_usd=0.0, ) @@ -238,7 +258,11 @@ async def execute(self, inputs: Dict[str, Any], context): group_guid = user_response.get("default_group_guid") if not group_guid: return ActionResult( - data={"bitlinks": [], "result": False, "error": "No default_group_guid found for user"}, + data={ + "bitlinks": [], + "result": False, + "error": "No default_group_guid found for user", + }, cost_usd=0.0, ) @@ -332,4 +356,7 @@ async def execute(self, inputs: Dict[str, Any], context): return ActionResult(data={"organizations": organizations, "result": True}, cost_usd=0.0) except Exception as e: - return ActionResult(data={"organizations": [], "result": False, "error": str(e)}, cost_usd=0.0) + return ActionResult( + data={"organizations": [], "result": False, "error": str(e)}, + cost_usd=0.0, + ) diff --git a/bitly/config.json b/bitly/config.json index 0f7299cb..75fced9a 100644 --- a/bitly/config.json +++ b/bitly/config.json @@ -1,7 +1,7 @@ { "name": "Bitly", "display_name": "Bitly", - "version": "1.0.0", + "version": "1.0.1", "description": "URL shortening and link management integration with Bitly for creating, managing, and tracking shortened links", "entry_point": "bitly.py", "auth": { diff --git a/bitly/tests/conftest.py b/bitly/tests/conftest.py new file mode 100644 index 00000000..1d99cac4 --- /dev/null +++ b/bitly/tests/conftest.py @@ -0,0 +1,5 @@ +import sys +import os + +# Allow 'from context import ...' to work when pytest runs from repo root +sys.path.insert(0, os.path.dirname(__file__)) diff --git a/bitly/tests/context.py b/bitly/tests/context.py deleted file mode 100644 index aa46b613..00000000 --- a/bitly/tests/context.py +++ /dev/null @@ -1,6 +0,0 @@ -import os -import sys - -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - -from bitly import bitly diff --git a/bitly/tests/test_bitly.py b/bitly/tests/test_bitly.py deleted file mode 100644 index 0ca26049..00000000 --- a/bitly/tests/test_bitly.py +++ /dev/null @@ -1,276 +0,0 @@ -# Test suite for Bitly integration -import asyncio -from context import bitly -from autohive_integrations_sdk import ExecutionContext - - -# Test credentials - replace with your actual OAuth token -TEST_AUTH = { - "auth_type": "PlatformOauth2", - "credentials": { - "access_token": "your_access_token_here" # nosec B105 - }, -} - - -# ---- User Tests ---- - - -async def test_get_user(): - """Test getting current user information.""" - async with ExecutionContext(auth=TEST_AUTH) as context: - try: - result = await bitly.execute_action("get_user", {}, context) - print(f"Get User Result: {result}") - assert result.data.get("result") - assert "user" in result.data - return result - except Exception as e: - print(f"Error testing get_user: {e}") - return None - - -# ---- Link Management Tests ---- - - -async def test_shorten_url(): - """Test shortening a URL.""" - inputs = {"long_url": "https://www.example.com/very/long/url/path"} - - async with ExecutionContext(auth=TEST_AUTH) as context: - try: - result = await bitly.execute_action("shorten_url", inputs, context) - print(f"Shorten URL Result: {result}") - assert result.data.get("result") - assert "bitlink" in result.data - return result - except Exception as e: - print(f"Error testing shorten_url: {e}") - return None - - -async def test_create_bitlink(): - """Test creating a bitlink with options.""" - inputs = {"long_url": "https://www.example.com/another/url", "title": "Test Link", "tags": ["test", "autohive"]} - - async with ExecutionContext(auth=TEST_AUTH) as context: - try: - result = await bitly.execute_action("create_bitlink", inputs, context) - print(f"Create Bitlink Result: {result}") - assert result.data.get("result") - return result - except Exception as e: - print(f"Error testing create_bitlink: {e}") - return None - - -async def test_get_bitlink(): - """Test getting bitlink information.""" - inputs = {"bitlink": "bit.ly/example"} # Replace with actual bitlink - - async with ExecutionContext(auth=TEST_AUTH) as context: - try: - result = await bitly.execute_action("get_bitlink", inputs, context) - print(f"Get Bitlink Result: {result}") - assert result.data.get("result") - return result - except Exception as e: - print(f"Error testing get_bitlink: {e}") - return None - - -async def test_update_bitlink(): - """Test updating a bitlink.""" - inputs = { - "bitlink": "bit.ly/example", # Replace with actual bitlink - "title": "Updated Title", - } - - async with ExecutionContext(auth=TEST_AUTH) as context: - try: - result = await bitly.execute_action("update_bitlink", inputs, context) - print(f"Update Bitlink Result: {result}") - assert result.data.get("result") - return result - except Exception as e: - print(f"Error testing update_bitlink: {e}") - return None - - -async def test_expand_bitlink(): - """Test expanding a bitlink to get original URL.""" - inputs = {"bitlink": "bit.ly/example"} # Replace with actual bitlink - - async with ExecutionContext(auth=TEST_AUTH) as context: - try: - result = await bitly.execute_action("expand_bitlink", inputs, context) - print(f"Expand Bitlink Result: {result}") - assert result.data.get("result") - assert "long_url" in result.data - return result - except Exception as e: - print(f"Error testing expand_bitlink: {e}") - return None - - -async def test_list_bitlinks(): - """Test listing bitlinks.""" - inputs = {"size": 10} - - async with ExecutionContext(auth=TEST_AUTH) as context: - try: - result = await bitly.execute_action("list_bitlinks", inputs, context) - print(f"List Bitlinks Result: {result}") - assert result.data.get("result") - assert "bitlinks" in result.data - return result - except Exception as e: - print(f"Error testing list_bitlinks: {e}") - return None - - -# ---- Click Analytics Tests ---- - - -async def test_get_clicks(): - """Test getting click counts.""" - inputs = { - "bitlink": "bit.ly/example", # Replace with actual bitlink - "unit": "day", - "units": 7, - } - - async with ExecutionContext(auth=TEST_AUTH) as context: - try: - result = await bitly.execute_action("get_clicks", inputs, context) - print(f"Get Clicks Result: {result}") - assert result.data.get("result") - assert "clicks" in result.data - return result - except Exception as e: - print(f"Error testing get_clicks: {e}") - return None - - -async def test_get_clicks_summary(): - """Test getting clicks summary.""" - inputs = { - "bitlink": "bit.ly/example", # Replace with actual bitlink - "unit": "day", - "units": 30, - } - - async with ExecutionContext(auth=TEST_AUTH) as context: - try: - result = await bitly.execute_action("get_clicks_summary", inputs, context) - print(f"Get Clicks Summary Result: {result}") - assert result.data.get("result") - assert "total_clicks" in result.data - return result - except Exception as e: - print(f"Error testing get_clicks_summary: {e}") - return None - - -# ---- Group & Organization Tests ---- - - -async def test_list_groups(): - """Test listing groups.""" - async with ExecutionContext(auth=TEST_AUTH) as context: - try: - result = await bitly.execute_action("list_groups", {}, context) - print(f"List Groups Result: {result}") - assert result.data.get("result") - assert "groups" in result.data - return result - except Exception as e: - print(f"Error testing list_groups: {e}") - return None - - -async def test_get_group(): - """Test getting a group.""" - inputs = {"group_guid": "your_group_guid"} # Replace with actual group GUID - - async with ExecutionContext(auth=TEST_AUTH) as context: - try: - result = await bitly.execute_action("get_group", inputs, context) - print(f"Get Group Result: {result}") - assert result.data.get("result") - return result - except Exception as e: - print(f"Error testing get_group: {e}") - return None - - -async def test_list_organizations(): - """Test listing organizations.""" - async with ExecutionContext(auth=TEST_AUTH) as context: - try: - result = await bitly.execute_action("list_organizations", {}, context) - print(f"List Organizations Result: {result}") - assert result.data.get("result") - assert "organizations" in result.data - return result - except Exception as e: - print(f"Error testing list_organizations: {e}") - return None - - -# Main test runner -async def run_all_tests(): - """Run all test functions.""" - print("=" * 60) - print("Bitly Integration Test Suite") - print("=" * 60) - print() - print("NOTE: Replace placeholders with actual values:") - print(" - your_access_token_here: Your OAuth access token") - print(" - bit.ly/example: Replace with actual bitlinks") - print(" - your_group_guid: Replace with actual group GUID") - print() - print("TIP: Run get_user and list_groups first to discover IDs!") - print() - - test_functions = [ - # User - ("Get User", test_get_user), - # Link Management - ("Shorten URL", test_shorten_url), - ("Create Bitlink", test_create_bitlink), - ("Get Bitlink", test_get_bitlink), - ("Update Bitlink", test_update_bitlink), - ("Expand Bitlink", test_expand_bitlink), - ("List Bitlinks", test_list_bitlinks), - # Click Analytics - ("Get Clicks", test_get_clicks), - ("Get Clicks Summary", test_get_clicks_summary), - # Groups & Organizations - ("List Groups", test_list_groups), - ("Get Group", test_get_group), - ("List Organizations", test_list_organizations), - ] - - results = [] - for test_name, test_func in test_functions: - print(f"\n{'-' * 60}") - print(f"Running: {test_name}") - print(f"{'-' * 60}") - result = await test_func() - results.append((test_name, result is not None)) - - print("\n" + "=" * 60) - print("Test Summary") - print("=" * 60) - for test_name, passed in results: - status = "PASS" if passed else "FAIL" - print(f"{status}: {test_name}") - - passed_count = sum(1 for _, passed in results if passed) - print(f"\nTotal: {passed_count}/{len(results)} tests passed") - print("=" * 60) - - -if __name__ == "__main__": - asyncio.run(run_all_tests()) diff --git a/bitly/tests/test_bitly_integration.py b/bitly/tests/test_bitly_integration.py new file mode 100644 index 00000000..5a0f64f6 --- /dev/null +++ b/bitly/tests/test_bitly_integration.py @@ -0,0 +1,219 @@ +""" +End-to-end integration tests for the Bitly integration (read-only actions). + +These tests call the real Bitly API and require a valid OAuth access token +set in the BITLY_ACCESS_TOKEN environment variable (via .env or export). + +Write actions (shorten_url, create_bitlink, update_bitlink) are intentionally +excluded — they create/modify real data in the Bitly account. + +Some tests require at least one bitlink to exist in the account. These will +skip gracefully if none are found: + - TestGetBitlink + - TestExpandBitlink + - TestGetClicks + - TestGetClicksSummary + +Run with: + pytest bitly/tests/test_bitly_integration.py -m integration + +Never runs in CI — the default pytest marker filter (-m unit) excludes these, +and the file naming (test_*_integration.py) is not matched by python_files. +""" + +import os +import sys +import importlib + +_parent = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +_deps = os.path.abspath(os.path.join(os.path.dirname(__file__), "../dependencies")) +sys.path.insert(0, _parent) +sys.path.insert(0, _deps) + +import pytest # noqa: E402 +from unittest.mock import MagicMock, AsyncMock # noqa: E402 + +_spec = importlib.util.spec_from_file_location("bitly_mod", os.path.join(_parent, "bitly.py")) +_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_mod) + +bitly = _mod.bitly + +pytestmark = pytest.mark.integration + +ACCESS_TOKEN = os.environ.get("BITLY_ACCESS_TOKEN", "") + + +@pytest.fixture +def live_context(): + """Execution context wired to a real HTTP client with Bitly OAuth token. + + The Bitly integration relies on context.fetch to auto-inject the OAuth token + (auth.type = "platform"). In tests we bypass the SDK auth layer and manually + add the Authorization header to every request. + """ + if not ACCESS_TOKEN: + pytest.skip("BITLY_ACCESS_TOKEN not set — skipping integration tests") + + import aiohttp + + async def real_fetch(url, *, method="GET", json=None, headers=None, params=None, **kwargs): + merged_headers = dict(headers or {}) + merged_headers["Authorization"] = f"Bearer {ACCESS_TOKEN}" + async with aiohttp.ClientSession() as session: + async with session.request(method, url, json=json, headers=merged_headers, params=params) as resp: + return await resp.json() + + ctx = MagicMock(name="ExecutionContext") + ctx.fetch = AsyncMock(side_effect=real_fetch) + ctx.auth = {"auth_type": "PlatformOauth2", "credentials": {"access_token": ACCESS_TOKEN}} + return ctx + + +# ---- User ---- + + +class TestGetUser: + async def test_returns_user_info(self, live_context): + result = await bitly.execute_action("get_user", {}, live_context) + + data = result.result.data + assert data["result"] is True + assert "user" in data + user = data["user"] + assert "login" in user + assert "default_group_guid" in user + + +# ---- Groups & Organizations ---- + + +class TestListGroups: + async def test_returns_groups(self, live_context): + result = await bitly.execute_action("list_groups", {}, live_context) + + data = result.result.data + assert data["result"] is True + assert "groups" in data + assert len(data["groups"]) > 0 + + async def test_group_structure(self, live_context): + result = await bitly.execute_action("list_groups", {}, live_context) + + group = result.result.data["groups"][0] + assert "guid" in group + assert "organization_guid" in group + + +class TestGetGroup: + async def test_fetches_group_by_guid(self, live_context): + # First get a real group GUID + groups_result = await bitly.execute_action("list_groups", {}, live_context) + group_guid = groups_result.result.data["groups"][0]["guid"] + + result = await bitly.execute_action("get_group", {"group_guid": group_guid}, live_context) + + data = result.result.data + assert data["result"] is True + assert data["group"]["guid"] == group_guid + + +class TestListOrganizations: + async def test_returns_organizations(self, live_context): + result = await bitly.execute_action("list_organizations", {}, live_context) + + data = result.result.data + assert data["result"] is True + assert "organizations" in data + assert len(data["organizations"]) > 0 + + +# ---- Bitlinks ---- + + +class TestListBitlinks: + async def test_returns_bitlinks(self, live_context): + result = await bitly.execute_action("list_bitlinks", {"size": 5}, live_context) + + data = result.result.data + assert data["result"] is True + assert "bitlinks" in data + + +class TestGetBitlink: + async def test_fetches_bitlink_details(self, live_context): + # First get a real bitlink from the account + list_result = await bitly.execute_action("list_bitlinks", {"size": 1}, live_context) + bitlinks = list_result.result.data["bitlinks"] + + if not bitlinks: + pytest.skip("No bitlinks in account to test with") + + bitlink_id = bitlinks[0].get("id", bitlinks[0].get("link", "")) + + result = await bitly.execute_action("get_bitlink", {"bitlink": bitlink_id}, live_context) + + data = result.result.data + assert data["result"] is True + assert "bitlink" in data + assert "long_url" in data["bitlink"] + + +class TestExpandBitlink: + async def test_expands_to_long_url(self, live_context): + list_result = await bitly.execute_action("list_bitlinks", {"size": 1}, live_context) + bitlinks = list_result.result.data["bitlinks"] + + if not bitlinks: + pytest.skip("No bitlinks in account to test with") + + bitlink_id = bitlinks[0].get("id", bitlinks[0].get("link", "")) + + result = await bitly.execute_action("expand_bitlink", {"bitlink": bitlink_id}, live_context) + + data = result.result.data + assert data["result"] is True + assert data["long_url"] != "" + assert data["long_url"].startswith("http") + + +# ---- Click Analytics ---- + + +class TestGetClicks: + async def test_returns_click_data(self, live_context): + list_result = await bitly.execute_action("list_bitlinks", {"size": 1}, live_context) + bitlinks = list_result.result.data["bitlinks"] + + if not bitlinks: + pytest.skip("No bitlinks in account to test with") + + bitlink_id = bitlinks[0].get("id", bitlinks[0].get("link", "")) + + result = await bitly.execute_action( + "get_clicks", {"bitlink": bitlink_id, "unit": "day", "units": 7}, live_context + ) + + data = result.result.data + assert data["result"] is True + assert "clicks" in data + + +class TestGetClicksSummary: + async def test_returns_summary(self, live_context): + list_result = await bitly.execute_action("list_bitlinks", {"size": 1}, live_context) + bitlinks = list_result.result.data["bitlinks"] + + if not bitlinks: + pytest.skip("No bitlinks in account to test with") + + bitlink_id = bitlinks[0].get("id", bitlinks[0].get("link", "")) + + result = await bitly.execute_action( + "get_clicks_summary", {"bitlink": bitlink_id, "unit": "day", "units": 30}, live_context + ) + + data = result.result.data + assert data["result"] is True + assert "total_clicks" in data + assert isinstance(data["total_clicks"], int) diff --git a/bitly/tests/test_bitly_unit.py b/bitly/tests/test_bitly_unit.py new file mode 100644 index 00000000..1bb28583 --- /dev/null +++ b/bitly/tests/test_bitly_unit.py @@ -0,0 +1,454 @@ +import os +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../dependencies"))) + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from bitly.bitly import bitly, normalize_bitlink, encode_bitlink_for_url + +pytestmark = pytest.mark.unit + +BITLY_API_BASE_URL = "https://api-ssl.bitly.com/v4" + + +@pytest.fixture +def mock_context(): + ctx = MagicMock(name="ExecutionContext") + ctx.fetch = AsyncMock(name="fetch") + ctx.auth = { + "auth_type": "PlatformOauth2", + "credentials": {"access_token": "test_token"}, # nosec B105 + } + return ctx + + +# ---- Pure Function Tests ---- + + +class TestNormalizeBitlink: + def test_full_http_url(self): + assert normalize_bitlink("http://bit.ly/abc123") == "bit.ly/abc123" + + def test_full_https_url(self): + assert normalize_bitlink("https://bit.ly/abc123") == "bit.ly/abc123" + + def test_domain_slash_path_format(self): + assert normalize_bitlink("bit.ly/abc123") == "bit.ly/abc123" + + def test_custom_domain(self): + assert normalize_bitlink("https://custom.short/xyz") == "custom.short/xyz" + + def test_hash_only(self): + assert normalize_bitlink("abc123") == "bit.ly/abc123" + + def test_url_with_trailing_path(self): + assert normalize_bitlink("https://bit.ly/abc/def") == "bit.ly/abc/def" + + +class TestEncodeBitlinkForUrl: + def test_encodes_slash(self): + assert encode_bitlink_for_url("bit.ly/abc123") == "bit.ly%2Fabc123" + + def test_encodes_special_characters(self): + result = encode_bitlink_for_url("bit.ly/a b") + assert "%2F" in result + assert "%20" in result + + def test_no_slash(self): + assert encode_bitlink_for_url("abc123") == "abc123" + + +# ---- Action Tests ---- + + +class TestGetUser: + @pytest.mark.asyncio + async def test_returns_user(self, mock_context): + mock_context.fetch.return_value = {"login": "testuser", "name": "Test"} + + result = await bitly.execute_action("get_user", {}, mock_context) + + assert result.result.data["result"] is True + assert result.result.data["user"] == {"login": "testuser", "name": "Test"} + mock_context.fetch.assert_called_once_with(f"{BITLY_API_BASE_URL}/user", method="GET") + + +class TestShortenUrl: + @pytest.mark.asyncio + async def test_shorten_basic(self, mock_context): + mock_context.fetch.return_value = { + "link": "https://bit.ly/short", + "id": "bit.ly/short", + } + inputs = {"long_url": "https://example.com/long"} + + result = await bitly.execute_action("shorten_url", inputs, mock_context) + + assert result.result.data["result"] is True + assert result.result.data["bitlink"]["link"] == "https://bit.ly/short" + mock_context.fetch.assert_called_once_with( + f"{BITLY_API_BASE_URL}/shorten", + method="POST", + json={"long_url": "https://example.com/long"}, + ) + + @pytest.mark.asyncio + async def test_shorten_with_domain_and_group(self, mock_context): + mock_context.fetch.return_value = {"link": "https://cstm.ly/x"} + inputs = { + "long_url": "https://example.com", + "domain": "cstm.ly", + "group_guid": "Ga1b2c", + } + + result = await bitly.execute_action("shorten_url", inputs, mock_context) + + assert result.result.data["result"] is True + call_kwargs = mock_context.fetch.call_args + body = call_kwargs.kwargs["json"] + assert body["domain"] == "cstm.ly" + assert body["group_guid"] == "Ga1b2c" + + +class TestCreateBitlink: + @pytest.mark.asyncio + async def test_create_minimal(self, mock_context): + mock_context.fetch.return_value = { + "link": "https://bit.ly/new", + "id": "bit.ly/new", + } + inputs = {"long_url": "https://example.com/page"} + + result = await bitly.execute_action("create_bitlink", inputs, mock_context) + + assert result.result.data["result"] is True + call_kwargs = mock_context.fetch.call_args + assert call_kwargs.kwargs["json"] == {"long_url": "https://example.com/page"} + + @pytest.mark.asyncio + async def test_create_with_all_options(self, mock_context): + mock_context.fetch.return_value = {"link": "https://bit.ly/custom"} + inputs = { + "long_url": "https://example.com", + "domain": "bit.ly", + "group_guid": "Ga1b2c", + "title": "My Link", + "tags": ["test", "demo"], + "custom_back_half": "mylink", + } + + result = await bitly.execute_action("create_bitlink", inputs, mock_context) + + assert result.result.data["result"] is True + body = mock_context.fetch.call_args.kwargs["json"] + assert body["long_url"] == "https://example.com" + assert body["domain"] == "bit.ly" + assert body["group_guid"] == "Ga1b2c" + assert body["title"] == "My Link" + assert body["tags"] == ["test", "demo"] + assert body["custom_back_half"] == "mylink" + + +class TestGetBitlink: + @pytest.mark.asyncio + async def test_get_by_domain_path(self, mock_context): + mock_context.fetch.return_value = { + "id": "bit.ly/abc", + "long_url": "https://example.com", + } + inputs = {"bitlink": "bit.ly/abc"} + + result = await bitly.execute_action("get_bitlink", inputs, mock_context) + + assert result.result.data["result"] is True + assert result.result.data["bitlink"]["id"] == "bit.ly/abc" + mock_context.fetch.assert_called_once_with(f"{BITLY_API_BASE_URL}/bitlinks/bit.ly%2Fabc", method="GET") + + @pytest.mark.asyncio + async def test_get_by_full_url(self, mock_context): + mock_context.fetch.return_value = {"id": "bit.ly/abc"} + inputs = {"bitlink": "https://bit.ly/abc"} + + await bitly.execute_action("get_bitlink", inputs, mock_context) + + mock_context.fetch.assert_called_once_with(f"{BITLY_API_BASE_URL}/bitlinks/bit.ly%2Fabc", method="GET") + + +class TestUpdateBitlink: + @pytest.mark.asyncio + async def test_update_title(self, mock_context): + mock_context.fetch.return_value = {"id": "bit.ly/abc", "title": "New Title"} + inputs = {"bitlink": "bit.ly/abc", "title": "New Title"} + + result = await bitly.execute_action("update_bitlink", inputs, mock_context) + + assert result.result.data["result"] is True + call_kwargs = mock_context.fetch.call_args + assert call_kwargs.kwargs["json"] == {"title": "New Title"} + assert call_kwargs.kwargs["method"] == "PATCH" + + @pytest.mark.asyncio + async def test_update_multiple_fields(self, mock_context): + mock_context.fetch.return_value = {"id": "bit.ly/abc"} + inputs = { + "bitlink": "bit.ly/abc", + "title": "T", + "tags": ["a"], + "archived": True, + } + + await bitly.execute_action("update_bitlink", inputs, mock_context) + + body = mock_context.fetch.call_args.kwargs["json"] + assert body == {"title": "T", "tags": ["a"], "archived": True} + + +class TestExpandBitlink: + @pytest.mark.asyncio + async def test_expand(self, mock_context): + mock_context.fetch.return_value = {"long_url": "https://example.com/original"} + inputs = {"bitlink": "bit.ly/abc"} + + result = await bitly.execute_action("expand_bitlink", inputs, mock_context) + + assert result.result.data["result"] is True + assert result.result.data["long_url"] == "https://example.com/original" + mock_context.fetch.assert_called_once_with( + f"{BITLY_API_BASE_URL}/expand", + method="POST", + json={"bitlink_id": "bit.ly/abc"}, + ) + + @pytest.mark.asyncio + async def test_expand_normalizes_full_url(self, mock_context): + mock_context.fetch.return_value = {"long_url": "https://example.com"} + inputs = {"bitlink": "https://bit.ly/abc"} + + await bitly.execute_action("expand_bitlink", inputs, mock_context) + + body = mock_context.fetch.call_args.kwargs["json"] + assert body["bitlink_id"] == "bit.ly/abc" + + +class TestGetClicks: + @pytest.mark.asyncio + async def test_get_clicks_with_params(self, mock_context): + mock_context.fetch.return_value = { + "link_clicks": [{"clicks": 5, "date": "2025-01-01"}], + } + inputs = {"bitlink": "bit.ly/abc", "unit": "day", "units": 7} + + result = await bitly.execute_action("get_clicks", inputs, mock_context) + + assert result.result.data["result"] is True + assert result.result.data["clicks"] == [{"clicks": 5, "date": "2025-01-01"}] + call_kwargs = mock_context.fetch.call_args + assert call_kwargs.kwargs["params"] == {"unit": "day", "units": 7} + + @pytest.mark.asyncio + async def test_get_clicks_defaults(self, mock_context): + mock_context.fetch.return_value = {"link_clicks": []} + inputs = {"bitlink": "bit.ly/abc"} + + await bitly.execute_action("get_clicks", inputs, mock_context) + + params = mock_context.fetch.call_args.kwargs["params"] + assert params["unit"] == "day" + assert params["units"] == -1 + + +class TestGetClicksSummary: + @pytest.mark.asyncio + async def test_get_summary(self, mock_context): + mock_context.fetch.return_value = { + "total_clicks": 42, + "unit": "day", + "units": 30, + } + inputs = {"bitlink": "bit.ly/abc", "unit": "day", "units": 30} + + result = await bitly.execute_action("get_clicks_summary", inputs, mock_context) + + assert result.result.data["result"] is True + assert result.result.data["total_clicks"] == 42 + assert result.result.data["unit"] == "day" + assert result.result.data["units"] == 30 + + @pytest.mark.asyncio + async def test_get_summary_defaults(self, mock_context): + mock_context.fetch.return_value = { + "total_clicks": 0, + "unit": "day", + "units": -1, + } + inputs = {"bitlink": "bit.ly/abc"} + + await bitly.execute_action("get_clicks_summary", inputs, mock_context) + + params = mock_context.fetch.call_args.kwargs["params"] + assert params["unit"] == "day" + assert params["units"] == -1 + + +class TestListBitlinks: + @pytest.mark.asyncio + async def test_with_group_guid(self, mock_context): + mock_context.fetch.return_value = { + "links": [{"id": "bit.ly/a"}, {"id": "bit.ly/b"}], + "pagination": {"total": 2, "page": 1, "size": 50}, + } + inputs = {"group_guid": "Gabcdef"} + + result = await bitly.execute_action("list_bitlinks", inputs, mock_context) + + assert result.result.data["result"] is True + assert len(result.result.data["bitlinks"]) == 2 + assert result.result.data["total"] == 2 + mock_context.fetch.assert_called_once() + call_url = mock_context.fetch.call_args.args[0] + assert "groups/Gabcdef/bitlinks" in call_url + + @pytest.mark.asyncio + async def test_without_group_guid_fetches_user(self, mock_context): + mock_context.fetch.side_effect = [ + {"default_group_guid": "Gauto123"}, + { + "links": [{"id": "bit.ly/x"}], + "pagination": {"total": 1, "page": 1, "size": 50}, + }, + ] + inputs = {} + + result = await bitly.execute_action("list_bitlinks", inputs, mock_context) + + assert result.result.data["result"] is True + assert len(result.result.data["bitlinks"]) == 1 + assert mock_context.fetch.call_count == 2 + first_call = mock_context.fetch.call_args_list[0] + assert first_call.args[0] == f"{BITLY_API_BASE_URL}/user" + + @pytest.mark.asyncio + async def test_without_group_guid_no_default(self, mock_context): + mock_context.fetch.return_value = {"default_group_guid": None} + inputs = {} + + result = await bitly.execute_action("list_bitlinks", inputs, mock_context) + + assert result.result.data["result"] is False + assert "No default_group_guid" in result.result.data["error"] + + @pytest.mark.asyncio + async def test_with_pagination_params(self, mock_context): + mock_context.fetch.return_value = { + "links": [], + "pagination": {"total": 0, "page": 2, "size": 10}, + } + inputs = {"group_guid": "Gabcdef", "size": 10, "page": 2, "keyword": "test"} + + await bitly.execute_action("list_bitlinks", inputs, mock_context) + + call_kwargs = mock_context.fetch.call_args + params = call_kwargs.kwargs["params"] + assert params["size"] == 10 + assert params["page"] == 2 + assert params["keyword"] == "test" + + +class TestListGroups: + @pytest.mark.asyncio + async def test_list_groups(self, mock_context): + mock_context.fetch.return_value = { + "groups": [{"guid": "G1", "name": "Default"}], + } + + result = await bitly.execute_action("list_groups", {}, mock_context) + + assert result.result.data["result"] is True + assert result.result.data["groups"] == [{"guid": "G1", "name": "Default"}] + mock_context.fetch.assert_called_once_with(f"{BITLY_API_BASE_URL}/groups", method="GET") + + +class TestGetGroup: + @pytest.mark.asyncio + async def test_get_group(self, mock_context): + mock_context.fetch.return_value = {"guid": "G1", "name": "My Group"} + inputs = {"group_guid": "G1"} + + result = await bitly.execute_action("get_group", inputs, mock_context) + + assert result.result.data["result"] is True + assert result.result.data["group"]["guid"] == "G1" + mock_context.fetch.assert_called_once_with(f"{BITLY_API_BASE_URL}/groups/G1", method="GET") + + +class TestListOrganizations: + @pytest.mark.asyncio + async def test_list_organizations(self, mock_context): + mock_context.fetch.return_value = { + "organizations": [{"guid": "O1", "name": "Org"}], + } + + result = await bitly.execute_action("list_organizations", {}, mock_context) + + assert result.result.data["result"] is True + assert result.result.data["organizations"] == [{"guid": "O1", "name": "Org"}] + mock_context.fetch.assert_called_once_with(f"{BITLY_API_BASE_URL}/organizations", method="GET") + + +class TestErrorHandling: + @pytest.mark.asyncio + async def test_get_user_error(self, mock_context): + mock_context.fetch.side_effect = Exception("Network error") + + result = await bitly.execute_action("get_user", {}, mock_context) + + assert result.result.data["result"] is False + assert "Network error" in result.result.data["error"] + + @pytest.mark.asyncio + async def test_shorten_url_error(self, mock_context): + mock_context.fetch.side_effect = Exception("API failure") + + result = await bitly.execute_action("shorten_url", {"long_url": "https://example.com"}, mock_context) + + assert result.result.data["result"] is False + assert result.result.data["bitlink"] == {} + + @pytest.mark.asyncio + async def test_get_clicks_error(self, mock_context): + mock_context.fetch.side_effect = Exception("Timeout") + + result = await bitly.execute_action("get_clicks", {"bitlink": "bit.ly/abc"}, mock_context) + + assert result.result.data["result"] is False + assert result.result.data["clicks"] == [] + + @pytest.mark.asyncio + async def test_list_bitlinks_error(self, mock_context): + mock_context.fetch.side_effect = Exception("Server error") + + result = await bitly.execute_action("list_bitlinks", {"group_guid": "G1"}, mock_context) + + assert result.result.data["result"] is False + assert result.result.data["bitlinks"] == [] + + @pytest.mark.asyncio + async def test_expand_bitlink_error(self, mock_context): + mock_context.fetch.side_effect = Exception("Bad request") + + result = await bitly.execute_action("expand_bitlink", {"bitlink": "bit.ly/abc"}, mock_context) + + assert result.result.data["result"] is False + assert result.result.data["long_url"] == "" + + @pytest.mark.asyncio + async def test_get_clicks_summary_error(self, mock_context): + mock_context.fetch.side_effect = Exception("Forbidden") + + result = await bitly.execute_action("get_clicks_summary", {"bitlink": "bit.ly/abc"}, mock_context) + + assert result.result.data["result"] is False + assert result.result.data["total_clicks"] == 0 diff --git a/conftest.py b/conftest.py new file mode 100644 index 00000000..026017c5 --- /dev/null +++ b/conftest.py @@ -0,0 +1,132 @@ +""" +Root conftest.py — shared fixtures for all integration test suites. + +Provides: +- mock_context: A pre-configured MagicMock ExecutionContext with AsyncMock fetch +- make_context: Factory for building mock contexts with custom auth shapes +- env_credentials: Helper to load credentials from env/.env files + +Also patches Integration.load() so it resolves config.json from the calling +module's directory (instead of the SDK's package tree), which is needed when +the SDK is installed as a site-package rather than vendored. + +Usage in any integration's tests/: + def test_something(mock_context): + mock_context.fetch.return_value = {"data": ...} + ... +""" + +from __future__ import annotations + +import inspect +import os +from pathlib import Path +from typing import Any, Dict, Optional, Union +from unittest.mock import AsyncMock, MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Patch Integration.load() for non-vendored SDK installs +# --------------------------------------------------------------------------- +# The SDK's default load() resolves config.json relative to the SDK package +# itself (three dirname() calls up from integration.py). When the SDK is a +# normal site-package this resolves into site-packages/, not the integration +# directory. We monkeypatch it to use caller frame inspection instead. + +from autohive_integrations_sdk import Integration # noqa: E402 + +_original_load = Integration.load.__func__ + + +@classmethod # type: ignore[misc] +def _patched_load(cls, config_path: Union[str, Path, None] = None) -> "Integration": + if config_path is None: + frame = inspect.stack()[1] + caller_dir = Path(frame.filename).resolve().parent + config_path = caller_dir / "config.json" + return _original_load(cls, config_path) + + +Integration.load = _patched_load # type: ignore[assignment] + + +# --------------------------------------------------------------------------- +# .env file loading (stdlib-only, no third-party dependency) +# --------------------------------------------------------------------------- + + +def _load_dotenv(path: Path) -> None: + """Load a .env file into os.environ (simple key=value, ignores comments).""" + if not path.is_file(): + return + with open(path) as fh: + for line in fh: + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, _, value = line.partition("=") + key = key.strip() + value = value.strip().strip("'\"") + os.environ.setdefault(key, value) + + +# Load project-root .env once at collection time +_load_dotenv(Path(__file__).parent / ".env") + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_context() -> MagicMock: + """Minimal mock ExecutionContext with an async-capable ``fetch``.""" + ctx = MagicMock(name="ExecutionContext") + ctx.fetch = AsyncMock(name="fetch") + ctx.auth = {} + return ctx + + +@pytest.fixture +def make_context(): + """Factory fixture — build a mock context with arbitrary auth. + + Example:: + + def test_foo(make_context): + ctx = make_context(auth={"credentials": {"api_key": "k"}}) + ctx.fetch.return_value = {...} + """ + + def _factory( + *, + auth: Optional[Dict[str, Any]] = None, + ) -> MagicMock: + ctx = MagicMock(name="ExecutionContext") + ctx.fetch = AsyncMock(name="fetch") + ctx.auth = auth or {} + return ctx + + return _factory + + +@pytest.fixture +def env_credentials(): + """Return a helper that reads credentials from environment variables. + + Example:: + + def test_live(env_credentials): + creds = env_credentials("BITLY_ACCESS_TOKEN") + if creds is None: + pytest.skip("BITLY_ACCESS_TOKEN not set") + """ + + def _get(var_name: str) -> Optional[str]: + val = os.environ.get(var_name) + return val if val else None + + return _get diff --git a/hackernews/tests/conftest.py b/hackernews/tests/conftest.py new file mode 100644 index 00000000..1d99cac4 --- /dev/null +++ b/hackernews/tests/conftest.py @@ -0,0 +1,5 @@ +import sys +import os + +# Allow 'from context import ...' to work when pytest runs from repo root +sys.path.insert(0, os.path.dirname(__file__)) diff --git a/hackernews/tests/context.py b/hackernews/tests/context.py deleted file mode 100755 index ee9f7e2e..00000000 --- a/hackernews/tests/context.py +++ /dev/null @@ -1,8 +0,0 @@ -# -*- coding: utf-8 -*- -import sys -import os - -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../dependencies"))) - -from hackernews import hackernews # noqa: F401 diff --git a/hackernews/tests/test_hackernews.py b/hackernews/tests/test_hackernews.py deleted file mode 100755 index 819a69ce..00000000 --- a/hackernews/tests/test_hackernews.py +++ /dev/null @@ -1,200 +0,0 @@ -import asyncio -from context import hackernews -from autohive_integrations_sdk import ExecutionContext - - -async def test_get_top_stories(): - """Test fetching top stories.""" - print("\nTesting get_top_stories...") - - inputs = {"limit": 5} - - async with ExecutionContext(auth={}) as context: - try: - result = await hackernews.execute_action("get_top_stories", inputs, context) - assert "stories" in result - assert "fetched_at" in result - assert "count" in result - assert len(result["stories"]) <= 5 - - if result["stories"]: - story = result["stories"][0] - assert "id" in story - assert "title" in story - assert "hn_url" in story - print(f" [OK] Got {result['count']} stories") - print(f" Top story: {story['title'][:60]}...") - else: - print(" [WARN] No stories returned") - except Exception as e: - print(f" [FAIL] Error: {e}") - raise - - -async def test_get_best_stories(): - """Test fetching best stories.""" - print("\nTesting get_best_stories...") - - inputs = {"limit": 3} - - async with ExecutionContext(auth={}) as context: - try: - result = await hackernews.execute_action("get_best_stories", inputs, context) - assert "stories" in result - assert len(result["stories"]) <= 3 - print(f" [OK] Got {result['count']} best stories") - except Exception as e: - print(f" [FAIL] Error: {e}") - raise - - -async def test_get_new_stories(): - """Test fetching new stories.""" - print("\nTesting get_new_stories...") - - inputs = {"limit": 3} - - async with ExecutionContext(auth={}) as context: - try: - result = await hackernews.execute_action("get_new_stories", inputs, context) - assert "stories" in result - print(f" [OK] Got {result['count']} new stories") - except Exception as e: - print(f" [FAIL] Error: {e}") - raise - - -async def test_get_ask_hn_stories(): - """Test fetching Ask HN stories.""" - print("\nTesting get_ask_hn_stories...") - - inputs = {"limit": 3} - - async with ExecutionContext(auth={}) as context: - try: - result = await hackernews.execute_action("get_ask_hn_stories", inputs, context) - assert "stories" in result - print(f" [OK] Got {result['count']} Ask HN stories") - except Exception as e: - print(f" [FAIL] Error: {e}") - raise - - -async def test_get_show_hn_stories(): - """Test fetching Show HN stories.""" - print("\nTesting get_show_hn_stories...") - - inputs = {"limit": 3} - - async with ExecutionContext(auth={}) as context: - try: - result = await hackernews.execute_action("get_show_hn_stories", inputs, context) - assert "stories" in result - print(f" [OK] Got {result['count']} Show HN stories") - except Exception as e: - print(f" [FAIL] Error: {e}") - raise - - -async def test_get_job_stories(): - """Test fetching job stories.""" - print("\nTesting get_job_stories...") - - inputs = {"limit": 3} - - async with ExecutionContext(auth={}) as context: - try: - result = await hackernews.execute_action("get_job_stories", inputs, context) - assert "jobs" in result - print(f" [OK] Got {result['count']} job postings") - except Exception as e: - print(f" [FAIL] Error: {e}") - raise - - -async def test_get_story_with_comments(): - """Test fetching a story with comments.""" - print("\nTesting get_story_with_comments...") - - async with ExecutionContext(auth={}) as context: - try: - top_result = await hackernews.execute_action("get_top_stories", {"limit": 1}, context) - - if not top_result["stories"]: - print(" [WARN] No stories to test with") - return - - story_id = top_result["stories"][0]["id"] - - inputs = {"story_id": story_id, "comment_limit": 5, "comment_depth": 2} - - result = await hackernews.execute_action("get_story_with_comments", inputs, context) - assert "story" in result - assert "comments" in result - assert result["story"]["id"] == story_id - - print(f" [OK] Got story with {len(result['comments'])} top-level comments") - except Exception as e: - print(f" [FAIL] Error: {e}") - raise - - -async def test_get_user_profile(): - """Test fetching a user profile.""" - print("\nTesting get_user_profile...") - - inputs = {"username": "dang"} - - async with ExecutionContext(auth={}) as context: - try: - result = await hackernews.execute_action("get_user_profile", inputs, context) - assert "id" in result - assert "karma" in result - assert result["id"] == "dang" - - print(f" [OK] Got profile for {result['id']} (karma: {result['karma']})") - except Exception as e: - print(f" [FAIL] Error: {e}") - raise - - -async def test_user_not_found(): - """Test handling of non-existent user.""" - print("\nTesting user not found handling...") - - inputs = {"username": "this_user_definitely_does_not_exist_12345"} - - async with ExecutionContext(auth={}) as context: - try: - await hackernews.execute_action("get_user_profile", inputs, context) - print(" [FAIL] Should have raised an error") - assert False, "Expected ValueError but none was raised" - except ValueError as e: - print(f" [OK] Correctly raised error: {e}") - except Exception as e: - print(f" [FAIL] Unexpected error type: {e}") - raise - - -async def main(): - print("=" * 50) - print("Testing Hacker News Integration") - print("=" * 50) - - await test_get_top_stories() - await test_get_best_stories() - await test_get_new_stories() - await test_get_ask_hn_stories() - await test_get_show_hn_stories() - await test_get_job_stories() - await test_get_story_with_comments() - await test_get_user_profile() - await test_user_not_found() - - print("\n" + "=" * 50) - print("[OK] All tests passed!") - print("=" * 50) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/hackernews/tests/test_hackernews_integration.py b/hackernews/tests/test_hackernews_integration.py new file mode 100644 index 00000000..966d6a11 --- /dev/null +++ b/hackernews/tests/test_hackernews_integration.py @@ -0,0 +1,191 @@ +""" +End-to-end integration tests for the Hacker News integration. + +These tests call the real Hacker News Firebase API (public, no auth needed). + +Run with: + pytest hackernews/tests/test_hackernews_integration.py -m integration + +Never runs in CI — the default pytest marker filter (-m unit) excludes these, +and the file naming (test_*_integration.py) is not matched by python_files. +""" + +import os +import sys +import importlib + +_parent = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +_deps = os.path.abspath(os.path.join(os.path.dirname(__file__), "../dependencies")) +sys.path.insert(0, _parent) +sys.path.insert(0, _deps) + +import pytest # noqa: E402 +from unittest.mock import MagicMock, AsyncMock # noqa: E402 + +_spec = importlib.util.spec_from_file_location("hackernews_mod", os.path.join(_parent, "hackernews.py")) +_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_mod) + +hackernews = _mod.hackernews + +pytestmark = pytest.mark.integration + + +@pytest.fixture +def live_context(): + """Execution context wired to a real HTTP client via aiohttp.""" + import aiohttp + + async def real_fetch(url, *, method="GET", json=None, headers=None, **kwargs): + async with aiohttp.ClientSession() as session: + async with session.request(method, url, json=json, headers=headers) as resp: + return await resp.json(content_type=None) + + ctx = MagicMock(name="ExecutionContext") + ctx.fetch = AsyncMock(side_effect=real_fetch) + ctx.auth = {} + return ctx + + +# ---- Story List Actions ---- + + +class TestGetTopStories: + async def test_returns_stories(self, live_context): + result = await hackernews.execute_action("get_top_stories", {"limit": 5}, live_context) + + data = result.result.data + assert "stories" in data + assert "fetched_at" in data + assert "count" in data + assert data["count"] > 0 + assert len(data["stories"]) <= 5 + + async def test_story_structure(self, live_context): + result = await hackernews.execute_action("get_top_stories", {"limit": 1}, live_context) + + story = result.result.data["stories"][0] + assert "id" in story + assert "title" in story + assert "hn_url" in story + assert "score" in story + assert "by" in story + assert "time" in story + assert story["hn_url"].startswith("https://news.ycombinator.com/item?id=") + + async def test_cost_is_zero(self, live_context): + result = await hackernews.execute_action("get_top_stories", {"limit": 1}, live_context) + + assert result.result.cost_usd == 0.0 + + +class TestGetBestStories: + async def test_returns_stories(self, live_context): + result = await hackernews.execute_action("get_best_stories", {"limit": 3}, live_context) + + data = result.result.data + assert data["count"] > 0 + assert len(data["stories"]) <= 3 + + +class TestGetNewStories: + async def test_returns_stories(self, live_context): + result = await hackernews.execute_action("get_new_stories", {"limit": 3}, live_context) + + data = result.result.data + assert data["count"] > 0 + assert len(data["stories"]) <= 3 + + +class TestGetAskHNStories: + async def test_returns_stories(self, live_context): + result = await hackernews.execute_action("get_ask_hn_stories", {"limit": 3}, live_context) + + data = result.result.data + assert "stories" in data + assert data["count"] >= 0 + + +class TestGetShowHNStories: + async def test_returns_stories(self, live_context): + result = await hackernews.execute_action("get_show_hn_stories", {"limit": 3}, live_context) + + data = result.result.data + assert "stories" in data + assert data["count"] >= 0 + + +class TestGetJobStories: + async def test_returns_jobs(self, live_context): + result = await hackernews.execute_action("get_job_stories", {"limit": 3}, live_context) + + data = result.result.data + assert "jobs" in data + assert "count" in data + assert data["count"] >= 0 + + +# ---- Story with Comments ---- + + +class TestGetStoryWithComments: + async def test_fetches_story_and_comments(self, live_context): + # First get a real story ID from top stories + top = await hackernews.execute_action("get_top_stories", {"limit": 1}, live_context) + story_id = top.result.data["stories"][0]["id"] + + result = await hackernews.execute_action( + "get_story_with_comments", {"story_id": story_id, "comment_limit": 3, "comment_depth": 1}, live_context + ) + + data = result.result.data + assert "story" in data + assert "comments" in data + assert "fetched_at" in data + assert data["story"]["id"] == story_id + + async def test_comment_structure(self, live_context): + top = await hackernews.execute_action("get_top_stories", {"limit": 5}, live_context) + + # Find a story that has comments + story_id = None + for story in top.result.data["stories"]: + if story.get("descendants", 0) > 0: + story_id = story["id"] + break + + if story_id is None: + pytest.skip("No stories with comments found in top 5") + + result = await hackernews.execute_action( + "get_story_with_comments", {"story_id": story_id, "comment_limit": 2, "comment_depth": 1}, live_context + ) + + comments = result.result.data["comments"] + assert len(comments) > 0 + comment = comments[0] + assert "id" in comment + assert "by" in comment + assert "text" in comment + + +# ---- User Profile ---- + + +class TestGetUserProfile: + async def test_known_user(self, live_context): + result = await hackernews.execute_action("get_user_profile", {"username": "dang"}, live_context) + + data = result.result.data + assert data["id"] == "dang" + assert data["karma"] > 0 + assert "profile_url" in data + assert "created" in data + + async def test_nonexistent_user(self, live_context): + result = await hackernews.execute_action( + "get_user_profile", {"username": "this_user_definitely_does_not_exist_99999"}, live_context + ) + + data = result.result.data + assert "error" in data diff --git a/hackernews/tests/test_hackernews_unit.py b/hackernews/tests/test_hackernews_unit.py new file mode 100644 index 00000000..e7f21454 --- /dev/null +++ b/hackernews/tests/test_hackernews_unit.py @@ -0,0 +1,370 @@ +import sys +import os +import importlib + +_parent = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +_deps = os.path.abspath(os.path.join(os.path.dirname(__file__), "../dependencies")) +sys.path.insert(0, _parent) +sys.path.insert(0, _deps) + +import pytest # noqa: E402 + +_spec = importlib.util.spec_from_file_location("hackernews_mod", os.path.join(_parent, "hackernews.py")) +_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_mod) + +hackernews = _mod.hackernews +format_item = _mod.format_item +format_comment = _mod.format_comment + +pytestmark = pytest.mark.unit + +SAMPLE_STORY = { + "id": 12345, + "title": "Show HN: A new project", + "type": "story", + "by": "testuser", + "score": 42, + "descendants": 10, + "time": 1700000000, + "url": "https://example.com", + "kids": [100, 101], +} + +SAMPLE_COMMENT = { + "id": 100, + "type": "comment", + "by": "commenter1", + "text": "Great post!", + "time": 1700000100, + "kids": [200], +} + +SAMPLE_REPLY = { + "id": 200, + "type": "comment", + "by": "commenter2", + "text": "I agree!", + "time": 1700000200, +} + +SAMPLE_USER = { + "id": "dang", + "karma": 50000, + "created": 1200000000, + "about": "HN moderator", +} + + +class TestFormatItem: + def test_full_item(self): + result = format_item(SAMPLE_STORY) + assert result["id"] == 12345 + assert result["title"] == "Show HN: A new project" + assert result["type"] == "story" + assert result["by"] == "testuser" + assert result["score"] == 42 + assert result["descendants"] == 10 + assert result["url"] == "https://example.com" + assert result["hn_url"] == "https://news.ycombinator.com/item?id=12345" + assert "time" in result + + def test_item_without_url(self): + item = {**SAMPLE_STORY} + del item["url"] + result = format_item(item) + assert "url" not in result + + def test_item_without_text(self): + result = format_item(SAMPLE_STORY) + assert "text" not in result + + def test_item_with_text(self): + item = {**SAMPLE_STORY, "text": "Some text content"} + result = format_item(item) + assert result["text"] == "Some text content" + + def test_item_without_time(self): + item = {k: v for k, v in SAMPLE_STORY.items() if k != "time"} + result = format_item(item) + assert "time" not in result + + def test_defaults_for_missing_score_and_descendants(self): + item = {"id": 1, "title": "Test"} + result = format_item(item) + assert result["score"] == 0 + assert result["descendants"] == 0 + + +class TestFormatComment: + def test_normal_comment(self): + result = format_comment(SAMPLE_COMMENT) + assert result["id"] == 100 + assert result["by"] == "commenter1" + assert result["text"] == "Great post!" + assert "time" in result + assert "replies" not in result + + def test_comment_with_replies(self): + replies = [{"id": 200, "by": "someone", "text": "reply"}] + result = format_comment(SAMPLE_COMMENT, replies=replies) + assert result["replies"] == replies + + def test_deleted_comment_returns_none(self): + item = {**SAMPLE_COMMENT, "deleted": True} + assert format_comment(item) is None + + def test_dead_comment_returns_none(self): + item = {**SAMPLE_COMMENT, "dead": True} + assert format_comment(item) is None + + def test_comment_without_author(self): + item = {"id": 300, "text": "anonymous"} + result = format_comment(item) + assert result["by"] == "[deleted]" + + def test_comment_without_time(self): + item = {"id": 300, "by": "user", "text": "hi"} + result = format_comment(item) + assert "time" not in result + + def test_empty_replies_not_included(self): + result = format_comment(SAMPLE_COMMENT, replies=None) + assert "replies" not in result + + +def _story_ids_and_items(ids, items): + """Build a side_effect for fetch that returns story IDs first, then items.""" + responses = iter([ids] + items) + return lambda *args, **kwargs: next(responses) + + +class TestGetTopStories: + @pytest.mark.asyncio + async def test_returns_stories(self, mock_context): + mock_context.fetch.side_effect = [ + [12345, 12346], + SAMPLE_STORY, + {**SAMPLE_STORY, "id": 12346, "title": "Second story"}, + ] + + result = await hackernews.execute_action("get_top_stories", {"limit": 2}, mock_context) + data = result.result.data + + assert "stories" in data + assert data["count"] == 2 + assert "fetched_at" in data + assert data["stories"][0]["id"] == 12345 + assert data["stories"][1]["id"] == 12346 + + @pytest.mark.asyncio + async def test_empty_response(self, mock_context): + mock_context.fetch.return_value = None + + result = await hackernews.execute_action("get_top_stories", {"limit": 5}, mock_context) + data = result.result.data + + assert data["stories"] == [] + assert data["count"] == 0 + + @pytest.mark.asyncio + async def test_default_limit(self, mock_context): + mock_context.fetch.side_effect = [ + [12345], + SAMPLE_STORY, + ] + + result = await hackernews.execute_action("get_top_stories", {}, mock_context) + data = result.result.data + + assert data["count"] == 1 + + +class TestGetBestStories: + @pytest.mark.asyncio + async def test_returns_stories(self, mock_context): + mock_context.fetch.side_effect = [ + [12345], + SAMPLE_STORY, + ] + + result = await hackernews.execute_action("get_best_stories", {"limit": 3}, mock_context) + data = result.result.data + + assert "stories" in data + assert data["count"] == 1 + + +class TestGetNewStories: + @pytest.mark.asyncio + async def test_returns_stories(self, mock_context): + mock_context.fetch.side_effect = [ + [12345], + SAMPLE_STORY, + ] + + result = await hackernews.execute_action("get_new_stories", {"limit": 3}, mock_context) + data = result.result.data + + assert "stories" in data + assert data["count"] == 1 + + +class TestGetAskHNStories: + @pytest.mark.asyncio + async def test_returns_stories(self, mock_context): + mock_context.fetch.side_effect = [ + [12345], + SAMPLE_STORY, + ] + + result = await hackernews.execute_action("get_ask_hn_stories", {"limit": 3}, mock_context) + data = result.result.data + + assert "stories" in data + assert data["count"] == 1 + + +class TestGetShowHNStories: + @pytest.mark.asyncio + async def test_returns_stories(self, mock_context): + mock_context.fetch.side_effect = [ + [12345], + SAMPLE_STORY, + ] + + result = await hackernews.execute_action("get_show_hn_stories", {"limit": 3}, mock_context) + data = result.result.data + + assert "stories" in data + assert data["count"] == 1 + + +class TestGetJobStories: + @pytest.mark.asyncio + async def test_returns_jobs(self, mock_context): + job_item = {**SAMPLE_STORY, "type": "job", "title": "Hiring Engineers"} + mock_context.fetch.side_effect = [ + [12345], + job_item, + ] + + result = await hackernews.execute_action("get_job_stories", {"limit": 3}, mock_context) + data = result.result.data + + assert "jobs" in data + assert data["count"] == 1 + assert data["jobs"][0]["type"] == "job" + + +class TestGetStoryWithComments: + @pytest.mark.asyncio + async def test_story_with_comments(self, mock_context): + second_comment = { + "id": 101, + "type": "comment", + "by": "commenter3", + "text": "Nice work!", + "time": 1700000300, + } + mock_context.fetch.side_effect = [ + SAMPLE_STORY, + SAMPLE_COMMENT, + second_comment, + SAMPLE_REPLY, + ] + + inputs = {"story_id": 12345, "comment_limit": 5, "comment_depth": 2} + result = await hackernews.execute_action("get_story_with_comments", inputs, mock_context) + data = result.result.data + + assert data["story"]["id"] == 12345 + assert "comments" in data + assert "fetched_at" in data + assert len(data["comments"]) == 2 + assert data["comments"][0]["by"] == "commenter1" + assert data["comments"][0]["replies"][0]["by"] == "commenter2" + assert data["comments"][1]["by"] == "commenter3" + + @pytest.mark.asyncio + async def test_story_not_found(self, mock_context): + mock_context.fetch.return_value = None + + inputs = {"story_id": 99999} + result = await hackernews.execute_action("get_story_with_comments", inputs, mock_context) + data = result.result.data + + assert "error" in data + assert "99999" in data["error"] + + @pytest.mark.asyncio + async def test_story_without_comments(self, mock_context): + story_no_kids = {k: v for k, v in SAMPLE_STORY.items() if k != "kids"} + mock_context.fetch.side_effect = [story_no_kids] + + inputs = {"story_id": 12345} + result = await hackernews.execute_action("get_story_with_comments", inputs, mock_context) + data = result.result.data + + assert data["story"]["id"] == 12345 + assert data["comments"] == [] + + @pytest.mark.asyncio + async def test_deleted_comments_filtered(self, mock_context): + deleted_comment = {**SAMPLE_COMMENT, "deleted": True} + mock_context.fetch.side_effect = [ + SAMPLE_STORY, + deleted_comment, + SAMPLE_COMMENT, + ] + + inputs = {"story_id": 12345, "comment_limit": 5, "comment_depth": 1} + result = await hackernews.execute_action("get_story_with_comments", inputs, mock_context) + data = result.result.data + + assert len(data["comments"]) == 1 + assert data["comments"][0]["id"] == 100 + + +class TestGetUserProfile: + @pytest.mark.asyncio + async def test_returns_profile(self, mock_context): + mock_context.fetch.return_value = SAMPLE_USER + + result = await hackernews.execute_action("get_user_profile", {"username": "dang"}, mock_context) + data = result.result.data + + assert data["id"] == "dang" + assert data["karma"] == 50000 + assert data["about"] == "HN moderator" + assert "created" in data + assert "profile_url" in data + + @pytest.mark.asyncio + async def test_user_not_found(self, mock_context): + mock_context.fetch.return_value = None + + result = await hackernews.execute_action("get_user_profile", {"username": "nonexistent_user_xyz"}, mock_context) + data = result.result.data + + assert "error" in data + assert "nonexistent_user_xyz" in data["error"] + + @pytest.mark.asyncio + async def test_user_without_about(self, mock_context): + user = {k: v for k, v in SAMPLE_USER.items() if k != "about"} + mock_context.fetch.return_value = user + + result = await hackernews.execute_action("get_user_profile", {"username": "dang"}, mock_context) + data = result.result.data + + assert "about" not in data + + @pytest.mark.asyncio + async def test_fetch_error_handled(self, mock_context): + mock_context.fetch.side_effect = Exception("Network error") + + result = await hackernews.execute_action("get_user_profile", {"username": "dang"}, mock_context) + data = result.result.data + + assert "error" in data diff --git a/notion/config.json b/notion/config.json index beff33cb..5cbabecb 100644 --- a/notion/config.json +++ b/notion/config.json @@ -1,6 +1,6 @@ { "name": "Notion", - "version": "1.0.0", + "version": "1.0.1", "description": "Enhanced integration with Notion API featuring database querying, block management, page property operations, and advanced search capabilities", "entry_point": "notion.py", "auth": { diff --git a/notion/notion.py b/notion/notion.py index 034c7c28..b7489cac 100644 --- a/notion/notion.py +++ b/notion/notion.py @@ -1,4 +1,9 @@ -from autohive_integrations_sdk import Integration, ExecutionContext, ActionHandler, ActionResult +from autohive_integrations_sdk import ( + Integration, + ExecutionContext, + ActionHandler, + ActionResult, +) from typing import Dict, Any # API Version constant for Notion API @@ -44,12 +49,18 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac search_body["start_cursor"] = inputs["start_cursor"] # Prepare headers for Notion API - headers = {"Notion-Version": NOTION_API_VERSION, "Content-Type": "application/json"} + headers = { + "Notion-Version": NOTION_API_VERSION, + "Content-Type": "application/json", + } # Make the search request to Notion API try: response = await context.fetch( - url="https://api.notion.com/v1/search", method="POST", headers=headers, json=search_body + url="https://api.notion.com/v1/search", + method="POST", + headers=headers, + json=search_body, ) return ActionResult( @@ -63,7 +74,14 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac ) except Exception as e: - return ActionResult(data={"object": "list", "error": str(e), "results": [], "has_more": False}) + return ActionResult( + data={ + "object": "list", + "error": str(e), + "results": [], + "has_more": False, + } + ) @notion.action("get_notion_page") @@ -89,7 +107,9 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac # Make the get page request to Notion API try: response = await context.fetch( - url=f"https://api.notion.com/v1/pages/{page_id}", method="GET", headers=headers + url=f"https://api.notion.com/v1/pages/{page_id}", + method="GET", + headers=headers, ) return ActionResult(data={"page": response}) @@ -117,12 +137,18 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac create_body = {"parent": inputs["parent"], "properties": inputs["properties"]} # Prepare headers for Notion API - headers = {"Notion-Version": NOTION_API_VERSION, "Content-Type": "application/json"} + headers = { + "Notion-Version": NOTION_API_VERSION, + "Content-Type": "application/json", + } # Make the create page request to Notion API try: response = await context.fetch( - url="https://api.notion.com/v1/pages", method="POST", headers=headers, json=create_body + url="https://api.notion.com/v1/pages", + method="POST", + headers=headers, + json=create_body, ) return ActionResult(data={"page": response}) @@ -150,12 +176,18 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac comment_body = {"parent": inputs["parent"], "rich_text": inputs["rich_text"]} # Prepare headers for Notion API - headers = {"Notion-Version": NOTION_API_VERSION, "Content-Type": "application/json"} + headers = { + "Notion-Version": NOTION_API_VERSION, + "Content-Type": "application/json", + } # Make the create comment request to Notion API try: response = await context.fetch( - url="https://api.notion.com/v1/comments", method="POST", headers=headers, json=comment_body + url="https://api.notion.com/v1/comments", + method="POST", + headers=headers, + json=comment_body, ) return ActionResult(data={"comment": response}) @@ -184,7 +216,10 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac try: response = await context.fetch( - url="https://api.notion.com/v1/comments", method="GET", headers=headers, params=params + url="https://api.notion.com/v1/comments", + method="GET", + headers=headers, + params=params, ) result = { @@ -197,7 +232,9 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac blocks_with_comments = [] blocks_response = await context.fetch( - url=f"https://api.notion.com/v1/blocks/{block_id}/children", method="GET", headers=headers + url=f"https://api.notion.com/v1/blocks/{block_id}/children", + method="GET", + headers=headers, ) for block in blocks_response.get("results", []): @@ -211,7 +248,11 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac ) if comments_response.get("results"): blocks_with_comments.append( - {"block_id": child_block_id, "block_type": block.get("type"), "has_comments": True} + { + "block_id": child_block_id, + "block_type": block.get("type"), + "has_comments": True, + } ) result["child_blocks_with_comments"] = blocks_with_comments @@ -287,7 +328,9 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac try: response = await context.fetch( - url=f"https://api.notion.com/v1/data_sources/{data_source_id}", method="GET", headers=headers + url=f"https://api.notion.com/v1/data_sources/{data_source_id}", + method="GET", + headers=headers, ) return ActionResult(data={"data_source": response}) except Exception as e: @@ -333,7 +376,10 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac query_body["start_cursor"] = inputs["start_cursor"] # Prepare headers for Notion API - headers = {"Notion-Version": NOTION_API_VERSION, "Content-Type": "application/json"} + headers = { + "Notion-Version": NOTION_API_VERSION, + "Content-Type": "application/json", + } # Make the query request to Notion API try: @@ -430,7 +476,10 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac append_body["after"] = inputs["after"] # Prepare headers for Notion API - headers = {"Notion-Version": NOTION_API_VERSION, "Content-Type": "application/json"} + headers = { + "Notion-Version": NOTION_API_VERSION, + "Content-Type": "application/json", + } # Make the append block children request to Notion API try: @@ -553,12 +602,18 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac } # Prepare headers for Notion API - headers = {"Notion-Version": NOTION_API_VERSION, "Content-Type": "application/json"} + headers = { + "Notion-Version": NOTION_API_VERSION, + "Content-Type": "application/json", + } # Make the update block request to Notion API try: response = await context.fetch( - url=f"https://api.notion.com/v1/blocks/{block_id}", method="PATCH", headers=headers, json=update_body + url=f"https://api.notion.com/v1/blocks/{block_id}", + method="PATCH", + headers=headers, + json=update_body, ) return ActionResult(data={"block": response}) @@ -590,7 +645,9 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac # Make the delete block request to Notion API try: response = await context.fetch( - url=f"https://api.notion.com/v1/blocks/{block_id}", method="DELETE", headers=headers + url=f"https://api.notion.com/v1/blocks/{block_id}", + method="DELETE", + headers=headers, ) return ActionResult(data={"block": response}) @@ -627,12 +684,18 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac } # Prepare headers for Notion API - headers = {"Notion-Version": NOTION_API_VERSION, "Content-Type": "application/json"} + headers = { + "Notion-Version": NOTION_API_VERSION, + "Content-Type": "application/json", + } # Make the update page request to Notion API try: response = await context.fetch( - url=f"https://api.notion.com/v1/pages/{page_id}", method="PATCH", headers=headers, json=update_body + url=f"https://api.notion.com/v1/pages/{page_id}", + method="PATCH", + headers=headers, + json=update_body, ) return ActionResult(data={"page": response}) diff --git a/notion/tests/conftest.py b/notion/tests/conftest.py new file mode 100644 index 00000000..1d99cac4 --- /dev/null +++ b/notion/tests/conftest.py @@ -0,0 +1,5 @@ +import sys +import os + +# Allow 'from context import ...' to work when pytest runs from repo root +sys.path.insert(0, os.path.dirname(__file__)) diff --git a/notion/tests/test_notion_integration.py b/notion/tests/test_notion_integration.py index 8a220a4c..acc383b5 100644 --- a/notion/tests/test_notion_integration.py +++ b/notion/tests/test_notion_integration.py @@ -222,7 +222,12 @@ async def test_get_comments_handler_empty_optional_params(): """Test NotionGetCommentsHandler ignores empty optional params""" handler = NotionGetCommentsHandler() - mock_response = {"object": "list", "results": [], "next_cursor": None, "has_more": False} + mock_response = { + "object": "list", + "results": [], + "next_cursor": None, + "has_more": False, + } mock_context = MagicMock() mock_context.fetch = AsyncMock(return_value=mock_response) @@ -243,7 +248,12 @@ async def test_get_comments_handler_empty_optional_params(): async def test_new_actions(): """Test that the new update/delete actions are properly configured""" - new_actions = ["update_notion_block", "delete_notion_block", "update_notion_page", "get_notion_comments"] + new_actions = [ + "update_notion_block", + "delete_notion_block", + "update_notion_page", + "get_notion_comments", + ] with open("config.json", "r") as f: config = json.load(f) diff --git a/notion/tests/test_notion_unit.py b/notion/tests/test_notion_unit.py new file mode 100644 index 00000000..ec9fe46a --- /dev/null +++ b/notion/tests/test_notion_unit.py @@ -0,0 +1,700 @@ +""" +Unit tests for Notion integration. + +Migrated from test_notion_integration.py (asyncio.run style) to proper pytest. +Covers all handlers with mocked context.fetch calls. +""" + +import json +import os +import sys + +import pytest +from unittest.mock import AsyncMock, MagicMock + +# Add parent and tests directories to path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) + +from notion.notion import ( + NOTION_API_VERSION, + NotionGetCommentsHandler, + NotionSearchHandler, + NotionGetPageHandler, + NotionCreatePageHandler, + NotionCreateCommentHandler, + NotionGetBlockChildrenHandler, + NotionUpdateBlockHandler, + NotionDeleteBlockHandler, + NotionUpdatePageHandler, + notion as notion_integration, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_context(): + ctx = MagicMock(name="ExecutionContext") + ctx.fetch = AsyncMock(name="fetch") + ctx.auth = { + "auth_type": "PlatformOauth2", + "credentials": {"access_token": "test_token"}, # nosec B105 + } + return ctx + + +NOTION_HEADERS = {"Notion-Version": NOTION_API_VERSION} +NOTION_HEADERS_JSON = { + "Notion-Version": NOTION_API_VERSION, + "Content-Type": "application/json", +} + +CONFIG_PATH = os.path.join(os.path.dirname(__file__), "..", "config.json") + + +# --------------------------------------------------------------------------- +# Config validation (migrated from old tests) +# --------------------------------------------------------------------------- + + +class TestConfigValidation: + """Verify config.json actions match registered handlers.""" + + def test_actions_match_handlers(self): + with open(CONFIG_PATH, "r") as f: + config = json.load(f) + + defined_actions = set(config.get("actions", {}).keys()) + registered_actions = set(notion_integration._action_handlers.keys()) + + missing_handlers = defined_actions - registered_actions + extra_handlers = registered_actions - defined_actions + + assert not missing_handlers, f"Missing handlers for actions: {missing_handlers}" + assert not extra_handlers, f"Extra handlers without config: {extra_handlers}" + + def test_get_comments_action_config(self): + with open(CONFIG_PATH, "r") as f: + config = json.load(f) + + action_config = config["actions"]["get_notion_comments"] + + assert action_config["display_name"] == "Get Comments" + assert "Retrieve comments" in action_config["description"] + + input_schema = action_config["input_schema"] + assert "block_id" in input_schema["properties"] + assert "page_size" in input_schema["properties"] + assert "start_cursor" in input_schema["properties"] + assert input_schema["required"] == ["block_id"] + + output_schema = action_config["output_schema"] + assert "comments" in output_schema["properties"] + assert "next_cursor" in output_schema["properties"] + assert "has_more" in output_schema["properties"] + + def test_get_comments_pagination_schema(self): + with open(CONFIG_PATH, "r") as f: + config = json.load(f) + + props = config["actions"]["get_notion_comments"]["input_schema"]["properties"] + + page_size = props["page_size"] + assert page_size["type"] == "integer" + assert page_size["minimum"] == 1 + assert page_size["maximum"] == 100 + + start_cursor = props["start_cursor"] + assert start_cursor["type"] == "string" + + def test_create_and_get_comment_actions_complement(self): + with open(CONFIG_PATH, "r") as f: + config = json.load(f) + + actions = config["actions"] + assert "create_notion_comment" in actions + assert "get_notion_comments" in actions + + create_output = actions["create_notion_comment"]["output_schema"]["properties"] + assert "id" in create_output + + get_output = actions["get_notion_comments"]["output_schema"]["properties"] + assert "comments" in get_output + + def test_new_actions_defined(self): + with open(CONFIG_PATH, "r") as f: + config = json.load(f) + + actions = config["actions"] + for action_name in [ + "update_notion_block", + "delete_notion_block", + "update_notion_page", + "get_notion_comments", + ]: + assert action_name in actions, f"{action_name} not in config.json" + action_config = actions[action_name] + assert "display_name" in action_config + assert "description" in action_config + assert "input_schema" in action_config + assert "output_schema" in action_config + + +# --------------------------------------------------------------------------- +# GetComments handler (migrated from old tests) +# --------------------------------------------------------------------------- + + +class TestGetComments: + """Tests for NotionGetCommentsHandler — migrated from old test suite.""" + + @pytest.mark.asyncio + async def test_basic(self, mock_context): + handler = NotionGetCommentsHandler() + + mock_context.fetch.return_value = { + "object": "list", + "results": [ + { + "id": "comment-123", + "discussion_id": "disc-456", + "created_time": "2024-01-15T10:00:00.000Z", + "rich_text": [{"type": "text", "text": {"content": "Test comment"}}], + "parent": {"type": "page_id", "page_id": "page-789"}, + } + ], + "next_cursor": None, + "has_more": False, + } + + result = await handler.execute({"block_id": "page-789"}, mock_context) + + mock_context.fetch.assert_called_once_with( + url="https://api.notion.com/v1/comments", + method="GET", + headers=NOTION_HEADERS, + params={"block_id": "page-789"}, + ) + + assert len(result.data["comments"]) == 1 + assert result.data["comments"][0]["id"] == "comment-123" + assert result.data["has_more"] is False + + @pytest.mark.asyncio + async def test_with_pagination(self, mock_context): + handler = NotionGetCommentsHandler() + + mock_context.fetch.return_value = { + "object": "list", + "results": [{"id": "comment-1"}, {"id": "comment-2"}], + "next_cursor": "cursor-abc", + "has_more": True, + } + + result = await handler.execute( + {"block_id": "page-123", "page_size": 2, "start_cursor": "prev-cursor"}, + mock_context, + ) + + mock_context.fetch.assert_called_once_with( + url="https://api.notion.com/v1/comments", + method="GET", + headers=NOTION_HEADERS, + params={ + "block_id": "page-123", + "page_size": 2, + "start_cursor": "prev-cursor", + }, + ) + + assert result.data["has_more"] is True + assert result.data["next_cursor"] == "cursor-abc" + + @pytest.mark.asyncio + async def test_error_handling(self, mock_context): + handler = NotionGetCommentsHandler() + mock_context.fetch.side_effect = Exception("API rate limit exceeded") + + result = await handler.execute({"block_id": "page-789"}, mock_context) + + assert "error" in result.data + assert "API rate limit exceeded" in result.data["error"] + assert result.data["comments"] == [] + + @pytest.mark.asyncio + async def test_empty_optional_params(self, mock_context): + handler = NotionGetCommentsHandler() + + mock_context.fetch.return_value = { + "object": "list", + "results": [], + "next_cursor": None, + "has_more": False, + } + + await handler.execute( + {"block_id": "page-123", "page_size": None, "start_cursor": ""}, + mock_context, + ) + + mock_context.fetch.assert_called_once_with( + url="https://api.notion.com/v1/comments", + method="GET", + headers=NOTION_HEADERS, + params={"block_id": "page-123"}, + ) + + +# --------------------------------------------------------------------------- +# Search handler +# --------------------------------------------------------------------------- + + +class TestSearch: + @pytest.mark.asyncio + async def test_basic_search(self, mock_context): + handler = NotionSearchHandler() + + mock_context.fetch.return_value = { + "object": "list", + "results": [{"id": "page-1", "object": "page"}], + "has_more": False, + "next_cursor": None, + "type": "page_or_database", + } + + result = await handler.execute({"query": "meeting notes"}, mock_context) + + mock_context.fetch.assert_called_once_with( + url="https://api.notion.com/v1/search", + method="POST", + headers=NOTION_HEADERS_JSON, + json={"query": "meeting notes"}, + ) + + assert result.data["results"] == [{"id": "page-1", "object": "page"}] + assert result.data["has_more"] is False + + @pytest.mark.asyncio + async def test_search_with_filter_and_sort(self, mock_context): + handler = NotionSearchHandler() + + mock_context.fetch.return_value = { + "object": "list", + "results": [], + "has_more": False, + "next_cursor": None, + } + + inputs = { + "query": "test", + "filter": {"value": "page", "property": "object"}, + "sort": {"direction": "descending", "timestamp": "last_edited_time"}, + "page_size": 10, + "start_cursor": "abc", + } + + await handler.execute(inputs, mock_context) + + call_json = mock_context.fetch.call_args.kwargs["json"] + assert call_json["query"] == "test" + assert call_json["filter"] == {"value": "page", "property": "object"} + assert call_json["sort"] == { + "direction": "descending", + "timestamp": "last_edited_time", + } + assert call_json["page_size"] == 10 + assert call_json["start_cursor"] == "abc" + + @pytest.mark.asyncio + async def test_search_error(self, mock_context): + handler = NotionSearchHandler() + mock_context.fetch.side_effect = Exception("Unauthorized") + + result = await handler.execute({"query": "test"}, mock_context) + + assert "error" in result.data + assert "Unauthorized" in result.data["error"] + assert result.data["results"] == [] + + +# --------------------------------------------------------------------------- +# GetPage handler +# --------------------------------------------------------------------------- + + +class TestGetPage: + @pytest.mark.asyncio + async def test_get_page(self, mock_context): + handler = NotionGetPageHandler() + + page_data = { + "id": "page-abc", + "object": "page", + "properties": {"Name": {"title": []}}, + } + mock_context.fetch.return_value = page_data + + result = await handler.execute({"page_id": "page-abc"}, mock_context) + + mock_context.fetch.assert_called_once_with( + url="https://api.notion.com/v1/pages/page-abc", + method="GET", + headers=NOTION_HEADERS, + ) + + assert result.data["page"] == page_data + + @pytest.mark.asyncio + async def test_get_page_error(self, mock_context): + handler = NotionGetPageHandler() + mock_context.fetch.side_effect = Exception("Not found") + + result = await handler.execute({"page_id": "bad-id"}, mock_context) + + assert "error" in result.data + assert "Not found" in result.data["error"] + assert result.data["page"] is None + + +# --------------------------------------------------------------------------- +# CreatePage handler +# --------------------------------------------------------------------------- + + +class TestCreatePage: + @pytest.mark.asyncio + async def test_create_page(self, mock_context): + handler = NotionCreatePageHandler() + + parent = {"database_id": "db-123"} + properties = {"Name": {"title": [{"text": {"content": "New Page"}}]}} + created_page = {"id": "new-page-1", "object": "page"} + mock_context.fetch.return_value = created_page + + result = await handler.execute({"parent": parent, "properties": properties}, mock_context) + + mock_context.fetch.assert_called_once_with( + url="https://api.notion.com/v1/pages", + method="POST", + headers=NOTION_HEADERS_JSON, + json={"parent": parent, "properties": properties}, + ) + + assert result.data["page"] == created_page + + @pytest.mark.asyncio + async def test_create_page_error(self, mock_context): + handler = NotionCreatePageHandler() + mock_context.fetch.side_effect = Exception("Validation error") + + result = await handler.execute( + {"parent": {"database_id": "db-1"}, "properties": {}}, + mock_context, + ) + + assert "error" in result.data + assert result.data["page"] is None + + +# --------------------------------------------------------------------------- +# CreateComment handler +# --------------------------------------------------------------------------- + + +class TestCreateComment: + @pytest.mark.asyncio + async def test_create_comment(self, mock_context): + handler = NotionCreateCommentHandler() + + parent = {"page_id": "page-123"} + rich_text = [{"type": "text", "text": {"content": "A comment"}}] + created_comment = {"id": "comment-new", "object": "comment"} + mock_context.fetch.return_value = created_comment + + result = await handler.execute({"parent": parent, "rich_text": rich_text}, mock_context) + + mock_context.fetch.assert_called_once_with( + url="https://api.notion.com/v1/comments", + method="POST", + headers=NOTION_HEADERS_JSON, + json={"parent": parent, "rich_text": rich_text}, + ) + + assert result.data["comment"] == created_comment + + @pytest.mark.asyncio + async def test_create_comment_error(self, mock_context): + handler = NotionCreateCommentHandler() + mock_context.fetch.side_effect = Exception("Forbidden") + + result = await handler.execute( + {"parent": {"page_id": "p-1"}, "rich_text": []}, + mock_context, + ) + + assert "error" in result.data + assert result.data["comment"] is None + + +# --------------------------------------------------------------------------- +# GetBlockChildren handler +# --------------------------------------------------------------------------- + + +class TestGetBlockChildren: + @pytest.mark.asyncio + async def test_get_block_children(self, mock_context): + handler = NotionGetBlockChildrenHandler() + + mock_context.fetch.return_value = { + "results": [ + {"id": "block-1", "type": "paragraph"}, + {"id": "block-2", "type": "heading_1"}, + ], + "has_more": False, + "next_cursor": None, + "type": "block", + } + + result = await handler.execute({"block_id": "parent-block"}, mock_context) + + mock_context.fetch.assert_called_once_with( + url="https://api.notion.com/v1/blocks/parent-block/children", + method="GET", + headers=NOTION_HEADERS, + params={}, + ) + + assert len(result.data["blocks"]) == 2 + assert result.data["has_more"] is False + + @pytest.mark.asyncio + async def test_get_block_children_with_pagination(self, mock_context): + handler = NotionGetBlockChildrenHandler() + + mock_context.fetch.return_value = { + "results": [{"id": "block-3"}], + "has_more": True, + "next_cursor": "next-abc", + } + + result = await handler.execute( + {"block_id": "parent-block", "page_size": 1, "start_cursor": "cur-1"}, + mock_context, + ) + + call_params = mock_context.fetch.call_args.kwargs["params"] + assert call_params["page_size"] == 1 + assert call_params["start_cursor"] == "cur-1" + assert result.data["has_more"] is True + assert result.data["next_cursor"] == "next-abc" + + @pytest.mark.asyncio + async def test_get_block_children_error(self, mock_context): + handler = NotionGetBlockChildrenHandler() + mock_context.fetch.side_effect = Exception("Server error") + + result = await handler.execute({"block_id": "block-x"}, mock_context) + + assert "error" in result.data + assert result.data["blocks"] == [] + + +# --------------------------------------------------------------------------- +# UpdateBlock handler +# --------------------------------------------------------------------------- + + +class TestUpdateBlock: + @pytest.mark.asyncio + async def test_update_block(self, mock_context): + handler = NotionUpdateBlockHandler() + + updated_block = { + "id": "block-1", + "type": "paragraph", + "paragraph": {"rich_text": []}, + } + mock_context.fetch.return_value = updated_block + + inputs = { + "block_id": "block-1", + "paragraph": {"rich_text": [{"type": "text", "text": {"content": "Updated text"}}]}, + } + + result = await handler.execute(inputs, mock_context) + + mock_context.fetch.assert_called_once_with( + url="https://api.notion.com/v1/blocks/block-1", + method="PATCH", + headers=NOTION_HEADERS_JSON, + json={"paragraph": inputs["paragraph"]}, + ) + + assert result.data["block"] == updated_block + + @pytest.mark.asyncio + async def test_update_block_filters_invalid_keys(self, mock_context): + handler = NotionUpdateBlockHandler() + mock_context.fetch.return_value = {"id": "block-1"} + + await handler.execute( + { + "block_id": "block-1", + "paragraph": {"rich_text": []}, + "invalid_key": "ignored", + }, + mock_context, + ) + + call_json = mock_context.fetch.call_args.kwargs["json"] + assert "paragraph" in call_json + assert "invalid_key" not in call_json + assert "block_id" not in call_json + + @pytest.mark.asyncio + async def test_update_block_error(self, mock_context): + handler = NotionUpdateBlockHandler() + mock_context.fetch.side_effect = Exception("Conflict") + + result = await handler.execute({"block_id": "block-1", "paragraph": {}}, mock_context) + + assert "error" in result.data + assert result.data["block"] is None + + +# --------------------------------------------------------------------------- +# DeleteBlock handler +# --------------------------------------------------------------------------- + + +class TestDeleteBlock: + @pytest.mark.asyncio + async def test_delete_block(self, mock_context): + handler = NotionDeleteBlockHandler() + + deleted_block = {"id": "block-del", "archived": True} + mock_context.fetch.return_value = deleted_block + + result = await handler.execute({"block_id": "block-del"}, mock_context) + + mock_context.fetch.assert_called_once_with( + url="https://api.notion.com/v1/blocks/block-del", + method="DELETE", + headers=NOTION_HEADERS, + ) + + assert result.data["block"] == deleted_block + + @pytest.mark.asyncio + async def test_delete_block_error(self, mock_context): + handler = NotionDeleteBlockHandler() + mock_context.fetch.side_effect = Exception("Not found") + + result = await handler.execute({"block_id": "gone"}, mock_context) + + assert "error" in result.data + assert result.data["block"] is None + + +# --------------------------------------------------------------------------- +# UpdatePage handler +# --------------------------------------------------------------------------- + + +class TestUpdatePage: + @pytest.mark.asyncio + async def test_update_page_properties(self, mock_context): + handler = NotionUpdatePageHandler() + + updated_page = {"id": "page-1", "object": "page"} + mock_context.fetch.return_value = updated_page + + properties = {"Status": {"select": {"name": "Done"}}} + result = await handler.execute({"page_id": "page-1", "properties": properties}, mock_context) + + mock_context.fetch.assert_called_once_with( + url="https://api.notion.com/v1/pages/page-1", + method="PATCH", + headers=NOTION_HEADERS_JSON, + json={"properties": properties}, + ) + + assert result.data["page"] == updated_page + + @pytest.mark.asyncio + async def test_update_page_archive(self, mock_context): + handler = NotionUpdatePageHandler() + mock_context.fetch.return_value = {"id": "page-1", "archived": True} + + result = await handler.execute({"page_id": "page-1", "archived": True}, mock_context) + + call_json = mock_context.fetch.call_args.kwargs["json"] + assert call_json["archived"] is True + assert result.data["page"]["archived"] is True + + @pytest.mark.asyncio + async def test_update_page_filters_invalid_keys(self, mock_context): + handler = NotionUpdatePageHandler() + mock_context.fetch.return_value = {"id": "page-1"} + + await handler.execute( + {"page_id": "page-1", "properties": {}, "bad_field": "ignored"}, + mock_context, + ) + + call_json = mock_context.fetch.call_args.kwargs["json"] + assert "properties" in call_json + assert "bad_field" not in call_json + assert "page_id" not in call_json + + @pytest.mark.asyncio + async def test_update_page_error(self, mock_context): + handler = NotionUpdatePageHandler() + mock_context.fetch.side_effect = Exception("Bad request") + + result = await handler.execute({"page_id": "page-1", "properties": {}}, mock_context) + + assert "error" in result.data + assert result.data["page"] is None + + +# --------------------------------------------------------------------------- +# Cross-handler error consistency +# --------------------------------------------------------------------------- + + +class TestErrorHandling: + """Verify all handlers return structured error data when fetch raises.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "handler_cls, inputs, error_key", + [ + (NotionSearchHandler, {"query": "x"}, "results"), + (NotionGetPageHandler, {"page_id": "x"}, "page"), + (NotionCreatePageHandler, {"parent": {}, "properties": {}}, "page"), + (NotionCreateCommentHandler, {"parent": {}, "rich_text": []}, "comment"), + (NotionGetCommentsHandler, {"block_id": "x"}, "comments"), + (NotionGetBlockChildrenHandler, {"block_id": "x"}, "blocks"), + (NotionUpdateBlockHandler, {"block_id": "x"}, "block"), + (NotionDeleteBlockHandler, {"block_id": "x"}, "block"), + (NotionUpdatePageHandler, {"page_id": "x", "properties": {}}, "page"), + ], + ) + async def test_handler_returns_error_on_exception(self, mock_context, handler_cls, inputs, error_key): + mock_context.fetch.side_effect = Exception("boom") + + handler = handler_cls() + result = await handler.execute(inputs, mock_context) + + assert "error" in result.data + assert "boom" in result.data["error"] + assert error_key in result.data diff --git a/nzbn/config.json b/nzbn/config.json index c560f473..e023bdaf 100644 --- a/nzbn/config.json +++ b/nzbn/config.json @@ -1,7 +1,7 @@ { "name": "NZBN", "display_name": "NZBN (New Zealand Business Number)", - "version": "1.0.0", + "version": "1.0.1", "description": "Integration with the New Zealand Business Number (NZBN) API for searching and retrieving business entity information from the NZBN Register.", "entry_point": "nzbn.py", "auth": { diff --git a/nzbn/nzbn.py b/nzbn/nzbn.py index 0be62028..59d7aa37 100644 --- a/nzbn/nzbn.py +++ b/nzbn/nzbn.py @@ -22,7 +22,12 @@ DO NOT commit actual secrets to this file. """ -from autohive_integrations_sdk import Integration, ExecutionContext, ActionHandler, ActionResult +from autohive_integrations_sdk import ( + Integration, + ExecutionContext, + ActionHandler, + ActionResult, +) from typing import Dict, Any, Optional, Tuple import base64 import time @@ -98,10 +103,16 @@ async def get_oauth_token(context: ExecutionContext) -> Optional[str]: auth_string = f"{OAUTH_CLIENT_ID}:{OAUTH_CLIENT_SECRET}" auth_bytes = base64.b64encode(auth_string.encode()).decode() - headers = {"Authorization": f"Basic {auth_bytes}", "Content-Type": "application/x-www-form-urlencoded"} + headers = { + "Authorization": f"Basic {auth_bytes}", + "Content-Type": "application/x-www-form-urlencoded", + } response = await context.fetch( - TOKEN_URL, method="POST", headers=headers, data={"grant_type": "client_credentials", "scope": PRODUCTION_SCOPE} + TOKEN_URL, + method="POST", + headers=headers, + data={"grant_type": "client_credentials", "scope": PRODUCTION_SCOPE}, ) # Handle response @@ -136,7 +147,10 @@ async def get_headers(context: ExecutionContext) -> Dict[str, str]: async def make_request( - context: ExecutionContext, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None + context: ExecutionContext, + method: str, + endpoint: str, + params: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Make a request to the NZBN API.""" headers = await get_headers(context) @@ -151,7 +165,10 @@ async def make_request( return {"success": True, "data": None, "not_modified": True} elif response.status_code == 400: error_data = response.json() if hasattr(response, "json") else {} - return {"success": False, "error": error_data.get("errorDescription", "Bad request - validation failed")} + return { + "success": False, + "error": error_data.get("errorDescription", "Bad request - validation failed"), + } elif response.status_code == 401: return {"success": False, "error": "Unauthorized - invalid credentials"} elif response.status_code == 403: @@ -177,7 +194,10 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac try: search_term = inputs.get("search_term", "") if not search_term: - return ActionResult(data={"result": False, "error": "search_term is required"}, cost_usd=0.0) + return ActionResult( + data={"result": False, "error": "search_term is required"}, + cost_usd=0.0, + ) params = {"search-term": search_term} @@ -193,7 +213,10 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac result = await make_request(context, "GET", "/entities", params) if not result["success"]: - return ActionResult(data={"result": False, "error": result["error"], "items": []}, cost_usd=0.0) + return ActionResult( + data={"result": False, "error": result["error"], "items": []}, + cost_usd=0.0, + ) data = result["data"] return ActionResult( @@ -292,16 +315,31 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac try: nzbn_id = inputs.get("nzbn", "") if not nzbn_id: - return ActionResult(data={"result": False, "error": "nzbn is required", "addresses": []}, cost_usd=0.0) + return ActionResult( + data={ + "result": False, + "error": "nzbn is required", + "addresses": [], + }, + cost_usd=0.0, + ) params = {} if inputs.get("address_type"): params["address-type"] = inputs["address_type"] - result = await make_request(context, "GET", f"/entities/{nzbn_id}/addresses", params if params else None) + result = await make_request( + context, + "GET", + f"/entities/{nzbn_id}/addresses", + params if params else None, + ) if not result["success"]: - return ActionResult(data={"result": False, "error": result["error"], "addresses": []}, cost_usd=0.0) + return ActionResult( + data={"result": False, "error": result["error"], "addresses": []}, + cost_usd=0.0, + ) addresses = result["data"] if isinstance(result["data"], list) else result["data"].get("items", []) return ActionResult(data={"result": True, "addresses": addresses}, cost_usd=0.0) @@ -317,12 +355,18 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac try: nzbn_id = inputs.get("nzbn", "") if not nzbn_id: - return ActionResult(data={"result": False, "error": "nzbn is required", "roles": []}, cost_usd=0.0) + return ActionResult( + data={"result": False, "error": "nzbn is required", "roles": []}, + cost_usd=0.0, + ) result = await make_request(context, "GET", f"/entities/{nzbn_id}/roles") if not result["success"]: - return ActionResult(data={"result": False, "error": result["error"], "roles": []}, cost_usd=0.0) + return ActionResult( + data={"result": False, "error": result["error"], "roles": []}, + cost_usd=0.0, + ) roles = result["data"] if isinstance(result["data"], list) else result["data"].get("items", []) return ActionResult(data={"result": True, "roles": roles}, cost_usd=0.0) @@ -339,18 +383,33 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac nzbn_id = inputs.get("nzbn", "") if not nzbn_id: return ActionResult( - data={"result": False, "error": "nzbn is required", "tradingNames": []}, cost_usd=0.0 + data={ + "result": False, + "error": "nzbn is required", + "tradingNames": [], + }, + cost_usd=0.0, ) result = await make_request(context, "GET", f"/entities/{nzbn_id}/trading-names") if not result["success"]: - return ActionResult(data={"result": False, "error": result["error"], "tradingNames": []}, cost_usd=0.0) + return ActionResult( + data={ + "result": False, + "error": result["error"], + "tradingNames": [], + }, + cost_usd=0.0, + ) trading_names = result["data"] if isinstance(result["data"], list) else result["data"].get("items", []) return ActionResult(data={"result": True, "tradingNames": trading_names}, cost_usd=0.0) except Exception as e: - return ActionResult(data={"result": False, "error": str(e), "tradingNames": []}, cost_usd=0.0) + return ActionResult( + data={"result": False, "error": str(e), "tradingNames": []}, + cost_usd=0.0, + ) @nzbn.action("get_company_details") @@ -381,12 +440,22 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac try: nzbn_id = inputs.get("nzbn", "") if not nzbn_id: - return ActionResult(data={"result": False, "error": "nzbn is required", "gstNumbers": []}, cost_usd=0.0) + return ActionResult( + data={ + "result": False, + "error": "nzbn is required", + "gstNumbers": [], + }, + cost_usd=0.0, + ) result = await make_request(context, "GET", f"/entities/{nzbn_id}/gst-numbers") if not result["success"]: - return ActionResult(data={"result": False, "error": result["error"], "gstNumbers": []}, cost_usd=0.0) + return ActionResult( + data={"result": False, "error": result["error"], "gstNumbers": []}, + cost_usd=0.0, + ) gst_numbers = result["data"] if isinstance(result["data"], list) else result["data"].get("items", []) return ActionResult(data={"result": True, "gstNumbers": gst_numbers}, cost_usd=0.0) @@ -403,20 +472,36 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac nzbn_id = inputs.get("nzbn", "") if not nzbn_id: return ActionResult( - data={"result": False, "error": "nzbn is required", "industryClassifications": []}, cost_usd=0.0 + data={ + "result": False, + "error": "nzbn is required", + "industryClassifications": [], + }, + cost_usd=0.0, ) result = await make_request(context, "GET", f"/entities/{nzbn_id}/industry-classifications") if not result["success"]: return ActionResult( - data={"result": False, "error": result["error"], "industryClassifications": []}, cost_usd=0.0 + data={ + "result": False, + "error": result["error"], + "industryClassifications": [], + }, + cost_usd=0.0, ) classifications = result["data"] if isinstance(result["data"], list) else result["data"].get("items", []) - return ActionResult(data={"result": True, "industryClassifications": classifications}, cost_usd=0.0) + return ActionResult( + data={"result": True, "industryClassifications": classifications}, + cost_usd=0.0, + ) except Exception as e: - return ActionResult(data={"result": False, "error": str(e), "industryClassifications": []}, cost_usd=0.0) + return ActionResult( + data={"result": False, "error": str(e), "industryClassifications": []}, + cost_usd=0.0, + ) @nzbn.action("get_changes") @@ -428,7 +513,12 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac change_event_type = inputs.get("change_event_type", "") if not change_event_type: return ActionResult( - data={"result": False, "error": "change_event_type is required", "changes": []}, cost_usd=0.0 + data={ + "result": False, + "error": "change_event_type is required", + "changes": [], + }, + cost_usd=0.0, ) params = {"change-event-type": change_event_type} @@ -445,7 +535,10 @@ async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> Ac result = await make_request(context, "GET", "/entities/changes", params) if not result["success"]: - return ActionResult(data={"result": False, "error": result["error"], "changes": []}, cost_usd=0.0) + return ActionResult( + data={"result": False, "error": result["error"], "changes": []}, + cost_usd=0.0, + ) data = result["data"] changes = data.get("items", []) if isinstance(data, dict) else data diff --git a/nzbn/tests/conftest.py b/nzbn/tests/conftest.py new file mode 100644 index 00000000..1d99cac4 --- /dev/null +++ b/nzbn/tests/conftest.py @@ -0,0 +1,5 @@ +import sys +import os + +# Allow 'from context import ...' to work when pytest runs from repo root +sys.path.insert(0, os.path.dirname(__file__)) diff --git a/nzbn/tests/test_nzbn.py b/nzbn/tests/test_nzbn.py index 556ca721..90971ebd 100644 --- a/nzbn/tests/test_nzbn.py +++ b/nzbn/tests/test_nzbn.py @@ -51,7 +51,12 @@ async def test_search_entities_with_filters(): """Test searching with entity type filter.""" print("\n=== Test: Search Entities with Filters ===") - inputs = {"search_term": "Limited", "entity_type": "LTD", "entity_status": "Registered", "page_size": 3} + inputs = { + "search_term": "Limited", + "entity_type": "LTD", + "entity_status": "Registered", + "page_size": 3, + } async with ExecutionContext(auth=TEST_AUTH) as context: try: diff --git a/nzbn/tests/test_nzbn_unit.py b/nzbn/tests/test_nzbn_unit.py new file mode 100644 index 00000000..d0e0d5dd --- /dev/null +++ b/nzbn/tests/test_nzbn_unit.py @@ -0,0 +1,630 @@ +""" +Unit tests for NZBN integration. + +These tests use mocks — no real API credentials or network calls required. +""" + +import os +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../dependencies"))) + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from nzbn.nzbn import ( + nzbn, + _get_cached_token, + _cache_token, + _token_cache, + make_request, + get_headers, +) + +pytestmark = pytest.mark.unit + +TEST_NZBN = "9429041525746" + + +@pytest.fixture +def mock_context(): + """Create a mock ExecutionContext.""" + context = MagicMock() + context.auth = {"credentials": {}} + context.fetch = AsyncMock() + return context + + +@pytest.fixture(autouse=True) +def clear_token_cache(): + """Clear the token cache before each test.""" + _token_cache.clear() + yield + _token_cache.clear() + + +# ============================================================================= +# Token Cache +# ============================================================================= + + +class TestTokenCache: + """Test token caching helpers directly.""" + + def test_cache_and_retrieve_token(self): + _cache_token("test_scope", "tok_abc", 3600) + assert _get_cached_token("test_scope") == "tok_abc" + + def test_get_cached_token_returns_none_when_empty(self): + assert _get_cached_token("nonexistent") is None + + def test_get_cached_token_returns_none_when_expired(self): + _cache_token("test_scope", "tok_old", 0) + assert _get_cached_token("test_scope") is None + + def test_get_cached_token_returns_none_within_buffer(self): + # Token that expires in 30 seconds — within the 60-second buffer + _cache_token("test_scope", "tok_buf", 30) + assert _get_cached_token("test_scope") is None + + def test_cache_overwrites_existing(self): + _cache_token("scope", "tok_first", 3600) + _cache_token("scope", "tok_second", 3600) + assert _get_cached_token("scope") == "tok_second" + + +# ============================================================================= +# Input Validation +# ============================================================================= + + +class TestInputValidation: + """Missing required fields return error without making API calls. + + The SDK may raise ValidationError or the handler may return an error + result — either way, the missing field is rejected. + """ + + async def _assert_rejects_missing_field(self, mock_context, action, inputs=None): + """Assert that calling an action with missing required fields fails.""" + try: + result = await nzbn.execute_action(action, inputs or {}, mock_context) + # Handler returned an error result instead of the SDK raising + data = result.result.data + assert data["result"] is False + assert "required" in data.get("error", "").lower() + except Exception: # nosec B110 + # SDK raised a validation error — also acceptable + pass + + @pytest.mark.asyncio + async def test_search_entities_missing_search_term(self, mock_context): + await self._assert_rejects_missing_field(mock_context, "search_entities") + + @pytest.mark.asyncio + async def test_get_entity_missing_nzbn(self, mock_context): + await self._assert_rejects_missing_field(mock_context, "get_entity") + + @pytest.mark.asyncio + async def test_get_entity_summary_missing_nzbn(self, mock_context): + await self._assert_rejects_missing_field(mock_context, "get_entity_summary") + + @pytest.mark.asyncio + async def test_get_entity_addresses_missing_nzbn(self, mock_context): + await self._assert_rejects_missing_field(mock_context, "get_entity_addresses") + + @pytest.mark.asyncio + async def test_get_entity_roles_missing_nzbn(self, mock_context): + await self._assert_rejects_missing_field(mock_context, "get_entity_roles") + + @pytest.mark.asyncio + async def test_get_entity_trading_names_missing_nzbn(self, mock_context): + await self._assert_rejects_missing_field(mock_context, "get_entity_trading_names") + + @pytest.mark.asyncio + async def test_get_company_details_missing_nzbn(self, mock_context): + await self._assert_rejects_missing_field(mock_context, "get_company_details") + + @pytest.mark.asyncio + async def test_get_entity_gst_numbers_missing_nzbn(self, mock_context): + await self._assert_rejects_missing_field(mock_context, "get_entity_gst_numbers") + + @pytest.mark.asyncio + async def test_get_entity_industry_classifications_missing_nzbn(self, mock_context): + await self._assert_rejects_missing_field(mock_context, "get_entity_industry_classifications") + + @pytest.mark.asyncio + async def test_get_changes_missing_event_type(self, mock_context): + await self._assert_rejects_missing_field(mock_context, "get_changes") + + +# ============================================================================= +# Action Tests (patching make_request) +# ============================================================================= + + +class TestSearchEntities: + """Test search_entities action.""" + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_basic_search(self, mock_make_request, mock_context): + mock_make_request.return_value = { + "success": True, + "data": { + "items": [{"entityName": "Xero Limited", "nzbn": TEST_NZBN}], + "totalItems": 1, + "page": 0, + "pageSize": 25, + }, + } + + result = await nzbn.execute_action("search_entities", {"search_term": "Xero"}, mock_context) + data = result.result.data + + assert data["result"] is True + assert data["totalItems"] == 1 + assert data["items"][0]["entityName"] == "Xero Limited" + mock_make_request.assert_called_once_with(mock_context, "GET", "/entities", {"search-term": "Xero"}) + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_search_with_filters(self, mock_make_request, mock_context): + mock_make_request.return_value = { + "success": True, + "data": {"items": [], "totalItems": 0, "page": 0, "pageSize": 3}, + } + + inputs = { + "search_term": "Limited", + "entity_type": "LTD", + "entity_status": "Registered", + "page_size": 3, + "page": 1, + } + result = await nzbn.execute_action("search_entities", inputs, mock_context) + data = result.result.data + + assert data["result"] is True + call_args = mock_make_request.call_args + params = call_args[0][3] + assert params["search-term"] == "Limited" + assert params["entity-type"] == "LTD" + assert params["entity-status"] == "Registered" + assert params["page-size"] == 3 + assert params["page"] == 1 + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_search_api_error(self, mock_make_request, mock_context): + mock_make_request.return_value = { + "success": False, + "error": "Bad request - validation failed", + } + + result = await nzbn.execute_action("search_entities", {"search_term": "test"}, mock_context) + data = result.result.data + + assert data["result"] is False + assert "Bad request" in data["error"] + + +class TestGetEntity: + """Test get_entity action.""" + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_get_entity_success(self, mock_make_request, mock_context): + mock_make_request.return_value = { + "success": True, + "data": { + "nzbn": TEST_NZBN, + "entityName": "Xero Limited", + "entityStatusCode": "50", + }, + } + + result = await nzbn.execute_action("get_entity", {"nzbn": TEST_NZBN}, mock_context) + data = result.result.data + + assert data["result"] is True + assert data["entity"]["entityName"] == "Xero Limited" + mock_make_request.assert_called_once_with(mock_context, "GET", f"/entities/{TEST_NZBN}") + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_get_entity_not_found(self, mock_make_request, mock_context): + mock_make_request.return_value = {"success": False, "error": "Entity not found"} + + result = await nzbn.execute_action("get_entity", {"nzbn": "0000000000000"}, mock_context) + data = result.result.data + + assert data["result"] is False + assert "not found" in data["error"].lower() + + +class TestGetEntitySummary: + """Test get_entity_summary action.""" + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_summary_returns_three_fields(self, mock_make_request, mock_context): + mock_make_request.return_value = { + "success": True, + "data": { + "nzbn": TEST_NZBN, + "entityName": "Xero Limited", + "entityStatusCode": "50", + "addresses": { + "addressList": [ + { + "addressType": "REGISTERED", + "address1": "19-23 Taranaki Street", + "address3": "Wellington", + "postCode": "6011", + } + ] + }, + "roles": {"roleList": [{"roleName": "Director"}]}, + }, + } + + result = await nzbn.execute_action("get_entity_summary", {"nzbn": TEST_NZBN}, mock_context) + data = result.result.data + + assert data["result"] is True + summary = data["summary"] + assert len(summary) == 3 + assert summary["nzbn"] == TEST_NZBN + assert summary["entityName"] == "Xero Limited" + assert "19-23 Taranaki Street" in summary["registeredOffice"] + assert "6011" in summary["registeredOffice"] + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_summary_no_registered_address(self, mock_make_request, mock_context): + mock_make_request.return_value = { + "success": True, + "data": { + "nzbn": TEST_NZBN, + "entityName": "Test Co", + "addresses": {"addressList": []}, + }, + } + + result = await nzbn.execute_action("get_entity_summary", {"nzbn": TEST_NZBN}, mock_context) + data = result.result.data + + assert data["result"] is True + assert data["summary"]["registeredOffice"] == "" + + +class TestGetEntityAddresses: + """Test get_entity_addresses action.""" + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_addresses_success(self, mock_make_request, mock_context): + mock_make_request.return_value = { + "success": True, + "data": { + "items": [ + {"addressType": "REGISTERED", "address1": "123 Main St"}, + {"addressType": "POSTAL", "address1": "PO Box 100"}, + ] + }, + } + + result = await nzbn.execute_action("get_entity_addresses", {"nzbn": TEST_NZBN}, mock_context) + data = result.result.data + + assert data["result"] is True + assert len(data["addresses"]) == 2 + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_addresses_with_type_filter(self, mock_make_request, mock_context): + mock_make_request.return_value = { + "success": True, + "data": {"items": [{"addressType": "REGISTERED", "address1": "123 Main St"}]}, + } + + await nzbn.execute_action( + "get_entity_addresses", + {"nzbn": TEST_NZBN, "address_type": "RegisteredOffice"}, + mock_context, + ) + + call_args = mock_make_request.call_args + params = call_args[0][3] + assert params["address-type"] == "RegisteredOffice" + + +class TestGetEntityRoles: + """Test get_entity_roles action.""" + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_roles_success(self, mock_make_request, mock_context): + mock_make_request.return_value = { + "success": True, + "data": {"items": [{"roleName": "Director", "firstName": "Jane"}]}, + } + + result = await nzbn.execute_action("get_entity_roles", {"nzbn": TEST_NZBN}, mock_context) + data = result.result.data + + assert data["result"] is True + assert len(data["roles"]) == 1 + assert data["roles"][0]["roleName"] == "Director" + + +class TestGetEntityTradingNames: + """Test get_entity_trading_names action.""" + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_trading_names_success(self, mock_make_request, mock_context): + mock_make_request.return_value = { + "success": True, + "data": {"items": [{"name": "Xero NZ"}]}, + } + + result = await nzbn.execute_action("get_entity_trading_names", {"nzbn": TEST_NZBN}, mock_context) + data = result.result.data + + assert data["result"] is True + assert len(data["tradingNames"]) == 1 + + +class TestGetCompanyDetails: + """Test get_company_details action.""" + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_company_details_success(self, mock_make_request, mock_context): + mock_make_request.return_value = { + "success": True, + "data": {"companyNumber": "1234567", "annualReturnFilingMonth": 3}, + } + + result = await nzbn.execute_action("get_company_details", {"nzbn": TEST_NZBN}, mock_context) + data = result.result.data + + assert data["result"] is True + assert data["companyDetails"]["companyNumber"] == "1234567" + + +class TestGetEntityGstNumbers: + """Test get_entity_gst_numbers action.""" + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_gst_numbers_success(self, mock_make_request, mock_context): + mock_make_request.return_value = { + "success": True, + "data": {"items": [{"gstNumber": "123-456-789"}]}, + } + + result = await nzbn.execute_action("get_entity_gst_numbers", {"nzbn": TEST_NZBN}, mock_context) + data = result.result.data + + assert data["result"] is True + assert len(data["gstNumbers"]) == 1 + + +class TestGetEntityIndustryClassifications: + """Test get_entity_industry_classifications action.""" + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_industry_classifications_success(self, mock_make_request, mock_context): + mock_make_request.return_value = { + "success": True, + "data": { + "items": [ + { + "classificationCode": "L631", + "classificationDescription": "Software", + } + ] + }, + } + + result = await nzbn.execute_action("get_entity_industry_classifications", {"nzbn": TEST_NZBN}, mock_context) + data = result.result.data + + assert data["result"] is True + assert len(data["industryClassifications"]) == 1 + assert data["industryClassifications"][0]["classificationCode"] == "L631" + + +class TestGetChanges: + """Test get_changes action.""" + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_get_changes_success(self, mock_make_request, mock_context): + mock_make_request.return_value = { + "success": True, + "data": { + "items": [{"nzbn": TEST_NZBN, "changeEventType": "NewRegistration"}], + "totalItems": 1, + }, + } + + result = await nzbn.execute_action("get_changes", {"change_event_type": "NewRegistration"}, mock_context) + data = result.result.data + + assert data["result"] is True + assert len(data["changes"]) == 1 + assert data["totalItems"] == 1 + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_get_changes_with_date_filters(self, mock_make_request, mock_context): + mock_make_request.return_value = { + "success": True, + "data": {"items": [], "totalItems": 0}, + } + + inputs = { + "change_event_type": "NameChange", + "start_date": "2024-01-01", + "end_date": "2024-01-31", + "page_size": 10, + "page": 0, + } + await nzbn.execute_action("get_changes", inputs, mock_context) + + call_args = mock_make_request.call_args + params = call_args[0][3] + assert params["change-event-type"] == "NameChange" + assert params["start-date"] == "2024-01-01" + assert params["end-date"] == "2024-01-31" + assert params["page-size"] == 10 + assert params["page"] == 0 + + +# ============================================================================= +# Helper Function Tests +# ============================================================================= + + +class TestMakeRequest: + """Test make_request helper.""" + + @pytest.mark.asyncio + @patch("nzbn.nzbn.get_headers") + async def test_make_request_success_dict_response(self, mock_get_headers, mock_context): + mock_get_headers.return_value = { + "Authorization": "Bearer tok", + "Accept": "application/json", + } + mock_context.fetch.return_value = {"nzbn": TEST_NZBN, "entityName": "Test"} + + result = await make_request(mock_context, "GET", f"/entities/{TEST_NZBN}") + + assert result["success"] is True + assert result["data"]["entityName"] == "Test" + + @pytest.mark.asyncio + @patch("nzbn.nzbn.get_headers") + async def test_make_request_http_404(self, mock_get_headers, mock_context): + mock_get_headers.return_value = {"Authorization": "Bearer tok"} + response = MagicMock() + response.status_code = 404 + mock_context.fetch.return_value = response + + result = await make_request(mock_context, "GET", "/entities/0000000000000") + + assert result["success"] is False + assert "not found" in result["error"].lower() + + @pytest.mark.asyncio + @patch("nzbn.nzbn.get_headers") + async def test_make_request_http_401(self, mock_get_headers, mock_context): + mock_get_headers.return_value = {} + response = MagicMock() + response.status_code = 401 + mock_context.fetch.return_value = response + + result = await make_request(mock_context, "GET", "/entities/test") + + assert result["success"] is False + assert "Unauthorized" in result["error"] + + @pytest.mark.asyncio + @patch("nzbn.nzbn.get_headers") + async def test_make_request_http_200(self, mock_get_headers, mock_context): + mock_get_headers.return_value = {} + response = MagicMock() + response.status_code = 200 + response.json.return_value = {"entityName": "OK Corp"} + mock_context.fetch.return_value = response + + result = await make_request(mock_context, "GET", "/entities/test") + + assert result["success"] is True + assert result["data"]["entityName"] == "OK Corp" + + @pytest.mark.asyncio + @patch("nzbn.nzbn.get_headers") + async def test_make_request_http_304(self, mock_get_headers, mock_context): + mock_get_headers.return_value = {} + response = MagicMock() + response.status_code = 304 + mock_context.fetch.return_value = response + + result = await make_request(mock_context, "GET", "/entities/test") + + assert result["success"] is True + assert result["not_modified"] is True + + +class TestGetHeaders: + """Test get_headers helper.""" + + @pytest.mark.asyncio + @patch("nzbn.nzbn.get_oauth_token") + @patch("nzbn.nzbn.SUBSCRIPTION_KEY", "test-sub-key") + async def test_headers_include_subscription_key(self, mock_get_oauth_token, mock_context): + mock_get_oauth_token.return_value = "tok_abc" + + headers = await get_headers(mock_context) + + assert headers["Ocp-Apim-Subscription-Key"] == "test-sub-key" + assert headers["Authorization"] == "Bearer tok_abc" + assert headers["Accept"] == "application/json" + + @pytest.mark.asyncio + @patch("nzbn.nzbn.get_oauth_token") + async def test_headers_without_token(self, mock_get_oauth_token, mock_context): + mock_get_oauth_token.return_value = None + + headers = await get_headers(mock_context) + + assert "Authorization" not in headers + + +# ============================================================================= +# Error Handling +# ============================================================================= + + +class TestErrorHandling: + """Verify actions handle exceptions gracefully.""" + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_search_entities_exception(self, mock_make_request, mock_context): + mock_make_request.side_effect = RuntimeError("connection refused") + + result = await nzbn.execute_action("search_entities", {"search_term": "test"}, mock_context) + data = result.result.data + + assert data["result"] is False + assert "connection refused" in data["error"] + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_get_entity_exception(self, mock_make_request, mock_context): + mock_make_request.side_effect = RuntimeError("timeout") + + result = await nzbn.execute_action("get_entity", {"nzbn": TEST_NZBN}, mock_context) + data = result.result.data + + assert data["result"] is False + assert "timeout" in data["error"] + + @pytest.mark.asyncio + @patch("nzbn.nzbn.make_request") + async def test_get_changes_exception(self, mock_make_request, mock_context): + mock_make_request.side_effect = RuntimeError("server error") + + result = await nzbn.execute_action("get_changes", {"change_event_type": "NewRegistration"}, mock_context) + data = result.result.data + + assert data["result"] is False + assert "server error" in data["error"] diff --git a/perplexity/README.md b/perplexity/README.md new file mode 100644 index 00000000..c80f7fe8 --- /dev/null +++ b/perplexity/README.md @@ -0,0 +1,113 @@ +# Perplexity Search Integration for Autohive + +Web search integration powered by Perplexity's AI search API. Get ranked, structured search results from billions of webpages with advanced filtering options. + +## Description + +This integration provides access to Perplexity's Search API, enabling AI agents and automation workflows to perform real-time web searches. It returns structured, ranked results optimized for AI consumption with comprehensive content extraction capabilities. + +Key features include: +- Real-time web search across hundreds of billions of webpages +- Structured JSON results with titles, URLs, snippets, and dates +- Content depth control (quick, default, detailed extraction) +- Geographic filtering by country +- Multi-query support (up to 5 queries per request) +- Configurable result limits (1-20 results per query) + +This integration uses the Perplexity Search API and provides robust error handling with clear user-facing error messages. + +## Setup & Authentication + +This integration requires a Perplexity API key set as an environment variable: + +```bash +export PERPLEXITY_API_KEY="your-api-key-here" +``` + +Get your API key from [Perplexity API Settings](https://www.perplexity.ai/settings/api). + +The integration automatically handles: +- API authentication via Bearer token +- Rate limiting (3 requests per second) +- Error handling for missing API key, insufficient credits, and invalid keys +- Content extraction optimization + +## Actions + +### search_web + +Search the web using Perplexity's search API. Returns ranked, structured results with titles, URLs, snippets, and dates. + +**Input Parameters:** +- `query` (required): Search query string or array of queries for multi-query search +- `max_results` (optional): Maximum number of results to return (1-20, default: 10) +- `content_depth` (optional): Content extraction depth + - `"quick"` - Brief snippets (512 tokens per page) + - `"default"` - Moderate content (2048 tokens per page) + - `"detailed"` - Comprehensive content (8192 tokens per page) +- `country` (optional): Two-letter ISO country code for geographic filtering (e.g., "US", "GB", "DE") + +**Output:** +- `results`: Array of search results with: + - `title`: Page title + - `url`: Full URL + - `snippet`: Content excerpt + - `date`: Publication date (may be null) + - `last_updated`: Last update date (may be null) +- `id`: Unique request identifier +- `total_results`: Total number of results returned + +**Example Usage:** + +Basic search: +``` +Search for "quantum computing breakthroughs 2025" +``` + +Multi-query search: +``` +Search for "AI agents", "LLM developments", and "autonomous systems" using Perplexity +``` + +Detailed research: +``` +Search for "comprehensive climate change solutions" with detailed content depth +``` + +Geographic filtering: +``` +Search for "tech startups" in the US using Perplexity +``` + +## Rate Limits + +- 3 requests per second +- 3 request burst capacity +- Tier-based limits (50-2000 requests per minute depending on API spending) + +## Error Handling + +The integration provides clear error messages for: +- Missing API key (PERPLEXITY_API_KEY not set) +- Rate limit exceeded (429 errors) +- Invalid API key (401 errors) +- Insufficient credits (403 errors) +- General API failures + +## Use Cases + +- Real-time web research for AI agents +- Competitive intelligence gathering +- Content research and curation +- Market research automation +- News monitoring and alerts +- Academic research compilation +- Product comparison research +- Customer support knowledge gathering + +## Technical Details + +- **API Endpoint**: `https://api.perplexity.ai/search` +- **Authentication**: Bearer token via `PERPLEXITY_API_KEY` environment variable +- **Response Format**: JSON +- **Pricing**: $5 per 1,000 requests diff --git a/perplexity/__init__.py b/perplexity/__init__.py new file mode 100644 index 00000000..3b278b69 --- /dev/null +++ b/perplexity/__init__.py @@ -0,0 +1,3 @@ +from .perplexity import perplexity + +__all__ = ["perplexity"] diff --git a/perplexity/config.json b/perplexity/config.json new file mode 100644 index 00000000..01d97f26 --- /dev/null +++ b/perplexity/config.json @@ -0,0 +1,97 @@ +{ + "name": "Perplexity", + "display_name": "Perplexity", + "version": "1.1.0", + "description": "Search the web with Perplexity's AI-powered search API. Get ranked, structured results from billions of webpages with advanced filtering options.", + "entry_point": "perplexity.py", + "supports_billing": true, + "actions": { + "search_web": { + "display_name": "Search Web", + "description": "Search the web using Perplexity's search API. Returns ranked, structured results with titles, URLs, snippets, and dates. Supports filtering by country, controlling result count, and multi-query searches.", + "input_schema": { + "type": "object", + "properties": { + "query": { + "oneOf": [ + { + "type": "string", + "description": "A single search query" + }, + { + "type": "array", + "description": "Multiple search queries for comprehensive research", + "items": { + "type": "string" + } + } + ], + "description": "Search query or queries. Can be a single string or an array of strings for multi-query search." + }, + "max_results": { + "type": "integer", + "description": "Maximum number of search results to return per query", + "default": 10, + "minimum": 1, + "maximum": 20 + }, + "content_depth": { + "type": "string", + "description": "How much content to extract from each webpage. 'quick' extracts brief snippets, 'default' extracts moderate content, 'detailed' extracts comprehensive content for in-depth research.", + "enum": ["quick", "default", "detailed"], + "default": "default" + }, + "country": { + "type": "string", + "description": "Two-letter ISO country code to filter search results by geographic location (e.g., 'US', 'GB', 'DE', 'FR', 'JP'). Leave empty for global results.", + "pattern": "^[A-Z]{2}$" + } + }, + "required": ["query"] + }, + "output_schema": { + "type": "object", + "properties": { + "results": { + "type": "array", + "description": "Array of search results with structured information", + "items": { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Title of the webpage or article" + }, + "url": { + "type": "string", + "description": "Full URL of the search result" + }, + "snippet": { + "type": "string", + "description": "Content excerpt from the page" + }, + "date": { + "type": ["string", "null"], + "description": "Publication date of the content (format: YYYY-MM-DD). May be null if date is unavailable." + }, + "last_updated": { + "type": ["string", "null"], + "description": "Last updated date of the content (format: YYYY-MM-DD). May be null if date is unavailable." + } + } + } + }, + "id": { + "type": "string", + "description": "Unique request identifier" + }, + "total_results": { + "type": "integer", + "description": "Total number of results returned" + } + }, + "required": ["results"] + } + } + } +} diff --git a/perplexity/icon.png b/perplexity/icon.png new file mode 100644 index 00000000..42436320 Binary files /dev/null and b/perplexity/icon.png differ diff --git a/perplexity/perplexity.py b/perplexity/perplexity.py new file mode 100644 index 00000000..f911cf46 --- /dev/null +++ b/perplexity/perplexity.py @@ -0,0 +1,124 @@ +"""Perplexity Integration for Autohive + +This integration provides web search capabilities using Perplexity's Search API. +""" + +import os +from autohive_integrations_sdk import Integration, ExecutionContext, ActionHandler, ActionResult +from typing import Dict, Any + +# Load the integration from config.json +perplexity = Integration.load() + + +async def parse_response(response): + """Parse the response from context.fetch()""" + if hasattr(response, "json"): + return await response.json() + return response + + +@perplexity.action("search_web") +class SearchWebActionHandler(ActionHandler): + """ + Action handler to search the web using Perplexity's Search API. + + Returns ranked, structured search results with titles, URLs, snippets, + publication dates, and last updated dates. + """ + + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext): + """ + Execute the search_web action. + + :param inputs: Dictionary with keys: + - query: Search query (string or array of strings, max 5) + - max_results: Maximum results to return (1-20, default 10) + - max_tokens_per_page: Tokens per page (default 1024) + - country: ISO country code (optional) + :param context: Execution context with authentication details + :return: Dictionary with search results + """ + + try: + api_key = os.environ.get("PERPLEXITY_API_KEY", "") + if not api_key: + return ActionResult( + data={ + "results": [], + "total_results": 0, + "error": "PERPLEXITY_API_KEY environment variable is not set or empty.", + } + ) + + query = inputs["query"] + + # Build the request payload + payload = {"query": query} + + # Add optional parameters if provided + if "max_results" in inputs: + payload["max_results"] = inputs["max_results"] + + # Convert content_depth from string enum to max_tokens_per_page integer + if "content_depth" in inputs: + token_mapping = {"quick": 512, "default": 2048, "detailed": 8192} + content_depth_value = inputs["content_depth"] + payload["max_tokens_per_page"] = token_mapping.get(content_depth_value, 2048) + + if "country" in inputs and inputs["country"]: + payload["country"] = inputs["country"] + + # Make the API request using context.fetch() + response = await context.fetch( + "https://api.perplexity.ai/search", + method="POST", + json=payload, + headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, + ) + + # Parse the response + result = await parse_response(response) + + # Enhance the response with total_results count + if "results" in result: + result["total_results"] = len(result["results"]) + + return ActionResult(data=result, cost_usd=0.005) + + except KeyError as e: + return ActionResult( + data={"results": [], "total_results": 0, "error": f"Missing required input field: {str(e)}"} + ) + + except Exception as e: + error_message = str(e) + + if "429" in error_message or "rate limit" in error_message.lower(): + return ActionResult( + data={ + "results": [], + "total_results": 0, + "error": "Rate limit exceeded. Please wait a moment and try again. Perplexity allows 3 requests per second.", # noqa: E501 + } + ) + elif "401" in error_message or "unauthorized" in error_message.lower(): + return ActionResult( + data={ + "results": [], + "total_results": 0, + "error": "Invalid API key. Please check your PERPLEXITY_API_KEY environment variable.", + } + ) + elif "403" in error_message or "forbidden" in error_message.lower(): + return ActionResult( + data={ + "results": [], + "total_results": 0, + "error": "Access forbidden. Please ensure you have purchased API credits at https://www.perplexity.ai/settings/api", # noqa: E501 + } + ) + else: + return ActionResult( + data={"results": [], "total_results": 0, "error": f"Failed to search: {error_message}"} + ) diff --git a/perplexity/requirements.txt b/perplexity/requirements.txt new file mode 100644 index 00000000..b56fee2e --- /dev/null +++ b/perplexity/requirements.txt @@ -0,0 +1 @@ +autohive-integrations-sdk~=1.0.2 diff --git a/perplexity/tests/__init__.py b/perplexity/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/perplexity/tests/conftest.py b/perplexity/tests/conftest.py new file mode 100644 index 00000000..1d99cac4 --- /dev/null +++ b/perplexity/tests/conftest.py @@ -0,0 +1,5 @@ +import sys +import os + +# Allow 'from context import ...' to work when pytest runs from repo root +sys.path.insert(0, os.path.dirname(__file__)) diff --git a/perplexity/tests/test_perplexity_integration.py b/perplexity/tests/test_perplexity_integration.py new file mode 100644 index 00000000..ecbbc179 --- /dev/null +++ b/perplexity/tests/test_perplexity_integration.py @@ -0,0 +1,162 @@ +""" +End-to-end integration tests for the Perplexity search integration. + +These tests call the real Perplexity API and require a valid API key +set in the PERPLEXITY_API_KEY environment variable (via .env or export). + +Run with: + pytest perplexity/tests/test_perplexity_integration.py -m integration + +Never runs in CI — the default pytest marker filter (-m unit) excludes these, +and the file naming (test_*_integration.py) is not matched by python_files. +""" + +import os +import sys +import importlib + +_parent = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +_deps = os.path.abspath(os.path.join(os.path.dirname(__file__), "../dependencies")) +sys.path.insert(0, _parent) +sys.path.insert(0, _deps) + +import pytest # noqa: E402 +from unittest.mock import MagicMock, AsyncMock # noqa: E402 + +_spec = importlib.util.spec_from_file_location("perplexity_mod", os.path.join(_parent, "perplexity.py")) +_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_mod) + +perplexity = _mod.perplexity + +pytestmark = pytest.mark.integration + +API_KEY = os.environ.get("PERPLEXITY_API_KEY", "") + + +@pytest.fixture +def live_context(): + """Execution context that uses a real fetch (via the SDK's context.fetch mock pattern). + + For e2e tests we still use a MagicMock context but with a real async HTTP call + wired through context.fetch. Since the integration calls context.fetch() directly, + we need to provide a real async HTTP client. + """ + if not API_KEY: + pytest.skip("PERPLEXITY_API_KEY not set — skipping integration tests") + + import aiohttp + + async def real_fetch(url, *, method="GET", json=None, headers=None, **kwargs): + async with aiohttp.ClientSession() as session: + async with session.request(method, url, json=json, headers=headers) as resp: + return await resp.json() + + ctx = MagicMock(name="ExecutionContext") + ctx.fetch = AsyncMock(side_effect=real_fetch) + ctx.auth = {} + return ctx + + +class TestBasicSearch: + @pytest.mark.asyncio + async def test_simple_query_returns_results(self, live_context): + result = await perplexity.execute_action("search_web", {"query": "Python programming language"}, live_context) + + data = result.result.data + assert "results" in data + assert data["total_results"] > 0 + assert len(data["results"]) > 0 + + @pytest.mark.asyncio + async def test_result_structure(self, live_context): + result = await perplexity.execute_action("search_web", {"query": "what is pytest"}, live_context) + + data = result.result.data + first_result = data["results"][0] + assert "title" in first_result + assert "url" in first_result + assert first_result["url"].startswith("http") + + @pytest.mark.asyncio + async def test_cost_is_set(self, live_context): + result = await perplexity.execute_action("search_web", {"query": "test"}, live_context) + + assert result.result.cost_usd == 0.005 + + +class TestMaxResults: + @pytest.mark.asyncio + async def test_respects_max_results(self, live_context): + result = await perplexity.execute_action( + "search_web", {"query": "artificial intelligence", "max_results": 3}, live_context + ) + + data = result.result.data + assert data["total_results"] <= 3 + + @pytest.mark.asyncio + async def test_single_result(self, live_context): + result = await perplexity.execute_action( + "search_web", {"query": "SpaceX launch", "max_results": 1}, live_context + ) + + data = result.result.data + assert data["total_results"] >= 1 + assert len(data["results"]) >= 1 + + +class TestContentDepth: + @pytest.mark.asyncio + async def test_quick_depth(self, live_context): + result = await perplexity.execute_action( + "search_web", {"query": "climate change", "max_results": 2, "content_depth": "quick"}, live_context + ) + + data = result.result.data + assert data["total_results"] > 0 + + @pytest.mark.asyncio + async def test_detailed_depth(self, live_context): + result = await perplexity.execute_action( + "search_web", {"query": "quantum computing", "max_results": 2, "content_depth": "detailed"}, live_context + ) + + data = result.result.data + assert data["total_results"] > 0 + + +class TestCountryFilter: + @pytest.mark.asyncio + async def test_country_filter_us(self, live_context): + result = await perplexity.execute_action( + "search_web", {"query": "tech companies", "max_results": 5, "country": "US"}, live_context + ) + + data = result.result.data + assert data["total_results"] > 0 + + +class TestMultiQuery: + @pytest.mark.asyncio + async def test_multi_query(self, live_context): + result = await perplexity.execute_action( + "search_web", {"query": ["machine learning", "deep learning"], "max_results": 3}, live_context + ) + + data = result.result.data + assert "results" in data + + +class TestAllParametersCombined: + @pytest.mark.asyncio + async def test_all_params(self, live_context): + result = await perplexity.execute_action( + "search_web", + {"query": "renewable energy", "max_results": 5, "content_depth": "default", "country": "GB"}, + live_context, + ) + + data = result.result.data + assert data["total_results"] > 0 + assert len(data["results"]) <= 5 diff --git a/perplexity/tests/test_perplexity_unit.py b/perplexity/tests/test_perplexity_unit.py new file mode 100644 index 00000000..532fc5b2 --- /dev/null +++ b/perplexity/tests/test_perplexity_unit.py @@ -0,0 +1,362 @@ +import os +import sys +import importlib + +_parent = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +_deps = os.path.abspath(os.path.join(os.path.dirname(__file__), "../dependencies")) +sys.path.insert(0, _parent) +sys.path.insert(0, _deps) + +import pytest # noqa: E402 +from unittest.mock import AsyncMock, MagicMock, patch # noqa: E402 +from autohive_integrations_sdk.integration import ValidationError # noqa: E402 + +_spec = importlib.util.spec_from_file_location("perplexity_mod", os.path.join(_parent, "perplexity.py")) +_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_mod) + +perplexity = _mod.perplexity +parse_response = _mod.parse_response + +pytestmark = pytest.mark.unit + +PERPLEXITY_API_URL = "https://api.perplexity.ai/search" + +SAMPLE_RESULTS = [ + { + "title": "AI Breakthroughs 2025", + "url": "https://example.com/ai-2025", + "snippet": "Major developments in artificial intelligence...", + "date": "2025-03-15", + "last_updated": "2025-03-20", + }, + { + "title": "Machine Learning Trends", + "url": "https://example.com/ml-trends", + "snippet": "The latest trends in machine learning...", + "date": "2025-02-10", + "last_updated": None, + }, +] + + +@pytest.fixture +def mock_context(): + ctx = MagicMock(name="ExecutionContext") + ctx.fetch = AsyncMock(name="fetch") + ctx.auth = {} + return ctx + + +# ---- Helper Function Tests ---- + + +class TestParseResponse: + @pytest.mark.asyncio + async def test_dict_response_passthrough(self): + response = {"results": [], "id": "abc"} + result = await parse_response(response) + assert result == {"results": [], "id": "abc"} + + @pytest.mark.asyncio + async def test_response_with_json_method(self): + response = MagicMock() + response.json = AsyncMock(return_value={"results": [{"title": "Test"}]}) + result = await parse_response(response) + assert result == {"results": [{"title": "Test"}]} + + @pytest.mark.asyncio + async def test_string_response_passthrough(self): + result = await parse_response("raw text") + assert result == "raw text" + + @pytest.mark.asyncio + async def test_list_response_passthrough(self): + result = await parse_response([1, 2, 3]) + assert result == [1, 2, 3] + + +# ---- API Key Handling ---- + + +class TestApiKeyHandling: + @pytest.mark.asyncio + @patch.dict(os.environ, {}, clear=True) + async def test_missing_api_key(self, mock_context): + os.environ.pop("PERPLEXITY_API_KEY", None) + + result = await perplexity.execute_action("search_web", {"query": "test"}, mock_context) + + data = result.result.data + assert data["results"] == [] + assert data["total_results"] == 0 + assert "PERPLEXITY_API_KEY" in data["error"] + mock_context.fetch.assert_not_called() + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": ""}) + async def test_empty_api_key(self, mock_context): + result = await perplexity.execute_action("search_web", {"query": "test"}, mock_context) + + data = result.result.data + assert data["results"] == [] + assert "PERPLEXITY_API_KEY" in data["error"] + mock_context.fetch.assert_not_called() + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key-123"}) # nosec B105 + async def test_api_key_sent_in_header(self, mock_context): + mock_context.fetch.return_value = {"results": [], "id": "req-1"} + + await perplexity.execute_action("search_web", {"query": "test"}, mock_context) + + call_kwargs = mock_context.fetch.call_args + headers = call_kwargs.kwargs["headers"] + assert headers["Authorization"] == "Bearer test-key-123" + assert headers["Content-Type"] == "application/json" + + +# ---- Search Web Action: Happy Path ---- + + +class TestSearchWebBasic: + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_basic_search(self, mock_context): + mock_context.fetch.return_value = {"results": SAMPLE_RESULTS, "id": "req-123"} + + result = await perplexity.execute_action("search_web", {"query": "AI developments"}, mock_context) + + data = result.result.data + assert data["total_results"] == 2 + assert len(data["results"]) == 2 + assert data["results"][0]["title"] == "AI Breakthroughs 2025" + assert data["id"] == "req-123" + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_request_url_and_method(self, mock_context): + mock_context.fetch.return_value = {"results": [], "id": "req-1"} + + await perplexity.execute_action("search_web", {"query": "test query"}, mock_context) + + mock_context.fetch.assert_called_once() + call_args = mock_context.fetch.call_args + assert call_args.args[0] == PERPLEXITY_API_URL + assert call_args.kwargs["method"] == "POST" + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_basic_payload(self, mock_context): + mock_context.fetch.return_value = {"results": [], "id": "req-1"} + + await perplexity.execute_action("search_web", {"query": "quantum computing"}, mock_context) + + call_kwargs = mock_context.fetch.call_args.kwargs + assert call_kwargs["json"] == {"query": "quantum computing"} + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_cost_usd_set(self, mock_context): + mock_context.fetch.return_value = {"results": [], "id": "req-1"} + + result = await perplexity.execute_action("search_web", {"query": "test"}, mock_context) + + assert result.result.cost_usd == 0.005 + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_multi_query_search(self, mock_context): + mock_context.fetch.return_value = {"results": SAMPLE_RESULTS, "id": "req-multi"} + + queries = ["AI trends", "ML applications", "neural networks"] + result = await perplexity.execute_action("search_web", {"query": queries}, mock_context) + + call_kwargs = mock_context.fetch.call_args.kwargs + assert call_kwargs["json"]["query"] == queries + assert result.result.data["total_results"] == 2 + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_empty_results(self, mock_context): + mock_context.fetch.return_value = {"results": [], "id": "req-empty"} + + result = await perplexity.execute_action("search_web", {"query": "xyznonexistent"}, mock_context) + + data = result.result.data + assert data["results"] == [] + assert data["total_results"] == 0 + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_response_without_results_key(self, mock_context): + mock_context.fetch.return_value = {"id": "req-no-results"} + + with pytest.raises(ValidationError): + await perplexity.execute_action("search_web", {"query": "test"}, mock_context) + + +# ---- Optional Parameters ---- + + +class TestOptionalParameters: + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_max_results(self, mock_context): + mock_context.fetch.return_value = {"results": SAMPLE_RESULTS[:1], "id": "req-1"} + + await perplexity.execute_action("search_web", {"query": "test", "max_results": 5}, mock_context) + + payload = mock_context.fetch.call_args.kwargs["json"] + assert payload["max_results"] == 5 + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_content_depth_quick(self, mock_context): + mock_context.fetch.return_value = {"results": [], "id": "req-1"} + + await perplexity.execute_action("search_web", {"query": "test", "content_depth": "quick"}, mock_context) + + payload = mock_context.fetch.call_args.kwargs["json"] + assert payload["max_tokens_per_page"] == 512 + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_content_depth_default(self, mock_context): + mock_context.fetch.return_value = {"results": [], "id": "req-1"} + + await perplexity.execute_action("search_web", {"query": "test", "content_depth": "default"}, mock_context) + + payload = mock_context.fetch.call_args.kwargs["json"] + assert payload["max_tokens_per_page"] == 2048 + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_content_depth_detailed(self, mock_context): + mock_context.fetch.return_value = {"results": [], "id": "req-1"} + + await perplexity.execute_action("search_web", {"query": "test", "content_depth": "detailed"}, mock_context) + + payload = mock_context.fetch.call_args.kwargs["json"] + assert payload["max_tokens_per_page"] == 8192 + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_content_depth_unknown_rejected_by_schema(self, mock_context): + with pytest.raises(ValidationError): + await perplexity.execute_action("search_web", {"query": "test", "content_depth": "unknown"}, mock_context) + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_country_filter(self, mock_context): + mock_context.fetch.return_value = {"results": [], "id": "req-1"} + + await perplexity.execute_action("search_web", {"query": "test", "country": "US"}, mock_context) + + payload = mock_context.fetch.call_args.kwargs["json"] + assert payload["country"] == "US" + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_empty_country_rejected_by_schema(self, mock_context): + with pytest.raises(ValidationError): + await perplexity.execute_action("search_web", {"query": "test", "country": ""}, mock_context) + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_all_params_combined(self, mock_context): + mock_context.fetch.return_value = {"results": SAMPLE_RESULTS, "id": "req-full"} + + inputs = {"query": "climate change", "max_results": 10, "content_depth": "detailed", "country": "GB"} + result = await perplexity.execute_action("search_web", inputs, mock_context) + + payload = mock_context.fetch.call_args.kwargs["json"] + assert payload["query"] == "climate change" + assert payload["max_results"] == 10 + assert payload["max_tokens_per_page"] == 8192 + assert payload["country"] == "GB" + assert result.result.data["total_results"] == 2 + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_no_optional_params(self, mock_context): + mock_context.fetch.return_value = {"results": [], "id": "req-1"} + + await perplexity.execute_action("search_web", {"query": "test"}, mock_context) + + payload = mock_context.fetch.call_args.kwargs["json"] + assert payload == {"query": "test"} + + +# ---- Error Handling ---- + + +class TestErrorHandling: + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_rate_limit_429(self, mock_context): + mock_context.fetch.side_effect = Exception("HTTP 429: rate limit exceeded") + + result = await perplexity.execute_action("search_web", {"query": "test"}, mock_context) + + data = result.result.data + assert data["results"] == [] + assert data["total_results"] == 0 + assert "Rate limit exceeded" in data["error"] + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_rate_limit_text_match(self, mock_context): + mock_context.fetch.side_effect = Exception("Too many requests, rate limit hit") + + result = await perplexity.execute_action("search_web", {"query": "test"}, mock_context) + + assert "Rate limit exceeded" in result.result.data["error"] + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_unauthorized_401(self, mock_context): + mock_context.fetch.side_effect = Exception("HTTP 401: unauthorized") + + result = await perplexity.execute_action("search_web", {"query": "test"}, mock_context) + + data = result.result.data + assert data["results"] == [] + assert "Invalid API key" in data["error"] + assert "PERPLEXITY_API_KEY" in data["error"] + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_forbidden_403(self, mock_context): + mock_context.fetch.side_effect = Exception("HTTP 403: forbidden") + + result = await perplexity.execute_action("search_web", {"query": "test"}, mock_context) + + data = result.result.data + assert data["results"] == [] + assert "Access forbidden" in data["error"] + assert "perplexity.ai/settings/api" in data["error"] + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_generic_exception(self, mock_context): + mock_context.fetch.side_effect = Exception("Connection timeout") + + result = await perplexity.execute_action("search_web", {"query": "test"}, mock_context) + + data = result.result.data + assert data["results"] == [] + assert data["total_results"] == 0 + assert "Failed to search" in data["error"] + assert "Connection timeout" in data["error"] + + @pytest.mark.asyncio + @patch.dict(os.environ, {"PERPLEXITY_API_KEY": "test-key"}) # nosec B105 + async def test_runtime_error(self, mock_context): + mock_context.fetch.side_effect = RuntimeError("Network unreachable") + + result = await perplexity.execute_action("search_web", {"query": "test"}, mock_context) + + data = result.result.data + assert data["results"] == [] + assert "Failed to search" in data["error"] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..dd857d39 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,28 @@ +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["."] +# Only unit tests are auto-discovered. Integration tests (test_*_integration.py) +# must be passed as explicit file paths — see CONTRIBUTING.md for details. +python_files = ["test_*_unit.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +markers = [ + "unit: Pure unit tests — mocked, no credentials needed, safe for CI", + "integration: Integration tests — require real API credentials, never run in CI", +] +# --import-mode=importlib: avoids __init__.py package resolution issues +# -m unit: only run CI-safe mocked tests by default +# To run integration tests: pytest /test_*_integration.py -m integration +addopts = "--import-mode=importlib -m unit --tb=short" + +norecursedirs = ["dependencies", "__pycache__", ".git", ".ruff_cache"] + +# Environment variables for integration tests are loaded from .env +# by the root conftest.py (stdlib-only, no plugin dependency). + +[tool.coverage.run] +omit = ["*/tests/*", "conftest.py"] + +[tool.coverage.report] +show_missing = true +skip_empty = true diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 00000000..eb849232 --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,3 @@ +pytest>=9.0 +pytest-asyncio>=0.23 +pytest-cov>=4.0 diff --git a/shopify-customer/tests/conftest.py b/shopify-customer/tests/conftest.py new file mode 100644 index 00000000..1d99cac4 --- /dev/null +++ b/shopify-customer/tests/conftest.py @@ -0,0 +1,5 @@ +import sys +import os + +# Allow 'from context import ...' to work when pytest runs from repo root +sys.path.insert(0, os.path.dirname(__file__)) diff --git a/shopify-customer/tests/test_shopify_customer_unit.py b/shopify-customer/tests/test_shopify_customer_unit.py new file mode 100644 index 00000000..e3573372 --- /dev/null +++ b/shopify-customer/tests/test_shopify_customer_unit.py @@ -0,0 +1,527 @@ +""" +Unit tests for Shopify Customer Account API Integration. + +Refactored from test_unit.py with improved style, additional coverage, +and proper pytest patterns (no bare try/except). +""" + +import importlib.util +import os + +_parent = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + +import pytest # noqa: E402 +from unittest.mock import AsyncMock, MagicMock # noqa: E402 + +from autohive_integrations_sdk.integration import ValidationError # noqa: E402 + +_spec = importlib.util.spec_from_file_location("shopify_customer_mod", os.path.join(_parent, "shopify_customer.py")) +_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_mod) + +shopify_customer = _mod.shopify_customer +get_shop_url = _mod.get_shop_url +get_customer_api_url = _mod.get_customer_api_url +build_headers = _mod.build_headers +generate_pkce_pair = _mod.generate_pkce_pair +build_authorization_url = _mod.build_authorization_url +extract_edges = _mod.extract_edges + +pytestmark = pytest.mark.unit + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def mock_context(): + ctx = MagicMock(name="ExecutionContext") + ctx.fetch = AsyncMock(name="fetch") + ctx.auth = { + "credentials": { + "access_token": "test_token_123", # nosec B105 + "shop_url": "test-store.myshopify.com", + "client_id": "test_client_id", + } + } + return ctx + + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +class TestHelpers: + def test_get_shop_url_strips_https(self): + ctx = MagicMock() + ctx.auth = {"credentials": {"shop_url": "https://test-store.myshopify.com/"}} + assert get_shop_url(ctx) == "test-store.myshopify.com" + + def test_get_shop_url_strips_http(self): + ctx = MagicMock() + ctx.auth = {"credentials": {"shop_url": "http://test-store.myshopify.com/"}} + assert get_shop_url(ctx) == "test-store.myshopify.com" + + def test_get_shop_url_plain(self): + ctx = MagicMock() + ctx.auth = {"credentials": {"shop_url": "test-store.myshopify.com"}} + assert get_shop_url(ctx) == "test-store.myshopify.com" + + def test_get_customer_api_url_correct_path(self): + ctx = MagicMock() + ctx.auth = {"credentials": {"shop_url": "test-store.myshopify.com"}} + result = get_customer_api_url(ctx) + assert result == "https://test-store.myshopify.com/customer/api/2024-10/graphql" + assert "/account/" not in result + + def test_build_headers(self): + ctx = MagicMock() + ctx.auth = {"credentials": {"access_token": "test_token"}} # nosec B105 + result = build_headers(ctx) + assert result["Authorization"] == "Bearer test_token" + assert result["Content-Type"] == "application/json" + + +# ============================================================================ +# OAuth Helpers +# ============================================================================ + + +class TestOAuthHelpers: + def test_generate_pkce_pair_returns_two_distinct_values(self): + verifier, challenge = generate_pkce_pair() + assert len(verifier) > 0 + assert len(challenge) > 0 + assert verifier != challenge + + def test_generate_pkce_pair_unique_each_call(self): + pair_a = generate_pkce_pair() + pair_b = generate_pkce_pair() + assert pair_a[0] != pair_b[0] + + def test_build_authorization_url(self): + url = build_authorization_url( + shop_url="test-store.myshopify.com", + client_id="test_client", + redirect_uri="https://example.com/callback", + scopes=["openid", "email"], + state="test_state", + code_challenge="test_challenge", + ) + assert "test-store.myshopify.com" in url + assert "/authentication/oauth/authorize" in url + assert "client_id=test_client" in url + assert "openid" in url + assert "email" in url + assert "code_challenge=test_challenge" in url + assert "code_challenge_method=S256" in url + + +# ============================================================================ +# extract_edges +# ============================================================================ + + +class TestExtractEdges: + def test_extracts_nodes_from_edges(self): + data = { + "orders": { + "edges": [ + {"node": {"id": "1", "name": "Order #1"}}, + {"node": {"id": "2", "name": "Order #2"}}, + ], + } + } + result = extract_edges(data, "orders") + assert len(result) == 2 + assert result[0]["id"] == "1" + assert result[1]["name"] == "Order #2" + + def test_nested_path(self): + data = { + "customer": { + "addresses": { + "edges": [{"node": {"id": "addr_1"}}], + } + } + } + result = extract_edges(data, "customer.addresses") + assert len(result) == 1 + assert result[0]["id"] == "addr_1" + + def test_returns_empty_list_when_path_missing(self): + assert extract_edges({}, "missing.path") == [] + + def test_returns_empty_list_when_none_in_path(self): + data = {"customer": None} + assert extract_edges(data, "customer.addresses") == [] + + def test_returns_empty_list_when_no_edges_key(self): + data = {"orders": {"something_else": []}} + assert extract_edges(data, "orders") == [] + + +# ============================================================================ +# Action: customer_get_profile +# ============================================================================ + + +class TestGetProfile: + @pytest.mark.asyncio + async def test_success(self, mock_context): + mock_context.fetch.return_value = { + "data": { + "customer": { + "id": "gid://shopify/Customer/123", + "email": "test@example.com", + "firstName": "Test", + "lastName": "User", + } + } + } + + result = await shopify_customer.execute_action("customer_get_profile", {}, mock_context) + + assert result.result.data["success"] is True + assert result.result.data["customer"]["email"] == "test@example.com" + + @pytest.mark.asyncio + async def test_graphql_error(self, mock_context): + mock_context.fetch.return_value = {"errors": [{"message": "Unauthorized"}]} + + with pytest.raises(ValidationError): + await shopify_customer.execute_action("customer_get_profile", {}, mock_context) + + +# ============================================================================ +# Action: customer_list_addresses +# ============================================================================ + + +class TestListAddresses: + @pytest.mark.asyncio + async def test_success(self, mock_context): + mock_context.fetch.return_value = { + "data": { + "customer": { + "addresses": { + "edges": [ + { + "cursor": "cursor1", + "node": { + "id": "gid://shopify/CustomerAddress/1", + "address1": "123 Main St", + "city": "New York", + }, + } + ], + "pageInfo": { + "hasNextPage": False, + "endCursor": "end_cursor_value", + }, + }, + "defaultAddress": {"id": "gid://shopify/CustomerAddress/1"}, + } + } + } + + result = await shopify_customer.execute_action("customer_list_addresses", {"first": 10}, mock_context) + + assert result.result.data["success"] is True + assert result.result.data["count"] == 1 + assert result.result.data["addresses"][0]["city"] == "New York" + + +# ============================================================================ +# Action: customer_create_address +# ============================================================================ + + +class TestCreateAddress: + @pytest.mark.asyncio + async def test_success(self, mock_context): + mock_context.fetch.return_value = { + "data": { + "customerAddressCreate": { + "customerAddress": { + "id": "gid://shopify/CustomerAddress/new", + "address1": "456 Oak Ave", + "city": "Los Angeles", + }, + "userErrors": [], + } + } + } + + result = await shopify_customer.execute_action( + "customer_create_address", + { + "address1": "456 Oak Ave", + "city": "Los Angeles", + "country": "US", + "zip": "90001", + }, + mock_context, + ) + + assert result.result.data["success"] is True + assert result.result.data["address"]["city"] == "Los Angeles" + + @pytest.mark.asyncio + async def test_user_error(self, mock_context): + mock_context.fetch.return_value = { + "data": { + "customerAddressCreate": { + "customerAddress": None, + "userErrors": [{"field": "zip", "message": "Invalid postal code"}], + } + } + } + + with pytest.raises(ValidationError): + await shopify_customer.execute_action( + "customer_create_address", + { + "address1": "456 Oak Ave", + "city": "LA", + "country": "US", + "zip": "invalid", + }, + mock_context, + ) + + +# ============================================================================ +# Action: customer_list_orders +# ============================================================================ + + +class TestListOrders: + @pytest.mark.asyncio + async def test_success(self, mock_context): + mock_context.fetch.return_value = { + "data": { + "customer": { + "orders": { + "edges": [ + { + "cursor": "cursor1", + "node": { + "id": "gid://shopify/Order/123", + "orderNumber": 1001, + "totalPrice": { + "amount": "99.99", + "currencyCode": "USD", + }, + }, + } + ], + "pageInfo": { + "hasNextPage": False, + "endCursor": "end_cursor_value", + }, + } + } + } + } + + result = await shopify_customer.execute_action("customer_list_orders", {"first": 10}, mock_context) + + assert result.result.data["success"] is True + assert result.result.data["count"] == 1 + assert result.result.data["orders"][0]["orderNumber"] == 1001 + + +# ============================================================================ +# Action: customer_get_order +# ============================================================================ + + +class TestGetOrder: + @pytest.mark.asyncio + async def test_success(self, mock_context): + mock_context.fetch.return_value = { + "data": { + "customer": { + "order": { + "id": "gid://shopify/Order/456", + "orderNumber": 1002, + "fulfillmentStatus": "FULFILLED", + "totalPrice": {"amount": "49.99", "currencyCode": "USD"}, + } + } + } + } + + result = await shopify_customer.execute_action( + "customer_get_order", + {"order_id": "gid://shopify/Order/456"}, + mock_context, + ) + + assert result.result.data["success"] is True + assert result.result.data["order"]["orderNumber"] == 1002 + + @pytest.mark.asyncio + async def test_not_found(self, mock_context): + mock_context.fetch.return_value = {"data": {"customer": {"order": None}}} + + with pytest.raises(ValidationError): + await shopify_customer.execute_action( + "customer_get_order", + {"order_id": "gid://shopify/Order/999"}, + mock_context, + ) + + +# ============================================================================ +# Action: customer_set_default_address +# ============================================================================ + + +class TestSetDefaultAddress: + @pytest.mark.asyncio + async def test_success(self, mock_context): + mock_context.fetch.return_value = { + "data": { + "customerDefaultAddressUpdate": { + "customer": {"defaultAddress": {"id": "gid://shopify/CustomerAddress/1"}}, + "userErrors": [], + } + } + } + + result = await shopify_customer.execute_action( + "customer_set_default_address", + {"address_id": "gid://shopify/CustomerAddress/1"}, + mock_context, + ) + + assert result.result.data["success"] is True + assert result.result.data["default_address_id"] == "gid://shopify/CustomerAddress/1" + + @pytest.mark.asyncio + async def test_user_error(self, mock_context): + mock_context.fetch.return_value = { + "data": { + "customerDefaultAddressUpdate": { + "customer": {"defaultAddress": None}, + "userErrors": [{"field": "addressId", "message": "Address not found"}], + } + } + } + + result = await shopify_customer.execute_action( + "customer_set_default_address", + {"address_id": "gid://shopify/CustomerAddress/999"}, + mock_context, + ) + + assert result.result.data["success"] is False + assert "Address not found" in result.result.data["message"] + + +# ============================================================================ +# Action: customer_generate_oauth_url +# ============================================================================ + + +class TestGenerateOAuthUrl: + @pytest.mark.asyncio + async def test_success(self, mock_context): + result = await shopify_customer.execute_action( + "customer_generate_oauth_url", + { + "client_id": "test_client", + "redirect_uri": "https://example.com/callback", + }, + mock_context, + ) + + assert result.result.data["success"] is True + assert "authorization_url" in result.result.data + assert "code_verifier" in result.result.data + assert "state" in result.result.data + assert "/authentication/oauth/authorize" in result.result.data["authorization_url"] + + @pytest.mark.asyncio + async def test_missing_client_id(self, mock_context): + with pytest.raises(ValidationError): + await shopify_customer.execute_action( + "customer_generate_oauth_url", + {"redirect_uri": "https://example.com/callback"}, + mock_context, + ) + + @pytest.mark.asyncio + async def test_missing_redirect_uri(self, mock_context): + with pytest.raises(ValidationError): + await shopify_customer.execute_action( + "customer_generate_oauth_url", + {"client_id": "test_client"}, + mock_context, + ) + + +# ============================================================================ +# Error Handling +# ============================================================================ + + +class TestErrorHandling: + @pytest.mark.asyncio + async def test_get_profile_fetch_exception(self, mock_context): + mock_context.fetch.side_effect = RuntimeError("Connection refused") + + with pytest.raises(ValidationError): + await shopify_customer.execute_action("customer_get_profile", {}, mock_context) + + @pytest.mark.asyncio + async def test_list_orders_fetch_exception(self, mock_context): + mock_context.fetch.side_effect = RuntimeError("Timeout") + + result = await shopify_customer.execute_action("customer_list_orders", {"first": 10}, mock_context) + + assert result.result.data["success"] is False + + @pytest.mark.asyncio + async def test_create_address_fetch_exception(self, mock_context): + mock_context.fetch.side_effect = RuntimeError("Network error") + + with pytest.raises(ValidationError): + await shopify_customer.execute_action( + "customer_create_address", + { + "address1": "123 Main St", + "city": "Test", + "country": "US", + "zip": "12345", + }, + mock_context, + ) + + @pytest.mark.asyncio + async def test_set_default_address_fetch_exception(self, mock_context): + mock_context.fetch.side_effect = RuntimeError("Service unavailable") + + result = await shopify_customer.execute_action( + "customer_set_default_address", + {"address_id": "gid://shopify/CustomerAddress/1"}, + mock_context, + ) + + assert result.result.data["success"] is False + + @pytest.mark.asyncio + async def test_get_order_fetch_exception(self, mock_context): + mock_context.fetch.side_effect = RuntimeError("Bad gateway") + + with pytest.raises(ValidationError): + await shopify_customer.execute_action( + "customer_get_order", + {"order_id": "gid://shopify/Order/123"}, + mock_context, + )