Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
1e099aa
free env
xingdi-eric-yuan Nov 25, 2025
9a52667
Update free_env.py
xingdi-eric-yuan Nov 25, 2025
8d58417
minor
xingdi-eric-yuan Nov 25, 2025
78d0cfe
remove default instructions
xingdi-eric-yuan Nov 25, 2025
ddd4e0a
minor updates
xingdi-eric-yuan Nov 25, 2025
972dca3
minor
xingdi-eric-yuan Nov 25, 2025
8d1e8a6
minor
xingdi-eric-yuan Nov 25, 2025
4718b21
minor
xingdi-eric-yuan Nov 25, 2025
8492f68
change prompts
xingdi-eric-yuan Nov 25, 2025
43c9677
minor
xingdi-eric-yuan Nov 25, 2025
f0aebf8
submit tool
xingdi-eric-yuan Nov 26, 2025
ad173f4
Update free_env.py
xingdi-eric-yuan Nov 26, 2025
dc5272b
add comments and docstrings
xingdi-eric-yuan Nov 26, 2025
560c86c
add tests
xingdi-eric-yuan Nov 26, 2025
db044e7
override sys prompt
xingdi-eric-yuan Nov 26, 2025
7bc995f
refactor free_env
xingdi-eric-yuan Nov 27, 2025
a8a7a10
Update config_free_env.yaml
xingdi-eric-yuan Nov 27, 2025
f94deb0
add in script the apply_eval: false
xingdi-eric-yuan Nov 27, 2025
25983c1
add back tree
xingdi-eric-yuan Nov 27, 2025
bf6a929
Update test_free_env.py
xingdi-eric-yuan Nov 27, 2025
f27d42f
Update run_free_env.py
xingdi-eric-yuan Nov 27, 2025
a2ce7c9
being able to reset
xingdi-eric-yuan Nov 27, 2025
e2b7b54
Merge branch 'main' into free_env
xingdi-eric-yuan Nov 27, 2025
cc6f22d
minor
xingdi-eric-yuan Nov 27, 2025
b292d92
Update __init__.py
xingdi-eric-yuan Nov 27, 2025
fb8dfac
Update __init__.py
xingdi-eric-yuan Nov 27, 2025
7ff692c
add terminal type in yaml
xingdi-eric-yuan Nov 27, 2025
6e686b5
Update __init__.py
xingdi-eric-yuan Nov 27, 2025
8329803
Merge branch 'main' into free_env
xingdi-eric-yuan Nov 27, 2025
d905a71
Update run_free_env.py
xingdi-eric-yuan Nov 27, 2025
6b1b7f8
Update config_free_env.yaml
xingdi-eric-yuan Nov 27, 2025
f447e34
Update openai.py
xingdi-eric-yuan Nov 28, 2025
4cc5fb0
Merge branch 'main' into free_env
MarcCote Nov 28, 2025
dc54990
Rename apply_eval to eval_on_submit
MarcCote Nov 28, 2025
391e9d3
update test terminal
xingdi-eric-yuan Nov 28, 2025
30108f3
Merge branch 'main' into free_env
xingdi-eric-yuan Nov 28, 2025
f22fdb9
Revert "Merge branch 'main' into free_env"
xingdi-eric-yuan Nov 28, 2025
afd61a5
Update kubernetes.py
xingdi-eric-yuan Nov 28, 2025
b177806
minor
xingdi-eric-yuan Nov 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ In addition to all [built-in Jinja filters](https://jinja.palletsprojects.com/en
- **`where`**: Specifies where to trim the message if it exceeds the limit. The default is `"middle"`, which trims from the middle of the message. Other options are `start` or `end`.

```jinja
{{ info.instructions | trim_message(max_length_percentage=0.1, where="end") }}
{{ info.dir_tree | trim_message(max_length_percentage=0.1, where="end") }}
```

#### Example Template
Expand All @@ -223,6 +223,9 @@ Task: {{ agent.system_prompt }}
Instructions:
{{ info.instructions }}

Directory Tree:
{{ info.dir_tree | trim_message(max_length=1000) }}

Current Breakpoints:
{{ info.current_breakpoints | to_pretty_json }}

Expand Down
2 changes: 2 additions & 0 deletions debug_gym/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from debug_gym.agents.base_agent import BaseAgent, register_agent
from debug_gym.agents.debug_agent import Debug_5_Agent, DebugAgent
from debug_gym.agents.free_agent import FreeAgent
from debug_gym.agents.rewrite_agent import RewriteAgent
from debug_gym.agents.solution_agent import AgentSolution
from debug_gym.agents.swe_agent import SWEAgent
221 changes: 124 additions & 97 deletions debug_gym/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import os
import subprocess
import uuid
from dataclasses import MISSING, asdict, dataclass, field, fields
from typing import Any, Dict
from os.path import join as pjoin

import numpy as np
from jinja2 import Environment, Template
Expand All @@ -18,55 +17,6 @@
AGENT_REGISTRY = {}


@dataclass
class AgentArgs:
random_seed: int
memory_size: int
max_steps: int
max_rewrite_steps: int
system_prompt_template_file: str | None = None
uuid: str = field(default_factory=lambda: str(uuid.uuid4()))
extras: Dict[str, Any] = field(default_factory=dict)

@classmethod
def from_dict(cls, config: Dict[str, Any]) -> "AgentArgs":
# Get all field names from the dataclass
field_names = {f.name for f in fields(cls)}

# Check for required fields (those without defaults)
required_fields = {
f.name
for f in fields(cls)
if f.default is MISSING and f.default_factory is MISSING
}
missing = required_fields - config.keys()
if missing:
raise ValueError(
f"Missing required agent config keys: {', '.join(sorted(missing))}"
)

# Separate known fields from extras
known_values = {k: v for k, v in config.items() if k in field_names}
extras = {k: v for k, v in config.items() if k not in field_names}

# Add extras if that field exists
if "extras" in field_names:
known_values["extras"] = extras

return cls(**known_values)

def get(self, key: str, default=None):
if key in self.__dataclass_fields__:
return getattr(self, key)
return self.extras.get(key, default)

def to_dict(self) -> Dict[str, Any]:
data = asdict(self)
extras = data.pop("extras", {})
data.update(extras)
return data


def register_agent(cls):
if not issubclass(cls, BaseAgent):
raise ValueError("agent_class must be a subclass of BaseAgent")
Expand All @@ -83,28 +33,36 @@ class BaseAgent:

def __init__(
self,
agent_args: AgentArgs | Dict[str, Any],
config: dict,
env: RepoEnv,
llm: LLM | None = None,
logger: DebugGymLogger | None = None,
):
self.args = (
AgentArgs.from_dict(agent_args)
if isinstance(agent_args, dict)
else agent_args
)
self.config = config
self.env = env
self.logger = logger or DebugGymLogger("debug-gym")
self.llm = llm
self._uuid = self.args.uuid
self.env = None
self._uuid = self.config.get("uuid", str(uuid.uuid4()))
self._output_path = pjoin(self.config["output_path"], self._uuid)

os.makedirs(self._output_path, exist_ok=True)

if "memory_size" not in self.config:
self.config["memory_size"] = self.config["max_steps"]

self.set_seed(self.args.random_seed)
self.history = HistoryTracker(self.args.memory_size)
self.set_seed(self.config["random_seed"])
self.history = HistoryTracker(self.config["memory_size"])

def set_seed(self, seed):
np.random.seed(seed)

def build_history_prompt(self):
return build_history_prompt(self.history, self.llm)
messages = build_history_prompt(
self.history,
self.llm,
self.config.get("reset_prompt_history_after_rewrite", False),
)
return messages

def parse_reasoning_model_response(self, response, reasoning_end_token):
# Strip the reasoning, e.g., in Deepseek r1, between <think> and </think>.
Expand All @@ -121,6 +79,45 @@ def _auto_eval_on_rewrite(self):
except KeyError:
return False # no eval tool

def _show_current_breakpoints(self):
"""Check if current breakpoints should be shown in the system prompt."""
return self.config.get("env_kwargs", {}).get("show_current_breakpoints", False)

def _show_directory_tree(self):
"""Check if directory tree should be shown in the system prompt."""
return self.config.get("env_kwargs", {}).get("show_directory_tree", False)

def shortcut_features(self):
features = []
if self._auto_eval_on_rewrite():
features.append(
"After successful rewrites, the environment will automatically "
"call the Eval tool to evaluate the rewritten code. Therefore, "
"you do not need to call the Eval tool yourself. The evaluation "
"output will be updated automatically in the system prompt."
)
if self._show_directory_tree():
features.append(
"The environment will show the directory tree of the repository in the system prompt."
)
if self.env.has_tool("pdb"):
if self._show_current_breakpoints():
features.append(
"The environment will show the current breakpoints in the system prompt."
)
if self.config.get("env_kwargs", {}).get("persistent_breakpoints"):
features.append(
"The environment will automatically restore existing breakpoints "
"when a new PDB session is started (e.g., after a rewrite)."
)
if self.config.get("env_kwargs", {}).get("auto_list"):
features.append(
"After every valid PDB tool calling, the environment will "
"automatically call the PDB tool again with a `list .` command, "
"which will show the code around the current frame."
)
return features

@staticmethod
def to_pretty_json(value):
"""Convert a value to a pretty JSON string."""
Expand Down Expand Up @@ -156,7 +153,7 @@ def _load_system_prompt_template(self) -> Template | None:
"""Load system prompt template from config if specified and register custom filters.
If no template is specified, return None.
"""
system_prompt_template = self.args.system_prompt_template_file
system_prompt_template = self.config.get("system_prompt_template_file")
if system_prompt_template:
if not os.path.isfile(system_prompt_template):
error_msg = (
Expand All @@ -179,8 +176,28 @@ def _default_system_prompt(self, info) -> str:

system_prompt_dict = {
"Overall task": self.system_prompt,
"Instructions": info.instructions,
}

if self._show_directory_tree():
system_prompt_dict["Repo directory tree"] = self.trim_message(
info.dir_tree, max_length_percentage=0.1, where="end"
)

if self._show_current_breakpoints():
system_prompt_dict["Current breakpoints"] = info.current_breakpoints

if self._auto_eval_on_rewrite():
system_prompt_dict["Evaluation output of current code"] = self.trim_message(
info.eval_observation.observation,
max_length_percentage=0.8,
where="middle",
)

shortcut_features = self.shortcut_features()
if shortcut_features:
system_prompt_dict["Shortcut features"] = shortcut_features

return self.to_pretty_json(system_prompt_dict)

def build_system_prompt(self, info):
Expand All @@ -206,20 +223,20 @@ def build_prompt(self, info):
messages.extend(self.build_question_prompt())
return messages

def run(self, env: RepoEnv, debug=False):
self.env = env
def run(self, task_name=None, debug=False):
step = 0
info = None
max_steps = self.args.max_steps
max_steps = self.config["max_steps"]
max_rewrite_steps = self.config.get("max_rewrite_steps")
try:
self.history.reset()
info = self.env.reset()
info = self.env.reset(options={"task_name": task_name})
# initial state does not have prompt and response
self.history.step(info, None)

if info.resolved is True:
self.logger.report_progress(
problem_id=env.task_name,
problem_id=task_name,
step=1,
total_steps=1,
score=info.score,
Expand All @@ -237,7 +254,7 @@ def run(self, env: RepoEnv, debug=False):
for step in range(max_steps):
self.logger.info(f"\n{'='*20} STEP {step+1} {'='*20}\n")
highscore = max(highscore, info.score)
msg = f"[{env.task_name[:10]:<10}] Step {step} | Score: {info.score}/{info.max_score or '-'} [Best: {highscore}]"
msg = f"[{task_name[:10]:<10}] Step {step} | Score: {info.score}/{info.max_score or '-'} [Best: {highscore}]"
self.logger.info(msg)

messages = self.build_prompt(info)
Expand All @@ -253,19 +270,24 @@ def run(self, env: RepoEnv, debug=False):
)
self.history.step(info, llm_response)

if (
info.terminated
or info.rewrite_counter >= self.args.max_rewrite_steps
):
reason = (
"terminated" if info.resolved else "max_rewrite_steps reached"
)
limit_reached = (
max_rewrite_steps is not None
and info.rewrite_counter >= max_rewrite_steps
)

if info.terminated or limit_reached:
if info.resolved:
reason = "terminated"
elif limit_reached:
reason = "max_rewrite_steps reached"
else:
reason = "terminated"
self.logger.info(
f"Step: {step} | Score: {info.score}/{info.max_score if info.max_score else '-'} | Reason: {reason}"
)
# early stop, set current step and total steps to be the same
self.logger.report_progress(
problem_id=env.task_name,
problem_id=task_name,
step=step + 1,
total_steps=step + 1,
score=info.score,
Expand All @@ -275,7 +297,7 @@ def run(self, env: RepoEnv, debug=False):
break
# keep progress bar running until max_steps is reached
self.logger.report_progress(
problem_id=env.task_name,
problem_id=task_name,
step=step + 1,
total_steps=max_steps + 1,
score=info.score,
Expand All @@ -284,7 +306,7 @@ def run(self, env: RepoEnv, debug=False):
)
# max_steps was reached, task was either resolved or unresolved
self.logger.report_progress(
problem_id=env.task_name,
problem_id=task_name,
step=step + 1,
total_steps=step + 1,
score=info.score,
Expand All @@ -295,11 +317,11 @@ def run(self, env: RepoEnv, debug=False):
except Exception:
# report any error that happens during the run
self.logger.report_progress(
problem_id=env.task_name,
problem_id=task_name,
step=step + 1,
total_steps=step + 1,
score=info.score if info else 0,
max_score=info.max_score,
max_score=info.max_score if info else None,
status="error",
)
raise
Expand All @@ -326,12 +348,22 @@ def apply_patch(self, patch_path: str) -> bool:
print("Error:", e.stderr)
return False

def build_trajectory(self, task_name) -> Dict[str, Any]:
"""Return the trajectory as a JSON-serializable dict without writing it."""
def save_patch(self, task_name="custom"):
os.makedirs(pjoin(self._output_path, task_name), exist_ok=True)
patch_path = pjoin(self._output_path, task_name, "debug_gym.patch")
with open(patch_path, "w") as f:
f.write(self.env.patch)

self.logger.debug(
f"Patch saved in {pjoin(self._output_path, task_name, 'debug_gym.patch')}"
)

def save_trajectory(self, task_name="custom"):
# Simple tools list.
tools = [f"{tool.name}({tool.arguments})" for tool in self.env.tools]
json_output = {
"problem": task_name,
"config": self.args.to_dict(),
"config": self.config,
"tools": self.llm.define_tools(self.env.tools) if self.llm else tools,
"uuid": self._uuid,
"success": self.env.resolved,
Expand All @@ -342,16 +374,15 @@ def build_trajectory(self, task_name) -> Dict[str, Any]:
for step_id in range(len(self.history)):
step_json = self.history.json(step_id)
json_output["log"].append(step_json)
return json_output
os.makedirs(pjoin(self._output_path, task_name), exist_ok=True)
json_file = pjoin(self._output_path, task_name, "trajectory.json")
with open(json_file, "w") as f:
json.dump(json_output, f, indent=4)

self.logger.debug(f"Trajectory saved in {json_file}")

def create_agent(
agent_type: str,
*,
agent_args: AgentArgs | Dict[str, Any] | None = None,
config: Dict[str, Any] | None = None,
**agent_kwargs,
):

def create_agent(agent_type: str, **agent_kwargs):
if agent_type in AGENT_REGISTRY:
agent_class = AGENT_REGISTRY[agent_type]
elif "." in agent_type:
Expand All @@ -367,9 +398,5 @@ def create_agent(
else:
raise ValueError(f"Unknown agent type: {agent_type}")

agent_args = agent_args or config
if agent_args is None:
raise ValueError("Either agent_args or config must be provided.")

agent = agent_class(args=agent_args, **agent_kwargs)
agent = agent_class(**agent_kwargs)
return agent
Loading
Loading