diff --git a/src/litai/llm.py b/src/litai/llm.py index a70e940..3d451ea 100644 --- a/src/litai/llm.py +++ b/src/litai/llm.py @@ -31,6 +31,7 @@ if TYPE_CHECKING: from langchain_core.tools import StructuredTool + from pydantic import BaseModel CLOUDY_MODELS = { "openai/gpt-4o", @@ -358,7 +359,7 @@ def chat( # noqa: D417 @staticmethod def call_tool( response: Union[List[dict], dict, str], tools: Optional[Sequence[Union[LitTool, "StructuredTool"]]] = None - ) -> Optional[str]: + ) -> Optional[Union[str, "BaseModel", list["BaseModel"]]]: """Calls a tool with the given response.""" if tools is None: raise ValueError("No tools provided") diff --git a/src/litai/tools.py b/src/litai/tools.py index 9598f19..86c23a8 100644 --- a/src/litai/tools.py +++ b/src/litai/tools.py @@ -117,6 +117,25 @@ def _extract_parameters(self) -> Dict[str, Any]: return LangchainTool() + @classmethod + def from_model(cls, model: type[BaseModel]) -> "LitTool": + """Create a LitTool that exposes a Pydantic model as a structured schema.""" + + class ModelTool(LitTool): + def setup(self) -> None: + super().setup() + self.name = model.__name__ + self.description = model.__doc__ or "" + + def run(self, *args, **kwargs) -> Any: # type: ignore + # Default implementation: validate & return an instance + return model(*args, **kwargs) + + def _extract_parameters(self) -> Dict[str, Any]: + return model.model_json_schema() + + return ModelTool() + @classmethod def convert_tools(cls, tools: Optional[Sequence[Union["LitTool", "StructuredTool"]]]) -> List["LitTool"]: """Convert a list of tools into LitTool instances. diff --git a/tests/test_tools.py b/tests/test_tools.py index c642071..0f425f0 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -15,10 +15,21 @@ import pytest from langchain_core.tools import tool as langchain_tool +from pydantic import BaseModel from litai import LitTool, tool +@pytest.fixture +def weather_tool_model(): + class WeatherRequest(BaseModel): + """Get weather for location.""" + + location: str + + return WeatherRequest + + @pytest.fixture def basic_tool_class(): class TestTool(LitTool): @@ -226,3 +237,45 @@ def get_weather(city: str) -> str: with pytest.raises(TypeError, match="Unsupported tool type: "): LitTool.convert_tools([get_weather]) + + +def test_tool_from_model_with_no_description(weather_tool_model): + weather_tool_model.__doc__ = None + + lit_tool = LitTool.from_model(weather_tool_model) + + assert isinstance(lit_tool, LitTool) + assert lit_tool.name == "WeatherRequest" + assert lit_tool.description == "" + + assert lit_tool.as_tool() == { + "type": "function", + "function": { + "name": "WeatherRequest", + "description": "", + "parameters": weather_tool_model.model_json_schema(), + }, + } + + +def test_tool_run_from_model(weather_tool_model): + lit_tool = LitTool.from_model(weather_tool_model) + + assert lit_tool.run(location="NYC") == weather_tool_model(location="NYC") + + +def test_tool_from_model_with_description(weather_tool_model): + lit_tool = LitTool.from_model(weather_tool_model) + + assert isinstance(lit_tool, LitTool) + assert lit_tool.name == "WeatherRequest" + assert lit_tool.description == "Get weather for location." + + assert lit_tool.as_tool() == { + "type": "function", + "function": { + "name": "WeatherRequest", + "description": "Get weather for location.", + "parameters": weather_tool_model.model_json_schema(), + }, + }