diff --git a/autogpt_platform/backend/backend/executor/cluster_lock.py b/autogpt_platform/backend/backend/executor/cluster_lock.py new file mode 100644 index 000000000000..ad6362b5355b --- /dev/null +++ b/autogpt_platform/backend/backend/executor/cluster_lock.py @@ -0,0 +1,115 @@ +"""Redis-based distributed locking for cluster coordination.""" + +import logging +import time +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from redis import Redis + +logger = logging.getLogger(__name__) + + +class ClusterLock: + """Simple Redis-based distributed lock for preventing duplicate execution.""" + + def __init__(self, redis: "Redis", key: str, owner_id: str, timeout: int = 300): + self.redis = redis + self.key = key + self.owner_id = owner_id + self.timeout = timeout + self._last_refresh = 0.0 + + def try_acquire(self) -> str | None: + """Try to acquire the lock. + + Returns: + - owner_id (self.owner_id) if successfully acquired + - different owner_id if someone else holds the lock + - None if Redis is unavailable or other error + """ + try: + success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout) + if success: + self._last_refresh = time.time() + return self.owner_id # Successfully acquired + + # Failed to acquire, get current owner + current_value = self.redis.get(self.key) + if current_value: + current_owner = ( + current_value.decode("utf-8") + if isinstance(current_value, bytes) + else str(current_value) + ) + return current_owner + + # Key doesn't exist but we failed to set it - race condition or Redis issue + return None + + except Exception as e: + logger.error(f"ClusterLock.try_acquire failed for key {self.key}: {e}") + return None + + def refresh(self) -> bool: + """Refresh lock TTL if we still own it. + + Rate limited to at most once every timeout/10 seconds (minimum 1 second). + During rate limiting, still verifies lock existence but skips TTL extension. + Setting _last_refresh to 0 bypasses rate limiting for testing. + """ + # Calculate refresh interval: max(timeout // 10, 1) + refresh_interval = max(self.timeout // 10, 1) + current_time = time.time() + + # Check if we're within the rate limit period + # _last_refresh == 0 forces a refresh (bypasses rate limiting for testing) + is_rate_limited = ( + self._last_refresh > 0 + and (current_time - self._last_refresh) < refresh_interval + ) + + try: + # Always verify lock existence, even during rate limiting + current_value = self.redis.get(self.key) + if not current_value: + self._last_refresh = 0 + return False + + stored_owner = ( + current_value.decode("utf-8") + if isinstance(current_value, bytes) + else str(current_value) + ) + if stored_owner != self.owner_id: + self._last_refresh = 0 + return False + + # If rate limited, return True but don't update TTL or timestamp + if is_rate_limited: + return True + + # Perform actual refresh + if self.redis.expire(self.key, self.timeout): + self._last_refresh = current_time + return True + + self._last_refresh = 0 + return False + + except Exception as e: + logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}") + self._last_refresh = 0 + return False + + def release(self): + """Release the lock.""" + if self._last_refresh == 0: + return + + try: + self.redis.delete(self.key) + except Exception: + pass + + self._last_refresh = 0.0 diff --git a/autogpt_platform/backend/backend/executor/cluster_lock_test.py b/autogpt_platform/backend/backend/executor/cluster_lock_test.py new file mode 100644 index 000000000000..c5d8965f0f9e --- /dev/null +++ b/autogpt_platform/backend/backend/executor/cluster_lock_test.py @@ -0,0 +1,507 @@ +""" +Integration tests for ClusterLock - Redis-based distributed locking. + +Tests the complete lock lifecycle without mocking Redis to ensure +real-world behavior is correct. Covers acquisition, refresh, expiry, +contention, and error scenarios. +""" + +import logging +import time +import uuid +from threading import Thread + +import pytest +import redis + +from .cluster_lock import ClusterLock + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def redis_client(): + """Get Redis client for testing using same config as backend.""" + from backend.data.redis_client import HOST, PASSWORD, PORT + + # Use same config as backend but without decode_responses since ClusterLock needs raw bytes + client = redis.Redis( + host=HOST, + port=PORT, + password=PASSWORD, + decode_responses=False, # ClusterLock needs raw bytes for ownership verification + ) + + # Clean up any existing test keys + try: + for key in client.scan_iter(match="test_lock:*"): + client.delete(key) + except Exception: + pass # Ignore cleanup errors + + return client + + +@pytest.fixture +def lock_key(): + """Generate unique lock key for each test.""" + return f"test_lock:{uuid.uuid4()}" + + +@pytest.fixture +def owner_id(): + """Generate unique owner ID for each test.""" + return str(uuid.uuid4()) + + +class TestClusterLockBasic: + """Basic lock acquisition and release functionality.""" + + def test_lock_acquisition_success(self, redis_client, lock_key, owner_id): + """Test basic lock acquisition succeeds.""" + lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60) + + # Lock should be acquired successfully + result = lock.try_acquire() + assert result == owner_id # Returns our owner_id when successfully acquired + assert lock._last_refresh > 0 + + # Lock key should exist in Redis + assert redis_client.exists(lock_key) == 1 + assert redis_client.get(lock_key).decode("utf-8") == owner_id + + def test_lock_acquisition_contention(self, redis_client, lock_key): + """Test second acquisition fails when lock is held.""" + owner1 = str(uuid.uuid4()) + owner2 = str(uuid.uuid4()) + + lock1 = ClusterLock(redis_client, lock_key, owner1, timeout=60) + lock2 = ClusterLock(redis_client, lock_key, owner2, timeout=60) + + # First lock should succeed + result1 = lock1.try_acquire() + assert result1 == owner1 # Successfully acquired, returns our owner_id + + # Second lock should fail and return the first owner + result2 = lock2.try_acquire() + assert result2 == owner1 # Returns the current owner (first owner) + assert lock2._last_refresh == 0 + + def test_lock_release_deletes_redis_key(self, redis_client, lock_key, owner_id): + """Test lock release deletes Redis key and marks locally as released.""" + lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60) + + lock.try_acquire() + assert lock._last_refresh > 0 + assert redis_client.exists(lock_key) == 1 + + # Release should delete Redis key and mark locally as released + lock.release() + assert lock._last_refresh == 0 + assert lock._last_refresh == 0.0 + + # Redis key should be deleted for immediate release + assert redis_client.exists(lock_key) == 0 + + # Another lock should be able to acquire immediately + new_owner_id = str(uuid.uuid4()) + new_lock = ClusterLock(redis_client, lock_key, new_owner_id, timeout=60) + assert new_lock.try_acquire() == new_owner_id + + +class TestClusterLockRefresh: + """Lock refresh and TTL management.""" + + def test_lock_refresh_success(self, redis_client, lock_key, owner_id): + """Test lock refresh extends TTL.""" + lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60) + + lock.try_acquire() + original_ttl = redis_client.ttl(lock_key) + + # Wait a bit then refresh + time.sleep(1) + lock._last_refresh = 0 # Force refresh past rate limit + assert lock.refresh() is True + + # TTL should be reset to full timeout (allow for small timing differences) + new_ttl = redis_client.ttl(lock_key) + assert new_ttl >= original_ttl or new_ttl >= 58 # Allow for timing variance + + def test_lock_refresh_rate_limiting(self, redis_client, lock_key, owner_id): + """Test refresh is rate-limited to timeout/10.""" + lock = ClusterLock( + redis_client, lock_key, owner_id, timeout=100 + ) # 100s timeout + + lock.try_acquire() + + # First refresh should work + assert lock.refresh() is True + first_refresh_time = lock._last_refresh + + # Immediate second refresh should be skipped (rate limited) but verify key exists + assert lock.refresh() is True # Returns True but skips actual refresh + assert lock._last_refresh == first_refresh_time # Time unchanged + + def test_lock_refresh_verifies_existence_during_rate_limit( + self, redis_client, lock_key, owner_id + ): + """Test refresh verifies lock existence even during rate limiting.""" + lock = ClusterLock(redis_client, lock_key, owner_id, timeout=100) + + lock.try_acquire() + + # Manually delete the key (simulates expiry or external deletion) + redis_client.delete(lock_key) + + # Refresh should detect missing key even during rate limit period + assert lock.refresh() is False + assert lock._last_refresh == 0 + + def test_lock_refresh_ownership_lost(self, redis_client, lock_key, owner_id): + """Test refresh fails when ownership is lost.""" + lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60) + + lock.try_acquire() + + # Simulate another process taking the lock + different_owner = str(uuid.uuid4()) + redis_client.set(lock_key, different_owner, ex=60) + + # Force refresh past rate limit and verify it fails + lock._last_refresh = 0 # Force refresh past rate limit + assert lock.refresh() is False + assert lock._last_refresh == 0 + + def test_lock_refresh_when_not_acquired(self, redis_client, lock_key, owner_id): + """Test refresh fails when lock was never acquired.""" + lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60) + + # Refresh without acquiring should fail + assert lock.refresh() is False + + +class TestClusterLockExpiry: + """Lock expiry and timeout behavior.""" + + def test_lock_natural_expiry(self, redis_client, lock_key, owner_id): + """Test lock expires naturally via Redis TTL.""" + lock = ClusterLock( + redis_client, lock_key, owner_id, timeout=2 + ) # 2 second timeout + + lock.try_acquire() + assert redis_client.exists(lock_key) == 1 + + # Wait for expiry + time.sleep(3) + assert redis_client.exists(lock_key) == 0 + + # New lock with same key should succeed + new_lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60) + assert new_lock.try_acquire() == owner_id + + def test_lock_refresh_prevents_expiry(self, redis_client, lock_key, owner_id): + """Test refreshing prevents lock from expiring.""" + lock = ClusterLock( + redis_client, lock_key, owner_id, timeout=3 + ) # 3 second timeout + + lock.try_acquire() + + # Wait and refresh before expiry + time.sleep(1) + lock._last_refresh = 0 # Force refresh past rate limit + assert lock.refresh() is True + + # Wait beyond original timeout + time.sleep(2.5) + assert redis_client.exists(lock_key) == 1 # Should still exist + + +class TestClusterLockConcurrency: + """Concurrent access patterns.""" + + def test_multiple_threads_contention(self, redis_client, lock_key): + """Test multiple threads competing for same lock.""" + num_threads = 5 + successful_acquisitions = [] + + def try_acquire_lock(thread_id): + owner_id = f"thread_{thread_id}" + lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60) + if lock.try_acquire() == owner_id: + successful_acquisitions.append(thread_id) + time.sleep(0.1) # Hold lock briefly + lock.release() + + threads = [] + for i in range(num_threads): + thread = Thread(target=try_acquire_lock, args=(i,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # Only one thread should have acquired the lock + assert len(successful_acquisitions) == 1 + + def test_sequential_lock_reuse(self, redis_client, lock_key): + """Test lock can be reused after natural expiry.""" + owners = [str(uuid.uuid4()) for _ in range(3)] + + for i, owner_id in enumerate(owners): + lock = ClusterLock(redis_client, lock_key, owner_id, timeout=1) # 1 second + + assert lock.try_acquire() == owner_id + time.sleep(1.5) # Wait for expiry + + # Verify lock expired + assert redis_client.exists(lock_key) == 0 + + def test_refresh_during_concurrent_access(self, redis_client, lock_key): + """Test lock refresh works correctly during concurrent access attempts.""" + owner1 = str(uuid.uuid4()) + owner2 = str(uuid.uuid4()) + + lock1 = ClusterLock(redis_client, lock_key, owner1, timeout=5) + lock2 = ClusterLock(redis_client, lock_key, owner2, timeout=5) + + # Thread 1 holds lock and refreshes + assert lock1.try_acquire() == owner1 + + def refresh_continuously(): + for _ in range(10): + lock1._last_refresh = 0 # Force refresh + lock1.refresh() + time.sleep(0.1) + + def try_acquire_continuously(): + attempts = 0 + while attempts < 20: + if lock2.try_acquire() == owner2: + return True + time.sleep(0.1) + attempts += 1 + return False + + refresh_thread = Thread(target=refresh_continuously) + acquire_thread = Thread(target=try_acquire_continuously) + + refresh_thread.start() + acquire_thread.start() + + refresh_thread.join() + acquire_thread.join() + + # Lock1 should still own the lock due to refreshes + assert lock1._last_refresh > 0 + assert lock2._last_refresh == 0 + + +class TestClusterLockErrorHandling: + """Error handling and edge cases.""" + + def test_redis_connection_failure_on_acquire(self, lock_key, owner_id): + """Test graceful handling when Redis is unavailable during acquisition.""" + # Use invalid Redis connection + bad_redis = redis.Redis( + host="invalid_host", port=1234, socket_connect_timeout=1 + ) + lock = ClusterLock(bad_redis, lock_key, owner_id, timeout=60) + + # Should return None for Redis connection failures + result = lock.try_acquire() + assert result is None # Returns None when Redis fails + assert lock._last_refresh == 0 + + def test_redis_connection_failure_on_refresh( + self, redis_client, lock_key, owner_id + ): + """Test graceful handling when Redis fails during refresh.""" + lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60) + + # Acquire normally + assert lock.try_acquire() == owner_id + + # Replace Redis client with failing one + lock.redis = redis.Redis( + host="invalid_host", port=1234, socket_connect_timeout=1 + ) + + # Refresh should fail gracefully + lock._last_refresh = 0 # Force refresh + assert lock.refresh() is False + assert lock._last_refresh == 0 + + def test_invalid_lock_parameters(self, redis_client): + """Test validation of lock parameters.""" + owner_id = str(uuid.uuid4()) + + # All parameters are now simple - no validation needed + # Just test basic construction works + lock = ClusterLock(redis_client, "test_key", owner_id, timeout=60) + assert lock.key == "test_key" + assert lock.owner_id == owner_id + assert lock.timeout == 60 + + def test_refresh_after_redis_key_deleted(self, redis_client, lock_key, owner_id): + """Test refresh behavior when Redis key is manually deleted.""" + lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60) + + lock.try_acquire() + + # Manually delete the key (simulates external deletion) + redis_client.delete(lock_key) + + # Refresh should fail and mark as not acquired + lock._last_refresh = 0 # Force refresh + assert lock.refresh() is False + assert lock._last_refresh == 0 + + +class TestClusterLockDynamicRefreshInterval: + """Dynamic refresh interval based on timeout.""" + + def test_refresh_interval_calculation(self, redis_client, lock_key, owner_id): + """Test refresh interval is calculated as max(timeout/10, 1).""" + test_cases = [ + (5, 1), # 5/10 = 0, but minimum is 1 + (10, 1), # 10/10 = 1 + (30, 3), # 30/10 = 3 + (100, 10), # 100/10 = 10 + (200, 20), # 200/10 = 20 + (1000, 100), # 1000/10 = 100 + ] + + for timeout, expected_interval in test_cases: + lock = ClusterLock( + redis_client, f"{lock_key}_{timeout}", owner_id, timeout=timeout + ) + lock.try_acquire() + + # Calculate expected interval using same logic as implementation + refresh_interval = max(timeout // 10, 1) + assert refresh_interval == expected_interval + + # Test rate limiting works with calculated interval + assert lock.refresh() is True + first_refresh_time = lock._last_refresh + + # Sleep less than interval - should be rate limited + time.sleep(0.1) + assert lock.refresh() is True + assert lock._last_refresh == first_refresh_time # No actual refresh + + +class TestClusterLockRealWorldScenarios: + """Real-world usage patterns.""" + + def test_execution_coordination_simulation(self, redis_client): + """Simulate graph execution coordination across multiple pods.""" + graph_exec_id = str(uuid.uuid4()) + lock_key = f"execution:{graph_exec_id}" + + # Simulate 3 pods trying to execute same graph + pods = [f"pod_{i}" for i in range(3)] + execution_results = {} + + def execute_graph(pod_id): + """Simulate graph execution with cluster lock.""" + lock = ClusterLock(redis_client, lock_key, pod_id, timeout=300) + + if lock.try_acquire() == pod_id: + # Simulate execution work + execution_results[pod_id] = "executed" + time.sleep(0.1) + lock.release() + else: + execution_results[pod_id] = "rejected" + + threads = [] + for pod_id in pods: + thread = Thread(target=execute_graph, args=(pod_id,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # Only one pod should have executed + executed_count = sum( + 1 for result in execution_results.values() if result == "executed" + ) + rejected_count = sum( + 1 for result in execution_results.values() if result == "rejected" + ) + + assert executed_count == 1 + assert rejected_count == 2 + + def test_long_running_execution_with_refresh( + self, redis_client, lock_key, owner_id + ): + """Test lock maintains ownership during long execution with periodic refresh.""" + lock = ClusterLock( + redis_client, lock_key, owner_id, timeout=30 + ) # 30 second timeout, refresh interval = max(30//10, 1) = 3 seconds + + def long_execution_with_refresh(): + """Simulate long-running execution with periodic refresh.""" + assert lock.try_acquire() == owner_id + + # Simulate 10 seconds of work with refreshes every 2 seconds + # This respects rate limiting - actual refreshes will happen at 0s, 3s, 6s, 9s + try: + for i in range(5): # 5 iterations * 2 seconds = 10 seconds total + time.sleep(2) + refresh_success = lock.refresh() + assert refresh_success is True, f"Refresh failed at iteration {i}" + return "completed" + finally: + lock.release() + + # Should complete successfully without losing lock + result = long_execution_with_refresh() + assert result == "completed" + + def test_graceful_degradation_pattern(self, redis_client, lock_key): + """Test graceful degradation when Redis becomes unavailable.""" + owner_id = str(uuid.uuid4()) + lock = ClusterLock( + redis_client, lock_key, owner_id, timeout=3 + ) # Use shorter timeout + + # Normal operation + assert lock.try_acquire() == owner_id + lock._last_refresh = 0 # Force refresh past rate limit + assert lock.refresh() is True + + # Simulate Redis becoming unavailable + original_redis = lock.redis + lock.redis = redis.Redis( + host="invalid_host", + port=1234, + socket_connect_timeout=1, + decode_responses=False, + ) + + # Should degrade gracefully + lock._last_refresh = 0 # Force refresh past rate limit + assert lock.refresh() is False + assert lock._last_refresh == 0 + + # Restore Redis and verify can acquire again + lock.redis = original_redis + # Wait for original lock to expire (use longer wait for 3s timeout) + time.sleep(4) + + new_lock = ClusterLock(redis_client, lock_key, owner_id, timeout=60) + assert new_lock.try_acquire() == owner_id + + +if __name__ == "__main__": + # Run specific test for quick validation + pytest.main([__file__, "-v"]) diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index 43696d7bceb9..95ce5699a984 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -3,6 +3,7 @@ import os import threading import time +import uuid from collections import defaultdict from concurrent.futures import Future, ThreadPoolExecutor from contextlib import asynccontextmanager @@ -10,31 +11,11 @@ from pika.adapters.blocking_connection import BlockingChannel from pika.spec import Basic, BasicProperties -from redis.asyncio.lock import Lock as RedisLock - -from backend.blocks.io import AgentOutputBlock -from backend.data.model import GraphExecutionStats, NodeExecutionStats -from backend.data.notifications import ( - AgentRunData, - LowBalanceData, - NotificationEventModel, - NotificationType, - ZeroBalanceData, -) -from backend.data.rabbitmq import SyncRabbitMQ -from backend.executor.activity_status_generator import ( - generate_activity_status_for_execution, -) -from backend.executor.utils import LogMetadata -from backend.notifications.notifications import queue_notification -from backend.util.exceptions import InsufficientBalanceError, ModerationError - -if TYPE_CHECKING: - from backend.executor import DatabaseManagerClient, DatabaseManagerAsyncClient - from prometheus_client import Gauge, start_http_server +from redis.asyncio.lock import Lock as AsyncRedisLock from backend.blocks.agent import AgentExecutorBlock +from backend.blocks.io import AgentOutputBlock from backend.data import redis_client as redis from backend.data.block import ( BlockInput, @@ -55,12 +36,25 @@ UserContext, ) from backend.data.graph import Link, Node +from backend.data.model import GraphExecutionStats, NodeExecutionStats +from backend.data.notifications import ( + AgentRunData, + LowBalanceData, + NotificationEventModel, + NotificationType, + ZeroBalanceData, +) +from backend.data.rabbitmq import SyncRabbitMQ +from backend.executor.activity_status_generator import ( + generate_activity_status_for_execution, +) from backend.executor.utils import ( GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS, GRAPH_EXECUTION_CANCEL_QUEUE_NAME, GRAPH_EXECUTION_QUEUE_NAME, CancelExecutionEvent, ExecutionOutputEntry, + LogMetadata, NodeExecutionProgress, block_usage_cost, create_execution_queue_config, @@ -69,6 +63,7 @@ validate_exec, ) from backend.integrations.creds_manager import IntegrationCredentialsManager +from backend.notifications.notifications import queue_notification from backend.server.v2.AutoMod.manager import automod_manager from backend.util import json from backend.util.clients import ( @@ -84,6 +79,7 @@ error_logged, time_measured, ) +from backend.util.exceptions import InsufficientBalanceError, ModerationError from backend.util.file import clean_exec_files from backend.util.logging import TruncatedLogger, configure_logging from backend.util.metrics import DiscordChannel @@ -91,6 +87,12 @@ from backend.util.retry import continuous_retry, func_retry from backend.util.settings import Settings +from .cluster_lock import ClusterLock + +if TYPE_CHECKING: + from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient + + _logger = logging.getLogger(__name__) logger = TruncatedLogger(_logger, prefix="[GraphExecutor]") settings = Settings() @@ -106,6 +108,7 @@ "Ratio of active graph runs to max graph workers", ) + # Thread-local storage for ExecutionProcessor instances _tls = threading.local() @@ -117,10 +120,14 @@ def init_worker(): def execute_graph( - graph_exec_entry: "GraphExecutionEntry", cancel_event: threading.Event + graph_exec_entry: "GraphExecutionEntry", + cancel_event: threading.Event, + cluster_lock: ClusterLock, ): """Execute graph using thread-local ExecutionProcessor instance""" - return _tls.processor.on_graph_execution(graph_exec_entry, cancel_event) + return _tls.processor.on_graph_execution( + graph_exec_entry, cancel_event, cluster_lock + ) T = TypeVar("T") @@ -583,6 +590,7 @@ def on_graph_execution( self, graph_exec: GraphExecutionEntry, cancel: threading.Event, + cluster_lock: ClusterLock, ): log_metadata = LogMetadata( logger=_logger, @@ -641,6 +649,7 @@ def on_graph_execution( cancel=cancel, log_metadata=log_metadata, execution_stats=exec_stats, + cluster_lock=cluster_lock, ) exec_stats.walltime += timing_info.wall_time exec_stats.cputime += timing_info.cpu_time @@ -742,6 +751,7 @@ def _on_graph_execution( cancel: threading.Event, log_metadata: LogMetadata, execution_stats: GraphExecutionStats, + cluster_lock: ClusterLock, ) -> ExecutionStatus: """ Returns: @@ -927,7 +937,7 @@ def _on_graph_execution( and execution_queue.empty() and (running_node_execution or running_node_evaluation) ): - # There is nothing to execute, and no output to process, let's relax for a while. + cluster_lock.refresh() time.sleep(0.1) # loop done -------------------------------------------------- @@ -1219,6 +1229,7 @@ def __init__(self): super().__init__() self.pool_size = settings.config.num_graph_workers self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {} + self.executor_id = str(uuid.uuid4()) self._executor = None self._stop_consuming = None @@ -1228,6 +1239,8 @@ def __init__(self): self._run_thread = None self._run_client = None + self._execution_locks = {} + @property def cancel_thread(self) -> threading.Thread: if self._cancel_thread is None: @@ -1435,17 +1448,46 @@ def _ack_message(reject: bool, requeue: bool): logger.info( f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}" ) + + # Check for local duplicate execution first if graph_exec_id in self.active_graph_runs: - # TODO: Make this check cluster-wide, prevent duplicate runs across executor pods. - logger.error( - f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run." + logger.warning( + f"[{self.service_name}] Graph {graph_exec_id} already running locally; rejecting duplicate." ) - _ack_message(reject=True, requeue=False) + _ack_message(reject=True, requeue=True) return + # Try to acquire cluster-wide execution lock + cluster_lock = ClusterLock( + redis=redis.get_redis(), + key=f"exec_lock:{graph_exec_id}", + owner_id=self.executor_id, + timeout=settings.config.cluster_lock_timeout, + ) + current_owner = cluster_lock.try_acquire() + if current_owner != self.executor_id: + # Either someone else has it or Redis is unavailable + if current_owner is not None: + logger.warning( + f"[{self.service_name}] Graph {graph_exec_id} already running on pod {current_owner}" + ) + else: + logger.warning( + f"[{self.service_name}] Could not acquire lock for {graph_exec_id} - Redis unavailable" + ) + _ack_message(reject=True, requeue=True) + return + self._execution_locks[graph_exec_id] = cluster_lock + + logger.info( + f"[{self.service_name}] Acquired cluster lock for {graph_exec_id} with executor {self.executor_id}" + ) + cancel_event = threading.Event() - future = self.executor.submit(execute_graph, graph_exec_entry, cancel_event) + future = self.executor.submit( + execute_graph, graph_exec_entry, cancel_event, cluster_lock + ) self.active_graph_runs[graph_exec_id] = (future, cancel_event) self._update_prompt_metrics() @@ -1464,6 +1506,10 @@ def _on_run_done(f: Future): f"[{self.service_name}] Error in run completion callback: {e}" ) finally: + # Release the cluster-wide execution lock + if graph_exec_id in self._execution_locks: + self._execution_locks[graph_exec_id].release() + del self._execution_locks[graph_exec_id] self._cleanup_completed_runs() future.add_done_callback(_on_run_done) @@ -1546,6 +1592,10 @@ def cleanup(self): f"{prefix} ⏳ Still waiting for {len(self.active_graph_runs)} executions: {ids}" ) + for graph_exec_id in self.active_graph_runs: + if lock := self._execution_locks.get(graph_exec_id): + lock.refresh() + time.sleep(wait_interval) waited += wait_interval @@ -1563,6 +1613,15 @@ def cleanup(self): except Exception as e: logger.error(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}") + # Release remaining execution locks + try: + for lock in self._execution_locks.values(): + lock.release() + self._execution_locks.clear() + logger.info(f"{prefix} ✅ Released execution locks") + except Exception as e: + logger.warning(f"{prefix} ⚠️ Failed to release all locks: {e}") + # Disconnect the run execution consumer self._stop_message_consumers( self.run_thread, @@ -1668,9 +1727,9 @@ def update_graph_execution_state( @asynccontextmanager -async def synchronized(key: str, timeout: int = 60): +async def synchronized(key: str, timeout: int = settings.config.cluster_lock_timeout): r = await redis.get_redis_async() - lock: RedisLock = r.lock(f"lock:{key}", timeout=timeout) + lock: AsyncRedisLock = r.lock(f"lock:{key}", timeout=timeout) try: await lock.acquire() yield diff --git a/autogpt_platform/backend/backend/util/settings.py b/autogpt_platform/backend/backend/util/settings.py index cac358d42d1c..49c3790df49d 100644 --- a/autogpt_platform/backend/backend/util/settings.py +++ b/autogpt_platform/backend/backend/util/settings.py @@ -127,6 +127,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings): default=5 * 60, description="Time in seconds after which the execution stuck on QUEUED status is considered late.", ) + cluster_lock_timeout: int = Field( + default=300, + description="Cluster lock timeout in seconds for graph execution coordination.", + ) execution_late_notification_checkrange_secs: int = Field( default=60 * 60, description="Time in seconds for how far back to check for the late executions.",