Skip to content

Commit 839f11d

Browse files
chore: speeding up benchmark with concurrent requests (#720)
# Motivation Leverages asyncio to optimize API limit usage. # Content <!-- Please include a summary of the change --> # Testing <!-- How was the change tested? --> # Please check the following before marking your PR as ready for review - [ ] I have added tests for my changes - [ ] I have updated the documentation or added new documentation as needed
1 parent 5c3ccf8 commit 839f11d

File tree

1 file changed

+87
-69
lines changed
  • codegen-examples/examples/swebench_agent_run

1 file changed

+87
-69
lines changed

codegen-examples/examples/swebench_agent_run/run_eval.py

Lines changed: 87 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -17,93 +17,106 @@
1717
run_agent_modal = modal.Function.from_name(app_name="swebench-agent-run", name="run_agent_modal")
1818

1919

20-
async def process_batch_modal(examples: list[SweBenchExample], batch_size=10):
21-
"""Process a batch of examples concurrently.
20+
async def process_batch_modal(examples: list[SweBenchExample], num_workers=10, max_retries=3):
21+
"""Process a batch of examples concurrently using a queue system.
2222
2323
Args:
2424
examples: List of SweBenchExample objects to process
25-
batch_size: Number of examples to process concurrently.
26-
Default is 50 which provides good parallelization
27-
while staying well within Modal's limits.
25+
num_workers: Number of examples to process concurrently
26+
max_retries: Maximum number of retries for failed requests
2827
"""
29-
results = []
30-
31-
# Process examples in batches
32-
for i in range(0, len(examples), batch_size):
33-
batch = examples[i : i + batch_size]
28+
results = {}
29+
queue = asyncio.Queue()
3430

35-
# Create tasks for this batch
36-
batch_tasks = [run_agent_modal.remote.aio(example) for example in batch]
37-
38-
# Wait for all tasks in this batch to complete
39-
print(f"Processing batch {i // batch_size + 1}/{len(examples) // batch_size + 1} (examples {i + 1}-{min(i + batch_size, len(examples))})")
31+
# Initialize the queue with (example, attempt) tuples
32+
for example in examples:
33+
await queue.put((example, 0)) # 0 represents first attempt
4034

35+
async def process_example(example, attempt):
4136
try:
42-
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
43-
44-
# Store results
45-
for example, result in zip(batch, batch_results):
46-
error_info = None
47-
48-
if isinstance(result, Exception):
49-
error_type = type(result).__name__
50-
error_info = {
51-
"error_type": error_type,
52-
"error_message": str(result),
53-
"traceback": traceback.format_exception(type(result), result, result.__traceback__),
54-
}
55-
56-
if isinstance(result, modal.exception.Error):
57-
error_info["modal_error_code"] = getattr(result, "code", None)
58-
error_info["modal_error_details"] = getattr(result, "details", None)
59-
60-
print(f"Error processing {example.instance_id}:")
61-
print(f"Type: {error_type}")
62-
print(f"Message: {str(result)}")
63-
print("Traceback:")
64-
print("".join(error_info["traceback"]))
65-
66-
results.append({"instance_id": example.instance_id, "status": "error", "error_info": error_info})
67-
else:
68-
if result is None:
69-
print(f"Warning: Null result for {example.instance_id}")
70-
results.append({"instance_id": example.instance_id, "status": "error", "error_info": {"error_type": "NullResult", "error_message": "Process returned None"}})
71-
else:
72-
results.append(result)
37+
result = await run_agent_modal.remote.aio(example)
38+
39+
if result is None:
40+
print(f"Warning: Null result for {example.instance_id}")
41+
return {"instance_id": example.instance_id, "status": "error", "error_info": {"error_type": "NullResult", "error_message": "Process returned None"}}
42+
return result
7343

7444
except Exception as e:
75-
print("Batch processing error:")
76-
print(f"Type: {type(e).__name__}")
45+
error_type = type(e).__name__
46+
error_info = {
47+
"error_type": error_type,
48+
"error_message": str(e),
49+
"traceback": traceback.format_exception(type(e), e, e.__traceback__),
50+
}
51+
52+
if isinstance(e, modal.exception.Error):
53+
error_info["modal_error_code"] = getattr(e, "code", None)
54+
error_info["modal_error_details"] = getattr(e, "details", None)
55+
56+
print(f"Error processing {example.instance_id} (attempt {attempt + 1}):")
57+
print(f"Type: {error_type}")
7758
print(f"Message: {str(e)}")
78-
traceback.print_exc()
79-
80-
# Mark all examples in the batch as failed
81-
for example in batch:
82-
results.append(
83-
{
84-
"instance_id": example.instance_id,
85-
"status": "error",
86-
"error_info": {"error_type": type(e).__name__, "error_message": str(e), "traceback": traceback.format_exc(), "batch_failure": True},
87-
}
88-
)
59+
print("Traceback:")
60+
print("".join(error_info["traceback"]))
8961

90-
return results
62+
if attempt < max_retries:
63+
await queue.put((example, attempt + 1))
64+
return None
65+
66+
return {"instance_id": example.instance_id, "status": "error", "error_info": error_info}
67+
68+
async def worker():
69+
while True:
70+
try:
71+
example, attempt = await queue.get()
72+
73+
if example.instance_id in results:
74+
queue.task_done()
75+
continue
76+
77+
result = await process_example(example, attempt)
78+
79+
if result is not None:
80+
results[example.instance_id] = result
81+
82+
queue.task_done()
83+
84+
except Exception as e:
85+
print(f"Worker error: {str(e)}")
86+
traceback.print_exc()
87+
queue.task_done()
88+
89+
# Start workers
90+
workers = [asyncio.create_task(worker()) for _ in range(num_workers)]
91+
92+
# Wait for queue to be fully processed
93+
await queue.join()
94+
95+
# Cancel workers
96+
for w in workers:
97+
w.cancel()
98+
99+
# Wait for all workers to be cancelled
100+
await asyncio.gather(*workers, return_exceptions=True)
101+
102+
# Return results in the same order as input examples
103+
return [results[example.instance_id] for example in examples]
91104

92105

93-
def process_batch_local(examples: list[SweBenchExample], batch_size=10, codebases: dict[str, Codebase] = {}):
106+
def process_batch_local(examples: list[SweBenchExample], num_workers=5, codebases: dict[str, Codebase] = {}):
94107
"""Process a batch of examples synchronously.
95108
96109
Args:
97110
examples: List of SweBenchExample objects to process
98-
batch_size: Number of examples to process in each batch.
111+
num_workers: Number of examples to process in each batch.
99112
Default is 10 to avoid overwhelming the system.
100113
"""
101114
results = []
102115

103116
# Process examples in batches
104-
for i in range(0, len(examples), batch_size):
105-
batch = examples[i : i + batch_size]
106-
print(f"Processing batch {i // batch_size + 1}/{len(examples) // batch_size + 1} (examples {i + 1}-{min(i + batch_size, len(examples))})")
117+
for i in range(0, len(examples), num_workers):
118+
batch = examples[i : i + num_workers]
119+
print(f"Processing batch {i // num_workers + 1}/{len(examples) // num_workers + 1} (examples {i + 1}-{min(i + num_workers, len(examples))})")
107120

108121
# Process each example in the batch
109122
for example in batch:
@@ -134,7 +147,9 @@ def process_batch_local(examples: list[SweBenchExample], batch_size=10, codebase
134147
return results
135148

136149

137-
async def run_eval(use_existing_preds: str | None, dataset: str, length: int, instance_id: str | None = None, local: bool = False, codebases: dict[str, Codebase] = {}, repo: str | None = None):
150+
async def run_eval(
151+
use_existing_preds: str | None, dataset: str, length: int, instance_id: str | None = None, local: bool = False, codebases: dict[str, Codebase] = {}, repo: str | None = None, num_workers: int = 5
152+
):
138153
run_id = use_existing_preds or str(uuid.uuid4())
139154
print(f"Run ID: {run_id}")
140155
predictions_dir = PREDS_DNAME / f"results_{run_id}"
@@ -162,7 +177,7 @@ async def run_eval(use_existing_preds: str | None, dataset: str, length: int, in
162177
if local:
163178
results = process_batch_local(examples, codebases=codebases)
164179
else:
165-
results = await process_batch_modal(examples)
180+
results = await process_batch_modal(examples, num_workers=num_workers)
166181

167182
# Save individual results
168183
for result in results:
@@ -218,9 +233,12 @@ async def run_eval(use_existing_preds: str | None, dataset: str, length: int, in
218233
@click.option("--instance-id", help="The instance ID of the example to process.", type=str, default=None)
219234
@click.option("--local", help="Run the evaluation locally.", is_flag=True, default=False)
220235
@click.option("--repo", help="The repo to use.", type=str, default=None)
221-
def run_eval_command(use_existing_preds, dataset, length, instance_id, local, repo):
236+
@click.option(
237+
"--num-workers", help="The number of workers to use. This is the number of examples that will be processed concurrently. A large number may lead to rate limiting issues.", type=int, default=5
238+
)
239+
def run_eval_command(use_existing_preds, dataset, length, instance_id, local, repo, num_workers):
222240
print(f"Repo: {repo}")
223-
asyncio.run(run_eval(use_existing_preds=use_existing_preds, dataset=dataset, length=length, instance_id=instance_id, codebases=None, local=local, repo=repo))
241+
asyncio.run(run_eval(use_existing_preds=use_existing_preds, dataset=dataset, length=length, instance_id=instance_id, codebases=None, local=local, repo=repo, num_workers=num_workers))
224242

225243

226244
if __name__ == "__main__":

0 commit comments

Comments
 (0)