Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ ml-intern --model anthropic/claude-opus-4-7 "your prompt" # requires ANTHROPIC
ml-intern --model openai/gpt-5.5 "your prompt" # requires OPENAI_API_KEY
ml-intern --model ollama/llama3.1:8b "your prompt"
ml-intern --model vllm/meta-llama/Llama-3.1-8B-Instruct "your prompt"
ml-intern --sandbox-tools "your prompt" # use HF Space sandbox tools
ml-intern --max-iterations 100 "your prompt"
ml-intern --no-stream "your prompt"
```
Expand Down Expand Up @@ -97,6 +98,30 @@ one shared local endpoint, or override a specific provider with its matching
`VLLM_API_KEY`. Provider-specific variables take precedence over the shared
local variables. Base URLs may include or omit `/v1`.

**CLI tool runtime:**

By default, the CLI runs `bash`, `read`, `write`, and `edit` on your local
filesystem. To use HF Space sandbox tools instead, including `sandbox_create`,
opt in with `--sandbox-tools`:

```bash
ml-intern --sandbox-tools "test this training script in a GPU sandbox"
ml-intern --model llamacpp/ggml-org/gemma-3-1b-it-GGUF --sandbox-tools
```

Sandbox tool runtime requires `HF_TOKEN`, even when the selected model is local,
because it creates private HF Spaces. You can also make sandbox tools your CLI
default in `~/.config/ml-intern/cli_agent_config.json`:

```json
{ "tool_runtime": "sandbox" }
```

Use the default local runtime when you want tools to inspect or edit files in
your checkout. Use sandbox runtime when you want the agent to create or replace
an HF Space sandbox, test code remotely, or request GPU sandbox hardware before
launching larger HF Jobs.

## Sharing Traces

Every session is auto-uploaded to your **own private Hugging Face dataset**
Expand Down
3 changes: 2 additions & 1 deletion agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import re
from pathlib import Path
from typing import Any, Union
from typing import Any, Literal, Union

from dotenv import load_dotenv
from fastmcp.mcp_config import (
Expand Down Expand Up @@ -46,6 +46,7 @@ class Config(BaseModel):
# Permission control parameters
confirm_cpu_jobs: bool = True
auto_file_upload: bool = False
tool_runtime: Literal["local", "sandbox"] = "local"

# Reasoning effort *preference* — the ceiling the user wants. The probe
# on `/model` walks a cascade down from here (``max`` → ``xhigh`` → ``high``
Expand Down
95 changes: 94 additions & 1 deletion agent/core/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,55 @@
from agent.core.session import DEFAULT_SESSION_LOG_DIR, Event, OpType, Session
from agent.core.tools import ToolRouter
from agent.tools.jobs_tool import CPU_FLAVORS
from agent.tools.sandbox_tool import DEFAULT_CPU_SANDBOX_HARDWARE
from agent.tools.sandbox_tool import (
DEFAULT_CPU_SANDBOX_HARDWARE,
start_cpu_sandbox_preload,
teardown_session_sandbox,
)

logger = logging.getLogger(__name__)

ToolCall = ChatCompletionMessageToolCall

_MALFORMED_TOOL_PREFIX = "ERROR: Tool call to '"
_MALFORMED_TOOL_SUFFIX = "' had malformed JSON arguments"
_NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT = 2


def _unfinished_plan_items(session: Session) -> list[dict[str, str]]:
plan = getattr(session, "current_plan", None) or []
unfinished: list[dict[str, str]] = []
for item in plan:
if not isinstance(item, dict):
continue
status = item.get("status")
if status in {"pending", "in_progress"}:
unfinished.append(item)
return unfinished


def _format_plan_items_for_guard(items: list[dict[str, str]], limit: int = 4) -> str:
formatted = []
for item in items[:limit]:
item_id = item.get("id") or "?"
content = item.get("content") or "(unnamed task)"
status = item.get("status") or "unknown"
formatted.append(f"{item_id}. {content} [{status}]")
if len(items) > limit:
formatted.append(f"... and {len(items) - limit} more")
return "; ".join(formatted)


def _no_tool_incomplete_plan_prompt(items: list[dict[str, str]]) -> str:
summary = _format_plan_items_for_guard(items)
return (
"[SYSTEM: CONTINUATION GUARD] Your previous response ended without any "
"tool calls, but the task is not complete. The current plan still has "
f"unfinished items: {summary}. Do not return control to the user yet. "
"Continue from the next unfinished item and make at least one tool call "
"now. If you genuinely cannot continue, first use tools to inspect the "
"state or verify the blocker."
)


def _malformed_tool_name(message: Message) -> str | None:
Expand Down Expand Up @@ -1153,6 +1194,7 @@ async def run_agent(
final_response = None
errored = False
max_iterations = session.config.max_iterations
no_tool_incomplete_plan_retries = 0

while max_iterations == -1 or iteration < max_iterations:
# ── Cancellation check: before LLM call ──
Expand Down Expand Up @@ -1301,6 +1343,51 @@ async def run_agent(

# If no tool calls, add assistant message and we're done
if not tool_calls:
unfinished_plan = _unfinished_plan_items(session)
if (
unfinished_plan
and no_tool_incomplete_plan_retries
< _NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT
):
logger.info(
"No tool calls with unfinished plan; retrying agent turn "
"(attempt %d/%d)",
no_tool_incomplete_plan_retries + 1,
_NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT,
)
if content:
assistant_msg = _assistant_message_from_result(
llm_result,
model_name=llm_params.get("model"),
)
session.context_manager.add_message(
assistant_msg, token_count
)
session.context_manager.add_message(
Message(
role="user",
content=_no_tool_incomplete_plan_prompt(
unfinished_plan
),
)
)
no_tool_incomplete_plan_retries += 1
await session.send_event(
Event(
event_type="tool_log",
data={
"tool": "system",
"log": (
"Plan still has unfinished items after a "
"text-only response — retrying instead of "
"returning to the prompt."
),
},
)
)
iteration += 1
continue

logger.debug(
"Agent loop ending: no tool calls. "
"finish_reason=%s, token_count=%d, "
Expand All @@ -1324,6 +1411,8 @@ async def run_agent(
final_response = content
break

no_tool_incomplete_plan_retries = 0

# Validate tool call args (one json.loads per call, once)
# and split into good vs bad
good_tools: list[tuple[ToolCall, str, dict]] = []
Expand Down Expand Up @@ -1940,6 +2029,8 @@ async def shutdown(session: Session) -> bool:
_ = session.save_and_upload_detached(repo_id)

session.is_running = False
if not getattr(session, "local_mode", False):
await teardown_session_sandbox(session)
await session.send_event(Event(event_type="shutdown"))
return True

Expand Down Expand Up @@ -2023,6 +2114,8 @@ async def submission_loop(
)
if session_holder is not None:
session_holder[0] = session
if not local_mode:
start_cpu_sandbox_preload(session)
logger.info("Agent loop started")

# Retry any failed uploads from previous sessions (fire-and-forget).
Expand Down
2 changes: 2 additions & 0 deletions agent/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
self.hf_token: Optional[str] = hf_token
self.user_id: Optional[str] = user_id
self.hf_username: Optional[str] = hf_username
self.local_mode = local_mode
self.persistence_store = persistence_store
self.tool_router = tool_router
self.stream = stream
Expand All @@ -117,6 +118,7 @@ def __init__(
self.session_id = session_id or str(uuid.uuid4())
self.config = config
self.is_running = True
self.current_plan: list[dict[str, str]] = []
self._cancelled = asyncio.Event()
self.pending_approval: Optional[dict[str, Any]] = None
self.sandbox = None
Expand Down
Loading
Loading