From e0146a23ad7bb8d793bfce5b1b1b9dc2a51807ee Mon Sep 17 00:00:00 2001 From: "filip.chytil" Date: Wed, 18 Mar 2026 18:14:53 +0100 Subject: [PATCH] added new vector selector middleware & new message & small ref --- docs/api-reference.md | 1 + docs/concepts/memory.md | 3 +- docs/concepts/middleware.md | 120 +++++++++ docs/examples.md | 60 +++++ examples/agents/middleware/README.md | 2 + .../middleware/llm_tool_selector_example.py | 8 +- .../vector_tool_selector_example.py | 253 ++++++++++++++++++ examples/tool-usage/main.py | 9 +- .../src/tiny_anthropic/utils.py | 3 +- packages/tiny_gemini/src/tiny_gemini/utils.py | 3 +- .../src/tiny_mistralai/utils.py | 3 +- packages/tiny_openai/src/tiny_openai/utils.py | 3 +- pyproject.toml | 1 + tinygent/agents/map_agent.py | 15 +- tinygent/agents/middleware/__init__.py | 4 + .../agents/middleware/base_tool_selector.py | 117 ++++++++ .../agents/middleware/llm_tool_selector.py | 163 +++++------ tinygent/agents/middleware/register.py | 11 + .../agents/middleware/vector_tool_selector.py | 177 ++++++++++++ tinygent/agents/multi_step_agent.py | 5 +- tinygent/agents/squad_agent.py | 3 +- tinygent/core/datamodels/messages.py | 20 +- tinygent/core/datamodels/tool.py | 3 + uv.lock | 2 + 24 files changed, 875 insertions(+), 114 deletions(-) create mode 100644 examples/agents/middleware/vector_tool_selector_example.py create mode 100644 tinygent/agents/middleware/base_tool_selector.py create mode 100644 tinygent/agents/middleware/vector_tool_selector.py diff --git a/docs/api-reference.md b/docs/api-reference.md index 1765664..f7602d5 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -401,6 +401,7 @@ class MyMiddleware(TinyBaseMiddleware): ```python from tinygent.core.datamodels.messages import ( + TinyUserMessage, TinyHumanMessage, TinyChatMessage, TinySystemMessage, diff --git a/docs/concepts/memory.md b/docs/concepts/memory.md index 86ffcf6..e399d84 100644 --- a/docs/concepts/memory.md +++ b/docs/concepts/memory.md @@ -286,9 +286,10 @@ Tinygent supports multiple message types: ```python from tinygent.core.datamodels.messages import ( - TinyHumanMessage, # User messages + TinyHumanMessage, # Human messages TinyChatMessage, # AI responses TinySystemMessage, # System prompts + TinyUserMessage, # User prompts TinyPlanMessage, # Planning messages TinyToolMessage, # Tool results ) diff --git a/docs/concepts/middleware.md b/docs/concepts/middleware.md index 8118429..79fb51e 100644 --- a/docs/concepts/middleware.md +++ b/docs/concepts/middleware.md @@ -801,6 +801,126 @@ agent = TinyMultiStepAgent( --- +### TinyVectorToolSelectorMiddleware + +Selects the most relevant tools for each LLM call using semantic similarity between the user query and tool descriptions. No secondary LLM call required — selection is done purely via vector embeddings and cosine similarity. + +**Features:** +- Uses an embedder to compute cosine similarity between the query and each tool description +- Ranks tools by similarity and selects the top candidates +- Supports always-include list for critical tools +- Configurable maximum tools limit and minimum similarity threshold +- Customizable query and tool transform functions for fine-grained embedding control + +**How It Works:** +1. Before each LLM call, the last `TinyHumanMessage` is embedded as the query +2. Each tool's name and description is embedded +3. Cosine similarity is computed between the query and every tool embedding +4. Tools are ranked by similarity; only those above `similarity_threshold` (up to `max_tools`) are passed to the main agent + +**Basic Usage:** + +```python +from tinygent.agents.middleware import TinyVectorToolSelectorMiddleware +from tinygent.agents import TinyMultiStepAgent +from tinygent.core.factory import build_embedder, build_llm + +selector = TinyVectorToolSelectorMiddleware( + embedder=build_embedder('openai:text-embedding-3-small'), + similarity_threshold=0.5, + max_tools=5, +) + +agent = TinyMultiStepAgent( + llm=build_llm('openai:gpt-4o'), + tools=[search, calculator, weather, database, email, calendar, notes], + middleware=[selector], +) +``` + +**Always Include Critical Tools:** + +```python +selector = TinyVectorToolSelectorMiddleware( + embedder=build_embedder('openai:text-embedding-3-small'), + similarity_threshold=0.4, + max_tools=5, + always_include=[search], +) +``` + +**Custom Transform Functions:** + +```python +from tinygent.core.datamodels.tool import AbstractTool +from tinygent.core.types.io.llm_io_input import TinyLLMInput + +def query_transform(llm_input: TinyLLMInput) -> str: + # Embed the last 3 messages combined for richer context + recent = llm_input.messages[-3:] + return ' '.join(m.content for m in recent if hasattr(m, 'content')) + +def tool_transform(tool: AbstractTool) -> str: + # Repeat name to increase its weight in the embedding + return f'{tool.info.name} {tool.info.name}: {tool.info.description}' + +selector = TinyVectorToolSelectorMiddleware( + embedder=build_embedder('openai:text-embedding-3-small'), + similarity_threshold=0.45, + max_tools=4, + query_transform_fn=query_transform, + tool_transform_fn=tool_transform, +) +``` + +**Using Config Factory:** + +Transform functions and the similarity threshold can also be set directly on the config: + +```python +from tinygent.agents.middleware import TinyVectorToolSelectorMiddlewareConfig + +config = TinyVectorToolSelectorMiddlewareConfig( + embedder='openai:text-embedding-3-small', + similarity_threshold=0.5, + max_tools=5, + always_include=['search'], + query_transform_fn=query_transform, + tool_transform_fn=tool_transform, +) + +selector = config.build() +``` + +**Factory Configuration Options:** + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `type` | `Literal['vector_tool_classifier']` | `'vector_tool_classifier'` | Type identifier (frozen) | +| `embedder` | `AbstractEmbedderConfig \| AbstractEmbedder` | Required | Embedder used to compute similarity. Can be a string like `'openai:text-embedding-3-small'` or an embedder instance | +| `similarity_threshold` | `float \| None` | `None` | Minimum cosine similarity score for a tool to be selected. `None` = no threshold | +| `max_tools` | `int \| None` | `None` | Maximum number of tools to select. `None` = no limit | +| `always_include` | `list[str] \| None` | `None` | List of tool names to always include regardless of similarity score | +| `query_transform_fn` | `Callable[[TinyLLMInput], str] \| None` | `None` | Custom function to extract the query string from the LLM input. Defaults to last `TinyHumanMessage` found | +| `tool_transform_fn` | `Callable[[AbstractTool], str] \| None` | `None` | Custom function to produce the text embedded for each tool. Defaults to `"name - description"` | + +**LLM vs. Vector Tool Selector:** + +| | `TinyLLMToolSelectorMiddleware` | `TinyVectorToolSelectorMiddleware` | +|---|---|---| +| Selection method | Secondary LLM call | Cosine similarity | +| Extra API cost | Yes (LLM tokens) | Yes (embeddings, cheaper) | +| Latency | Higher | Lower | +| Accuracy | Higher (understands context) | Good (semantic similarity) | +| Custom logic | Via prompt template | Via transform functions | + +**When to Use:** +- You have 10+ tools and want lower latency/cost than the LLM selector +- Tool descriptions are semantically distinct +- You want deterministic, reproducible selection behavior + +--- + ## Next Steps - **[Agents](agents.md)**: Use middleware with agents diff --git a/docs/examples.md b/docs/examples.md index 40072d3..38d6ce3 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -323,6 +323,66 @@ Three custom middleware examples: uv run examples/agents/middleware/main.py ``` +--- + +#### 6. LLM Tool Selector Middleware + +**Location**: `examples/agents/middleware/llm_tool_selector_example.py` + +Demonstrates intelligent tool selection using a secondary LLM before each agent call. + +**Run:** + +```bash +uv run examples/agents/middleware/llm_tool_selector_example.py +``` + +--- + +#### 7. Vector Tool Selector Middleware + +**Location**: `examples/agents/middleware/vector_tool_selector_example.py` + +Demonstrates tool selection using semantic similarity (embeddings + cosine similarity) — no secondary LLM call needed. + +**Run:** + +```bash +uv run examples/agents/middleware/vector_tool_selector_example.py +``` + +**Highlights:** + +```python +from tinygent.agents.middleware import TinyVectorToolSelectorMiddlewareConfig +from tinygent.core.factory import build_embedder + +# Basic: embed query, rank tools by cosine similarity +selector = TinyVectorToolSelectorMiddlewareConfig( + embedder=build_embedder('openai:text-embedding-3-small'), + max_tools=4, +) + +# Always include a critical tool regardless of similarity +selector = TinyVectorToolSelectorMiddlewareConfig( + embedder=build_embedder('openai:text-embedding-3-small'), + max_tools=4, + always_include=[greet], +) + +# Custom transform functions for fine-grained embedding control +from tinygent.agents.middleware.vector_tool_selector import TinyVectorToolSelectorMiddleware + +selector = TinyVectorToolSelectorMiddleware( + embedder=build_embedder('openai:text-embedding-3-small'), + max_tools=4, + query_transform_fn=lambda llm_input: ' '.join( + m.content for m in llm_input.messages[-3:] if hasattr(m, 'content') + ), + tool_transform_fn=lambda tool: f'{tool.info.name}: {tool.info.description}', +) +``` + **Highlights:** ```python diff --git a/examples/agents/middleware/README.md b/examples/agents/middleware/README.md index f203c06..0ddb585 100644 --- a/examples/agents/middleware/README.md +++ b/examples/agents/middleware/README.md @@ -8,6 +8,8 @@ This example demonstrates how to use **middleware** in TinyGent agents. Middlewa uv sync --extra openai uv run examples/agents/middleware/main.py +uv run examples/agents/middleware/llm_tool_selector_example.py +uv run examples/agents/middleware/vector_tool_selector_example.py ``` ## Concept diff --git a/examples/agents/middleware/llm_tool_selector_example.py b/examples/agents/middleware/llm_tool_selector_example.py index 66426e1..53b76b5 100644 --- a/examples/agents/middleware/llm_tool_selector_example.py +++ b/examples/agents/middleware/llm_tool_selector_example.py @@ -108,7 +108,9 @@ def example_1_basic_selection() -> None: middleware=[selector], ) - result = agent.run('Greet Alice and then add 5 and 7') + result = agent.run( + 'Greet Alice and then tell her what is weather like in San Francisco' + ) print(f'Result: {result}\n') @@ -151,7 +153,7 @@ def example_3_always_include() -> None: selector = build_middleware( 'llm_tool_selector', llm=build_llm('openai:gpt-4o-mini'), - always_include=['greet'], + always_include=[greet], ) agent = build_agent( @@ -182,7 +184,7 @@ def example_4_combined_constraints() -> None: selector = TinyLLMToolSelectorMiddlewareConfig( llm=build_llm('openai:gpt-4o-mini'), max_tools=4, - always_include=['greet'], + always_include=[greet], ) agent = build_agent( diff --git a/examples/agents/middleware/vector_tool_selector_example.py b/examples/agents/middleware/vector_tool_selector_example.py new file mode 100644 index 0000000..60e0be5 --- /dev/null +++ b/examples/agents/middleware/vector_tool_selector_example.py @@ -0,0 +1,253 @@ +from typing import TypeGuard + +from pydantic import Field + +from tinygent.agents.middleware.vector_tool_selector import ( + TinyVectorToolSelectorMiddlewareConfig, +) +from tinygent.core.datamodels.tool import AbstractTool +from tinygent.core.datamodels.tool import AbstractToolConfig +from tinygent.core.factory import build_agent +from tinygent.core.factory import build_embedder +from tinygent.core.factory import build_middleware +from tinygent.core.types import TinyModel +from tinygent.logging import setup_logger +from tinygent.tools import reasoning_tool + +logging = setup_logger('info') + + +class GreetInput(TinyModel): + name: str = Field(..., description='The name of the person to greet.') + + +class CalculateInput(TinyModel): + a: int = Field(..., description='First number') + b: int = Field(..., description='Second number') + + +class WeatherInput(TinyModel): + location: str = Field(..., description='Location to get weather for') + + +class NewsInput(TinyModel): + topic: str = Field(..., description='News topic to search for') + + +class TranslateInput(TinyModel): + text: str = Field(..., description='Text to translate') + language: str = Field(..., description='Target language') + + +class SummarizeInput(TinyModel): + text: str = Field(..., description='Text to summarize') + + +@reasoning_tool +def greet(data: GreetInput) -> str: + """Return a simple greeting.""" + return f'Hello, {data.name}!' + + +@reasoning_tool +def add_numbers(data: CalculateInput) -> str: + """Add two numbers together.""" + result = data.a + data.b + return f'The sum of {data.a} and {data.b} is {result}' + + +@reasoning_tool +def multiply_numbers(data: CalculateInput) -> str: + """Multiply two numbers together.""" + result = data.a * data.b + return f'The product of {data.a} and {data.b} is {result}' + + +@reasoning_tool +def divide_numbers(data: CalculateInput) -> str: + """Divide two numbers.""" + if data.b == 0: + return 'Error: Division by zero' + result = data.a / data.b + return f'The quotient of {data.a} and {data.b} is {result}' + + +@reasoning_tool +def subtract_numbers(data: CalculateInput) -> str: + """Subtract two numbers.""" + result = data.a - data.b + return f'The difference of {data.a} and {data.b} is {result}' + + +@reasoning_tool +def get_weather(data: WeatherInput) -> str: + """Get weather information for a location (mock implementation).""" + return f'The weather in {data.location} is sunny with a temperature of 72F' + + +@reasoning_tool +def get_news(data: NewsInput) -> str: + """Get news about a topic (mock implementation).""" + return f'Latest news about {data.topic}: Sample news article content here' + + +@reasoning_tool +def translate_text(data: TranslateInput) -> str: + """Translate text to a target language (mock implementation).""" + return f'[{data.language}] {data.text}' + + +@reasoning_tool +def summarize_text(data: SummarizeInput) -> str: + """Summarize a long piece of text (mock implementation).""" + return f'Summary: {data.text[:50]}...' + + +ALL_TOOLS: list[dict | AbstractTool | AbstractToolConfig | str] = [ + greet, + add_numbers, + multiply_numbers, + divide_numbers, + subtract_numbers, + get_weather, + get_news, + translate_text, + summarize_text, +] + + +def example_1_basic_selection() -> None: + """Example 1: Use embeddings to select relevant tools from a large set.""" + print('\nEXAMPLE 1: Basic Vector Tool Selection') + print('9 tools available, embedder selects only semantically relevant ones\n') + + selector = TinyVectorToolSelectorMiddlewareConfig( + embedder=build_embedder('openai:text-embedding-3-small'), + similarity_threshold=0.1, + ) + + agent = build_agent( + 'multistep', + llm='openai:gpt-4o-mini', + tools=ALL_TOOLS, + middleware=[selector], + ) + + result = agent.run( + 'Greet Alice and then tell her what is weather like in San Francisco' + ) + print(f'Result: {result}\n') + + +def example_2_max_tools_limit() -> None: + """Example 2: Limit maximum number of tools selected by cosine similarity.""" + print('\nEXAMPLE 2: Limit Maximum Tools') + print('9 tools available, max_tools=3\n') + + selector = build_middleware( + 'vector_tool_classifier', + embedder=build_embedder('openai:text-embedding-3-small'), + max_tools=3, + ) + + agent = build_agent( + 'multistep', + llm='openai:gpt-4o-mini', + tools=ALL_TOOLS, + middleware=[selector], + ) + + result = agent.run('What is 15 divided by 3? Then multiply the result by 4.') + print(f'Result: {result}\n') + + +def example_3_always_include() -> None: + """Example 3: Always include specific tools regardless of similarity score.""" + print('\nEXAMPLE 3: Always Include Specific Tools') + print('Always include "greet" tool + vector selection for the rest\n') + + selector = build_middleware( + 'vector_tool_classifier', + embedder=build_embedder('openai:text-embedding-3-small'), + similarity_threshold=0.1, + always_include=[greet], + max_tools=3, + ) + + agent = build_agent( + 'multistep', + llm='openai:gpt-4o-mini', + tools=ALL_TOOLS, + middleware=[selector], + ) + + result = agent.run( + 'Summarize the following: "The quick brown fox jumps over the lazy dog."' + ) + print(f'Result: {result}\n') + + +def example_4_custom_transform_fns() -> None: + """Example 4: Custom query and tool transform functions.""" + print('\nEXAMPLE 4: Custom Transform Functions') + print('Custom functions control what gets embedded for query and tools\n') + + from tinygent.agents.middleware.vector_tool_selector import ( + TinyVectorToolSelectorMiddleware, + ) + from tinygent.core.datamodels.tool import AbstractTool + from tinygent.core.types.io.llm_io_input import TinyLLMInput + + class HasContent: + content: str + + def _has_content(m: object) -> TypeGuard['HasContent']: + return hasattr(m, 'content') + + def query_transform(llm_input: TinyLLMInput) -> str: + # Embed only the last 3 messages combined instead of just the last human message + + recent = ( + llm_input.messages[-3:] + if len(llm_input.messages) >= 3 + else llm_input.messages + ) + return ' '.join(m.content for m in recent if _has_content(m)) + + def tool_transform(tool: AbstractTool) -> str: + # Include tool name twice to give it more weight in the embedding + return f'{tool.info.name} {tool.info.name}: {tool.info.description}' + + selector = TinyVectorToolSelectorMiddleware( + embedder=build_embedder('openai:text-embedding-3-small'), + similarity_threshold=0.1, + max_tools=4, + query_transform_fn=query_transform, + tool_transform_fn=tool_transform, + ) + + agent = build_agent( + 'multistep', + llm='openai:gpt-4o-mini', + tools=ALL_TOOLS, + middleware=[selector], + ) + + result = agent.run('Translate "Hello world" to Spanish') + print(f'Result: {result}\n') + + +def main() -> None: + print('\n=== VECTOR TOOL SELECTOR MIDDLEWARE EXAMPLES ===') + print('Selects relevant tools using semantic similarity (embeddings)\n') + + example_1_basic_selection() + example_2_max_tools_limit() + example_3_always_include() + example_4_custom_transform_fns() + + print('=== ALL EXAMPLES COMPLETED ===\n') + + +if __name__ == '__main__': + main() diff --git a/examples/tool-usage/main.py b/examples/tool-usage/main.py index 6cf86ef..ad283df 100644 --- a/examples/tool-usage/main.py +++ b/examples/tool-usage/main.py @@ -159,6 +159,13 @@ def search(data: SearchInput) -> str: header_print('Reasoning Tool Execution') global_registry_search = registry.get_tool('search') - global_registry_print(global_registry_search({'query': 'TinyGent query', 'reasoning': 'Model-generated reasoning for this tool invocation'})) + global_registry_print( + global_registry_search( + { + 'query': 'TinyGent query', + 'reasoning': 'Model-generated reasoning for this tool invocation', + } + ) + ) # NOTE: count and async_count are not cachable, so their cache_info will be None diff --git a/packages/tiny_anthropic/src/tiny_anthropic/utils.py b/packages/tiny_anthropic/src/tiny_anthropic/utils.py index 8c810fe..f2e9526 100644 --- a/packages/tiny_anthropic/src/tiny_anthropic/utils.py +++ b/packages/tiny_anthropic/src/tiny_anthropic/utils.py @@ -13,6 +13,7 @@ from tinygent.core.datamodels.messages import TinySystemMessage from tinygent.core.datamodels.messages import TinyToolCall from tinygent.core.datamodels.messages import TinyToolResult +from tinygent.core.datamodels.messages import TinyUserMessage from tinygent.core.types.io.llm_io_chunks import TinyChatMessageChunk from tinygent.core.types.io.llm_io_chunks import TinyLLMResultChunk from tinygent.core.types.io.llm_io_result import TinyLLMResult @@ -29,7 +30,7 @@ def tiny_prompt_to_anthropic_params( system_message: str | None = None for msg in prompt.messages: - if isinstance(msg, TinyHumanMessage): + if isinstance(msg, TinyHumanMessage) or isinstance(msg, TinyUserMessage): params.append(MessageParam(role='user', content=msg.content)) elif isinstance(msg, TinySystemMessage): diff --git a/packages/tiny_gemini/src/tiny_gemini/utils.py b/packages/tiny_gemini/src/tiny_gemini/utils.py index ec183d0..a083f8f 100644 --- a/packages/tiny_gemini/src/tiny_gemini/utils.py +++ b/packages/tiny_gemini/src/tiny_gemini/utils.py @@ -29,6 +29,7 @@ from tinygent.core.datamodels.messages import TinySystemMessage from tinygent.core.datamodels.messages import TinyToolCall from tinygent.core.datamodels.messages import TinyToolResult +from tinygent.core.datamodels.messages import TinyUserMessage from tinygent.core.types.io.llm_io_chunks import TinyLLMResultChunk from tinygent.core.types.io.llm_io_result import TinyLLMResult @@ -101,7 +102,7 @@ def tiny_prompt_to_gemini_params( params: list[Content] = [] for msg in prompt.messages: - if isinstance(msg, TinyHumanMessage): + if isinstance(msg, TinyHumanMessage) or isinstance(msg, TinyUserMessage): params.append(UserContent(str(msg.content))) elif isinstance(msg, TinySystemMessage): diff --git a/packages/tiny_mistralai/src/tiny_mistralai/utils.py b/packages/tiny_mistralai/src/tiny_mistralai/utils.py index 905fbd9..d7b25cb 100644 --- a/packages/tiny_mistralai/src/tiny_mistralai/utils.py +++ b/packages/tiny_mistralai/src/tiny_mistralai/utils.py @@ -25,6 +25,7 @@ from tinygent.core.datamodels.messages import TinySystemMessage from tinygent.core.datamodels.messages import TinyToolCall from tinygent.core.datamodels.messages import TinyToolResult +from tinygent.core.datamodels.messages import TinyUserMessage from tinygent.core.types.io.llm_io_chunks import TinyLLMResultChunk from tinygent.core.types.io.llm_io_chunks import TinyToolCallChunk from tinygent.core.types.io.llm_io_result import TinyLLMResult @@ -49,7 +50,7 @@ def tiny_prompt_to_mistralai_params( params: list[ChatCompletionMessageParams] = [] for msg in prompt.messages: - if isinstance(msg, TinyHumanMessage): + if isinstance(msg, TinyHumanMessage) or isinstance(msg, TinyUserMessage): params.append(UserMessage(role='user', content=str(msg.content))) elif isinstance(msg, TinySystemMessage): diff --git a/packages/tiny_openai/src/tiny_openai/utils.py b/packages/tiny_openai/src/tiny_openai/utils.py index 6d454af..5be69e6 100644 --- a/packages/tiny_openai/src/tiny_openai/utils.py +++ b/packages/tiny_openai/src/tiny_openai/utils.py @@ -26,6 +26,7 @@ from tinygent.core.datamodels.messages import TinyToolCall from tinygent.core.datamodels.messages import TinyToolCallChunk from tinygent.core.datamodels.messages import TinyToolResult +from tinygent.core.datamodels.messages import TinyUserMessage from tinygent.core.types.io.llm_io_chunks import TinyLLMResultChunk from tinygent.core.types.io.llm_io_result import TinyLLMResult @@ -90,7 +91,7 @@ def tiny_prompt_to_openai_params( params: list[ChatCompletionMessageParam] = [] for msg in prompt.messages: - if isinstance(msg, TinyHumanMessage): + if isinstance(msg, TinyHumanMessage) or isinstance(msg, TinyUserMessage): params.append( ChatCompletionUserMessageParam(role='user', content=msg.content) ) diff --git a/pyproject.toml b/pyproject.toml index 888482d..21ee73b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "async-lru>=2.0.5", "jinja2>=3.1.6", "langchain-core>=0.3.74", + "numpy>=2.3.4", "opentelemetry-exporter-otlp>=1.38.0", "opentelemetry-sdk>=1.38.0", "pydantic>=2.11.7", diff --git a/tinygent/agents/map_agent.py b/tinygent/agents/map_agent.py index 8ba71d6..02e8a55 100644 --- a/tinygent/agents/map_agent.py +++ b/tinygent/agents/map_agent.py @@ -17,6 +17,7 @@ from tinygent.core.datamodels.messages import TinyChatMessage from tinygent.core.datamodels.messages import TinyHumanMessage from tinygent.core.datamodels.messages import TinySystemMessage +from tinygent.core.datamodels.messages import TinyUserMessage from tinygent.core.datamodels.middleware import AbstractMiddleware from tinygent.core.datamodels.tool import AbstractTool from tinygent.core.runtime.executors import run_async_in_executor @@ -203,7 +204,7 @@ class subgoal(TinyModel): messages = TinyLLMInput(messages=[*self.memory.copy_chat_messages()]) messages.add_at_end( - TinyHumanMessage( + TinyUserMessage( content=render_template( self.prompt_template.task_decomposer.user, {'question': input_txt, 'max_subquestions': self.max_plan_length}, @@ -258,7 +259,7 @@ async def _actor( formatted_proposals = [p.sum for p in prev_proposals] messages = TinyLLMInput(messages=[*self.memory.copy_chat_messages()]) messages.add_at_end( - TinyHumanMessage( + TinyUserMessage( content=render_template( prompt_templ.user, { @@ -282,7 +283,7 @@ async def _actor( ) ) messages.add_at_end( - TinyHumanMessage( + TinyUserMessage( content=render_template( fix_prompt_templ.user, {'question': subgoal, 'validation': '\n'.join(feedback)}, @@ -340,7 +341,7 @@ class _MonitorResult(TinyModel): formatted_proposals = [p.sum for p in prev_proposals] messages = TinyLLMInput() messages.add_at_end( - TinyHumanMessage( + TinyUserMessage( content=render_template( prompt_templ.user, { @@ -483,7 +484,7 @@ async def _predictor( messages = TinyLLMInput() messages.add_at_end( - TinyHumanMessage( + TinyUserMessage( content=render_template( self.prompt_template.predictor.user, {'state': state.next_state, 'proposed_action': action.sum}, @@ -591,7 +592,7 @@ async def _evaluator( messages = TinyLLMInput() messages.add_at_end( - TinyHumanMessage( + TinyUserMessage( content=render_template( self.prompt_template.action_proposal.actor.evaluator.user, {'state': state.next_state, 'subgoal': subgoal}, @@ -634,7 +635,7 @@ async def _orchestrator( messages = TinyLLMInput() messages.add_at_end( - TinyHumanMessage( + TinyUserMessage( content=render_template( self.prompt_template.orchestrator.user, {'question': subgoal, 'answer': state.next_state}, diff --git a/tinygent/agents/middleware/__init__.py b/tinygent/agents/middleware/__init__.py index 1f0f523..5f3608d 100644 --- a/tinygent/agents/middleware/__init__.py +++ b/tinygent/agents/middleware/__init__.py @@ -5,6 +5,8 @@ from .tool_limiter import TinyToolCallLimiterMiddleware from .tool_limiter import TinyToolCallLimiterMiddlewareConfig from .tool_limiter import ToolCallBlockedException +from .vector_tool_selector import TinyVectorToolSelectorMiddleware +from .vector_tool_selector import TinyVectorToolSelectorMiddlewareConfig __all__ = [ 'TinyBaseMiddleware', @@ -14,4 +16,6 @@ 'TinyLLMToolSelectorMiddlewareConfig', 'TinyToolCallLimiterMiddleware', 'TinyToolCallLimiterMiddlewareConfig', + 'TinyVectorToolSelectorMiddleware', + 'TinyVectorToolSelectorMiddlewareConfig', ] diff --git a/tinygent/agents/middleware/base_tool_selector.py b/tinygent/agents/middleware/base_tool_selector.py new file mode 100644 index 0000000..7bcb5ed --- /dev/null +++ b/tinygent/agents/middleware/base_tool_selector.py @@ -0,0 +1,117 @@ +from dataclasses import dataclass +import logging +from typing import Any +from typing import Generic +from typing import TypeVar +from typing import cast + +from pydantic import Field + +from tinygent.agents.middleware.base import TinyBaseMiddleware +from tinygent.agents.middleware.base import TinyBaseMiddlewareConfig +from tinygent.core.datamodels.tool import AbstractTool +from tinygent.core.datamodels.tool import AbstractToolConfig +from tinygent.core.factory import build_tool + +logger = logging.getLogger(__name__) + +T = TypeVar('T', bound='TinyBaseToolSelectorMiddleware') + + +@dataclass +class ToolSelectorCandidates: + selected_tools: set[AbstractTool] + mapped_tools: dict[str, AbstractTool] + tools: list[AbstractTool] + remaining_tools: list[AbstractTool] + remaining_space: int | None + + +class TinyBaseToolSelectorMiddlewareConfig( + TinyBaseMiddlewareConfig[T], + Generic[T], +): + """Configuration for BaseToolSelector""" + + max_tools: int | None = Field(default=None) + + always_include: list[str | AbstractTool | AbstractToolConfig] | None = Field( + default=None + ) + + def build_base_kwargs(self) -> dict: + always_include: list[AbstractTool] | None = None + + if self.always_include: + always_include = [ + t if isinstance(t, AbstractTool) else build_tool(t) + for t in self.always_include + ] + + return { + 'max_tools': self.max_tools, + 'always_include': always_include, + } + + +class TinyBaseToolSelectorMiddleware(TinyBaseMiddleware): + def __init__( + self, + max_tools: int | None = None, + always_include: list[AbstractTool] | None = None, + ) -> None: + self.max_tools = max_tools + self.always_include = always_include + + if always_include and max_tools and len(always_include) > max_tools: + logger.warning( + 'always_include contains %d items which exceeds max_tools=%d; ' + 'increasing max_tools to %d to fit always_include.', + len(always_include), + max_tools, + len(always_include), + ) + self.max_tools = len(always_include) + + def _prepare_candidates( + self, kwargs: dict[str, Any] + ) -> ToolSelectorCandidates | None: + """Build the initial candidate set from always_include and check max_tools. + + Returns None in two cases: + - No tools are present in kwargs (nothing to select from). + - max_tools is set and already exhausted by always_include tools alone + (kwargs['tools'] is updated in place before returning None). + + Otherwise returns a ToolSelectorCandidates with: + - selected_tools: tools pinned via always_include + - mapped_tools: name → tool mapping for all available tools + - tools: full list of available tools + - remaining_space: slots left after always_include (None = unlimited) + """ + if not kwargs.get('tools'): + return None + + selected_tools: set[AbstractTool] = set() + tools = cast(list[AbstractTool], kwargs.get('tools', [])) + mapped_tools: dict[str, AbstractTool] = {t.info.name: t for t in tools} + + if self.always_include: + for tool in self.always_include: + if t := mapped_tools.get(tool.info.name): + selected_tools.add(t) + + remaining_space: int | None = None + if self.max_tools: + remaining_space = self.max_tools - len(selected_tools) + if remaining_space <= 0: + kwargs['tools'] = selected_tools + return None + + return ToolSelectorCandidates( + selected_tools=selected_tools, + mapped_tools=mapped_tools, + tools=tools, + remaining_tools=[t for t in tools if t not in selected_tools], + remaining_space=remaining_space, + ) diff --git a/tinygent/agents/middleware/llm_tool_selector.py b/tinygent/agents/middleware/llm_tool_selector.py index f3f35b6..e900cd1 100644 --- a/tinygent/agents/middleware/llm_tool_selector.py +++ b/tinygent/agents/middleware/llm_tool_selector.py @@ -3,12 +3,13 @@ import logging from typing import Any from typing import Literal -from typing import cast from pydantic import Field -from tinygent.agents.middleware.base import TinyBaseMiddleware -from tinygent.agents.middleware.base import TinyBaseMiddlewareConfig +from tinygent.agents.middleware.base_tool_selector import TinyBaseToolSelectorMiddleware +from tinygent.agents.middleware.base_tool_selector import ( + TinyBaseToolSelectorMiddlewareConfig, +) from tinygent.core.datamodels.llm import AbstractLLM from tinygent.core.datamodels.llm import AbstractLLMConfig from tinygent.core.datamodels.messages import TinyHumanMessage @@ -25,11 +26,9 @@ logger = logging.getLogger(__name__) -_DEFAULT_PROMPT = get_llm_tool_selector_prompt_template() - class TinyLLMToolSelectorMiddlewareConfig( - TinyBaseMiddlewareConfig['TinyLLMToolSelectorMiddleware'] + TinyBaseToolSelectorMiddlewareConfig['TinyLLMToolSelectorMiddleware'] ): """Configuration for LLMToolSelector Middleware.""" @@ -37,22 +36,19 @@ class TinyLLMToolSelectorMiddlewareConfig( llm: AbstractLLMConfig | AbstractLLM = Field(...) - prompt_template: LLMToolSelectorPromptTemplate = Field(default=_DEFAULT_PROMPT) - - max_tools: int | None = Field(default=None) - - always_include: list[str] | None = Field(default=None) + prompt_template: LLMToolSelectorPromptTemplate = Field( + default_factory=get_llm_tool_selector_prompt_template + ) def build(self) -> TinyLLMToolSelectorMiddleware: return TinyLLMToolSelectorMiddleware( llm=self.llm if isinstance(self.llm, AbstractLLM) else build_llm(self.llm), prompt_template=self.prompt_template, - max_tools=self.max_tools, - always_include=self.always_include, + **self.build_base_kwargs(), ) -class TinyLLMToolSelectorMiddleware(TinyBaseMiddleware): +class TinyLLMToolSelectorMiddleware(TinyBaseToolSelectorMiddleware): """Middleware that intelligently selects relevant tools using an LLM. Before each LLM call, this middleware uses a dedicated selection LLM to analyze @@ -67,31 +63,21 @@ class TinyLLMToolSelectorMiddleware(TinyBaseMiddleware): llm: LLM to use for tool selection (typically a fast, cost-effective model) prompt_template: Template for tool selection prompt (default provided) max_tools: Maximum number of tools to select (None = no limit) - always_include: List of tool names to always include in selection (None = no always-include list) + always_include: List of tools to always include in selection (None = no always-include list) """ def __init__( self, *, llm: AbstractLLM, - prompt_template: LLMToolSelectorPromptTemplate = _DEFAULT_PROMPT, + prompt_template: LLMToolSelectorPromptTemplate | None = None, max_tools: int | None = None, - always_include: list[str] | None = None, + always_include: list[AbstractTool] | None = None, ) -> None: + super().__init__(max_tools, always_include) + self.llm = llm - self.prompt_template = prompt_template - self.max_tools = max_tools - self.always_include = always_include - - if always_include and max_tools and len(always_include) > max_tools: - logger.warning( - 'always_include contains %d items which exceeds max_tools=%d; ' - 'increasing max_tools to %d to fit always_include.', - len(always_include), - max_tools, - len(always_include), - ) - self.max_tools = len(always_include) + self.prompt_template = prompt_template or get_llm_tool_selector_prompt_template() @staticmethod def _create_selection_model(tools: list[AbstractTool]) -> type[TinyModel]: @@ -106,91 +92,82 @@ class LocalSelectionModel(TinyModel): return LocalSelectionModel - @tiny_trace('tool_selector.before_llm_call') + @tiny_trace('llm_tool_selector.before_llm_call') async def before_llm_call( self, *, run_id: str, llm_input: TinyLLMInput, kwargs: dict[str, Any] ) -> None: - if not kwargs.get('tools'): + candidates = self._prepare_candidates(kwargs) + + if candidates is None: return - selected_tools = set() - tools = cast(list[AbstractTool], kwargs.get('tools', [])) - mapped_tools = {t.info.name: t for t in tools} + selected_tools = candidates.selected_tools + tools = candidates.tools + remaining_tools = candidates.remaining_tools + mapped_tools = candidates.mapped_tools set_tiny_attributes( { - 'tool_selector.available_tools': [t.info.name for t in tools], - 'tool_selector.available_tools.total': len(tools), - 'tool_selector.max_tools': str(self.max_tools), - 'tool_selector.always_include': str(self.always_include), + 'llm_tool_selector.available_tools': [t.info.name for t in tools], + 'llm_tool_selector.available_tools.total': len(tools), + 'llm_tool_selector.max_tools': str(self.max_tools), + 'llm_tool_selector.always_include': str(self.always_include), + 'llm_tool_selector.remaining_space': str(candidates.remaining_space), } ) - if self.always_include: - for name in self.always_include: - if t := mapped_tools.get(name): - selected_tools.add(t) - - remaining_space: int | None = None - if self.max_tools: - remaining_space = self.max_tools - len(selected_tools) - - set_tiny_attributes( - { - 'tool_selector.remaining_space': remaining_space, - 'tool_selector.early_exit': remaining_space <= 0, - } + if len(remaining_tools) > 0: + local_llm_input = llm_input.model_copy() + local_llm_input.add_at_end( + TinySystemMessage(content=self.prompt_template.system) ) - - if remaining_space <= 0: - kwargs['tools'] = selected_tools - return - - local_llm_input = llm_input.model_copy() - local_llm_input.add_at_end( - TinySystemMessage(content=self.prompt_template.system) - ) - local_llm_input.add_at_end( - TinyHumanMessage( - content=render_template( - self.prompt_template.user, - { - 'tools': '\n'.join( - [ - f'{t.info.name} - {t.info.description or "Description not provided"}' - for t in tools - ] - ) - }, + local_llm_input.add_at_end( + TinyHumanMessage( + content=render_template( + self.prompt_template.user, + { + 'tools': '\n'.join( + [ + f'{t.info.name} - {t.info.description or "Description not provided"}' + for t in remaining_tools + ] + ) + }, + ) ) ) - ) - result = await self.llm.agenerate_structured( - llm_input=local_llm_input, output_schema=self._create_selection_model(tools) - ) + result = await self.llm.agenerate_structured( + llm_input=local_llm_input, + output_schema=self._create_selection_model(remaining_tools), + ) - for selected_tool_name in result.selected_tools: # type: ignore - tool_obj = mapped_tools.get(selected_tool_name) + for selected_tool_name in result.selected_tools: # type: ignore + if self.max_tools and len(selected_tools) >= self.max_tools: + break - if tool_obj in selected_tools: - logger.warning("Tool '%s' already selected", selected_tool_name) - continue + tool_obj = mapped_tools.get(selected_tool_name) - if not tool_obj: - logger.error( - "Selected tool '%s' doesn't exist amongs available tools: %s", - selected_tool_name, - [t.info.name for t in tools], - ) - continue + if tool_obj in selected_tools: + logger.warning("Tool '%s' already selected", selected_tool_name) + continue + + if not tool_obj: + logger.error( + "Selected tool '%s' doesn't exist amongs available tools: %s", + selected_tool_name, + [t.info.name for t in remaining_tools], + ) + continue - selected_tools.add(tool_obj) + selected_tools.add(tool_obj) set_tiny_attributes( { - 'tool_selector.selected_tools': [t.info.name for t in selected_tools], - 'tool_selector.selected_tools.total': len(selected_tools), + 'llm_tool_selector.selected_tools': [ + t.info.name for t in selected_tools + ], + 'llm_tool_selector.selected_tools.total': len(selected_tools), } ) diff --git a/tinygent/agents/middleware/register.py b/tinygent/agents/middleware/register.py index b4c9f04..b6f0901 100644 --- a/tinygent/agents/middleware/register.py +++ b/tinygent/agents/middleware/register.py @@ -4,6 +4,12 @@ ) from tinygent.agents.middleware.tool_limiter import TinyToolCallLimiterMiddleware from tinygent.agents.middleware.tool_limiter import TinyToolCallLimiterMiddlewareConfig +from tinygent.agents.middleware.vector_tool_selector import ( + TinyVectorToolSelectorMiddleware, +) +from tinygent.agents.middleware.vector_tool_selector import ( + TinyVectorToolSelectorMiddlewareConfig, +) from tinygent.core.runtime.global_registry import GlobalRegistry @@ -20,6 +26,11 @@ def _register_middleware() -> None: TinyLLMToolSelectorMiddlewareConfig, TinyLLMToolSelectorMiddleware, ) + registry.register_middleware( + 'vector_tool_classifier', + TinyVectorToolSelectorMiddlewareConfig, + TinyVectorToolSelectorMiddleware, + ) _register_middleware() diff --git a/tinygent/agents/middleware/vector_tool_selector.py b/tinygent/agents/middleware/vector_tool_selector.py new file mode 100644 index 0000000..3e03ed8 --- /dev/null +++ b/tinygent/agents/middleware/vector_tool_selector.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import logging +from typing import Any +from typing import Callable +from typing import Literal + +import numpy as np +from pydantic import Field + +from tinygent.agents.middleware.base_tool_selector import TinyBaseToolSelectorMiddleware +from tinygent.agents.middleware.base_tool_selector import ( + TinyBaseToolSelectorMiddlewareConfig, +) +from tinygent.core.datamodels.embedder import AbstractEmbedder +from tinygent.core.datamodels.embedder import AbstractEmbedderConfig +from tinygent.core.datamodels.messages import TinyHumanMessage +from tinygent.core.datamodels.tool import AbstractTool +from tinygent.core.factory.embedder import build_embedder +from tinygent.core.telemetry.decorators import tiny_trace +from tinygent.core.telemetry.otel import set_tiny_attributes +from tinygent.core.types.io.llm_io_input import TinyLLMInput + +logger = logging.getLogger(__name__) + + +class TinyVectorToolSelectorMiddlewareConfig( + TinyBaseToolSelectorMiddlewareConfig['TinyVectorToolSelectorMiddleware'] +): + """Configuration for TinyVectorToolSelector Middleware.""" + + type: Literal['vector_tool_classifier'] = Field( + default='vector_tool_classifier', frozen=True + ) + + embedder: AbstractEmbedder | AbstractEmbedderConfig = Field(...) + + similarity_threshold: float | None = Field(default=None) + + query_transform_fn: Callable[[TinyLLMInput], str] | None = Field(default=None) + + tool_transform_fn: Callable[[AbstractTool], str] | None = Field(default=None) + + def build(self) -> TinyVectorToolSelectorMiddleware: + return TinyVectorToolSelectorMiddleware( + embedder=self.embedder + if isinstance(self.embedder, AbstractEmbedder) + else build_embedder(self.embedder), + similarity_threshold=self.similarity_threshold, + query_transform_fn=self.query_transform_fn, + tool_transform_fn=self.tool_transform_fn, + **self.build_base_kwargs(), + ) + + +class TinyVectorToolSelectorMiddleware(TinyBaseToolSelectorMiddleware): + def __init__( + self, + *, + embedder: AbstractEmbedder, + max_tools: int | None = None, + always_include: list[AbstractTool] | None = None, + similarity_threshold: float | None = None, + query_transform_fn: Callable[[TinyLLMInput], str] | None = None, + tool_transform_fn: Callable[[AbstractTool], str] | None = None, + ): + super().__init__(max_tools, always_include) + + if not query_transform_fn: + + def _default_query_transform_fn(llm_input: TinyLLMInput) -> str: + for m in reversed(llm_input.messages): + if isinstance(m, TinyHumanMessage): + return m.content + raise ValueError('No human message for embedding found in history') + + query_transform_fn = _default_query_transform_fn + + if not tool_transform_fn: + + def _default_tool_transform_fn(tool: AbstractTool) -> str: + return f'{tool.info.name} - {tool.info.description}' + + tool_transform_fn = _default_tool_transform_fn + + self.embedder = embedder + self.similarity_threshold = similarity_threshold + self.query_transform_fn = query_transform_fn + self.tool_transform_fn = tool_transform_fn + + def _cosine_similarity(self, a: list[float], b: list[float]) -> float: + a_np = np.array(a, dtype=float) + b_np = np.array(b, dtype=float) + + if a_np.shape != b_np.shape: + raise ValueError( + f'Vectors for cosine similarity must have the same shape a[{a_np.shape}], b[{b_np.shape}]' + ) + + norm_a = np.linalg.norm(a_np) + norm_b = np.linalg.norm(b_np) + + if norm_a == 0 or norm_b == 0: + return 0.0 + + return float(np.dot(a_np, b_np) / (norm_a * norm_b)) + + @tiny_trace('vector_tool_selector.before_llm_call') + async def before_llm_call( + self, *, run_id: str, llm_input: TinyLLMInput, kwargs: dict[str, Any] + ) -> None: + candidates = self._prepare_candidates(kwargs) + + if candidates is None: + return + + tools = candidates.tools + remaining_tools = candidates.remaining_tools + selected_tools = candidates.selected_tools + + set_tiny_attributes( + { + 'vector_tool_selector.available_tools': [t.info.name for t in tools], + 'vector_tool_selector.available_tools.total': len(tools), + 'vector_tool_selector.max_tools': str(self.max_tools), + 'vector_tool_selector.always_include': str(self.always_include), + 'vector_tool_selector.remaining_space': str(candidates.remaining_space), + } + ) + + transformed_query = self.query_transform_fn(llm_input) + transformed_tools = [self.tool_transform_fn(t) for t in remaining_tools] + + query_emb, *tool_embs = await self.embedder.aembed_batch( + [transformed_query] + transformed_tools + ) + + similarities = [ + self._cosine_similarity(query_emb, tool_emb) for tool_emb in tool_embs + ] + tool_sim_pairs = [ + (s, t, tt) + for s, t, tt in zip( + similarities, remaining_tools, transformed_tools, strict=True + ) + ] + tool_sim_pairs = sorted( + tool_sim_pairs, key=lambda x: x[0], reverse=True + ) # sort by highest similarity + + for s, t, tt in tool_sim_pairs: + set_tiny_attributes( + { + f'vector_tool_selector.{t.info.name}.transformed_description': tt, + f'vector_tool_selector.{t.info.name}.similarity_score': s, + } + ) + + if self.max_tools and len(selected_tools) >= self.max_tools: + continue + + if self.similarity_threshold is None or ( + self.similarity_threshold and s > self.similarity_threshold + ): + selected_tools.add(t) + + set_tiny_attributes( + { + 'vector_tool_selector.transformed_query': transformed_query, + 'vector_tool_selector.selected_tools': [ + t.info.name for t in selected_tools + ], + 'vector_tool_selector.selected_tools.total': len(selected_tools), + } + ) + + kwargs['tools'] = selected_tools diff --git a/tinygent/agents/multi_step_agent.py b/tinygent/agents/multi_step_agent.py index 1b5b707..6504c1a 100644 --- a/tinygent/agents/multi_step_agent.py +++ b/tinygent/agents/multi_step_agent.py @@ -21,6 +21,7 @@ from tinygent.core.datamodels.messages import TinyReasoningMessage from tinygent.core.datamodels.messages import TinySystemMessage from tinygent.core.datamodels.messages import TinyToolCall +from tinygent.core.datamodels.messages import TinyUserMessage from tinygent.core.datamodels.middleware import AbstractMiddleware from tinygent.core.datamodels.tool import AbstractTool from tinygent.core.runtime.executors import run_async_in_executor @@ -179,7 +180,7 @@ async def _stream_action( messages = TinyLLMInput( messages=[ *self.memory.copy_chat_messages(), - TinyHumanMessage( + TinyUserMessage( content=render_template( self.acter_prompt.final_answer, { @@ -215,7 +216,7 @@ async def _stream_fallback_answer( ] ) messages.add_at_beginning( - TinyHumanMessage( + TinyUserMessage( content=render_template( self.fallback_prompt.fallback_answer, { diff --git a/tinygent/agents/squad_agent.py b/tinygent/agents/squad_agent.py index a687661..3f08435 100644 --- a/tinygent/agents/squad_agent.py +++ b/tinygent/agents/squad_agent.py @@ -19,6 +19,7 @@ from tinygent.core.datamodels.llm import AbstractLLM from tinygent.core.datamodels.memory import AbstractMemory from tinygent.core.datamodels.messages import AllTinyMessages +from tinygent.core.datamodels.messages import TinyChatMessage from tinygent.core.datamodels.messages import TinyHumanMessage from tinygent.core.datamodels.messages import TinySquadMemberMessage from tinygent.core.datamodels.messages import TinySystemMessage @@ -281,7 +282,7 @@ async def _run_agent( result=final_answer, ) ) - self.memory.save_context(TinyHumanMessage(content=final_answer)) + self.memory.save_context(TinyChatMessage(content=final_answer)) except Exception as e: await self.on_error(run_id=run_id, e=e, kwargs={}) raise e diff --git a/tinygent/core/datamodels/messages.py b/tinygent/core/datamodels/messages.py index e04aab4..c8a9178 100644 --- a/tinygent/core/datamodels/messages.py +++ b/tinygent/core/datamodels/messages.py @@ -24,6 +24,7 @@ TinyMessageType = TypeVar( 'TinyMessageType', Literal['system'], + Literal['user'], Literal['squad_member'], Literal['chat'], Literal['tool'], @@ -55,7 +56,7 @@ def tiny_str(self) -> str: class TinySystemMessage(BaseMessage[Literal['system']]): - """Message representing system-level instructions.""" + """Message representing system-level instruction.""" type: Literal['system'] = 'system' """The type of the message.""" @@ -65,7 +66,21 @@ class TinySystemMessage(BaseMessage[Literal['system']]): @property def tiny_str(self) -> str: - return f'SYSTEM: {self.content}' + return f'SYSTEM (instruction): {self.content}' + + +class TinyUserMessage(BaseMessage[Literal['user']]): + """Message representing user-level instruction.""" + + type: Literal['user'] = 'user' + """The type of the message.""" + + content: str + """The content of the human message.""" + + @property + def tiny_str(self) -> str: + return f'USER (instruction): {self.content}' class TinyPlanMessage(BaseMessage[Literal['plan']]): @@ -283,6 +298,7 @@ def tiny_str(self) -> str: | TinySquadMemberMessage | TinyHumanMessage | TinySystemMessage + | TinyUserMessage | TinyToolResult | TinySummaryMessage ), diff --git a/tinygent/core/datamodels/tool.py b/tinygent/core/datamodels/tool.py index 957e050..0c8477a 100644 --- a/tinygent/core/datamodels/tool.py +++ b/tinygent/core/datamodels/tool.py @@ -60,3 +60,6 @@ def clear_cache(self) -> None: def cache_info(self) -> Any: """Get information about the tool's cache.""" pass + + def __repr__(self) -> str: + return self.__str__() diff --git a/uv.lock b/uv.lock index 2190ae6..35628d0 100644 --- a/uv.lock +++ b/uv.lock @@ -1646,6 +1646,7 @@ dependencies = [ { name = "async-lru" }, { name = "jinja2" }, { name = "langchain-core" }, + { name = "numpy" }, { name = "opentelemetry-exporter-otlp" }, { name = "opentelemetry-sdk" }, { name = "pydantic" }, @@ -1697,6 +1698,7 @@ requires-dist = [ { name = "mkdocs-material", marker = "extra == 'docs'", specifier = ">=9.7.1" }, { name = "mkdocstrings", extras = ["python"], marker = "extra == 'docs'", specifier = ">=1.0.2" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.17.1" }, + { name = "numpy", specifier = ">=2.3.4" }, { name = "opentelemetry-exporter-otlp", specifier = ">=1.38.0" }, { name = "opentelemetry-sdk", specifier = ">=1.38.0" }, { name = "pydantic", specifier = ">=2.11.7" },