From b556d303e54963f8c18234800f3a606a9f91a950 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 24 Oct 2025 09:17:07 -0700 Subject: [PATCH] [Refactor] Refactor tool transforms --- docs/source/reference/llms.rst | 78 ++++- torchrl/envs/llm/transforms/__init__.py | 3 +- torchrl/envs/llm/transforms/tools.py | 413 ++++++++++++++++-------- 3 files changed, 358 insertions(+), 136 deletions(-) diff --git a/docs/source/reference/llms.rst b/docs/source/reference/llms.rst index 327303fe338..717bd4d34f3 100644 --- a/docs/source/reference/llms.rst +++ b/docs/source/reference/llms.rst @@ -869,7 +869,8 @@ Transforms are the main way to extend ChatEnv with specific capabilities: - **Reward computation**: :class:`~torchrl.envs.llm.transforms.KLRewardTransform` for KL divergence rewards - **Tool execution**: :class:`~torchrl.envs.llm.transforms.PythonInterpreter` for Python code - execution, :class:`~torchrl.envs.llm.transforms.MCPToolTransform` for general tool calling. + execution, :class:`~torchrl.envs.llm.transforms.ExecuteToolsInOrder` for order-preserving tool execution with + pluggable services, :class:`~torchrl.envs.llm.transforms.MCPToolTransform` for general tool calling (legacy). - **Data loading**: :class:`~torchrl.envs.llm.transforms.DataLoadingPrimer` for loading prompts from datasets - **Thinking prompts**: :class:`~torchrl.envs.llm.transforms.AddThinkingPrompt` for chain-of-thought reasoning - **Policy tracking**: :class:`~torchrl.envs.llm.transforms.PolicyVersion` for version control @@ -929,6 +930,75 @@ Transforms are used to modify the data before it is passed to the LLM. Tools are usually implemented as transforms, and appended to a base environment such as :class:`~torchrl.envs.llm.ChatEnv`. +Tool Service Library +^^^^^^^^^^^^^^^^^^^^ + +TorchRL provides a flexible tool service library for adding tool-calling capabilities to LLM agents. +The :class:`~torchrl.envs.llm.transforms.ExecuteToolsInOrder` transform enables LLMs to call external tools +(e.g., web search, calculators, databases) in a structured, order-preserving manner. + +**Key Features:** + +- **Order-Preserving Execution**: Tools execute in the order they appear in the LLM output, respecting the LLM's reasoning +- **Pluggable Parsers**: Support for different tool-calling formats (XML, JSON, custom) via :class:`~torchrl.envs.llm.transforms.XMLBlockParser` and :class:`~torchrl.envs.llm.transforms.JSONCallParser` +- **Protocol-Based Services**: Clean interfaces using :class:`~torchrl.envs.llm.transforms.ToolService` protocol for easy extension +- **Service Registry**: :class:`~torchrl.envs.llm.transforms.ToolRegistry` manages available tool services +- **State Passing**: Optional filtered TensorDict state access for tools +- **Error Handling**: Configurable fail-fast or continue-on-error behavior + +**Basic Usage:** + +.. code-block:: python + + from torchrl.envs.llm.transforms import ( + ExecuteToolsInOrder, + ToolRegistry, + XMLBlockParser, + ) + from torchrl.envs.llm import ChatEnv + from torchrl.envs.transforms import TransformedEnv + + # Define a tool service + class WebSearchService: + name = "search" + schema_in = {"query": str} + schema_out = {"results": list} + + def __call__(self, query: str, **kwargs): + # Implement your search logic + return {"results": [...]} + + # Create registry and parser + registry = ToolRegistry([WebSearchService()]) + parser = XMLBlockParser() # Or JSONCallParser() + + # Create environment with tool execution + env = ChatEnv(batch_size=(1,), input_mode="history") + env = TransformedEnv( + env, + ExecuteToolsInOrder( + registry=registry, + parser=parser, + stop_on_error=False, + pass_state_to_tools=True, + ) + ) + +The LLM generates responses with tool calls (e.g., ``{"query": "TorchRL"}``), +and the transform automatically parses, executes, and injects results back into the conversation history. + +**Parser Formats:** + +- **XML Style** (:class:`~torchrl.envs.llm.transforms.XMLBlockParser`): Parses ``JSON`` blocks +- **JSON Style** (:class:`~torchrl.envs.llm.transforms.JSONCallParser`): Parses ``{"message": "...", "tools": [...]}`` format +- **Custom Parsers**: Implement the :class:`~torchrl.envs.llm.transforms.LLMToolParser` protocol for custom formats + +For more details and examples, see ``examples/llm/tool_service_example.py`` and the comprehensive guide at +``torchrl/envs/llm/transforms/TOOL_SERVICE_GUIDE.md``. + +Python Interpreter Example +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + An example of a tool transform is the :class:`~torchrl.envs.llm.transforms.PythonInterpreter` transform, which is used to execute Python code in the context of the LLM. @@ -1137,6 +1207,8 @@ By following these design principles, reward transforms can be effectively integ AddThinkingPrompt BrowserTransform DataLoadingPrimer + ExecuteToolsInOrder + JSONCallParser KLComputation KLRewardTransform MCPToolTransform @@ -1147,6 +1219,10 @@ By following these design principles, reward transforms can be effectively integ RetrieveLogProb TemplateTransform Tokenizer + ToolCall + ToolRegistry + ToolService + XMLBlockParser as_nested_tensor as_padded_tensor diff --git a/torchrl/envs/llm/transforms/__init__.py b/torchrl/envs/llm/transforms/__init__.py index 0a90f86daeb..51545ea1504 100644 --- a/torchrl/envs/llm/transforms/__init__.py +++ b/torchrl/envs/llm/transforms/__init__.py @@ -15,7 +15,7 @@ from .policy_version import PolicyVersion from .reason import AddThinkingPrompt from .tokenizer import Tokenizer -from .tools import MCPToolTransform, PythonInterpreter +from .tools import MCPToolTransform, PythonInterpreter, ToolTransformBase __all__ = [ "BrowserTransform", @@ -31,6 +31,7 @@ "KLComputation", "TemplateTransform", "Tokenizer", + "ToolTransformBase", "as_nested_tensor", "as_padded_tensor", ] diff --git a/torchrl/envs/llm/transforms/tools.py b/torchrl/envs/llm/transforms/tools.py index f0940b1b1aa..10a1aed829b 100644 --- a/torchrl/envs/llm/transforms/tools.py +++ b/torchrl/envs/llm/transforms/tools.py @@ -23,6 +23,226 @@ from torchrl.envs import Transform +# --- Base Class for Tool Transforms --- + + +class ToolTransformBase(Transform): + """Base class for tool transforms that parse and execute tools from LLM output. + + This class handles all the common boilerplate for tool transforms: + - History extraction and validation + - Batch dimension flattening + - Result collection and padding + - History extension with tool results + + Subclasses only need to implement: + - :meth:`_process_batch_item`: Extract and execute tools from one response + - :meth:`_format_result`: Format one tool result as string (optional) + + Attributes: + use_step (bool): Whether to use _step() vs _call(). Defaults to True. + tool_role (str): Role name for results in history. Defaults to "tool". + + Examples: + >>> class SimpleCalculator(ToolTransformBase): + ... tool_role = "calculator" + ... + ... def _process_batch_item(self, content: str, index: int): + ... # Extract math expressions and evaluate + ... if "2+2" in content: + ... return ["2+2=4"] + ... return None + """ + + use_step: bool = True # Use _step() vs _call() + tool_role: str = "tool" # Role name for results in history + + def _validate_and_extract_history( + self, next_tensordict: TensorDictBase + ) -> tuple[History, History]: + """Validate environment and extract history. + + Args: + next_tensordict: The tensordict containing history. + + Returns: + tuple: (full_history, local_history) where local_history is the last message. + + Raises: + RuntimeError: If parent env doesn't exist or isn't in history mode. + """ + # Check that base_env is in history mode + parent = self.parent + if parent is None: + raise RuntimeError(f"{self.__class__.__name__} must be used with a ChatEnv") + base_env = parent.base_env + if base_env.input_mode != "history": + raise RuntimeError( + f"{self.__class__.__name__} must be used with a ChatEnv in history mode" + ) + + # Get history and isolate last element (the LLM's response) + history = next_tensordict["history"].prompt + local_history = history[..., -1] + + return history, local_history + + def _process_batch_item(self, content: str, index: int) -> list[str] | None: + """Process one item in the batch to extract and execute tools. + + This is the main method subclasses must implement. + + Args: + content: The text content from the LLM response. + index: The index of this item in the batch. + + Returns: + list[str] or None: List of result strings for each tool executed, + or None if no tools were found/executed. + """ + raise NotImplementedError( + f"{self.__class__.__name__} must implement _process_batch_item()" + ) + + def _format_result(self, result: str) -> str: + """Format a single result string. + + Override this to customize result formatting. Default is identity. + + Args: + result: Raw result string from tool execution. + + Returns: + str: Formatted result string. + """ + return result + + def _inject_results_to_history( + self, + history: History, + results: list[list[str] | None], + next_tensordict: TensorDictBase, + ) -> TensorDictBase: + """Inject tool results back into history with proper batching. + + Args: + history: The full conversation history. + results: List of results per batch item (can contain None). + next_tensordict: The tensordict to update. + + Returns: + TensorDictBase: Updated tensordict with results in history. + """ + # Convert string results to History objects + procs = [] + for batch_results in results: + if batch_results is None or len(batch_results) == 0: + procs.append(None) + else: + formatted_results = [self._format_result(r) for r in batch_results] + procs.append( + [ + History(role=self.tool_role, content=result) + for result in formatted_results + ] + ) + + # If there are no tool responses, skip + if all(p is None for p in procs): + return next_tensordict + + # Fill None entries with empty lists for consistent batching + if any(p is None for p in procs): + procs = [p if p is not None else [] for p in procs] + + # Pad all results to same length (required for batching) + if len(procs) > 1 and not all(len(p) == len(procs[0]) for p in procs): + + def fill_procs(proc: list[History], max_len: int) -> list[History]: + if len(proc) == max_len: + return proc + return proc + [History(role="", content="")] * ( + max_len - len(proc) + ) + + max_len = max(len(p) for p in procs) + procs = [fill_procs(p, max_len) for p in procs] + + # Stack and extend history + procs = lazy_stack([lazy_stack(p) for p in procs]) + history.extend(procs, dim=-1) + next_tensordict["history"].prompt = history + + return next_tensordict + + def _process_tensordict(self, next_tensordict: TensorDictBase) -> TensorDictBase: + """Main processing logic for tool transforms. + + Handles batch flattening, history extraction, tool processing, and result injection. + + Args: + next_tensordict: The tensordict to process. + + Returns: + TensorDictBase: Updated tensordict with tool results. + """ + # Flatten batch dimensions if needed + if next_tensordict.batch_dims > 1: + with next_tensordict.view(-1) as next_tensordict_flat: + next_tensordict_flat = self._process_tensordict(next_tensordict_flat) + return next_tensordict + + # Extract and validate history + history, local_history = self._validate_and_extract_history(next_tensordict) + + # Handle content as string or list + content = local_history.content + if isinstance(content, str): + content = [content] + + # Process each batch item + results = [] + for i, text in enumerate(content): + batch_results = self._process_batch_item(text, i) + results.append(batch_results) + + # Inject results back into history + return self._inject_results_to_history(history, results, next_tensordict) + + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + """Handle step with tool processing. + + Args: + tensordict: Input tensordict. + next_tensordict: Output tensordict. + + Returns: + TensorDictBase: Updated next_tensordict. + """ + if not self.use_step: + raise RuntimeError( + f"{self.__class__.__name__} uses _call(), not _step(). Set use_step=False." + ) + return self._process_tensordict(next_tensordict) + + def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: + """Handle call with tool processing. + + Args: + next_tensordict: The tensordict to process. + + Returns: + TensorDictBase: Updated tensordict. + """ + if self.use_step: + raise RuntimeError( + f"{self.__class__.__name__} uses _step(), not _call(). Set use_step=True." + ) + return self._process_tensordict(next_tensordict) + + class PersistentPythonProcess: """A persistent Python process that can execute code blocks.""" @@ -332,9 +552,12 @@ def __del__(self): self.cleanup() -class PythonInterpreter(Transform): +class PythonInterpreter(ToolTransformBase): r"""A transform that executes Python code in the LLM response. + This transform inherits from :class:`ToolTransformBase` and handles all the + boilerplate for history extraction, batch processing, and result injection. + Args: tokenizer: The tokenizer to use. Defaults to `None` (no tokenizer). tool_name: The name of the tool in the chat history. Defaults to `"tool"`. @@ -378,6 +601,8 @@ class PythonInterpreter(Transform): '<|im_start|>assistant\n'] """ + use_step = True # Use _step() method + def __init__( self, tokenizer=None, # type: ignore @@ -387,7 +612,7 @@ def __init__( ): super().__init__() self.tokenizer = tokenizer - self.tool_name = tool_name + self.tool_role = tool_name # Set the role for history entries self.persistent = persistent # Initialize as empty list if persistent, None otherwise self.processes: list[PersistentPythonProcess | None] = [] if persistent else [] @@ -403,7 +628,7 @@ def clone(self): """Clone the transform.""" return self.__class__( tokenizer=self.tokenizer, - tool_name=self.tool_name, + tool_name=self.tool_role, # tool_role is the instance attribute persistent=self.persistent, ) @@ -483,87 +708,55 @@ def _extract_python_code(self, text: str) -> list[str]: matches = re.findall(pattern, text, re.DOTALL) return matches - def _process_llm_response(self, response: str, i: int) -> list[str]: - """Process LLM response and execute any Python code found. + def _process_batch_item(self, content: str, index: int) -> list[str] | None: + """Process one batch item to extract and execute Python code. + + This is the main method required by ToolTransformBase. Args: - response (str): The response from the LLM. - i (int): The index of the response in the batch. + content: The text content from the LLM response. + index: The index of this item in the batch. Returns: - list[str]: A list of strings, each containing the result of the execution of the code block. + list[str] or None: List of result strings for each code block executed, + or None if no code blocks were found. """ - code_blocks = self._extract_python_code(response) + # Ensure we have enough processes for persistent mode + if self.persistent: + if index >= len(self.processes): + self._ensure_processes(index + 1) + + # Extract code blocks + code_blocks = self._extract_python_code(content) + if not code_blocks: + return None + # Execute each code block results = [] - for i, code in enumerate(code_blocks): - result = self._execute_python_code(code, i) + for block_idx, code in enumerate(code_blocks): + result = self._execute_python_code(code, index) if result["success"]: results.append( - f"Code block {i + 1} executed successfully:\n{result['stdout']}" + f"Code block {block_idx + 1} executed successfully:\n{result['stdout']}" ) else: - results.append(f"Code block {i + 1} failed:\n{result['stderr']}") + results.append( + f"Code block {block_idx + 1} failed:\n{result['stderr']}" + ) - return results + return results if results else None def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: - if next_tensordict.batch_dims > 1: - with next_tensordict.view(-1) as next_tensordict_flat, tensordict.view( - -1 - ) as tensordict_flat: - # Call the transform on the flattened tensordict - next_tensordict_flat = self._step(tensordict_flat, next_tensordict_flat) - return next_tensordict - - # Ensure we have enough processes for the batch - if self.persistent: + """Override to handle batch size management for persistent processes.""" + # Ensure we have enough processes for the entire batch + if self.persistent and next_tensordict.batch_dims == 1: self._ensure_processes(len(next_tensordict)) - # Convert text to a history - history = next_tensordict["history"].prompt - # Isolate last element, which should be our action - local_history = history[..., -1] - - procs = [] - # Iterate over env batch-size - content = local_history.content - if isinstance(content, str): - content = [content] - for i, t in enumerate(content): - results = self._process_llm_response(t, i) - if len(results) == 0: - procs.append(None) - continue - procs.append( - [History(role=self.tool_name, content=result) for result in results] - ) - - # If there is no tool response, just skip entire batch - if all(p is None for p in procs): - return next_tensordict - if any(p is None for p in procs): - procs = [p if p is not None else [] for p in procs] - # We need to have the same number of items for eache element in the batch - if len(procs) > 1 and not all(len(p) == len(procs[0]) for p in procs): - - def fill_procs(proc: list[History], max_len: int) -> list[History]: - if len(proc) == max_len: - return proc - return proc + [History(role="", content="")] * ( - max_len - len(proc) - ) - - max_len = max(len(p) for p in procs) - procs = [fill_procs(p, max_len) for p in procs] - # Procs has the shape of the batch-size. We can cat along dim=-1 - procs = lazy_stack([lazy_stack(p) for p in procs]) - history.extend(procs, dim=-1) - next_tensordict["history"].prompt = history - return next_tensordict + # Delegate to base class for all the heavy lifting + return super()._step(tensordict, next_tensordict) def __del__(self): """Cleanup persistent processes on deletion.""" @@ -594,9 +787,12 @@ def _reset( return tensordict_reset -class MCPToolTransform(Transform): +class MCPToolTransform(ToolTransformBase): r"""A transform that executes MCP-style tools in response to LLM actions. + This transform inherits from :class:`ToolTransformBase` and handles all the + boilerplate for history extraction, batch processing, and result injection. + This transform allows execution of tools following the Mission Control Protocol pattern, where tools are defined with clear input/output schemas and executed in a controlled manner. @@ -661,6 +857,8 @@ class MCPToolTransform(Transform): '<|im_start|>assistant\n'] """ + use_step = False # Use _call() method + def __init__( self, tools: dict[str, callable], @@ -673,7 +871,7 @@ def __init__( self.tools = tools self.tool_schemas = tool_schemas self.tokenizer = tokenizer - self.tool_name = tool_name + self.tool_role = tool_name # Set the role for history entries self.timeout = timeout def _extract_tool_calls( @@ -741,17 +939,25 @@ def signal_handler(signum, frame): "error": f"Tool execution failed: {str(e)}", } - def _process_llm_response(self, response: str) -> list[str]: - """Process LLM response and execute any tool calls found. + def _process_batch_item(self, content: str, index: int) -> list[str] | None: + """Process one batch item to extract and execute MCP-style tools. + + This is the main method required by ToolTransformBase. Args: - response (str): The response from the LLM. + content: The text content from the LLM response. + index: The index of this item in the batch (unused for MCP tools). Returns: - list[str]: A list of strings, each containing the result of a tool execution. + list[str] or None: List of result strings for each tool executed, + or None if no tools were found. """ - tool_calls = self._extract_tool_calls(response) + # Extract tool calls + tool_calls = self._extract_tool_calls(content) + if not tool_calls: + return None + # Execute each tool results = [] for tool_name, args_json in tool_calls: result = self._execute_tool(tool_name, args_json) @@ -763,65 +969,4 @@ def _process_llm_response(self, response: str) -> list[str]: else: results.append(f"Tool {tool_name} failed:\n{result['error']}") - return results - - def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: - if next_tensordict.batch_dims > 1: - with next_tensordict.view(-1) as next_tensordict_flat: - # Call the transform on the flattened tensordict - next_tensordict_flat = self._call(next_tensordict_flat) - return next_tensordict - - # Check that base_env is on history mode - parent = self.parent - if parent is None: - raise RuntimeError("MCPToolTransform must be used with a ChatEnv") - base_env = parent.base_env - if base_env.input_mode != "history": - raise RuntimeError( - "MCPToolTransform must be used with a ChatEnv in history mode" - ) - - # Convert text to a history - history = next_tensordict["history"].prompt - # Isolate last element, which should be our action - local_history = history[..., -1] - - procs = [] - # Iterate over env batch-size - for t in local_history.content: - results = self._process_llm_response(t) - if len(results) == 0: - procs.append(None) - continue - procs.append( - [History(role=self.tool_name, content=result) for result in results] - ) - - # If there is no tool response, just skip entire batch - if all(p is None for p in procs): - return next_tensordict - if any(p is None for p in procs): - procs = [p if p is not None else [] for p in procs] - # We need to have the same number of items for each element in the batch - if len(procs) > 1 and not all(len(p) == len(procs[0]) for p in procs): - - def fill_procs(proc: list[History], max_len: int) -> list[History]: - if len(proc) == max_len: - return proc - return proc + [History(role="", content="")] * ( - max_len - len(proc) - ) - - max_len = max(len(p) for p in procs) - procs = [fill_procs(p, max_len) for p in procs] - # Procs has the shape of the batch-size. We can cat along dim=-1 - procs = lazy_stack([lazy_stack(p) for p in procs]) - history.extend(procs, dim=-1) - next_tensordict["history"].prompt = history - return next_tensordict - - def _reset( - self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase - ) -> TensorDictBase: - return tensordict_reset + return results if results else None