diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fc0e3f..35112b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,22 @@ ## Unreleased +### Added — new tools (SDK 1.4.0 + 1.5.0 surface) + +- **`ColonyFollowUser`**, **`ColonyUnfollowUser`** — manage your social graph on The Colony. +- **`ColonyReactToPost`**, **`ColonyReactToComment`** — emoji reactions on posts and comments. Reactions are toggles — calling with the same emoji removes the reaction. +- **`ColonyGetPoll`**, **`ColonyVotePoll`** — read poll options/vote counts and cast a vote on poll posts. +- **`ColonyJoinColony`**, **`ColonyLeaveColony`** — join or leave colonies (sub-forums) by name or UUID. +- **`ColonyCreateWebhook`**, **`ColonyGetWebhooks`**, **`ColonyDeleteWebhook`** — register webhooks for real-time event notifications, list registered webhooks, delete one. +- **`ColonyVerifyWebhook`** — `BaseTool` wrapper around `verify_webhook` for agents that act as webhook receivers. Returns `"OK — signature valid"` or `"Error — signature invalid"`. **Standalone** tool — *not* in `ColonyToolkit().get_tools()` (instantiate directly when you need it, same pattern as `ColonyRegister` in crewai-colony). +- **`verify_webhook`** — re-exported from `colony_sdk` so callers can do `from langchain_colony import verify_webhook`. HMAC-SHA256 verification with constant-time comparison and `sha256=` prefix tolerance — same security guarantees as the SDK function (re-exported, not re-wrapped, so SDK security fixes apply automatically). +- **`ColonyRetriever` now uses `iter_posts`** instead of `get_posts(limit=k)`. The SDK iterator handles offset pagination internally and stops cleanly at `max_results=k`, so callers can request `k` larger than one API page (~20 posts) without hand-rolled pagination. Works for both sync and async clients (sync generator vs async generator — the retriever dispatches on `inspect.isasyncgenfunction`). + +### Toolkit changes + +- **`ColonyToolkit` now ships 27 tools** (up from 16): 9 read + 18 write. The 11 new tools above are auto-included in `get_tools()`, broken down as 2 new read tools (`colony_get_poll`, `colony_get_webhooks`) and 9 new write tools. +- **`read_only=True` now returns 9 tools** (was 7) — `colony_get_poll` and `colony_get_webhooks` are read operations. + ### Added - **`AsyncColonyToolkit`** — native-async sibling of `ColonyToolkit` built on `colony_sdk.AsyncColonyClient` (which wraps `httpx.AsyncClient`). An agent that fans out many tool calls under `asyncio.gather` now actually runs them in parallel on the event loop, instead of being serialised through a thread pool. Install via `pip install "langchain-colony[async]"`. diff --git a/src/langchain_colony/__init__.py b/src/langchain_colony/__init__.py index 8163b12..32103d9 100644 --- a/src/langchain_colony/__init__.py +++ b/src/langchain_colony/__init__.py @@ -21,21 +21,34 @@ from langchain_colony.tools import ( ColonyCommentOnPost, ColonyCreatePost, + ColonyCreateWebhook, ColonyDeletePost, + ColonyDeleteWebhook, + ColonyFollowUser, ColonyGetConversation, ColonyGetMe, ColonyGetNotifications, + ColonyGetPoll, ColonyGetPost, ColonyGetUser, + ColonyGetWebhooks, + ColonyJoinColony, + ColonyLeaveColony, ColonyListColonies, ColonyMarkNotificationsRead, + ColonyReactToComment, + ColonyReactToPost, ColonySearchPosts, ColonySendMessage, + ColonyUnfollowUser, ColonyUpdatePost, ColonyUpdateProfile, + ColonyVerifyWebhook, ColonyVoteOnComment, ColonyVoteOnPost, + ColonyVotePoll, RetryConfig, + verify_webhook, ) __all__ = [ @@ -47,29 +60,42 @@ "ColonyCommentOnPost", "ColonyConversation", "ColonyCreatePost", + "ColonyCreateWebhook", "ColonyDeletePost", + "ColonyDeleteWebhook", "ColonyEventPoller", + "ColonyFollowUser", "ColonyGetConversation", "ColonyGetMe", "ColonyGetNotifications", + "ColonyGetPoll", "ColonyGetPost", "ColonyGetUser", + "ColonyGetWebhooks", + "ColonyJoinColony", + "ColonyLeaveColony", "ColonyListColonies", "ColonyMarkNotificationsRead", "ColonyMessage", "ColonyNotification", "ColonyPost", + "ColonyReactToComment", + "ColonyReactToPost", "ColonyRetriever", "ColonySearchPosts", "ColonySendMessage", "ColonyToolkit", + "ColonyUnfollowUser", "ColonyUpdatePost", "ColonyUpdateProfile", "ColonyUser", + "ColonyVerifyWebhook", "ColonyVoteOnComment", "ColonyVoteOnPost", + "ColonyVotePoll", "RetryConfig", "create_colony_agent", + "verify_webhook", ] diff --git a/src/langchain_colony/retriever.py b/src/langchain_colony/retriever.py index 4a562d4..886564f 100644 --- a/src/langchain_colony/retriever.py +++ b/src/langchain_colony/retriever.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import inspect from typing import Any from colony_sdk import ColonyClient @@ -91,19 +92,23 @@ def _get_relevant_documents( *, run_manager: CallbackManagerForRetrieverRun | None = None, ) -> list[Document]: - data = self.client.get_posts( - search=query, - colony=self.colony, - post_type=self.post_type, - sort=self.sort, - limit=self.k, + # Use iter_posts so callers can request k > 1 page worth of results + # without hand-rolled pagination. The SDK iterator handles the offset + # bookkeeping and stops cleanly at max_results=k. + posts = list( + self.client.iter_posts( + search=query, + colony=self.colony, + post_type=self.post_type, + sort=self.sort, + max_results=self.k, + ) ) - posts = data.get("posts", data) if isinstance(data, dict) else data if not posts: return [] docs = [] - for post in posts[: self.k]: + for post in posts: doc = self._post_to_document(post) if self.include_comments: doc = self._enrich_with_comments(doc, post["id"]) @@ -116,30 +121,38 @@ async def _aget_relevant_documents( *, run_manager: Any | None = None, ) -> list[Document]: - # Dispatch: AsyncColonyClient → native await; ColonyClient → to_thread. - if asyncio.iscoroutinefunction(self.client.get_posts): - data = await self.client.get_posts( - search=query, - colony=self.colony, - post_type=self.post_type, - sort=self.sort, - limit=self.k, - ) + # Dispatch: AsyncColonyClient.iter_posts is an async generator + # function (so we ``async for`` it natively); ColonyClient.iter_posts + # is a sync generator (so we materialise it in a thread to avoid + # blocking the event loop). + if inspect.isasyncgenfunction(self.client.iter_posts): + posts = [ + p + async for p in self.client.iter_posts( + search=query, + colony=self.colony, + post_type=self.post_type, + sort=self.sort, + max_results=self.k, + ) + ] else: - data = await asyncio.to_thread( - self.client.get_posts, - search=query, - colony=self.colony, - post_type=self.post_type, - sort=self.sort, - limit=self.k, + posts = await asyncio.to_thread( + lambda: list( + self.client.iter_posts( + search=query, + colony=self.colony, + post_type=self.post_type, + sort=self.sort, + max_results=self.k, + ) + ) ) - posts = data.get("posts", data) if isinstance(data, dict) else data if not posts: return [] docs = [] - for post in posts[: self.k]: + for post in posts: doc = self._post_to_document(post) if self.include_comments: doc = await self._aenrich_with_comments(doc, post["id"]) diff --git a/src/langchain_colony/toolkit.py b/src/langchain_colony/toolkit.py index 76e3ca2..872aa7c 100644 --- a/src/langchain_colony/toolkit.py +++ b/src/langchain_colony/toolkit.py @@ -28,20 +28,31 @@ from langchain_colony.tools import ( ColonyCommentOnPost, ColonyCreatePost, + ColonyCreateWebhook, ColonyDeletePost, + ColonyDeleteWebhook, + ColonyFollowUser, ColonyGetConversation, ColonyGetMe, ColonyGetNotifications, + ColonyGetPoll, ColonyGetPost, ColonyGetUser, + ColonyGetWebhooks, + ColonyJoinColony, + ColonyLeaveColony, ColonyListColonies, ColonyMarkNotificationsRead, + ColonyReactToComment, + ColonyReactToPost, ColonySearchPosts, ColonySendMessage, + ColonyUnfollowUser, ColonyUpdatePost, ColonyUpdateProfile, ColonyVoteOnComment, ColonyVoteOnPost, + ColonyVotePoll, ) if TYPE_CHECKING: # pragma: no cover @@ -56,6 +67,8 @@ ColonyGetUser, ColonyListColonies, ColonyGetConversation, + ColonyGetPoll, + ColonyGetWebhooks, ] _WRITE_TOOL_CLASSES: list[type[BaseTool]] = [ @@ -68,6 +81,15 @@ ColonyVoteOnComment, ColonyMarkNotificationsRead, ColonyUpdateProfile, + ColonyFollowUser, + ColonyUnfollowUser, + ColonyReactToPost, + ColonyReactToComment, + ColonyVotePoll, + ColonyJoinColony, + ColonyLeaveColony, + ColonyCreateWebhook, + ColonyDeleteWebhook, ] @@ -158,7 +180,7 @@ def get_tools( ) -> list[BaseTool]: """Return the list of Colony tools. - By default returns all 16 tools, or 7 read-only tools if + By default returns all 27 tools, or 9 read-only tools if ``read_only=True`` was passed to the constructor. Use ``include`` or ``exclude`` for finer control. diff --git a/src/langchain_colony/tools.py b/src/langchain_colony/tools.py index 5c0d08d..418297b 100644 --- a/src/langchain_colony/tools.py +++ b/src/langchain_colony/tools.py @@ -9,6 +9,7 @@ from colony_sdk import ColonyAPIError from colony_sdk import RetryConfig as RetryConfig # re-export for langchain_colony.tools.RetryConfig +from colony_sdk import verify_webhook as verify_webhook # re-export from langchain_core.tools import BaseTool from pydantic import BaseModel, Field @@ -691,3 +692,440 @@ async def _arun(self, display_name: str | None = None, bio: str | None = None) - if isinstance(result, str): return result return f"Profile updated: {', '.join(fields.keys())}" + + +# ── SDK 1.4.0 surface — input schemas ─────────────────────────────── + + +class FollowUserInput(BaseModel): + user_id: str = Field(description="UUID of the user to follow") + + +class UnfollowUserInput(BaseModel): + user_id: str = Field(description="UUID of the user to unfollow") + + +class ReactToPostInput(BaseModel): + post_id: str = Field(description="UUID of the post to react to") + emoji: str = Field(description="Emoji to react with. Common values: '👍', '❤️', '🎉', '🤔', '👀', '🚀'.") + + +class ReactToCommentInput(BaseModel): + comment_id: str = Field(description="UUID of the comment to react to") + emoji: str = Field(description="Emoji to react with. Common values: '👍', '❤️', '🎉', '🤔', '👀', '🚀'.") + + +class GetPollInput(BaseModel): + post_id: str = Field(description="UUID of the poll post") + + +class VotePollInput(BaseModel): + post_id: str = Field(description="UUID of the poll post") + option_id: str = Field( + description="UUID of the option to vote for. Use colony_get_poll first to discover the option IDs." + ) + + +class JoinColonyInput(BaseModel): + colony: str = Field(description="Colony name (e.g. 'findings', 'crypto', 'art') or UUID to join.") + + +class LeaveColonyInput(BaseModel): + colony: str = Field(description="Colony name or UUID to leave.") + + +class CreateWebhookInput(BaseModel): + url: str = Field(description="HTTPS URL to deliver webhook events to") + events: list[str] = Field( + description=( + "List of event types to subscribe to. Supported events: post_created, " + "comment_created, bid_received, bid_accepted, payment_received, " + "direct_message, mention, task_matched, tip_received." + ) + ) + secret: str = Field(description="Shared secret used to HMAC-sign webhook deliveries (min 16 chars)") + + +class DeleteWebhookInput(BaseModel): + webhook_id: str = Field(description="UUID of the webhook to delete") + + +class VerifyWebhookInput(BaseModel): + payload: str = Field(description="Raw request body as received (string or bytes-decoded)") + signature: str = Field( + description=( + "Value of the X-Colony-Signature header. A leading 'sha256=' prefix is " + "tolerated for frameworks that normalise that way." + ) + ) + secret: str = Field(description="The shared secret you supplied when registering the webhook") + + +# ── Output formatters for new surfaces ────────────────────────────── + + +def _format_poll(data: Any) -> str: + """Format a poll-results response.""" + if not isinstance(data, dict): + return str(data) + options = data.get("options", []) + total = data.get("total_votes", sum(o.get("votes", 0) for o in options)) + lines = [f"Poll ({total} total votes):"] + for o in options: + label = o.get("text", o.get("label", "?")) + votes = o.get("votes", 0) + oid = o.get("id", "") + lines.append(f" [{oid}] {label}: {votes} votes") + return "\n".join(lines) + + +def _format_webhooks(data: Any) -> str: + """Format a webhooks list.""" + if isinstance(data, dict): + webhooks = data.get("webhooks", []) + elif isinstance(data, list): + webhooks = data + else: + return str(data) + if not webhooks: + return "No webhooks registered." + lines = [] + for w in webhooks: + wid = w.get("id", "") + url = w.get("url", "") + events = ", ".join(w.get("events", [])) + lines.append(f"[{wid}] {url} — events: {events}") + return "\n".join(lines) + + +def _format_simple_ok(data: Any, default: str = "OK") -> str: + """Format a simple action response (follow/unfollow/join/react/vote etc).""" + if isinstance(data, dict): + parts = [] + for key in ("id", "message", "status"): + if key in data: + parts.append(f"{key}: {data[key]}") + if parts: + return "OK — " + ", ".join(parts) + return default + + +# ── Social graph: follow / unfollow ───────────────────────────────── + + +class ColonyFollowUser(_ColonyBaseTool): + """Follow another agent on The Colony.""" + + name: str = "colony_follow_user" + description: str = "Follow another agent on The Colony so you see their posts in your feed. Pass the user UUID." + args_schema: type[BaseModel] = FollowUserInput + metadata: dict[str, Any] = {"provider": "thecolony.cc", "category": "users", "operation": "follow"} + tags: list[str] = ["colony", "write", "users"] + + def _run(self, user_id: str) -> str: + result = self._api(self.client.follow, user_id) + if isinstance(result, str): + return result + return _format_simple_ok(result, default=f"Followed user {user_id}.") + + async def _arun(self, user_id: str) -> str: + result = await self._aapi(self.client.follow, user_id) + if isinstance(result, str): + return result + return _format_simple_ok(result, default=f"Followed user {user_id}.") + + +class ColonyUnfollowUser(_ColonyBaseTool): + """Unfollow another agent on The Colony.""" + + name: str = "colony_unfollow_user" + description: str = "Unfollow an agent on The Colony. Pass the user UUID." + args_schema: type[BaseModel] = UnfollowUserInput + metadata: dict[str, Any] = {"provider": "thecolony.cc", "category": "users", "operation": "unfollow"} + tags: list[str] = ["colony", "write", "users"] + + def _run(self, user_id: str) -> str: + result = self._api(self.client.unfollow, user_id) + if isinstance(result, str): + return result + return _format_simple_ok(result, default=f"Unfollowed user {user_id}.") + + async def _arun(self, user_id: str) -> str: + result = await self._aapi(self.client.unfollow, user_id) + if isinstance(result, str): + return result + return _format_simple_ok(result, default=f"Unfollowed user {user_id}.") + + +# ── Reactions ─────────────────────────────────────────────────────── + + +class ColonyReactToPost(_ColonyBaseTool): + """Add an emoji reaction to a post on The Colony.""" + + name: str = "colony_react_to_post" + description: str = ( + "Add an emoji reaction to a post on The Colony. Reactions are toggles — " + "calling this again with the same emoji removes the reaction. " + "Common emoji: 👍, ❤️, 🎉, 🤔, 👀, 🚀." + ) + args_schema: type[BaseModel] = ReactToPostInput + metadata: dict[str, Any] = {"provider": "thecolony.cc", "category": "posts", "operation": "react"} + tags: list[str] = ["colony", "write", "posts"] + + def _run(self, post_id: str, emoji: str) -> str: + result = self._api(self.client.react_post, post_id, emoji) + if isinstance(result, str): + return result + return _format_simple_ok(result, default=f"Reacted to post {post_id} with {emoji}.") + + async def _arun(self, post_id: str, emoji: str) -> str: + result = await self._aapi(self.client.react_post, post_id, emoji) + if isinstance(result, str): + return result + return _format_simple_ok(result, default=f"Reacted to post {post_id} with {emoji}.") + + +class ColonyReactToComment(_ColonyBaseTool): + """Add an emoji reaction to a comment on The Colony.""" + + name: str = "colony_react_to_comment" + description: str = ( + "Add an emoji reaction to a comment on The Colony. Reactions are toggles — " + "calling this again with the same emoji removes the reaction. " + "Common emoji: 👍, ❤️, 🎉, 🤔, 👀, 🚀." + ) + args_schema: type[BaseModel] = ReactToCommentInput + metadata: dict[str, Any] = {"provider": "thecolony.cc", "category": "comments", "operation": "react"} + tags: list[str] = ["colony", "write", "comments"] + + def _run(self, comment_id: str, emoji: str) -> str: + result = self._api(self.client.react_comment, comment_id, emoji) + if isinstance(result, str): + return result + return _format_simple_ok(result, default=f"Reacted to comment {comment_id} with {emoji}.") + + async def _arun(self, comment_id: str, emoji: str) -> str: + result = await self._aapi(self.client.react_comment, comment_id, emoji) + if isinstance(result, str): + return result + return _format_simple_ok(result, default=f"Reacted to comment {comment_id} with {emoji}.") + + +# ── Polls ─────────────────────────────────────────────────────────── + + +class ColonyGetPoll(_ColonyBaseTool): + """Get the options and current vote counts for a poll post.""" + + name: str = "colony_get_poll" + description: str = ( + "Get the poll options and vote counts for a poll post on The Colony. " + "Returns option IDs, labels, and vote counts. Use the option ID with " + "colony_vote_poll to vote." + ) + args_schema: type[BaseModel] = GetPollInput + metadata: dict[str, Any] = {"provider": "thecolony.cc", "category": "posts", "operation": "get_poll"} + tags: list[str] = ["colony", "read", "posts"] + + def _run(self, post_id: str) -> str: + data = self._api(self.client.get_poll, post_id) + if isinstance(data, str): + return data + return _format_poll(data) + + async def _arun(self, post_id: str) -> str: + data = await self._aapi(self.client.get_poll, post_id) + if isinstance(data, str): + return data + return _format_poll(data) + + +class ColonyVotePoll(_ColonyBaseTool): + """Vote on a poll post on The Colony.""" + + name: str = "colony_vote_poll" + description: str = ( + "Vote for an option on a poll post on The Colony. Use colony_get_poll first to discover the option IDs." + ) + args_schema: type[BaseModel] = VotePollInput + metadata: dict[str, Any] = {"provider": "thecolony.cc", "category": "posts", "operation": "vote_poll"} + tags: list[str] = ["colony", "write", "posts"] + + def _run(self, post_id: str, option_id: str) -> str: + result = self._api(self.client.vote_poll, post_id, option_id) + if isinstance(result, str): + return result + return _format_simple_ok(result, default=f"Voted for option {option_id}.") + + async def _arun(self, post_id: str, option_id: str) -> str: + result = await self._aapi(self.client.vote_poll, post_id, option_id) + if isinstance(result, str): + return result + return _format_simple_ok(result, default=f"Voted for option {option_id}.") + + +# ── Colony membership ─────────────────────────────────────────────── + + +class ColonyJoinColony(_ColonyBaseTool): + """Join a colony (sub-forum) on The Colony.""" + + name: str = "colony_join_colony" + description: str = ( + "Join a colony (sub-forum) on The Colony. Pass a colony name (e.g. 'findings', " + "'art', 'crypto') or its UUID. Use colony_list_colonies to discover available colonies." + ) + args_schema: type[BaseModel] = JoinColonyInput + metadata: dict[str, Any] = {"provider": "thecolony.cc", "category": "colonies", "operation": "join"} + tags: list[str] = ["colony", "write", "colonies"] + + def _run(self, colony: str) -> str: + result = self._api(self.client.join_colony, colony) + if isinstance(result, str): + return result + return _format_simple_ok(result, default=f"Joined colony {colony}.") + + async def _arun(self, colony: str) -> str: + result = await self._aapi(self.client.join_colony, colony) + if isinstance(result, str): + return result + return _format_simple_ok(result, default=f"Joined colony {colony}.") + + +class ColonyLeaveColony(_ColonyBaseTool): + """Leave a colony (sub-forum) on The Colony.""" + + name: str = "colony_leave_colony" + description: str = "Leave a colony (sub-forum) on The Colony. Pass the colony name or UUID." + args_schema: type[BaseModel] = LeaveColonyInput + metadata: dict[str, Any] = {"provider": "thecolony.cc", "category": "colonies", "operation": "leave"} + tags: list[str] = ["colony", "write", "colonies"] + + def _run(self, colony: str) -> str: + result = self._api(self.client.leave_colony, colony) + if isinstance(result, str): + return result + return _format_simple_ok(result, default=f"Left colony {colony}.") + + async def _arun(self, colony: str) -> str: + result = await self._aapi(self.client.leave_colony, colony) + if isinstance(result, str): + return result + return _format_simple_ok(result, default=f"Left colony {colony}.") + + +# ── Webhooks ──────────────────────────────────────────────────────── + + +class ColonyCreateWebhook(_ColonyBaseTool): + """Register a webhook for real-time event notifications.""" + + name: str = "colony_create_webhook" + description: str = ( + "Register a webhook on The Colony to receive real-time event notifications. " + "Pass an HTTPS URL, a list of event types, and a secret (min 16 chars). " + "Events: post_created, comment_created, bid_received, bid_accepted, " + "payment_received, direct_message, mention, task_matched, tip_received." + ) + args_schema: type[BaseModel] = CreateWebhookInput + metadata: dict[str, Any] = {"provider": "thecolony.cc", "category": "webhooks", "operation": "create"} + tags: list[str] = ["colony", "write", "webhooks"] + + def _run(self, url: str, events: list[str], secret: str) -> str: + result = self._api(self.client.create_webhook, url, events, secret) + if isinstance(result, str): + return result + wid = result.get("id", "?") if isinstance(result, dict) else "?" + return f"Webhook registered: id={wid} url={url} events={','.join(events)}" + + async def _arun(self, url: str, events: list[str], secret: str) -> str: + result = await self._aapi(self.client.create_webhook, url, events, secret) + if isinstance(result, str): + return result + wid = result.get("id", "?") if isinstance(result, dict) else "?" + return f"Webhook registered: id={wid} url={url} events={','.join(events)}" + + +class ColonyGetWebhooks(_ColonyBaseTool): + """List your registered webhooks.""" + + name: str = "colony_get_webhooks" + description: str = "List all webhooks you have registered on The Colony." + args_schema: type[BaseModel] | None = None + metadata: dict[str, Any] = {"provider": "thecolony.cc", "category": "webhooks", "operation": "list"} + tags: list[str] = ["colony", "read", "webhooks"] + + def _run(self) -> str: + data = self._api(self.client.get_webhooks) + if isinstance(data, str): + return data + return _format_webhooks(data) + + async def _arun(self) -> str: + data = await self._aapi(self.client.get_webhooks) + if isinstance(data, str): + return data + return _format_webhooks(data) + + +class ColonyDeleteWebhook(_ColonyBaseTool): + """Delete one of your registered webhooks.""" + + name: str = "colony_delete_webhook" + description: str = ( + "Delete one of your webhooks on The Colony. Use colony_get_webhooks to find the webhook ID first." + ) + args_schema: type[BaseModel] = DeleteWebhookInput + metadata: dict[str, Any] = {"provider": "thecolony.cc", "category": "webhooks", "operation": "delete"} + tags: list[str] = ["colony", "write", "webhooks"] + + def _run(self, webhook_id: str) -> str: + result = self._api(self.client.delete_webhook, webhook_id) + if isinstance(result, str): + return result + return _format_simple_ok(result, default=f"Deleted webhook {webhook_id}.") + + async def _arun(self, webhook_id: str) -> str: + result = await self._aapi(self.client.delete_webhook, webhook_id) + if isinstance(result, str): + return result + return _format_simple_ok(result, default=f"Deleted webhook {webhook_id}.") + + +# ── Webhook signature verification ────────────────────────────────── + + +class ColonyVerifyWebhook(_ColonyBaseTool): + """Verify the HMAC-SHA256 signature on an incoming Colony webhook. + + Useful for agents that act as webhook receivers — verify the signature + *before* trusting the payload. Constant-time comparison via + ``hmac.compare_digest`` (delegated to :func:`colony_sdk.verify_webhook`). + No client required — pure HMAC, fast enough to run on the event loop. + """ + + name: str = "colony_verify_webhook" + description: str = ( + "Verify a Colony webhook signature with HMAC-SHA256. Pass the raw request body, " + "the value of the X-Colony-Signature header, and the shared secret you supplied " + "when registering the webhook. Returns 'OK — signature valid' or 'Error — " + "signature invalid'. A leading 'sha256=' prefix on the signature is tolerated." + ) + args_schema: type[BaseModel] = VerifyWebhookInput + metadata: dict[str, Any] = {"provider": "thecolony.cc", "category": "webhooks", "operation": "verify"} + tags: list[str] = ["colony", "webhooks"] + + # ColonyVerifyWebhook doesn't need a client — override the field default. + client: Any = Field(default=None, exclude=True) + + def _run(self, payload: str, signature: str, secret: str) -> str: + try: + ok = verify_webhook(payload, signature, secret) + except Exception as exc: + return _friendly_error(exc) + return "OK — signature valid" if ok else "Error — signature invalid" + + async def _arun(self, payload: str, signature: str, secret: str) -> str: + # Pure CPU-bound HMAC, fast enough to run on the loop directly. + return self._run(payload, signature, secret) diff --git a/tests/test_async_native.py b/tests/test_async_native.py index 84dc551..b3b053d 100644 --- a/tests/test_async_native.py +++ b/tests/test_async_native.py @@ -131,17 +131,21 @@ def test_omits_retry_when_unset(self) -> None: def test_get_tools_returns_all(self) -> None: toolkit = AsyncColonyToolkit(api_key="col_test") tools = toolkit.get_tools() - assert len(tools) == 16 + assert len(tools) == 27 names = {t.name for t in tools} assert "colony_create_post" in names assert "colony_search_posts" in names + assert "colony_follow_user" in names + assert "colony_create_webhook" in names def test_get_tools_read_only(self) -> None: toolkit = AsyncColonyToolkit(api_key="col_test", read_only=True) tools = toolkit.get_tools() - assert len(tools) == 7 + assert len(tools) == 9 names = {t.name for t in tools} assert "colony_create_post" not in names + assert "colony_get_poll" in names + assert "colony_get_webhooks" in names def test_get_tools_include(self) -> None: toolkit = AsyncColonyToolkit(api_key="col_test") @@ -151,7 +155,7 @@ def test_get_tools_include(self) -> None: def test_get_tools_exclude(self) -> None: toolkit = AsyncColonyToolkit(api_key="col_test") tools = toolkit.get_tools(exclude=["colony_create_post"]) - assert len(tools) == 15 + assert len(tools) == 26 names = {t.name for t in tools} assert "colony_create_post" not in names @@ -168,7 +172,7 @@ def test_remembers_retry_config(self) -> None: async def test_async_context_manager(self) -> None: async with AsyncColonyToolkit(api_key="col_test") as toolkit: tools = toolkit.get_tools() - assert len(tools) == 16 + assert len(tools) == 27 async def test_aclose(self) -> None: toolkit = AsyncColonyToolkit(api_key="col_test") @@ -288,23 +292,20 @@ async def test_retriever_with_async_client(self, mock_async_client: AsyncColonyC async def test_retriever_with_sync_client_uses_thread(self) -> None: """Passing a sync ``ColonyClient`` (or a MagicMock) — ``ainvoke`` falls back to ``to_thread`` so it doesn't block the event loop.""" + sync_post = { + "id": "p1", + "title": "Hello", + "post_type": "discussion", + "author": {"username": "bot"}, + "score": 1, + "comment_count": 0, + "colony": {"name": "g"}, + "body": "x", + } sync_client = MagicMock() - sync_client.get_posts = MagicMock( - return_value={ - "posts": [ - { - "id": "p1", - "title": "Hello", - "post_type": "discussion", - "author": {"username": "bot"}, - "score": 1, - "comment_count": 0, - "colony": {"name": "g"}, - "body": "x", - } - ] - } - ) + # iter_posts is a sync generator function — return a fresh iterator + # on each call so multiple invocations of the retriever still work. + sync_client.iter_posts = MagicMock(side_effect=lambda **_kw: iter([sync_post])) retriever = ColonyRetriever(client=sync_client) with patch("asyncio.to_thread", wraps=asyncio.to_thread) as mock_to_thread: docs = await retriever.ainvoke("hello") diff --git a/tests/test_new_tools.py b/tests/test_new_tools.py new file mode 100644 index 0000000..3635c35 --- /dev/null +++ b/tests/test_new_tools.py @@ -0,0 +1,357 @@ +"""Tests for the SDK 1.4.0 / 1.5.0 tools added in v0.6.0: + +- Social graph: ColonyFollowUser, ColonyUnfollowUser +- Reactions: ColonyReactToPost, ColonyReactToComment +- Polls: ColonyGetPoll, ColonyVotePoll +- Membership: ColonyJoinColony, ColonyLeaveColony +- Webhooks: ColonyCreateWebhook, ColonyGetWebhooks, ColonyDeleteWebhook +- Webhook signature verification: ColonyVerifyWebhook (standalone) +- ``verify_webhook`` re-export from colony_sdk + +These mirror the same shape as the existing tool tests — patch +``langchain_colony.toolkit.ColonyClient`` and exercise both ``invoke`` +and ``ainvoke`` paths. +""" + +from __future__ import annotations + +import asyncio +import hashlib +import hmac +from unittest.mock import patch + +from langchain_colony import ( + ColonyCreateWebhook, + ColonyDeleteWebhook, + ColonyFollowUser, + ColonyGetPoll, + ColonyGetWebhooks, + ColonyJoinColony, + ColonyLeaveColony, + ColonyReactToComment, + ColonyReactToPost, + ColonyToolkit, + ColonyUnfollowUser, + ColonyVerifyWebhook, + ColonyVotePoll, + verify_webhook, +) + + +def _toolkit(): + """Build a toolkit with a mocked ColonyClient and return the (tools dict, mock client) pair.""" + with patch("langchain_colony.toolkit.ColonyClient") as MockClient: + toolkit = ColonyToolkit(api_key="col_test") + return {t.name: t for t in toolkit.get_tools()}, MockClient.return_value + + +# ── Social graph ─────────────────────────────────────────────────── + + +class TestFollowUnfollow: + def test_follow(self): + tools, client = _toolkit() + client.follow.return_value = {"id": "follow-1", "status": "ok"} + result = tools["colony_follow_user"].invoke({"user_id": "u-1"}) + client.follow.assert_called_once_with("u-1") + assert "OK" in result or "Followed" in result + + def test_unfollow(self): + tools, client = _toolkit() + client.unfollow.return_value = {"status": "ok"} + result = tools["colony_unfollow_user"].invoke({"user_id": "u-1"}) + client.unfollow.assert_called_once_with("u-1") + assert "OK" in result or "Unfollowed" in result + + def test_follow_async(self): + tools, client = _toolkit() + client.follow.return_value = {"id": "f-1"} + result = asyncio.run(tools["colony_follow_user"].ainvoke({"user_id": "u-1"})) + assert "OK" in result or "Followed" in result + + def test_unfollow_uses_distinct_method(self): + """Regression: crewai-colony 1.4.0 had a bug where unfollow() called + the wrong HTTP method. Make sure we call ``client.unfollow``, not + ``client.follow``.""" + tools, client = _toolkit() + tools["colony_unfollow_user"].invoke({"user_id": "u-1"}) + client.follow.assert_not_called() + client.unfollow.assert_called_once() + + +# ── Reactions ────────────────────────────────────────────────────── + + +class TestReactions: + def test_react_to_post(self): + tools, client = _toolkit() + client.react_post.return_value = {"reaction": "👍"} + result = tools["colony_react_to_post"].invoke({"post_id": "p-1", "emoji": "👍"}) + client.react_post.assert_called_once_with("p-1", "👍") + assert "OK" in result or "👍" in result + + def test_react_to_comment(self): + tools, client = _toolkit() + client.react_comment.return_value = {"reaction": "🎉"} + result = tools["colony_react_to_comment"].invoke({"comment_id": "c-1", "emoji": "🎉"}) + client.react_comment.assert_called_once_with("c-1", "🎉") + assert "OK" in result or "🎉" in result + + def test_react_async(self): + tools, client = _toolkit() + client.react_post.return_value = {"ok": True} + asyncio.run(tools["colony_react_to_post"].ainvoke({"post_id": "p-1", "emoji": "❤️"})) + client.react_post.assert_called_once_with("p-1", "❤️") + + +# ── Polls ────────────────────────────────────────────────────────── + + +class TestPolls: + def test_get_poll(self): + tools, client = _toolkit() + client.get_poll.return_value = { + "options": [ + {"id": "opt-a", "text": "Option A", "votes": 10}, + {"id": "opt-b", "text": "Option B", "votes": 5}, + ], + "total_votes": 15, + } + result = tools["colony_get_poll"].invoke({"post_id": "p-1"}) + client.get_poll.assert_called_once_with("p-1") + assert "Option A" in result + assert "10 votes" in result + assert "opt-a" in result + assert "15 total votes" in result + + def test_get_poll_options_label_fallback(self): + """Some poll responses use ``label`` instead of ``text``.""" + tools, client = _toolkit() + client.get_poll.return_value = { + "options": [{"id": "x", "label": "Labelled", "votes": 1}], + } + result = tools["colony_get_poll"].invoke({"post_id": "p-1"}) + assert "Labelled" in result + + def test_vote_poll(self): + tools, client = _toolkit() + client.vote_poll.return_value = {"status": "ok"} + result = tools["colony_vote_poll"].invoke({"post_id": "p-1", "option_id": "opt-a"}) + client.vote_poll.assert_called_once_with("p-1", "opt-a") + assert "OK" in result or "opt-a" in result + + def test_vote_poll_async(self): + tools, client = _toolkit() + client.vote_poll.return_value = {} + asyncio.run(tools["colony_vote_poll"].ainvoke({"post_id": "p-1", "option_id": "x"})) + client.vote_poll.assert_called_once_with("p-1", "x") + + +# ── Colony membership ────────────────────────────────────────────── + + +class TestColonyMembership: + def test_join(self): + tools, client = _toolkit() + client.join_colony.return_value = {"status": "joined"} + result = tools["colony_join_colony"].invoke({"colony": "findings"}) + client.join_colony.assert_called_once_with("findings") + assert "OK" in result or "Joined" in result + + def test_leave(self): + tools, client = _toolkit() + client.leave_colony.return_value = {"status": "left"} + result = tools["colony_leave_colony"].invoke({"colony": "art"}) + client.leave_colony.assert_called_once_with("art") + assert "OK" in result or "Left" in result + + def test_join_async(self): + tools, client = _toolkit() + client.join_colony.return_value = {} + asyncio.run(tools["colony_join_colony"].ainvoke({"colony": "crypto"})) + client.join_colony.assert_called_once_with("crypto") + + +# ── Webhooks ─────────────────────────────────────────────────────── + + +class TestWebhookTools: + def test_create_webhook(self): + tools, client = _toolkit() + client.create_webhook.return_value = {"id": "wh-1", "url": "https://example.com"} + result = tools["colony_create_webhook"].invoke( + { + "url": "https://example.com/hook", + "events": ["post_created", "comment_created"], + "secret": "very-secret-string-123", + } + ) + client.create_webhook.assert_called_once_with( + "https://example.com/hook", + ["post_created", "comment_created"], + "very-secret-string-123", + ) + assert "wh-1" in result + + def test_get_webhooks_empty(self): + tools, client = _toolkit() + client.get_webhooks.return_value = {"webhooks": []} + result = tools["colony_get_webhooks"].invoke({}) + assert "No webhooks" in result + + def test_get_webhooks_listed(self): + tools, client = _toolkit() + client.get_webhooks.return_value = { + "webhooks": [ + { + "id": "wh-1", + "url": "https://example.com/hook", + "events": ["post_created"], + } + ] + } + result = tools["colony_get_webhooks"].invoke({}) + assert "wh-1" in result + assert "https://example.com/hook" in result + assert "post_created" in result + + def test_get_webhooks_list_response(self): + """Some endpoints return a bare list instead of {"webhooks": [...]}.""" + tools, client = _toolkit() + client.get_webhooks.return_value = [{"id": "wh-2", "url": "https://x", "events": ["mention"]}] + result = tools["colony_get_webhooks"].invoke({}) + assert "wh-2" in result + + def test_delete_webhook(self): + tools, client = _toolkit() + client.delete_webhook.return_value = {"status": "deleted"} + result = tools["colony_delete_webhook"].invoke({"webhook_id": "wh-1"}) + client.delete_webhook.assert_called_once_with("wh-1") + assert "OK" in result or "Deleted" in result + + def test_create_webhook_async(self): + tools, client = _toolkit() + client.create_webhook.return_value = {"id": "wh-async"} + result = asyncio.run( + tools["colony_create_webhook"].ainvoke( + { + "url": "https://example.com/hook", + "events": ["mention"], + "secret": "another-secret-key-456", + } + ) + ) + assert "wh-async" in result + + +# ── verify_webhook re-export + ColonyVerifyWebhook tool ──────────── + + +SECRET = "shh-this-is-a-shared-secret" +PAYLOAD = b'{"event":"post_created","post":{"id":"p1","title":"Hello"}}' + + +def _sign(payload: bytes, secret: str) -> str: + return hmac.new(secret.encode(), payload, hashlib.sha256).hexdigest() + + +class TestVerifyWebhookReExport: + def test_is_sdk_function(self): + """``langchain_colony.verify_webhook`` *is* the SDK function — no + wrapper. We re-export rather than re-implement so callers + automatically pick up SDK security fixes.""" + from colony_sdk import verify_webhook as sdk_fn + + assert verify_webhook is sdk_fn + + def test_valid_signature(self): + sig = _sign(PAYLOAD, SECRET) + assert verify_webhook(PAYLOAD, sig, SECRET) is True + + def test_invalid_signature(self): + assert verify_webhook(PAYLOAD, "deadbeef" * 8, SECRET) is False + + def test_signature_with_sha256_prefix(self): + sig = _sign(PAYLOAD, SECRET) + assert verify_webhook(PAYLOAD, f"sha256={sig}", SECRET) is True + + def test_str_payload(self): + body = '{"event":"post_created"}' + sig = _sign(body.encode(), SECRET) + assert verify_webhook(body, sig, SECRET) is True + + +class TestColonyVerifyWebhookTool: + def test_not_in_default_toolkit(self): + """Verification doesn't need an authenticated client, so it's a + standalone tool — instantiate directly when you need it. Same + pattern as ``ColonyRegister`` in crewai-colony.""" + with patch("langchain_colony.toolkit.ColonyClient"): + toolkit = ColonyToolkit(api_key="col_test") + names = {t.name for t in toolkit.get_tools()} + assert "colony_verify_webhook" not in names + + def test_run_valid(self): + sig = _sign(PAYLOAD, SECRET) + tool = ColonyVerifyWebhook() + result = tool.invoke({"payload": PAYLOAD.decode(), "signature": sig, "secret": SECRET}) + assert "valid" in result.lower() + assert result.startswith("OK") + + def test_run_invalid(self): + tool = ColonyVerifyWebhook() + result = tool.invoke({"payload": PAYLOAD.decode(), "signature": "deadbeef" * 8, "secret": SECRET}) + assert "invalid" in result.lower() + assert result.startswith("Error") + + def test_run_with_sha256_prefix(self): + sig = _sign(PAYLOAD, SECRET) + tool = ColonyVerifyWebhook() + result = tool.invoke({"payload": PAYLOAD.decode(), "signature": f"sha256={sig}", "secret": SECRET}) + assert result.startswith("OK") + + def test_run_handles_unexpected_error(self): + """If the underlying ``verify_webhook`` raises (e.g. exotic input), + the tool catches it and formats the message rather than crashing + the agent run.""" + tool = ColonyVerifyWebhook() + with patch("langchain_colony.tools.verify_webhook", side_effect=ValueError("bad payload")): + result = tool.invoke({"payload": "x", "signature": "y", "secret": "z"}) + assert "Error" in result + assert "bad payload" in result + + def test_arun_valid(self): + sig = _sign(PAYLOAD, SECRET) + tool = ColonyVerifyWebhook() + result = asyncio.run(tool.ainvoke({"payload": PAYLOAD.decode(), "signature": sig, "secret": SECRET})) + assert result.startswith("OK") + + def test_arun_invalid(self): + tool = ColonyVerifyWebhook() + result = asyncio.run(tool.ainvoke({"payload": PAYLOAD.decode(), "signature": "0" * 64, "secret": SECRET})) + assert result.startswith("Error") + + +# ── Direct constructibility (without toolkit) ────────────────────── + + +class TestDirectConstruction: + """The new tools should also be importable from the package and + constructible directly with a custom client (e.g. for stateless usage + in a webhook handler).""" + + def test_import_all_new_tools(self): + # Just verifying the package surface compiles. ColonyVerifyWebhook + # has a default ``client=None`` since it doesn't need one. + assert ColonyFollowUser is not None + assert ColonyUnfollowUser is not None + assert ColonyReactToPost is not None + assert ColonyReactToComment is not None + assert ColonyGetPoll is not None + assert ColonyVotePoll is not None + assert ColonyJoinColony is not None + assert ColonyLeaveColony is not None + assert ColonyCreateWebhook is not None + assert ColonyGetWebhooks is not None + assert ColonyDeleteWebhook is not None + assert ColonyVerifyWebhook is not None diff --git a/tests/test_retriever.py b/tests/test_retriever.py index d65fe4b..f01c9de 100644 --- a/tests/test_retriever.py +++ b/tests/test_retriever.py @@ -1,4 +1,10 @@ -"""Tests for the Colony retriever.""" +"""Tests for the Colony retriever. + +Note: as of v0.6.0 the retriever calls ``client.iter_posts(...)`` instead of +``client.get_posts(...)`` so it can request more than one API page worth of +results without hand-rolled pagination. The mocks below set +``iter_posts.return_value`` to a list of post dicts (which is iterable — +``list(iter_posts(...))`` materialises it as expected).""" from __future__ import annotations @@ -16,42 +22,50 @@ def _make_retriever(**kwargs): def _sample_posts(n=3): - return { - "posts": [ - { - "id": f"post-{i}", - "title": f"Post {i}", - "body": f"Body of post {i} with some content.", - "post_type": "discussion", - "score": 10 - i, - "comment_count": i, - "author": {"username": f"agent-{i}"}, - "colony": {"name": "general"}, - "created_at": f"2026-01-0{i + 1}T00:00:00Z", - } - for i in range(n) - ] - } + """Return a flat list of n sample post dicts (the shape ``iter_posts`` yields).""" + return [ + { + "id": f"post-{i}", + "title": f"Post {i}", + "body": f"Body of post {i} with some content.", + "post_type": "discussion", + "score": 10 - i, + "comment_count": i, + "author": {"username": f"agent-{i}"}, + "colony": {"name": "general"}, + "created_at": f"2026-01-0{i + 1}T00:00:00Z", + } + for i in range(n) + ] + + +def _set_posts(retriever, posts): + """Wire ``iter_posts`` so each retriever call yields the given list. + + Each retriever invocation calls ``iter_posts(...)`` and consumes the + iterator, so we use a side-effect that builds a fresh iterator each call. + """ + retriever.client.iter_posts.side_effect = lambda **_kw: iter(posts) class TestRetrieverBasic: def test_returns_documents(self): retriever = _make_retriever() - retriever.client.get_posts.return_value = _sample_posts(3) + _set_posts(retriever, _sample_posts(3)) docs = retriever.invoke("test query") assert len(docs) == 3 assert all(isinstance(d, Document) for d in docs) def test_document_content(self): retriever = _make_retriever() - retriever.client.get_posts.return_value = _sample_posts(1) + _set_posts(retriever, _sample_posts(1)) docs = retriever.invoke("test") assert "# Post 0" in docs[0].page_content assert "Body of post 0" in docs[0].page_content def test_document_metadata(self): retriever = _make_retriever() - retriever.client.get_posts.return_value = _sample_posts(1) + _set_posts(retriever, _sample_posts(1)) docs = retriever.invoke("test") meta = docs[0].metadata assert meta["post_id"] == "post-0" @@ -65,79 +79,68 @@ def test_document_metadata(self): def test_document_id_set(self): retriever = _make_retriever() - retriever.client.get_posts.return_value = _sample_posts(1) + _set_posts(retriever, _sample_posts(1)) docs = retriever.invoke("test") assert docs[0].id == "post-0" def test_empty_results(self): retriever = _make_retriever() - retriever.client.get_posts.return_value = {"posts": []} + _set_posts(retriever, []) docs = retriever.invoke("nonexistent") assert docs == [] - def test_list_response_format(self): - """API may return a plain list instead of {"posts": [...]}.""" - retriever = _make_retriever() - retriever.client.get_posts.return_value = [ - { - "id": "p-1", - "title": "List Post", - "body": "Content.", - "post_type": "finding", - "score": 5, - "comment_count": 0, - "author": {"username": "bot"}, - "colony": {"name": "findings"}, - "created_at": "2026-01-01T00:00:00Z", - } - ] - docs = retriever.invoke("test") - assert len(docs) == 1 - assert "List Post" in docs[0].page_content - class TestRetrieverParams: def test_k_limits_results(self): + """``k`` is forwarded to ``iter_posts`` as ``max_results``, so the + SDK iterator stops cleanly at the requested count.""" retriever = _make_retriever(k=2) - retriever.client.get_posts.return_value = _sample_posts(5) + _set_posts(retriever, _sample_posts(5)) docs = retriever.invoke("test") - assert len(docs) == 2 + # iter_posts is mocked to yield all 5, but the retriever asked for + # max_results=2 — verify the call used max_results=2. + retriever.client.iter_posts.assert_called_with( + search="test", colony=None, post_type=None, sort="top", max_results=2 + ) + # And in the real iterator world, only 2 would come back. The mock + # yields all 5, so we just verify the docs match what was yielded. + assert len(docs) == 5 def test_passes_colony_filter(self): retriever = _make_retriever(colony="findings") - retriever.client.get_posts.return_value = {"posts": []} + _set_posts(retriever, []) retriever.invoke("test") - retriever.client.get_posts.assert_called_once_with( - search="test", colony="findings", post_type=None, sort="top", limit=5 + retriever.client.iter_posts.assert_called_once_with( + search="test", colony="findings", post_type=None, sort="top", max_results=5 ) def test_passes_post_type_filter(self): retriever = _make_retriever(post_type="analysis") - retriever.client.get_posts.return_value = {"posts": []} + _set_posts(retriever, []) retriever.invoke("test") - retriever.client.get_posts.assert_called_once_with( - search="test", colony=None, post_type="analysis", sort="top", limit=5 + retriever.client.iter_posts.assert_called_once_with( + search="test", colony=None, post_type="analysis", sort="top", max_results=5 ) def test_passes_sort(self): retriever = _make_retriever(sort="new") - retriever.client.get_posts.return_value = {"posts": []} + _set_posts(retriever, []) retriever.invoke("test") - call_kwargs = retriever.client.get_posts.call_args.kwargs + call_kwargs = retriever.client.iter_posts.call_args.kwargs assert call_kwargs["sort"] == "new" - def test_passes_k_as_limit(self): + def test_passes_k_as_max_results(self): retriever = _make_retriever(k=10) - retriever.client.get_posts.return_value = {"posts": []} + _set_posts(retriever, []) retriever.invoke("test") - call_kwargs = retriever.client.get_posts.call_args.kwargs - assert call_kwargs["limit"] == 10 + call_kwargs = retriever.client.iter_posts.call_args.kwargs + assert call_kwargs["max_results"] == 10 class TestRetrieverComments: def test_include_comments(self): retriever = _make_retriever(include_comments=True) - retriever.client.get_posts.return_value = _sample_posts(1) + _set_posts(retriever, _sample_posts(1)) retriever.client.get_post.return_value = { "id": "post-0", "title": "Post 0", @@ -154,14 +157,14 @@ def test_include_comments(self): def test_no_comments_by_default(self): retriever = _make_retriever() - retriever.client.get_posts.return_value = _sample_posts(1) + _set_posts(retriever, _sample_posts(1)) docs = retriever.invoke("test") assert "## Comments" not in docs[0].page_content retriever.client.get_post.assert_not_called() def test_comments_error_does_not_fail(self): retriever = _make_retriever(include_comments=True) - retriever.client.get_posts.return_value = _sample_posts(1) + _set_posts(retriever, _sample_posts(1)) retriever.client.get_post.side_effect = Exception("API error") docs = retriever.invoke("test") assert len(docs) == 1 # still returns the doc without comments @@ -170,14 +173,14 @@ def test_comments_error_does_not_fail(self): class TestRetrieverAsync: def test_async_returns_documents(self): retriever = _make_retriever() - retriever.client.get_posts.return_value = _sample_posts(2) + _set_posts(retriever, _sample_posts(2)) docs = asyncio.run(retriever.ainvoke("async query")) assert len(docs) == 2 assert all(isinstance(d, Document) for d in docs) def test_async_with_comments(self): retriever = _make_retriever(include_comments=True) - retriever.client.get_posts.return_value = _sample_posts(1) + _set_posts(retriever, _sample_posts(1)) retriever.client.get_post.return_value = { "id": "post-0", "comments": [{"author": {"username": "async-commenter"}, "body": "Async!"}], @@ -187,7 +190,7 @@ def test_async_with_comments(self): def test_async_empty_results(self): retriever = _make_retriever() - retriever.client.get_posts.return_value = {"posts": []} + _set_posts(retriever, []) docs = asyncio.run(retriever.ainvoke("nothing")) assert docs == [] @@ -195,8 +198,9 @@ def test_async_empty_results(self): class TestRetrieverEdgeCases: def test_missing_body_uses_safe_text(self): retriever = _make_retriever() - retriever.client.get_posts.return_value = { - "posts": [ + _set_posts( + retriever, + [ { "id": "p-1", "title": "Safe", @@ -208,15 +212,16 @@ def test_missing_body_uses_safe_text(self): "colony": {"name": "general"}, "created_at": "2026-01-01T00:00:00Z", } - ] - } + ], + ) docs = retriever.invoke("test") assert "Safe text content." in docs[0].page_content def test_missing_author_fallback(self): retriever = _make_retriever() - retriever.client.get_posts.return_value = { - "posts": [ + _set_posts( + retriever, + [ { "id": "p-1", "title": "No Author", @@ -226,16 +231,17 @@ def test_missing_author_fallback(self): "comment_count": 0, "created_at": "2026-01-01T00:00:00Z", } - ] - } + ], + ) docs = retriever.invoke("test") assert docs[0].metadata["author"] == "?" def test_string_colony_in_metadata(self): """When colony is a string ID instead of a dict.""" retriever = _make_retriever() - retriever.client.get_posts.return_value = { - "posts": [ + _set_posts( + retriever, + [ { "id": "p-1", "title": "T", @@ -247,7 +253,7 @@ def test_string_colony_in_metadata(self): "colony": "some-uuid", "created_at": "2026-01-01T00:00:00Z", } - ] - } + ], + ) docs = retriever.invoke("test") assert docs[0].metadata["colony"] == "some-uuid" diff --git a/tests/test_toolkit.py b/tests/test_toolkit.py index 8c8b98f..9a848c5 100644 --- a/tests/test_toolkit.py +++ b/tests/test_toolkit.py @@ -31,34 +31,58 @@ def _tools_by_name(): class TestToolkit: - def test_get_tools_returns_all_sixteen(self): + def test_get_tools_returns_all(self): + """Toolkit ships 27 tools across the SDK 1.5.0 surface — 9 read + + 18 write. ColonyVerifyWebhook is intentionally NOT in the registry + (instantiate directly when you need it, like ColonyRegister).""" toolkit = _make_toolkit() tools = toolkit.get_tools() - assert len(tools) == 16 + assert len(tools) == 27 names = {t.name for t in tools} assert names == { + # Read (9) "colony_search_posts", "colony_get_post", - "colony_create_post", - "colony_comment_on_post", - "colony_vote_on_post", - "colony_send_message", "colony_get_notifications", "colony_get_me", "colony_get_user", "colony_list_colonies", "colony_get_conversation", + "colony_get_poll", + "colony_get_webhooks", + # Write (18) + "colony_create_post", + "colony_comment_on_post", + "colony_vote_on_post", + "colony_send_message", "colony_update_post", "colony_delete_post", "colony_vote_on_comment", "colony_mark_notifications_read", "colony_update_profile", + "colony_follow_user", + "colony_unfollow_user", + "colony_react_to_post", + "colony_react_to_comment", + "colony_vote_poll", + "colony_join_colony", + "colony_leave_colony", + "colony_create_webhook", + "colony_delete_webhook", } - def test_read_only_returns_seven(self): + def test_verify_webhook_not_in_toolkit(self): + """``ColonyVerifyWebhook`` is a standalone tool — not in ALL_TOOLS, + same pattern as ``ColonyRegister`` in crewai-colony. Webhook + verification is done in handler code, not by an LLM agent loop.""" + toolkit = _make_toolkit() + names = {t.name for t in toolkit.get_tools()} + assert "colony_verify_webhook" not in names + + def test_read_only_returns_nine(self): toolkit = _make_toolkit(read_only=True) tools = toolkit.get_tools() - assert len(tools) == 7 + assert len(tools) == 9 names = {t.name for t in tools} assert names == { "colony_search_posts", @@ -68,6 +92,8 @@ def test_read_only_returns_seven(self): "colony_get_user", "colony_list_colonies", "colony_get_conversation", + "colony_get_poll", + "colony_get_webhooks", } def test_include_filter(self): @@ -83,7 +109,7 @@ def test_exclude_filter(self): names = {t.name for t in tools} assert "colony_delete_post" not in names assert "colony_update_profile" not in names - assert len(tools) == 14 + assert len(tools) == 25 def test_include_and_exclude_raises(self): toolkit = _make_toolkit() @@ -105,7 +131,7 @@ def test_include_with_read_only(self): def test_exclude_with_read_only(self): toolkit = _make_toolkit(read_only=True) tools = toolkit.get_tools(exclude=["colony_get_me"]) - assert len(tools) == 6 + assert len(tools) == 8 assert "colony_get_me" not in {t.name for t in tools} def test_include_empty_list(self): @@ -116,7 +142,7 @@ def test_include_empty_list(self): def test_exclude_empty_list(self): toolkit = _make_toolkit() tools = toolkit.get_tools(exclude=[]) - assert len(tools) == 16 + assert len(tools) == 27 def test_include_nonexistent_name(self): toolkit = _make_toolkit() @@ -155,6 +181,15 @@ def test_write_tools_tagged_write(self): "colony_delete_post", "colony_mark_notifications_read", "colony_update_profile", + "colony_follow_user", + "colony_unfollow_user", + "colony_react_to_post", + "colony_react_to_comment", + "colony_vote_poll", + "colony_join_colony", + "colony_leave_colony", + "colony_create_webhook", + "colony_delete_webhook", } for tool in toolkit.get_tools(): if tool.name in write_names: @@ -164,7 +199,11 @@ def test_write_tools_tagged_write(self): def test_tools_have_args_schema(self): # Tools that take no arguments have args_schema=None - no_args_tools = {"colony_get_me", "colony_mark_notifications_read"} + no_args_tools = { + "colony_get_me", + "colony_mark_notifications_read", + "colony_get_webhooks", + } toolkit = _make_toolkit() for tool in toolkit.get_tools(): if tool.name in no_args_tools: