Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 69 additions & 19 deletions aixplain/v1/modules/model/rlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,27 @@

The REPL environment is initialized with:
1. A `context` variable that contains extremely important information about your query. You should check the content of the `context` variable to understand what you are working with. Make sure you look through it sufficiently as you answer your query.
2. A `llm_query` function that allows you to query an LLM (that can handle around 500K chars) inside your REPL environment.
2. A `llm_query` function that allows you to query an LLM with a context window of {worker_context_window} inside your REPL environment. You must take this context window into consideration when deciding how much text to pass in each call.
3. The ability to use `print()` statements to view the output of your REPL code and continue your reasoning.

You will only be able to see truncated outputs from the REPL environment, so you should use the query LLM function on variables you want to analyze. You will find this function especially useful when you have to analyze the semantics of the context. Use these variables as buffers to build up your final answer.
Make sure to explicitly look through the entire context in REPL before answering your query. An example strategy is to first look at the context and figure out a chunking strategy, then break up the context into smart chunks, and query an LLM per chunk with a particular question and save the answers to a buffer, then query an LLM with all the buffers to produce your final answer.

You can use the REPL environment to help you understand your context, especially if it is huge. Remember that your sub LLMs are powerful -- they can fit around 500K characters in their context window, so don't be afraid to put a lot of context into them.
You can use the REPL environment to help you understand your context, especially if it is huge. Remember that your sub LLMs are powerful -- they have a context window of {worker_context_window}, so don't be afraid to put a lot of context into them.

When you want to execute Python code in the REPL environment, wrap it in triple backticks with the 'repl' language identifier:
```repl
# your Python code here
chunk = context[:10000]
answer = llm_query(f"What is the key finding in this text?\\n{chunk}")
answer = llm_query(f"What is the key finding in this text?\\n{{chunk}}")
print(answer)
```

IMPORTANT: When you are done with the iterative process, you MUST provide a final answer using one of these two forms (NOT inside a code block):
1. FINAL(your final answer here) — to provide the answer as literal text
2. FINAL_VAR(variable_name) — to return a variable you created in the REPL as your final answer
1. FINAL(your final answer here) — to provide the answer as literal text. Use `FINAL(...)` only when you are completely finished: you will make no further REPL calls, need no further inspection of REPL output, and are not including any REPL code in the same response.
2. FINAL_VAR(variable_name) — to return a variable you created in the REPL as your final answer. Use `FINAL_VAR(...)` only when that variable already contains your completed final answer and you will make no further REPL calls.

Do not use `FINAL(...)` or `FINAL_VAR(...)` for intermediate status updates, plans, requests to inspect REPL output, statements such as needing more information, or any response that also includes REPL code to be executed first; those must be written as normal assistant text instead.

Think step by step carefully, plan, and execute this plan immediately — do not just say what you will do.
"""
Expand Down Expand Up @@ -107,8 +109,8 @@
# Prompt Helpers


def _build_system_messages() -> List[Dict[str, str]]:
return [{"role": "system", "content": _SYSTEM_PROMPT}]
def _build_system_messages(worker_context_window: str) -> List[Dict[str, str]]:
return [{"role": "system", "content": _SYSTEM_PROMPT.format(worker_context_window=worker_context_window)}]


def _next_action_message(query: str, iteration: int, force_final: bool = False) -> Dict[str, str]:
Expand Down Expand Up @@ -238,6 +240,28 @@ def __init__(
self._session_id: Optional[str] = None
self._sandbox_tool: Optional[Model] = None
self._messages: List[Dict[str, str]] = []
self._used_credits: float = 0.0

# Worker Context Window

def _get_worker_context_window(self) -> str:
"""Return a human-readable description of the worker model's context window."""
attributes = getattr(self.worker, "additional_info", {}).get("attributes", [])
raw = next(
(attr["code"] for attr in attributes if attr.get("name") == "max_context_length"),
None,
)
if raw is not None:
try:
tokens = int(raw)
if tokens >= 1_000_000:
return f"{tokens / 1_000_000:.1f}M tokens"
if tokens >= 1_000:
return f"{tokens / 1_000:.0f}K tokens"
return f"{tokens} tokens"
except (ValueError, TypeError):
return str(raw)
return "a large context window"

# Context Resolution

Expand Down Expand Up @@ -422,7 +446,10 @@ def _setup_repl(self, context: Union[str, dict, list]) -> None:
import time as __time
import json as __json

_total_llm_query_credits = 0.0

def llm_query(prompt):
global _total_llm_query_credits
_headers = {{"x-api-key": "{self.api_key}", "Content-Type": "application/json"}}
_payload = __json.dumps({{"data": prompt, "max_tokens": 8192}})
try:
Expand All @@ -437,6 +464,7 @@ def llm_query(prompt):
_r = __requests.get(_poll_url, headers=_headers, timeout=30)
_result = _r.json()
_wait = min(_wait * 1.1, 60)
_total_llm_query_credits += float(_result.get("usedCredits", 0) or 0)
return str(_result.get("data", "Error: no data in worker response"))
except Exception as _e:
return f"Error: llm_query failed — {{_e}}"
Expand All @@ -445,12 +473,14 @@ def llm_query(prompt):
self._run_sandbox(llm_query_code)
logging.debug("RLM: llm_query injected into sandbox.")

def _run_sandbox(self, code: str) -> None:
"""Execute code in the sandbox, ignoring the output (used for setup steps)."""
self._sandbox_tool.run(
def _run_sandbox(self, code: str) -> ModelResponse:
"""Execute code in the sandbox and return the raw response."""
result = self._sandbox_tool.run(
inputs={"code": code, "sessionId": self._session_id},
action="run",
)
self._used_credits += float(getattr(result, "used_credits", 0) or 0)
return result

# Code Execution

Expand All @@ -468,10 +498,7 @@ def _execute_code(self, code: str) -> str:
Formatted string combining stdout and stderr. Returns "No output"
if both are empty.
"""
result = self._sandbox_tool.run(
inputs={"code": code, "sessionId": self._session_id},
action="run",
)
result = self._run_sandbox(code)
stdout = result.data.get("stdout", "") if isinstance(result.data, dict) else ""
stderr = result.data.get("stderr", "") if isinstance(result.data, dict) else ""

Expand All @@ -497,10 +524,7 @@ def _get_repl_variable(self, variable_name: str) -> Optional[str]:
String representation of the variable, or None if not found or on error.
"""
var = variable_name.strip().strip("\"'")
result = self._sandbox_tool.run(
inputs={"code": f"print(str({var}))", "sessionId": self._session_id},
action="run",
)
result = self._run_sandbox(f"print(str({var}))")
stdout = result.data.get("stdout", "") if isinstance(result.data, dict) else ""
stderr = result.data.get("stderr", "") if isinstance(result.data, dict) else ""

Expand All @@ -509,6 +533,23 @@ def _get_repl_variable(self, variable_name: str) -> Optional[str]:
return None
return stdout.strip() if stdout else None

# Credit Tracking

def _collect_llm_query_credits(self) -> None:
"""Retrieve accumulated ``llm_query`` worker credits from the sandbox.

The injected ``llm_query`` function tracks per-call ``usedCredits``
from the worker model API in a global ``_total_llm_query_credits``
variable inside the sandbox session. This method reads that variable
and adds it to ``self._used_credits``.
"""
try:
raw = self._get_repl_variable("_total_llm_query_credits")
if raw is not None:
self._used_credits += float(raw)
except Exception:
logging.debug("RLM: could not retrieve llm_query credits from sandbox.")

# Orchestrator

def _orchestrator_completion(self, messages: List[Dict[str, str]]) -> str:
Expand All @@ -532,6 +573,7 @@ def _orchestrator_completion(self, messages: List[Dict[str, str]]) -> str:
# response = self.orchestrator.run(data={"messages": messages})
prompt = _messages_to_prompt(messages)
response = self.orchestrator.run(data=prompt, max_tokens=8192)
self._used_credits += float(getattr(response, "used_credits", 0) or 0)
if response.get("completed") or response["status"] == ResponseStatus.SUCCESS:
return str(response["data"])
raise RuntimeError(f"Orchestrator model failed: {response.get('error_message', 'Unknown error')}")
Expand Down Expand Up @@ -585,6 +627,9 @@ def run(
- ``data``: The final answer string.
- ``completed``: True on success.
- ``run_time``: Total elapsed seconds.
- ``used_credits``: Total credits consumed across all
orchestrator calls, sandbox executions, and worker
``llm_query()`` invocations.
- ``iterations_used``: Number of orchestrator iterations (via
``response["iterations_used"]``).

Expand Down Expand Up @@ -630,13 +675,14 @@ def run(
iterations_used = 0
final_answer = None
repl_logs: List[Dict] = []
self._used_credits = 0.0

# Normalize context: resolve file paths and pathlib.Path objects
context = self._resolve_context(context)

# Initialize sandbox and conversation
self._setup_repl(context)
self._messages = _build_system_messages()
self._messages = _build_system_messages(self._get_worker_context_window())

try:
for iteration in range(self.max_iterations):
Expand Down Expand Up @@ -692,14 +738,17 @@ def run(
except Exception as e:
error_msg = f"RLM run error: {str(e)}"
logging.error(error_msg)
self._collect_llm_query_credits()
return ModelResponse(
status=ResponseStatus.FAILED,
completed=True,
error_message=error_msg,
run_time=time.time() - start_time,
used_credits=self._used_credits,
iterations_used=iterations_used,
)

self._collect_llm_query_credits()
run_time = time.time() - start_time
logging.info(f"RLM '{name}': done in {iterations_used} iterations, {run_time:.1f}s.")

Expand All @@ -708,6 +757,7 @@ def run(
data=final_answer or "",
completed=True,
run_time=run_time,
used_credits=self._used_credits,
iterations_used=iterations_used,
)

Expand Down
Loading
Loading