diff --git a/docs/source/reference/llms.rst b/docs/source/reference/llms.rst
index 327303fe338..717bd4d34f3 100644
--- a/docs/source/reference/llms.rst
+++ b/docs/source/reference/llms.rst
@@ -869,7 +869,8 @@ Transforms are the main way to extend ChatEnv with specific capabilities:
- **Reward computation**: :class:`~torchrl.envs.llm.transforms.KLRewardTransform` for KL divergence rewards
- **Tool execution**: :class:`~torchrl.envs.llm.transforms.PythonInterpreter` for Python code
- execution, :class:`~torchrl.envs.llm.transforms.MCPToolTransform` for general tool calling.
+ execution, :class:`~torchrl.envs.llm.transforms.ExecuteToolsInOrder` for order-preserving tool execution with
+ pluggable services, :class:`~torchrl.envs.llm.transforms.MCPToolTransform` for general tool calling (legacy).
- **Data loading**: :class:`~torchrl.envs.llm.transforms.DataLoadingPrimer` for loading prompts from datasets
- **Thinking prompts**: :class:`~torchrl.envs.llm.transforms.AddThinkingPrompt` for chain-of-thought reasoning
- **Policy tracking**: :class:`~torchrl.envs.llm.transforms.PolicyVersion` for version control
@@ -929,6 +930,75 @@ Transforms are used to modify the data before it is passed to the LLM.
Tools are usually implemented as transforms, and appended to a base environment
such as :class:`~torchrl.envs.llm.ChatEnv`.
+Tool Service Library
+^^^^^^^^^^^^^^^^^^^^
+
+TorchRL provides a flexible tool service library for adding tool-calling capabilities to LLM agents.
+The :class:`~torchrl.envs.llm.transforms.ExecuteToolsInOrder` transform enables LLMs to call external tools
+(e.g., web search, calculators, databases) in a structured, order-preserving manner.
+
+**Key Features:**
+
+- **Order-Preserving Execution**: Tools execute in the order they appear in the LLM output, respecting the LLM's reasoning
+- **Pluggable Parsers**: Support for different tool-calling formats (XML, JSON, custom) via :class:`~torchrl.envs.llm.transforms.XMLBlockParser` and :class:`~torchrl.envs.llm.transforms.JSONCallParser`
+- **Protocol-Based Services**: Clean interfaces using :class:`~torchrl.envs.llm.transforms.ToolService` protocol for easy extension
+- **Service Registry**: :class:`~torchrl.envs.llm.transforms.ToolRegistry` manages available tool services
+- **State Passing**: Optional filtered TensorDict state access for tools
+- **Error Handling**: Configurable fail-fast or continue-on-error behavior
+
+**Basic Usage:**
+
+.. code-block:: python
+
+ from torchrl.envs.llm.transforms import (
+ ExecuteToolsInOrder,
+ ToolRegistry,
+ XMLBlockParser,
+ )
+ from torchrl.envs.llm import ChatEnv
+ from torchrl.envs.transforms import TransformedEnv
+
+ # Define a tool service
+ class WebSearchService:
+ name = "search"
+ schema_in = {"query": str}
+ schema_out = {"results": list}
+
+ def __call__(self, query: str, **kwargs):
+ # Implement your search logic
+ return {"results": [...]}
+
+ # Create registry and parser
+ registry = ToolRegistry([WebSearchService()])
+ parser = XMLBlockParser() # Or JSONCallParser()
+
+ # Create environment with tool execution
+ env = ChatEnv(batch_size=(1,), input_mode="history")
+ env = TransformedEnv(
+ env,
+ ExecuteToolsInOrder(
+ registry=registry,
+ parser=parser,
+ stop_on_error=False,
+ pass_state_to_tools=True,
+ )
+ )
+
+The LLM generates responses with tool calls (e.g., ``{"query": "TorchRL"}``),
+and the transform automatically parses, executes, and injects results back into the conversation history.
+
+**Parser Formats:**
+
+- **XML Style** (:class:`~torchrl.envs.llm.transforms.XMLBlockParser`): Parses ``JSON`` blocks
+- **JSON Style** (:class:`~torchrl.envs.llm.transforms.JSONCallParser`): Parses ``{"message": "...", "tools": [...]}`` format
+- **Custom Parsers**: Implement the :class:`~torchrl.envs.llm.transforms.LLMToolParser` protocol for custom formats
+
+For more details and examples, see ``examples/llm/tool_service_example.py`` and the comprehensive guide at
+``torchrl/envs/llm/transforms/TOOL_SERVICE_GUIDE.md``.
+
+Python Interpreter Example
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
An example of a tool transform is the :class:`~torchrl.envs.llm.transforms.PythonInterpreter` transform, which is used
to execute Python code in the context of the LLM.
@@ -1137,6 +1207,8 @@ By following these design principles, reward transforms can be effectively integ
AddThinkingPrompt
BrowserTransform
DataLoadingPrimer
+ ExecuteToolsInOrder
+ JSONCallParser
KLComputation
KLRewardTransform
MCPToolTransform
@@ -1147,6 +1219,10 @@ By following these design principles, reward transforms can be effectively integ
RetrieveLogProb
TemplateTransform
Tokenizer
+ ToolCall
+ ToolRegistry
+ ToolService
+ XMLBlockParser
as_nested_tensor
as_padded_tensor
diff --git a/torchrl/envs/llm/transforms/__init__.py b/torchrl/envs/llm/transforms/__init__.py
index 0a90f86daeb..51545ea1504 100644
--- a/torchrl/envs/llm/transforms/__init__.py
+++ b/torchrl/envs/llm/transforms/__init__.py
@@ -15,7 +15,7 @@
from .policy_version import PolicyVersion
from .reason import AddThinkingPrompt
from .tokenizer import Tokenizer
-from .tools import MCPToolTransform, PythonInterpreter
+from .tools import MCPToolTransform, PythonInterpreter, ToolTransformBase
__all__ = [
"BrowserTransform",
@@ -31,6 +31,7 @@
"KLComputation",
"TemplateTransform",
"Tokenizer",
+ "ToolTransformBase",
"as_nested_tensor",
"as_padded_tensor",
]
diff --git a/torchrl/envs/llm/transforms/tools.py b/torchrl/envs/llm/transforms/tools.py
index f0940b1b1aa..10a1aed829b 100644
--- a/torchrl/envs/llm/transforms/tools.py
+++ b/torchrl/envs/llm/transforms/tools.py
@@ -23,6 +23,226 @@
from torchrl.envs import Transform
+# --- Base Class for Tool Transforms ---
+
+
+class ToolTransformBase(Transform):
+ """Base class for tool transforms that parse and execute tools from LLM output.
+
+ This class handles all the common boilerplate for tool transforms:
+ - History extraction and validation
+ - Batch dimension flattening
+ - Result collection and padding
+ - History extension with tool results
+
+ Subclasses only need to implement:
+ - :meth:`_process_batch_item`: Extract and execute tools from one response
+ - :meth:`_format_result`: Format one tool result as string (optional)
+
+ Attributes:
+ use_step (bool): Whether to use _step() vs _call(). Defaults to True.
+ tool_role (str): Role name for results in history. Defaults to "tool".
+
+ Examples:
+ >>> class SimpleCalculator(ToolTransformBase):
+ ... tool_role = "calculator"
+ ...
+ ... def _process_batch_item(self, content: str, index: int):
+ ... # Extract math expressions and evaluate
+ ... if "2+2" in content:
+ ... return ["2+2=4"]
+ ... return None
+ """
+
+ use_step: bool = True # Use _step() vs _call()
+ tool_role: str = "tool" # Role name for results in history
+
+ def _validate_and_extract_history(
+ self, next_tensordict: TensorDictBase
+ ) -> tuple[History, History]:
+ """Validate environment and extract history.
+
+ Args:
+ next_tensordict: The tensordict containing history.
+
+ Returns:
+ tuple: (full_history, local_history) where local_history is the last message.
+
+ Raises:
+ RuntimeError: If parent env doesn't exist or isn't in history mode.
+ """
+ # Check that base_env is in history mode
+ parent = self.parent
+ if parent is None:
+ raise RuntimeError(f"{self.__class__.__name__} must be used with a ChatEnv")
+ base_env = parent.base_env
+ if base_env.input_mode != "history":
+ raise RuntimeError(
+ f"{self.__class__.__name__} must be used with a ChatEnv in history mode"
+ )
+
+ # Get history and isolate last element (the LLM's response)
+ history = next_tensordict["history"].prompt
+ local_history = history[..., -1]
+
+ return history, local_history
+
+ def _process_batch_item(self, content: str, index: int) -> list[str] | None:
+ """Process one item in the batch to extract and execute tools.
+
+ This is the main method subclasses must implement.
+
+ Args:
+ content: The text content from the LLM response.
+ index: The index of this item in the batch.
+
+ Returns:
+ list[str] or None: List of result strings for each tool executed,
+ or None if no tools were found/executed.
+ """
+ raise NotImplementedError(
+ f"{self.__class__.__name__} must implement _process_batch_item()"
+ )
+
+ def _format_result(self, result: str) -> str:
+ """Format a single result string.
+
+ Override this to customize result formatting. Default is identity.
+
+ Args:
+ result: Raw result string from tool execution.
+
+ Returns:
+ str: Formatted result string.
+ """
+ return result
+
+ def _inject_results_to_history(
+ self,
+ history: History,
+ results: list[list[str] | None],
+ next_tensordict: TensorDictBase,
+ ) -> TensorDictBase:
+ """Inject tool results back into history with proper batching.
+
+ Args:
+ history: The full conversation history.
+ results: List of results per batch item (can contain None).
+ next_tensordict: The tensordict to update.
+
+ Returns:
+ TensorDictBase: Updated tensordict with results in history.
+ """
+ # Convert string results to History objects
+ procs = []
+ for batch_results in results:
+ if batch_results is None or len(batch_results) == 0:
+ procs.append(None)
+ else:
+ formatted_results = [self._format_result(r) for r in batch_results]
+ procs.append(
+ [
+ History(role=self.tool_role, content=result)
+ for result in formatted_results
+ ]
+ )
+
+ # If there are no tool responses, skip
+ if all(p is None for p in procs):
+ return next_tensordict
+
+ # Fill None entries with empty lists for consistent batching
+ if any(p is None for p in procs):
+ procs = [p if p is not None else [] for p in procs]
+
+ # Pad all results to same length (required for batching)
+ if len(procs) > 1 and not all(len(p) == len(procs[0]) for p in procs):
+
+ def fill_procs(proc: list[History], max_len: int) -> list[History]:
+ if len(proc) == max_len:
+ return proc
+ return proc + [History(role="", content="")] * (
+ max_len - len(proc)
+ )
+
+ max_len = max(len(p) for p in procs)
+ procs = [fill_procs(p, max_len) for p in procs]
+
+ # Stack and extend history
+ procs = lazy_stack([lazy_stack(p) for p in procs])
+ history.extend(procs, dim=-1)
+ next_tensordict["history"].prompt = history
+
+ return next_tensordict
+
+ def _process_tensordict(self, next_tensordict: TensorDictBase) -> TensorDictBase:
+ """Main processing logic for tool transforms.
+
+ Handles batch flattening, history extraction, tool processing, and result injection.
+
+ Args:
+ next_tensordict: The tensordict to process.
+
+ Returns:
+ TensorDictBase: Updated tensordict with tool results.
+ """
+ # Flatten batch dimensions if needed
+ if next_tensordict.batch_dims > 1:
+ with next_tensordict.view(-1) as next_tensordict_flat:
+ next_tensordict_flat = self._process_tensordict(next_tensordict_flat)
+ return next_tensordict
+
+ # Extract and validate history
+ history, local_history = self._validate_and_extract_history(next_tensordict)
+
+ # Handle content as string or list
+ content = local_history.content
+ if isinstance(content, str):
+ content = [content]
+
+ # Process each batch item
+ results = []
+ for i, text in enumerate(content):
+ batch_results = self._process_batch_item(text, i)
+ results.append(batch_results)
+
+ # Inject results back into history
+ return self._inject_results_to_history(history, results, next_tensordict)
+
+ def _step(
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
+ ) -> TensorDictBase:
+ """Handle step with tool processing.
+
+ Args:
+ tensordict: Input tensordict.
+ next_tensordict: Output tensordict.
+
+ Returns:
+ TensorDictBase: Updated next_tensordict.
+ """
+ if not self.use_step:
+ raise RuntimeError(
+ f"{self.__class__.__name__} uses _call(), not _step(). Set use_step=False."
+ )
+ return self._process_tensordict(next_tensordict)
+
+ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
+ """Handle call with tool processing.
+
+ Args:
+ next_tensordict: The tensordict to process.
+
+ Returns:
+ TensorDictBase: Updated tensordict.
+ """
+ if self.use_step:
+ raise RuntimeError(
+ f"{self.__class__.__name__} uses _step(), not _call(). Set use_step=True."
+ )
+ return self._process_tensordict(next_tensordict)
+
+
class PersistentPythonProcess:
"""A persistent Python process that can execute code blocks."""
@@ -332,9 +552,12 @@ def __del__(self):
self.cleanup()
-class PythonInterpreter(Transform):
+class PythonInterpreter(ToolTransformBase):
r"""A transform that executes Python code in the LLM response.
+ This transform inherits from :class:`ToolTransformBase` and handles all the
+ boilerplate for history extraction, batch processing, and result injection.
+
Args:
tokenizer: The tokenizer to use. Defaults to `None` (no tokenizer).
tool_name: The name of the tool in the chat history. Defaults to `"tool"`.
@@ -378,6 +601,8 @@ class PythonInterpreter(Transform):
'<|im_start|>assistant\n']
"""
+ use_step = True # Use _step() method
+
def __init__(
self,
tokenizer=None, # type: ignore
@@ -387,7 +612,7 @@ def __init__(
):
super().__init__()
self.tokenizer = tokenizer
- self.tool_name = tool_name
+ self.tool_role = tool_name # Set the role for history entries
self.persistent = persistent
# Initialize as empty list if persistent, None otherwise
self.processes: list[PersistentPythonProcess | None] = [] if persistent else []
@@ -403,7 +628,7 @@ def clone(self):
"""Clone the transform."""
return self.__class__(
tokenizer=self.tokenizer,
- tool_name=self.tool_name,
+ tool_name=self.tool_role, # tool_role is the instance attribute
persistent=self.persistent,
)
@@ -483,87 +708,55 @@ def _extract_python_code(self, text: str) -> list[str]:
matches = re.findall(pattern, text, re.DOTALL)
return matches
- def _process_llm_response(self, response: str, i: int) -> list[str]:
- """Process LLM response and execute any Python code found.
+ def _process_batch_item(self, content: str, index: int) -> list[str] | None:
+ """Process one batch item to extract and execute Python code.
+
+ This is the main method required by ToolTransformBase.
Args:
- response (str): The response from the LLM.
- i (int): The index of the response in the batch.
+ content: The text content from the LLM response.
+ index: The index of this item in the batch.
Returns:
- list[str]: A list of strings, each containing the result of the execution of the code block.
+ list[str] or None: List of result strings for each code block executed,
+ or None if no code blocks were found.
"""
- code_blocks = self._extract_python_code(response)
+ # Ensure we have enough processes for persistent mode
+ if self.persistent:
+ if index >= len(self.processes):
+ self._ensure_processes(index + 1)
+
+ # Extract code blocks
+ code_blocks = self._extract_python_code(content)
+ if not code_blocks:
+ return None
+ # Execute each code block
results = []
- for i, code in enumerate(code_blocks):
- result = self._execute_python_code(code, i)
+ for block_idx, code in enumerate(code_blocks):
+ result = self._execute_python_code(code, index)
if result["success"]:
results.append(
- f"Code block {i + 1} executed successfully:\n{result['stdout']}"
+ f"Code block {block_idx + 1} executed successfully:\n{result['stdout']}"
)
else:
- results.append(f"Code block {i + 1} failed:\n{result['stderr']}")
+ results.append(
+ f"Code block {block_idx + 1} failed:\n{result['stderr']}"
+ )
- return results
+ return results if results else None
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
- if next_tensordict.batch_dims > 1:
- with next_tensordict.view(-1) as next_tensordict_flat, tensordict.view(
- -1
- ) as tensordict_flat:
- # Call the transform on the flattened tensordict
- next_tensordict_flat = self._step(tensordict_flat, next_tensordict_flat)
- return next_tensordict
-
- # Ensure we have enough processes for the batch
- if self.persistent:
+ """Override to handle batch size management for persistent processes."""
+ # Ensure we have enough processes for the entire batch
+ if self.persistent and next_tensordict.batch_dims == 1:
self._ensure_processes(len(next_tensordict))
- # Convert text to a history
- history = next_tensordict["history"].prompt
- # Isolate last element, which should be our action
- local_history = history[..., -1]
-
- procs = []
- # Iterate over env batch-size
- content = local_history.content
- if isinstance(content, str):
- content = [content]
- for i, t in enumerate(content):
- results = self._process_llm_response(t, i)
- if len(results) == 0:
- procs.append(None)
- continue
- procs.append(
- [History(role=self.tool_name, content=result) for result in results]
- )
-
- # If there is no tool response, just skip entire batch
- if all(p is None for p in procs):
- return next_tensordict
- if any(p is None for p in procs):
- procs = [p if p is not None else [] for p in procs]
- # We need to have the same number of items for eache element in the batch
- if len(procs) > 1 and not all(len(p) == len(procs[0]) for p in procs):
-
- def fill_procs(proc: list[History], max_len: int) -> list[History]:
- if len(proc) == max_len:
- return proc
- return proc + [History(role="", content="")] * (
- max_len - len(proc)
- )
-
- max_len = max(len(p) for p in procs)
- procs = [fill_procs(p, max_len) for p in procs]
- # Procs has the shape of the batch-size. We can cat along dim=-1
- procs = lazy_stack([lazy_stack(p) for p in procs])
- history.extend(procs, dim=-1)
- next_tensordict["history"].prompt = history
- return next_tensordict
+ # Delegate to base class for all the heavy lifting
+ return super()._step(tensordict, next_tensordict)
def __del__(self):
"""Cleanup persistent processes on deletion."""
@@ -594,9 +787,12 @@ def _reset(
return tensordict_reset
-class MCPToolTransform(Transform):
+class MCPToolTransform(ToolTransformBase):
r"""A transform that executes MCP-style tools in response to LLM actions.
+ This transform inherits from :class:`ToolTransformBase` and handles all the
+ boilerplate for history extraction, batch processing, and result injection.
+
This transform allows execution of tools following the Mission Control Protocol pattern,
where tools are defined with clear input/output schemas and executed in a controlled manner.
@@ -661,6 +857,8 @@ class MCPToolTransform(Transform):
'<|im_start|>assistant\n']
"""
+ use_step = False # Use _call() method
+
def __init__(
self,
tools: dict[str, callable],
@@ -673,7 +871,7 @@ def __init__(
self.tools = tools
self.tool_schemas = tool_schemas
self.tokenizer = tokenizer
- self.tool_name = tool_name
+ self.tool_role = tool_name # Set the role for history entries
self.timeout = timeout
def _extract_tool_calls(
@@ -741,17 +939,25 @@ def signal_handler(signum, frame):
"error": f"Tool execution failed: {str(e)}",
}
- def _process_llm_response(self, response: str) -> list[str]:
- """Process LLM response and execute any tool calls found.
+ def _process_batch_item(self, content: str, index: int) -> list[str] | None:
+ """Process one batch item to extract and execute MCP-style tools.
+
+ This is the main method required by ToolTransformBase.
Args:
- response (str): The response from the LLM.
+ content: The text content from the LLM response.
+ index: The index of this item in the batch (unused for MCP tools).
Returns:
- list[str]: A list of strings, each containing the result of a tool execution.
+ list[str] or None: List of result strings for each tool executed,
+ or None if no tools were found.
"""
- tool_calls = self._extract_tool_calls(response)
+ # Extract tool calls
+ tool_calls = self._extract_tool_calls(content)
+ if not tool_calls:
+ return None
+ # Execute each tool
results = []
for tool_name, args_json in tool_calls:
result = self._execute_tool(tool_name, args_json)
@@ -763,65 +969,4 @@ def _process_llm_response(self, response: str) -> list[str]:
else:
results.append(f"Tool {tool_name} failed:\n{result['error']}")
- return results
-
- def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
- if next_tensordict.batch_dims > 1:
- with next_tensordict.view(-1) as next_tensordict_flat:
- # Call the transform on the flattened tensordict
- next_tensordict_flat = self._call(next_tensordict_flat)
- return next_tensordict
-
- # Check that base_env is on history mode
- parent = self.parent
- if parent is None:
- raise RuntimeError("MCPToolTransform must be used with a ChatEnv")
- base_env = parent.base_env
- if base_env.input_mode != "history":
- raise RuntimeError(
- "MCPToolTransform must be used with a ChatEnv in history mode"
- )
-
- # Convert text to a history
- history = next_tensordict["history"].prompt
- # Isolate last element, which should be our action
- local_history = history[..., -1]
-
- procs = []
- # Iterate over env batch-size
- for t in local_history.content:
- results = self._process_llm_response(t)
- if len(results) == 0:
- procs.append(None)
- continue
- procs.append(
- [History(role=self.tool_name, content=result) for result in results]
- )
-
- # If there is no tool response, just skip entire batch
- if all(p is None for p in procs):
- return next_tensordict
- if any(p is None for p in procs):
- procs = [p if p is not None else [] for p in procs]
- # We need to have the same number of items for each element in the batch
- if len(procs) > 1 and not all(len(p) == len(procs[0]) for p in procs):
-
- def fill_procs(proc: list[History], max_len: int) -> list[History]:
- if len(proc) == max_len:
- return proc
- return proc + [History(role="", content="")] * (
- max_len - len(proc)
- )
-
- max_len = max(len(p) for p in procs)
- procs = [fill_procs(p, max_len) for p in procs]
- # Procs has the shape of the batch-size. We can cat along dim=-1
- procs = lazy_stack([lazy_stack(p) for p in procs])
- history.extend(procs, dim=-1)
- next_tensordict["history"].prompt = history
- return next_tensordict
-
- def _reset(
- self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
- ) -> TensorDictBase:
- return tensordict_reset
+ return results if results else None