Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 35 additions & 13 deletions experiments/code/ace/adaptation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from appworld_experiments.code.ace.logger import Logger

from appworld.evaluator import evaluate_task
from appworld_experiments.code.ace.hf_policy import HFPolicy

@dataclass
class ExecutionIO:
Expand All @@ -33,9 +34,18 @@ def __init__(
use_gt_code: bool = False,
):
self.generator_model = LiteLLMGenerator(**generator_model_config)
self.reflector_model = LiteLLMGenerator(**reflector_model_config)
#self.reflector_model = LiteLLMGenerator(**reflector_model_config)
self.curator_model = LiteLLMGenerator(**curator_model_config)

refl_cfg = reflector_model_config
self.reflector_model = HFPolicy(
refl_cfg["name"],
trainable_lora=True,
bf16=refl_cfg["bf16"],
lora_r=refl_cfg["lora_r"],
lora_alpha=refl_cfg["lora_alpha"],
lora_dropout=refl_cfg["lora_dropout"],
lora_target_modules=refl_cfg["lora_target_modules"],
)
self.messages: list[dict] = []
self.max_steps = max_steps
self.step_number = 0
Expand All @@ -58,14 +68,16 @@ def __init__(
self.playbook = ''
self.current_task_index = 0 # Global variable to track current task index
self.trained_playbook_file_path = None
self.num_retries = 5
self.trained_checkpoints = None
self.num_retries = 1
self.use_gt_code = use_gt_code

self.refl_cfg = refl_cfg

def initialize(self, world: AppWorld):
self.world = world
if self.log_lm_calls:
self.generator_model.log_calls_to(world=world)
self.reflector_model.log_calls_to(world=world)
#self.reflector_model.log_calls_to(world=world)
self.curator_model.log_calls_to(world=world)
self.cost_tracker.reset(world.task_id)
self.step_number = 0
Expand All @@ -88,6 +100,7 @@ def solve_task_with_gt(self, task_id: str, experiment_name: str | None = None):
task_success = False
reasoning_text = ""

curr_flips = 0
for retry_id in range(self.num_retries):
with AppWorld(
task_id=task_id, experiment_name=experiment_name, **self.appworld_config
Expand All @@ -100,6 +113,7 @@ def solve_task_with_gt(self, task_id: str, experiment_name: str | None = None):
raise ValueError(f"GT code not found for task: {task_id}")
print("---Max steps---: ", self.max_steps)
print("GT Code: \n", gt_code)

self.step_number = 0
for _ in range(self.max_steps):
self.step_number += 1
Expand All @@ -110,7 +124,7 @@ def solve_task_with_gt(self, task_id: str, experiment_name: str | None = None):

if reflection:
reflections.append(reflection)

if len(execution_inputs) != 0:
execution_outputs = [
ExecutionIO(
Expand All @@ -132,14 +146,19 @@ def solve_task_with_gt(self, task_id: str, experiment_name: str | None = None):
self.cost_tracker.add(task_id, cost)
self.log_cost()
if world.task_completed() or self.cost_tracker.exceeded():
self.curator_call()
self.playbook = self.curator_call()
test_tracker, self.test_report = evaluate_task(task_id, experiment_name)
if len(test_tracker.failures)>0:
reasoning_text = self.reflector_call()
if len(test_tracker.failures) > 0:
# call restem
print("test errors")
curr_flips, best_self_edit = self.restem_trainer(task_id, experiment_name, world, original_failures=len(test_tracker.failures))
self.playbook = self.curator_call(best_self_edit, self.playbook)
#reasoning_text = self.reflector_call()
else:
task_success = True
print(f"{task_id} passed unit tests in retry: {retry_id} and step_number: {self.step_number}")
break

if task_success:
break

Expand All @@ -148,6 +167,7 @@ def solve_task_with_gt(self, task_id: str, experiment_name: str | None = None):
self.save_playbook_snapshot()

self.logger.complete_task()
return curr_flips

def solve_task_wo_gt(self, task_id: str, experiment_name: str | None = None):
self.star_guide_idx = None
Expand Down Expand Up @@ -192,7 +212,7 @@ def solve_task_wo_gt(self, task_id: str, experiment_name: str | None = None):
self.log_cost()
if world.task_completed() or self.cost_tracker.exceeded():
test_tracker, self.test_report = evaluate_task(task_id, experiment_name)
self.curator_call()
self.playbook = self.curator_call()
break

# Save playbook every 30 tasks
Expand All @@ -206,7 +226,7 @@ def solve_task(self, task_id: str, experiment_name: str | None = None):
self.cost_tracker.reset(task_id)

if self.use_gt_code:
self.solve_task_with_gt(task_id, experiment_name)
return self.solve_task_with_gt(task_id, experiment_name)
else:
self.solve_task_wo_gt(task_id, experiment_name)

Expand All @@ -226,9 +246,11 @@ def solve_tasks(
num_processes=num_processes,
process_index=process_index,
)
num_flips = 0
for task_index, task_id in enumerate(task_ids):
self.current_task_index = task_index
self.solve_task(task_id, experiment_name)
num_flips += self.solve_task(task_id, experiment_name)
print("total flips ", num_flips)

def log_cost(self) -> None:
self.cost_tracker.save(os.path.join(self.world.output_misc_directory, "cost.txt"))
Expand All @@ -245,4 +267,4 @@ def save_playbook_snapshot(self):
raise ValueError("trained_playbook_file_path is not set")
with open(snapshot_file_path, "w") as file:
file.write(self.playbook)
print(f"Saved playbook snapshot at task {self.current_task_index + 1}: {snapshot_file_path}")
print(f"Saved playbook snapshot at task {self.current_task_index + 1}: {snapshot_file_path}")
Loading