Skip to content

Commit e23dac4

Browse files
Dynamic scaling of asyncio request (#721)
# Motivation Scales number of concurrent request dynamically based on whether or not rate limits are hit. # Content <!-- Please include a summary of the change --> # Testing <!-- How was the change tested? --> # Please check the following before marking your PR as ready for review - [ ] I have added tests for my changes - [ ] I have updated the documentation or added new documentation as needed
1 parent 4d4b079 commit e23dac4

File tree

1 file changed

+142
-23
lines changed
  • codegen-examples/examples/swebench_agent_run

1 file changed

+142
-23
lines changed

codegen-examples/examples/swebench_agent_run/run_eval.py

Lines changed: 142 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import uuid
66
import modal
77
import click
8-
from datetime import datetime
8+
import time
99
from codegen.extensions.swebench.harness import run_agent_on_entry
1010
from codegen.extensions.swebench.utils import SWEBenchDataset, SweBenchExample, get_swe_bench_examples
1111
from codegen.extensions.swebench.report import generate_report
@@ -17,28 +17,112 @@
1717
run_agent_modal = modal.Function.from_name(app_name="swebench-agent-run", name="run_agent_modal")
1818

1919

20-
async def process_batch_modal(examples: list[SweBenchExample], num_workers=10, max_retries=3):
21-
"""Process a batch of examples concurrently using a queue system.
20+
async def process_batch_modal(examples: list[SweBenchExample], num_workers=5, min_workers=1, max_retries=3):
21+
"""Process a batch of examples concurrently using a queue system with incremental worker scaling.
2222
2323
Args:
2424
examples: List of SweBenchExample objects to process
25-
num_workers: Number of examples to process concurrently
25+
num_workers: Initial number of examples to process concurrently
26+
min_workers: Minimum number of concurrent workers to maintain
2627
max_retries: Maximum number of retries for failed requests
2728
"""
2829
results = {}
2930
queue = asyncio.Queue()
3031

32+
# Shared state for worker management
33+
state = {
34+
"active_workers": num_workers,
35+
"success_streak": 0,
36+
"last_scaling_time": time.time(),
37+
"scaling_cooldown": 0, # seconds between scaling operations
38+
"worker_tasks": [],
39+
"running": True,
40+
}
41+
42+
# Use a lock to protect shared state during adjustments
43+
state_lock = asyncio.Lock()
44+
3145
# Initialize the queue with (example, attempt) tuples
3246
for example in examples:
3347
await queue.put((example, 0)) # 0 represents first attempt
3448

35-
async def process_example(example, attempt):
49+
async def scale_down_worker(task_to_cancel=None):
50+
"""Remove a single worker when rate limiting is detected"""
51+
async with state_lock:
52+
# Only scale if cooldown period has passed and we're above min_workers
53+
current_time = time.time()
54+
if current_time - state["last_scaling_time"] < state["scaling_cooldown"] or state["active_workers"] <= min_workers:
55+
return False
56+
57+
# Reset success streak when scaling down
58+
state["success_streak"] = 0
59+
state["last_scaling_time"] = current_time
60+
61+
# If a specific task was provided, cancel it
62+
if task_to_cancel and task_to_cancel in state["worker_tasks"]:
63+
print(f"Rate limiting detected! Removing 1 worker, going from {state['active_workers']} to {state['active_workers'] - 1}")
64+
state["worker_tasks"].remove(task_to_cancel)
65+
task_to_cancel.cancel()
66+
state["active_workers"] -= 1
67+
return True
68+
69+
# Otherwise, cancel the most recently added worker
70+
elif state["worker_tasks"]:
71+
print(f"Rate limiting detected! Removing 1 worker, going from {state['active_workers']} to {state['active_workers'] - 1}")
72+
task = state["worker_tasks"].pop()
73+
task.cancel()
74+
state["active_workers"] -= 1
75+
return True
76+
77+
return False
78+
79+
async def scale_up_worker():
80+
"""Add a single worker when operations have been consistently successful"""
81+
async with state_lock:
82+
# Only scale if cooldown period has passed and we're below num_workers
83+
current_time = time.time()
84+
if current_time - state["last_scaling_time"] < state["scaling_cooldown"] or state["active_workers"] >= num_workers:
85+
return False
86+
87+
# Add a worker after a streak of successful operations
88+
if state["success_streak"] >= 5:
89+
print(f"Operations succeeding! Adding 1 worker, going from {state['active_workers']} to {state['active_workers'] + 1}")
90+
91+
# Create new worker
92+
if state["running"]:
93+
new_task = asyncio.create_task(worker())
94+
state["worker_tasks"].append(new_task)
95+
state["active_workers"] += 1
96+
state["success_streak"] = 0
97+
state["last_scaling_time"] = current_time
98+
return True
99+
100+
return False
101+
102+
async def is_rate_limit_error(error):
103+
"""Determine if an error is due to rate limiting"""
104+
# Check for common rate limit error patterns
105+
if isinstance(error, modal.exception.Error):
106+
error_msg = str(error).lower()
107+
rate_limit_indicators = ["rate limit", "too many requests", "429", "throttle", "quota exceeded", "capacity", "limit exceeded"]
108+
return any(indicator in error_msg for indicator in rate_limit_indicators)
109+
return False
110+
111+
async def process_example(example, attempt, current_task):
36112
try:
37113
result = await run_agent_modal.remote.aio(example)
38114

39115
if result is None:
40116
print(f"Warning: Null result for {example.instance_id}")
41-
return {"instance_id": example.instance_id, "status": "error", "error_info": {"error_type": "NullResult", "error_message": "Process returned None"}}
117+
return {"status": "error", "instance_id": example.instance_id, "error_info": {"error_type": "NullResult", "error_message": "Process returned None"}}
118+
119+
# Increment success streak and potentially scale up
120+
async with state_lock:
121+
state["success_streak"] += 1
122+
123+
if state["success_streak"] % 5 == 0: # Check after every 5 successes
124+
await scale_up_worker()
125+
42126
return result
43127

44128
except Exception as e:
@@ -56,51 +140,86 @@ async def process_example(example, attempt):
56140
print(f"Error processing {example.instance_id} (attempt {attempt + 1}):")
57141
print(f"Type: {error_type}")
58142
print(f"Message: {str(e)}")
59-
print("Traceback:")
60-
print("".join(error_info["traceback"]))
143+
144+
# Check if this is a rate limit error
145+
if await is_rate_limit_error(e):
146+
print(f"Rate limit detected on task for {example.instance_id}")
147+
148+
# Scale down by removing this specific worker
149+
scaled_down = await scale_down_worker(current_task)
150+
151+
# If we're removing this worker, we need to requeue the task for another worker
152+
if scaled_down:
153+
# Requeue this example with the same attempt count (not incrementing)
154+
await queue.put((example, attempt))
155+
return None
156+
157+
# Otherwise add a small delay before retrying
158+
await asyncio.sleep(2 * (attempt + 1)) # Exponential backoff
61159

62160
if attempt < max_retries:
63161
await queue.put((example, attempt + 1))
64162
return None
65163

66-
return {"instance_id": example.instance_id, "status": "error", "error_info": error_info}
164+
return {"status": "error", "instance_id": example.instance_id, "error_info": error_info}
67165

68166
async def worker():
69-
while True:
167+
# Store this task reference to allow targeted cancellation
168+
current_task = asyncio.current_task()
169+
170+
while state["running"]:
70171
try:
71-
example, attempt = await queue.get()
172+
# Use a timeout to allow worker to check if it should exit
173+
try:
174+
example, attempt = await asyncio.wait_for(queue.get(), timeout=1.0)
175+
except asyncio.TimeoutError:
176+
continue
72177

73178
if example.instance_id in results:
74179
queue.task_done()
75180
continue
181+
print(f"Processing example {example.instance_id}")
182+
process_result = await process_example(example, attempt, current_task)
76183

77-
result = await process_example(example, attempt)
78-
79-
if result is not None:
80-
results[example.instance_id] = result
184+
# If we're still processing this task (not requeued due to rate limiting)
185+
if process_result is not None:
186+
results[example.instance_id] = {"instance_id": example.instance_id, **process_result}
187+
print(f"Processed example {example.instance_id}")
188+
queue.task_done()
81189

82-
queue.task_done()
190+
# If None is returned, the task was requeued due to rate limiting
191+
# and this worker is being shut down, so exit the loop
192+
else:
193+
print(f"Task for {example.instance_id} has been requeued")
194+
queue.task_done()
195+
if current_task not in state["worker_tasks"]:
196+
break
83197

198+
except asyncio.CancelledError:
199+
# Handle graceful cancellation
200+
print("Worker task cancelled")
201+
break
84202
except Exception as e:
85203
print(f"Worker error: {str(e)}")
86204
traceback.print_exc()
87205
queue.task_done()
88206

89-
# Start workers
90-
workers = [asyncio.create_task(worker()) for _ in range(num_workers)]
207+
# Start initial workers
208+
state["worker_tasks"] = [asyncio.create_task(worker()) for _ in range(num_workers)]
91209

92210
# Wait for queue to be fully processed
93211
await queue.join()
94212

95-
# Cancel workers
96-
for w in workers:
213+
# Mark as not running and cancel remaining workers
214+
state["running"] = False
215+
for w in state["worker_tasks"]:
97216
w.cancel()
98217

99218
# Wait for all workers to be cancelled
100-
await asyncio.gather(*workers, return_exceptions=True)
219+
await asyncio.gather(*state["worker_tasks"], return_exceptions=True)
101220

102221
# Return results in the same order as input examples
103-
return [results[example.instance_id] for example in examples]
222+
return [results.get(example.instance_id, {"instance_id": example.instance_id, "status": "missing"}) for example in examples]
104223

105224

106225
def process_batch_local(examples: list[SweBenchExample], num_workers=5, codebases: dict[str, Codebase] = {}):
@@ -171,7 +290,7 @@ async def run_eval(
171290
predictions_dir.mkdir(exist_ok=True, parents=True)
172291

173292
# Create a timestamp for this run
174-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
293+
timestamp = time.strftime("%Y-%m-%d %H:%M %Z", time.localtime(time.time()))
175294

176295
# Process all examples in parallel batches
177296
if local:

0 commit comments

Comments
 (0)