Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ session_logs/
hf-agent-leaderboard/
skills/
.claude/
.omc/
.omx/
*.jsonl
*.csv

Expand Down
136 changes: 70 additions & 66 deletions agent/core/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import asyncio
import json
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Any
Expand Down Expand Up @@ -917,7 +916,6 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
token_count = response.usage.total_tokens if response.usage else 0
thinking_blocks, reasoning_content = _extract_thinking_state(message)

# Build tool_calls_acc in the same format as streaming
tool_calls_acc: dict[int, dict] = {}
if message.tool_calls:
for idx, tc in enumerate(message.tool_calls):
Expand All @@ -930,7 +928,6 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
},
}

# Emit the full message as a single event
if content:
await session.send_event(
Event(event_type="assistant_message", data={"content": content})
Expand Down Expand Up @@ -1306,37 +1303,40 @@ async def _exec_tool(
)
return (tc, name, args, out, ok)

gather_task = asyncio.ensure_future(asyncio.gather(
*[
_exec_tool(tc, name, args, decision, valid, err)
for tc, name, args, decision, valid, err in parsed_tools
]
))
cancel_task = asyncio.ensure_future(session._cancelled.wait())

done, _ = await asyncio.wait(
[gather_task, cancel_task],
return_when=asyncio.FIRST_COMPLETED,
)

if cancel_task in done:
gather_task.cancel()
try:
await gather_task
except asyncio.CancelledError:
pass
# Notify frontend that in-flight tools were cancelled
for tc, name, _args, _decision, valid, _ in parsed_tools:
if valid:
await session.send_event(Event(
event_type="tool_state_change",
data={"tool_call_id": tc.id, "tool": name, "state": "cancelled"},
))
await _cleanup_on_cancel(session)
break
session.is_in_tool_call = True
try:
gather_task = asyncio.ensure_future(asyncio.gather(
*[
_exec_tool(tc, name, args, decision, valid, err)
for tc, name, args, decision, valid, err in parsed_tools
]
))
cancel_task = asyncio.ensure_future(session._cancelled.wait())

done, _ = await asyncio.wait(
[gather_task, cancel_task],
return_when=asyncio.FIRST_COMPLETED,
)

cancel_task.cancel()
results = gather_task.result()
if cancel_task in done:
gather_task.cancel()
try:
await gather_task
except asyncio.CancelledError:
pass
for tc, name, _args, _decision, valid, _ in parsed_tools:
if valid:
await session.send_event(Event(
event_type="tool_state_change",
data={"tool_call_id": tc.id, "tool": name, "state": "cancelled"},
))
await _cleanup_on_cancel(session)
break

cancel_task.cancel()
results = gather_task.result()
finally:
session.is_in_tool_call = False

# 4. Record results and send outputs (order preserved)
for tc, tool_name, tool_args, output, success in results:
Expand Down Expand Up @@ -1610,40 +1610,44 @@ async def execute_tool(tc, tool_name, tool_args, was_edited):

# Execute all approved tools concurrently (cancellable)
if approved_tasks:
gather_task = asyncio.ensure_future(asyncio.gather(
*[
execute_tool(tc, tool_name, tool_args, was_edited)
for tc, tool_name, tool_args, was_edited in approved_tasks
],
return_exceptions=True,
))
cancel_task = asyncio.ensure_future(session._cancelled.wait())

done, _ = await asyncio.wait(
[gather_task, cancel_task],
return_when=asyncio.FIRST_COMPLETED,
)
session.is_in_tool_call = True
try:
gather_task = asyncio.ensure_future(asyncio.gather(
*[
execute_tool(tc, tool_name, tool_args, was_edited)
for tc, tool_name, tool_args, was_edited in approved_tasks
],
return_exceptions=True,
))
cancel_task = asyncio.ensure_future(session._cancelled.wait())

if cancel_task in done:
gather_task.cancel()
try:
await gather_task
except asyncio.CancelledError:
pass
# Notify frontend that approved tools were cancelled
for tc, tool_name, _args, _was_edited in approved_tasks:
await session.send_event(Event(
event_type="tool_state_change",
data={"tool_call_id": tc.id, "tool": tool_name, "state": "cancelled"},
))
await _cleanup_on_cancel(session)
await session.send_event(Event(event_type="interrupted"))
session.increment_turn()
await session.auto_save_if_needed()
return

cancel_task.cancel()
results = gather_task.result()
done, _ = await asyncio.wait(
[gather_task, cancel_task],
return_when=asyncio.FIRST_COMPLETED,
)

if cancel_task in done:
gather_task.cancel()
try:
await gather_task
except asyncio.CancelledError:
pass
# Notify frontend that approved tools were cancelled
for tc, tool_name, _args, _was_edited in approved_tasks:
await session.send_event(Event(
event_type="tool_state_change",
data={"tool_call_id": tc.id, "tool": tool_name, "state": "cancelled"},
))
await _cleanup_on_cancel(session)
await session.send_event(Event(event_type="interrupted"))
session.increment_turn()
await session.auto_save_if_needed()
return

cancel_task.cancel()
results = gather_task.result()
finally:
session.is_in_tool_call = False

# Process results and add to context
for result in results:
Expand Down
1 change: 1 addition & 0 deletions agent/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
self.config = config
self.is_running = True
self._cancelled = asyncio.Event()
self.is_in_tool_call: bool = False
self.pending_approval: Optional[dict[str, Any]] = None
self.sandbox = None
self._running_job_ids: set[str] = set() # HF job IDs currently executing
Expand Down
Loading
Loading