diff --git a/README.md b/README.md index 3421bfc5..117a288d 100644 --- a/README.md +++ b/README.md @@ -251,6 +251,10 @@ Supports basic HTTP authentication and Bearer token authentication via the SDK. [float](float): Comprehensive resource management and project scheduling integration with Float API for team capacity planning, time tracking, and project coordination. Supports full CRUD operations for team members (people) with roles, departments, rates, and availability management. Includes complete project lifecycle management with client associations, budgets, timelines, and team assignments. Features task/allocation scheduling across team members, time off management with leave types, logged time tracking with billable hours, and client relationship management. Provides access to organizational structure (departments, roles), account settings, project stages, phases, milestones, and expenses. Includes comprehensive reporting capabilities (people utilization, project analytics) with date range filtering. Features 60 actions covering all Float API v3 endpoints, custom API key authentication with required User-Agent header, connected account information display, pagination support (up to 200 items per page), rate limiting awareness (200 GET/min, 100 non-GET/min), field filtering, sorting, modified-since sync capabilities, and ActionResult return type for cost tracking. Ideal for resource planning, capacity management, project scheduling, time tracking workflows, and team utilization analysis. +### Salesforce + +[salesforce](salesforce): Salesforce is the world's leading CRM platform for managing sales pipelines, customer relationships, and activity tracking. This integration provides 7 focused actions covering record search via SOQL, single-record retrieval and update, task and event listing with filters, and human-readable summaries of task and event records. Supports OAuth 2.0 (platform) authentication. Ideal for sales automation, CRM data updates, and surfacing task and event activity within workflows. + ### Shopify Admin [shopify-admin](shopify-admin): Integrates with the Shopify Admin API for backend store management. Currently enables comprehensive customer lifecycle management including searching, creating, updating, and deleting customer records via the GraphQL API. diff --git a/salesforce/README.md b/salesforce/README.md new file mode 100644 index 00000000..f4a1b668 --- /dev/null +++ b/salesforce/README.md @@ -0,0 +1,55 @@ +# Salesforce + +Salesforce is the world's leading CRM platform, used to manage sales pipelines, customer relationships, tasks, events, and more. This integration provides 7 focused actions for searching and updating records, and for retrieving summaries of task and event activity. + +## Auth Setup + +This integration uses **OAuth 2.0** via a Salesforce Connected App. + +1. Log in to Salesforce and go to **Setup → App Manager → New Connected App**. +2. Enable **OAuth Settings** and set a callback URL. +3. Add the **`api`** scope under Selected OAuth Scopes. +4. Save and copy the **Consumer Key** (Client ID) and **Consumer Secret**. +5. Use those credentials to connect via the Autohive platform OAuth flow. + +## Actions + +| Action | Description | Key Inputs | Key Outputs | +|--------|-------------|------------|-------------| +| `search_records` | Run a SOQL query against any object | `soql` | `records`, `total_size` | +| `get_record` | Fetch a single record by ID | `object_type`, `record_id` | `record` | +| `update_record` | Update fields on a record | `object_type`, `record_id`, `fields` | `result`, `record_id` | +| `list_tasks` | List Task records with optional filters | `status`, `due_date_from`, `due_date_to`, `limit` | `tasks`, `total_size` | +| `list_events` | List Event records with optional date filters | `start_date_from`, `start_date_to`, `limit` | `events`, `total_size` | +| `get_task_summary` | Get a readable summary of a Task | `task_id` | `summary`, `task` | +| `get_event_summary` | Get a readable summary of an Event | `event_id` | `summary`, `event` | + +## API Info + +- **Base URL:** `https://{instance_url}/services/data/v62.0/` +- **Docs:** https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/ +- **Rate limits:** Typically 15,000 API calls per 24 hours (varies by Salesforce edition) +- **Query endpoint:** `GET /query?q={SOQL}` +- **Record endpoint:** `GET /sobjects/{ObjectType}/{Id}` + +## Running Tests + +```bash +cd salesforce/tests +export SALESFORCE_TOKEN=your_access_token +export SALESFORCE_INSTANCE_URL=https://yourinstance.salesforce.com +# Optional: set record IDs to test get/update actions +export SALESFORCE_RECORD_ID=003XXXXXXXXXXXXXXX +export SALESFORCE_TASK_ID=00TXXXXXXXXXXXXXXX +export SALESFORCE_EVENT_ID=00UXXXXXXXXXXXXXXX +python test_salesforce.py +``` + +## Troubleshooting + +| Error | Cause | Fix | +|-------|-------|-----| +| `401 Unauthorized` | Expired or invalid access token | Re-authenticate via Autohive OAuth flow | +| `400 MALFORMED_QUERY` | Invalid SOQL syntax | Check field names, quote strings with single quotes | +| `404 NOT_FOUND` | Record ID doesn't exist or wrong object type | Verify ID and object type match | +| `REQUEST_LIMIT_EXCEEDED` | Daily API call limit hit | Wait until the 24-hour window resets | diff --git a/salesforce/__init__.py b/salesforce/__init__.py new file mode 100644 index 00000000..511c4e4d --- /dev/null +++ b/salesforce/__init__.py @@ -0,0 +1,3 @@ +from .salesforce import salesforce + +__all__ = ["salesforce"] diff --git a/salesforce/config.json b/salesforce/config.json new file mode 100644 index 00000000..465ab3a5 --- /dev/null +++ b/salesforce/config.json @@ -0,0 +1,247 @@ +{ + "name": "Salesforce", + "display_name": "Salesforce", + "version": "1.0.0", + "description": "Salesforce CRM integration for searching and updating records, and summarising task and event activity.", + "entry_point": "salesforce.py", + "auth": { + "type": "platform", + "provider": "salesforce", + "scopes": ["api"] + }, + "actions": { + "search_records": { + "display_name": "Search Records", + "description": "Run a SOQL query to search any Salesforce object (e.g. Contact, Lead, Opportunity, Account). Returns matching records.", + "input_schema": { + "type": "object", + "properties": { + "soql": { + "type": "string", + "description": "A valid SOQL query string, e.g. SELECT Id, Name, Email FROM Contact WHERE LastName = 'Smith' LIMIT 10" + } + }, + "required": ["soql"] + }, + "output_schema": { + "type": "object", + "properties": { + "result": { "type": "boolean" }, + "records": { + "type": "array", + "items": { "type": "object" }, + "description": "List of matching Salesforce records" + }, + "total_size": { + "type": "integer", + "description": "Total number of records matched" + }, + "done": { + "type": "boolean", + "description": "Whether all results have been returned" + } + } + } + }, + "get_record": { + "display_name": "Get Record", + "description": "Retrieve a single Salesforce record by its ID and object type (e.g. Contact, Lead, Opportunity).", + "input_schema": { + "type": "object", + "properties": { + "object_type": { + "type": "string", + "description": "Salesforce object type, e.g. Contact, Lead, Account, Opportunity" + }, + "record_id": { + "type": "string", + "description": "The Salesforce record ID (15 or 18 character)" + }, + "fields": { + "type": "string", + "description": "Comma-separated list of fields to return. If omitted, all fields are returned." + } + }, + "required": ["object_type", "record_id"] + }, + "output_schema": { + "type": "object", + "properties": { + "result": { "type": "boolean" }, + "record": { + "type": "object", + "description": "The Salesforce record fields" + } + } + } + }, + "update_record": { + "display_name": "Update Record", + "description": "Update one or more fields on an existing Salesforce record by ID and object type.", + "input_schema": { + "type": "object", + "properties": { + "object_type": { + "type": "string", + "description": "Salesforce object type, e.g. Contact, Lead, Account, Opportunity" + }, + "record_id": { + "type": "string", + "description": "The Salesforce record ID to update" + }, + "fields": { + "type": "object", + "description": "Key-value pairs of fields to update, e.g. {\"Phone\": \"0400000000\", \"Title\": \"Manager\"}" + } + }, + "required": ["object_type", "record_id", "fields"] + }, + "output_schema": { + "type": "object", + "properties": { + "result": { "type": "boolean" }, + "record_id": { "type": "string" }, + "object_type": { "type": "string" } + } + } + }, + "list_tasks": { + "display_name": "List Tasks", + "description": "List Salesforce Task records with optional filters for status, due date, and assigned user.", + "input_schema": { + "type": "object", + "properties": { + "status": { + "type": "string", + "description": "Filter by task status, e.g. Not Started, In Progress, Completed, Waiting on someone else, Deferred" + }, + "assigned_to_id": { + "type": "string", + "description": "Filter by assigned user ID (OwnerId)" + }, + "due_date_from": { + "type": "string", + "description": "Filter tasks due on or after this date (YYYY-MM-DD)" + }, + "due_date_to": { + "type": "string", + "description": "Filter tasks due on or before this date (YYYY-MM-DD)" + }, + "limit": { + "type": "integer", + "description": "Maximum number of tasks to return (default 25, max 200)", + "default": 25 + } + }, + "required": [] + }, + "output_schema": { + "type": "object", + "properties": { + "result": { "type": "boolean" }, + "tasks": { + "type": "array", + "items": { "type": "object" }, + "description": "List of Task records" + }, + "total_size": { "type": "integer" } + } + } + }, + "list_events": { + "display_name": "List Events", + "description": "List Salesforce Event (calendar) records with optional date range filters.", + "input_schema": { + "type": "object", + "properties": { + "start_date_from": { + "type": "string", + "description": "Return events starting on or after this date (YYYY-MM-DD)" + }, + "start_date_to": { + "type": "string", + "description": "Return events starting on or before this date (YYYY-MM-DD)" + }, + "assigned_to_id": { + "type": "string", + "description": "Filter by assigned user ID (OwnerId)" + }, + "limit": { + "type": "integer", + "description": "Maximum number of events to return (default 25, max 200)", + "default": 25 + } + }, + "required": [] + }, + "output_schema": { + "type": "object", + "properties": { + "result": { "type": "boolean" }, + "events": { + "type": "array", + "items": { "type": "object" }, + "description": "List of Event records" + }, + "total_size": { "type": "integer" } + } + } + }, + "get_task_summary": { + "display_name": "Get Task Summary", + "description": "Retrieve a single Salesforce Task record by ID and return a human-readable summary of its details.", + "input_schema": { + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "The Salesforce Task record ID" + } + }, + "required": ["task_id"] + }, + "output_schema": { + "type": "object", + "properties": { + "result": { "type": "boolean" }, + "summary": { + "type": "string", + "description": "Human-readable summary of the task" + }, + "task": { + "type": "object", + "description": "Raw task record fields" + } + } + } + }, + "get_event_summary": { + "display_name": "Get Event Summary", + "description": "Retrieve a single Salesforce Event record by ID and return a human-readable summary of its details.", + "input_schema": { + "type": "object", + "properties": { + "event_id": { + "type": "string", + "description": "The Salesforce Event record ID" + } + }, + "required": ["event_id"] + }, + "output_schema": { + "type": "object", + "properties": { + "result": { "type": "boolean" }, + "summary": { + "type": "string", + "description": "Human-readable summary of the event" + }, + "event": { + "type": "object", + "description": "Raw event record fields" + } + } + } + } + } +} diff --git a/salesforce/icon.png b/salesforce/icon.png new file mode 100644 index 00000000..9a6cb7c4 Binary files /dev/null and b/salesforce/icon.png differ diff --git a/salesforce/requirements.txt b/salesforce/requirements.txt new file mode 100644 index 00000000..1af9591f --- /dev/null +++ b/salesforce/requirements.txt @@ -0,0 +1 @@ +autohive-integrations-sdk~=2.0.0 diff --git a/salesforce/salesforce.py b/salesforce/salesforce.py new file mode 100644 index 00000000..70b9a9c5 --- /dev/null +++ b/salesforce/salesforce.py @@ -0,0 +1,288 @@ +from autohive_integrations_sdk import ( + Integration, + ExecutionContext, + ActionHandler, + ActionResult, +) +from typing import Any, Dict +import os + +salesforce = Integration.load() + +API_VERSION = "v62.0" + + +def _base_url(instance_url: str) -> str: + return f"{instance_url.rstrip('/')}/services/data/{API_VERSION}" + + +def _headers(token: str) -> Dict[str, str]: + return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + +def _get_token_and_instance(context: ExecutionContext): + credentials = context.auth.get("credentials", {}) + token = credentials.get("access_token", "") + instance_url = ( + credentials.get("instance_url") + or context.metadata.get("instance_url") + or os.environ.get("SALESFORCE_INSTANCE_URL", "") + ) + if not instance_url: + raise ValueError("Salesforce instance_url not found in credentials or metadata. Please reconnect.") + return token, instance_url + + +@salesforce.action("search_records") +class SearchRecordsAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + token, instance_url = _get_token_and_instance(context) + response = await context.fetch( + f"{_base_url(instance_url)}/query", + method="GET", + headers=_headers(token), + params={"q": inputs["soql"]}, + ) + return ActionResult( + data={ + "result": True, + "records": response.data.get("records", []), + "total_size": response.data.get("totalSize", 0), + "done": response.data.get("done", True), + }, + cost_usd=0.0, + ) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@salesforce.action("get_record") +class GetRecordAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + token, instance_url = _get_token_and_instance(context) + object_type = inputs["object_type"] + record_id = inputs["record_id"] + url = f"{_base_url(instance_url)}/sobjects/{object_type}/{record_id}" + + params = {} + if inputs.get("fields"): + params["fields"] = inputs["fields"] + + response = await context.fetch(url, method="GET", headers=_headers(token), params=params) + return ActionResult(data={"result": True, "record": response.data}, cost_usd=0.0) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@salesforce.action("update_record") +class UpdateRecordAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + token, instance_url = _get_token_and_instance(context) + object_type = inputs["object_type"] + record_id = inputs["record_id"] + url = f"{_base_url(instance_url)}/sobjects/{object_type}/{record_id}" + + await context.fetch(url, method="PATCH", headers=_headers(token), json=inputs["fields"]) + return ActionResult( + data={ + "result": True, + "record_id": record_id, + "object_type": object_type, + }, + cost_usd=0.0, + ) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +def _build_task_query( # nosec B608 + status=None, + assigned_to_id=None, + due_date_from=None, + due_date_to=None, + limit=25, +) -> str: + limit = min(int(limit), 200) + conditions = [] + if status: + safe_status = status.replace("'", "\\'") + conditions.append(f"Status = '{safe_status}'") + if assigned_to_id: + conditions.append(f"OwnerId = '{assigned_to_id}'") + if due_date_from: + conditions.append(f"ActivityDate >= {due_date_from}") + if due_date_to: + conditions.append(f"ActivityDate <= {due_date_to}") + + where = f" WHERE {' AND '.join(conditions)}" if conditions else "" + fields = ( + "Id, Subject, Status, Priority, ActivityDate, Description, " + "OwnerId, WhoId, WhatId, CreatedDate, LastModifiedDate" + ) + return f"SELECT {fields} FROM Task{where} ORDER BY ActivityDate DESC LIMIT {limit}" # nosec B608 + + +def _build_event_query( # nosec B608 + start_date_from=None, + start_date_to=None, + assigned_to_id=None, + limit=25, +) -> str: + limit = min(int(limit), 200) + conditions = [] + if start_date_from: + conditions.append(f"StartDateTime >= {start_date_from}T00:00:00Z") + if start_date_to: + conditions.append(f"StartDateTime <= {start_date_to}T23:59:59Z") + if assigned_to_id: + conditions.append(f"OwnerId = '{assigned_to_id}'") + + where = f" WHERE {' AND '.join(conditions)}" if conditions else "" + fields = ( + "Id, Subject, StartDateTime, EndDateTime, Location, Description, " + "OwnerId, WhoId, WhatId, IsAllDayEvent, CreatedDate" + ) + return f"SELECT {fields} FROM Event{where} ORDER BY StartDateTime DESC LIMIT {limit}" # nosec B608 + + +@salesforce.action("list_tasks") +class ListTasksAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + token, instance_url = _get_token_and_instance(context) + soql = _build_task_query( + status=inputs.get("status"), + assigned_to_id=inputs.get("assigned_to_id"), + due_date_from=inputs.get("due_date_from"), + due_date_to=inputs.get("due_date_to"), + limit=inputs.get("limit", 25), + ) + response = await context.fetch( + f"{_base_url(instance_url)}/query", + method="GET", + headers=_headers(token), + params={"q": soql}, + ) + return ActionResult( + data={ + "result": True, + "tasks": response.data.get("records", []), + "total_size": response.data.get("totalSize", 0), + }, + cost_usd=0.0, + ) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@salesforce.action("list_events") +class ListEventsAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + token, instance_url = _get_token_and_instance(context) + soql = _build_event_query( + start_date_from=inputs.get("start_date_from"), + start_date_to=inputs.get("start_date_to"), + assigned_to_id=inputs.get("assigned_to_id"), + limit=inputs.get("limit", 25), + ) + response = await context.fetch( + f"{_base_url(instance_url)}/query", + method="GET", + headers=_headers(token), + params={"q": soql}, + ) + return ActionResult( + data={ + "result": True, + "events": response.data.get("records", []), + "total_size": response.data.get("totalSize", 0), + }, + cost_usd=0.0, + ) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +def _summarise_task(task: Dict[str, Any]) -> str: + subject = task.get("Subject") or "No subject" + status = task.get("Status") or "Unknown" + priority = task.get("Priority") or "Normal" + due = task.get("ActivityDate") or "No due date" + description = task.get("Description") or "No description" + return f"Task: {subject}\nStatus: {status} | Priority: {priority} | Due: {due}\nDescription: {description}" + + +def _summarise_event(event: Dict[str, Any]) -> str: + subject = event.get("Subject") or "No subject" + start = event.get("StartDateTime") or "Unknown start" + end = event.get("EndDateTime") or "Unknown end" + location = event.get("Location") or "No location" + description = event.get("Description") or "No description" + all_day = " (All day)" if event.get("IsAllDayEvent") else "" + return f"Event: {subject}{all_day}\nStart: {start} | End: {end} | Location: {location}\nDescription: {description}" + + +@salesforce.action("get_task_summary") +class GetTaskSummaryAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + token, instance_url = _get_token_and_instance(context) + task_id = inputs["task_id"] + fields = ( + "Id, Subject, Status, Priority, ActivityDate, Description, " + "OwnerId, WhoId, WhatId, CreatedDate, LastModifiedDate" + ) + soql = f"SELECT {fields} FROM Task WHERE Id = '{task_id}' LIMIT 1" # nosec B608 + response = await context.fetch( + f"{_base_url(instance_url)}/query", + method="GET", + headers=_headers(token), + params={"q": soql}, + ) + records = response.data.get("records", []) + if not records: + return ActionResult(data={"result": False, "error": "Task not found"}, cost_usd=0.0) + task = records[0] + return ActionResult( + data={"result": True, "summary": _summarise_task(task), "task": task}, + cost_usd=0.0, + ) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@salesforce.action("get_event_summary") +class GetEventSummaryAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + token, instance_url = _get_token_and_instance(context) + event_id = inputs["event_id"] + fields = ( + "Id, Subject, StartDateTime, EndDateTime, Location, Description, " + "OwnerId, WhoId, WhatId, IsAllDayEvent, CreatedDate" + ) + soql = f"SELECT {fields} FROM Event WHERE Id = '{event_id}' LIMIT 1" # nosec B608 + response = await context.fetch( + f"{_base_url(instance_url)}/query", + method="GET", + headers=_headers(token), + params={"q": soql}, + ) + records = response.data.get("records", []) + if not records: + return ActionResult(data={"result": False, "error": "Event not found"}, cost_usd=0.0) + event = records[0] + return ActionResult( + data={ + "result": True, + "summary": _summarise_event(event), + "event": event, + }, + cost_usd=0.0, + ) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) diff --git a/salesforce/tests/__init__.py b/salesforce/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/salesforce/tests/conftest.py b/salesforce/tests/conftest.py new file mode 100644 index 00000000..e669d95e --- /dev/null +++ b/salesforce/tests/conftest.py @@ -0,0 +1,4 @@ +import os +import sys + +sys.path.insert(0, os.path.dirname(__file__)) diff --git a/salesforce/tests/context.py b/salesforce/tests/context.py new file mode 100644 index 00000000..6b2fcba9 --- /dev/null +++ b/salesforce/tests/context.py @@ -0,0 +1,8 @@ +import os +import sys + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +deps_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../dependencies")) +sys.path.insert(0, parent_dir) +sys.path.insert(0, deps_dir) +from salesforce import salesforce # noqa diff --git a/salesforce/tests/test_salesforce.py b/salesforce/tests/test_salesforce.py new file mode 100644 index 00000000..3ac3fde4 --- /dev/null +++ b/salesforce/tests/test_salesforce.py @@ -0,0 +1,101 @@ +""" +Integration tests for Salesforce — require real API credentials. +Run with: pytest salesforce/tests/test_salesforce.py -m integration +""" + +import asyncio +import os +import sys + +import pytest +from autohive_integrations_sdk import ExecutionContext, IntegrationResult + +from context import salesforce # noqa + +pytestmark = pytest.mark.integration + +ACCESS_TOKEN = sys.argv[1] if len(sys.argv) > 1 else os.getenv("SALESFORCE_TOKEN", "") +INSTANCE_URL = os.getenv("SALESFORCE_INSTANCE_URL", "https://login.salesforce.com") +TEST_AUTH = {"credentials": {"access_token": ACCESS_TOKEN, "instance_url": INSTANCE_URL}} + +RECORD_ID = os.getenv("SALESFORCE_RECORD_ID", "") +TASK_ID = os.getenv("SALESFORCE_TASK_ID", "") +EVENT_ID = os.getenv("SALESFORCE_EVENT_ID", "") + + +async def test_search_records(): + inputs = {"soql": "SELECT Id, Name FROM Contact LIMIT 5"} + async with ExecutionContext(auth=TEST_AUTH) as context: + result = await salesforce.execute_action("search_records", inputs, context) + assert isinstance(result, IntegrationResult) + data = result.result.data + assert data.get("result") is True + print(f"[OK] search_records: {len(data.get('records', []))} record(s)") + + +async def test_list_tasks(): + inputs = {"limit": 5} + async with ExecutionContext(auth=TEST_AUTH) as context: + result = await salesforce.execute_action("list_tasks", inputs, context) + assert isinstance(result, IntegrationResult) + data = result.result.data + assert data.get("result") is True + print(f"[OK] list_tasks: {len(data.get('tasks', []))} task(s)") + + +async def test_list_events(): + inputs = {"limit": 5} + async with ExecutionContext(auth=TEST_AUTH) as context: + result = await salesforce.execute_action("list_events", inputs, context) + assert isinstance(result, IntegrationResult) + data = result.result.data + assert data.get("result") is True + print(f"[OK] list_events: {len(data.get('events', []))} event(s)") + + +async def test_get_record(): + if not RECORD_ID: + print("[SKIP] get_record: set SALESFORCE_RECORD_ID to test") + return + inputs = {"object_type": "Contact", "record_id": RECORD_ID} + async with ExecutionContext(auth=TEST_AUTH) as context: + result = await salesforce.execute_action("get_record", inputs, context) + assert isinstance(result, IntegrationResult) + data = result.result.data + assert data.get("result") is True + print(f"[OK] get_record: {data.get('record', {}).get('Id')}") + + +async def test_get_task_summary(): + if not TASK_ID: + print("[SKIP] get_task_summary: set SALESFORCE_TASK_ID to test") + return + inputs = {"task_id": TASK_ID} + async with ExecutionContext(auth=TEST_AUTH) as context: + result = await salesforce.execute_action("get_task_summary", inputs, context) + assert isinstance(result, IntegrationResult) + data = result.result.data + assert data.get("result") is True + print(f"[OK] get_task_summary:\n{data.get('summary')}") + + +async def test_get_event_summary(): + if not EVENT_ID: + print("[SKIP] get_event_summary: set SALESFORCE_EVENT_ID to test") + return + inputs = {"event_id": EVENT_ID} + async with ExecutionContext(auth=TEST_AUTH) as context: + result = await salesforce.execute_action("get_event_summary", inputs, context) + assert isinstance(result, IntegrationResult) + data = result.result.data + assert data.get("result") is True + print(f"[OK] get_event_summary:\n{data.get('summary')}") + + +if __name__ == "__main__": + asyncio.run(test_search_records()) + asyncio.run(test_list_tasks()) + asyncio.run(test_list_events()) + asyncio.run(test_get_record()) + asyncio.run(test_get_task_summary()) + asyncio.run(test_get_event_summary()) diff --git a/salesforce/tests/test_salesforce_unit.py b/salesforce/tests/test_salesforce_unit.py new file mode 100644 index 00000000..33501dbf --- /dev/null +++ b/salesforce/tests/test_salesforce_unit.py @@ -0,0 +1,549 @@ +""" +Unit tests for Salesforce integration. + +All tests are fully mocked — no real API credentials required. +Covers all 7 action handlers plus helper functions. +""" + +import json +import os +import sys + +import pytest +from unittest.mock import AsyncMock, MagicMock + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) + +from autohive_integrations_sdk import FetchResponse # noqa: E402 + +from salesforce.salesforce import ( # noqa: E402 + SearchRecordsAction, + GetRecordAction, + UpdateRecordAction, + ListTasksAction, + ListEventsAction, + GetTaskSummaryAction, + GetEventSummaryAction, + _build_task_query, + _build_event_query, + _summarise_task, + _summarise_event, + salesforce as salesforce_integration, +) + +pytestmark = pytest.mark.unit + +CONFIG_PATH = os.path.join(os.path.dirname(__file__), "..", "config.json") + +TEST_TOKEN = "test_access_token" # nosec B105 +TEST_INSTANCE = "https://test.salesforce.com" +TEST_AUTH = {"credentials": {"access_token": TEST_TOKEN, "instance_url": TEST_INSTANCE}} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_fetch_response(data: dict) -> MagicMock: + resp = MagicMock(spec=FetchResponse) + resp.data = data + return resp + + +@pytest.fixture +def mock_context(): + ctx = MagicMock(name="ExecutionContext") + ctx.fetch = AsyncMock(name="fetch") + ctx.auth = TEST_AUTH + return ctx + + +# --------------------------------------------------------------------------- +# Config validation +# --------------------------------------------------------------------------- + + +class TestConfigValidation: + def test_actions_match_handlers(self): + with open(CONFIG_PATH) as f: + config = json.load(f) + + defined = set(config.get("actions", {}).keys()) + registered = set(salesforce_integration._action_handlers.keys()) + + missing = defined - registered + extra = registered - defined + + assert not missing, f"Missing handlers: {missing}" + assert not extra, f"Extra handlers without config: {extra}" + + def test_auth_type_is_platform(self): + with open(CONFIG_PATH) as f: + config = json.load(f) + assert config["auth"]["type"] == "platform" + assert config["auth"]["provider"] == "salesforce" + + def test_all_actions_have_output_schema(self): + with open(CONFIG_PATH) as f: + config = json.load(f) + for name, action in config["actions"].items(): + assert "output_schema" in action, f"Action '{name}' missing output_schema" + + def test_all_actions_have_result_in_output(self): + with open(CONFIG_PATH) as f: + config = json.load(f) + for name, action in config["actions"].items(): + props = action.get("output_schema", {}).get("properties", {}) + assert "result" in props, f"Action '{name}' output_schema missing 'result' field" + + +# --------------------------------------------------------------------------- +# Helper function tests +# --------------------------------------------------------------------------- + + +class TestBuildTaskQuery: + def test_no_filters(self): + q = _build_task_query() + assert "FROM Task" in q + assert "WHERE" not in q + assert "LIMIT 25" in q + + def test_status_filter(self): + q = _build_task_query(status="Completed") + assert "Status = 'Completed'" in q + assert "WHERE" in q + + def test_status_escapes_single_quote(self): + q = _build_task_query(status="Won't do") + assert "Won\\'t do" in q + + def test_assigned_to_filter(self): + q = _build_task_query(assigned_to_id="005XXXX") + assert "OwnerId = '005XXXX'" in q + + def test_due_date_range(self): + q = _build_task_query(due_date_from="2026-01-01", due_date_to="2026-12-31") + assert "ActivityDate >= 2026-01-01" in q + assert "ActivityDate <= 2026-12-31" in q + + def test_limit_capped_at_200(self): + q = _build_task_query(limit=999) + assert "LIMIT 200" in q + + def test_custom_limit(self): + q = _build_task_query(limit=10) + assert "LIMIT 10" in q + + def test_multiple_conditions_use_and(self): + q = _build_task_query(status="Open", assigned_to_id="005XXX") + assert " AND " in q + + def test_required_fields_in_select(self): + q = _build_task_query() + for field in ["Id", "Subject", "Status", "Priority", "ActivityDate", "Description"]: + assert field in q + + +class TestBuildEventQuery: + def test_no_filters(self): + q = _build_event_query() + assert "FROM Event" in q + assert "WHERE" not in q + assert "LIMIT 25" in q + + def test_start_date_range(self): + q = _build_event_query(start_date_from="2026-01-01", start_date_to="2026-01-31") + assert "StartDateTime >= 2026-01-01T00:00:00Z" in q + assert "StartDateTime <= 2026-01-31T23:59:59Z" in q + + def test_assigned_to_filter(self): + q = _build_event_query(assigned_to_id="005XXX") + assert "OwnerId = '005XXX'" in q + + def test_limit_capped_at_200(self): + q = _build_event_query(limit=500) + assert "LIMIT 200" in q + + def test_required_fields_in_select(self): + q = _build_event_query() + for field in ["Id", "Subject", "StartDateTime", "EndDateTime", "Location", "Description"]: + assert field in q + + +class TestSummariseTask: + def test_full_task(self): + task = { + "Subject": "Follow up call", + "Status": "Not Started", + "Priority": "High", + "ActivityDate": "2026-05-01", + "Description": "Call the client to follow up on the proposal.", + } + summary = _summarise_task(task) + assert "Follow up call" in summary + assert "Not Started" in summary + assert "High" in summary + assert "2026-05-01" in summary + assert "Call the client" in summary + + def test_missing_fields_use_defaults(self): + summary = _summarise_task({}) + assert "No subject" in summary + assert "Unknown" in summary + assert "No due date" in summary + assert "No description" in summary + + +class TestSummariseEvent: + def test_full_event(self): + event = { + "Subject": "Quarterly review", + "StartDateTime": "2026-05-01T09:00:00Z", + "EndDateTime": "2026-05-01T10:00:00Z", + "Location": "Board Room", + "Description": "Q1 results discussion.", + "IsAllDayEvent": False, + } + summary = _summarise_event(event) + assert "Quarterly review" in summary + assert "2026-05-01T09:00:00Z" in summary + assert "Board Room" in summary + assert "Q1 results" in summary + assert "(All day)" not in summary + + def test_all_day_event_label(self): + event = {"Subject": "Holiday", "IsAllDayEvent": True} + summary = _summarise_event(event) + assert "(All day)" in summary + + def test_missing_fields_use_defaults(self): + summary = _summarise_event({}) + assert "No subject" in summary + assert "No location" in summary + assert "No description" in summary + + +# --------------------------------------------------------------------------- +# Action handler tests +# --------------------------------------------------------------------------- + + +class TestSearchRecordsAction: + async def test_success(self, mock_context): + mock_context.fetch.return_value = make_fetch_response( + {"records": [{"Id": "003XX", "Name": "Jane Doe"}], "totalSize": 1, "done": True} + ) + handler = SearchRecordsAction() + result = await handler.execute({"soql": "SELECT Id, Name FROM Contact LIMIT 1"}, mock_context) + + assert result.data["result"] is True + assert len(result.data["records"]) == 1 + assert result.data["records"][0]["Name"] == "Jane Doe" + assert result.data["total_size"] == 1 + assert result.data["done"] is True + + async def test_passes_soql_as_query_param(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({"records": [], "totalSize": 0, "done": True}) + handler = SearchRecordsAction() + soql = "SELECT Id FROM Lead LIMIT 5" + await handler.execute({"soql": soql}, mock_context) + + call_kwargs = mock_context.fetch.call_args + assert call_kwargs.kwargs["params"]["q"] == soql + + async def test_uses_bearer_auth_header(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({"records": [], "totalSize": 0, "done": True}) + handler = SearchRecordsAction() + await handler.execute({"soql": "SELECT Id FROM Contact"}, mock_context) + + headers = mock_context.fetch.call_args.kwargs["headers"] + assert headers["Authorization"] == f"Bearer {TEST_TOKEN}" + + async def test_error_returns_false(self, mock_context): + mock_context.fetch.side_effect = Exception("API error") + handler = SearchRecordsAction() + result = await handler.execute({"soql": "SELECT Id FROM Contact"}, mock_context) + + assert result.data["result"] is False + assert "API error" in result.data["error"] + + async def test_empty_results(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({"records": [], "totalSize": 0, "done": True}) + handler = SearchRecordsAction() + result = await handler.execute({"soql": "SELECT Id FROM Contact WHERE Name = 'Nobody'"}, mock_context) + + assert result.data["result"] is True + assert result.data["records"] == [] + assert result.data["total_size"] == 0 + + +class TestGetRecordAction: + async def test_success(self, mock_context): + record = {"Id": "003XX", "Name": "Jane Doe", "Email": "jane@example.com"} + mock_context.fetch.return_value = make_fetch_response(record) + handler = GetRecordAction() + result = await handler.execute({"object_type": "Contact", "record_id": "003XX"}, mock_context) + + assert result.data["result"] is True + assert result.data["record"]["Name"] == "Jane Doe" + + async def test_url_contains_object_type_and_id(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({"Id": "003XX"}) + handler = GetRecordAction() + await handler.execute({"object_type": "Contact", "record_id": "003XX"}, mock_context) + + url = mock_context.fetch.call_args.args[0] + assert "/sobjects/Contact/003XX" in url + + async def test_fields_param_passed_when_provided(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({"Id": "003XX", "Name": "Jane"}) + handler = GetRecordAction() + await handler.execute({"object_type": "Contact", "record_id": "003XX", "fields": "Id,Name"}, mock_context) + + params = mock_context.fetch.call_args.kwargs["params"] + assert params["fields"] == "Id,Name" + + async def test_no_fields_param_when_not_provided(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({"Id": "003XX"}) + handler = GetRecordAction() + await handler.execute({"object_type": "Contact", "record_id": "003XX"}, mock_context) + + params = mock_context.fetch.call_args.kwargs.get("params", {}) + assert "fields" not in params + + async def test_error_returns_false(self, mock_context): + mock_context.fetch.side_effect = Exception("Not found") + handler = GetRecordAction() + result = await handler.execute({"object_type": "Contact", "record_id": "BAD"}, mock_context) + + assert result.data["result"] is False + assert "Not found" in result.data["error"] + + +class TestUpdateRecordAction: + async def test_success(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({}) + handler = UpdateRecordAction() + result = await handler.execute( + {"object_type": "Contact", "record_id": "003XX", "fields": {"Phone": "0400000000"}}, + mock_context, + ) + + assert result.data["result"] is True + assert result.data["record_id"] == "003XX" + assert result.data["object_type"] == "Contact" + + async def test_uses_patch_method(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({}) + handler = UpdateRecordAction() + await handler.execute( + {"object_type": "Lead", "record_id": "00QXX", "fields": {"Title": "Manager"}}, + mock_context, + ) + + assert mock_context.fetch.call_args.kwargs["method"] == "PATCH" + + async def test_fields_sent_as_json_body(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({}) + handler = UpdateRecordAction() + fields = {"Phone": "0400000000", "Title": "Director"} + await handler.execute({"object_type": "Contact", "record_id": "003XX", "fields": fields}, mock_context) + + assert mock_context.fetch.call_args.kwargs["json"] == fields + + async def test_error_returns_false(self, mock_context): + mock_context.fetch.side_effect = Exception("Forbidden") + handler = UpdateRecordAction() + result = await handler.execute( + {"object_type": "Contact", "record_id": "003XX", "fields": {"Name": "X"}}, mock_context + ) + + assert result.data["result"] is False + + +class TestListTasksAction: + async def test_success_no_filters(self, mock_context): + tasks = [{"Id": "00TXX", "Subject": "Call client", "Status": "Not Started"}] + mock_context.fetch.return_value = make_fetch_response({"records": tasks, "totalSize": 1}) + handler = ListTasksAction() + result = await handler.execute({}, mock_context) + + assert result.data["result"] is True + assert len(result.data["tasks"]) == 1 + assert result.data["total_size"] == 1 + + async def test_status_filter_applied(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({"records": [], "totalSize": 0}) + handler = ListTasksAction() + await handler.execute({"status": "Completed"}, mock_context) + + soql = mock_context.fetch.call_args.kwargs["params"]["q"] + assert "Status = 'Completed'" in soql + + async def test_date_filter_applied(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({"records": [], "totalSize": 0}) + handler = ListTasksAction() + await handler.execute({"due_date_from": "2026-01-01", "due_date_to": "2026-06-30"}, mock_context) + + soql = mock_context.fetch.call_args.kwargs["params"]["q"] + assert "ActivityDate >= 2026-01-01" in soql + assert "ActivityDate <= 2026-06-30" in soql + + async def test_default_limit_is_25(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({"records": [], "totalSize": 0}) + handler = ListTasksAction() + await handler.execute({}, mock_context) + + soql = mock_context.fetch.call_args.kwargs["params"]["q"] + assert "LIMIT 25" in soql + + async def test_custom_limit(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({"records": [], "totalSize": 0}) + handler = ListTasksAction() + await handler.execute({"limit": 50}, mock_context) + + soql = mock_context.fetch.call_args.kwargs["params"]["q"] + assert "LIMIT 50" in soql + + async def test_error_returns_false(self, mock_context): + mock_context.fetch.side_effect = Exception("timeout") + handler = ListTasksAction() + result = await handler.execute({}, mock_context) + + assert result.data["result"] is False + + +class TestListEventsAction: + async def test_success_no_filters(self, mock_context): + events = [{"Id": "00UXX", "Subject": "Client meeting", "StartDateTime": "2026-05-01T09:00:00Z"}] + mock_context.fetch.return_value = make_fetch_response({"records": events, "totalSize": 1}) + handler = ListEventsAction() + result = await handler.execute({}, mock_context) + + assert result.data["result"] is True + assert len(result.data["events"]) == 1 + assert result.data["total_size"] == 1 + + async def test_date_filter_applied(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({"records": [], "totalSize": 0}) + handler = ListEventsAction() + await handler.execute({"start_date_from": "2026-05-01", "start_date_to": "2026-05-31"}, mock_context) + + soql = mock_context.fetch.call_args.kwargs["params"]["q"] + assert "StartDateTime >= 2026-05-01T00:00:00Z" in soql + assert "StartDateTime <= 2026-05-31T23:59:59Z" in soql + + async def test_default_limit_is_25(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({"records": [], "totalSize": 0}) + handler = ListEventsAction() + await handler.execute({}, mock_context) + + soql = mock_context.fetch.call_args.kwargs["params"]["q"] + assert "LIMIT 25" in soql + + async def test_error_returns_false(self, mock_context): + mock_context.fetch.side_effect = Exception("network error") + handler = ListEventsAction() + result = await handler.execute({}, mock_context) + + assert result.data["result"] is False + + +class TestGetTaskSummaryAction: + async def test_success(self, mock_context): + task = { + "Id": "00TXX", + "Subject": "Follow up", + "Status": "In Progress", + "Priority": "High", + "ActivityDate": "2026-05-10", + "Description": "Check on contract status.", + } + mock_context.fetch.return_value = make_fetch_response({"records": [task], "totalSize": 1}) + handler = GetTaskSummaryAction() + result = await handler.execute({"task_id": "00TXX"}, mock_context) + + assert result.data["result"] is True + assert "Follow up" in result.data["summary"] + assert "In Progress" in result.data["summary"] + assert result.data["task"]["Id"] == "00TXX" + + async def test_task_not_found(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({"records": [], "totalSize": 0}) + handler = GetTaskSummaryAction() + result = await handler.execute({"task_id": "00TBAD"}, mock_context) + + assert result.data["result"] is False + assert "not found" in result.data["error"].lower() + + async def test_soql_filters_by_task_id(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({"records": [], "totalSize": 0}) + handler = GetTaskSummaryAction() + await handler.execute({"task_id": "00TXX123"}, mock_context) + + soql = mock_context.fetch.call_args.kwargs["params"]["q"] + assert "00TXX123" in soql + assert "FROM Task" in soql + + async def test_error_returns_false(self, mock_context): + mock_context.fetch.side_effect = Exception("API error") + handler = GetTaskSummaryAction() + result = await handler.execute({"task_id": "00TXX"}, mock_context) + + assert result.data["result"] is False + + +class TestGetEventSummaryAction: + async def test_success(self, mock_context): + event = { + "Id": "00UXX", + "Subject": "Board meeting", + "StartDateTime": "2026-06-01T09:00:00Z", + "EndDateTime": "2026-06-01T11:00:00Z", + "Location": "HQ", + "Description": "Annual board review.", + "IsAllDayEvent": False, + } + mock_context.fetch.return_value = make_fetch_response({"records": [event], "totalSize": 1}) + handler = GetEventSummaryAction() + result = await handler.execute({"event_id": "00UXX"}, mock_context) + + assert result.data["result"] is True + assert "Board meeting" in result.data["summary"] + assert "HQ" in result.data["summary"] + assert result.data["event"]["Id"] == "00UXX" + + async def test_event_not_found(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({"records": [], "totalSize": 0}) + handler = GetEventSummaryAction() + result = await handler.execute({"event_id": "00UBAD"}, mock_context) + + assert result.data["result"] is False + assert "not found" in result.data["error"].lower() + + async def test_all_day_event_in_summary(self, mock_context): + event = {"Id": "00UXX", "Subject": "Public Holiday", "IsAllDayEvent": True} + mock_context.fetch.return_value = make_fetch_response({"records": [event], "totalSize": 1}) + handler = GetEventSummaryAction() + result = await handler.execute({"event_id": "00UXX"}, mock_context) + + assert "(All day)" in result.data["summary"] + + async def test_soql_filters_by_event_id(self, mock_context): + mock_context.fetch.return_value = make_fetch_response({"records": [], "totalSize": 0}) + handler = GetEventSummaryAction() + await handler.execute({"event_id": "00UABC"}, mock_context) + + soql = mock_context.fetch.call_args.kwargs["params"]["q"] + assert "00UABC" in soql + assert "FROM Event" in soql + + async def test_error_returns_false(self, mock_context): + mock_context.fetch.side_effect = Exception("timeout") + handler = GetEventSummaryAction() + result = await handler.execute({"event_id": "00UXX"}, mock_context) + + assert result.data["result"] is False diff --git a/shotstack/shotstack.py b/shotstack/shotstack.py new file mode 100644 index 00000000..b0826726 --- /dev/null +++ b/shotstack/shotstack.py @@ -0,0 +1,568 @@ +from autohive_integrations_sdk import Integration, ExecutionContext, ActionHandler, ActionResult +from typing import Any, Dict +from urllib.parse import quote +import asyncio +import base64 +import mimetypes +import os + + +config_path = os.path.join(os.path.dirname(__file__), "config.json") +shotstack = Integration.load(config_path) + +EDIT_API_BASE = "https://api.shotstack.io/edit" +INGEST_API_BASE = "https://api.shotstack.io/ingest" + + +def _get_api_key(context: ExecutionContext) -> str: + return context.auth.get("credentials", {}).get("api_key", "") + + +def _get_env(context: ExecutionContext) -> str: + return context.auth.get("credentials", {}).get("environment", "stage") + + +def _get_headers(context: ExecutionContext) -> Dict[str, str]: + return {"x-api-key": _get_api_key(context), "Content-Type": "application/json"} + + +async def _poll_render(context: ExecutionContext, render_id: str, max_wait: int = 300, poll_interval: int = 5) -> Dict[str, Any]: + env = _get_env(context) + elapsed = 0 + while elapsed < max_wait: + response = await context.fetch(f"{EDIT_API_BASE}/{env}/render/{render_id}", method="GET", headers=_get_headers(context)) + render_data = response.get("response", {}) + status = render_data.get("status") + if status == "done": + return {"status": "done", "url": render_data.get("url"), "render": render_data} + elif status == "failed": + return {"status": "failed", "error": render_data.get("error", "Render failed"), "render": render_data} + await asyncio.sleep(poll_interval) + elapsed += poll_interval + return {"status": "timeout", "error": f"Render did not complete within {max_wait} seconds"} + + +async def _poll_source(context: ExecutionContext, source_id: str, max_wait: int = 120, poll_interval: int = 3) -> Dict[str, Any]: + env = _get_env(context) + elapsed = 0 + while elapsed < max_wait: + response = await context.fetch(f"{INGEST_API_BASE}/{env}/sources/{source_id}", method="GET", headers=_get_headers(context)) + source_data = response.get("data", {}) + attributes = source_data.get("attributes", {}) + status = attributes.get("status") + if status == "ready": + return {"status": "ready", "source_url": attributes.get("source"), "source": source_data} + elif status == "failed": + return {"status": "failed", "error": source_data.get("error", "Source processing failed")} + await asyncio.sleep(poll_interval) + elapsed += poll_interval + return {"status": "timeout", "error": f"Source did not become ready within {max_wait} seconds"} + + +async def _download_base64(context: ExecutionContext, url: str) -> Dict[str, Any]: + response = await context.fetch(url, method="GET", headers={"Accept": "*/*"}, raw_response=True) + content_type = response.get("content_type", "application/octet-stream") + if not content_type or content_type == "application/octet-stream": + guessed_type, _ = mimetypes.guess_type(url) + if guessed_type: + content_type = guessed_type + filename = url.split("/")[-1].split("?")[0] or "downloaded_file" + content_bytes = response.get("body", b"") + if isinstance(content_bytes, str): + content_bytes = content_bytes.encode("utf-8") + return {"content": base64.b64encode(content_bytes).decode("utf-8"), "content_type": content_type, "filename": filename, "size": len(content_bytes)} + + +async def _get_media_info(context: ExecutionContext, url: str) -> Dict[str, Any]: + env = _get_env(context) + encoded_url = quote(url, safe="") + response = await context.fetch(f"{EDIT_API_BASE}/{env}/probe/{encoded_url}", method="GET", headers=_get_headers(context)) + return response.get("response", {}) + + +def _position_to_offset(position: str) -> Dict[str, float]: + offsets = { + "center": {"x": 0, "y": 0}, "top": {"x": 0, "y": 0.4}, "topRight": {"x": 0.4, "y": 0.4}, + "right": {"x": 0.4, "y": 0}, "bottomRight": {"x": 0.4, "y": -0.4}, "bottom": {"x": 0, "y": -0.4}, + "bottomLeft": {"x": -0.4, "y": -0.4}, "left": {"x": -0.4, "y": 0}, "topLeft": {"x": -0.4, "y": 0.4}, + } + return offsets.get(position, {"x": 0, "y": 0}) + + +def _build_timeline_from_clips(clips: list, background_color: str = "#000000") -> Dict[str, Any]: + timeline_clips = [] + current_time = 0.0 + for clip in clips: + url = clip.get("url") + duration = clip.get("duration") + start_from = clip.get("start_from", 0) + length = clip.get("length") + fit = clip.get("fit", "crop") + effect = clip.get("effect") + transition = clip.get("transition", {}) + is_image = any(url.lower().endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"]) + if is_image: + asset = {"type": "image", "src": url} + clip_length = duration or 5 + else: + asset = {"type": "video", "src": url} + if start_from: + asset["trim"] = start_from + clip_length = length or duration + timeline_clip = {"asset": asset, "start": current_time, "fit": fit} + if clip_length: + timeline_clip["length"] = clip_length + if effect: + timeline_clip["effect"] = effect + if transition: + timeline_clip["transition"] = transition + timeline_clips.append(timeline_clip) + current_time += clip_length if clip_length else 5 + return {"background": background_color, "tracks": [{"clips": timeline_clips}]} + + +async def _submit_and_maybe_wait(context: ExecutionContext, payload: Dict[str, Any], wait: bool, max_wait: int = 300) -> ActionResult: + env = _get_env(context) + response = await context.fetch(f"{EDIT_API_BASE}/{env}/render", method="POST", headers=_get_headers(context), json=payload) + render_id = response.get("response", {}).get("id") + if not render_id: + return ActionResult(data={"result": False, "error": "Failed to submit render job"}, cost_usd=0.0) + if wait: + poll_result = await _poll_render(context, render_id, max_wait) + if poll_result["status"] == "done": + render_data = poll_result.get("render", {}) + return ActionResult(data={"render_id": render_id, "status": "done", "url": poll_result["url"], "duration": render_data.get("duration"), "result": True}, cost_usd=0.0) + return ActionResult(data={"render_id": render_id, "status": poll_result["status"], "error": poll_result.get("error"), "result": False}, cost_usd=0.0) + return ActionResult(data={"render_id": render_id, "status": "queued", "result": True}, cost_usd=0.0) + + +@shotstack.action("upload_file") +class UploadFileAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + env = _get_env(context) + wait_for_ready = inputs.get("wait_for_ready", False) + file_obj = inputs.get("file") + if file_obj: + content_base64 = file_obj.get("content") + filename = file_obj.get("name") + content_type = file_obj.get("contentType") + file_url = file_obj.get("url") + if file_url and not content_base64: + resp = await context.fetch(file_url, method="GET", raw_response=True) + content_base64 = base64.b64encode(resp.get("body", b"")).decode("utf-8") + else: + content_base64 = inputs.get("content") + filename = inputs.get("filename") + content_type = inputs.get("content_type") + if not content_base64 or not filename: + return ActionResult(data={"result": False, "error": "Missing required file content or filename"}, cost_usd=0.0) + file_bytes = base64.b64decode(content_base64) + if not content_type: + guessed_type, _ = mimetypes.guess_type(filename) + content_type = guessed_type or "application/octet-stream" + response = await context.fetch(f"{INGEST_API_BASE}/{env}/upload", method="POST", headers=_get_headers(context)) + upload_data = response.get("data", {}) + attributes = upload_data.get("attributes", {}) + presigned_url = attributes.get("url") + source_id = upload_data.get("id") + upload_headers = attributes.get("headers", {}) + if not presigned_url: + return ActionResult(data={"result": False, "error": "Failed to get presigned upload URL"}, cost_usd=0.0) + put_headers = upload_headers if upload_headers else {} + await context.fetch(presigned_url, method="PUT", data=file_bytes, headers=put_headers) + if wait_for_ready: + poll_result = await _poll_source(context, source_id) + if poll_result["status"] == "ready": + return ActionResult(data={"source_id": source_id, "source_url": poll_result["source_url"], "status": "ready", "result": True}, cost_usd=0.0) + return ActionResult(data={"source_id": source_id, "status": poll_result["status"], "error": poll_result.get("error"), "result": False}, cost_usd=0.0) + return ActionResult(data={"source_id": source_id, "status": "processing", "result": True}, cost_usd=0.0) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@shotstack.action("check_source_status") +class CheckSourceStatusAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + env = _get_env(context) + source_id = inputs["source_id"] + response = await context.fetch(f"{INGEST_API_BASE}/{env}/sources/{source_id}", method="GET", headers=_get_headers(context)) + source_data = response.get("data", {}) + attributes = source_data.get("attributes", {}) + status = attributes.get("status") + result_data: Dict[str, Any] = {"source_id": source_id, "status": status, "result": True} + if status == "ready": + result_data["source_url"] = attributes.get("source") + result_data["message"] = "File is ready to use in edits!" + elif status == "failed": + result_data["error"] = source_data.get("error", "Source processing failed") + result_data["result"] = False + else: + result_data["message"] = f"File is {status}. Check again in a few seconds." + return ActionResult(data=result_data, cost_usd=0.0) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@shotstack.action("get_upload_url") +class GetUploadUrlAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + env = _get_env(context) + response = await context.fetch(f"{INGEST_API_BASE}/{env}/upload", method="POST", headers=_get_headers(context)) + upload_data = response.get("data", {}) + attributes = upload_data.get("attributes", {}) + return ActionResult(data={"upload_url": attributes.get("url"), "source_id": upload_data.get("id"), "expires": attributes.get("expires"), "result": True}, cost_usd=0.0) + except Exception as e: + return ActionResult(data={"upload_url": None, "source_id": None, "result": False, "error": str(e)}, cost_usd=0.0) + + +@shotstack.action("submit_render") +class SubmitRenderAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + env = _get_env(context) + payload = {"timeline": inputs["timeline"], "output": inputs["output"]} + response = await context.fetch(f"{EDIT_API_BASE}/{env}/render", method="POST", headers=_get_headers(context), json=payload) + render_id = response.get("response", {}).get("id") + if not render_id: + return ActionResult(data={"result": False, "error": "Failed to submit render job"}, cost_usd=0.0) + return ActionResult(data={"render_id": render_id, "status": "queued", "message": "Render job submitted. Use check_render_status to poll for completion.", "result": True}, cost_usd=0.0) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@shotstack.action("check_render_status") +class CheckRenderStatusAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + env = _get_env(context) + render_id = inputs["render_id"] + response = await context.fetch(f"{EDIT_API_BASE}/{env}/render/{render_id}", method="GET", headers=_get_headers(context)) + render_data = response.get("response", {}) + status = render_data.get("status") + result_data: Dict[str, Any] = {"render_id": render_id, "status": status, "result": True} + if status == "done": + result_data["url"] = render_data.get("url") + result_data["duration"] = render_data.get("duration") + result_data["message"] = "Render complete!" + elif status == "failed": + result_data["error"] = render_data.get("error", "Render failed") + result_data["result"] = False + else: + result_data["message"] = f"Render is {status}. Check again in a few seconds." + return ActionResult(data=result_data, cost_usd=0.0) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@shotstack.action("render_and_wait") +class RenderAndWaitAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + payload = {"timeline": inputs["timeline"], "output": inputs["output"]} + max_wait = inputs.get("max_wait_seconds", 300) + poll_interval = inputs.get("poll_interval_seconds", 5) + env = _get_env(context) + response = await context.fetch(f"{EDIT_API_BASE}/{env}/render", method="POST", headers=_get_headers(context), json=payload) + render_id = response.get("response", {}).get("id") + if not render_id: + return ActionResult(data={"result": False, "error": "Failed to submit render job"}, cost_usd=0.0) + poll_result = await _poll_render(context, render_id, max_wait, poll_interval) + if poll_result["status"] == "done": + return ActionResult(data={"render_id": render_id, "status": "done", "url": poll_result["url"], "duration": poll_result.get("render", {}).get("duration"), "result": True}, cost_usd=0.0) + return ActionResult(data={"render_id": render_id, "status": poll_result["status"], "error": poll_result.get("error"), "result": False}, cost_usd=0.0) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@shotstack.action("download_render") +class DownloadRenderAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + env = _get_env(context) + render_id = inputs.get("render_id") + url = inputs.get("url") + if render_id and not url: + response = await context.fetch(f"{EDIT_API_BASE}/{env}/render/{render_id}", method="GET", headers=_get_headers(context)) + render_data = response.get("response", {}) + status = render_data.get("status") + if status != "done": + return ActionResult(data={"result": False, "error": f"Render is not complete. Status: {status}"}, cost_usd=0.0) + url = render_data.get("url") + if not url: + return ActionResult(data={"result": False, "error": "No URL available. Provide render_id or url."}, cost_usd=0.0) + dl = await _download_base64(context, url) + return ActionResult(data={"content": dl["content"], "content_type": dl["content_type"], "filename": dl["filename"], "size": dl["size"], "result": True}, cost_usd=0.0) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@shotstack.action("custom_edit") +class CustomEditAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + wait = inputs.get("wait_for_completion", True) + max_wait = inputs.get("max_wait_seconds", 300) + payload = {"timeline": inputs["timeline"], "output": inputs["output"]} + return await _submit_and_maybe_wait(context, payload, wait, max_wait) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@shotstack.action("compose_video") +class ComposeVideoAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + clips = inputs["clips"] + output = inputs.get("output", {"format": "mp4", "resolution": "hd"}) + background_color = inputs.get("background_color", "#000000") + wait = inputs.get("wait_for_completion", True) + timeline = _build_timeline_from_clips(clips, background_color) + return await _submit_and_maybe_wait(context, {"timeline": timeline, "output": output}, wait) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@shotstack.action("add_text_overlay") +class AddTextOverlayAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + video_url = inputs["video_url"] + text = inputs["text"] + style = inputs.get("style", "minimal") + position = inputs.get("position", "center") + start_time = inputs.get("start_time", 0) + duration = inputs.get("duration") + font_size = inputs.get("font_size", "medium") + color = inputs.get("color", "#ffffff") + background_color = inputs.get("background_color") + effect = inputs.get("effect") + transition = inputs.get("transition") + output = inputs.get("output", {"format": "mp4", "resolution": "hd"}) + wait = inputs.get("wait_for_completion", True) + if not duration: + try: + media_info = await _get_media_info(context, video_url) + video_duration = media_info.get("metadata", {}).get("streams", [{}])[0].get("duration", 10) + duration = video_duration - start_time + except Exception: + duration = 10 + title_asset: Dict[str, Any] = {"type": "title", "text": text, "style": style, "color": color, "size": font_size, "position": position} + if background_color: + title_asset["background"] = background_color + text_clip: Dict[str, Any] = {"asset": title_asset, "start": start_time, "length": duration} + if effect: + text_clip["effect"] = effect + if transition: + text_clip["transition"] = transition + timeline = {"tracks": [{"clips": [text_clip]}, {"clips": [{"asset": {"type": "video", "src": video_url}, "start": 0}]}]} + return await _submit_and_maybe_wait(context, {"timeline": timeline, "output": output}, wait) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@shotstack.action("add_logo_overlay") +class AddLogoOverlayAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + video_url = inputs["video_url"] + logo_url = inputs["logo_url"] + position = inputs.get("position", "bottomRight") + scale = inputs.get("scale", 0.15) + opacity = inputs.get("opacity", 1) + offset_x = inputs.get("offset_x") + offset_y = inputs.get("offset_y") + start_time = inputs.get("start_time", 0) + duration = inputs.get("duration") + output = inputs.get("output", {"format": "mp4", "resolution": "hd"}) + wait = inputs.get("wait_for_completion", True) + if not duration: + try: + media_info = await _get_media_info(context, video_url) + video_duration = media_info.get("metadata", {}).get("streams", [{}])[0].get("duration", 10) + duration = video_duration - start_time + except Exception: + duration = 10 + logo_clip: Dict[str, Any] = {"asset": {"type": "image", "src": logo_url}, "start": start_time, "length": duration, "scale": scale, "position": position, "opacity": opacity} + if offset_x is not None or offset_y is not None: + offset = _position_to_offset(position) + if offset_x is not None: + offset["x"] = offset_x + if offset_y is not None: + offset["y"] = offset_y + logo_clip["offset"] = offset + timeline = {"tracks": [{"clips": [logo_clip]}, {"clips": [{"asset": {"type": "video", "src": video_url}, "start": 0}]}]} + return await _submit_and_maybe_wait(context, {"timeline": timeline, "output": output}, wait) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@shotstack.action("add_audio_track") +class AddAudioTrackAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + video_url = inputs["video_url"] + audio_url = inputs["audio_url"] + volume = inputs.get("volume", 1) + start_time = inputs.get("start_time", 0) + trim_from = inputs.get("trim_from", 0) + trim_duration = inputs.get("trim_duration") + fade_in = inputs.get("fade_in") + fade_out = inputs.get("fade_out") + mix_mode = inputs.get("mix_mode", "mix") + output = inputs.get("output", {"format": "mp4", "resolution": "hd"}) + wait = inputs.get("wait_for_completion", True) + video_asset: Dict[str, Any] = {"type": "video", "src": video_url} + if mix_mode == "replace": + video_asset["volume"] = 0 + audio_asset: Dict[str, Any] = {"type": "audio", "src": audio_url, "volume": volume} + if trim_from: + audio_asset["trim"] = trim_from + if fade_in and fade_out: + audio_asset["effect"] = "fadeInFadeOut" + elif fade_in: + audio_asset["effect"] = "fadeIn" + elif fade_out: + audio_asset["effect"] = "fadeOut" + audio_clip: Dict[str, Any] = {"asset": audio_asset, "start": start_time} + if trim_duration: + audio_clip["length"] = trim_duration + timeline = {"tracks": [{"clips": [{"asset": video_asset, "start": 0}]}, {"clips": [audio_clip]}]} + return await _submit_and_maybe_wait(context, {"timeline": timeline, "output": output}, wait) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@shotstack.action("trim_video") +class TrimVideoAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + video_url = inputs["video_url"] + start_time = inputs["start_time"] + end_time = inputs.get("end_time") + duration = inputs.get("duration") + output = inputs.get("output", {"format": "mp4", "resolution": "hd"}) + wait = inputs.get("wait_for_completion", True) + if end_time is not None: + length = end_time - start_time + elif duration is not None: + length = duration + else: + return ActionResult(data={"result": False, "error": "Either end_time or duration is required"}, cost_usd=0.0) + timeline = {"tracks": [{"clips": [{"asset": {"type": "video", "src": video_url, "trim": start_time}, "start": 0, "length": length}]}]} + return await _submit_and_maybe_wait(context, {"timeline": timeline, "output": output}, wait) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@shotstack.action("concatenate_videos") +class ConcatenateVideosAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + videos = inputs["videos"] + transition = inputs.get("transition") + output = inputs.get("output", {"format": "mp4", "resolution": "hd"}) + wait = inputs.get("wait_for_completion", True) + clips = [] + for i, video_url in enumerate(videos): + clip: Dict[str, Any] = {"asset": {"type": "video", "src": video_url}, "start": 0} + if transition and transition != "none": + trans = {} + if i > 0: + trans["in"] = transition + if i < len(videos) - 1: + trans["out"] = transition + if trans: + clip["transition"] = trans + clips.append(clip) + timeline = {"tracks": [{"clips": clips}]} + return await _submit_and_maybe_wait(context, {"timeline": timeline, "output": output}, wait) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0) + + +@shotstack.action("add_captions") +class AddCaptionsAction(ActionHandler): + async def execute(self, inputs: Dict[str, Any], context: ExecutionContext) -> ActionResult: + try: + video_url = inputs["video_url"] + subtitle_url = inputs.get("subtitle_url") + auto_generate = inputs.get("auto_generate", True) + font_family = inputs.get("font_family") + font_size = inputs.get("font_size", 16) + font_color = inputs.get("font_color", "#ffffff") + line_height = inputs.get("line_height") + stroke_color = inputs.get("stroke_color") + stroke_width = inputs.get("stroke_width") + background_color = inputs.get("background_color") + background_opacity = inputs.get("background_opacity", 0.8) + background_padding = inputs.get("background_padding", 10) + background_border_radius = inputs.get("background_border_radius", 4) + position = inputs.get("position", "bottom") + margin_top = inputs.get("margin_top") + margin_bottom = inputs.get("margin_bottom", 0.1) + margin_left = inputs.get("margin_left") + margin_right = inputs.get("margin_right") + caption_width = inputs.get("width") + caption_height = inputs.get("height") + output = inputs.get("output", {"format": "mp4", "resolution": "hd"}) + wait = inputs.get("wait_for_completion", True) + max_wait = inputs.get("max_wait_seconds", 300) + video_clip: Dict[str, Any] = {"asset": {"type": "video", "src": video_url}, "start": 0, "length": "auto"} + if auto_generate and not subtitle_url: + video_clip["alias"] = "main_video" + caption_asset: Dict[str, Any] = {"type": "caption"} + if subtitle_url: + caption_asset["src"] = subtitle_url + elif auto_generate: + caption_asset["src"] = "alias://main_video" + else: + return ActionResult(data={"result": False, "error": "Either subtitle_url or auto_generate=True is required"}, cost_usd=0.0) + font: Dict[str, Any] = {} + if font_family: + font["family"] = font_family + if font_size: + font["size"] = font_size + if font_color: + font["color"] = font_color + if line_height: + font["lineHeight"] = line_height + if stroke_color: + font["stroke"] = stroke_color + if stroke_width: + font["strokeWidth"] = stroke_width + if font: + caption_asset["font"] = font + if background_color: + bg: Dict[str, Any] = {"color": background_color} + if background_opacity is not None: + bg["opacity"] = background_opacity + if background_padding is not None: + bg["padding"] = background_padding + if background_border_radius is not None: + bg["borderRadius"] = background_border_radius + caption_asset["background"] = bg + if position: + caption_asset["position"] = position + margin: Dict[str, Any] = {} + if margin_top is not None: + margin["top"] = margin_top + if margin_bottom is not None: + margin["bottom"] = margin_bottom + if margin_left is not None: + margin["left"] = margin_left + if margin_right is not None: + margin["right"] = margin_right + if margin: + caption_asset["margin"] = margin + if caption_width: + caption_asset["width"] = caption_width + if caption_height: + caption_asset["height"] = caption_height + caption_clip = {"asset": caption_asset, "start": 0, "length": "end"} + timeline = {"tracks": [{"clips": [caption_clip]}, {"clips": [video_clip]}]} + return await _submit_and_maybe_wait(context, {"timeline": timeline, "output": output}, wait, max_wait) + except Exception as e: + return ActionResult(data={"result": False, "error": str(e)}, cost_usd=0.0)