Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions experiments/code/gepa/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# ruff: noqa: F401
from appworld_experiments.code.gepa.gepa_agent import GEPAAgent
from appworld_experiments.code.gepa.gepa_react import GEPAReActAgent
80 changes: 80 additions & 0 deletions experiments/code/gepa/gepa_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from appworld import AppWorld
from appworld.common.constants import DEFAULT_EXPERIMENT_NAME
from appworld_experiments.code.ace.evaluation_agent import Agent, ExecutionIO

from appworld.evaluator import evaluate_task

class GEPAAgent(Agent):
def __init__(
self,
generator_model_config: dict,
appworld_config: dict | None = None,
logger_config: dict | None = None,
max_steps: int = 10,
max_cost_overall: float = 3000,
max_cost_per_task: float = 10,
log_lm_calls: bool = False,
):
super().__init__(
generator_model_config=generator_model_config,
appworld_config=appworld_config,
logger_config=logger_config,
max_steps=max_steps,
max_cost_overall=max_cost_overall,
max_cost_per_task=max_cost_per_task,
log_lm_calls=log_lm_calls
)

def solve_task(self, task_id: str, experiment_name: str | None = None):
experiment_name = experiment_name or DEFAULT_EXPERIMENT_NAME
self.cost_tracker.reset(task_id)

self.initial_code_idx = None
self.previous_code_idx = None
self.previous_error_idx = None
reflections = []
test_tracker = None

with AppWorld(
task_id=task_id, experiment_name=experiment_name, **self.appworld_config
) as world:
execution_outputs: list[ExecutionIO] = []
self.initialize(world)

print("---Max steps---: ", self.max_steps)
for _ in range(self.max_steps):
self.step_number += 1
execution_inputs, cost, reflection = self.next_execution_inputs_and_cost(execution_outputs, "")
if reflection:
reflections.append(reflection)

if len(execution_inputs) != 0:
execution_outputs = [
ExecutionIO(
content=world.execute(execution_input.content),
metadata=execution_input.metadata,
)
for execution_input in execution_inputs
]

# Show execution results to user via logger
for i, output in enumerate(execution_outputs):
if output.content.strip(): # only show non-empty outputs
self.logger.show_message(
role="environment",
message=output.content,
step_number=self.step_number
)

self.cost_tracker.add(task_id, cost)
self.log_cost()

if world.task_completed() or self.cost_tracker.exceeded():
test_tracker, _ = evaluate_task(task_id, experiment_name)
break

if test_tracker is None:
test_tracker = [execution_output.content for execution_output in execution_outputs]

self.logger.complete_task()
return test_tracker
206 changes: 206 additions & 0 deletions experiments/code/gepa/gepa_react.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import copy
import json
import os
import re
from typing import Any

from jinja2 import Template

from appworld import AppWorld
from appworld.common.utils import read_file
from appworld_experiments.code.ace.evaluation_agent import Agent, ExecutionIO
from appworld_experiments.code.gepa.gepa_agent import GEPAAgent

@GEPAAgent.register("gepa_react")
class GEPAReActAgent(GEPAAgent):
def __init__(
self,
generator_prompt_file_path: str | None = None,
trained_playbook_file_path: str | None = None,
ignore_multiple_calls: bool = True,
max_prompt_length: int | None = None,
max_output_length: int = 400000,
**kwargs: Any,
):
super().__init__(**kwargs)
self.generator_prompt_template = read_file(generator_prompt_file_path.replace("/", os.sep)).lstrip()
self.trained_playbook_file_path = trained_playbook_file_path
self.max_prompt_length = max_prompt_length
self.max_output_length = max_output_length
self.ignore_multiple_calls = ignore_multiple_calls
self.partial_code_regex = r".*```python\n(.*)"
self.full_code_regex = r"```python\n(.*?)```"

self.playbook = None
self.gepa_prompt_replace = None

def replace_gepa_prompt(self, prompt: str):
self.gepa_prompt_replace = prompt

def initialize(self, world: AppWorld):
super().initialize(world)
template = Template(self.generator_prompt_template)
app_descriptions = json.dumps(
[{"name": k, "description": v} for (k, v) in world.task.app_descriptions.items()],
indent=1,
)
template_params = {
"input_str": world.task.instruction,
"main_user": world.task.supervisor,
"app_descriptions": app_descriptions,
"relevant_apis": str(world.task.ground_truth.required_apis),
"playbook": self.playbook,
}
output_str = template.render(template_params)
output_str = self.truncate_input(output_str) + "\n\n"
self.messages = self.text_to_messages(output_str)
self.num_instruction_messages = len(self.messages)
assert self.gepa_prompt_replace is not None
self.messages[0]['content'] = self.gepa_prompt_replace + self.messages[0]['content']

def next_execution_inputs_and_cost(
self, last_execution_outputs: list[ExecutionIO], world_gt_code: str = None
) -> tuple[ExecutionIO, float, str | None]:
if last_execution_outputs:
assert (
len(last_execution_outputs) == 1
), "React expects exactly one last_execution_output."
last_execution_output_content = last_execution_outputs[0].content
potential_new_line = ""
last_execution_output_content = (
"Output:\n```\n" + self.truncate_output(last_execution_output_content) + potential_new_line + "```\n\n"
)
self.messages.append({"role": "user", "content": last_execution_output_content})
messages = self.trimmed_messages
output = self.language_model.generate(messages=messages)
code, fixed_output_content = self.extract_code_and_fix_content(output["content"])
self.messages.append({"role": "assistant", "content": fixed_output_content + "\n\n"})
self.logger.show_message(
role="agent", message=fixed_output_content, step_number=self.step_number
)
return [ExecutionIO(content=code)], output["cost"], None

def extract_code_and_fix_content(self, text: str) -> tuple[str, str]:
if text is None:
return "", ""
original_text = text
output_code = ""
match_end = 0
# Handle multiple calls
for re_match in re.finditer(self.full_code_regex, original_text, flags=re.DOTALL):
code = re_match.group(1).strip()
if self.ignore_multiple_calls:
text = original_text[: re_match.end()]
return code, text
output_code += code + "\n"
match_end = re_match.end()
# Check for partial code match at end (no terminating ```) following the last match
partial_match = re.match(
self.partial_code_regex, original_text[match_end:], flags=re.DOTALL
)
if partial_match:
output_code += partial_match.group(1).strip()
# Terminated due to stop condition; add stop condition to output
if not text.endswith("\n"):
text = text + "\n"
text = text + "```"
if len(output_code) == 0:
return "", text
else:
return output_code, text

def truncate_input(self, input_str: str) -> str:
if self.max_prompt_length is None:
return input_str
max_prompt_length = self.max_prompt_length
goal_index = input_str.rfind("Task:")
if goal_index == -1:
raise ValueError(f"No goal found in input string:\n{input_str}")
next_new_line_index = input_str.find("\n", goal_index) + 1
init_prompt = input_str[:next_new_line_index]
prompt = input_str[next_new_line_index:]
if len(init_prompt) > max_prompt_length:
raise ValueError("Input prompt longer than max allowed length")
if len(prompt) > max_prompt_length - len(init_prompt):
new_prompt = prompt[-(max_prompt_length - len(init_prompt)) :]
cmd_index = new_prompt.find("ASSISTANT:") if "ASSISTANT:" in new_prompt else 0
prompt = "\n[TRIMMED HISTORY]\n\n" + new_prompt[cmd_index:]
return init_prompt + prompt

def truncate_output(self, execution_output_content: str) -> str:
if len(execution_output_content) > 20000:
execution_output_content = execution_output_content[:20000] + "\n[REST NOT SHOWN FOR BREVITY]"
return execution_output_content

def text_to_messages(self, input_str: str) -> list[dict]:
messages_json = []
last_start = 0
for m in re.finditer("(USER|ASSISTANT|SYSTEM):\n", input_str, flags=re.IGNORECASE):
last_end = m.span()[0]
if len(messages_json) == 0:
if last_end != 0:
raise ValueError(
f"Start of the prompt has no assigned role: {input_str[:last_end]}"
)
else:
messages_json[-1]["content"] = input_str[last_start:last_end]
role = m.group(1).lower()
messages_json.append({"role": role, "content": None})
last_start = m.span()[1]
messages_json[-1]["content"] = input_str[last_start:]
return messages_json

def messages_to_text(self, messages: list[dict]) -> str:
output_str = ""
for message in messages:
role = message["role"]
if role == "system":
output_str += "SYSTEM:\n" + message["content"]
if role == "assistant":
output_str += "ASSISTANT:\n" + message["content"]
elif role == "user":
output_str += "USER:\n" + message["content"]
else:
raise ValueError(f"Unknown message role {role} in: {message}")
return output_str

@property
def trimmed_messages(self) -> list[dict]:
messages = copy.deepcopy(self.messages)
pre_messages = messages[: self.num_instruction_messages - 1]
post_messages = messages[self.num_instruction_messages - 1 :]
output_str = self.messages_to_text(post_messages)
remove_prefix = output_str[: output_str.index("Task: ") + 6]
output_str = output_str.removeprefix(
remove_prefix
) # not needed, it's only to match the original code
observation_index = 0
while len(output_str) > self.max_output_length:
found_block = False
# Dont remove observations from the last 5 blocks
if observation_index < len(post_messages) - 5:
# Find the next observation block to remove
for message_index, message in enumerate(post_messages[observation_index:]):
# Only keep the code blocks and remove observations
if message["role"] == "user" and message["content"].startswith("Output:"):
message["content"] = "Output:\n```\n[NOT SHOWN FOR BREVITY]```\n\n"
found_block = True
observation_index += message_index + 1
break
if not found_block:
observation_index = len(post_messages)
# If no observation block left to trim, we need to start removing complete history blocks
if not found_block and len(post_messages):
first_post_message = copy.deepcopy(post_messages[0])
if not first_post_message["content"].endswith("[TRIMMED HISTORY]\n\n"):
first_post_message["content"] += "[TRIMMED HISTORY]\n\n"
post_messages = [first_post_message] + post_messages[2:]
found_block = True
if not found_block:
raise ValueError(f"No blocks found to be removed!\n{post_messages}")
output_str = self.messages_to_text(
post_messages
) # not needed, it's only to match the original code
output_str = output_str.removeprefix(remove_prefix)
messages = pre_messages + post_messages
return messages
46 changes: 46 additions & 0 deletions experiments/configs/GEPA_offline_with_GT_adaptation.jsonnet
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Not used for appworld training run, only used for initialization of agent in GEPA
local project_home_path = std.extVar("APPWORLD_PROJECT_PATH");
local experiment_prompts_path = project_home_path + "/experiments/prompts";
local experiment_playbooks_path = project_home_path + "/experiments/playbooks";
local experiment_configs_path = project_home_path + "/experiments/configs";
local experiment_code_path = project_home_path + "/experiments/code";

local generator_model_config = {
"name": "DeepSeek-V3.1",
"provider": "sambanova",
"temperature": 0,
"seed": 100,
"stop": ["<|endoftext|>", "<|eot_id|>", "<|start_header_id|>"],
"logprobs": false,
"top_logprobs": null,
"frequency_penalty": 0,
"presence_penalty": 0,
"n": 1,
"response_format": {"type": "text"},
"retry_after_n_seconds": 10,
"use_cache": true,
"max_retries": 50,
};

{
"type": "gepa",
"config": {
"agent": {
"type": "gepa_react",
"generator_model_config": generator_model_config,
"appworld_config": {
"random_seed": 123,
},
"logger_config": {
"color": true,
"verbose": true,
},
"generator_prompt_file_path": experiment_prompts_path + "/appworld_react_gepa_prompt.txt",
"ignore_multiple_calls": true,
"max_steps": 40,
"max_cost_overall": 1000,
"max_cost_per_task": 10,
"log_lm_calls": true,
}
}
}
Loading