From 5aa37f5a496d776839924413a4a3a3466292d06c Mon Sep 17 00:00:00 2001 From: Abdelrahman Elsheikh Date: Thu, 9 Apr 2026 15:04:59 +0200 Subject: [PATCH 1/4] Add used_credits to response and context_window information to sys prompt --- aixplain/v1/modules/model/rlm.py | 80 ++++-- aixplain/v2/rlm.py | 72 ++++- tests/unit/rlm_test.py | 441 +++++++++++++++++++++++++++++++ 3 files changed, 571 insertions(+), 22 deletions(-) create mode 100644 tests/unit/rlm_test.py diff --git a/aixplain/v1/modules/model/rlm.py b/aixplain/v1/modules/model/rlm.py index aa95961a..eb53a004 100644 --- a/aixplain/v1/modules/model/rlm.py +++ b/aixplain/v1/modules/model/rlm.py @@ -57,13 +57,13 @@ The REPL environment is initialized with: 1. A `context` variable that contains extremely important information about your query. You should check the content of the `context` variable to understand what you are working with. Make sure you look through it sufficiently as you answer your query. -2. A `llm_query` function that allows you to query an LLM (that can handle around 500K chars) inside your REPL environment. +2. A `llm_query` function that allows you to query an LLM with a context window of {worker_context_window} inside your REPL environment. You must take this context window into consideration when deciding how much text to pass in each call. 3. The ability to use `print()` statements to view the output of your REPL code and continue your reasoning. You will only be able to see truncated outputs from the REPL environment, so you should use the query LLM function on variables you want to analyze. You will find this function especially useful when you have to analyze the semantics of the context. Use these variables as buffers to build up your final answer. Make sure to explicitly look through the entire context in REPL before answering your query. An example strategy is to first look at the context and figure out a chunking strategy, then break up the context into smart chunks, and query an LLM per chunk with a particular question and save the answers to a buffer, then query an LLM with all the buffers to produce your final answer. -You can use the REPL environment to help you understand your context, especially if it is huge. Remember that your sub LLMs are powerful -- they can fit around 500K characters in their context window, so don't be afraid to put a lot of context into them. +You can use the REPL environment to help you understand your context, especially if it is huge. Remember that your sub LLMs are powerful -- they have a context window of {worker_context_window}, so don't be afraid to put a lot of context into them. When you want to execute Python code in the REPL environment, wrap it in triple backticks with the 'repl' language identifier: ```repl @@ -107,8 +107,8 @@ # Prompt Helpers -def _build_system_messages() -> List[Dict[str, str]]: - return [{"role": "system", "content": _SYSTEM_PROMPT}] +def _build_system_messages(worker_context_window: str) -> List[Dict[str, str]]: + return [{"role": "system", "content": _SYSTEM_PROMPT.format(worker_context_window=worker_context_window)}] def _next_action_message(query: str, iteration: int, force_final: bool = False) -> Dict[str, str]: @@ -238,6 +238,28 @@ def __init__( self._session_id: Optional[str] = None self._sandbox_tool: Optional[Model] = None self._messages: List[Dict[str, str]] = [] + self._used_credits: float = 0.0 + + # Worker Context Window + + def _get_worker_context_window(self) -> str: + """Return a human-readable description of the worker model's context window.""" + attributes = getattr(self.worker, "additional_info", {}).get("attributes", []) + raw = next( + (attr["code"] for attr in attributes if attr.get("name") == "max_context_length"), + None, + ) + if raw is not None: + try: + tokens = int(raw) + if tokens >= 1_000_000: + return f"{tokens / 1_000_000:.1f}M tokens" + if tokens >= 1_000: + return f"{tokens / 1_000:.0f}K tokens" + return f"{tokens} tokens" + except (ValueError, TypeError): + return str(raw) + return "a large context window" # Context Resolution @@ -422,7 +444,10 @@ def _setup_repl(self, context: Union[str, dict, list]) -> None: import time as __time import json as __json +_total_llm_query_credits = 0.0 + def llm_query(prompt): + global _total_llm_query_credits _headers = {{"x-api-key": "{self.api_key}", "Content-Type": "application/json"}} _payload = __json.dumps({{"data": prompt, "max_tokens": 8192}}) try: @@ -437,6 +462,7 @@ def llm_query(prompt): _r = __requests.get(_poll_url, headers=_headers, timeout=30) _result = _r.json() _wait = min(_wait * 1.1, 60) + _total_llm_query_credits += float(_result.get("usedCredits", 0) or 0) return str(_result.get("data", "Error: no data in worker response")) except Exception as _e: return f"Error: llm_query failed — {{_e}}" @@ -445,12 +471,14 @@ def llm_query(prompt): self._run_sandbox(llm_query_code) logging.debug("RLM: llm_query injected into sandbox.") - def _run_sandbox(self, code: str) -> None: - """Execute code in the sandbox, ignoring the output (used for setup steps).""" - self._sandbox_tool.run( + def _run_sandbox(self, code: str) -> ModelResponse: + """Execute code in the sandbox and return the raw response.""" + result = self._sandbox_tool.run( inputs={"code": code, "sessionId": self._session_id}, action="run", ) + self._used_credits += float(getattr(result, "used_credits", 0) or 0) + return result # Code Execution @@ -468,10 +496,7 @@ def _execute_code(self, code: str) -> str: Formatted string combining stdout and stderr. Returns "No output" if both are empty. """ - result = self._sandbox_tool.run( - inputs={"code": code, "sessionId": self._session_id}, - action="run", - ) + result = self._run_sandbox(code) stdout = result.data.get("stdout", "") if isinstance(result.data, dict) else "" stderr = result.data.get("stderr", "") if isinstance(result.data, dict) else "" @@ -497,10 +522,7 @@ def _get_repl_variable(self, variable_name: str) -> Optional[str]: String representation of the variable, or None if not found or on error. """ var = variable_name.strip().strip("\"'") - result = self._sandbox_tool.run( - inputs={"code": f"print(str({var}))", "sessionId": self._session_id}, - action="run", - ) + result = self._run_sandbox(f"print(str({var}))") stdout = result.data.get("stdout", "") if isinstance(result.data, dict) else "" stderr = result.data.get("stderr", "") if isinstance(result.data, dict) else "" @@ -509,6 +531,23 @@ def _get_repl_variable(self, variable_name: str) -> Optional[str]: return None return stdout.strip() if stdout else None + # Credit Tracking + + def _collect_llm_query_credits(self) -> None: + """Retrieve accumulated ``llm_query`` worker credits from the sandbox. + + The injected ``llm_query`` function tracks per-call ``usedCredits`` + from the worker model API in a global ``_total_llm_query_credits`` + variable inside the sandbox session. This method reads that variable + and adds it to ``self._used_credits``. + """ + try: + raw = self._get_repl_variable("_total_llm_query_credits") + if raw is not None: + self._used_credits += float(raw) + except Exception: + logging.debug("RLM: could not retrieve llm_query credits from sandbox.") + # Orchestrator def _orchestrator_completion(self, messages: List[Dict[str, str]]) -> str: @@ -532,6 +571,7 @@ def _orchestrator_completion(self, messages: List[Dict[str, str]]) -> str: # response = self.orchestrator.run(data={"messages": messages}) prompt = _messages_to_prompt(messages) response = self.orchestrator.run(data=prompt, max_tokens=8192) + self._used_credits += float(getattr(response, "used_credits", 0) or 0) if response.get("completed") or response["status"] == ResponseStatus.SUCCESS: return str(response["data"]) raise RuntimeError(f"Orchestrator model failed: {response.get('error_message', 'Unknown error')}") @@ -585,6 +625,9 @@ def run( - ``data``: The final answer string. - ``completed``: True on success. - ``run_time``: Total elapsed seconds. + - ``used_credits``: Total credits consumed across all + orchestrator calls, sandbox executions, and worker + ``llm_query()`` invocations. - ``iterations_used``: Number of orchestrator iterations (via ``response["iterations_used"]``). @@ -630,13 +673,14 @@ def run( iterations_used = 0 final_answer = None repl_logs: List[Dict] = [] + self._used_credits = 0.0 # Normalize context: resolve file paths and pathlib.Path objects context = self._resolve_context(context) # Initialize sandbox and conversation self._setup_repl(context) - self._messages = _build_system_messages() + self._messages = _build_system_messages(self._get_worker_context_window()) try: for iteration in range(self.max_iterations): @@ -692,14 +736,17 @@ def run( except Exception as e: error_msg = f"RLM run error: {str(e)}" logging.error(error_msg) + self._collect_llm_query_credits() return ModelResponse( status=ResponseStatus.FAILED, completed=True, error_message=error_msg, run_time=time.time() - start_time, + used_credits=self._used_credits, iterations_used=iterations_used, ) + self._collect_llm_query_credits() run_time = time.time() - start_time logging.info(f"RLM '{name}': done in {iterations_used} iterations, {run_time:.1f}s.") @@ -708,6 +755,7 @@ def run( data=final_answer or "", completed=True, run_time=run_time, + used_credits=self._used_credits, iterations_used=iterations_used, ) diff --git a/aixplain/v2/rlm.py b/aixplain/v2/rlm.py index 2db27f08..b23b9dc5 100644 --- a/aixplain/v2/rlm.py +++ b/aixplain/v2/rlm.py @@ -46,13 +46,13 @@ The REPL environment is initialized with: 1. A `context` variable that contains extremely important information about your query. You should check the content of the `context` variable to understand what you are working with. Make sure you look through it sufficiently as you answer your query. -2. A `llm_query` function that allows you to query an LLM (that can handle around 500K chars) inside your REPL environment. +2. A `llm_query` function that allows you to query an LLM with a context window of {worker_context_window} inside your REPL environment. You must take this context window into consideration when deciding how much text to pass in each call. 3. The ability to use `print()` statements to view the output of your REPL code and continue your reasoning. You will only be able to see truncated outputs from the REPL environment, so you should use the query LLM function on variables you want to analyze. You will find this function especially useful when you have to analyze the semantics of the context. Use these variables as buffers to build up your final answer. Make sure to explicitly look through the entire context in REPL before answering your query. An example strategy is to first look at the context and figure out a chunking strategy, then break up the context into smart chunks, and query an LLM per chunk with a particular question and save the answers to a buffer, then query an LLM with all the buffers to produce your final answer. -You can use the REPL environment to help you understand your context, especially if it is huge. Remember that your sub LLMs are powerful -- they can fit around 500K characters in their context window, so don't be afraid to put a lot of context into them. +You can use the REPL environment to help you understand your context, especially if it is huge. Remember that your sub LLMs are powerful -- they have a context window of {worker_context_window}, so don't be afraid to put a lot of context into them. When you want to execute Python code in the REPL environment, wrap it in triple backticks with the 'repl' language identifier: ```repl @@ -96,8 +96,8 @@ # Prompt Helpers -def _build_system_messages() -> List[Dict[str, str]]: - return [{"role": "system", "content": _SYSTEM_PROMPT}] +def _build_system_messages(worker_context_window: str) -> List[Dict[str, str]]: + return [{"role": "system", "content": _SYSTEM_PROMPT.format(worker_context_window=worker_context_window)}] def _next_action_message(query: str, iteration: int, force_final: bool = False) -> Dict[str, str]: @@ -151,11 +151,14 @@ class RLMResult(Result): Attributes: iterations_used: Number of orchestrator iterations consumed. + used_credits: Total credits consumed across all orchestrator calls, + sandbox executions, and worker ``llm_query()`` invocations. repl_logs: Per-iteration REPL execution log (excluded from serialization; present only on live instances). """ iterations_used: int = field(default=0) + used_credits: float = field(default=0.0, metadata=dj_config(field_name="usedCredits")) repl_logs: List[Dict] = field( default_factory=list, repr=False, @@ -257,6 +260,13 @@ class RLM(BaseResource, ToolableMixin): metadata=dj_config(exclude=lambda x: True), init=False, ) + _used_credits: float = field( + default=0.0, + repr=False, + compare=False, + metadata=dj_config(exclude=lambda x: True), + init=False, + ) def __post_init__(self) -> None: """Auto-assign a UUID when no id is provided.""" @@ -343,6 +353,25 @@ def _get_sandbox(self) -> Any: logger.debug("RLM: sandbox tool resolved.") return self._sandbox_tool + # Worker Context Window + + def _get_worker_context_window(self) -> str: + """Return a human-readable description of the worker model's context window.""" + worker = self._get_worker() + attrs = getattr(worker, "attributes", None) or {} + raw = attrs.get("max_context_length", None) + if raw is not None: + try: + tokens = int(raw) + if tokens >= 1_000_000: + return f"{tokens / 1_000_000:.1f}M tokens" + if tokens >= 1_000: + return f"{tokens / 1_000:.0f}K tokens" + return f"{tokens} tokens" + except (ValueError, TypeError): + return str(raw) + return "a large context window" + # Sandbox Setup def _setup_repl(self, context: Union[str, dict, list]) -> None: @@ -470,7 +499,10 @@ def _setup_repl(self, context: Union[str, dict, list]) -> None: import time as __time import json as __json +_total_llm_query_credits = 0.0 + def llm_query(prompt): + global _total_llm_query_credits _headers = {{"x-api-key": "{self.context.api_key}", "Content-Type": "application/json"}} _payload = __json.dumps({{"data": prompt, "max_tokens": 8192}}) try: @@ -485,6 +517,7 @@ def llm_query(prompt): _r = __requests.get(_poll_url, headers=_headers, timeout=30) _result = _r.json() _wait = min(_wait * 1.1, 60) + _total_llm_query_credits += float(_result.get("usedCredits", 0) or 0) return str(_result.get("data", "Error: no data in worker response")) except Exception as _e: return f"Error: llm_query failed \u2014 {{_e}}" @@ -496,10 +529,12 @@ def llm_query(prompt): def _run_sandbox(self, sandbox: Any, code: str) -> Any: """Execute code in the sandbox and return the raw ToolResult.""" - return sandbox.run( + result = sandbox.run( data={"code": code, "sessionId": self._session_id}, action="run", ) + self._used_credits += float(getattr(result, "used_credits", 0) or 0) + return result def _execute_code(self, code: str) -> str: """Execute a code block in the sandbox and return formatted output. @@ -551,6 +586,23 @@ def _get_repl_variable(self, variable_name: str) -> Optional[str]: return None return stdout.strip() if stdout else None + # Credit Tracking + + def _collect_llm_query_credits(self) -> None: + """Retrieve accumulated ``llm_query`` worker credits from the sandbox. + + The injected ``llm_query`` function tracks per-call ``usedCredits`` + from the worker model API in a global ``_total_llm_query_credits`` + variable inside the sandbox session. This method reads that variable + and adds it to ``self._used_credits``. + """ + try: + raw = self._get_repl_variable("_total_llm_query_credits") + if raw is not None: + self._used_credits += float(raw) + except Exception: + logger.debug("RLM: could not retrieve llm_query credits from sandbox.") + # Orchestrator def _orchestrator_completion(self, messages: List[Dict[str, str]]) -> str: @@ -569,6 +621,7 @@ def _orchestrator_completion(self, messages: List[Dict[str, str]]) -> str: ResourceError: If the orchestrator model call fails or returns an error. """ response = self._get_orchestrator().run(text=_messages_to_prompt(messages), max_tokens=8192) + self._used_credits += float(getattr(response, "used_credits", 0) or 0) if response.completed or response.status == "SUCCESS": return str(response.data) raise ResourceError( @@ -617,6 +670,8 @@ def run( - ``data``: Final answer string. - ``status``: ``"SUCCESS"`` or ``"FAILED"``. - ``completed``: ``True``. + - ``used_credits``: Total credits consumed across all orchestrator + calls, sandbox executions, and worker ``llm_query()`` invocations. - ``iterations_used``: Number of orchestrator iterations consumed. - ``repl_logs``: Per-iteration execution log (not serialized). @@ -655,11 +710,12 @@ def run( iterations_used = 0 final_answer: Optional[str] = None repl_logs: List[Dict] = [] + self._used_credits = 0.0 # Resolve file-path context, initialise sandbox + conversation context = self._resolve_context(context) self._setup_repl(context) - self._messages = _build_system_messages() + self._messages = _build_system_messages(self._get_worker_context_window()) try: for iteration in range(self.max_iterations): @@ -715,6 +771,7 @@ def run( except Exception as exc: error_msg = f"RLM run error: {exc}" logger.error(error_msg) + self._collect_llm_query_credits() result = RLMResult( status="FAILED", completed=True, @@ -722,9 +779,11 @@ def run( data=None, ) result.iterations_used = iterations_used + result.used_credits = self._used_credits result.repl_logs = repl_logs return result + self._collect_llm_query_credits() run_time = time.time() - start_time logger.info(f"RLM '{name}': done in {iterations_used} iteration(s), {run_time:.1f}s.") @@ -734,6 +793,7 @@ def run( data=final_answer or "", ) result.iterations_used = iterations_used + result.used_credits = self._used_credits result.repl_logs = repl_logs result._raw_data = {"run_time": run_time} return result diff --git a/tests/unit/rlm_test.py b/tests/unit/rlm_test.py new file mode 100644 index 00000000..b40c9d7b --- /dev/null +++ b/tests/unit/rlm_test.py @@ -0,0 +1,441 @@ +"""Unit tests for RLM context resolution, sandbox setup, credit tracking, and context window.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from aixplain.v1.modules.model.rlm import RLM as RLMV1 +from aixplain.v2.rlm import RLM as RLMV2, RLMResult + + +# Parametrize over both implementations +RLM_IMPLS = [ + pytest.param(RLMV1, id="v1"), + pytest.param(RLMV2, id="v2"), +] + + +# _resolve_context +class TestResolveContext: + @pytest.mark.parametrize("RLM", RLM_IMPLS) + def test_local_text_file(self, RLM, tmp_path): + p = tmp_path / "doc.txt" + p.write_text("file content", encoding="utf-8") + assert RLM._resolve_context(str(p)) == "file content" + + @pytest.mark.parametrize("RLM", RLM_IMPLS) + def test_local_json_file(self, RLM, tmp_path): + data = {"a": 1} + p = tmp_path / "data.json" + p.write_text(json.dumps(data), encoding="utf-8") + assert RLM._resolve_context(str(p)) == data + + @pytest.mark.parametrize("RLM", RLM_IMPLS) + def test_pathlib_path(self, RLM, tmp_path): + p = tmp_path / "doc.txt" + p.write_text("pathlib content", encoding="utf-8") + assert RLM._resolve_context(p) == "pathlib content" + + @pytest.mark.parametrize("RLM", RLM_IMPLS) + def test_raw_string(self, RLM): + assert RLM._resolve_context("just raw text") == "just raw text" + + @pytest.mark.parametrize("RLM", RLM_IMPLS) + def test_dict_passthrough(self, RLM): + d = {"x": 1} + assert RLM._resolve_context(d) is d + + @pytest.mark.parametrize("RLM", RLM_IMPLS) + def test_list_passthrough(self, RLM): + lst = [1, 2, 3] + assert RLM._resolve_context(lst) is lst + + @pytest.mark.parametrize("RLM", RLM_IMPLS) + def test_non_string_fallback(self, RLM): + assert RLM._resolve_context(42) == "42" + + @pytest.mark.parametrize("RLM", RLM_IMPLS) + def test_http_url_passes_through_unchanged(self, RLM): + url = "http://example.com/doc.txt" + assert RLM._resolve_context(url) == url + + @pytest.mark.parametrize("RLM", RLM_IMPLS) + def test_https_url_passes_through_unchanged(self, RLM): + url = "https://example.com/data.json" + assert RLM._resolve_context(url) == url + + +# _setup_repl — URL branch +def _make_v1_rlm() -> RLMV1: + """Minimal v1 RLM with stubbed models.""" + rlm = RLMV1.__new__(RLMV1) + rlm.api_key = "test-key" + rlm.orchestrator = MagicMock() + rlm.worker = MagicMock() + rlm.worker.url = "https://models.aixplain.com/api/v2/execute" + rlm.worker.id = "worker-id" + rlm.worker.additional_info = {} + rlm._session_id = None + rlm._sandbox_tool = None + rlm._messages = [] + rlm._used_credits = 0.0 + return rlm + + +def _make_v2_rlm() -> RLMV2: + """Minimal v2 RLM with stubbed context client.""" + rlm = RLMV2.__new__(RLMV2) + rlm.orchestrator_id = "orch-id" + rlm.worker_id = "worker-id" + rlm.max_iterations = 10 + rlm.timeout = 600.0 + rlm._session_id = None + rlm._sandbox_tool = None + rlm._orchestrator = None + rlm._worker = None + rlm._messages = [] + rlm._used_credits = 0.0 + client = MagicMock() + client.backend_url = "https://platform-api.aixplain.com" + client.api_key = "test-key" + client.model_url = "https://models.aixplain.com/api/v2/execute" + rlm.context = client + return rlm + + +class TestSetupReplURLPath: + def test_v1_url_skips_file_factory(self): + rlm = _make_v1_rlm() + sandbox_mock = MagicMock() + + with ( + patch("aixplain.factories.tool_factory.ToolFactory") as mock_tf, + patch("aixplain.factories.file_factory.FileFactory") as mock_ff, + ): + mock_tf.get.return_value = sandbox_mock + rlm._setup_repl("https://example.com/doc.txt") + + mock_ff.create.assert_not_called() + + def test_v1_url_sandbox_code_contains_url(self): + rlm = _make_v1_rlm() + sandbox_mock = MagicMock() + captured = [] + + def capture_run(inputs, action): + captured.append(inputs["code"]) + return MagicMock(used_credits=0) + + sandbox_mock.run.side_effect = capture_run + + with patch("aixplain.factories.tool_factory.ToolFactory") as mock_tf: + mock_tf.get.return_value = sandbox_mock + rlm._setup_repl("https://example.com/doc.txt") + + context_code = captured[0] + assert "https://example.com/doc.txt" in context_code + assert "_content_type" in context_code + assert "_is_json" in context_code + assert "__json.load" in context_code + + def test_v2_url_skips_file_uploader(self): + rlm = _make_v2_rlm() + sandbox_mock = MagicMock() + rlm._sandbox_tool = sandbox_mock + + with patch("aixplain.v2.rlm.FileUploader") as mock_uploader: + rlm._setup_repl("https://example.com/doc.txt") + + mock_uploader.assert_not_called() + + def test_v2_url_sandbox_code_contains_url(self): + rlm = _make_v2_rlm() + sandbox_mock = MagicMock() + rlm._sandbox_tool = sandbox_mock + captured = [] + + def capture_run(data, action): + captured.append(data["code"]) + return MagicMock(used_credits=0) + + sandbox_mock.run.side_effect = capture_run + + rlm._setup_repl("https://example.com/doc.txt") + + context_code = captured[0] + assert "https://example.com/doc.txt" in context_code + assert "_content_type" in context_code + assert "_is_json" in context_code + assert "__json.load" in context_code + + +# Credit tracking +def _sandbox_result(stdout="", stderr="", used_credits=0.0): + """Create a mock sandbox result.""" + r = MagicMock() + r.data = {"stdout": stdout, "stderr": stderr} + r.used_credits = used_credits + return r + + +def _model_response_v1(data="response text", used_credits=0.0, completed=True, status="SUCCESS"): + """Create a mock v1 model response.""" + r = MagicMock() + r.data = data + r.used_credits = used_credits + r.get = lambda k, default=None: {"completed": completed, "data": data, "status": status, "error_message": ""}.get( + k, default + ) + r.__getitem__ = lambda self_, k: {"completed": completed, "data": data, "status": status}.get(k) + return r + + +class TestV1CreditTracking: + def test_orchestrator_credits_accumulated(self): + rlm = _make_v1_rlm() + rlm._used_credits = 0.0 + rlm.orchestrator.run.return_value = _model_response_v1(used_credits=0.05) + + rlm._orchestrator_completion([{"role": "user", "content": "test"}]) + + assert rlm._used_credits == pytest.approx(0.05) + + def test_sandbox_credits_accumulated(self): + rlm = _make_v1_rlm() + rlm._used_credits = 0.0 + rlm._sandbox_tool = MagicMock() + rlm._sandbox_tool.run.return_value = _sandbox_result(used_credits=0.01) + rlm._session_id = "test-session" + + rlm._run_sandbox("print('hello')") + + assert rlm._used_credits == pytest.approx(0.01) + + def test_execute_code_credits_accumulated(self): + rlm = _make_v1_rlm() + rlm._used_credits = 0.0 + rlm._sandbox_tool = MagicMock() + rlm._sandbox_tool.run.return_value = _sandbox_result(stdout="done", used_credits=0.02) + rlm._session_id = "test-session" + + output = rlm._execute_code("x = 1\nprint('done')") + + assert "done" in output + assert rlm._used_credits == pytest.approx(0.02) + + def test_collect_llm_query_credits(self): + rlm = _make_v1_rlm() + rlm._used_credits = 1.0 + rlm._sandbox_tool = MagicMock() + rlm._session_id = "test-session" + rlm._sandbox_tool.run.return_value = _sandbox_result(stdout="0.35", used_credits=0.0) + + rlm._collect_llm_query_credits() + + assert rlm._used_credits == pytest.approx(1.35) + + def test_multiple_calls_accumulate(self): + rlm = _make_v1_rlm() + rlm._used_credits = 0.0 + rlm._session_id = "test-session" + rlm._sandbox_tool = MagicMock() + + rlm.orchestrator.run.return_value = _model_response_v1(used_credits=0.1) + rlm._orchestrator_completion([{"role": "user", "content": "a"}]) + rlm._orchestrator_completion([{"role": "user", "content": "b"}]) + + rlm._sandbox_tool.run.return_value = _sandbox_result(stdout="ok", used_credits=0.05) + rlm._execute_code("pass") + rlm._execute_code("pass") + + assert rlm._used_credits == pytest.approx(0.3) + + +class TestV2CreditTracking: + def test_orchestrator_credits_accumulated(self): + rlm = _make_v2_rlm() + rlm._used_credits = 0.0 + mock_model = MagicMock() + resp = MagicMock() + resp.completed = True + resp.status = "SUCCESS" + resp.data = "answer" + resp.used_credits = 0.07 + mock_model.run.return_value = resp + rlm._orchestrator = mock_model + + rlm._orchestrator_completion([{"role": "user", "content": "test"}]) + + assert rlm._used_credits == pytest.approx(0.07) + + def test_sandbox_credits_accumulated(self): + rlm = _make_v2_rlm() + rlm._used_credits = 0.0 + sandbox = MagicMock() + sandbox.run.return_value = _sandbox_result(used_credits=0.03) + rlm._session_id = "test-session" + + rlm._run_sandbox(sandbox, "print('hi')") + + assert rlm._used_credits == pytest.approx(0.03) + + def test_execute_code_credits_accumulated(self): + rlm = _make_v2_rlm() + rlm._used_credits = 0.0 + sandbox = MagicMock() + sandbox.run.return_value = _sandbox_result(stdout="done", used_credits=0.04) + rlm._sandbox_tool = sandbox + rlm._session_id = "test-session" + + output = rlm._execute_code("print('done')") + + assert "done" in output + assert rlm._used_credits == pytest.approx(0.04) + + def test_collect_llm_query_credits(self): + rlm = _make_v2_rlm() + rlm._used_credits = 2.0 + sandbox = MagicMock() + sandbox.run.return_value = _sandbox_result(stdout="0.50", used_credits=0.0) + rlm._sandbox_tool = sandbox + rlm._session_id = "test-session" + + rlm._collect_llm_query_credits() + + assert rlm._used_credits == pytest.approx(2.50) + + def test_used_credits_field_on_rlm_result(self): + result = RLMResult(status="SUCCESS", completed=True, data="answer") + result.used_credits = 1.23 + result.iterations_used = 5 + + assert result.used_credits == pytest.approx(1.23) + serialized = result.to_dict() + assert serialized["usedCredits"] == pytest.approx(1.23) + + +class TestLlmQueryCodeCreditsTracking: + def test_v1_llm_query_code_accumulates_credits(self): + rlm = _make_v1_rlm() + sandbox_mock = MagicMock() + captured = [] + + def capture_run(inputs, action): + captured.append(inputs["code"]) + return MagicMock(used_credits=0) + + sandbox_mock.run.side_effect = capture_run + + with ( + patch("aixplain.factories.tool_factory.ToolFactory") as mock_tf, + patch("aixplain.factories.file_factory.FileFactory") as mock_ff, + ): + mock_tf.get.return_value = sandbox_mock + mock_ff.create.return_value = "https://storage.example.com/ctx.txt" + rlm._setup_repl("raw text context") + + llm_query_code = captured[-1] + assert "_total_llm_query_credits" in llm_query_code + assert "global _total_llm_query_credits" in llm_query_code + assert "usedCredits" in llm_query_code + + def test_v2_llm_query_code_accumulates_credits(self): + rlm = _make_v2_rlm() + sandbox_mock = MagicMock() + captured = [] + + def capture_run(data, action): + captured.append(data["code"]) + return MagicMock(used_credits=0) + + sandbox_mock.run.side_effect = capture_run + rlm._sandbox_tool = sandbox_mock + + with patch("aixplain.v2.rlm.FileUploader") as mock_uploader: + uploader_instance = MagicMock() + uploader_instance.upload.return_value = "https://storage.example.com/ctx.txt" + mock_uploader.return_value = uploader_instance + rlm._setup_repl("raw text context") + + llm_query_code = captured[-1] + assert "_total_llm_query_credits" in llm_query_code + assert "global _total_llm_query_credits" in llm_query_code + assert "usedCredits" in llm_query_code + + +# Worker context window +class TestV1WorkerContextWindow: + def test_returns_formatted_k_tokens(self): + rlm = _make_v1_rlm() + rlm.worker.additional_info = {"attributes": [{"name": "max_context_length", "code": "128000"}]} + assert rlm._get_worker_context_window() == "128K tokens" + + def test_returns_formatted_m_tokens(self): + rlm = _make_v1_rlm() + rlm.worker.additional_info = {"attributes": [{"name": "max_context_length", "code": "1048576"}]} + assert rlm._get_worker_context_window() == "1.0M tokens" + + def test_returns_small_token_count(self): + rlm = _make_v1_rlm() + rlm.worker.additional_info = {"attributes": [{"name": "max_context_length", "code": "512"}]} + assert rlm._get_worker_context_window() == "512 tokens" + + def test_fallback_when_no_attributes(self): + rlm = _make_v1_rlm() + rlm.worker.additional_info = {} + assert rlm._get_worker_context_window() == "a large context window" + + def test_fallback_when_attribute_missing(self): + rlm = _make_v1_rlm() + rlm.worker.additional_info = {"attributes": [{"name": "other_attr", "code": "100"}]} + assert rlm._get_worker_context_window() == "a large context window" + + def test_non_numeric_returns_raw_string(self): + rlm = _make_v1_rlm() + rlm.worker.additional_info = {"attributes": [{"name": "max_context_length", "code": "unlimited"}]} + assert rlm._get_worker_context_window() == "unlimited" + + +class TestV2WorkerContextWindow: + def test_returns_formatted_k_tokens(self): + rlm = _make_v2_rlm() + mock_worker = MagicMock() + mock_worker.attributes = {"max_context_length": "200000"} + rlm._worker = mock_worker + assert rlm._get_worker_context_window() == "200K tokens" + + def test_returns_formatted_m_tokens(self): + rlm = _make_v2_rlm() + mock_worker = MagicMock() + mock_worker.attributes = {"max_context_length": "2000000"} + rlm._worker = mock_worker + assert rlm._get_worker_context_window() == "2.0M tokens" + + def test_fallback_when_no_attributes(self): + rlm = _make_v2_rlm() + mock_worker = MagicMock() + mock_worker.attributes = {} + rlm._worker = mock_worker + assert rlm._get_worker_context_window() == "a large context window" + + def test_fallback_when_attributes_none(self): + rlm = _make_v2_rlm() + mock_worker = MagicMock() + mock_worker.attributes = None + rlm._worker = mock_worker + assert rlm._get_worker_context_window() == "a large context window" + + def test_non_numeric_returns_raw_string(self): + rlm = _make_v2_rlm() + mock_worker = MagicMock() + mock_worker.attributes = {"max_context_length": "very_large"} + rlm._worker = mock_worker + assert rlm._get_worker_context_window() == "very_large" + + def test_integer_attribute_value(self): + rlm = _make_v2_rlm() + mock_worker = MagicMock() + mock_worker.attributes = {"max_context_length": 32000} + rlm._worker = mock_worker + assert rlm._get_worker_context_window() == "32K tokens" From cae54cf935bf33c41e6c0f007cba7a6e0dfa81a6 Mon Sep 17 00:00:00 2001 From: Abdelrahman Elsheikh Date: Thu, 9 Apr 2026 15:17:15 +0200 Subject: [PATCH 2/4] Escape literal braces in the system prompt --- aixplain/v1/modules/model/rlm.py | 2 +- aixplain/v2/rlm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aixplain/v1/modules/model/rlm.py b/aixplain/v1/modules/model/rlm.py index eb53a004..286c6951 100644 --- a/aixplain/v1/modules/model/rlm.py +++ b/aixplain/v1/modules/model/rlm.py @@ -69,7 +69,7 @@ ```repl # your Python code here chunk = context[:10000] -answer = llm_query(f"What is the key finding in this text?\\n{chunk}") +answer = llm_query(f"What is the key finding in this text?\\n{{chunk}}") print(answer) ``` diff --git a/aixplain/v2/rlm.py b/aixplain/v2/rlm.py index b23b9dc5..05c76707 100644 --- a/aixplain/v2/rlm.py +++ b/aixplain/v2/rlm.py @@ -58,7 +58,7 @@ ```repl # your Python code here chunk = context[:10000] -answer = llm_query(f"What is the key finding in this text?\\n{chunk}") +answer = llm_query(f"What is the key finding in this text?\\n{{chunk}}") print(answer) ``` From 051b0214e7b01466763cf168907f064c693cb9a7 Mon Sep 17 00:00:00 2001 From: Abdelrahman Elsheikh Date: Thu, 9 Apr 2026 16:48:55 +0200 Subject: [PATCH 3/4] Improve RLM sys prompt --- aixplain/v1/modules/model/rlm.py | 4 +++- aixplain/v2/rlm.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/aixplain/v1/modules/model/rlm.py b/aixplain/v1/modules/model/rlm.py index 286c6951..45b9175c 100644 --- a/aixplain/v1/modules/model/rlm.py +++ b/aixplain/v1/modules/model/rlm.py @@ -74,9 +74,11 @@ ``` IMPORTANT: When you are done with the iterative process, you MUST provide a final answer using one of these two forms (NOT inside a code block): -1. FINAL(your final answer here) — to provide the answer as literal text +1. FINAL(your final answer here) — to provide the answer as literal text. Use `FINAL(...)` only when you are completely finished and will make no further REPL calls, request no further inspection, and need no additional intermediate outputs. 2. FINAL_VAR(variable_name) — to return a variable you created in the REPL as your final answer +Do not use `FINAL(...)` or `FINAL_VAR(...)` for intermediate status updates, plans, requests to inspect REPL output, or statements such as needing more information; those must be written as normal assistant text instead. + Think step by step carefully, plan, and execute this plan immediately — do not just say what you will do. """ diff --git a/aixplain/v2/rlm.py b/aixplain/v2/rlm.py index 05c76707..b1962084 100644 --- a/aixplain/v2/rlm.py +++ b/aixplain/v2/rlm.py @@ -63,9 +63,11 @@ ``` IMPORTANT: When you are done with the iterative process, you MUST provide a final answer using one of these two forms (NOT inside a code block): -1. FINAL(your final answer here) — to provide the answer as literal text +1. FINAL(your final answer here) — to provide the answer as literal text. Use `FINAL(...)` only when you are completely finished and will make no further REPL calls, request no further inspection, and need no additional intermediate outputs. 2. FINAL_VAR(variable_name) — to return a variable you created in the REPL as your final answer +Do not use `FINAL(...)` or `FINAL_VAR(...)` for intermediate status updates, plans, requests to inspect REPL output, or statements such as needing more information; those must be written as normal assistant text instead. + Think step by step carefully, plan, and execute this plan immediately — do not just say what you will do. """ From d9001c02d09ba0808c8a74bc125c635a4c822257 Mon Sep 17 00:00:00 2001 From: Abdelrahman Elsheikh Date: Thu, 9 Apr 2026 17:10:51 +0200 Subject: [PATCH 4/4] Optimize the system prompt of RLM --- aixplain/v1/modules/model/rlm.py | 6 +++--- aixplain/v2/rlm.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/aixplain/v1/modules/model/rlm.py b/aixplain/v1/modules/model/rlm.py index 45b9175c..f7262b29 100644 --- a/aixplain/v1/modules/model/rlm.py +++ b/aixplain/v1/modules/model/rlm.py @@ -74,10 +74,10 @@ ``` IMPORTANT: When you are done with the iterative process, you MUST provide a final answer using one of these two forms (NOT inside a code block): -1. FINAL(your final answer here) — to provide the answer as literal text. Use `FINAL(...)` only when you are completely finished and will make no further REPL calls, request no further inspection, and need no additional intermediate outputs. -2. FINAL_VAR(variable_name) — to return a variable you created in the REPL as your final answer +1. FINAL(your final answer here) — to provide the answer as literal text. Use `FINAL(...)` only when you are completely finished: you will make no further REPL calls, need no further inspection of REPL output, and are not including any REPL code in the same response. +2. FINAL_VAR(variable_name) — to return a variable you created in the REPL as your final answer. Use `FINAL_VAR(...)` only when that variable already contains your completed final answer and you will make no further REPL calls. -Do not use `FINAL(...)` or `FINAL_VAR(...)` for intermediate status updates, plans, requests to inspect REPL output, or statements such as needing more information; those must be written as normal assistant text instead. +Do not use `FINAL(...)` or `FINAL_VAR(...)` for intermediate status updates, plans, requests to inspect REPL output, statements such as needing more information, or any response that also includes REPL code to be executed first; those must be written as normal assistant text instead. Think step by step carefully, plan, and execute this plan immediately — do not just say what you will do. """ diff --git a/aixplain/v2/rlm.py b/aixplain/v2/rlm.py index b1962084..b12c7ee2 100644 --- a/aixplain/v2/rlm.py +++ b/aixplain/v2/rlm.py @@ -63,10 +63,10 @@ ``` IMPORTANT: When you are done with the iterative process, you MUST provide a final answer using one of these two forms (NOT inside a code block): -1. FINAL(your final answer here) — to provide the answer as literal text. Use `FINAL(...)` only when you are completely finished and will make no further REPL calls, request no further inspection, and need no additional intermediate outputs. -2. FINAL_VAR(variable_name) — to return a variable you created in the REPL as your final answer +1. FINAL(your final answer here) — to provide the answer as literal text. Use `FINAL(...)` only when you are completely finished: you will make no further REPL calls, need no further inspection of REPL output, and are not including any REPL code in the same response. +2. FINAL_VAR(variable_name) — to return a variable you created in the REPL as your final answer. Use `FINAL_VAR(...)` only when that variable already contains your completed final answer and you will make no further REPL calls. -Do not use `FINAL(...)` or `FINAL_VAR(...)` for intermediate status updates, plans, requests to inspect REPL output, or statements such as needing more information; those must be written as normal assistant text instead. +Do not use `FINAL(...)` or `FINAL_VAR(...)` for intermediate status updates, plans, requests to inspect REPL output, statements such as needing more information, or any response that also includes REPL code to be executed first; those must be written as normal assistant text instead. Think step by step carefully, plan, and execute this plan immediately — do not just say what you will do. """