|
17 | 17 | run_agent_modal = modal.Function.from_name(app_name="swebench-agent-run", name="run_agent_modal") |
18 | 18 |
|
19 | 19 |
|
20 | | -async def process_batch_modal(examples: list[SweBenchExample], batch_size=10): |
21 | | - """Process a batch of examples concurrently. |
| 20 | +async def process_batch_modal(examples: list[SweBenchExample], num_workers=10, max_retries=3): |
| 21 | + """Process a batch of examples concurrently using a queue system. |
22 | 22 |
|
23 | 23 | Args: |
24 | 24 | examples: List of SweBenchExample objects to process |
25 | | - batch_size: Number of examples to process concurrently. |
26 | | - Default is 50 which provides good parallelization |
27 | | - while staying well within Modal's limits. |
| 25 | + num_workers: Number of examples to process concurrently |
| 26 | + max_retries: Maximum number of retries for failed requests |
28 | 27 | """ |
29 | | - results = [] |
30 | | - |
31 | | - # Process examples in batches |
32 | | - for i in range(0, len(examples), batch_size): |
33 | | - batch = examples[i : i + batch_size] |
| 28 | + results = {} |
| 29 | + queue = asyncio.Queue() |
34 | 30 |
|
35 | | - # Create tasks for this batch |
36 | | - batch_tasks = [run_agent_modal.remote.aio(example) for example in batch] |
37 | | - |
38 | | - # Wait for all tasks in this batch to complete |
39 | | - print(f"Processing batch {i // batch_size + 1}/{len(examples) // batch_size + 1} (examples {i + 1}-{min(i + batch_size, len(examples))})") |
| 31 | + # Initialize the queue with (example, attempt) tuples |
| 32 | + for example in examples: |
| 33 | + await queue.put((example, 0)) # 0 represents first attempt |
40 | 34 |
|
| 35 | + async def process_example(example, attempt): |
41 | 36 | try: |
42 | | - batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True) |
43 | | - |
44 | | - # Store results |
45 | | - for example, result in zip(batch, batch_results): |
46 | | - error_info = None |
47 | | - |
48 | | - if isinstance(result, Exception): |
49 | | - error_type = type(result).__name__ |
50 | | - error_info = { |
51 | | - "error_type": error_type, |
52 | | - "error_message": str(result), |
53 | | - "traceback": traceback.format_exception(type(result), result, result.__traceback__), |
54 | | - } |
55 | | - |
56 | | - if isinstance(result, modal.exception.Error): |
57 | | - error_info["modal_error_code"] = getattr(result, "code", None) |
58 | | - error_info["modal_error_details"] = getattr(result, "details", None) |
59 | | - |
60 | | - print(f"Error processing {example.instance_id}:") |
61 | | - print(f"Type: {error_type}") |
62 | | - print(f"Message: {str(result)}") |
63 | | - print("Traceback:") |
64 | | - print("".join(error_info["traceback"])) |
65 | | - |
66 | | - results.append({"instance_id": example.instance_id, "status": "error", "error_info": error_info}) |
67 | | - else: |
68 | | - if result is None: |
69 | | - print(f"Warning: Null result for {example.instance_id}") |
70 | | - results.append({"instance_id": example.instance_id, "status": "error", "error_info": {"error_type": "NullResult", "error_message": "Process returned None"}}) |
71 | | - else: |
72 | | - results.append(result) |
| 37 | + result = await run_agent_modal.remote.aio(example) |
| 38 | + |
| 39 | + if result is None: |
| 40 | + print(f"Warning: Null result for {example.instance_id}") |
| 41 | + return {"instance_id": example.instance_id, "status": "error", "error_info": {"error_type": "NullResult", "error_message": "Process returned None"}} |
| 42 | + return result |
73 | 43 |
|
74 | 44 | except Exception as e: |
75 | | - print("Batch processing error:") |
76 | | - print(f"Type: {type(e).__name__}") |
| 45 | + error_type = type(e).__name__ |
| 46 | + error_info = { |
| 47 | + "error_type": error_type, |
| 48 | + "error_message": str(e), |
| 49 | + "traceback": traceback.format_exception(type(e), e, e.__traceback__), |
| 50 | + } |
| 51 | + |
| 52 | + if isinstance(e, modal.exception.Error): |
| 53 | + error_info["modal_error_code"] = getattr(e, "code", None) |
| 54 | + error_info["modal_error_details"] = getattr(e, "details", None) |
| 55 | + |
| 56 | + print(f"Error processing {example.instance_id} (attempt {attempt + 1}):") |
| 57 | + print(f"Type: {error_type}") |
77 | 58 | print(f"Message: {str(e)}") |
78 | | - traceback.print_exc() |
79 | | - |
80 | | - # Mark all examples in the batch as failed |
81 | | - for example in batch: |
82 | | - results.append( |
83 | | - { |
84 | | - "instance_id": example.instance_id, |
85 | | - "status": "error", |
86 | | - "error_info": {"error_type": type(e).__name__, "error_message": str(e), "traceback": traceback.format_exc(), "batch_failure": True}, |
87 | | - } |
88 | | - ) |
| 59 | + print("Traceback:") |
| 60 | + print("".join(error_info["traceback"])) |
89 | 61 |
|
90 | | - return results |
| 62 | + if attempt < max_retries: |
| 63 | + await queue.put((example, attempt + 1)) |
| 64 | + return None |
| 65 | + |
| 66 | + return {"instance_id": example.instance_id, "status": "error", "error_info": error_info} |
| 67 | + |
| 68 | + async def worker(): |
| 69 | + while True: |
| 70 | + try: |
| 71 | + example, attempt = await queue.get() |
| 72 | + |
| 73 | + if example.instance_id in results: |
| 74 | + queue.task_done() |
| 75 | + continue |
| 76 | + |
| 77 | + result = await process_example(example, attempt) |
| 78 | + |
| 79 | + if result is not None: |
| 80 | + results[example.instance_id] = result |
| 81 | + |
| 82 | + queue.task_done() |
| 83 | + |
| 84 | + except Exception as e: |
| 85 | + print(f"Worker error: {str(e)}") |
| 86 | + traceback.print_exc() |
| 87 | + queue.task_done() |
| 88 | + |
| 89 | + # Start workers |
| 90 | + workers = [asyncio.create_task(worker()) for _ in range(num_workers)] |
| 91 | + |
| 92 | + # Wait for queue to be fully processed |
| 93 | + await queue.join() |
| 94 | + |
| 95 | + # Cancel workers |
| 96 | + for w in workers: |
| 97 | + w.cancel() |
| 98 | + |
| 99 | + # Wait for all workers to be cancelled |
| 100 | + await asyncio.gather(*workers, return_exceptions=True) |
| 101 | + |
| 102 | + # Return results in the same order as input examples |
| 103 | + return [results[example.instance_id] for example in examples] |
91 | 104 |
|
92 | 105 |
|
93 | | -def process_batch_local(examples: list[SweBenchExample], batch_size=10, codebases: dict[str, Codebase] = {}): |
| 106 | +def process_batch_local(examples: list[SweBenchExample], num_workers=5, codebases: dict[str, Codebase] = {}): |
94 | 107 | """Process a batch of examples synchronously. |
95 | 108 |
|
96 | 109 | Args: |
97 | 110 | examples: List of SweBenchExample objects to process |
98 | | - batch_size: Number of examples to process in each batch. |
| 111 | + num_workers: Number of examples to process in each batch. |
99 | 112 | Default is 10 to avoid overwhelming the system. |
100 | 113 | """ |
101 | 114 | results = [] |
102 | 115 |
|
103 | 116 | # Process examples in batches |
104 | | - for i in range(0, len(examples), batch_size): |
105 | | - batch = examples[i : i + batch_size] |
106 | | - print(f"Processing batch {i // batch_size + 1}/{len(examples) // batch_size + 1} (examples {i + 1}-{min(i + batch_size, len(examples))})") |
| 117 | + for i in range(0, len(examples), num_workers): |
| 118 | + batch = examples[i : i + num_workers] |
| 119 | + print(f"Processing batch {i // num_workers + 1}/{len(examples) // num_workers + 1} (examples {i + 1}-{min(i + num_workers, len(examples))})") |
107 | 120 |
|
108 | 121 | # Process each example in the batch |
109 | 122 | for example in batch: |
@@ -134,7 +147,9 @@ def process_batch_local(examples: list[SweBenchExample], batch_size=10, codebase |
134 | 147 | return results |
135 | 148 |
|
136 | 149 |
|
137 | | -async def run_eval(use_existing_preds: str | None, dataset: str, length: int, instance_id: str | None = None, local: bool = False, codebases: dict[str, Codebase] = {}, repo: str | None = None): |
| 150 | +async def run_eval( |
| 151 | + use_existing_preds: str | None, dataset: str, length: int, instance_id: str | None = None, local: bool = False, codebases: dict[str, Codebase] = {}, repo: str | None = None, num_workers: int = 5 |
| 152 | +): |
138 | 153 | run_id = use_existing_preds or str(uuid.uuid4()) |
139 | 154 | print(f"Run ID: {run_id}") |
140 | 155 | predictions_dir = PREDS_DNAME / f"results_{run_id}" |
@@ -162,7 +177,7 @@ async def run_eval(use_existing_preds: str | None, dataset: str, length: int, in |
162 | 177 | if local: |
163 | 178 | results = process_batch_local(examples, codebases=codebases) |
164 | 179 | else: |
165 | | - results = await process_batch_modal(examples) |
| 180 | + results = await process_batch_modal(examples, num_workers=num_workers) |
166 | 181 |
|
167 | 182 | # Save individual results |
168 | 183 | for result in results: |
@@ -218,9 +233,12 @@ async def run_eval(use_existing_preds: str | None, dataset: str, length: int, in |
218 | 233 | @click.option("--instance-id", help="The instance ID of the example to process.", type=str, default=None) |
219 | 234 | @click.option("--local", help="Run the evaluation locally.", is_flag=True, default=False) |
220 | 235 | @click.option("--repo", help="The repo to use.", type=str, default=None) |
221 | | -def run_eval_command(use_existing_preds, dataset, length, instance_id, local, repo): |
| 236 | +@click.option( |
| 237 | + "--num-workers", help="The number of workers to use. This is the number of examples that will be processed concurrently. A large number may lead to rate limiting issues.", type=int, default=5 |
| 238 | +) |
| 239 | +def run_eval_command(use_existing_preds, dataset, length, instance_id, local, repo, num_workers): |
222 | 240 | print(f"Repo: {repo}") |
223 | | - asyncio.run(run_eval(use_existing_preds=use_existing_preds, dataset=dataset, length=length, instance_id=instance_id, codebases=None, local=local, repo=repo)) |
| 241 | + asyncio.run(run_eval(use_existing_preds=use_existing_preds, dataset=dataset, length=length, instance_id=instance_id, codebases=None, local=local, repo=repo, num_workers=num_workers)) |
224 | 242 |
|
225 | 243 |
|
226 | 244 | if __name__ == "__main__": |
|
0 commit comments