diff --git a/mini_agent/agent.py b/mini_agent/agent.py index 9539bc4..821c90c 100644 --- a/mini_agent/agent.py +++ b/mini_agent/agent.py @@ -73,10 +73,43 @@ def __init__( # Initialize logger self.logger = AgentLogger() + # Pause/stop control flags + self._stop_requested = False + self._stop_notified = False + self._paused = False + self._resume_step = 0 + def add_user_message(self, content: str): """Add a user message to history.""" self.messages.append(Message(role="user", content=content)) + def request_stop(self): + """Signal the agent loop to stop at the next safe checkpoint.""" + self._stop_requested = True + self._stop_notified = False + + def cancel_pause(self): + """Reset pause state when user abandons a resume.""" + self._stop_requested = False + self._stop_notified = False + self._paused = False + self._resume_step = 0 + + def _check_stop_requested(self, current_step: int) -> bool: + """Return True if a stop was requested and emit a single notification.""" + if not self._stop_requested: + return False + if not self._stop_notified: + print(f"\n{Colors.BRIGHT_YELLOW}⏸️ Agent paused by user (press Enter to continue interacting).{Colors.RESET}\n") + self._stop_notified = True + self._paused = True + self._resume_step = current_step + return True + + def is_paused(self) -> bool: + """Whether the agent halted due to a stop request.""" + return self._paused + def _estimate_tokens(self) -> int: """Accurately calculate token count for message history using tiktoken @@ -258,13 +291,24 @@ async def _create_summary(self, messages: list[Message], round_num: int) -> str: async def run(self) -> str: """Execute agent loop until task is complete or max steps reached.""" - # Start new run, initialize log file - self.logger.start_new_run() - print(f"{Colors.DIM}πŸ“ Log file: {self.logger.get_log_file_path()}{Colors.RESET}") - - step = 0 + resuming = self._paused + if not resuming: + self.logger.start_new_run() + print(f"{Colors.DIM}πŸ“ Log file: {self.logger.get_log_file_path()}{Colors.RESET}") + else: + print(f"{Colors.DIM}πŸ“ Resuming run (log: {self.logger.get_log_file_path()}){Colors.RESET}") + + step = self._resume_step if resuming else 0 + self._stop_requested = False + self._stop_notified = False + self._paused = False + if not resuming: + self._resume_step = 0 while step < self.max_steps: + if self._check_stop_requested(step): + return "Agent run interrupted by user." + # Check and summarize message history to prevent context overflow await self._summarize_messages() @@ -296,8 +340,12 @@ async def run(self) -> str: else: error_msg = f"LLM call failed: {str(e)}" print(f"\n{Colors.BRIGHT_RED}❌ Error:{Colors.RESET} {error_msg}") + self.cancel_pause() return error_msg + if self._check_stop_requested(step): + return "Agent run interrupted by user." + # Log LLM response self.logger.log_response( content=response.content, @@ -327,9 +375,10 @@ async def run(self) -> str: # Check if task is complete (no tool calls) if not response.tool_calls: + self.cancel_pause() return response.content - # Execute tool calls + # Execute all tool calls before checking for stop requests again for tool_call in response.tool_calls: tool_call_id = tool_call.id function_name = tool_call.function.name @@ -403,10 +452,15 @@ async def run(self) -> str: self.messages.append(tool_msg) step += 1 + self._resume_step = step + + if self._check_stop_requested(step): + return "Agent run interrupted by user." # Max steps reached error_msg = f"Task couldn't be completed after {self.max_steps} steps." print(f"\n{Colors.BRIGHT_YELLOW}⚠️ {error_msg}{Colors.RESET}") + self.cancel_pause() return error_msg def get_history(self) -> list[Message]: diff --git a/mini_agent/cli.py b/mini_agent/cli.py index 096be5e..82d3ca8 100644 --- a/mini_agent/cli.py +++ b/mini_agent/cli.py @@ -11,10 +11,18 @@ import argparse import asyncio +import sys from datetime import datetime from pathlib import Path from typing import List +try: + import termios + import tty +except ImportError: # pragma: no cover - Windows fallback + termios = None # type: ignore[assignment] + tty = None # type: ignore[assignment] + from prompt_toolkit import PromptSession from prompt_toolkit.auto_suggest import AutoSuggestFromHistory from prompt_toolkit.completion import WordCompleter @@ -70,6 +78,75 @@ class Colors: BG_BLUE = "\033[44m" +class EscapeKeyListener: + """Listen for ESC key presses during agent execution to request a stop.""" + + def __init__(self, agent: Agent): + self.agent = agent + self._loop: asyncio.AbstractEventLoop | None = None + self._fd: int | None = None + self._old_settings = None + self._cbreak_enabled = False + self._reader_registered = False + + async def __aenter__(self): + self._loop = asyncio.get_running_loop() + await self._loop.run_in_executor(None, self._enable_cbreak_mode) + if ( + self._cbreak_enabled + and self._loop + and self._fd is not None + and hasattr(self._loop, "add_reader") + ): + self._loop.add_reader(self._fd, self._handle_keypress) + self._reader_registered = True + return self + + async def __aexit__(self, exc_type, exc, tb): + if self._reader_registered and self._loop and self._fd is not None: + self._loop.remove_reader(self._fd) + self._reader_registered = False + if self._loop: + await self._loop.run_in_executor(None, self._restore_terminal) + + def _enable_cbreak_mode(self): + if termios is None or tty is None: + return + if not sys.stdin.isatty(): + return + try: + self._fd = sys.stdin.fileno() + self._old_settings = termios.tcgetattr(self._fd) + tty.setcbreak(self._fd) + self._cbreak_enabled = True + except Exception: + self._cbreak_enabled = False + + def _restore_terminal(self): + if ( + self._cbreak_enabled + and self._fd is not None + and self._old_settings is not None + and termios is not None + ): + termios.tcsetattr(self._fd, termios.TCSADRAIN, self._old_settings) + self._cbreak_enabled = False + + def _handle_keypress(self): + if not self._cbreak_enabled: + return + try: + ch = sys.stdin.read(1) + except Exception: + return + if ch == "\x1b": # ESC key + print(f"\n{Colors.BRIGHT_YELLOW}⏹️ Escape detected, requesting agent pause...{Colors.RESET}") + self.agent.request_stop() + if self._loop and self._reader_registered and self._fd is not None: + self._loop.remove_reader(self._fd) + self._reader_registered = False + + def print_banner(): """Print welcome banner with proper alignment""" BOX_WIDTH = 58 @@ -105,6 +182,7 @@ def print_help(): {Colors.BRIGHT_CYAN}Tab{Colors.RESET} - Auto-complete commands {Colors.BRIGHT_CYAN}↑/↓{Colors.RESET} - Browse command history {Colors.BRIGHT_CYAN}β†’{Colors.RESET} - Accept auto-suggestion + {Colors.BRIGHT_CYAN}Esc{Colors.RESET} - Pause the current agent run (press Enter to resume) {Colors.BOLD}{Colors.BRIGHT_YELLOW}Usage:{Colors.RESET} - Enter your task directly, Agent will help you complete it @@ -112,6 +190,7 @@ def print_help(): - Use {Colors.BRIGHT_GREEN}/clear{Colors.RESET} to start a new session - Press {Colors.BRIGHT_CYAN}Enter{Colors.RESET} to submit your message - Use {Colors.BRIGHT_CYAN}Ctrl+J{Colors.RESET} to insert line breaks within your message + - Press {Colors.BRIGHT_CYAN}Esc{Colors.RESET} anytime during execution to stop the agent, then press Enter (empty line) to resume """ print(help_text) @@ -491,6 +570,13 @@ def _(event): key_bindings=kb, ) + async def invoke_agent_run(): + print(f"\n{Colors.BRIGHT_BLUE}Agent{Colors.RESET} {Colors.DIM}β€Ί{Colors.RESET} {Colors.DIM}Thinking...{Colors.RESET}\n") + async with EscapeKeyListener(agent): + return await agent.run() + + resume_pending = False + # 9. Interactive loop while True: try: @@ -507,7 +593,14 @@ def _(event): user_input = user_input.strip() if not user_input: - continue + if resume_pending: + result = await invoke_agent_run() + resume_pending = agent.is_paused() + if not resume_pending: + print(f"\n{Colors.DIM}{'─' * 60}{Colors.RESET}\n") + continue + else: + continue # Handle commands if user_input.startswith("/"): @@ -520,6 +613,8 @@ def _(event): elif command == "/help": print_help() + resume_pending = False + agent.cancel_pause() continue elif command == "/clear": @@ -527,19 +622,27 @@ def _(event): old_count = len(agent.messages) agent.messages = [agent.messages[0]] # Keep only system message print(f"{Colors.GREEN}βœ… Cleared {old_count - 1} messages, starting new session{Colors.RESET}\n") + resume_pending = False + agent.cancel_pause() continue elif command == "/history": print(f"\n{Colors.BRIGHT_CYAN}Current session message count: {len(agent.messages)}{Colors.RESET}\n") + resume_pending = False + agent.cancel_pause() continue elif command == "/stats": print_stats(agent, session_start) + resume_pending = False + agent.cancel_pause() continue else: print(f"{Colors.RED}❌ Unknown command: {user_input}{Colors.RESET}") print(f"{Colors.DIM}Type /help to see available commands{Colors.RESET}\n") + resume_pending = False + agent.cancel_pause() continue # Normal conversation - exit check @@ -549,12 +652,16 @@ def _(event): break # Run Agent - print(f"\n{Colors.BRIGHT_BLUE}Agent{Colors.RESET} {Colors.DIM}β€Ί{Colors.RESET} {Colors.DIM}Thinking...{Colors.RESET}\n") + if resume_pending: + agent.cancel_pause() + resume_pending = False agent.add_user_message(user_input) - _ = await agent.run() + result = await invoke_agent_run() + resume_pending = agent.is_paused() - # Visual separation - keep it simple like the reference code - print(f"\n{Colors.DIM}{'─' * 60}{Colors.RESET}\n") + if not resume_pending: + # Visual separation - keep it simple like the reference code + print(f"\n{Colors.DIM}{'─' * 60}{Colors.RESET}\n") except KeyboardInterrupt: print(f"\n\n{Colors.BRIGHT_YELLOW}πŸ‘‹ Interrupt signal detected, exiting...{Colors.RESET}\n") diff --git a/tests/test_agent_interrupts.py b/tests/test_agent_interrupts.py new file mode 100644 index 0000000..b6b345d --- /dev/null +++ b/tests/test_agent_interrupts.py @@ -0,0 +1,147 @@ +"""Tests for Agent interrupt/pause behavior.""" + +import asyncio + +from mini_agent.agent import Agent +from mini_agent.schema import FunctionCall, LLMResponse, ToolCall +from mini_agent.tools.base import Tool, ToolResult + + +class StubLLM: + """Fake LLM client returning predetermined responses.""" + + def __init__(self): + self.call_count = 0 + + tool_calls = [ + ToolCall( + id="call-1", + type="function", + function=FunctionCall(name="tool_a", arguments={}), + ), + ToolCall( + id="call-2", + type="function", + function=FunctionCall(name="tool_b", arguments={}), + ), + ] + self.tool_response = LLMResponse( + content="", + thinking=None, + tool_calls=tool_calls, + finish_reason="tool_calls", + ) + self.final_response = LLMResponse( + content="All done!", + thinking=None, + tool_calls=None, + finish_reason="stop", + ) + + async def generate(self, messages, tools=None): + if self.call_count == 0: + self.call_count += 1 + return self.tool_response + return self.final_response + + +class NotifyingTool(Tool): + """Simple tool that records executions and can trigger an event.""" + + def __init__(self, name: str, event: asyncio.Event | None = None): + self._name = name + self._event = event + self.executions = 0 + + @property + def name(self) -> str: + return self._name + + @property + def description(self) -> str: + return f"Test tool {self._name}" + + @property + def parameters(self) -> dict: + return {"type": "object", "properties": {}} + + async def execute(self, **kwargs) -> ToolResult: + self.executions += 1 + if self._event and not self._event.is_set(): + self._event.set() + await asyncio.sleep(0) # Ensure the caller can observe the event + return ToolResult(success=True, content=f"{self._name}-result") + + +async def _run_interrupt_preserves_tool_results(tmp_path): + """Requesting stop mid-tool loop should still record all tool outputs.""" + + llm = StubLLM() + first_tool_event = asyncio.Event() + + tool_a = NotifyingTool("tool_a", event=first_tool_event) + tool_b = NotifyingTool("tool_b") + + agent = Agent( + llm_client=llm, + system_prompt="System prompt", + tools=[tool_a, tool_b], + max_steps=3, + workspace_dir=str(tmp_path), + ) + + agent.add_user_message("Run tools") + + run_task = asyncio.create_task(agent.run()) + + # Wait until the first tool has executed, then request a stop + await first_tool_event.wait() + agent.request_stop() + + result = await run_task + assert "interrupted" in result.lower() + + # Ensure both tool results exist in message history + tool_messages = [m for m in agent.messages if m.role == "tool"] + assert len(tool_messages) == 2 + assert {m.name for m in tool_messages} == {"tool_a", "tool_b"} + + +async def _run_agent_can_resume_after_interrupt(tmp_path): + """Agent should continue the same run after being interrupted.""" + + llm = StubLLM() + tool_a_event = asyncio.Event() + + tool_a = NotifyingTool("tool_a", event=tool_a_event) + tool_b = NotifyingTool("tool_b") + + agent = Agent( + llm_client=llm, + system_prompt="System prompt", + tools=[tool_a, tool_b], + max_steps=3, + workspace_dir=str(tmp_path), + ) + agent.add_user_message("Run tools") + + first_run = asyncio.create_task(agent.run()) + await tool_a_event.wait() + agent.request_stop() + await first_run + + # Resume run without adding new messages + final_result = await agent.run() + assert final_result == "All done!" + + tool_messages = [m for m in agent.messages if m.role == "tool"] + assert len(tool_messages) == 2 + assert not agent.is_paused() + + +def test_interrupt_preserves_tool_results(tmp_path): + asyncio.run(_run_interrupt_preserves_tool_results(tmp_path)) + + +def test_agent_can_resume_after_interrupt(tmp_path): + asyncio.run(_run_agent_can_resume_after_interrupt(tmp_path))