From a1ce24b7fce60e245b1366003805bd37b2fd95a1 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Wed, 6 May 2026 21:55:32 +0200 Subject: [PATCH 1/9] Resolve bare Python job scripts Co-authored-by: OpenAI Codex --- agent/tools/sandbox_tool.py | 18 ++++-- tests/unit/test_sandbox_script_resolution.py | 60 ++++++++++++++++++++ 2 files changed, 72 insertions(+), 6 deletions(-) create mode 100644 tests/unit/test_sandbox_script_resolution.py diff --git a/agent/tools/sandbox_tool.py b/agent/tools/sandbox_tool.py index 4d643f4c..6f7041d6 100644 --- a/agent/tools/sandbox_tool.py +++ b/agent/tools/sandbox_tool.py @@ -60,15 +60,21 @@ def _get_sandbox_create_lock(owner: str) -> asyncio.Lock: def _looks_like_path(script: str) -> bool: """Return True if the script string looks like a file path (not inline code).""" - return ( + if not ( isinstance(script, str) and script.strip() == script and not any(c in script for c in "\r\n\0") - and ( - script.startswith("/") - or script.startswith("./") - or script.startswith("../") - ) + ): + return False + + if script.startswith("http://") or script.startswith("https://"): + return False + + return ( + script.startswith("/") + or script.startswith("./") + or script.startswith("../") + or (script.endswith(".py") and not any(c.isspace() for c in script)) ) diff --git a/tests/unit/test_sandbox_script_resolution.py b/tests/unit/test_sandbox_script_resolution.py new file mode 100644 index 00000000..84cb5385 --- /dev/null +++ b/tests/unit/test_sandbox_script_resolution.py @@ -0,0 +1,60 @@ +from types import SimpleNamespace + +import pytest + +from agent.tools.sandbox_tool import resolve_sandbox_script + + +class FakeSandbox: + def __init__(self): + self.read_paths = [] + + def read(self, path, *, limit): + self.read_paths.append((path, limit)) + return SimpleNamespace( + success=True, + output="1\tprint('training')\n2\tprint('done')", + error="", + ) + + +@pytest.mark.asyncio +async def test_resolve_sandbox_script_accepts_bare_python_filename(): + sandbox = FakeSandbox() + + content, error = await resolve_sandbox_script(sandbox, "train_smollm2.py") + + assert error is None + assert content == "print('training')\nprint('done')" + assert sandbox.read_paths == [("train_smollm2.py", 100_000)] + + +@pytest.mark.asyncio +async def test_resolve_sandbox_script_accepts_relative_python_path(): + sandbox = FakeSandbox() + + content, error = await resolve_sandbox_script(sandbox, "scripts/train.py") + + assert error is None + assert content == "print('training')\nprint('done')" + assert sandbox.read_paths == [("scripts/train.py", 100_000)] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "script", + [ + "https://example.com/train.py", + "http://example.com/train.py", + "train_smollm2.py --epochs 1", + "print('hello')", + ], +) +async def test_resolve_sandbox_script_ignores_non_path_scripts(script): + sandbox = FakeSandbox() + + content, error = await resolve_sandbox_script(sandbox, script) + + assert content is None + assert error is None + assert sandbox.read_paths == [] From 4be6975b297c0302bce8aeae5ea05a5a8066eb29 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 7 May 2026 10:58:14 +0200 Subject: [PATCH 2/9] Update HF Jobs script description Co-authored-by: OpenAI Codex --- agent/tools/jobs_tool.py | 2 +- tests/unit/test_sandbox_script_resolution.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/agent/tools/jobs_tool.py b/agent/tools/jobs_tool.py index c058921a..05fb37ff 100644 --- a/agent/tools/jobs_tool.py +++ b/agent/tools/jobs_tool.py @@ -1156,7 +1156,7 @@ async def _resume_scheduled_job(self, args: Dict[str, Any]) -> ToolResult: "script": { "type": "string", "description": ( - "Python code or sandbox file path (e.g. '/app/train.py') or URL. " + "Python code, sandbox file path (e.g. '/app/train.py', './train.py', or bare 'train.py'), or URL. " "Triggers Python mode. For ML training: base this on a working example found via github_find_examples, not on internal knowledge. " "Mutually exclusive with 'command'." ), diff --git a/tests/unit/test_sandbox_script_resolution.py b/tests/unit/test_sandbox_script_resolution.py index 84cb5385..3102d5a3 100644 --- a/tests/unit/test_sandbox_script_resolution.py +++ b/tests/unit/test_sandbox_script_resolution.py @@ -2,6 +2,7 @@ import pytest +from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC from agent.tools.sandbox_tool import resolve_sandbox_script @@ -58,3 +59,11 @@ async def test_resolve_sandbox_script_ignores_non_path_scripts(script): assert content is None assert error is None assert sandbox.read_paths == [] + + +def test_hf_jobs_script_description_mentions_bare_python_filenames(): + script_description = HF_JOBS_TOOL_SPEC["parameters"]["properties"]["script"][ + "description" + ] + + assert "bare 'train.py'" in script_description From 7b4687bc8847d72dc981cedf7dc9590bf50d4d59 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 8 May 2026 09:59:04 +0200 Subject: [PATCH 3/9] Clarify HF Jobs script path prompt Co-authored-by: OpenAI Codex --- agent/prompts/system_prompt_v3.yaml | 6 ++++++ tests/unit/test_sandbox_auto_start.py | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/agent/prompts/system_prompt_v3.yaml b/agent/prompts/system_prompt_v3.yaml index 4543048f..5de61a86 100644 --- a/agent/prompts/system_prompt_v3.yaml +++ b/agent/prompts/system_prompt_v3.yaml @@ -102,6 +102,12 @@ system_prompt: | # When submitting a training job + Never pass a local machine path to hf_jobs.script, such as /Users/..., /home/..., /fsx/..., or a repo checkout path. HF Jobs runs in a fresh cloud environment where local files do not exist. For hf_jobs.script, use exactly one of: + - inline Python source code + - a file already written in the session sandbox, e.g. /app/train.py, ./train.py, or train.py + - a public/raw URL + If you wrote or tested a script locally, read the file content and submit it inline, or write it into the sandbox first. + Before calling hf_jobs, output a pre-flight check: - Reference implementation: [which example you based this on] - Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details] diff --git a/tests/unit/test_sandbox_auto_start.py b/tests/unit/test_sandbox_auto_start.py index 1ad27fca..4cf67435 100644 --- a/tests/unit/test_sandbox_auto_start.py +++ b/tests/unit/test_sandbox_auto_start.py @@ -34,3 +34,12 @@ def test_prompt_and_tool_specs_do_not_require_cpu_sandbox_create(): in tool_specs["sandbox_create"] ) assert "started automatically for normal CPU work" in tool_specs["bash"] + + +def test_prompt_rejects_local_machine_paths_for_hf_jobs_scripts(): + prompt = Path("agent/prompts/system_prompt_v3.yaml").read_text() + + assert "Never pass a local machine path to hf_jobs.script" in prompt + assert "/fsx/..." in prompt + assert "inline Python source code" in prompt + assert "a file already written in the session sandbox" in prompt From e63b945171a47c1766ebfb22dda66a01ca4fadca Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 8 May 2026 10:31:25 +0200 Subject: [PATCH 4/9] Add CLI sandbox tool runtime Co-authored-by: OpenAI Codex --- agent/config.py | 3 +- agent/core/agent_loop.py | 10 +- agent/core/session.py | 1 + agent/main.py | 60 +++++++-- agent/utils/terminal_display.py | 7 +- configs/cli_agent_config.json | 1 + tests/unit/test_cli_rendering.py | 187 +++++++++++++++++++++++++- tests/unit/test_config.py | 35 +++++ tests/unit/test_sandbox_auto_start.py | 91 +++++++++++++ 9 files changed, 377 insertions(+), 18 deletions(-) diff --git a/agent/config.py b/agent/config.py index 35b095c3..5ad8bd8a 100644 --- a/agent/config.py +++ b/agent/config.py @@ -2,7 +2,7 @@ import os import re from pathlib import Path -from typing import Any, Union +from typing import Any, Literal, Union from dotenv import load_dotenv from fastmcp.mcp_config import ( @@ -46,6 +46,7 @@ class Config(BaseModel): # Permission control parameters confirm_cpu_jobs: bool = True auto_file_upload: bool = False + tool_runtime: Literal["local", "sandbox"] = "local" # Reasoning effort *preference* — the ceiling the user wants. The probe # on `/model` walks a cascade down from here (``max`` → ``xhigh`` → ``high`` diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index 0eaa6e9d..7c75e66f 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -32,7 +32,11 @@ from agent.core.session import Event, OpType, Session from agent.core.tools import ToolRouter from agent.tools.jobs_tool import CPU_FLAVORS -from agent.tools.sandbox_tool import DEFAULT_CPU_SANDBOX_HARDWARE +from agent.tools.sandbox_tool import ( + DEFAULT_CPU_SANDBOX_HARDWARE, + start_cpu_sandbox_preload, + teardown_session_sandbox, +) logger = logging.getLogger(__name__) @@ -1926,6 +1930,8 @@ async def shutdown(session: Session) -> bool: _ = session.save_and_upload_detached(repo_id) session.is_running = False + if not getattr(session, "local_mode", False): + await teardown_session_sandbox(session) await session.send_event(Event(event_type="shutdown")) return True @@ -1999,6 +2005,8 @@ async def submission_loop( ) if session_holder is not None: session_holder[0] = session + if not local_mode: + start_cpu_sandbox_preload(session) start_session_artifact_collection_task(session, token=hf_token) logger.info("Agent loop started") diff --git a/agent/core/session.py b/agent/core/session.py index fb08c75f..b435b671 100644 --- a/agent/core/session.py +++ b/agent/core/session.py @@ -96,6 +96,7 @@ def __init__( self.hf_token: Optional[str] = hf_token self.user_id: Optional[str] = user_id self.hf_username: Optional[str] = hf_username + self.local_mode = local_mode self.persistence_store = persistence_store self.tool_router = tool_router self.stream = stream diff --git a/agent/main.py b/agent/main.py index a7262707..059cc3b7 100644 --- a/agent/main.py +++ b/agent/main.py @@ -57,6 +57,20 @@ CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.json" +def _apply_tool_runtime_override(config: Any, *, sandbox_tools: bool) -> str: + if sandbox_tools: + config.tool_runtime = "sandbox" + return getattr(config, "tool_runtime", "local") + + +def _is_local_tool_runtime(config: Any) -> bool: + return getattr(config, "tool_runtime", "local") == "local" + + +def _tool_runtime_label(local_mode: bool) -> str: + return "local filesystem" if local_mode else "HF sandbox" + + def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool: if tool_info.get("tool") != "hf_jobs": return False @@ -840,6 +854,7 @@ async def _handle_slash_command( session = session_holder[0] if session_holder else None print(f"Model: {config.model_name}") print(f"Reasoning effort: {config.reasoning_effort or 'off'}") + print(f"Tool runtime: {_tool_runtime_label(_is_local_tool_runtime(config))}") if session: print(f"Turns: {session.turn_count}") print(f"Context items: {len(session.context_manager.items)}") @@ -959,7 +974,7 @@ async def _handle_share_traces_command(arg: str, config, session) -> None: console.print(f"[green]Dataset is now {label}.[/green] {url}") -async def main(model: str | None = None): +async def main(model: str | None = None, sandbox_tools: bool = False): """Interactive chat with the agent""" # Clear screen @@ -971,16 +986,23 @@ async def main(model: str | None = None): config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) if model: config.model_name = model + _apply_tool_runtime_override(config, sandbox_tools=sandbox_tools) + local_mode = _is_local_tool_runtime(config) - # HF token — required for Hub-backed models/tools, but not for local LLMs. + # HF token — required for Hub-backed models/tools and sandbox tools, but + # not for local LLMs using only local filesystem tools. hf_token = resolve_hf_token() - if not hf_token and not is_local_model_id(config.model_name): + if not hf_token and (not is_local_model_id(config.model_name) or not local_mode): hf_token = await _prompt_and_save_hf_token(prompt_session) # Resolve username for banner hf_user = _get_hf_user(hf_token) - print_banner(model=config.model_name, hf_user=hf_user) + print_banner( + model=config.model_name, + hf_user=hf_user, + tool_runtime=_tool_runtime_label(local_mode), + ) # Pre-warm the HF router catalog in the background so /model switches # don't block on a network fetch. @@ -999,8 +1021,10 @@ async def main(model: str | None = None): notification_gateway = NotificationGateway(config.messaging) await notification_gateway.start() - # Create tool router with local mode - tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True) + # Create tool router with the selected CLI tool runtime. + tool_router = ToolRouter( + config.mcpServers, hf_token=hf_token, local_mode=local_mode + ) # Session holder for interrupt/model/status access session_holder = [None] @@ -1014,7 +1038,7 @@ async def main(model: str | None = None): session_holder=session_holder, hf_token=hf_token, user_id=hf_user, - local_mode=True, + local_mode=local_mode, stream=True, notification_gateway=notification_gateway, notification_destinations=config.messaging.default_auto_destinations(), @@ -1192,6 +1216,7 @@ async def headless_main( model: str | None = None, max_iterations: int | None = None, stream: bool = True, + sandbox_tools: bool = False, ) -> None: """Run a single prompt headlessly and exit.""" import logging @@ -1204,11 +1229,13 @@ async def headless_main( if model: config.model_name = model + _apply_tool_runtime_override(config, sandbox_tools=sandbox_tools) + local_mode = _is_local_tool_runtime(config) hf_token = resolve_hf_token() - if not hf_token and not is_local_model_id(config.model_name): + if not hf_token and (not is_local_model_id(config.model_name) or not local_mode): print( - "ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", + "ERROR: No HF token found. Set HF_TOKEN or run `hf auth login`.", file=sys.stderr, ) sys.exit(1) @@ -1224,6 +1251,7 @@ async def headless_main( config.max_iterations = max_iterations print(f"Model: {config.model_name}", file=sys.stderr) + print(f"Tool runtime: {_tool_runtime_label(local_mode)}", file=sys.stderr) print(f"Max iterations: {config.max_iterations}", file=sys.stderr) print(f"Prompt: {prompt}", file=sys.stderr) print("---", file=sys.stderr) @@ -1231,7 +1259,9 @@ async def headless_main( submission_queue: asyncio.Queue = asyncio.Queue() event_queue: asyncio.Queue = asyncio.Queue() - tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True) + tool_router = ToolRouter( + config.mcpServers, hf_token=hf_token, local_mode=local_mode + ) session_holder: list = [None] agent_task = asyncio.create_task( @@ -1243,7 +1273,7 @@ async def headless_main( session_holder=session_holder, hf_token=hf_token, user_id=hf_user, - local_mode=True, + local_mode=local_mode, stream=stream, notification_gateway=notification_gateway, notification_destinations=config.messaging.default_auto_destinations(), @@ -1438,6 +1468,11 @@ def cli(): action="store_true", help="Disable token streaming (use non-streaming LLM calls)", ) + parser.add_argument( + "--sandbox-tools", + action="store_true", + help="Use HF Space sandbox tools instead of local filesystem tools", + ) args = parser.parse_args() try: @@ -1451,10 +1486,11 @@ def cli(): model=args.model, max_iterations=max_iter, stream=not args.no_stream, + sandbox_tools=args.sandbox_tools, ) ) else: - asyncio.run(main(model=args.model)) + asyncio.run(main(model=args.model, sandbox_tools=args.sandbox_tools)) except KeyboardInterrupt: print("\n\nGoodbye!") diff --git a/agent/utils/terminal_display.py b/agent/utils/terminal_display.py index d464fd8a..e678ecb9 100644 --- a/agent/utils/terminal_display.py +++ b/agent/utils/terminal_display.py @@ -92,7 +92,11 @@ def get_console() -> Console: # ── Banner ───────────────────────────────────────────────────────────── -def print_banner(model: str | None = None, hf_user: str | None = None) -> None: +def print_banner( + model: str | None = None, + hf_user: str | None = None, + tool_runtime: str | None = None, +) -> None: """Print particle logo then CRT boot sequence with system info.""" from agent.utils.particle_logo import run_particle_logo from agent.utils.crt_boot import run_boot_sequence @@ -115,6 +119,7 @@ def print_banner(model: str | None = None, hf_user: str | None = None) -> None: (f"{_I}Initializing agent runtime...", gold), (f"{_I} User: {user_label}", dim_gold), (f"{_I} Model: {model_label}", dim_gold), + (f"{_I} Tool runtime: {tool_runtime or 'local filesystem'}", dim_gold), (f"{_I} Tools: loading...", dim_gold), ("", ""), (f"{_I}/help for commands · /model to switch · /quit to exit", gold), diff --git a/configs/cli_agent_config.json b/configs/cli_agent_config.json index ed247998..48ba4dd8 100644 --- a/configs/cli_agent_config.json +++ b/configs/cli_agent_config.json @@ -7,6 +7,7 @@ "yolo_mode": false, "confirm_cpu_jobs": true, "auto_file_upload": true, + "tool_runtime": "local", "messaging": { "enabled": false, "auto_event_types": ["approval_required", "error", "turn_complete"], diff --git a/tests/unit/test_cli_rendering.py b/tests/unit/test_cli_rendering.py index e94700bf..ab68c920 100644 --- a/tests/unit/test_cli_rendering.py +++ b/tests/unit/test_cli_rendering.py @@ -52,10 +52,11 @@ def _unexpected_future(*args, **kwargs): def test_cli_forwards_model_flag_to_interactive_main(monkeypatch): - seen: dict[str, str | None] = {} + seen: dict[str, object] = {} - async def fake_main(*, model=None): + async def fake_main(*, model=None, sandbox_tools=False): seen["model"] = model + seen["sandbox_tools"] = sandbox_tools monkeypatch.setattr(sys, "argv", ["ml-intern", "--model", "openai/gpt-5.5"]) monkeypatch.setattr(main_mod, "main", fake_main) @@ -63,6 +64,61 @@ async def fake_main(*, model=None): main_mod.cli() assert seen["model"] == "openai/gpt-5.5" + assert seen["sandbox_tools"] is False + + +def test_cli_forwards_sandbox_flag_to_interactive_main(monkeypatch): + seen: dict[str, object] = {} + + async def fake_main(*, model=None, sandbox_tools=False): + seen["model"] = model + seen["sandbox_tools"] = sandbox_tools + + monkeypatch.setattr(sys, "argv", ["ml-intern", "--sandbox-tools"]) + monkeypatch.setattr(main_mod, "main", fake_main) + + main_mod.cli() + + assert seen == {"model": None, "sandbox_tools": True} + + +def test_cli_forwards_sandbox_flag_to_headless_main(monkeypatch): + seen: dict[str, object] = {} + + async def fake_headless_main( + prompt, + *, + model=None, + max_iterations=None, + stream=True, + sandbox_tools=False, + ): + seen.update( + { + "prompt": prompt, + "model": model, + "max_iterations": max_iterations, + "stream": stream, + "sandbox_tools": sandbox_tools, + } + ) + + monkeypatch.setattr( + sys, + "argv", + ["ml-intern", "--sandbox-tools", "--no-stream", "train a model"], + ) + monkeypatch.setattr(main_mod, "headless_main", fake_headless_main) + + main_mod.cli() + + assert seen == { + "prompt": "train a model", + "model": None, + "max_iterations": None, + "stream": False, + "sandbox_tools": True, + } @pytest.mark.asyncio @@ -70,9 +126,10 @@ async def test_interactive_main_applies_model_override_before_banner(monkeypatch class StopAfterBanner(Exception): pass - def fake_banner(*, model=None, hf_user=None): + def fake_banner(*, model=None, hf_user=None, tool_runtime=None): assert model == "openai/gpt-5.5" assert hf_user == "tester" + assert tool_runtime == "local filesystem" raise StopAfterBanner monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0) @@ -85,9 +142,133 @@ def fake_banner(*, model=None, hf_user=None): lambda _path, **_kwargs: SimpleNamespace( model_name="moonshotai/Kimi-K2.6", mcpServers={}, + tool_runtime="local", ), ) monkeypatch.setattr(main_mod, "print_banner", fake_banner) with pytest.raises(StopAfterBanner): await main_mod.main(model="openai/gpt-5.5") + + +@pytest.mark.asyncio +async def test_local_model_local_runtime_skips_hf_token_prompt(monkeypatch): + class StopAfterBanner(Exception): + pass + + async def fail_prompt(_prompt_session): + raise AssertionError("local model with local tools should not prompt") + + def fake_banner(*, model=None, hf_user=None, tool_runtime=None): + assert model == "llamacpp/model" + assert hf_user is None + assert tool_runtime == "local filesystem" + raise StopAfterBanner + + monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0) + monkeypatch.setattr(main_mod, "PromptSession", lambda: object()) + monkeypatch.setattr(main_mod, "resolve_hf_token", lambda: None) + monkeypatch.setattr(main_mod, "_prompt_and_save_hf_token", fail_prompt) + monkeypatch.setattr(main_mod, "_get_hf_user", lambda _token: None) + monkeypatch.setattr( + main_mod, + "load_config", + lambda _path, **_kwargs: SimpleNamespace( + model_name="llamacpp/model", + mcpServers={}, + tool_runtime="local", + ), + ) + monkeypatch.setattr(main_mod, "print_banner", fake_banner) + + with pytest.raises(StopAfterBanner): + await main_mod.main() + + +@pytest.mark.asyncio +async def test_local_model_sandbox_runtime_prompts_for_hf_token(monkeypatch): + class StopAfterBanner(Exception): + pass + + prompted = False + + async def fake_prompt(_prompt_session): + nonlocal prompted + prompted = True + return "hf-token" + + def fake_banner(*, model=None, hf_user=None, tool_runtime=None): + assert model == "llamacpp/model" + assert hf_user == "tester" + assert tool_runtime == "HF sandbox" + raise StopAfterBanner + + monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0) + monkeypatch.setattr(main_mod, "PromptSession", lambda: object()) + monkeypatch.setattr(main_mod, "resolve_hf_token", lambda: None) + monkeypatch.setattr(main_mod, "_prompt_and_save_hf_token", fake_prompt) + monkeypatch.setattr(main_mod, "_get_hf_user", lambda _token: "tester") + monkeypatch.setattr( + main_mod, + "load_config", + lambda _path, **_kwargs: SimpleNamespace( + model_name="llamacpp/model", + mcpServers={}, + tool_runtime="local", + ), + ) + monkeypatch.setattr(main_mod, "print_banner", fake_banner) + + with pytest.raises(StopAfterBanner): + await main_mod.main(sandbox_tools=True) + + assert prompted is True + + +@pytest.mark.asyncio +async def test_interactive_main_passes_sandbox_runtime_to_tool_router(monkeypatch): + class StopAfterToolRouter(Exception): + pass + + seen: dict[str, object] = {} + + class FakeGateway: + def __init__(self, _config): + pass + + async def start(self): + pass + + class FakeToolRouter: + def __init__(self, mcp_servers, *, hf_token=None, local_mode=True): + seen["mcp_servers"] = mcp_servers + seen["hf_token"] = hf_token + seen["local_mode"] = local_mode + raise StopAfterToolRouter + + from agent.core import hf_router_catalog + + monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0) + monkeypatch.setattr(main_mod, "PromptSession", lambda: object()) + monkeypatch.setattr(main_mod, "resolve_hf_token", lambda: "hf-token") + monkeypatch.setattr(main_mod, "_get_hf_user", lambda _token: "tester") + monkeypatch.setattr(main_mod, "print_banner", lambda **_kwargs: None) + monkeypatch.setattr(hf_router_catalog, "prewarm", lambda: None) + monkeypatch.setattr( + main_mod, + "load_config", + lambda _path, **_kwargs: SimpleNamespace( + model_name="llamacpp/model", + mcpServers={"server": object()}, + messaging=SimpleNamespace(default_auto_destinations=lambda: []), + tool_runtime="local", + ), + ) + monkeypatch.setattr(main_mod, "NotificationGateway", FakeGateway) + monkeypatch.setattr(main_mod, "ToolRouter", FakeToolRouter) + + with pytest.raises(StopAfterToolRouter): + await main_mod.main(sandbox_tools=True) + + assert seen["hf_token"] == "hf-token" + assert seen["local_mode"] is False diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index c99f05ee..da66baff 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,5 +1,8 @@ import json +import pytest +from pydantic import ValidationError + from agent import config as config_module @@ -121,3 +124,35 @@ def test_slack_user_defaults_can_be_disabled(tmp_path, monkeypatch): assert not config.messaging.enabled assert config.messaging.destinations == {} + + +def test_tool_runtime_defaults_to_local(tmp_path): + config_path = tmp_path / "config.json" + _write_json(config_path, {"model_name": "moonshotai/Kimi-K2.6"}) + + config = config_module.load_config(str(config_path)) + + assert config.tool_runtime == "local" + + +def test_user_config_can_set_sandbox_tool_runtime(tmp_path, monkeypatch): + config_path = tmp_path / "config.json" + user_config_path = tmp_path / "user-config.json" + _write_json(config_path, {"model_name": "moonshotai/Kimi-K2.6"}) + _write_json(user_config_path, {"tool_runtime": "sandbox"}) + monkeypatch.setenv("ML_INTERN_CLI_CONFIG", str(user_config_path)) + + config = config_module.load_config(str(config_path), include_user_defaults=True) + + assert config.tool_runtime == "sandbox" + + +def test_invalid_tool_runtime_is_rejected(tmp_path): + config_path = tmp_path / "config.json" + _write_json( + config_path, + {"model_name": "moonshotai/Kimi-K2.6", "tool_runtime": "hybrid"}, + ) + + with pytest.raises(ValidationError): + config_module.load_config(str(config_path)) diff --git a/tests/unit/test_sandbox_auto_start.py b/tests/unit/test_sandbox_auto_start.py index 4cf67435..5dc6f668 100644 --- a/tests/unit/test_sandbox_auto_start.py +++ b/tests/unit/test_sandbox_auto_start.py @@ -1,7 +1,14 @@ +import asyncio from types import SimpleNamespace from pathlib import Path +import pytest + +from agent.config import Config +from agent.core import agent_loop from agent.core.agent_loop import _needs_approval +from agent.core.session import OpType +from agent.core.tools import create_builtin_tools from agent.tools.sandbox_tool import get_sandbox_tools @@ -43,3 +50,87 @@ def test_prompt_rejects_local_machine_paths_for_hf_jobs_scripts(): assert "/fsx/..." in prompt assert "inline Python source code" in prompt assert "a file already written in the session sandbox" in prompt + + +def test_local_tool_runtime_excludes_sandbox_create(): + tool_names = {tool.name for tool in create_builtin_tools(local_mode=True)} + + assert {"bash", "read", "write", "edit"} <= tool_names + assert "sandbox_create" not in tool_names + + +def test_sandbox_tool_runtime_includes_sandbox_create(): + tool_names = {tool.name for tool in create_builtin_tools(local_mode=False)} + + assert {"sandbox_create", "bash", "read", "write", "edit"} <= tool_names + + +@pytest.mark.asyncio +async def test_cli_sandbox_runtime_preloads_and_tears_down_sandbox(monkeypatch): + started = [] + torn_down = [] + + class FakeToolRouter: + tools = {} + + def get_tool_specs_for_llm(self): + return [] + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + def fake_start_cpu_sandbox_preload(session): + started.append(session) + return None + + async def fake_teardown_session_sandbox(session): + torn_down.append(session) + + monkeypatch.setattr( + agent_loop, + "start_session_artifact_collection_task", + lambda *_args, **_kwargs: None, + ) + monkeypatch.setattr( + agent_loop, "start_cpu_sandbox_preload", fake_start_cpu_sandbox_preload + ) + monkeypatch.setattr( + agent_loop, "teardown_session_sandbox", fake_teardown_session_sandbox + ) + + submission_queue = asyncio.Queue() + event_queue = asyncio.Queue() + session_holder = [None] + config = Config.model_validate( + {"model_name": "openai/gpt-5.5", "save_sessions": False} + ) + + task = asyncio.create_task( + agent_loop.submission_loop( + submission_queue, + event_queue, + config=config, + tool_router=FakeToolRouter(), + session_holder=session_holder, + hf_token="hf-token", + user_id="tester", + local_mode=False, + ) + ) + + ready = await asyncio.wait_for(event_queue.get(), timeout=1) + assert ready.event_type == "ready" + assert started == [session_holder[0]] + assert session_holder[0].local_mode is False + + await submission_queue.put( + SimpleNamespace( + operation=SimpleNamespace(op_type=OpType.SHUTDOWN, data=None), + ) + ) + await asyncio.wait_for(task, timeout=1) + + assert torn_down == [session_holder[0]] From 105ca83826552858d8f3cf6c143f1d037b1c00aa Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 8 May 2026 11:14:00 +0200 Subject: [PATCH 5/9] Document CLI sandbox tool runtime Co-authored-by: OpenAI Codex --- README.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/README.md b/README.md index b9db1863..849aa326 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ ml-intern --model anthropic/claude-opus-4-7 "your prompt" # requires ANTHROPIC ml-intern --model openai/gpt-5.5 "your prompt" # requires OPENAI_API_KEY ml-intern --model ollama/llama3.1:8b "your prompt" ml-intern --model vllm/meta-llama/Llama-3.1-8B-Instruct "your prompt" +ml-intern --sandbox-tools "your prompt" # use HF Space sandbox tools ml-intern --max-iterations 100 "your prompt" ml-intern --no-stream "your prompt" ``` @@ -97,6 +98,30 @@ one shared local endpoint, or override a specific provider with its matching `VLLM_API_KEY`. Provider-specific variables take precedence over the shared local variables. Base URLs may include or omit `/v1`. +**CLI tool runtime:** + +By default, the CLI runs `bash`, `read`, `write`, and `edit` on your local +filesystem. To use HF Space sandbox tools instead, including `sandbox_create`, +opt in with `--sandbox-tools`: + +```bash +ml-intern --sandbox-tools "test this training script in a GPU sandbox" +ml-intern --model llamacpp/ggml-org/gemma-3-1b-it-GGUF --sandbox-tools +``` + +Sandbox tool runtime requires `HF_TOKEN`, even when the selected model is local, +because it creates private HF Spaces. You can also make sandbox tools your CLI +default in `~/.config/ml-intern/cli_agent_config.json`: + +```json +{ "tool_runtime": "sandbox" } +``` + +Use the default local runtime when you want tools to inspect or edit files in +your checkout. Use sandbox runtime when you want the agent to create or replace +an HF Space sandbox, test code remotely, or request GPU sandbox hardware before +launching larger HF Jobs. + ## Sharing Traces Every session is auto-uploaded to your **own private Hugging Face dataset** From b44a7de62e43d68beac7314a2d1922c6112603ec Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 8 May 2026 11:23:56 +0200 Subject: [PATCH 6/9] Wait for initial CLI sandbox preload Co-authored-by: OpenAI Codex --- agent/main.py | 16 ++++++++++++++++ tests/unit/test_cli_rendering.py | 18 ++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/agent/main.py b/agent/main.py index d87e3ee7..9d7d4399 100644 --- a/agent/main.py +++ b/agent/main.py @@ -73,6 +73,20 @@ def _tool_runtime_label(local_mode: bool) -> str: return "local filesystem" if local_mode else "HF sandbox" +async def _wait_for_initial_sandbox_preload(session_holder: list | None) -> None: + session = session_holder[0] if session_holder else None + task = getattr(session, "sandbox_preload_task", None) + if not task: + return + try: + await asyncio.shield(task) + except asyncio.CancelledError: + raise + except Exception: + # The sandbox tool will surface the stored preload error on first use. + return + + def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool: if tool_info.get("tool") != "hf_jobs": return False @@ -1177,6 +1191,8 @@ async def main(model: str | None = None, sandbox_tools: bool = False): ) await ready_event.wait() + if not local_mode: + await _wait_for_initial_sandbox_preload(session_holder) submission_id = [0] # Mirrors codex-rs/tui/src/bottom_pane/mod.rs:137 diff --git a/tests/unit/test_cli_rendering.py b/tests/unit/test_cli_rendering.py index ab68c920..a7bdd4f6 100644 --- a/tests/unit/test_cli_rendering.py +++ b/tests/unit/test_cli_rendering.py @@ -1,5 +1,6 @@ """Regression tests for interactive CLI rendering and research model routing.""" +import asyncio import sys from io import StringIO from types import SimpleNamespace @@ -272,3 +273,20 @@ def __init__(self, mcp_servers, *, hf_token=None, local_mode=True): assert seen["hf_token"] == "hf-token" assert seen["local_mode"] is False + + +@pytest.mark.asyncio +async def test_initial_sandbox_preload_waits_before_prompt(): + waited = False + + async def preload(): + nonlocal waited + await asyncio.sleep(0) + waited = True + + task = asyncio.create_task(preload()) + await main_mod._wait_for_initial_sandbox_preload( + [SimpleNamespace(sandbox_preload_task=task)] + ) + + assert waited is True From c6a618abc7eacea71d3393b07ca7d6cd4b3c76d2 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 8 May 2026 14:20:30 +0200 Subject: [PATCH 7/9] Strengthen GPU sandbox preflight guidance Co-authored-by: OpenAI Codex --- agent/prompts/system_prompt_v3.yaml | 5 ++++- agent/tools/jobs_tool.py | 4 ++++ tests/unit/test_sandbox_auto_start.py | 12 ++++++++++++ tests/unit/test_sandbox_script_resolution.py | 1 + 4 files changed, 21 insertions(+), 1 deletion(-) diff --git a/agent/prompts/system_prompt_v3.yaml b/agent/prompts/system_prompt_v3.yaml index 5de61a86..9aa45b7c 100644 --- a/agent/prompts/system_prompt_v3.yaml +++ b/agent/prompts/system_prompt_v3.yaml @@ -108,9 +108,12 @@ system_prompt: | - a public/raw URL If you wrote or tested a script locally, read the file content and submit it inline, or write it into the sandbox first. + GPU preflight is mandatory before hf_jobs when the job will run on GPU, or when the script loads a model, uses CUDA, bf16/fp16, quantization, flash attention, or torch.compile. First create a GPU sandbox with sandbox_create (t4-small minimum; choose larger hardware when VRAM requires it), run a tiny smoke test there using the same imports, model-loading path, training entrypoint, and a tiny dataset/subset, then fix failures before submitting. If you skip GPU sandbox preflight, state why before calling hf_jobs. + Before calling hf_jobs, output a pre-flight check: - Reference implementation: [which example you based this on] - Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details] + - GPU sandbox smoke test: [hardware and result, or explicitly not applicable because ...] - push_to_hub=True and hub_model_id set - timeout: [value] (based on: [model size] on [hardware]) - Trackio monitoring included and deploying metrics to a public Space @@ -133,7 +136,7 @@ system_prompt: | Do NOT call sandbox_create before normal CPU work. Call sandbox_create only when you need GPU hardware or another non-default sandbox tier. - Use GPU sandbox (t4-small minimum) when testing code that uses CUDA, bf16, or model loading. CPU sandboxes cannot test GPU code paths. + Use a GPU sandbox (t4-small minimum) when testing code that uses CUDA, bf16/fp16, quantization, flash attention, torch.compile, or model loading. CPU sandboxes cannot test GPU code paths. If the available sandbox tiers cannot fit the full model path, test the largest useful smoke path, state what was not covered, and submit one HF job first. # When a task has 3+ steps diff --git a/agent/tools/jobs_tool.py b/agent/tools/jobs_tool.py index aa01c03b..46bd4e37 100644 --- a/agent/tools/jobs_tool.py +++ b/agent/tools/jobs_tool.py @@ -1112,6 +1112,9 @@ async def _resume_scheduled_job(self, args: Dict[str, Any]) -> ToolResult: "- You MUST have called github_find_examples + github_read_file to find a working reference implementation. " "Scripts based on your internal knowledge WILL use outdated APIs and fail.\n" "- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n" + "- If the job runs on GPU, or the script loads a model, uses CUDA, bf16/fp16, quantization, flash attention, " + "or torch.compile, you MUST create a GPU sandbox with sandbox_create first, run a tiny smoke test there, " + "and fix failures before submitting. If skipped, state why before calling hf_jobs.\n" "- Training config MUST include push_to_hub=True and hub_model_id. " "Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n" "- Include trackio monitoring and provide the dashboard URL to the user. " @@ -1159,6 +1162,7 @@ async def _resume_scheduled_job(self, args: Dict[str, Any]) -> ToolResult: "description": ( "Python code, sandbox file path (e.g. '/app/train.py', './train.py', or bare 'train.py'), or URL. " "Triggers Python mode. For ML training: base this on a working example found via github_find_examples, not on internal knowledge. " + "For GPU/model-loading training scripts, smoke-test in a GPU sandbox before submission. " "Mutually exclusive with 'command'." ), }, diff --git a/tests/unit/test_sandbox_auto_start.py b/tests/unit/test_sandbox_auto_start.py index 75c17eb7..c6051c17 100644 --- a/tests/unit/test_sandbox_auto_start.py +++ b/tests/unit/test_sandbox_auto_start.py @@ -9,6 +9,7 @@ from agent.core.agent_loop import _needs_approval from agent.core.session import OpType from agent.core.tools import create_builtin_tools +from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC from agent.tools.sandbox_tool import get_sandbox_tools @@ -52,6 +53,17 @@ def test_prompt_rejects_local_machine_paths_for_hf_jobs_scripts(): assert "a file already written in the session sandbox" in prompt +def test_prompt_and_hf_jobs_spec_require_gpu_preflight_for_gpu_jobs(): + prompt = Path("agent/prompts/system_prompt_v3.yaml").read_text() + jobs_description = HF_JOBS_TOOL_SPEC["description"] + + assert "GPU preflight is mandatory before hf_jobs" in prompt + assert "GPU sandbox smoke test" in prompt + assert "If you skip GPU sandbox preflight" in prompt + assert "you MUST create a GPU sandbox with sandbox_create first" in jobs_description + assert "If skipped, state why before calling hf_jobs" in jobs_description + + def test_local_tool_runtime_excludes_sandbox_create(): tool_names = {tool.name for tool in create_builtin_tools(local_mode=True)} diff --git a/tests/unit/test_sandbox_script_resolution.py b/tests/unit/test_sandbox_script_resolution.py index 3102d5a3..4bb6e42d 100644 --- a/tests/unit/test_sandbox_script_resolution.py +++ b/tests/unit/test_sandbox_script_resolution.py @@ -67,3 +67,4 @@ def test_hf_jobs_script_description_mentions_bare_python_filenames(): ] assert "bare 'train.py'" in script_description + assert "smoke-test in a GPU sandbox before submission" in script_description From 072ae12d9a028dffa3de9fb46aaa586f59fb2aee Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 8 May 2026 14:51:15 +0200 Subject: [PATCH 8/9] Guard against text-only stops with unfinished plans Co-authored-by: OpenAI Codex --- agent/core/agent_loop.py | 85 ++++++++++ agent/core/session.py | 1 + agent/tools/plan_tool.py | 12 +- tests/unit/test_no_tool_continuation_guard.py | 147 ++++++++++++++++++ 4 files changed, 241 insertions(+), 4 deletions(-) create mode 100644 tests/unit/test_no_tool_continuation_guard.py diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index 56178b12..e6742efb 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -44,6 +44,43 @@ _MALFORMED_TOOL_PREFIX = "ERROR: Tool call to '" _MALFORMED_TOOL_SUFFIX = "' had malformed JSON arguments" +_NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT = 2 + + +def _unfinished_plan_items(session: Session) -> list[dict[str, str]]: + plan = getattr(session, "current_plan", None) or [] + unfinished: list[dict[str, str]] = [] + for item in plan: + if not isinstance(item, dict): + continue + status = item.get("status") + if status in {"pending", "in_progress"}: + unfinished.append(item) + return unfinished + + +def _format_plan_items_for_guard(items: list[dict[str, str]], limit: int = 4) -> str: + formatted = [] + for item in items[:limit]: + item_id = item.get("id") or "?" + content = item.get("content") or "(unnamed task)" + status = item.get("status") or "unknown" + formatted.append(f"{item_id}. {content} [{status}]") + if len(items) > limit: + formatted.append(f"... and {len(items) - limit} more") + return "; ".join(formatted) + + +def _no_tool_incomplete_plan_prompt(items: list[dict[str, str]]) -> str: + summary = _format_plan_items_for_guard(items) + return ( + "[SYSTEM: CONTINUATION GUARD] Your previous response ended without any " + "tool calls, but the task is not complete. The current plan still has " + f"unfinished items: {summary}. Do not return control to the user yet. " + "Continue from the next unfinished item and make at least one tool call " + "now. If you genuinely cannot continue, first use tools to inspect the " + "state or verify the blocker." + ) def _malformed_tool_name(message: Message) -> str | None: @@ -1157,6 +1194,7 @@ async def run_agent( final_response = None errored = False max_iterations = session.config.max_iterations + no_tool_incomplete_plan_retries = 0 while max_iterations == -1 or iteration < max_iterations: # ── Cancellation check: before LLM call ── @@ -1305,6 +1343,51 @@ async def run_agent( # If no tool calls, add assistant message and we're done if not tool_calls: + unfinished_plan = _unfinished_plan_items(session) + if ( + unfinished_plan + and no_tool_incomplete_plan_retries + < _NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT + ): + logger.info( + "No tool calls with unfinished plan; retrying agent turn " + "(attempt %d/%d)", + no_tool_incomplete_plan_retries + 1, + _NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT, + ) + if content: + assistant_msg = _assistant_message_from_result( + llm_result, + model_name=llm_params.get("model"), + ) + session.context_manager.add_message( + assistant_msg, token_count + ) + session.context_manager.add_message( + Message( + role="user", + content=_no_tool_incomplete_plan_prompt( + unfinished_plan + ), + ) + ) + no_tool_incomplete_plan_retries += 1 + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": ( + "Plan still has unfinished items after a " + "text-only response — retrying instead of " + "returning to the prompt." + ), + }, + ) + ) + iteration += 1 + continue + logger.debug( "Agent loop ending: no tool calls. " "finish_reason=%s, token_count=%d, " @@ -1328,6 +1411,8 @@ async def run_agent( final_response = content break + no_tool_incomplete_plan_retries = 0 + # Validate tool call args (one json.loads per call, once) # and split into good vs bad good_tools: list[tuple[ToolCall, str, dict]] = [] diff --git a/agent/core/session.py b/agent/core/session.py index 59504b62..8999b818 100644 --- a/agent/core/session.py +++ b/agent/core/session.py @@ -118,6 +118,7 @@ def __init__( self.session_id = session_id or str(uuid.uuid4()) self.config = config self.is_running = True + self.current_plan: list[dict[str, str]] = [] self._cancelled = asyncio.Event() self.pending_approval: Optional[dict[str, Any]] = None self.sandbox = None diff --git a/agent/tools/plan_tool.py b/agent/tools/plan_tool.py index a923d53c..b85ae22c 100644 --- a/agent/tools/plan_tool.py +++ b/agent/tools/plan_tool.py @@ -54,20 +54,24 @@ async def execute(self, params: Dict[str, Any]) -> ToolResult: "isError": True, } - # Store the raw todos structure in memory - _current_plan = todos + # Store a session-scoped copy so the runtime can tell whether a + # text-only model response is trying to stop while work remains. + stored_todos = [dict(todo) for todo in todos] + _current_plan = stored_todos + if self.session is not None: + self.session.current_plan = stored_todos # Emit plan update event if session is available if self.session: await self.session.send_event( Event( event_type="plan_update", - data={"plan": todos}, + data={"plan": stored_todos}, ) ) # Format only for display using terminal_display utility - formatted_output = format_plan_tool_output(todos) + formatted_output = format_plan_tool_output(stored_todos) return { "formatted": formatted_output, diff --git a/tests/unit/test_no_tool_continuation_guard.py b/tests/unit/test_no_tool_continuation_guard.py new file mode 100644 index 00000000..fdd592d3 --- /dev/null +++ b/tests/unit/test_no_tool_continuation_guard.py @@ -0,0 +1,147 @@ +import asyncio +import json + +import pytest + +from agent.config import Config +from agent.core import agent_loop +from agent.core.agent_loop import Handlers, LLMResult +from agent.core.session import Session +from agent.tools.plan_tool import PlanTool + + +class FakeToolRouter: + def __init__(self): + self.calls = [] + + def get_tool_specs_for_llm(self): + return [ + { + "type": "function", + "function": { + "name": "plan_tool", + "description": "Update plan", + "parameters": {"type": "object"}, + }, + } + ] + + async def call_tool(self, name, arguments, session=None, tool_call_id=None): + self.calls.append((name, arguments, tool_call_id)) + if name == "plan_tool" and session is not None: + session.current_plan = [dict(todo) for todo in arguments["todos"]] + return "plan updated", True + + +@pytest.mark.asyncio +async def test_plan_tool_stores_session_scoped_plan(): + events = [] + + class FakeSession: + current_plan = [] + + async def send_event(self, event): + events.append(event) + + session = FakeSession() + todos = [{"id": "1", "content": "Smoke test", "status": "in_progress"}] + + result = await PlanTool(session=session).execute({"todos": todos}) + + assert result["isError"] is False + assert session.current_plan == todos + assert events[0].event_type == "plan_update" + assert events[0].data == {"plan": todos} + + +@pytest.mark.asyncio +async def test_no_tool_response_retries_when_plan_is_incomplete(monkeypatch): + config = Config.model_validate( + {"model_name": "openai/test", "save_sessions": False} + ) + event_queue = asyncio.Queue() + router = FakeToolRouter() + session = Session( + event_queue, + config, + tool_router=router, + stream=False, + ) + session.current_plan = [ + { + "id": "1", + "content": "Write and smoke-test training script", + "status": "in_progress", + }, + {"id": "2", "content": "Launch full training job", "status": "pending"}, + ] + calls = [] + + async def fake_call_llm_non_streaming(session, messages, tools, llm_params): + calls.append(messages) + if len(calls) == 1: + return LLMResult( + content="I should keep going, but I forgot to call a tool.", + tool_calls_acc={}, + token_count=10, + finish_reason="stop", + ) + if len(calls) == 2: + assert "CONTINUATION GUARD" in messages[-1].content + return LLMResult( + content=None, + tool_calls_acc={ + 0: { + "id": "call_1", + "function": { + "name": "plan_tool", + "arguments": json.dumps( + { + "todos": [ + { + "id": "1", + "content": "Write and smoke-test training script", + "status": "completed", + }, + { + "id": "2", + "content": "Launch full training job", + "status": "completed", + }, + ] + } + ), + }, + } + }, + token_count=20, + finish_reason="tool_calls", + ) + return LLMResult( + content="Done.", + tool_calls_acc={}, + token_count=30, + finish_reason="stop", + ) + + monkeypatch.setattr( + agent_loop, "_resolve_llm_params", lambda *_, **__: {"model": "openai/test"} + ) + monkeypatch.setattr( + agent_loop, "_call_llm_non_streaming", fake_call_llm_non_streaming + ) + + final = await Handlers.run_agent(session, "continue") + + assert final == "Done." + assert len(calls) == 3 + assert router.calls[0][0] == "plan_tool" + assert all(todo["status"] == "completed" for todo in session.current_plan) + events = [] + while not event_queue.empty(): + events.append(await event_queue.get()) + assert any( + event.event_type == "tool_log" + and "text-only response" in (event.data or {}).get("log", "") + for event in events + ) From c9d05ee09a72e21d106cec5f9114171fcb21264e Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 8 May 2026 15:01:48 +0200 Subject: [PATCH 9/9] Route sandbox deletion logs through tool events Co-authored-by: OpenAI Codex --- agent/tools/sandbox_client.py | 8 +- agent/tools/sandbox_tool.py | 34 +++-- tests/unit/test_sandbox_private_spaces.py | 127 +++++++++++++++++- .../unit/test_session_manager_persistence.py | 2 +- 4 files changed, 151 insertions(+), 20 deletions(-) diff --git a/agent/tools/sandbox_client.py b/agent/tools/sandbox_client.py index d1644fc9..91b889b8 100644 --- a/agent/tools/sandbox_client.py +++ b/agent/tools/sandbox_client.py @@ -776,21 +776,23 @@ def _wait_for_api( f"Last status: {last_status}, last error: {last_err}" ) - def delete(self): + def delete(self, log: Callable[[str], object] | None = None): """Delete the Space. Only works if this Sandbox created it.""" if not self._owns_space: raise RuntimeError( f"This Sandbox did not create {self.space_id}. " f"Use self._hf_api.delete_repo() directly if you're sure." ) - print(f"Deleting sandbox: {self.space_id}...") + if log: + log(f"Deleting sandbox: {self.space_id}...") self._hf_api.delete_repo(self.space_id, repo_type="space") # Clear ownership so a second cleanup call (e.g. delete_session + # _run_session.finally both fire) early-returns instead of retrying # a 404 delete and emitting a spurious ERROR log. self._owns_space = False self._client.close() - print("Deleted.") + if log: + log("Deleted.") def pause(self): """Pause the Space (stops billing, preserves state).""" diff --git a/agent/tools/sandbox_tool.py b/agent/tools/sandbox_tool.py index f5b2c863..b5ad9b92 100644 --- a/agent/tools/sandbox_tool.py +++ b/agent/tools/sandbox_tool.py @@ -16,6 +16,7 @@ import re import threading import weakref +from collections.abc import Callable from datetime import datetime, timedelta, timezone from typing import Any @@ -58,6 +59,24 @@ def _get_sandbox_create_lock(owner: str) -> asyncio.Lock: return lock +def _session_tool_logger( + session: Any, *, tool: str = "sandbox" +) -> Callable[[str], object] | None: + event_queue = getattr(session, "event_queue", None) + if event_queue is None: + return None + + loop = asyncio.get_running_loop() + + def _log(msg: str) -> None: + loop.call_soon_threadsafe( + event_queue.put_nowait, + Event(event_type="tool_log", data={"tool": tool, "log": msg}), + ) + + return _log + + def _looks_like_path(script: str) -> bool: """Return True if the script string looks like a file path (not inline code).""" if not ( @@ -309,14 +328,8 @@ async def _create_sandbox_locked( ) ) - # Thread-safe log callback: posts tool_log events from the worker thread - loop = asyncio.get_running_loop() - - def _log(msg: str) -> None: - loop.call_soon_threadsafe( - session.event_queue.put_nowait, - Event(event_type="tool_log", data={"tool": "sandbox", "log": msg}), - ) + # Thread-safe log callback: posts tool_log events from worker threads. + _log = _session_tool_logger(session) or (lambda msg: None) # Bridge asyncio cancel event to a threading.Event for the blocking create call. # We poll session._cancelled from the main loop in a background task and set @@ -358,7 +371,7 @@ async def _watch_cancel(): if cancel_flag.is_set(): if getattr(sb, "_owns_space", False): try: - await asyncio.to_thread(sb.delete) + await asyncio.to_thread(sb.delete, log=_log) except Exception as e: logger.warning( "Failed to delete cancelled sandbox %s: %s", sb.space_id, e @@ -503,6 +516,7 @@ async def teardown_session_sandbox(session: Any) -> None: return space_id = getattr(sandbox, "space_id", None) + delete_log = _session_tool_logger(session) last_err: Exception | None = None for attempt in range(3): try: @@ -511,7 +525,7 @@ async def teardown_session_sandbox(session: Any) -> None: space_id, attempt + 1, ) - await asyncio.to_thread(sandbox.delete) + await asyncio.to_thread(sandbox.delete, log=delete_log) from agent.core import telemetry await telemetry.record_sandbox_destroy(session, sandbox) diff --git a/tests/unit/test_sandbox_private_spaces.py b/tests/unit/test_sandbox_private_spaces.py index b05b0ab2..115a0cfa 100644 --- a/tests/unit/test_sandbox_private_spaces.py +++ b/tests/unit/test_sandbox_private_spaces.py @@ -91,6 +91,31 @@ def get_space_runtime(self, space_id): assert not any("sleep time" in log for log in logs) +def test_sandbox_delete_uses_log_callback_without_stdout(monkeypatch, capsys): + deleted: list[tuple[str, str]] = [] + + class FakeApi: + def __init__(self, token=None): + self.token = token + + def delete_repo(self, repo_id, repo_type): + deleted.append((repo_id, repo_type)) + + monkeypatch.setattr(sandbox_client, "HfApi", FakeApi) + + sandbox = Sandbox("alice/sandbox-12345678", token="hf-token", _owns_space=True) + logs: list[str] = [] + + sandbox.delete(log=logs.append) + + captured = capsys.readouterr() + assert captured.out == "" + assert captured.err == "" + assert deleted == [("alice/sandbox-12345678", "space")] + assert logs == ["Deleting sandbox: alice/sandbox-12345678...", "Deleted."] + assert sandbox._owns_space is False + + def test_sandbox_client_retries_transient_runtime_404(monkeypatch): runtime_calls = 0 @@ -395,6 +420,71 @@ async def run(): assert persisted[-1]["sandbox_status"] == "active" +def test_cancelled_sandbox_creation_logs_delete_through_tool_log(monkeypatch): + deleted: list[str] = [] + + class FakeSession: + def __init__(self): + self.hf_token = "hf-token" + self.sandbox = None + self.event_queue = asyncio.Queue() + self._cancelled = asyncio.Event() + + async def send_event(self, event): + await self.event_queue.put(event) + + def fake_create(**kwargs): + def delete(log=None): + deleted.append("alice/sandbox-12345678") + if log: + log("Deleting sandbox: alice/sandbox-12345678...") + log("Deleted.") + + return SimpleNamespace( + space_id="alice/sandbox-12345678", + url="https://huggingface.co/spaces/alice/sandbox-12345678", + _owns_space=True, + delete=delete, + ) + + monkeypatch.setattr(Sandbox, "create", staticmethod(fake_create)) + + async def run(): + session = FakeSession() + cancel_event = threading.Event() + cancel_event.set() + + sb, error = await sandbox_tool._create_sandbox_locked( + session, + api=SimpleNamespace(), + owner="alice", + hardware="cpu-basic", + cancel_event=cancel_event, + ) + await asyncio.sleep(0) + events = [] + while not session.event_queue.empty(): + events.append(await session.event_queue.get()) + return sb, error, events + + sb, error, events = asyncio.run(run()) + + assert sb is None + assert error == "Sandbox creation cancelled by user." + assert deleted == ["alice/sandbox-12345678"] + assert [ + event.data + for event in events + if event.event_type == "tool_log" + and event.data + and event.data.get("log") + in {"Deleting sandbox: alice/sandbox-12345678...", "Deleted."} + ] == [ + {"tool": "sandbox", "log": "Deleting sandbox: alice/sandbox-12345678..."}, + {"tool": "sandbox", "log": "Deleted."}, + ] + + def test_sandbox_creation_is_serialized_per_owner(monkeypatch): active_creates = 0 max_active_creates = 0 @@ -514,7 +604,7 @@ def __init__(self): space_id="alice/sandbox-cpu", url="https://huggingface.co/spaces/alice/sandbox-cpu", _owns_space=True, - delete=lambda: deleted.append("alice/sandbox-cpu"), + delete=lambda log=None: deleted.append("alice/sandbox-cpu"), ) self.sandbox_hardware = "cpu-basic" self.sandbox_preload_task = None @@ -559,10 +649,11 @@ async def fake_record_sandbox_destroy(*args, **kwargs): def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch): deleted: list[str] = [] + destroyed: list[str] = [] persisted: list[dict] = [] - async def fake_record_sandbox_destroy(*args, **kwargs): - pass + async def fake_record_sandbox_destroy(session, sandbox, *args, **kwargs): + destroyed.append(sandbox.space_id) monkeypatch.setattr( telemetry, "record_sandbox_destroy", fake_record_sandbox_destroy @@ -570,20 +661,28 @@ async def fake_record_sandbox_destroy(*args, **kwargs): async def run(): cancel_event = threading.Event() + event_queue = asyncio.Queue() async def preload(): await asyncio.sleep(0) + def delete(log=None): + deleted.append("alice/sandbox-12345678") + if log: + log("Deleting sandbox: alice/sandbox-12345678...") + log("Deleted.") + session = SimpleNamespace( session_id="s1", sandbox=SimpleNamespace( space_id="alice/sandbox-12345678", _owns_space=True, - delete=lambda: deleted.append("alice/sandbox-12345678"), + delete=delete, ), sandbox_hardware="cpu-basic", sandbox_preload_task=asyncio.create_task(preload()), sandbox_preload_cancel_event=cancel_event, + event_queue=event_queue, persistence_store=SimpleNamespace( update_session_fields=lambda session_id, **fields: _record_metadata( session_id, fields @@ -592,17 +691,33 @@ async def preload(): ) await sandbox_tool.teardown_session_sandbox(session) - return session, cancel_event + await asyncio.sleep(0) + events = [] + while not event_queue.empty(): + events.append(await event_queue.get()) + return session, cancel_event, events async def _record_metadata(session_id, fields): persisted.append({"session_id": session_id, **fields}) - session, cancel_event = asyncio.run(run()) + session, cancel_event, events = asyncio.run(run()) assert cancel_event.is_set() assert deleted == ["alice/sandbox-12345678"] + assert destroyed == ["alice/sandbox-12345678"] assert session.sandbox is None assert session.sandbox_hardware is None + assert [ + event.data + for event in events + if event.event_type == "tool_log" + and event.data + and event.data.get("log") + in {"Deleting sandbox: alice/sandbox-12345678...", "Deleted."} + ] == [ + {"tool": "sandbox", "log": "Deleting sandbox: alice/sandbox-12345678..."}, + {"tool": "sandbox", "log": "Deleted."}, + ] assert persisted[-1]["session_id"] == "s1" assert persisted[-1]["sandbox_space_id"] is None assert persisted[-1]["sandbox_status"] == "destroyed" diff --git a/tests/unit/test_session_manager_persistence.py b/tests/unit/test_session_manager_persistence.py index 0835d878..59016eca 100644 --- a/tests/unit/test_session_manager_persistence.py +++ b/tests/unit/test_session_manager_persistence.py @@ -207,7 +207,7 @@ async def preload(): session.sandbox = SimpleNamespace( space_id="owner/sandbox-12345678", _owns_space=True, - delete=lambda: deleted.append("owner/sandbox-12345678"), + delete=lambda log=None: deleted.append("owner/sandbox-12345678"), ) session.sandbox_hardware = "cpu-basic" session.sandbox_preload_cancel_event = preload_cancel_event