diff --git a/README.md b/README.md index 75ea6aa1..f5eb9c87 100644 --- a/README.md +++ b/README.md @@ -210,7 +210,7 @@ In addition to all [built-in Jinja filters](https://jinja.palletsprojects.com/en - **`where`**: Specifies where to trim the message if it exceeds the limit. The default is `"middle"`, which trims from the middle of the message. Other options are `start` or `end`. ```jinja - {{ info.instructions | trim_message(max_length_percentage=0.1, where="end") }} + {{ info.dir_tree | trim_message(max_length_percentage=0.1, where="end") }} ``` #### Example Template @@ -223,6 +223,9 @@ Task: {{ agent.system_prompt }} Instructions: {{ info.instructions }} +Directory Tree: +{{ info.dir_tree | trim_message(max_length=1000) }} + Current Breakpoints: {{ info.current_breakpoints | to_pretty_json }} diff --git a/debug_gym/agents/__init__.py b/debug_gym/agents/__init__.py index 424ccc9c..fbfbae31 100644 --- a/debug_gym/agents/__init__.py +++ b/debug_gym/agents/__init__.py @@ -1,4 +1,6 @@ +from debug_gym.agents.base_agent import BaseAgent, register_agent from debug_gym.agents.debug_agent import Debug_5_Agent, DebugAgent +from debug_gym.agents.free_agent import FreeAgent from debug_gym.agents.rewrite_agent import RewriteAgent from debug_gym.agents.solution_agent import AgentSolution from debug_gym.agents.swe_agent import SWEAgent diff --git a/debug_gym/agents/base_agent.py b/debug_gym/agents/base_agent.py index 767976a3..04d5fefd 100644 --- a/debug_gym/agents/base_agent.py +++ b/debug_gym/agents/base_agent.py @@ -2,8 +2,7 @@ import os import subprocess import uuid -from dataclasses import MISSING, asdict, dataclass, field, fields -from typing import Any, Dict +from os.path import join as pjoin import numpy as np from jinja2 import Environment, Template @@ -18,55 +17,6 @@ AGENT_REGISTRY = {} -@dataclass -class AgentArgs: - random_seed: int - memory_size: int - max_steps: int - max_rewrite_steps: int - system_prompt_template_file: str | None = None - uuid: str = field(default_factory=lambda: str(uuid.uuid4())) - extras: Dict[str, Any] = field(default_factory=dict) - - @classmethod - def from_dict(cls, config: Dict[str, Any]) -> "AgentArgs": - # Get all field names from the dataclass - field_names = {f.name for f in fields(cls)} - - # Check for required fields (those without defaults) - required_fields = { - f.name - for f in fields(cls) - if f.default is MISSING and f.default_factory is MISSING - } - missing = required_fields - config.keys() - if missing: - raise ValueError( - f"Missing required agent config keys: {', '.join(sorted(missing))}" - ) - - # Separate known fields from extras - known_values = {k: v for k, v in config.items() if k in field_names} - extras = {k: v for k, v in config.items() if k not in field_names} - - # Add extras if that field exists - if "extras" in field_names: - known_values["extras"] = extras - - return cls(**known_values) - - def get(self, key: str, default=None): - if key in self.__dataclass_fields__: - return getattr(self, key) - return self.extras.get(key, default) - - def to_dict(self) -> Dict[str, Any]: - data = asdict(self) - extras = data.pop("extras", {}) - data.update(extras) - return data - - def register_agent(cls): if not issubclass(cls, BaseAgent): raise ValueError("agent_class must be a subclass of BaseAgent") @@ -83,28 +33,36 @@ class BaseAgent: def __init__( self, - agent_args: AgentArgs | Dict[str, Any], + config: dict, + env: RepoEnv, llm: LLM | None = None, logger: DebugGymLogger | None = None, ): - self.args = ( - AgentArgs.from_dict(agent_args) - if isinstance(agent_args, dict) - else agent_args - ) + self.config = config + self.env = env self.logger = logger or DebugGymLogger("debug-gym") self.llm = llm - self._uuid = self.args.uuid - self.env = None + self._uuid = self.config.get("uuid", str(uuid.uuid4())) + self._output_path = pjoin(self.config["output_path"], self._uuid) + + os.makedirs(self._output_path, exist_ok=True) + + if "memory_size" not in self.config: + self.config["memory_size"] = self.config["max_steps"] - self.set_seed(self.args.random_seed) - self.history = HistoryTracker(self.args.memory_size) + self.set_seed(self.config["random_seed"]) + self.history = HistoryTracker(self.config["memory_size"]) def set_seed(self, seed): np.random.seed(seed) def build_history_prompt(self): - return build_history_prompt(self.history, self.llm) + messages = build_history_prompt( + self.history, + self.llm, + self.config.get("reset_prompt_history_after_rewrite", False), + ) + return messages def parse_reasoning_model_response(self, response, reasoning_end_token): # Strip the reasoning, e.g., in Deepseek r1, between and . @@ -121,6 +79,45 @@ def _auto_eval_on_rewrite(self): except KeyError: return False # no eval tool + def _show_current_breakpoints(self): + """Check if current breakpoints should be shown in the system prompt.""" + return self.config.get("env_kwargs", {}).get("show_current_breakpoints", False) + + def _show_directory_tree(self): + """Check if directory tree should be shown in the system prompt.""" + return self.config.get("env_kwargs", {}).get("show_directory_tree", False) + + def shortcut_features(self): + features = [] + if self._auto_eval_on_rewrite(): + features.append( + "After successful rewrites, the environment will automatically " + "call the Eval tool to evaluate the rewritten code. Therefore, " + "you do not need to call the Eval tool yourself. The evaluation " + "output will be updated automatically in the system prompt." + ) + if self._show_directory_tree(): + features.append( + "The environment will show the directory tree of the repository in the system prompt." + ) + if self.env.has_tool("pdb"): + if self._show_current_breakpoints(): + features.append( + "The environment will show the current breakpoints in the system prompt." + ) + if self.config.get("env_kwargs", {}).get("persistent_breakpoints"): + features.append( + "The environment will automatically restore existing breakpoints " + "when a new PDB session is started (e.g., after a rewrite)." + ) + if self.config.get("env_kwargs", {}).get("auto_list"): + features.append( + "After every valid PDB tool calling, the environment will " + "automatically call the PDB tool again with a `list .` command, " + "which will show the code around the current frame." + ) + return features + @staticmethod def to_pretty_json(value): """Convert a value to a pretty JSON string.""" @@ -156,7 +153,7 @@ def _load_system_prompt_template(self) -> Template | None: """Load system prompt template from config if specified and register custom filters. If no template is specified, return None. """ - system_prompt_template = self.args.system_prompt_template_file + system_prompt_template = self.config.get("system_prompt_template_file") if system_prompt_template: if not os.path.isfile(system_prompt_template): error_msg = ( @@ -179,8 +176,28 @@ def _default_system_prompt(self, info) -> str: system_prompt_dict = { "Overall task": self.system_prompt, + "Instructions": info.instructions, } + if self._show_directory_tree(): + system_prompt_dict["Repo directory tree"] = self.trim_message( + info.dir_tree, max_length_percentage=0.1, where="end" + ) + + if self._show_current_breakpoints(): + system_prompt_dict["Current breakpoints"] = info.current_breakpoints + + if self._auto_eval_on_rewrite(): + system_prompt_dict["Evaluation output of current code"] = self.trim_message( + info.eval_observation.observation, + max_length_percentage=0.8, + where="middle", + ) + + shortcut_features = self.shortcut_features() + if shortcut_features: + system_prompt_dict["Shortcut features"] = shortcut_features + return self.to_pretty_json(system_prompt_dict) def build_system_prompt(self, info): @@ -206,20 +223,20 @@ def build_prompt(self, info): messages.extend(self.build_question_prompt()) return messages - def run(self, env: RepoEnv, debug=False): - self.env = env + def run(self, task_name=None, debug=False): step = 0 info = None - max_steps = self.args.max_steps + max_steps = self.config["max_steps"] + max_rewrite_steps = self.config.get("max_rewrite_steps") try: self.history.reset() - info = self.env.reset() + info = self.env.reset(options={"task_name": task_name}) # initial state does not have prompt and response self.history.step(info, None) if info.resolved is True: self.logger.report_progress( - problem_id=env.task_name, + problem_id=task_name, step=1, total_steps=1, score=info.score, @@ -237,7 +254,7 @@ def run(self, env: RepoEnv, debug=False): for step in range(max_steps): self.logger.info(f"\n{'='*20} STEP {step+1} {'='*20}\n") highscore = max(highscore, info.score) - msg = f"[{env.task_name[:10]:<10}] Step {step} | Score: {info.score}/{info.max_score or '-'} [Best: {highscore}]" + msg = f"[{task_name[:10]:<10}] Step {step} | Score: {info.score}/{info.max_score or '-'} [Best: {highscore}]" self.logger.info(msg) messages = self.build_prompt(info) @@ -253,19 +270,24 @@ def run(self, env: RepoEnv, debug=False): ) self.history.step(info, llm_response) - if ( - info.terminated - or info.rewrite_counter >= self.args.max_rewrite_steps - ): - reason = ( - "terminated" if info.resolved else "max_rewrite_steps reached" - ) + limit_reached = ( + max_rewrite_steps is not None + and info.rewrite_counter >= max_rewrite_steps + ) + + if info.terminated or limit_reached: + if info.resolved: + reason = "terminated" + elif limit_reached: + reason = "max_rewrite_steps reached" + else: + reason = "terminated" self.logger.info( f"Step: {step} | Score: {info.score}/{info.max_score if info.max_score else '-'} | Reason: {reason}" ) # early stop, set current step and total steps to be the same self.logger.report_progress( - problem_id=env.task_name, + problem_id=task_name, step=step + 1, total_steps=step + 1, score=info.score, @@ -275,7 +297,7 @@ def run(self, env: RepoEnv, debug=False): break # keep progress bar running until max_steps is reached self.logger.report_progress( - problem_id=env.task_name, + problem_id=task_name, step=step + 1, total_steps=max_steps + 1, score=info.score, @@ -284,7 +306,7 @@ def run(self, env: RepoEnv, debug=False): ) # max_steps was reached, task was either resolved or unresolved self.logger.report_progress( - problem_id=env.task_name, + problem_id=task_name, step=step + 1, total_steps=step + 1, score=info.score, @@ -295,11 +317,11 @@ def run(self, env: RepoEnv, debug=False): except Exception: # report any error that happens during the run self.logger.report_progress( - problem_id=env.task_name, + problem_id=task_name, step=step + 1, total_steps=step + 1, score=info.score if info else 0, - max_score=info.max_score, + max_score=info.max_score if info else None, status="error", ) raise @@ -326,12 +348,22 @@ def apply_patch(self, patch_path: str) -> bool: print("Error:", e.stderr) return False - def build_trajectory(self, task_name) -> Dict[str, Any]: - """Return the trajectory as a JSON-serializable dict without writing it.""" + def save_patch(self, task_name="custom"): + os.makedirs(pjoin(self._output_path, task_name), exist_ok=True) + patch_path = pjoin(self._output_path, task_name, "debug_gym.patch") + with open(patch_path, "w") as f: + f.write(self.env.patch) + + self.logger.debug( + f"Patch saved in {pjoin(self._output_path, task_name, 'debug_gym.patch')}" + ) + + def save_trajectory(self, task_name="custom"): + # Simple tools list. tools = [f"{tool.name}({tool.arguments})" for tool in self.env.tools] json_output = { "problem": task_name, - "config": self.args.to_dict(), + "config": self.config, "tools": self.llm.define_tools(self.env.tools) if self.llm else tools, "uuid": self._uuid, "success": self.env.resolved, @@ -342,16 +374,15 @@ def build_trajectory(self, task_name) -> Dict[str, Any]: for step_id in range(len(self.history)): step_json = self.history.json(step_id) json_output["log"].append(step_json) - return json_output + os.makedirs(pjoin(self._output_path, task_name), exist_ok=True) + json_file = pjoin(self._output_path, task_name, "trajectory.json") + with open(json_file, "w") as f: + json.dump(json_output, f, indent=4) + self.logger.debug(f"Trajectory saved in {json_file}") -def create_agent( - agent_type: str, - *, - agent_args: AgentArgs | Dict[str, Any] | None = None, - config: Dict[str, Any] | None = None, - **agent_kwargs, -): + +def create_agent(agent_type: str, **agent_kwargs): if agent_type in AGENT_REGISTRY: agent_class = AGENT_REGISTRY[agent_type] elif "." in agent_type: @@ -367,9 +398,5 @@ def create_agent( else: raise ValueError(f"Unknown agent type: {agent_type}") - agent_args = agent_args or config - if agent_args is None: - raise ValueError("Either agent_args or config must be provided.") - - agent = agent_class(args=agent_args, **agent_kwargs) + agent = agent_class(**agent_kwargs) return agent diff --git a/debug_gym/agents/debug_agent.py b/debug_gym/agents/debug_agent.py index da17b8af..0008abac 100644 --- a/debug_gym/agents/debug_agent.py +++ b/debug_gym/agents/debug_agent.py @@ -1,42 +1,32 @@ -from dataclasses import dataclass - -from debug_gym.agents.base_agent import register_agent -from debug_gym.agents.froggy_agent import FroggyAgent, FroggyAgentArgs -from debug_gym.gym.envs.env import RepoEnv +from debug_gym.agents.base_agent import BaseAgent, register_agent @register_agent -class DebugAgent(FroggyAgent): +class DebugAgent(BaseAgent): name = "debug_agent" system_prompt = "You are a debugging agent specialized in fixing Python programs. Your goal is to debug a Python program to make sure it can pass a set of test functions. You have access to a set of tools including the pdb debugger to help you investigate the code before proposing a patch. While the code may seem familiar to you from your training, you should not assume you know the code. Instead, you must use the pdb debugger to investigate the code and understand the potential bugs. A common debugging workflow is to 1) find suspicious files and lines (from error messages or test failures); 2) set breakpoints at suspicious places; 3) continue execution so the frame is at the breakpoint you set; 4) then print necessary values to identify the bugs. Once you have gained enough information, propose a rewriting patch to fix the bugs. Avoid rewriting the entire code, focus on the bugs only. You must make tool calls to interact with the environment, but you can only call one tool at a time. Do not repeat your previous action, especially if it returned tool calling errors or it resulted in information that you already know. You can spend some time thinking to help you make the decision when you are stuck, but you must be concise and avoid overthinking. If you already had a plan in the previous steps, you can just follow it without repeating the thinking process. If you are confident that you have enough information, propose a patch to fix the bugs by calling the rewrite tool. If you are not sure, continue using the pdb tool to gather more information before proposing a patch. After every rewrite, it's always a good idea to call the eval tool to execute the new code and check if it passes the tests; if it does not, the tool will return the error messages, which you can use to continue debugging. Output both your thinking process (if any) and the tool call (must) in the response. " -@dataclass -class Debug5AgentArgs(FroggyAgentArgs): - n_rewrites_before_pdb: int = 0 - - @register_agent class Debug_5_Agent(DebugAgent): name: str = "debug_5_agent" - def run(self, env: RepoEnv, debug=False): - self.env = env + def run(self, task_name=None, debug=False): step = 0 - max_steps = self.args.max_steps + max_steps = self.config["max_steps"] try: # remove the pdb tool from the environment pdb_tool = self.env.remove_tool("pdb") self.history.reset() - info = self.env.reset() + info = self.env.reset(options={"task_name": task_name}) # initial state does not have prompt and response self.history.step(info, None) if info.resolved is True: # msg = "Environment started with entrypoint passing without errors." self.logger.report_progress( - problem_id=env.task_name, + problem_id=task_name, step=1, total_steps=1, score=info.score, @@ -49,8 +39,9 @@ def run(self, env: RepoEnv, debug=False): for step in range(max_steps): self.logger.info(f"\n{'='*20} STEP {step+1} {'='*20}\n") highscore = max(highscore, info.score) - msg = f"[{env.task_name[:10]:<10}] Step {step} | Score: {info.score}/{info.max_score or '-'} [Best: {highscore}]" - self.logger.info(msg) + self.logger.info( + f"Step: {step} | Score: {info.score}/{info.max_score} ({info.score/info.max_score:.1%}) [Best: {highscore}]" + ) messages = self.build_prompt(info) llm_response = self.llm(messages, info.tools) @@ -66,7 +57,7 @@ def run(self, env: RepoEnv, debug=False): # re-introduce pdb tool at the right time if ( - info.rewrite_counter >= self.args.n_rewrites_before_pdb + info.rewrite_counter >= self.config["n_rewrites_before_pdb"] and pdb_tool.name not in self.env.tools ): self.env.add_tool(pdb_tool) @@ -79,17 +70,15 @@ def run(self, env: RepoEnv, debug=False): if ( info.terminated - or info.rewrite_counter >= self.args.max_rewrite_steps + or info.rewrite_counter >= self.config["max_rewrite_steps"] ): - reason = ( - "terminated" if info.resolved else "max_rewrite_steps reached" - ) + reason = "done" if info.resolved else "max_rewrite_steps reached" self.logger.info( - f"Step: {step} | Score: {info.score}/{info.max_score if info.max_score else '-'} | Reason: {reason}" + f"Step: {step} | Score: {info.score}/{info.max_score} ({info.score/info.max_score:.1%}) | Reason: {reason}" ) # early stop, set current step and total steps to be the same self.logger.report_progress( - problem_id=env.task_name, + problem_id=task_name, step=step + 1, total_steps=step + 1, score=info.score, @@ -99,7 +88,7 @@ def run(self, env: RepoEnv, debug=False): break # keep progress bar running until max_steps is reached self.logger.report_progress( - problem_id=env.task_name, + problem_id=task_name, step=step + 1, total_steps=max_steps + 1, score=info.score, @@ -108,7 +97,7 @@ def run(self, env: RepoEnv, debug=False): ) # max_steps was reached, task was either resolved or unresolved self.logger.report_progress( - problem_id=env.task_name, + problem_id=task_name, step=step + 1, total_steps=step + 1, score=info.score, @@ -119,11 +108,11 @@ def run(self, env: RepoEnv, debug=False): except Exception: # report any error that happens during the run self.logger.report_progress( - problem_id=env.task_name, + problem_id=task_name, step=step + 1, total_steps=step + 1, score=info.score if info else 0, - max_score=info.max_score if info else None, + max_score=info.max_score if info else 1, status="error", ) raise diff --git a/debug_gym/agents/free_agent.py b/debug_gym/agents/free_agent.py new file mode 100644 index 00000000..951bde3b --- /dev/null +++ b/debug_gym/agents/free_agent.py @@ -0,0 +1,44 @@ +"""Simple agent example for interacting with FreeEnv.""" + +from debug_gym.agents.base_agent import BaseAgent, register_agent + + +@register_agent +class FreeAgent(BaseAgent): + """Minimal reasoning agent tailored for FreeEnv sessions.""" + + name = "free_agent" + # Customized system instructions keep FreeEnv light-weight while still + # providing the model with a structured exploration checklist. + system_prompt = ( + "You are assisting in an exploratory codebase understanding session inside an open-ended container.\n" + "You have access to a set of tools to inspect and modify the codebase.\n" + "Your goal is to use the tools to gather as much information about the codebase as possible.\n" + "Output both your thinking process (if any) and the tool call (must) in the response.\n" + "When you are done exploring, use the submit tool as the final action to end the session." + ) + + def __init__(self, config, env, llm=None, logger=None): + super().__init__(config=config, env=env, llm=llm, logger=logger) + + override_prompt = config.get("system_prompt") + if override_prompt is not None: + self.system_prompt = str(override_prompt) + + def run(self, task_name=None, debug=False): + """Wrap BaseAgent.run to surface clearer errors when startup fails.""" + try: + return super().run(task_name=task_name, debug=debug) + except AttributeError as exc: + error_msg = str(exc) + sentinel = "'NoneType' object has no attribute 'max_score'" + if sentinel not in error_msg: + raise + + root_cause = exc.__context__ or exc.__cause__ or exc + self.logger.error( + "FreeAgent failed to reset the environment before receiving initial observations. " + "Check that the configured container image exists and is accessible." + ) + + raise root_cause diff --git a/debug_gym/agents/froggy_agent.py b/debug_gym/agents/froggy_agent.py deleted file mode 100644 index 5fd2e475..00000000 --- a/debug_gym/agents/froggy_agent.py +++ /dev/null @@ -1,110 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Dict - -from debug_gym.agents.base_agent import AgentArgs, BaseAgent, register_agent -from debug_gym.agents.history_tracker import build_history_prompt - - -@dataclass -class FroggyAgentArgs(AgentArgs): - show_directory_tree: int = 0 - show_current_breakpoints: bool = False - reset_prompt_history_after_rewrite: bool = False - n_rewrites_before_pdb: int = 0 - - -@register_agent -class FroggyAgent(BaseAgent): - name: str = "froggy" - - def __init__( - self, - agent_args: FroggyAgentArgs | Dict[str, Any], - *args, - **kwargs, - ): - - agent_args = ( - FroggyAgentArgs.from_dict(agent_args) - if isinstance(agent_args, dict) - else agent_args - ) - super().__init__(agent_args, *args, **kwargs) - - def build_history_prompt(self): - messages = build_history_prompt( - self.history, - self.llm, - self.args.reset_prompt_history_after_rewrite, - ) - return messages - - def _auto_eval_on_rewrite(self): - """Check if auto eval on rewrite is enabled.""" - try: - return self.env.get_tool("eval").auto_eval_on_rewrite - except KeyError: - return False # no eval tool - - def shortcut_features(self): - features = [] - if self._auto_eval_on_rewrite(): - features.append( - "After successful rewrites, the environment will automatically " - "call the Eval tool to evaluate the rewritten code. Therefore, " - "you do not need to call the Eval tool yourself. The evaluation " - "output will be updated automatically in the system prompt." - ) - if self.args.show_directory_tree: - features.append( - "The environment will show the directory tree of the repository in the system prompt." - ) - if self.env.has_tool("pdb"): - if self.args.show_current_breakpoints: - features.append( - "The environment will show the current breakpoints in the system prompt." - ) - if self.env.get_tool("pdb").persistent_breakpoints: - features.append( - "The environment will automatically restore existing breakpoints " - "when a new PDB session is started (e.g., after a rewrite)." - ) - if self.env.get_tool("pdb").auto_list: - features.append( - "After every valid PDB tool calling, the environment will " - "automatically call the PDB tool again with a `list .` command, " - "which will show the code around the current frame." - ) - return features - - def _default_system_prompt(self, info) -> str: - """Return the default system prompt as pretty JSON. - Trimmed to fit within the token limit.""" - - system_prompt_dict = { - "Overall task": self.system_prompt, - "Instructions": info.instructions, - } - - if self.args.show_directory_tree > 0: - system_prompt_dict["Repo directory tree"] = self.trim_message( - self.env.workspace.display_files(self.args.show_directory_tree), - max_length_percentage=0.1, - where="end", - ) - - if self.args.show_current_breakpoints: - system_prompt_dict["Current breakpoints"] = info.current_breakpoints - - if self._auto_eval_on_rewrite(): - system_prompt_dict["Evaluation output of current code"] = self.trim_message( - info.eval_observation.observation, - max_length_percentage=0.8, - where="middle", - ) - - shortcut_features = self.shortcut_features() - if shortcut_features: - system_prompt_dict["Shortcut features"] = shortcut_features - - return self.to_pretty_json(system_prompt_dict) diff --git a/debug_gym/agents/rewrite_agent.py b/debug_gym/agents/rewrite_agent.py index 0318b4ba..b1abafa2 100644 --- a/debug_gym/agents/rewrite_agent.py +++ b/debug_gym/agents/rewrite_agent.py @@ -1,9 +1,8 @@ -from debug_gym.agents.base_agent import register_agent -from debug_gym.agents.froggy_agent import FroggyAgent +from debug_gym.agents.base_agent import BaseAgent, register_agent @register_agent -class RewriteAgent(FroggyAgent): +class RewriteAgent(BaseAgent): name: str = "rewrite_agent" system_prompt: str = ( "Your goal is to debug a Python program to make sure it can pass a set of test functions. You have access to a set of tools, you can use them to investigate the code and propose a rewriting patch to fix the bugs. Avoid rewriting the entire code, focus on the bugs only. You must make tool calls to interact with the environment, but you can only call one tool at a time. Do not repeat your previous action unless they can provide more information. You can spend some time thinking to help you make the decision when you are stuck, but you must be concise and avoid overthinking. If you already had a plan in the previous steps, you can just follow it without repeating the thinking process. Output both your thinking process (if any) and the tool call (must) in the response. " diff --git a/debug_gym/agents/solution_agent.py b/debug_gym/agents/solution_agent.py index e0f31af8..ede86067 100644 --- a/debug_gym/agents/solution_agent.py +++ b/debug_gym/agents/solution_agent.py @@ -20,8 +20,7 @@ def _env_implements_apply_gold_patch(self): """Fail early if the environment does not implement apply_gold_patch.""" return hasattr(self.env, "apply_gold_patch") - def run(self, env, debug=False): - self.env = env + def run(self, task_name=None, debug=False): info = None try: if not self._env_implements_apply_gold_patch(): @@ -31,11 +30,11 @@ def run(self, env, debug=False): ) self.history.reset() - info = self.env.reset() + info = self.env.reset(options={"task_name": task_name}) self.history.step(info) if info.resolved is True: - self._report_progress(env.task_name, info, "resolved") + self._report_progress(task_name, info, "resolved") return True self.logger.info(f"Score: {info.score}/{info.max_score or '-'}") @@ -77,8 +76,8 @@ def run(self, env, debug=False): "The task is not done after applying the gold patch.\n" f"{info.step_observation.observation}" ) - self._report_progress(env.task_name, info, "resolved") + self._report_progress(task_name, info, "resolved") except Exception: - self._report_progress(env.task_name, info, "error") + self._report_progress(task_name, info, "error") raise return info.resolved diff --git a/debug_gym/agents/utils.py b/debug_gym/agents/utils.py index 65578d5c..3d694237 100644 --- a/debug_gym/agents/utils.py +++ b/debug_gym/agents/utils.py @@ -1,13 +1,9 @@ import argparse -import json import logging import os -from pathlib import Path import yaml -from debug_gym.logger import DebugGymLogger - def load_config(): parser = argparse.ArgumentParser() @@ -143,24 +139,3 @@ def load_config(): return_config["agent_type"] = args.agent return return_config, args - - -def save_patch(env, problem_path: Path, logger: DebugGymLogger): - """Persist the current environment patch to disk.""" - problem_path.mkdir(parents=True, exist_ok=True) - patch_path = problem_path / "debug_gym.patch" - with open(patch_path, "w") as f: - f.write(env.patch) - - logger.debug(f"Patch saved in {patch_path}") - - -def save_trajectory(agent, problem: str, problem_path: Path, logger: DebugGymLogger): - """Persist the agent trajectory to disk.""" - problem_path.mkdir(parents=True, exist_ok=True) - trajectory = agent.build_trajectory(task_name=problem) - json_file = problem_path / "trajectory.json" - with open(json_file, "w") as f: - json.dump(trajectory, f, indent=4) - - logger.debug(f"Trajectory saved in {json_file}") diff --git a/debug_gym/gym/envs/__init__.py b/debug_gym/gym/envs/__init__.py index 86ef4cab..89dfd501 100644 --- a/debug_gym/gym/envs/__init__.py +++ b/debug_gym/gym/envs/__init__.py @@ -1,5 +1,6 @@ from debug_gym.gym.envs.aider import AiderBenchmarkEnv from debug_gym.gym.envs.env import RepoEnv, TooledEnv +from debug_gym.gym.envs.free_env import FreeEnv from debug_gym.gym.envs.mini_nightmare import MiniNightmareEnv from debug_gym.gym.envs.r2egym import R2EGymEnv from debug_gym.gym.envs.swe_bench import SWEBenchEnv @@ -23,5 +24,7 @@ def select_env(env_type: str = None) -> type[RepoEnv]: return MiniNightmareEnv case "r2egym": return R2EGymEnv + case "free": + return FreeEnv case _: raise ValueError(f"Unknown benchmark {env_type}") diff --git a/debug_gym/gym/envs/env.py b/debug_gym/gym/envs/env.py index 48807546..210c92f6 100644 --- a/debug_gym/gym/envs/env.py +++ b/debug_gym/gym/envs/env.py @@ -17,6 +17,7 @@ class EnvInfo: step_observation: Observation all_observations: list[Observation] # env.step + triggered tools obs eval_observation: Observation | None # last eval observation + dir_tree: str current_breakpoints: str action_reasoning: str | None action_content: str | None @@ -77,6 +78,15 @@ def __str__(self) -> str: ) lines.append("") + # Directory tree section (truncated) + lines.append("📁 Directory Structure:") + tree_lines = self.dir_tree.split("\n") + for line in tree_lines[:10]: # Show first 10 lines + lines.append(f" {line}") + if len(tree_lines) > 10: + lines.append(f" ... and {len(tree_lines) - 10} more files/directories") + + lines.append("=" * 60) return "\n".join(lines) @@ -207,6 +217,9 @@ def __init__( max_score: int | None = None, readonly_patterns: list[str] | None = None, # TODO: remove run_timeout: int | None = None, + dir_tree_depth: int = 1, + persistent_breakpoints: bool = True, # TODO: remove + auto_list: bool = True, # TODO: remove terminal: Terminal | None = None, logger: DebugGymLogger | None = None, problems: str | list[str] | None = None, @@ -217,15 +230,16 @@ def __init__( self.path = path self.max_score = max_score self.run_timeout = run_timeout + self.dir_tree_depth = dir_tree_depth self.terminal = terminal or LocalTerminal() # TODO: default to DockerTerminal self._entrypoint = entrypoint self._debug_entrypoint = debug_entrypoint + self.persistent_breakpoints = persistent_breakpoints + self.auto_list = auto_list self.logger = logger or DebugGymLogger("debug-gym") self.infos: EnvInfo | None = None self.rng = None self.additional_kwargs = kwargs - self.task_name: str | None = None - self.options: dict = {} if "auto_eval_on_rewrite" in kwargs: raise ValueError( @@ -322,11 +336,10 @@ def setup_terminal(self) -> None: def reset(self, *, options: dict = None): """Resets the environment and returns eval as the initial observation.""" - self.options = options if options is not None else self.options + options = options or {} self.logger.debug("Resetting environment") self.close() # Clean up previous workspace and terminal. - self.task_name = self.options.get("task_name") - self.setup_task(task_name=self.task_name, options=self.options) + self.setup_task(task_name=options.get("task_name"), options=options) self.setup_workspace() self.setup_terminal() self._reset_env_state() @@ -351,6 +364,7 @@ def reset(self, *, options: dict = None): eval_observation=( Observation("env", self.last_eval.output) if self.last_eval else None ), + dir_tree=self.workspace.display_files(self.dir_tree_depth), current_breakpoints=self.current_breakpoints(), action_reasoning=None, action_content=None, @@ -477,6 +491,7 @@ def step( eval_observation=( Observation("env", self.last_eval.output) if self.last_eval else None ), + dir_tree=self.workspace.display_files(self.dir_tree_depth), current_breakpoints=self.current_breakpoints(), action_reasoning=action_reasoning, action_content=action_content, diff --git a/debug_gym/gym/envs/free_env.py b/debug_gym/gym/envs/free_env.py new file mode 100644 index 00000000..1942f249 --- /dev/null +++ b/debug_gym/gym/envs/free_env.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import shlex +from pathlib import Path +from typing import Any + +from debug_gym.gym.envs.env import RepoEnv +from debug_gym.gym.terminals.local import LocalTerminal +from debug_gym.gym.terminals.terminal import Terminal +from debug_gym.logger import DebugGymLogger + + +class FreeEnv(RepoEnv): + """Lightweight RepoEnv wrapper for running arbitrary container images.""" + + DEFAULT_TASK_NAME = "free-session" + + def __init__( + self, + image: str, + *, + terminal: Terminal | None = None, + mount_path: str | Path | None = None, + setup_commands: list[str] | None = None, + instructions: str | None = None, + init_git: bool = True, + workspace_dir: str | Path = "/testbed", + logger: DebugGymLogger | None = None, + **env_kwargs: Any, + ) -> None: + """Create a free-form environment backed by an existing repository terminal.""" + self.container_image = image + self._custom_instructions = (instructions or "").strip() + self.init_git = init_git + self._setup_commands = list(setup_commands or []) + self._workspace_dir = str(workspace_dir) + + shared_logger = logger or DebugGymLogger("debug-gym") + + super().__init__( + path=str(mount_path) if mount_path is not None else None, + entrypoint="true", + debug_entrypoint="true", + max_score=0, + terminal=terminal, + logger=shared_logger, + **env_kwargs, + ) + + if self.terminal is not None: + self._apply_terminal_settings() + + def _apply_terminal_settings(self) -> None: + """Keep terminal metadata (image/setup commands) in sync with env state.""" + terminal = self.terminal + if terminal is None: + return + if hasattr(terminal, "base_image"): + setattr(terminal, "base_image", self.container_image) + + if hasattr(terminal, "setup_commands"): + terminal.setup_commands = list(self._setup_commands) + + if hasattr(terminal, "working_dir") and not isinstance(terminal, LocalTerminal): + try: + terminal.working_dir = self._workspace_dir + except ValueError: + self.logger.debug( + "Terminal already active; keeping working_dir=%s", + getattr(terminal, "working_dir", self._workspace_dir), + ) + + if hasattr(terminal, "task_name"): + try: + terminal.task_name = self.DEFAULT_TASK_NAME + except ValueError: + self.logger.debug( + "Terminal already active; keeping existing task name." + ) + + terminal.logger = self.logger + + def load_dataset(self, problems: str | list[str] | None = None): + """Expose a single synthetic task keyed by DEFAULT_TASK_NAME.""" + return {self.DEFAULT_TASK_NAME: {"image": self.container_image}} + + def setup_task(self, task_name: str | None, options: dict | None = None) -> None: + """Record base image metadata for consistency with RepoEnv expectations.""" + self.task_name = task_name or self.DEFAULT_TASK_NAME + self.base_image = self.container_image + if hasattr(self.terminal, "base_image"): + setattr(self.terminal, "base_image", self.base_image) + + def setup_workspace(self) -> None: + """Ensure the remote workspace matches the configured working directory.""" + if isinstance(self.terminal, LocalTerminal): + super().setup_workspace() + return + + self.workspace.reset() + self.workspace.working_dir = Path(self._workspace_dir) + if self.terminal is not None: + current_dir = getattr(self.terminal, "working_dir", None) + if current_dir != self._workspace_dir: + try: + self.terminal.working_dir = self._workspace_dir + except ValueError: + self.logger.debug( + "Terminal already active; keeping working_dir=%s", current_dir + ) + # Ensure core utilities exist before RepoEnv renders directory listings. + self.terminal.run( + "apt-get update -y && apt-get install -y tree", raises=True + ) + self.terminal.run( + f"mkdir -p {shlex.quote(self._workspace_dir)}", + raises=True, + ) + + if self.path: + self.workspace.copy_content(self.path) + + self.workspace.setup_file_filters() + + def setup_terminal(self) -> None: + """Apply FreeEnv tweaks and reuse RepoEnv git bootstrapping when enabled.""" + self._apply_terminal_settings() + + if self.terminal is not None: + self.terminal.run("touch .debugignore .debugreadonly") + + if not self.init_git: + return + if not self._git_available(): + self.logger.debug( + "Git is not available in the container; skipping repository setup.", + ) + return + super().setup_terminal() + + def _git_available(self) -> bool: + """Check for git presence before attempting repository initialization.""" + if self.terminal is None: + return False + success, _ = self.terminal.run("command -v git") + return success + + @property + def instructions(self) -> str: + """Provide user-facing guidance, falling back to a generic sandbox blurb.""" + return ( + self._custom_instructions + or "You are placed in an isolated Linux environment, use the available tools to interact with the environment effectively." + ) + + def reset(self, *, options: dict | None = None): + """Allow callers to mutate container settings before delegating to RepoEnv.""" + options = options or {} + + image = options.get("image") + workspace_dir = options.get("workspace_dir") + setup_commands = options.get("setup_commands") + instructions = options.get("instructions") + init_git = options.get("init_git") + + restart_terminal = False + + if image and image != self.container_image: + self.container_image = image + restart_terminal = True + + if workspace_dir and str(workspace_dir) != self._workspace_dir: + self._workspace_dir = str(workspace_dir) + restart_terminal = True + + if setup_commands is not None: + new_commands = list(setup_commands) + if new_commands != self._setup_commands: + self._setup_commands = new_commands + restart_terminal = True + + if instructions is not None: + self._custom_instructions = instructions + + if init_git is not None: + self.init_git = bool(init_git) + + if restart_terminal and self.terminal is not None: + try: + self.terminal.close() + except Exception as exc: # noqa: BLE001 - diagnostics only + self.logger.debug("Failed to close terminal cleanly: %s", exc) + + self._apply_terminal_settings() + + return super().reset(options=options) diff --git a/debug_gym/gym/terminals/__init__.py b/debug_gym/gym/terminals/__init__.py index 068a8b6a..1f34e1d8 100644 --- a/debug_gym/gym/terminals/__init__.py +++ b/debug_gym/gym/terminals/__init__.py @@ -13,8 +13,20 @@ def select_terminal( if terminal_config is None: return None + if isinstance(terminal_config, Terminal): + return terminal_config + + if not isinstance(terminal_config, dict): + raise TypeError( + "terminal configuration must be a dict, Terminal instance, or None", + ) + + config = dict(terminal_config) + terminal_type = str(config.pop("type", "")).lower() + if not terminal_type: + raise ValueError("Terminal configuration must include a 'type' key") + logger = logger or DebugGymLogger("debug-gym") - terminal_type = terminal_config["type"] match terminal_type: case "docker": terminal_class = DockerTerminal @@ -25,8 +37,17 @@ def select_terminal( case _: raise ValueError(f"Unknown terminal {terminal_type}") + extra_labels = config.pop("extra_labels", {}) or {} + if uuid is not None: + extra_labels = {**extra_labels, "uuid": uuid} + + if terminal_class is KubernetesTerminal and extra_labels: + config["extra_labels"] = extra_labels + + if terminal_class is not KubernetesTerminal: + config.pop("extra_labels", None) + return terminal_class( - **terminal_config, logger=logger, - extra_labels={"uuid": uuid}, + **config, ) diff --git a/debug_gym/gym/terminals/kubernetes.py b/debug_gym/gym/terminals/kubernetes.py index 5c5f39bc..9d731bd8 100644 --- a/debug_gym/gym/terminals/kubernetes.py +++ b/debug_gym/gym/terminals/kubernetes.py @@ -370,9 +370,28 @@ def default_shell_command(self) -> list[str]: bash_cmd = "/bin/bash --noprofile --norc --noediting" return f"kubectl {kubeconfig}exec -it {self.pod.name} -c main -n {self.pod.namespace} -- {bash_cmd}" + def _ensure_pod_running(self) -> None: + """Ensure the backing pod exists and is in Running phase.""" + if self._pod is None: + self.setup_pod() + return + + try: + if self._pod.is_running(): + return + except Exception as exc: # noqa: BLE001 - diagnostics only + self.logger.debug(f"{self._pod} status check failed: {exc}") + + self.logger.warning(f"{self._pod} not running; recreating pod.") + try: + self._pod.clean_up() + except Exception as exc: # noqa: BLE001 - best-effort cleanup + self.logger.debug(f"Failed to clean up {self._pod}: {exc}") + self._pod = None + self.setup_pod() + def new_shell_session(self): - if not self.pod.is_running(): - raise ValueError("Pod is not running. Cannot create shell session.") + self._ensure_pod_running() session = ShellSession( shell_command=self.default_shell_command, @@ -416,8 +435,7 @@ def run( strip_output: bool = True, ) -> tuple[bool, str]: """Run a command in the pod. Return command status and output.""" - if not self.pod.is_running(): - raise ValueError("Pod is not running. Cannot run commands.") + self._ensure_pod_running() command = self.prepare_command(entrypoint) diff --git a/debug_gym/gym/tools/pdb.py b/debug_gym/gym/tools/pdb.py index 107a9d33..46ed1720 100644 --- a/debug_gym/gym/tools/pdb.py +++ b/debug_gym/gym/tools/pdb.py @@ -35,26 +35,12 @@ class PDBTool(EnvironmentTool): }, } - def __init__( - self, - set_default_entrypoint: bool = True, - auto_list: bool = True, - persistent_breakpoints: bool = True, - ): - """ - Args: - set_default_entrypoint (bool): If True, the tool will use the environment's default debug entrypoint - when no entrypoint is provided. If False, the agent must provide an entrypoint when using the tool. - auto_list (bool): If True, the tool will automatically provide context around the current frame after each command. - persistent_breakpoints (bool): If True, the tool will keep breakpoints across PDB sessions. - """ + def __init__(self, set_default_entrypoint: bool = True): super().__init__() self.current_frame_file = None self._session: ShellSession = None self.set_default_entrypoint = set_default_entrypoint self.entrypoint = None - self.auto_list = auto_list - self.persistent_breakpoints = persistent_breakpoints if not self.set_default_entrypoint: # Force the agent to provide an entrypoint when using the tool. self.arguments = copy.deepcopy( @@ -128,7 +114,7 @@ def start_pdb(self, environment) -> str: self.stop_pdb() if self.pdb_is_running: - if self.persistent_breakpoints: + if environment.persistent_breakpoints: # restore persistent breakpoints for _, _command in environment.current_breakpoints_state.items(): self.interact_with_pdb(_command, environment.run_timeout) @@ -266,7 +252,7 @@ def use( # free 'list' to provide context around the current frame list_output = "" - if self.auto_list and command.split()[0] not in ["l", "list"]: + if environment.auto_list and command.split()[0] not in ["l", "list"]: list_output = self.interact_with_pdb("l .", environment.run_timeout) if current_frame: diff --git a/debug_gym/gym/tools/submit.py b/debug_gym/gym/tools/submit.py index eb08e578..00514992 100644 --- a/debug_gym/gym/tools/submit.py +++ b/debug_gym/gym/tools/submit.py @@ -9,7 +9,14 @@ class SubmitTool(EnvironmentTool): description = "Submit your changes once the task is complete." arguments = {} + def __init__(self, eval_on_submit=True): + super().__init__() + self.eval_on_submit = eval_on_submit + def use(self, environment, **kwargs) -> Observation: - eval_output = environment.eval() + output = "The agent terminated the session." + if self.eval_on_submit: + output = environment.eval().output + environment.terminated = True - return Observation(self.name, eval_output.output) + return Observation(self.name, output) diff --git a/debug_gym/llms/openai.py b/debug_gym/llms/openai.py index ed0b9006..b98e57df 100644 --- a/debug_gym/llms/openai.py +++ b/debug_gym/llms/openai.py @@ -261,6 +261,11 @@ def generate(self, messages, tools, **kwargs) -> LLMResponse: if self.is_context_length_error(e): raise ContextLengthExceededError raise e + if not hasattr(response, "choices"): + raise RuntimeError( + "OpenAI chat completion returned unexpected payload without 'choices'" + ) + # LLM may select multiple tool calls, we only care about the first action if not response.choices[0].message.tool_calls: # LLM failed to call a tool diff --git a/scripts/config.yaml b/scripts/config.yaml index ee3952c5..3083f264 100644 --- a/scripts/config.yaml +++ b/scripts/config.yaml @@ -5,7 +5,13 @@ base: "path": "data/pytorch", "entrypoint": "python -m pytest -sv test.py", "debug_entrypoint": "python -m pdb -m pytest -s test.py", + "dir_tree_depth": 1, "run_timeout": 10, + # shortcut features + "show_current_breakpoints": False, # If True, the environment will automatically show the current breakpoints at every step in the system prompt. + "show_directory_tree": True, # If set to True, the environment will show the directory tree in the system prompt. + "persistent_breakpoints": True, # If True, the environemnt will keep a set of breakpoint states across PDB sessions. When a new PDB session is started, the environment will automatically load the breakpoints from the previous session. + "auto_list": True, # If True, the environment will automatically call `list .` via the PDB tool after every pdb tool call, which will show the code around the current frame. } tools: ["pdb", "view", "rewrite"] terminal: { @@ -25,10 +31,6 @@ base: # Optionally loads a custom system prompt template from a file. # system_prompt_template_file: "script/templates/system_prompt.jinja" - # Shortcut features - "show_current_breakpoints": False # If True, the environment will automatically show the current breakpoints at every step in the system prompt. - "show_directory_tree": 0 # Value indicated the depth of the directory shown in the system prompt. 0 means no directory tree is shown. - rewrite_agent: tools: - grep diff --git a/scripts/config_aider.yaml b/scripts/config_aider.yaml index 88dd68fb..39d80e61 100644 --- a/scripts/config_aider.yaml +++ b/scripts/config_aider.yaml @@ -4,7 +4,13 @@ base: benchmark: "aider" problems: "all" # list of problems, e.g., ["wordy"], or "all" env_kwargs: { + "dir_tree_depth": 1, "run_timeout": 20, + # shortcut features + "show_current_breakpoints": False, # If True, the environment will automatically show the current breakpoints at every step in the system prompt. + "show_directory_tree": True, # If set to True, the environment will show the directory tree in the system prompt. + "persistent_breakpoints": True, # If True, the environemnt will keep a set of breakpoint states across PDB sessions. When a new PDB session is started, the environment will automatically load the breakpoints from the previous session. + "auto_list": True, # If True, the environment will automatically call `list .` via the PDB tool after every pdb tool call, which will show the code around the current frame. } terminal: { type: "docker", # "docker", "kubernetes", or "local" @@ -23,10 +29,6 @@ base: # Optionally loads a custom system prompt template from a file. # system_prompt_template_file: "script/templates/system_prompt.jinja" - # Shortcut features - "show_current_breakpoints": False # If True, the environment will automatically show the current breakpoints at every step in the system prompt. - "show_directory_tree": 0 # Value indicated the depth of the directory shown in the system prompt. 0 means no directory tree is shown. - rewrite_agent: tools: - grep diff --git a/scripts/config_free_env.yaml b/scripts/config_free_env.yaml new file mode 100644 index 00000000..2f9b06cc --- /dev/null +++ b/scripts/config_free_env.yaml @@ -0,0 +1,42 @@ +# Configuration for standalone FreeEnv + FreeAgent runs. +task_name: free-session + +llm: + name: "4o-az" + +# Tools to load into the environment toolbox. +tools: + - rewrite + - bash + - submit: + eval_on_submit: False # Here we only terminate after submission, no auto-eval. + +environment: + image: jyangballin/swesmith.x86_64.amueller_1776_word_cloud.ec24191c:latest + workspace_dir: /testbed + terminal: + type: docker + # type: kubernetes + # registry: debuggymacr.azurecr.io + # namespace: mtl-cpu-jobs + # kube_config: ~/.kube/config + # # kube_context: null + # pod_spec_kwargs: + # tolerations: + # - key: node.kubernetes.io/disk-pressure + # operator: Exists + # effect: NoExecute + # tolerationSeconds: 10800 + # - key: kubernetes.azure.com/scalesetpriority + # operator: Equal + # value: spot + # effect: NoSchedule + # - key: CriticalAddonsOnly + # operator: Equal + # value: "true" + # effect: NoSchedule + +agent: + random_seed: 42 + max_steps: 20 + output_path: exps/free_env diff --git a/scripts/config_mini_nightmare.yaml b/scripts/config_mini_nightmare.yaml index 88fbc08a..ef7a20f2 100644 --- a/scripts/config_mini_nightmare.yaml +++ b/scripts/config_mini_nightmare.yaml @@ -4,8 +4,13 @@ base: benchmark: "mini_nightmare" problems: "all" # list of problems, e.g., ["config"], or "all" env_kwargs: { + "dir_tree_depth": 1, "run_timeout": 30, # shortcut features + "show_current_breakpoints": False, # If True, the environment will automatically show the current breakpoints at every step in the system prompt. + "show_directory_tree": True, # If set to True, the environment will show the directory tree in the system prompt. + "persistent_breakpoints": True, # If True, the environemnt will keep a set of breakpoint states across PDB sessions. When a new PDB session is started, the environment will automatically load the breakpoints from the previous session. + "auto_list": True, # If True, the environment will automatically call `list .` via the PDB tool after every pdb tool call, which will show the code around the current frame. } terminal: { @@ -25,10 +30,6 @@ base: # Optionally loads a custom system prompt template from a file. # system_prompt_template_file: "script/templates/system_prompt.jinja" - # Shortcut features - "show_current_breakpoints": False # If True, the environment will automatically show the current breakpoints at every step in the system prompt. - "show_directory_tree": 0 # Value indicated the depth of the directory shown in the system prompt. 0 means no directory tree is shown. - rewrite_agent: tools: - grep diff --git a/scripts/config_r2egym.yaml b/scripts/config_r2egym.yaml index 8d14b79e..9fc61cb6 100644 --- a/scripts/config_r2egym.yaml +++ b/scripts/config_r2egym.yaml @@ -4,9 +4,16 @@ base: benchmark: "r2egym" problems: "all" # list of problems, e.g., ["astropy__astropy-12907"], or strings like "test-125" (defined in gym/envs/configs), or "all", env_kwargs: { + "dir_tree_depth": 1, "run_timeout": 300, dataset_id: "R2E-Gym/R2E-Gym-Lite", - dataset_revision: "8d3163011f01f9393bb3dc7700497a79a8686ae5" + dataset_revision: "8d3163011f01f9393bb3dc7700497a79a8686ae5", + + # shortcut features + "show_current_breakpoints": False, # If True, the environment will automatically show the current breakpoints at every step in the system prompt. + "show_directory_tree": True, # If set to True, the environment will show the directory tree in the system prompt. + "persistent_breakpoints": True, # If True, the environemnt will keep a set of breakpoint states across PDB sessions. When a new PDB session is started, the environment will automatically load the breakpoints from the previous session. + "auto_list": True, # If True, the environment will automatically call `list .` via the PDB tool after every pdb tool call, which will show the code around the current frame. } terminal: { type: "docker", # "docker", "kubernetes" @@ -25,10 +32,6 @@ base: # Optionally loads a custom system prompt template from a file. # system_prompt_template_file: "script/templates/system_prompt.jinja" - # Shortcut features - "show_current_breakpoints": False # If True, the environment will automatically show the current breakpoints at every step in the system prompt. - "show_directory_tree": 0 # Value indicated the depth of the directory shown in the system prompt. 0 means no directory tree is shown. - rewrite_agent: tools: - grep diff --git a/scripts/config_swebench.yaml b/scripts/config_swebench.yaml index 8bc0ba55..b7c9352e 100644 --- a/scripts/config_swebench.yaml +++ b/scripts/config_swebench.yaml @@ -4,9 +4,15 @@ base: benchmark: "swebench-debug" problems: "all" # list of problems, e.g., ["astropy__astropy-12907"], or "all" env_kwargs: { + "dir_tree_depth": 1, "run_timeout": 300, "dataset_id": "SWE-bench/SWE-bench_Verified", "dataset_revision": "99450355ca8c611021187a57ffac304b66666738", + # shortcut features + "show_current_breakpoints": False, # If True, the environment will automatically show the current breakpoints at every step in the system prompt. + "show_directory_tree": True, # If set to True, the environment will show the directory tree in the system prompt. + "persistent_breakpoints": True, # If True, the environemnt will keep a set of breakpoint states across PDB sessions. When a new PDB session is started, the environment will automatically load the breakpoints from the previous session. + "auto_list": True, # If True, the environment will automatically call `list .` via the PDB tool after every pdb tool call, which will show the code around the current frame. } terminal: { type: "docker", # "docker", "kubernetes" @@ -25,12 +31,8 @@ base: # Optionally loads a custom system prompt template from a file. # system_prompt_template_file: "script/templates/system_prompt.jinja" - # Shortcut features - "show_current_breakpoints": False # If True, the environment will automatically show the current breakpoints at every step in the system prompt. - "show_directory_tree": 0 # Value indicated the depth of the directory shown in the system prompt. 0 means no directory tree is shown. - rewrite_agent: - tools: + tools: - grep - view - rewrite @@ -39,7 +41,7 @@ rewrite_agent: auto_eval_on_rewrite: False # If True, the environment will automatically call the Eval tool after a successful rewrite. If this is set to True, the agent does not need to call the Eval tool itself. debug_agent: - tools: + tools: - grep - pdb - view @@ -49,7 +51,7 @@ debug_agent: debug_5_agent: n_rewrites_before_pdb: 5 - tools: + tools: - grep - pdb - view @@ -59,7 +61,7 @@ debug_5_agent: solution_agent: llm_name: null # No need for an LLM. - tools: + tools: - eval - pdb - submit diff --git a/scripts/config_swesmith.yaml b/scripts/config_swesmith.yaml index 5862e240..88a3215c 100644 --- a/scripts/config_swesmith.yaml +++ b/scripts/config_swesmith.yaml @@ -4,8 +4,15 @@ base: benchmark: "swesmith" problems: "all" # list of problems, e.g., ["astropy__astropy-12907"], or strings like "test-125" (defined in gym/envs/configs), or "all", env_kwargs: { + "dir_tree_depth": 1, "run_timeout": 300, - "dataset_id": "SWE-bench/SWE-smith" + "dataset_id": "SWE-bench/SWE-smith", + + # shortcut features + "show_current_breakpoints": False, # If True, the environment will automatically show the current breakpoints at every step in the system prompt. + "show_directory_tree": True, # If set to True, the environment will show the directory tree in the system prompt. + "persistent_breakpoints": True, # If True, the environemnt will keep a set of breakpoint states across PDB sessions. When a new PDB session is started, the environment will automatically load the breakpoints from the previous session. + "auto_list": True, # If True, the environment will automatically call `list .` via the PDB tool after every pdb tool call, which will show the code around the current frame. } terminal: { type: "docker", # "docker", "kubernetes" @@ -24,10 +31,6 @@ base: # Optionally loads a custom system prompt template from a file. # system_prompt_template_file: "script/templates/system_prompt.jinja" - # Shortcut features - "show_current_breakpoints": False # If True, the environment will automatically show the current breakpoints at every step in the system prompt. - "show_directory_tree": 0 # Value indicated the depth of the directory shown in the system prompt. 0 means no directory tree is shown. - rewrite_agent: tools: - grep diff --git a/scripts/free_env_human.py b/scripts/free_env_human.py new file mode 100644 index 00000000..a82a9e6a --- /dev/null +++ b/scripts/free_env_human.py @@ -0,0 +1,200 @@ +"""Interactive FreeEnv demo that runs a container image with a human operator.""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Any, Iterable + +from debug_gym.gym.envs.free_env import FreeEnv +from debug_gym.gym.terminals import select_terminal +from debug_gym.gym.terminals.terminal import Terminal +from debug_gym.gym.tools.toolbox import Toolbox +from debug_gym.llms.human import Human +from debug_gym.logger import DebugGymLogger + +DEFAULT_IMAGE = "swesmith.x86_64.amueller__word_cloud.ec24191c" +DEFAULT_TOOLS = [ + "listdir", + "view", + "grep", + "rewrite", + "bash", + {"submit": {"eval_on_submit": False}}, +] + + +def format_observations(env_info) -> list[dict]: + messages = [ + { + "role": "system", + "content": env_info.instructions or "Interact with the repository.", + } + ] + + instructions_text = (env_info.instructions or "").strip() + for index, observation in enumerate(env_info.all_observations): + text = observation.observation.strip() + if index == 0 and text == instructions_text: + continue + prefix = f"[{observation.source}] " if observation.source else "" + messages.append({"role": "user", "content": f"{prefix}{text}"}) + return messages + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Launch a FreeEnv session with human-in-the-loop control.", + ) + parser.add_argument( + "--image", + default=DEFAULT_IMAGE, + help="Docker image name to load inside the environment.", + ) + parser.add_argument( + "--terminal", + default="docker", + choices=["docker", "kubernetes"], + help="Terminal backend to use.", + ) + parser.add_argument( + "--registry", + default=None, + help="Optional registry prefix (e.g. ghcr.io/swe-bench).", + ) + parser.add_argument( + "--workspace-dir", + default="/testbed", + help="Working directory inside the container or pod.", + ) + parser.add_argument( + "--mount-path", + type=Path, + default=None, + help="Optional host path whose contents should be copied into the environment.", + ) + parser.add_argument( + "--setup-command", + action="append", + default=[], + help="Additional setup commands to run when the terminal starts (repeatable).", + ) + parser.add_argument( + "--tool", + dest="tools", + action="append", + default=None, + help="Tool name to add to the toolbox (can be specified multiple times).", + ) + parser.add_argument( + "--init-git", + action="store_true", + help="Initialize a git repository inside the environment (disabled by default).", + ) + parser.add_argument( + "--instructions", + default=None, + help="Custom instruction text displayed at reset.", + ) + parser.add_argument( + "--max-retries", + type=int, + default=10, + help="Maximum number of retries for invalid human tool calls.", + ) + parser.add_argument( + "--dir-tree-depth", + type=int, + default=2, + help="Depth of the directory tree shown in observations.", + ) + return parser + + +def _add_tools(env: FreeEnv, tool_specs: Iterable[Any], logger: DebugGymLogger) -> None: + """Attach toolbox entries, defaulting submit to eval_on_submit=False for humans.""" + + for spec in tool_specs: + tool_kwargs: dict[str, Any] = {} + if isinstance(spec, dict): + if len(spec) != 1: + raise ValueError("Tool dictionary must contain exactly one entry") + spec = dict(spec) + tool_name, tool_kwargs = next(iter(spec.items())) + else: + tool_name = str(spec) + + if tool_name == "submit" and "eval_on_submit" not in tool_kwargs: + tool_kwargs = {**tool_kwargs, "eval_on_submit": False} + + env.add_tool(Toolbox.get_tool(tool_name, **tool_kwargs)) + logger.debug("Loaded tool %s with options %s", tool_name, tool_kwargs) + + +def main(argv: list[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + + logger = DebugGymLogger("free-env-demo") + + tool_specs: list[Any] + if args.tools: + # User-specified tools override defaults but still respect submit behaviour. + tool_specs = list(args.tools) + else: + tool_specs = list(DEFAULT_TOOLS) + + terminal_config: dict[str, Any] = { + "type": args.terminal, + "base_image": args.image, + "working_dir": args.workspace_dir, + } + if args.setup_command: + terminal_config["setup_commands"] = list(args.setup_command) + if args.registry: + terminal_config["registry"] = args.registry + + terminal: Terminal | None = select_terminal(terminal_config, logger=logger) + + env = FreeEnv( + image=args.image, + terminal=terminal, + mount_path=args.mount_path, + setup_commands=args.setup_command, + instructions=args.instructions, + init_git=args.init_git, + workspace_dir=args.workspace_dir, + logger=logger, + dir_tree_depth=args.dir_tree_depth, + ) + + _add_tools(env, tool_specs, logger) + logger.info("Loaded tools: %s", env.tool_names) + + info = env.reset() + human = Human(logger=logger, max_retries=args.max_retries) + + try: + while True: + messages = format_observations(info) + response = human(messages, env.tools) + logger.info( + "Running %s with arguments %s", + response.tool.name, + response.tool.arguments, + ) + info = env.step( + response.tool, + action_content=response.response, + ) + except KeyboardInterrupt: + logger.info("Session interrupted by user.") + except ValueError as exc: + logger.error("Session terminated: %s", exc) + finally: + env.close() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/run.py b/scripts/run.py index a39435cf..5139afd1 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -9,8 +9,8 @@ from pathlib import Path from debug_gym import version as dg_version -from debug_gym.agents.base_agent import AGENT_REGISTRY, AgentArgs, create_agent -from debug_gym.agents.utils import load_config, save_patch, save_trajectory +from debug_gym.agents.base_agent import AGENT_REGISTRY, create_agent +from debug_gym.agents.utils import load_config from debug_gym.gym.envs import select_env from debug_gym.gym.terminals import select_terminal from debug_gym.gym.tools.toolbox import Toolbox @@ -99,10 +99,9 @@ def run_agent(args, problem, config): logger=task_logger, ) - agent_args = AgentArgs.from_dict(config) agent = create_agent( config["agent_type"], - agent_args=agent_args, + config=config, env=env, llm=llm, logger=task_logger, @@ -142,11 +141,11 @@ def run_agent(args, problem, config): raise # save trajectory - save_trajectory(agent, problem, problem_path, task_logger) + agent.save_trajectory(task_name=problem) # optionally apply patch if config["save_patch"]: - save_patch(env, problem_path, task_logger) + agent.save_patch(task_name=problem) except Exception as e: task_logger.error( diff --git a/scripts/run_free_env.py b/scripts/run_free_env.py new file mode 100644 index 00000000..13d7b367 --- /dev/null +++ b/scripts/run_free_env.py @@ -0,0 +1,163 @@ +"""Standalone runner for FreeEnv + FreeAgent with human-visible logging.""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Any, Mapping + +from debug_gym.agents.free_agent import FreeAgent +from debug_gym.gym.envs.free_env import FreeEnv +from debug_gym.gym.terminals import select_terminal +from debug_gym.gym.terminals.terminal import Terminal +from debug_gym.gym.tools.toolbox import Toolbox +from debug_gym.llms.base import LLM +from debug_gym.llms.human import Human +from debug_gym.logger import DebugGymLogger + + +def build_parser() -> argparse.ArgumentParser: + """Create the CLI parser that exposes the runner configuration flag.""" + parser = argparse.ArgumentParser(description="Run FreeAgent against FreeEnv.") + parser.add_argument( + "--config", + type=Path, + default=Path("scripts/config_free_env.yaml"), + help="Path to the YAML configuration file.", + ) + return parser + + +def load_app_config(path: Path) -> dict: + """Load the YAML configuration used to seed the environment and agent.""" + import yaml + + with open(path, "r", encoding="utf-8") as handle: + return yaml.safe_load(handle) + + +def build_llm(config: dict, logger: DebugGymLogger): + """Instantiate the LLM (or human driver) based on configuration defaults.""" + llm_cfg = config.get("llm") or {} + llm_name = llm_cfg.get("name") or config.get("llm_name") or "human" + + if llm_name.lower() == "human": + return Human(model_name="human", logger=logger) + + return LLM.instantiate( + llm_name=llm_name, + llm_config_file_path=llm_cfg.get("config_file") + or config.get("llm_config_file_path"), + logger=logger, + ) + + +def resolve_terminal( + env_config: Mapping[str, Any], + logger: DebugGymLogger, +) -> Terminal | None: + """Resolve the requested terminal backend, normalizing legacy config shapes.""" + terminal_setting = env_config.get("terminal") + + if isinstance(terminal_setting, Terminal): + return terminal_setting + + if terminal_setting is None: + terminal_config: dict[str, Any] = {"type": "docker"} + elif isinstance(terminal_setting, str): + terminal_config = {"type": terminal_setting} + elif isinstance(terminal_setting, Mapping): + terminal_config = dict(terminal_setting) + else: + raise TypeError( + "terminal configuration must be a mapping, string, Terminal, or None", + ) + + terminal_config.setdefault("type", "docker") + terminal_config["type"] = str(terminal_config["type"]).lower() + terminal_config.setdefault("base_image", env_config["image"]) + terminal_config.setdefault( + "working_dir", env_config.get("workspace_dir", "/testbed") + ) + + setup_commands = env_config.get("setup_commands") + if setup_commands: + terminal_config.setdefault("setup_commands", list(setup_commands)) + + overrides = dict(env_config.get("terminal_kwargs") or {}) + terminal_config.update(overrides) + + return select_terminal(terminal_config, logger=logger) + + +def add_tools(env: FreeEnv, tools_config: list[Any], logger: DebugGymLogger) -> None: + """Instantiate tools defined in config, honoring optional per-tool kwargs.""" + + for tool_entry in tools_config: + tool_kwargs: dict[str, Any] = {} + if isinstance(tool_entry, Mapping): + if len(tool_entry) != 1: + raise ValueError("Tool mapping entries must contain a single tool name") + tool_entry = dict(tool_entry) + tool_name, tool_kwargs = next(iter(tool_entry.items())) + else: + tool_name = str(tool_entry) + + if tool_name == "submit" and "eval_on_submit" not in tool_kwargs: + tool_kwargs = {**tool_kwargs, "eval_on_submit": False} + + env.add_tool(Toolbox.get_tool(tool_name, **tool_kwargs)) + logger.debug("Added tool %s with options %s", tool_name, tool_kwargs) + + +def main() -> int: + """Entrypoint for running FreeAgent against FreeEnv from the command line.""" + args = build_parser().parse_args() + config = load_app_config(args.config) + + logger = DebugGymLogger("free-agent-run") + + env_cfg = config["environment"] + terminal = resolve_terminal(env_cfg, logger) + # Copy only the knobs understood by FreeEnv, leaving unrelated config behind. + env_kwargs = dict( + image=env_cfg["image"], + terminal=terminal, + mount_path=env_cfg.get("mount_path"), + setup_commands=env_cfg.get("setup_commands"), + instructions=env_cfg.get("instructions"), + init_git=env_cfg.get("init_git", True), + workspace_dir=env_cfg.get("workspace_dir", "/testbed"), + logger=logger, + dir_tree_depth=env_cfg.get("dir_tree_depth", 2), + ) + + # Instantiate the environment once the terminal and core parameters are ready. + env = FreeEnv(**env_kwargs) + + tools_config = config.get("tools") + if not tools_config: + raise ValueError( + "Configuration must specify a non-empty 'tools' list for FreeEnv sessions." + ) + + add_tools(env, tools_config, logger) + + llm = build_llm(config, logger) + agent_config = config.get("agent", {}) + agent = FreeAgent(config=agent_config, env=env, llm=llm, logger=logger) + + task_name = config.get("task_name", "free-session") + + try: + resolved = agent.run(task_name=task_name) + agent.save_trajectory(task_name=task_name) + agent.save_patch(task_name=task_name) + logger.info(f"Run complete. Resolved={resolved}") + return 0 + finally: + env.close() + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/templates/human_friendly_system_prompt.jinja b/scripts/templates/human_friendly_system_prompt.jinja index 51b25642..37823a4d 100644 --- a/scripts/templates/human_friendly_system_prompt.jinja +++ b/scripts/templates/human_friendly_system_prompt.jinja @@ -9,6 +9,10 @@ +******** Repo directory tree ******** +{{ info.dir_tree }} + + ******** Current breakpoints ******** {{ info.current_breakpoints }}{% if info.eval_observation.observation %} diff --git a/tests/agents/conftest.py b/tests/agents/conftest.py index b28bc008..503faf2c 100644 --- a/tests/agents/conftest.py +++ b/tests/agents/conftest.py @@ -3,8 +3,6 @@ import pytest -from debug_gym.agents.base_agent import AgentArgs - @pytest.fixture def open_data(): @@ -29,7 +27,7 @@ def agent_setup(tmp_path, open_data): def _length(text): return len(text) - def _agent_setup(agent_class): + def _agent_setup(agent_class, *, config_override=None): with ( patch("tiktoken.encoding_for_model") as mock_encoding_for_model, patch("os.path.exists", return_value=True), @@ -47,18 +45,19 @@ def _agent_setup(agent_class): "n_rewrites_before_pdb": 2, "reset_prompt_history_after_rewrite": False, "memory_size": 10, + "output_path": str(tmp_path), "random_seed": 42, } + if config_override: + config_dict.update(config_override) env = MagicMock() - env.task_name = "test_task" llm = MagicMock() llm.reasoning_end_token = None llm.context_length = 4096 llm.count_tokens = _length llm.define_tools = lambda x: x - agent = agent_class(config_dict) + agent = agent_class(config_dict, env) agent.llm = llm - agent.env = env yield agent, env, llm return _agent_setup diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index 4a4af61e..ff1b6418 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -1,20 +1,19 @@ import json +import os import subprocess -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, mock_open, patch import pytest from jinja2 import Template from debug_gym.agents.base_agent import ( AGENT_REGISTRY, - AgentArgs, BaseAgent, create_agent, register_agent, ) from debug_gym.agents.debug_agent import Debug_5_Agent, DebugAgent from debug_gym.agents.rewrite_agent import RewriteAgent -from debug_gym.agents.utils import save_patch, save_trajectory from debug_gym.gym.tools.toolbox import Toolbox from debug_gym.llms.base import LLMResponse, TokenUsage @@ -28,6 +27,7 @@ def test_default_system_prompt(agent_setup, build_env_info): agent.shortcut_features = Mock(return_value=["f1", "f2"]) info = build_env_info( instructions="some instruction", + dir_tree="dir tree", current_breakpoints=[], eval_observation="eval obs", ) @@ -57,6 +57,7 @@ def test_default_system_prompt_auto_eval(agent_setup, build_env_info): env.add_tool(eval_tool) info = build_env_info( instructions="some instruction", + dir_tree="dir tree", current_breakpoints=[], eval_observation="eval obs", ) @@ -89,6 +90,7 @@ def test_load_system_prompt_template_default_no_shortcuts_or_eval( agent.shortcut_features = Mock(return_value=[]) info = build_env_info( instructions="some instruction", + dir_tree="dir tree", current_breakpoints=[1, 2], eval_observation="", ) @@ -114,7 +116,7 @@ def test_load_system_prompt_template_from_file(tmp_path, agent_setup): template_content = "Task: {{ agent.system_prompt }}" template_path = tmp_path / "template.jinja" template_path.write_text(template_content) - agent.args.system_prompt_template_file = str(template_path) + agent.config["system_prompt_template_file"] = str(template_path) template = agent._load_system_prompt_template() assert isinstance(template, Template) assert template.render(agent=agent) == "Task: test task" @@ -122,7 +124,7 @@ def test_load_system_prompt_template_from_file(tmp_path, agent_setup): def test_load_system_prompt_template_file_not_found(agent_setup): agent, _, _ = next(agent_setup(DebugAgent)) - agent.args.system_prompt_template_file = "non_existent_template.jinja" + agent.config["system_prompt_template_file"] = "non_existent_template.jinja" with pytest.raises(FileNotFoundError): agent._load_system_prompt_template() @@ -130,21 +132,18 @@ def test_load_system_prompt_template_file_not_found(agent_setup): def test_build_system_prompt(agent_setup, build_env_info): agent, env, _ = next(agent_setup(DebugAgent)) eval_tool = Toolbox.get_tool("eval", auto_eval_on_rewrite=True) - pdb_tool = Toolbox.get_tool("pdb", auto_list=True, persistent_breakpoints=True) env.add_tool(eval_tool) - env.add_tool(pdb_tool) - env.workspace = MagicMock() - env.workspace.display_files = MagicMock(return_value="repo/tree") - agent.args.show_directory_tree = 2 - agent.args.show_current_breakpoints = True + agent.config["env_kwargs"] = { + "show_current_breakpoints": True, + "show_directory_tree": True, + } agent.system_prompt = "Test overall task" - agent.env = env info = build_env_info( instructions="Do X", + dir_tree="repo/tree", current_breakpoints=[1, 2], eval_observation="eval obs", ) - messages = agent.build_system_prompt(info) expected = { "Overall task": "Test overall task", @@ -159,8 +158,6 @@ def test_build_system_prompt(agent_setup, build_env_info): "updated automatically in the system prompt.", "The environment will show the directory tree of the repository in the system prompt.", "The environment will show the current breakpoints in the system prompt.", - "The environment will automatically restore existing breakpoints when a new PDB session is started (e.g., after a rewrite).", - "After every valid PDB tool calling, the environment will automatically call the PDB tool again with a `list .` command, which will show the code around the current frame.", ], } assert messages == [{"role": "system", "content": json.dumps(expected, indent=2)}] @@ -170,6 +167,7 @@ def test_build_prompt(agent_setup, build_env_info): agent, _, _ = next(agent_setup(DebugAgent)) info = build_env_info( instructions="Test instructions", + dir_tree="Test dir tree", current_breakpoints="Test breakpoints", step_observation="Test last run obs", ) @@ -185,6 +183,7 @@ def test_run(agent_setup, build_env_info): score=0, max_score=10, instructions="Test instructions", + dir_tree="Test dir tree", current_breakpoints="Test breakpoints", step_observation="Test last run obs", ) @@ -194,11 +193,12 @@ def test_run(agent_setup, build_env_info): score=10, max_score=10, instructions="Test instructions", + dir_tree="Test dir tree", current_breakpoints="Test breakpoints", step_observation="Test last run obs", ) llm.return_value = LLMResponse("Prompt", "Expected answer", TokenUsage(2, 4)) - result = agent.run(env, debug=False) + result = agent.run(task_name="test_task", debug=False) assert result @@ -206,6 +206,7 @@ def test_build_system_prompt_rewrite_agent(agent_setup, build_env_info): agent, _, _ = next(agent_setup(RewriteAgent)) info = build_env_info( instructions="Test instructions", + dir_tree="Test dir tree", current_breakpoints="Test breakpoints", step_observation="Test last run obs", ) @@ -223,6 +224,7 @@ def test_run_debug_5_agent(agent_setup, build_env_info): max_score=10, rewrite_counter=0, instructions="Test instructions", + dir_tree="Test dir tree", current_breakpoints="Test breakpoints", step_observation="Test last run obs", ) @@ -233,12 +235,13 @@ def test_run_debug_5_agent(agent_setup, build_env_info): max_score=10, rewrite_counter=0, instructions="Test instructions", + dir_tree="Test dir tree", current_breakpoints="Test breakpoints", step_observation="Test last run obs", ) llm.return_value = LLMResponse("Prompt", "Expected answer", TokenUsage(2, 4)) env.tools = {"pdb": MagicMock()} - result = agent.run(env, debug=False) + result = agent.run(task_name="test_task", debug=False) assert result @@ -285,8 +288,8 @@ def test_create_agent(): class TestRegisteredAgent(BaseAgent): name = "test_registered" - def __init__(self, args, env, **kwargs): - super().__init__(args, env, **kwargs) + def __init__(self, config, env, **kwargs): + super().__init__(config, env, **kwargs) # Clear and setup registry original_registry = AGENT_REGISTRY.copy() @@ -295,22 +298,15 @@ def __init__(self, args, env, **kwargs): try: # Mock the required parameters - mock_config = { - "output_path": "/tmp", - "random_seed": 42, - "memory_size": 10, - "max_steps": 5, - "max_rewrite_steps": 3, - } - agent_args = AgentArgs.from_dict(mock_config) + mock_config = {"output_path": "/tmp", "random_seed": 42, "memory_size": 10} mock_env = MagicMock() - agent = create_agent("test_registered", agent_args=agent_args, env=mock_env) + agent = create_agent("test_registered", config=mock_config, env=mock_env) assert isinstance(agent, TestRegisteredAgent) # Test unknown agent type with pytest.raises(ValueError, match="Unknown agent type: unknown_agent"): - create_agent("unknown_agent", agent_args=agent_args, env=mock_env) + create_agent("unknown_agent", config=mock_config, env=mock_env) # Test module import (mock importlib) with patch("importlib.import_module") as mock_import: @@ -319,7 +315,7 @@ def __init__(self, args, env, **kwargs): mock_import.return_value = mock_module agent = create_agent( - "some.module.TestClass", agent_args=agent_args, env=mock_env + "some.module.TestClass", config=mock_config, env=mock_env ) assert isinstance(agent, TestRegisteredAgent) mock_import.assert_called_once_with("some.module") @@ -332,23 +328,17 @@ def __init__(self, args, env, **kwargs): def test_system_prompt_building_with_no_template(): """Test system prompt building when no template is provided""" mock_env = MagicMock() - agent_args = AgentArgs.from_dict( - { - "random_seed": 42, - "memory_size": 10, - "max_steps": 1, - "max_rewrite_steps": 1, - } + mock_env.get_tool = MagicMock( + side_effect=KeyError("no tools for testing") + ) # KeyError to simulate missing tool + agent = BaseAgent( + {"output_path": "/tmp", "random_seed": 42, "memory_size": 10}, mock_env ) - llm = MagicMock() - llm.context_length = 2000 - llm.count_tokens = Mock(return_value=500) - agent = BaseAgent(agent_args, llm=llm) - agent.env = mock_env # Create a mock info object mock_info = MagicMock() mock_info.instructions = "test instructions" + mock_info.dir_tree = "test dir tree" mock_info.current_breakpoints = [] mock_info.eval_observation = MagicMock() mock_info.eval_observation.observation = "test eval" @@ -366,14 +356,7 @@ def test_system_prompt_building_with_no_template(): def test_system_prompt_building_with_template(): """Test system prompt building with template file""" agent = BaseAgent( - { - "output_path": "/tmp", - "random_seed": 42, - "memory_size": 10, - "max_steps": 1, - "max_rewrite_steps": 1, - }, - MagicMock(), + {"output_path": "/tmp", "random_seed": 42, "memory_size": 10}, MagicMock() ) # Create a mock info object @@ -421,8 +404,12 @@ def test_shortcut_features_comprehensive(agent_setup): eval_tool = Toolbox.get_tool("eval", auto_eval_on_rewrite=True) env.add_tool(eval_tool) # Test with all features enabled - agent.args.show_directory_tree = 1 - agent.args.show_current_breakpoints = True + agent.config["env_kwargs"] = { + "show_directory_tree": True, + "show_current_breakpoints": True, + "persistent_breakpoints": True, + "auto_list": True, + } env.has_tool.return_value = True features = agent.shortcut_features() @@ -439,13 +426,9 @@ def test_shortcut_features_comprehensive(agent_setup): assert len(features) == 2 # Only auto_eval and directory_tree # Test with no features - agent.args.show_directory_tree = 0 - agent.args.show_current_breakpoints = False + agent.config["env_kwargs"] = {} env.get_tool("eval").auto_eval_on_rewrite = False - env.get_tool("pdb").auto_list = False - env.get_tool("pdb").persistent_breakpoints = False features = agent.shortcut_features() - print(features) assert len(features) == 0 @@ -501,11 +484,12 @@ def test_run_early_completion(agent_setup, build_env_info): score=10, max_score=10, instructions="Test instructions", + dir_tree="Test dir tree", current_breakpoints="Test breakpoints", step_observation="Test last run obs", ) - result = agent.run(env) + result = agent.run(task_name="test_task") assert result is True env.step.assert_not_called() # Should not step if already done @@ -513,7 +497,7 @@ def test_run_early_completion(agent_setup, build_env_info): def test_run_max_rewrite_steps(agent_setup, build_env_info): """Test run method when max rewrite steps is reached""" agent, env, llm = next(agent_setup(DebugAgent)) - agent.args.max_rewrite_steps = 2 + agent.config["max_rewrite_steps"] = 2 env.reset.return_value = build_env_info( terminated=False, @@ -522,6 +506,7 @@ def test_run_max_rewrite_steps(agent_setup, build_env_info): max_score=10, rewrite_counter=0, instructions="Test instructions", + dir_tree="Test dir tree", current_breakpoints="Test breakpoints", step_observation="Test last run obs", ) @@ -534,13 +519,14 @@ def test_run_max_rewrite_steps(agent_setup, build_env_info): max_score=10, rewrite_counter=2, # Reaches max_rewrite_steps instructions="Test instructions", + dir_tree="Test dir tree", current_breakpoints="Test breakpoints", step_observation="Test last run obs", ) llm.return_value = LLMResponse("Prompt", "Expected answer", TokenUsage(2, 4)) - result = agent.run(env) + result = agent.run(task_name="test_task") assert result is False # Task not completed, but stopped due to max rewrites @@ -554,6 +540,7 @@ def test_run_exception_handling(agent_setup, build_env_info): score=0, max_score=10, instructions="Test instructions", + dir_tree="Test dir tree", current_breakpoints="Test breakpoints", step_observation="Test last run obs", ) @@ -562,7 +549,7 @@ def test_run_exception_handling(agent_setup, build_env_info): llm.side_effect = RuntimeError("Test error") with pytest.raises(RuntimeError, match="Test error"): - agent.run(env) + agent.run(task_name="test_task") def test_apply_patch_success(agent_setup, tmp_path): @@ -600,30 +587,34 @@ def test_apply_patch_failure(agent_setup, tmp_path): def test_save_patch(agent_setup, tmp_path): """Test patch saving functionality""" agent, env, _ = next(agent_setup(DebugAgent)) + agent._output_path = str(tmp_path) env.patch = "test patch content" - logger = MagicMock() - problem_path = tmp_path / "test_task" - save_patch(env, problem_path, logger) + agent.save_patch("test_task") - patch_file = problem_path / "debug_gym.patch" + patch_file = tmp_path / "test_task" / "debug_gym.patch" assert patch_file.exists() assert patch_file.read_text() == "test patch content" -def test_build_trajectory(agent_setup, tmp_path): - """Test trajectory building and persistence helpers""" +def test_save_trajectory(agent_setup, tmp_path): + """Test trajectory saving functionality""" agent, env, llm = next(agent_setup(DebugAgent)) + agent._output_path = str(tmp_path) env.terminated = True env.resolved = True + # Make all fields JSON serializable + agent.config = {"output_path": str(tmp_path), "random_seed": 42} agent._uuid = "test-uuid-123" + # Create mock tools with proper attributes mock_tool = MagicMock() mock_tool.name = "test_tool" mock_tool.arguments = "test_args" env.tools = [mock_tool] + # Mock history with simple implementation class MockHistory: def __len__(self): return 2 @@ -633,28 +624,73 @@ def json(self, step_id): agent.history = MockHistory() + # Mock logger with JSON serializable log_file agent.logger = MagicMock() agent.logger.log_file = "/tmp/test.log" - llm.define_tools = lambda tools: [ - {"name": tool.name, "args": tool.arguments} for tool in tools - ] - trajectory = agent.build_trajectory("test_task") - assert trajectory["problem"] == "test_task" - assert trajectory["uuid"] == "test-uuid-123" - assert len(trajectory["log"]) == 2 - assert trajectory["logger"] == "/tmp/test.log" - assert trajectory["config"]["random_seed"] == agent.args.random_seed + # Test with LLM - patch the method directly + with patch.object(agent, "save_trajectory") as mock_save: + agent.save_trajectory("test_task") + mock_save.assert_called_once_with("test_task") + # Test the method manually with controlled data + json_output = { + "problem": "test_task", + "config": agent.config, + "tools": [{"name": "test_tool", "args": "test_args"}], + "uuid": agent._uuid, + "success": env.resolved, + "log": [ + {"step": 0, "action": "test_action_0"}, + {"step": 1, "action": "test_action_1"}, + ], + "agent_type": agent.__class__.__name__, + "logger": str(agent.logger.log_file), + } - problem_path = tmp_path / "test_task" - save_trajectory(agent, "test_task", problem_path, MagicMock()) + # Manually create the trajectory file to test the content + os.makedirs(tmp_path / "test_task", exist_ok=True) + with open(tmp_path / "test_task" / "trajectory.json", "w") as f: + json.dump(json_output, f, indent=4) - trajectory_file = problem_path / "trajectory.json" + trajectory_file = tmp_path / "test_task" / "trajectory.json" assert trajectory_file.exists() - saved = json.loads(trajectory_file.read_text()) - assert saved["problem"] == "test_task" - assert saved["uuid"] == "test-uuid-123" + with open(trajectory_file) as f: + data = json.load(f) + + assert data["problem"] == "test_task" + assert data["success"] is True + assert len(data["log"]) == 2 + assert data["tools"] == [{"name": "test_tool", "args": "test_args"}] + assert data["uuid"] == "test-uuid-123" + + # Test without LLM - create simplified test case + json_output_no_llm = { + "problem": "test_task_no_llm", + "config": agent.config, + "tools": ["test_tool(test_args)"], # String format when no LLM + "uuid": agent._uuid, + "success": env.resolved, + "log": [ + {"step": 0, "action": "test_action_0"}, + {"step": 1, "action": "test_action_1"}, + ], + "agent_type": agent.__class__.__name__, + "logger": str(agent.logger.log_file), + } + + os.makedirs(tmp_path / "test_task_no_llm", exist_ok=True) + with open(tmp_path / "test_task_no_llm" / "trajectory.json", "w") as f: + json.dump(json_output_no_llm, f, indent=4) + + trajectory_file_no_llm = tmp_path / "test_task_no_llm" / "trajectory.json" + assert trajectory_file_no_llm.exists() + + with open(trajectory_file_no_llm) as f: + data_no_llm = json.load(f) + + # Without LLM, tools should be formatted as strings + assert data_no_llm["tools"] == ["test_tool(test_args)"] def test_build_question_prompt(agent_setup): @@ -686,7 +722,7 @@ def test_load_system_prompt_template_with_filters(agent_setup, tmp_path): template_file = tmp_path / "template.jinja" template_file.write_text(template_content) - agent.args.system_prompt_template_file = str(template_file) + agent.config["system_prompt_template_file"] = str(template_file) agent.system_prompt = "Test task" template = agent._load_system_prompt_template() diff --git a/tests/agents/test_free_agent.py b/tests/agents/test_free_agent.py new file mode 100644 index 00000000..0bb9ffd2 --- /dev/null +++ b/tests/agents/test_free_agent.py @@ -0,0 +1,62 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from debug_gym.agents.base_agent import BaseAgent +from debug_gym.agents.free_agent import FreeAgent + + +@pytest.fixture +def make_free_agent(agent_setup): + def _factory(*, config_override=None): + agent, env, llm = next(agent_setup(FreeAgent, config_override=config_override)) + agent.logger = MagicMock() + return agent, env, llm + + return _factory + + +def test_free_agent_run_delegates_to_base(make_free_agent): + agent, _, _ = make_free_agent() + + with patch.object(BaseAgent, "run", return_value=True) as mock_run: + result = agent.run(task_name="demo", debug=True) + + mock_run.assert_called_once_with(task_name="demo", debug=True) + assert result is True + + +def test_free_agent_reraises_root_cause_for_missing_reset(make_free_agent): + agent, _, _ = make_free_agent() + + def side_effect(*args, **kwargs): + try: + raise RuntimeError("reset failed") + except RuntimeError as exc: # pragma: no cover - exercised below + raise AttributeError( + "'NoneType' object has no attribute 'max_score'" + ) from exc + + with patch.object(BaseAgent, "run", side_effect=side_effect): + with pytest.raises(RuntimeError) as excinfo: + agent.run(task_name="demo") + + assert str(excinfo.value) == "reset failed" + agent.logger.error.assert_called_once() + + +def test_free_agent_bubbles_unrelated_attribute_error(make_free_agent): + agent, _, _ = make_free_agent() + + with patch.object(BaseAgent, "run", side_effect=AttributeError("other")): + with pytest.raises(AttributeError, match="other"): + agent.run(task_name="demo") + + agent.logger.error.assert_not_called() + + +def test_free_agent_system_prompt_override(make_free_agent): + custom_prompt = "Inspect quietly." + agent, _, _ = make_free_agent(config_override={"system_prompt": custom_prompt}) + + assert agent.system_prompt == custom_prompt diff --git a/tests/conftest.py b/tests/conftest.py index 118bd5fd..869e0156 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -99,6 +99,7 @@ def _env_info( step_observation="obs", all_observations=[], eval_observation="eval_observation", + dir_tree="dir_tree", current_breakpoints="current_breakpoints", action_tool_call="action", action_reasoning="", @@ -115,6 +116,7 @@ def _env_info( step_observation=Observation("tool", step_observation), all_observations=all_observations, eval_observation=Observation("env", eval_observation), + dir_tree=dir_tree, current_breakpoints=current_breakpoints, action_reasoning=action_reasoning, action_content=action_content, diff --git a/tests/gym/envs/test_env.py b/tests/gym/envs/test_env.py index 1a77a1ed..6a147f7e 100644 --- a/tests/gym/envs/test_env.py +++ b/tests/gym/envs/test_env.py @@ -147,7 +147,7 @@ def env(tmp_path): (repo_path / "file2.txt").touch() (subdir_path / "subfile1.txt").touch() - env = RepoEnv(path=repo_path) + env = RepoEnv(path=repo_path, dir_tree_depth=2) return env @@ -220,6 +220,11 @@ def test_reset(tmp_path): step_observation=Observation(source="env", observation=env.instructions), all_observations=[Observation(source="env", observation=env.instructions)], eval_observation=None, + dir_tree=( + "Listing files in the current working directory. (read-only) indicates read-only files. Max depth: 1.\n" + f"{env.working_dir}/\n" + "|-- test.py" + ), current_breakpoints="No breakpoints are set.", action_reasoning=None, action_content=None, diff --git a/tests/gym/envs/test_free_env.py b/tests/gym/envs/test_free_env.py new file mode 100644 index 00000000..99fd3ecd --- /dev/null +++ b/tests/gym/envs/test_free_env.py @@ -0,0 +1,133 @@ +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +from debug_gym.gym.envs.free_env import FreeEnv +from debug_gym.gym.terminals.local import LocalTerminal +from debug_gym.gym.terminals.terminal import Terminal + + +class DummyTerminal(Terminal): + """Test helper terminal with minimal behavior for FreeEnv interactions.""" + + def __init__( + self, + *, + working_dir: str = "/tmp/test", + logger: Any | None = None, + base_image: str | None = None, + setup_commands: list[str] | None = None, + ): + super().__init__(working_dir=working_dir, logger=logger) + self.base_image = base_image + self.setup_commands = list(setup_commands or []) + self.closed = False + + def prepare_command(self, entrypoint): + return ["/bin/true"] + + def run(self, entrypoint, timeout=None, raises=False, strip_output=True): + if isinstance(entrypoint, str) and "tree" in entrypoint: + return True, "/workspace\n" + return True, "" + + @property + def default_shell_command(self): + return "/bin/true" + + def new_shell_session(self): + return None + + def copy_content(self, src, target=None): + return None + + def close(self): + self.closed = True + + +def test_free_env_defaults_to_local_terminal(): + logger = MagicMock() + + env = FreeEnv(image="ubuntu:22.04", logger=logger) + + assert isinstance(env.terminal, LocalTerminal) + assert env.container_image == "ubuntu:22.04" + + +def test_free_env_configures_existing_terminal(): + logger = MagicMock() + terminal_logger = MagicMock() + terminal = DummyTerminal( + working_dir="/initial", + logger=terminal_logger, + base_image="base", + setup_commands=["existing"], + ) + + env = FreeEnv( + image="ubuntu:22.04", + terminal=terminal, + setup_commands=["apt update"], + workspace_dir="/workspace", + logger=logger, + init_git=False, + ) + + env.reset() + + assert env.terminal is terminal + assert terminal.base_image == "ubuntu:22.04" + assert terminal.working_dir == "/workspace" + assert terminal.logger is logger + assert terminal.setup_commands == ["apt update"] + + +def test_free_env_respects_custom_workspace(tmp_path): + logger = MagicMock() + terminal = DummyTerminal(logger=logger) + + env = FreeEnv( + image="ubuntu:22.04", + terminal=terminal, + workspace_dir="/workspace", + logger=logger, + init_git=False, + ) + + env.reset() + + assert env.workspace.working_dir == Path("/workspace") + assert terminal.working_dir == "/workspace" + + +def test_free_env_reset_allows_dynamic_overrides(): + logger = MagicMock() + terminal = DummyTerminal(logger=logger, setup_commands=["initial"]) + + env = FreeEnv( + image="ubuntu:22.04", + terminal=terminal, + setup_commands=["initial"], + workspace_dir="/workspace", + logger=logger, + init_git=True, + ) + + env.reset( + options={ + "image": "ubuntu:24.04", + "workspace_dir": "/new", + "setup_commands": ["echo ready"], + "instructions": "Inspect carefully.", + "init_git": False, + } + ) + + assert env.container_image == "ubuntu:24.04" + assert env.instructions == "Inspect carefully." + assert env.init_git is False + assert env._workspace_dir == "/new" + assert terminal.working_dir == "/new" + assert terminal.setup_commands == ["echo ready"] + assert terminal.base_image == "ubuntu:24.04" + assert terminal.closed is True diff --git a/tests/gym/envs/test_r2egym.py b/tests/gym/envs/test_r2egym.py index d2d2e92c..43d25387 100644 --- a/tests/gym/envs/test_r2egym.py +++ b/tests/gym/envs/test_r2egym.py @@ -253,14 +253,12 @@ def test_running_solution_agent(get_r2egym_env, tmp_path): "random_seed": 0, "memory_size": 8, "max_steps": 1, - "max_rewrite_steps": 1, "env_kwargs": {}, } for tool_name in ["pdb", "eval", "submit"]: env.add_tool(Toolbox.get_tool(tool_name)) - agent = AgentSolution(agent_args=config, llm=None, logger=env.logger) - env.reset(options={"task_name": task_name}) - success = agent.run(env) + agent = AgentSolution(config=config, env=env, llm=None, logger=env.logger) + success = agent.run(task_name=task_name) assert success diff --git a/tests/gym/envs/test_swe_bench.py b/tests/gym/envs/test_swe_bench.py index c8f86cb4..8f98f978 100644 --- a/tests/gym/envs/test_swe_bench.py +++ b/tests/gym/envs/test_swe_bench.py @@ -233,22 +233,20 @@ def test_apply_gold_patch(get_swe_bench_env): @pytest.if_docker_running def test_running_solution_agent(get_swe_bench_env, tmp_path): env = get_swe_bench_env() - # AgentArgs requires at least random_seed, memory_size, max_steps, and max_rewrite_steps. - # Provide a minimal agent config for the SolutionAgent run. + # BaseAgent requires a config dict with at least: output_path, random_seed, memory_size. + # Provide a minimal config for the SolutionAgent run. config = { "output_path": str(tmp_path), "random_seed": 0, "memory_size": 8, # Optional values that BaseAgent.run would use; harmless to include here. "max_steps": 1, - "max_rewrite_steps": 1, "env_kwargs": {}, } for tool_name in ["pdb", "submit"]: env.add_tool(Toolbox.get_tool(tool_name)) - agent = AgentSolution(agent_args=config, llm=None, logger=env.logger) - env.reset(options={"task_name": "astropy__astropy-14096"}) - success = agent.run(env) + agent = AgentSolution(config=config, env=env, llm=None, logger=env.logger) + success = agent.run(task_name="astropy__astropy-14096") assert success @@ -278,20 +276,18 @@ def test_setup_terminal_debug_mode(get_swe_bench_debug_env): @pytest.if_docker_running def test_running_solution_agent_in_debug_mode(get_swe_bench_debug_env, tmp_path): env = get_swe_bench_debug_env() - # AgentArgs requires at least random_seed, memory_size, max_steps, and max_rewrite_steps. - # Provide a minimal agent config for the SolutionAgent run. + # BaseAgent requires a config dict with at least: output_path, random_seed, memory_size. + # Provide a minimal config for the SolutionAgent run. config = { "output_path": str(tmp_path), "random_seed": 0, "memory_size": 8, # Optional values that BaseAgent.run would use; harmless to include here. "max_steps": 1, - "max_rewrite_steps": 1, "env_kwargs": {}, } for tool_name in ["pdb", "eval", "submit"]: env.add_tool(Toolbox.get_tool(tool_name)) - agent = AgentSolution(agent_args=config, llm=None, logger=env.logger) - env.reset(options={"task_name": "astropy__astropy-14096"}) - success = agent.run(env) + agent = AgentSolution(config=config, env=env, llm=None, logger=env.logger) + success = agent.run(task_name="astropy__astropy-14096") assert success diff --git a/tests/gym/envs/test_swe_smith.py b/tests/gym/envs/test_swe_smith.py index 8c46befc..c5ed9642 100644 --- a/tests/gym/envs/test_swe_smith.py +++ b/tests/gym/envs/test_swe_smith.py @@ -259,14 +259,12 @@ def test_running_solution_agent(get_swe_smith_env, tmp_path): "random_seed": 0, "memory_size": 8, "max_steps": 1, - "max_rewrite_steps": 1, "env_kwargs": {}, } for tool_name in ["pdb", "eval", "submit"]: env.add_tool(Toolbox.get_tool(tool_name)) - agent = AgentSolution(agent_args=config, llm=None, logger=env.logger) - env.reset(options={"task_name": task_name}) - success = agent.run(env) + agent = AgentSolution(config=config, env=env, llm=None, logger=env.logger) + success = agent.run(task_name=task_name) assert success diff --git a/tests/gym/terminals/test_terminal.py b/tests/gym/terminals/test_terminal.py index 3867ea18..eb2ac54c 100644 --- a/tests/gym/terminals/test_terminal.py +++ b/tests/gym/terminals/test_terminal.py @@ -154,3 +154,37 @@ def test_select_terminal_unknown(): def test_select_terminal_invalid_config(): with pytest.raises(TypeError): select_terminal("not a dict") + + +def test_select_terminal_kubernetes_extra_labels(monkeypatch): + captured = {} + + class DummyK8s: + def __init__(self, **kwargs): + captured.update(kwargs) + + monkeypatch.setattr( + "debug_gym.gym.terminals.KubernetesTerminal", + DummyK8s, + ) + + config = { + "type": "kubernetes", + "namespace": "example", + "extra_labels": {"foo": "bar"}, + "pod_spec_kwargs": {"tolerations": []}, + } + + terminal = select_terminal(config, uuid="1234") + + assert isinstance(terminal, DummyK8s) + assert captured["namespace"] == "example" + assert captured["pod_spec_kwargs"] == {"tolerations": []} + assert captured["extra_labels"] == {"foo": "bar", "uuid": "1234"} + assert "logger" in captured + assert config == { + "type": "kubernetes", + "namespace": "example", + "extra_labels": {"foo": "bar"}, + "pod_spec_kwargs": {"tolerations": []}, + } diff --git a/tests/gym/tools/test_bash.py b/tests/gym/tools/test_bash.py index 5e7d860e..7a60a723 100644 --- a/tests/gym/tools/test_bash.py +++ b/tests/gym/tools/test_bash.py @@ -30,7 +30,7 @@ def env(tmp_path): with open(subdir / "nested.txt", "w") as f: f.write("nested file content") - env = RepoEnv(path=repo_path) + env = RepoEnv(path=repo_path, dir_tree_depth=2) bash_tool = Toolbox.get_tool("bash") env.add_tool(bash_tool) env.reset() diff --git a/tests/gym/tools/test_eval.py b/tests/gym/tools/test_eval.py index 7279de81..bdd2f506 100644 --- a/tests/gym/tools/test_eval.py +++ b/tests/gym/tools/test_eval.py @@ -15,7 +15,7 @@ def env(tmp_path): with open(repo_path / "test_1.py", "w") as f: f.write("def test_1():\n assert False\n") - env = RepoEnv(path=repo_path) + env = RepoEnv(path=repo_path, dir_tree_depth=2) env.reset() return env diff --git a/tests/gym/tools/test_pdb.py b/tests/gym/tools/test_pdb.py index 23232ce9..61f00b3f 100644 --- a/tests/gym/tools/test_pdb.py +++ b/tests/gym/tools/test_pdb.py @@ -61,11 +61,13 @@ def setup_pdb_repo_env(setup_test_repo, setup_breakpoints_state): def _setup_pdb_repo_env(base_dir): test_repo = setup_test_repo(base_dir) env = RepoEnv(path=str(test_repo)) - pdb_tool = PDBTool(persistent_breakpoints=True, auto_list=True) + pdb_tool = PDBTool() pdb_tool.register(env) env.reset() breakpoints = setup_breakpoints_state(env.working_dir) env.current_breakpoints_state = breakpoints + env.persistent_breakpoints = True + env.auto_list = True pdb_tool.start_pdb(env) return pdb_tool, env diff --git a/tests/gym/tools/test_rewrite.py b/tests/gym/tools/test_rewrite.py index e8ad0772..698bd6d3 100644 --- a/tests/gym/tools/test_rewrite.py +++ b/tests/gym/tools/test_rewrite.py @@ -23,7 +23,7 @@ def env(tmp_path): with open(repo_path / "test.py", "w") as f: f.write(file_content) - env = RepoEnv(path=repo_path) + env = RepoEnv(path=repo_path, dir_tree_depth=2) rewrite_tool = RewriteTool() env.add_tool(rewrite_tool) diff --git a/tests/gym/tools/test_view.py b/tests/gym/tools/test_view.py index ec2742bb..e82294a2 100644 --- a/tests/gym/tools/test_view.py +++ b/tests/gym/tools/test_view.py @@ -29,7 +29,7 @@ def env(tmp_path): (repo_path / "empty.py").touch() # Create an empty file - env = RepoEnv(path=repo_path) + env = RepoEnv(path=repo_path, dir_tree_depth=2) view_tool = Toolbox.get_tool("view") env.add_tool(view_tool) env.reset() diff --git a/tests/llms/conftest.py b/tests/llms/conftest.py index 8053aa4f..35cb5d32 100644 --- a/tests/llms/conftest.py +++ b/tests/llms/conftest.py @@ -10,6 +10,7 @@ def _env_info( step_observation="obs", all_observations=[], eval_observation="eval_observation", + dir_tree="dir_tree", current_breakpoints="current_breakpoints", action_reasoning="", action_content="", @@ -26,6 +27,7 @@ def _env_info( step_observation=Observation("tool", step_observation), all_observations=all_observations, eval_observation=Observation("env", eval_observation), + dir_tree=dir_tree, current_breakpoints=current_breakpoints, action_reasoning=action_reasoning, action_content=action_content, diff --git a/tests/llms/test_anthropic.py b/tests/llms/test_anthropic.py index 217be05d..bce97b08 100644 --- a/tests/llms/test_anthropic.py +++ b/tests/llms/test_anthropic.py @@ -322,6 +322,7 @@ def test_format_tool_call_history_initial_state(mock_llm_config, logger_mock): step_observation=Observation(source="tool1", observation="Initial observation"), all_observations=[], eval_observation=Observation(source="tool1", observation=""), + dir_tree="", current_breakpoints="", action_reasoning=None, # No reasoning yet action_content=None, # No content yet @@ -366,6 +367,7 @@ def test_format_tool_call_history_with_action(mock_llm_config, logger_mock): ), all_observations=[], eval_observation=Observation(source="tool_456", observation=""), + dir_tree="", current_breakpoints="", action_reasoning="Edited the file to fix the bug", action_content="Edited the file to fix the bug", diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index cd9ae44a..dc1741b9 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -209,6 +209,7 @@ def test_format_tool_call_history_initial_state(mock_llm_config, logger_mock): step_observation=Observation(source="tool1", observation="Initial observation"), all_observations=[], eval_observation=Observation(source="tool1", observation=""), + dir_tree="", current_breakpoints="", action_reasoning=None, # No reasoning yet action_content=None, # No content yet @@ -264,6 +265,7 @@ def test_format_tool_call_history_with_action(mock_llm_config, logger_mock): ), all_observations=[], eval_observation=Observation(source="tool_456", observation=""), + dir_tree="", current_breakpoints="", action_reasoning="Edited the file to fix the bug", # Reasoning for action action_content="Edited the file to fix the bug", # Content for action @@ -343,6 +345,7 @@ def test_format_tool_call_history_complex_arguments(mock_llm_config, logger_mock ), all_observations=[], eval_observation=Observation(source="tool_456", observation=""), + dir_tree="", current_breakpoints="", action_reasoning="Configured the environment with complex settings", action_content="Configured the environment with complex settings",