diff --git a/.agents/AGENTS.md b/.agents/AGENTS.md index 8acceeb4d..11aac41fe 100644 --- a/.agents/AGENTS.md +++ b/.agents/AGENTS.md @@ -107,29 +107,34 @@ docker compose restart django celeryworker ### Backend (Django) -Run tests: +Run tests (use `docker-compose.ci.yml` to avoid conflicts with the local dev stack): +```bash +docker compose -f docker-compose.ci.yml run --rm django python manage.py test +``` + +Run a specific test module: ```bash -docker compose run --rm django python manage.py test +docker compose -f docker-compose.ci.yml run --rm django python manage.py test ami.ml.orchestration.tests.test_nats_connection ``` Run specific test pattern: ```bash -docker compose run --rm django python manage.py test -k pattern +docker compose -f docker-compose.ci.yml run --rm django python manage.py test -k pattern ``` Run tests with debugger on failure: ```bash -docker compose run --rm django python manage.py test -k pattern --failfast --pdb +docker compose -f docker-compose.ci.yml run --rm django python manage.py test -k pattern --failfast --pdb ``` Speed up test development (reuse database): ```bash -docker compose run --rm django python manage.py test --keepdb +docker compose -f docker-compose.ci.yml run --rm django python manage.py test --keepdb ``` Run pytest (alternative test runner): ```bash -docker compose run --rm django pytest --ds=config.settings.test --reuse-db +docker compose -f docker-compose.ci.yml run --rm django pytest --ds=config.settings.test --reuse-db ``` Django shell: @@ -654,13 +659,13 @@ images = SourceImage.objects.annotate(det_count=Count('detections')) ```bash # Run specific test class -docker compose run --rm django python manage.py test ami.main.tests.test_models.ProjectTestCase +docker compose -f docker-compose.ci.yml run --rm django python manage.py test ami.main.tests.test_models.ProjectTestCase # Run specific test method -docker compose run --rm django python manage.py test ami.main.tests.test_models.ProjectTestCase.test_project_creation +docker compose -f docker-compose.ci.yml run --rm django python manage.py test ami.main.tests.test_models.ProjectTestCase.test_project_creation # Run with pattern matching -docker compose run --rm django python manage.py test -k test_detection +docker compose -f docker-compose.ci.yml run --rm django python manage.py test -k test_detection ``` ### Pre-commit Hooks diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 3548d0ea5..94a6a821a 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -135,8 +135,8 @@ def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None: try: async def ack_task(): - async with TaskQueueManager() as manager: - return await manager.acknowledge_task(reply_subject) + manager = TaskQueueManager() + return await manager.acknowledge_task(reply_subject) ack_success = async_to_sync(ack_task)() diff --git a/ami/jobs/test_tasks.py b/ami/jobs/test_tasks.py index fc291c8ba..03c797d9c 100644 --- a/ami/jobs/test_tasks.py +++ b/ami/jobs/test_tasks.py @@ -73,10 +73,8 @@ def tearDown(self): def _setup_mock_nats(self, mock_manager_class): """Helper to setup mock NATS manager.""" - mock_manager = AsyncMock() + mock_manager = mock_manager_class.return_value mock_manager.acknowledge_task = AsyncMock(return_value=True) - mock_manager_class.return_value.__aenter__.return_value = mock_manager - mock_manager_class.return_value.__aexit__.return_value = AsyncMock() return mock_manager def _create_error_result(self, image_id: str | None = None, error_msg: str = "Processing failed") -> dict: diff --git a/ami/jobs/views.py b/ami/jobs/views.py index dd8da01b2..f45361007 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -242,11 +242,11 @@ def tasks(self, request, pk=None): async def get_tasks(): tasks = [] - async with TaskQueueManager() as manager: - for _ in range(batch): - task = await manager.reserve_task(job.pk, timeout=0.1) - if task: - tasks.append(task.dict()) + manager = TaskQueueManager() + for _ in range(batch): + task = await manager.reserve_task(job.pk, timeout=0.1) + if task: + tasks.append(task.dict()) return tasks # Use async_to_sync to properly handle the async call diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index 9b4a577e9..0f40e83bc 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -40,8 +40,8 @@ def cleanup_async_job_resources(job: "Job") -> bool: # Cleanup NATS resources async def cleanup(): - async with TaskQueueManager() as manager: - return await manager.cleanup_job_resources(job.pk) + manager = TaskQueueManager() + return await manager.cleanup_job_resources(job.pk) try: nats_success = async_to_sync(cleanup)() @@ -96,22 +96,22 @@ async def queue_all_images(): successful_queues = 0 failed_queues = 0 - async with TaskQueueManager() as manager: - for image_pk, task in tasks: - try: - logger.info(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}") - success = await manager.publish_task( - job_id=job.pk, - data=task, - ) - except Exception as e: - logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}") - success = False - - if success: - successful_queues += 1 - else: - failed_queues += 1 + manager = TaskQueueManager() + for image_pk, task in tasks: + try: + logger.info(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}") + success = await manager.publish_task( + job_id=job.pk, + data=task, + ) + except Exception as e: + logger.exception(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}") + success = False + + if success: + successful_queues += 1 + else: + failed_queues += 1 return successful_queues, failed_queues diff --git a/ami/ml/orchestration/nats_connection.py b/ami/ml/orchestration/nats_connection.py new file mode 100644 index 000000000..2b63443ae --- /dev/null +++ b/ami/ml/orchestration/nats_connection.py @@ -0,0 +1,275 @@ +""" +NATS connection management for both Celery workers and Django processes. + +Provides a ConnectionPool keyed by event loop. The pool reuses a single NATS +connection for all async operations *within* one async_to_sync() boundary. +It does NOT provide reuse across separate async_to_sync() calls — each call +creates a new event loop, so a new connection is established. + +Call paths and connection reuse +------------------------------- + +1. Queue images (high-value reuse): + POST /api/v2/jobs/{id}/run/ → Celery run_job → MLJob.run() + → queue_images_to_nats() wraps 1000+ sequential publish_task() awaits in + a single async_to_sync() call. All share one event loop → one connection. + Without the pool each publish would open its own TCP connection. + +2. Reserve tasks (moderate reuse): + GET /api/v2/jobs/{id}/tasks/?batch=N → JobViewSet.tasks() + → async_to_sync() wraps N sequential reserve_task() calls. Typical N=5-10. + +3. Acknowledge (single-use, no reuse): + POST /api/v2/jobs/{id}/result/ → Celery process_nats_pipeline_result + → _ack_task_via_nats() wraps a single acknowledge_task() in its own + async_to_sync() call. Each ACK gets its own event loop → new connection. + Pool overhead is negligible (one dict lookup). retry_on_connection_error + provides resilience. + +4. Cleanup (modest reuse): + Job completion → cleanup_async_job_resources() + → async_to_sync() wraps delete_consumer + delete_stream (2 ops, 1 conn). + +Concurrency model +----------------- + +All call paths use sequential for-loops, never asyncio.gather(). Within a +single async_to_sync() boundary there is only one coroutine running at a time. +This means: + +- The asyncio.Lock in ConnectionPool is defensive but never actually contends. +- reset() is only called from retry_on_connection_error between sequential + retries. No other coroutine races with it. +- Fast-path checks and state mutations outside the lock (lines ~87-94) are + safe because cooperative scheduling guarantees no preemption between + synchronous Python statements. +- Clearing self._lock in reset() is intentional — it ensures the replacement + lock is created bound to the current event loop state. + +Why NOT check is_reconnecting +----------------------------- + +Connections live inside short-lived async_to_sync() event loops. If the NATS +client enters RECONNECTING state, the event loop will typically be destroyed +before reconnection completes. Clearing the client and creating a fresh +connection is correct for this lifecycle. The retry_on_connection_error +decorator provides the real resilience layer, not nats.py's built-in +reconnection (which is designed for long-lived event loops). + +Why keyed by event loop +----------------------- + +asyncio.Lock and nats.Client are bound to the loop they were created on. +Sharing them across loops causes "attached to a different loop" errors. +Keying by loop ensures isolation. WeakKeyDictionary auto-cleans when loops +are garbage collected, so short-lived loops don't leak. + +Thread safety +------------- + +_pools_lock (threading.Lock) serializes access to the global _pools dict. +Multiple Celery worker threads can create pools concurrently — the lock +prevents races. Within a single event loop, the asyncio.Lock serializes +connection creation (though in practice it never contends, see above). + +Archived alternative +-------------------- + +ContextManagerConnection preserves the original pre-pool implementation +(one connection per `async with` block) as a drop-in fallback. +""" + +import asyncio +import logging +import threading +from typing import TYPE_CHECKING +from weakref import WeakKeyDictionary + +import nats +from django.conf import settings +from nats.js import JetStreamContext + +if TYPE_CHECKING: + from nats.aio.client import Client as NATSClient + +logger = logging.getLogger(__name__) + + +class ConnectionPool: + """ + Manages a single persistent NATS connection per event loop. + + This is safe because: + - asyncio.Lock and NATS Client are bound to the event loop they were created on + - Each event loop gets its own isolated connection and lock + - Works correctly with async_to_sync() which creates per-thread event loops + - Prevents "attached to a different loop" errors in Celery tasks and Django views + + Instantiating TaskQueueManager() is cheap — multiple instances share the same + underlying connection via this pool. + """ + + def __init__(self): + self._nc: "NATSClient | None" = None + self._js: JetStreamContext | None = None + self._lock: asyncio.Lock | None = None # Lazy-initialized when needed + + def _ensure_lock(self) -> asyncio.Lock: + """Lazily create lock bound to current event loop.""" + if self._lock is None: + self._lock = asyncio.Lock() + return self._lock + + async def get_connection(self) -> tuple["NATSClient", JetStreamContext]: + """ + Get or create the event loop's NATS connection. Checks connection health + and recreates if stale. + + Returns: + Tuple of (NATS connection, JetStream context) + Raises: + RuntimeError: If connection cannot be established + """ + # Fast path (no lock needed): connection exists, is open, and is connected. + # This is the hot path — most calls hit this and return immediately. + if self._nc is not None and self._js is not None and not self._nc.is_closed and self._nc.is_connected: + return self._nc, self._js + + # Connection is stale or doesn't exist — clear references before reconnecting + if self._nc is not None: + logger.warning("NATS connection is closed or disconnected, will reconnect") + self._nc = None + self._js = None + + # Slow path: acquire lock to prevent concurrent reconnection attempts + lock = self._ensure_lock() + async with lock: + # Double-check after acquiring lock (another coroutine may have reconnected) + if self._nc is not None and self._js is not None and not self._nc.is_closed and self._nc.is_connected: + return self._nc, self._js + + nats_url = settings.NATS_URL + try: + logger.info(f"Creating NATS connection to {nats_url}") + self._nc = await nats.connect(nats_url) + self._js = self._nc.jetstream() + logger.info(f"Successfully connected to NATS at {nats_url}") + return self._nc, self._js + except Exception as e: + logger.error(f"Failed to connect to NATS: {e}") + raise RuntimeError(f"Could not establish NATS connection: {e}") from e + + async def close(self): + """Close the NATS connection if it exists.""" + if self._nc is not None and not self._nc.is_closed: + logger.info("Closing NATS connection") + await self._nc.close() + self._nc = None + self._js = None + + async def reset(self): + """ + Close the current connection and clear all state so the next call to + get_connection() creates a fresh one. + + Called by retry_on_connection_error when an operation hits a connection + error (e.g. network blip, NATS restart). The lock is also cleared so it + gets recreated bound to the current event loop. + """ + logger.warning("Resetting NATS connection pool due to connection error") + if self._nc is not None: + try: + if not self._nc.is_closed: + await self._nc.close() + logger.debug("Successfully closed existing NATS connection during reset") + except Exception as e: + # Swallow errors - connection may already be broken + logger.debug(f"Error closing connection during reset (expected): {e}") + self._nc = None + self._js = None + self._lock = None # Clear lock so new one is created for fresh connection + + +class ContextManagerConnection: + """ + Archived pre-pool implementation: one NATS connection per `async with` block. + + This was the original approach before the connection pool was added. It creates + a fresh connection on get_connection() and expects the caller to close it when + done. There is no connection reuse and no retry logic at this layer. + + Trade-offs vs ConnectionPool: + - Simpler: no shared state, no locking, no event-loop keying + - Expensive: ~1500 TCP connections per 1000-image job vs 1 with the pool + - No automatic reconnection — caller must handle connection failures + + Kept as a drop-in fallback. To switch, change the class used in + _create_pool() below from ConnectionPool to ContextManagerConnection. + """ + + async def get_connection(self) -> tuple["NATSClient", JetStreamContext]: + """Create a fresh NATS connection.""" + nats_url = settings.NATS_URL + try: + logger.debug(f"Creating per-operation NATS connection to {nats_url}") + nc = await nats.connect(nats_url) + js = nc.jetstream() + return nc, js + except Exception as e: + logger.error(f"Failed to connect to NATS: {e}") + raise RuntimeError(f"Could not establish NATS connection: {e}") from e + + async def close(self): + """No-op — connections are not tracked.""" + pass + + async def reset(self): + """No-op — connections are not tracked.""" + pass + + +# Event-loop-keyed pools: one ConnectionPool per event loop. +# WeakKeyDictionary automatically cleans up when event loops are garbage collected. +_pools: WeakKeyDictionary[asyncio.AbstractEventLoop, ConnectionPool] = WeakKeyDictionary() +_pools_lock = threading.Lock() + + +def _get_pool() -> ConnectionPool: + """Get or create the ConnectionPool for the current event loop.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + raise RuntimeError( + "get_connection() must be called from an async context with a running event loop. " + "If calling from sync code, use async_to_sync() to wrap the async function." + ) from None + + with _pools_lock: + if loop not in _pools: + _pools[loop] = ConnectionPool() + logger.debug(f"Created NATS connection pool for event loop {id(loop)}") + return _pools[loop] + + +async def get_connection() -> tuple["NATSClient", JetStreamContext]: + """ + Get or create a NATS connection for the current event loop. + + Returns: + Tuple of (NATS connection, JetStream context) + Raises: + RuntimeError: If called outside of an async context (no running event loop) + """ + pool = _get_pool() + return await pool.get_connection() + + +async def reset_connection() -> None: + """ + Reset the NATS connection for the current event loop. + + Closes the current connection and clears all state so the next call to + get_connection() creates a fresh one. + """ + pool = _get_pool() + await pool.reset() diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index fa7188627..3fe1be355 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -10,57 +10,129 @@ support the visibility timeout semantics we want or a disconnected mode of pulling and ACKing tasks. """ +import asyncio +import functools import json import logging +from collections.abc import Callable +from typing import TYPE_CHECKING, TypeVar -import nats -from django.conf import settings +from nats import errors as nats_errors from nats.js import JetStreamContext from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy +from nats.js.errors import NotFoundError from ami.ml.schemas import PipelineProcessingTask -logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from collections.abc import Awaitable + from nats.aio.client import Client as NATSClient -async def get_connection(nats_url: str): - nc = await nats.connect(nats_url) - js = nc.jetstream() - return nc, js +logger = logging.getLogger(__name__) TASK_TTR = 300 # Default Time-To-Run (visibility timeout) in seconds +T = TypeVar("T") + + +def retry_on_connection_error(max_retries: int = 2, backoff_seconds: float = 0.5): + """ + Decorator that retries NATS operations on connection errors. When a connection error is detected: + 1. Resets the event-loop-local connection pool (clears stale connection and lock) + 2. Waits with exponential backoff + 3. Retries the operation (which will get a fresh connection from the same event loop) + + This works correctly with async_to_sync() because the pool is keyed by event loop, + ensuring each retry uses the connection bound to the current loop. + + Retried error types: + - ConnectionClosedError: server closed the connection + - NoServersError: cannot reach any NATS server + - TimeoutError: NATS operation timed out + - ConnectionReconnectingError: client is mid-reconnect + - StaleConnectionError: client detected a stale connection + - OSError: network-level failures — includes ConnectionRefusedError, + ConnectionResetError, BrokenPipeError, socket.timeout, and other + OS-level socket/DNS errors + + Args: + max_retries: Maximum number of retry attempts (default: 2) + backoff_seconds: Initial backoff time in seconds (default: 0.5) + Returns: + Decorated async function with retry logic + """ + + def decorator(func: Callable[..., "Awaitable[T]"]) -> Callable[..., "Awaitable[T]"]: + @functools.wraps(func) + async def wrapper(self, *args, **kwargs) -> T: + last_error = None + assert max_retries >= 0, "max_retries must be non-negative" + + for attempt in range(max_retries + 1): + try: + return await func(self, *args, **kwargs) + except ( + nats_errors.ConnectionClosedError, + nats_errors.NoServersError, + nats_errors.TimeoutError, + nats_errors.ConnectionReconnectingError, + nats_errors.StaleConnectionError, + OSError, # ConnectionRefusedError, ConnectionResetError, BrokenPipeError, etc. + ) as e: + last_error = e + # Don't retry on last attempt + if attempt == max_retries: + logger.error( + f"{func.__name__} failed after {max_retries + 1} attempts: {e}", + exc_info=True, + ) + break + # Reset the connection so next attempt gets a fresh one + from ami.ml.orchestration.nats_connection import reset_connection + + await reset_connection() + # Exponential backoff + wait_time = backoff_seconds * (2**attempt) + logger.warning( + f"{func.__name__} failed (attempt {attempt + 1}/{max_retries + 1}): {e}. " + f"Retrying in {wait_time}s..." + ) + await asyncio.sleep(wait_time) + # If we exhausted retries, raise the last error, guaranteed to be not None here + assert last_error is not None, "last_error should not be None if we exhausted retries" + raise last_error # type: ignore + + return wrapper + + return decorator + class TaskQueueManager: """ Manager for NATS JetStream task queue operations. - Use as an async context manager: - async with TaskQueueManager() as manager: - await manager.publish_task('job123', {'data': 'value'}) - task = await manager.reserve_task('job123') - await manager.acknowledge_task(task['reply_subject']) - """ + This class is a stateless wrapper — it holds no connection state itself. + All connections come from the event-loop-keyed provider in + nats_connection.py, so instantiating TaskQueueManager() is cheap + and multiple instances share the same underlying connection. - def __init__(self, nats_url: str | None = None): - self.nats_url = nats_url or getattr(settings, "NATS_URL", "nats://nats:4222") - self.nc: nats.NATS | None = None - self.js: JetStreamContext | None = None + Usage pattern (no context manager needed): + manager = TaskQueueManager() + await manager.publish_task(job_id, task) + await manager.acknowledge_task(reply_subject) - async def __aenter__(self): - """Create connection on enter.""" - self.nc, self.js = await get_connection(self.nats_url) - return self + Error handling: + The @retry_on_connection_error decorator on each method handles transient + connection failures by resetting the provider and retrying with backoff. + """ - async def __aexit__(self, exc_type, exc_val, exc_tb): - if self.js: - self.js = None - if self.nc and not self.nc.is_closed: - await self.nc.close() - self.nc = None + async def _get_connection(self) -> tuple["NATSClient", JetStreamContext]: + """Get connection from the event-loop-local pool.""" + from ami.ml.orchestration.nats_connection import get_connection - return False + return await get_connection() def _get_stream_name(self, job_id: int) -> str: """Get stream name from job_id.""" @@ -76,19 +148,18 @@ def _get_consumer_name(self, job_id: int) -> str: async def _ensure_stream(self, job_id: int): """Ensure stream exists for the given job.""" - if self.js is None: - raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + _, js = await self._get_connection() stream_name = self._get_stream_name(job_id) subject = self._get_subject(job_id) try: - await self.js.stream_info(stream_name) + await js.stream_info(stream_name) logger.debug(f"Stream {stream_name} already exists") - except Exception as e: + except NotFoundError as e: logger.warning(f"Stream {stream_name} does not exist: {e}") # Stream doesn't exist, create it - await self.js.add_stream( + await js.add_stream( name=stream_name, subjects=[subject], max_age=86400, # 24 hours retention @@ -97,19 +168,18 @@ async def _ensure_stream(self, job_id: int): async def _ensure_consumer(self, job_id: int): """Ensure consumer exists for the given job.""" - if self.js is None: - raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + _, js = await self._get_connection() stream_name = self._get_stream_name(job_id) consumer_name = self._get_consumer_name(job_id) subject = self._get_subject(job_id) try: - info = await self.js.consumer_info(stream_name, consumer_name) + info = await js.consumer_info(stream_name, consumer_name) logger.debug(f"Consumer {consumer_name} already exists: {info}") - except Exception: + except NotFoundError: # Consumer doesn't exist, create it - await self.js.add_consumer( + await js.add_consumer( stream=stream_name, config=ConsumerConfig( durable_name=consumer_name, @@ -123,178 +193,187 @@ async def _ensure_consumer(self, job_id: int): ) logger.info(f"Created consumer {consumer_name}") + @retry_on_connection_error(max_retries=2, backoff_seconds=0.5) async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool: """ - Publish a task to it's job queue. - + Publish a task to its job queue. + Automatically retries on connection errors with exponential backoff. Args: job_id: The job ID (integer primary key) data: PipelineProcessingTask object to be published - Returns: - bool: True if successful, False otherwise + bool: True if successful + Raises: + Connection errors are retried by decorator, other errors are raised """ - if self.js is None: - raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") - - try: - # Ensure stream and consumer exist - await self._ensure_stream(job_id) - await self._ensure_consumer(job_id) + _, js = await self._get_connection() - subject = self._get_subject(job_id) - # Convert Pydantic model to JSON - task_data = json.dumps(data.dict()) + # Ensure stream and consumer exist + await self._ensure_stream(job_id) + await self._ensure_consumer(job_id) - # Publish to JetStream - ack = await self.js.publish(subject, task_data.encode()) + subject = self._get_subject(job_id) + # Convert Pydantic model to JSON + task_data = json.dumps(data.dict()) - logger.info(f"Published task to stream for job '{job_id}', sequence {ack.seq}") - return True + # Publish to JetStream + # Note: JetStream publish() waits for PubAck, so it's implicitly flushed + ack = await js.publish(subject, task_data.encode()) - except Exception as e: - logger.error(f"Failed to publish task to stream for job '{job_id}': {e}") - return False + logger.info(f"Published task to stream for job '{job_id}', sequence {ack.seq}") + return True + @retry_on_connection_error(max_retries=2, backoff_seconds=0.5) async def reserve_task(self, job_id: int, timeout: float | None = None) -> PipelineProcessingTask | None: """ Reserve a task from the specified stream. - + Automatically retries on connection errors with exponential backoff. + Note: TimeoutError from fetch() (no messages) is NOT retried - only connection errors. Args: job_id: The job ID (integer primary key) to pull tasks from timeout: Timeout in seconds for reservation (default: 5 seconds) - Returns: PipelineProcessingTask with reply_subject set for acknowledgment, or None if no task available """ - if self.js is None: - raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + _, js = await self._get_connection() if timeout is None: timeout = 5 + # Ensure stream and consumer exist (let connection errors escape for retry) + await self._ensure_stream(job_id) + await self._ensure_consumer(job_id) + + consumer_name = self._get_consumer_name(job_id) + subject = self._get_subject(job_id) + + # Create ephemeral subscription for this pull + psub = await js.pull_subscribe(subject, consumer_name) + try: - # Ensure stream and consumer exist - await self._ensure_stream(job_id) - await self._ensure_consumer(job_id) + # Fetch a single message + msgs = await psub.fetch(1, timeout=timeout) - consumer_name = self._get_consumer_name(job_id) - subject = self._get_subject(job_id) + if msgs: + msg = msgs[0] + task_data = json.loads(msg.data.decode()) + metadata = msg.metadata - # Create ephemeral subscription for this pull - psub = await self.js.pull_subscribe(subject, consumer_name) + # Parse the task data into PipelineProcessingTask + task = PipelineProcessingTask(**task_data) + # Set the reply_subject for acknowledgment + task.reply_subject = msg.reply - try: - # Fetch a single message - msgs = await psub.fetch(1, timeout=timeout) - - if msgs: - msg = msgs[0] - task_data = json.loads(msg.data.decode()) - metadata = msg.metadata - - # Parse the task data into PipelineProcessingTask - task = PipelineProcessingTask(**task_data) - # Set the reply_subject for acknowledgment - task.reply_subject = msg.reply - - logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") - return task - - except nats.errors.TimeoutError: - # No messages available - logger.debug(f"No tasks available in stream for job '{job_id}'") - return None - finally: - # Always unsubscribe - await psub.unsubscribe() + logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") + return task - except Exception as e: - logger.error(f"Failed to reserve task from stream for job '{job_id}': {e}") return None + except nats_errors.TimeoutError: + # No messages available (expected behavior) + logger.debug(f"No tasks available in stream for job '{job_id}'") + return None + finally: + # Unsubscribe in its own try/except so it never masks exceptions from above + try: + await psub.unsubscribe() + except Exception as e: + logger.warning(f"Failed to unsubscribe pull subscription for job '{job_id}': {e}") + + @retry_on_connection_error(max_retries=2, backoff_seconds=0.5) async def acknowledge_task(self, reply_subject: str) -> bool: """ Acknowledge (delete) a completed task using its reply subject. - + Automatically retries on connection errors with exponential backoff. Args: reply_subject: The reply subject from reserve_task - Returns: bool: True if successful + Raises: + Connection errors are retried by decorator, other errors are raised """ - if self.nc is None: - raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + nc, _ = await self._get_connection() + # Don't catch connection errors - let retry decorator handle them + await nc.publish(reply_subject, b"+ACK") + + # CRITICAL: Flush to ensure ACK is sent immediately + # Without flush, ACKs may be buffered and not sent to NATS server try: - await self.nc.publish(reply_subject, b"+ACK") - logger.debug(f"Acknowledged task with reply subject {reply_subject}") - return True - except Exception as e: - logger.error(f"Failed to acknowledge task: {e}") - return False + await nc.flush(timeout=2) + except asyncio.TimeoutError as e: + # Flush timeout likely means connection is stale - re-raise to trigger retry + logger.warning(f"Flush timeout for ACK {reply_subject}, connection may be stale: {e}") + raise nats_errors.TimeoutError("Flush timeout") from e + logger.debug(f"Acknowledged task with reply subject {reply_subject}") + return True + + @retry_on_connection_error(max_retries=2, backoff_seconds=0.5) async def delete_consumer(self, job_id: int) -> bool: """ Delete the consumer for a job. - + Automatically retries on connection errors with exponential backoff. Args: job_id: The job ID (integer primary key) - Returns: - bool: True if successful, False otherwise + bool: True if successful + Raises: + Connection errors are retried by decorator, other errors are raised """ - if self.js is None: - raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + _, js = await self._get_connection() - try: - stream_name = self._get_stream_name(job_id) - consumer_name = self._get_consumer_name(job_id) + stream_name = self._get_stream_name(job_id) + consumer_name = self._get_consumer_name(job_id) - await self.js.delete_consumer(stream_name, consumer_name) - logger.info(f"Deleted consumer {consumer_name} for job '{job_id}'") - return True - except Exception as e: - logger.error(f"Failed to delete consumer for job '{job_id}': {e}") - return False + await js.delete_consumer(stream_name, consumer_name) + logger.info(f"Deleted consumer {consumer_name} for job '{job_id}'") + return True + @retry_on_connection_error(max_retries=2, backoff_seconds=0.5) async def delete_stream(self, job_id: int) -> bool: """ Delete the stream for a job. - + Automatically retries on connection errors with exponential backoff. Args: job_id: The job ID (integer primary key) - Returns: - bool: True if successful, False otherwise + bool: True if successful + Raises: + Connection errors are retried by decorator, other errors are raised """ - if self.js is None: - raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + _, js = await self._get_connection() - try: - stream_name = self._get_stream_name(job_id) + stream_name = self._get_stream_name(job_id) - await self.js.delete_stream(stream_name) - logger.info(f"Deleted stream {stream_name} for job '{job_id}'") - return True - except Exception as e: - logger.error(f"Failed to delete stream for job '{job_id}': {e}") - return False + await js.delete_stream(stream_name) + logger.info(f"Deleted stream {stream_name} for job '{job_id}'") + return True async def cleanup_job_resources(self, job_id: int) -> bool: """ Clean up all NATS resources (consumer and stream) for a job. - This should be called when a job completes or is cancelled. - + Best-effort cleanup - logs errors but doesn't fail if cleanup fails. Args: job_id: The job ID (integer primary key) - Returns: - bool: True if successful, False otherwise + bool: True if both cleanup operations succeeded, False otherwise """ - # Delete consumer first, then stream - consumer_deleted = await self.delete_consumer(job_id) - stream_deleted = await self.delete_stream(job_id) + consumer_deleted = False + stream_deleted = False + + # Delete consumer first, then stream (best-effort) + try: + await self.delete_consumer(job_id) + consumer_deleted = True + except Exception as e: + logger.warning(f"Failed to delete consumer for job {job_id} after retries: {e}") + + try: + await self.delete_stream(job_id) + stream_deleted = True + except Exception as e: + logger.warning(f"Failed to delete stream for job {job_id} after retries: {e}") return consumer_deleted and stream_deleted diff --git a/ami/ml/orchestration/tests/test_cleanup.py b/ami/ml/orchestration/tests/test_cleanup.py index ef8382d3d..a5fc6fcb7 100644 --- a/ami/ml/orchestration/tests/test_cleanup.py +++ b/ami/ml/orchestration/tests/test_cleanup.py @@ -68,25 +68,28 @@ def _verify_resources_created(self, job_id: int): # Verify NATS stream and consumer exist async def check_nats_resources(): - async with TaskQueueManager() as manager: - stream_name = manager._get_stream_name(job_id) - consumer_name = manager._get_consumer_name(job_id) - - # Try to get stream info - should succeed if created - stream_exists = True - try: - await manager.js.stream_info(stream_name) - except NotFoundError: - stream_exists = False - - # Try to get consumer info - should succeed if created - consumer_exists = True - try: - await manager.js.consumer_info(stream_name, consumer_name) - except NotFoundError: - consumer_exists = False - - return stream_exists, consumer_exists + manager = TaskQueueManager() + stream_name = manager._get_stream_name(job_id) + consumer_name = manager._get_consumer_name(job_id) + + # Get JetStream context + _, js = await manager._get_connection() + + # Try to get stream info - should succeed if created + stream_exists = True + try: + await js.stream_info(stream_name) + except NotFoundError: + stream_exists = False + + # Try to get consumer info - should succeed if created + consumer_exists = True + try: + await js.consumer_info(stream_name, consumer_name) + except NotFoundError: + consumer_exists = False + + return stream_exists, consumer_exists stream_exists, consumer_exists = async_to_sync(check_nats_resources)() @@ -134,25 +137,28 @@ def _verify_resources_cleaned(self, job_id: int): # Verify NATS stream and consumer are deleted async def check_nats_resources(): - async with TaskQueueManager() as manager: - stream_name = manager._get_stream_name(job_id) - consumer_name = manager._get_consumer_name(job_id) - - # Try to get stream info - should fail if deleted - stream_exists = True - try: - await manager.js.stream_info(stream_name) - except NotFoundError: - stream_exists = False - - # Try to get consumer info - should fail if deleted - consumer_exists = True - try: - await manager.js.consumer_info(stream_name, consumer_name) - except NotFoundError: - consumer_exists = False - - return stream_exists, consumer_exists + manager = TaskQueueManager() + stream_name = manager._get_stream_name(job_id) + consumer_name = manager._get_consumer_name(job_id) + + # Get JetStream context + _, js = await manager._get_connection() + + # Try to get stream info - should fail if deleted + stream_exists = True + try: + await js.stream_info(stream_name) + except NotFoundError: + stream_exists = False + + # Try to get consumer info - should fail if deleted + consumer_exists = True + try: + await js.consumer_info(stream_name, consumer_name) + except NotFoundError: + consumer_exists = False + + return stream_exists, consumer_exists stream_exists, consumer_exists = async_to_sync(check_nats_resources)() diff --git a/ami/ml/orchestration/tests/test_nats_connection.py b/ami/ml/orchestration/tests/test_nats_connection.py new file mode 100644 index 000000000..286995977 --- /dev/null +++ b/ami/ml/orchestration/tests/test_nats_connection.py @@ -0,0 +1,239 @@ +"""Unit tests for nats_connection module.""" + +import asyncio +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from ami.ml.orchestration.nats_connection import ConnectionPool + + +class TestConnectionPoolBehavior(unittest.IsolatedAsyncioTestCase): + """Test ConnectionPool lifecycle and connection reuse.""" + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_get_connection_creates_connection(self, mock_settings, mock_nats): + """Test that get_connection creates a connection on first call.""" + mock_settings.NATS_URL = "nats://test:4222" + + mock_nc = MagicMock() + mock_nc.is_closed = False + mock_nc.is_connected = True + mock_nc.close = AsyncMock() + mock_js = MagicMock() + mock_nc.jetstream.return_value = mock_js + + mock_nats.connect = AsyncMock(return_value=mock_nc) + + pool = ConnectionPool() + nc, js = await pool.get_connection() + + self.assertIs(nc, mock_nc) + self.assertIs(js, mock_js) + mock_nats.connect.assert_called_once_with("nats://test:4222") + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_get_connection_reuses_existing_connection(self, mock_settings, mock_nats): + """Test that get_connection reuses connection on subsequent calls.""" + mock_settings.NATS_URL = "nats://test:4222" + + mock_nc = MagicMock() + mock_nc.is_closed = False + mock_nc.is_connected = True + mock_nc.close = AsyncMock() + mock_js = MagicMock() + mock_nc.jetstream.return_value = mock_js + + mock_nats.connect = AsyncMock(return_value=mock_nc) + + pool = ConnectionPool() + + # First call + nc1, js1 = await pool.get_connection() + # Second call + nc2, js2 = await pool.get_connection() + + # Should only connect once + self.assertEqual(mock_nats.connect.call_count, 1) + self.assertIs(nc1, nc2) + self.assertIs(js1, js2) + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_get_connection_reconnects_if_closed(self, mock_settings, mock_nats): + """Test that get_connection reconnects if the connection is closed.""" + mock_settings.NATS_URL = "nats://test:4222" + + mock_nc1 = MagicMock() + mock_nc1.is_closed = True + mock_nc1.is_connected = False + mock_nc1.close = AsyncMock() + + mock_nc2 = MagicMock() + mock_nc2.is_closed = False + mock_nc2.is_connected = True + mock_nc2.close = AsyncMock() + mock_js2 = MagicMock() + mock_nc2.jetstream.return_value = mock_js2 + + mock_nats.connect = AsyncMock(side_effect=[mock_nc2]) + + pool = ConnectionPool() + pool._nc = mock_nc1 + pool._js = MagicMock() + + # This should detect the connection is closed and reconnect + nc, js = await pool.get_connection() + + self.assertIs(nc, mock_nc2) + self.assertIs(js, mock_js2) + mock_nats.connect.assert_called_once() + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_get_connection_raises_on_connection_error(self, mock_settings, mock_nats): + """Test that get_connection raises RuntimeError on connection failure.""" + mock_settings.NATS_URL = "nats://test:4222" + + mock_nats.connect = AsyncMock(side_effect=ConnectionError("Connection failed")) + + pool = ConnectionPool() + + with self.assertRaises(RuntimeError) as context: + await pool.get_connection() + + self.assertIn("Could not establish NATS connection", str(context.exception)) + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_close_closes_connection(self, mock_settings, mock_nats): + """Test that close() closes the connection.""" + mock_settings.NATS_URL = "nats://test:4222" + + mock_nc = MagicMock() + mock_nc.is_closed = False + mock_nc.is_connected = True + mock_nc.close = AsyncMock() + mock_js = MagicMock() + mock_nc.jetstream.return_value = mock_js + + mock_nats.connect = AsyncMock(return_value=mock_nc) + + pool = ConnectionPool() + await pool.get_connection() + + await pool.close() + + mock_nc.close.assert_called_once() + self.assertIsNone(pool._nc) + self.assertIsNone(pool._js) + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_reset_closes_and_clears_state(self, mock_settings, mock_nats): + """Test that reset() closes connection and clears all state.""" + mock_settings.NATS_URL = "nats://test:4222" + + mock_nc = MagicMock() + mock_nc.is_closed = False + mock_nc.is_connected = True + mock_nc.close = AsyncMock() + mock_js = MagicMock() + mock_nc.jetstream.return_value = mock_js + + mock_nats.connect = AsyncMock(return_value=mock_nc) + + pool = ConnectionPool() + await pool.get_connection() + + # Set the lock so we can verify it gets cleared + pool._lock = asyncio.Lock() + + await pool.reset() + + mock_nc.close.assert_called_once() + self.assertIsNone(pool._nc) + self.assertIsNone(pool._js) + self.assertIsNone(pool._lock) + + +class TestModuleLevelFunctions(unittest.IsolatedAsyncioTestCase): + """Test module-level get_connection() and reset_connection() functions.""" + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_get_connection_returns_connection(self, mock_settings, mock_nats): + """Test that module-level get_connection() returns a NATS connection.""" + from ami.ml.orchestration.nats_connection import _pools, get_connection + + mock_settings.NATS_URL = "nats://test:4222" + mock_nc = MagicMock() + mock_nc.is_closed = False + mock_nc.is_connected = True + mock_nc.close = AsyncMock() + mock_js = MagicMock() + mock_nc.jetstream.return_value = mock_js + mock_nats.connect = AsyncMock(return_value=mock_nc) + + # Clear pools to avoid leaking state between tests + _pools.clear() + + nc, js = await get_connection() + + self.assertIs(nc, mock_nc) + self.assertIs(js, mock_js) + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_get_connection_reuses_pool_for_same_loop(self, mock_settings, mock_nats): + """Test that repeated calls on the same loop reuse the same pool.""" + from ami.ml.orchestration.nats_connection import _pools, get_connection + + mock_settings.NATS_URL = "nats://test:4222" + mock_nc = MagicMock() + mock_nc.is_closed = False + mock_nc.is_connected = True + mock_nc.close = AsyncMock() + mock_js = MagicMock() + mock_nc.jetstream.return_value = mock_js + mock_nats.connect = AsyncMock(return_value=mock_nc) + + _pools.clear() + + await get_connection() + await get_connection() + + # Only one TCP connection should have been created + mock_nats.connect.assert_called_once() + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_reset_connection_clears_pool(self, mock_settings, mock_nats): + """Test that reset_connection() resets the pool for the current loop.""" + from ami.ml.orchestration.nats_connection import _pools, get_connection, reset_connection + + mock_settings.NATS_URL = "nats://test:4222" + mock_nc = MagicMock() + mock_nc.is_closed = False + mock_nc.is_connected = True + mock_nc.close = AsyncMock() + mock_js = MagicMock() + mock_nc.jetstream.return_value = mock_js + mock_nats.connect = AsyncMock(return_value=mock_nc) + + _pools.clear() + + await get_connection() + await reset_connection() + + mock_nc.close.assert_called_once() + + def test_get_connection_raises_without_event_loop(self): + """Test that _get_pool raises RuntimeError outside async context.""" + from ami.ml.orchestration.nats_connection import _get_pool + + with self.assertRaises(RuntimeError) as context: + _get_pool() + + self.assertIn("must be called from an async context", str(context.exception)) diff --git a/ami/ml/orchestration/tests/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py index 0cd2c3bef..2c3f446fb 100644 --- a/ami/ml/orchestration/tests/test_nats_queue.py +++ b/ami/ml/orchestration/tests/test_nats_queue.py @@ -1,8 +1,11 @@ """Unit tests for TaskQueueManager.""" import unittest +from contextlib import contextmanager from unittest.mock import AsyncMock, MagicMock, patch +from nats.js.errors import NotFoundError + from ami.ml.orchestration.nats_queue import TaskQueueManager from ami.ml.schemas import PipelineProcessingTask @@ -18,8 +21,13 @@ def _create_sample_task(self): image_url="https://example.com/image.jpg", ) - def _create_mock_nats_connection(self): - """Helper to create mock NATS connection and JetStream context.""" + @contextmanager + def _mock_nats_setup(self): + """Helper to create and mock NATS connection with connection pool. + + Yields: + tuple: (nc, js, mock_pool) - NATS connection, JetStream context, and mock pool + """ nc = MagicMock() nc.is_closed = False nc.close = AsyncMock() @@ -34,37 +42,27 @@ def _create_mock_nats_connection(self): js.delete_consumer = AsyncMock() js.delete_stream = AsyncMock() - return nc, js - - async def test_context_manager_lifecycle(self): - """Test that context manager properly opens and closes connections.""" - nc, js = self._create_mock_nats_connection() - - with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): - async with TaskQueueManager("nats://test:4222") as manager: - self.assertIsNotNone(manager.nc) - self.assertIsNotNone(manager.js) - - nc.close.assert_called_once() + with patch("ami.ml.orchestration.nats_connection.get_connection", new_callable=AsyncMock) as mock_get_conn: + mock_get_conn.return_value = (nc, js) + yield nc, js, mock_get_conn async def test_publish_task_creates_stream_and_consumer(self): """Test that publish_task ensures stream and consumer exist.""" - nc, js = self._create_mock_nats_connection() sample_task = self._create_sample_task() - js.stream_info.side_effect = Exception("Not found") - js.consumer_info.side_effect = Exception("Not found") - with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): - async with TaskQueueManager() as manager: - await manager.publish_task(456, sample_task) + with self._mock_nats_setup() as (_, js, _): + js.stream_info.side_effect = NotFoundError + js.consumer_info.side_effect = NotFoundError - js.add_stream.assert_called_once() - self.assertIn("job_456", str(js.add_stream.call_args)) - js.add_consumer.assert_called_once() + manager = TaskQueueManager() + await manager.publish_task(456, sample_task) + + js.add_stream.assert_called_once() + self.assertIn("job_456", str(js.add_stream.call_args)) + js.add_consumer.assert_called_once() async def test_reserve_task_success(self): """Test successful task reservation.""" - nc, js = self._create_mock_nats_connection() sample_task = self._create_sample_task() # Mock message with task data @@ -73,59 +71,72 @@ async def test_reserve_task_success(self): mock_msg.reply = "reply.subject.123" mock_msg.metadata = MagicMock(sequence=MagicMock(stream=1)) - mock_psub = MagicMock() - mock_psub.fetch = AsyncMock(return_value=[mock_msg]) - mock_psub.unsubscribe = AsyncMock() - js.pull_subscribe = AsyncMock(return_value=mock_psub) + with self._mock_nats_setup() as (_, js, _): + mock_psub = MagicMock() + mock_psub.fetch = AsyncMock(return_value=[mock_msg]) + mock_psub.unsubscribe = AsyncMock() + js.pull_subscribe = AsyncMock(return_value=mock_psub) - with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): - async with TaskQueueManager() as manager: - task = await manager.reserve_task(123) + manager = TaskQueueManager() + task = await manager.reserve_task(123) - self.assertIsNotNone(task) - self.assertEqual(task.id, sample_task.id) - self.assertEqual(task.reply_subject, "reply.subject.123") - mock_psub.unsubscribe.assert_called_once() + self.assertIsNotNone(task) + self.assertEqual(task.id, sample_task.id) + self.assertEqual(task.reply_subject, "reply.subject.123") + mock_psub.unsubscribe.assert_called_once() async def test_reserve_task_no_messages(self): """Test reserve_task when no messages are available.""" - nc, js = self._create_mock_nats_connection() + with self._mock_nats_setup() as (_, js, _): + mock_psub = MagicMock() + mock_psub.fetch = AsyncMock(return_value=[]) + mock_psub.unsubscribe = AsyncMock() + js.pull_subscribe = AsyncMock(return_value=mock_psub) + + manager = TaskQueueManager() + task = await manager.reserve_task(123) + + self.assertIsNone(task) + mock_psub.unsubscribe.assert_called_once() - mock_psub = MagicMock() - mock_psub.fetch = AsyncMock(return_value=[]) - mock_psub.unsubscribe = AsyncMock() - js.pull_subscribe = AsyncMock(return_value=mock_psub) + async def test_reserve_task_timeout(self): + """Test reserve_task when fetch raises TimeoutError (no messages available).""" + from nats import errors as nats_errors - with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): - async with TaskQueueManager() as manager: - task = await manager.reserve_task(123) + with self._mock_nats_setup() as (_, js, _): + mock_psub = MagicMock() + mock_psub.fetch = AsyncMock(side_effect=nats_errors.TimeoutError) + mock_psub.unsubscribe = AsyncMock() + js.pull_subscribe = AsyncMock(return_value=mock_psub) - self.assertIsNone(task) - mock_psub.unsubscribe.assert_called_once() + manager = TaskQueueManager() + task = await manager.reserve_task(123) + + self.assertIsNone(task) + mock_psub.unsubscribe.assert_called_once() async def test_acknowledge_task_success(self): """Test successful task acknowledgment.""" - nc, js = self._create_mock_nats_connection() - nc.publish = AsyncMock() + with self._mock_nats_setup() as (nc, _, _): + nc.publish = AsyncMock() + nc.flush = AsyncMock() - with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): - async with TaskQueueManager() as manager: - result = await manager.acknowledge_task("reply.subject.123") + manager = TaskQueueManager() + result = await manager.acknowledge_task("reply.subject.123") - self.assertTrue(result) - nc.publish.assert_called_once_with("reply.subject.123", b"+ACK") + self.assertTrue(result) + nc.publish.assert_called_once_with("reply.subject.123", b"+ACK") + nc.flush.assert_called_once() async def test_cleanup_job_resources(self): """Test cleanup of job resources (consumer and stream).""" - nc, js = self._create_mock_nats_connection() - - with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): - async with TaskQueueManager() as manager: - result = await manager.cleanup_job_resources(123) + with self._mock_nats_setup() as (_, js, _): + manager = TaskQueueManager() + result = await manager.cleanup_job_resources(123) - self.assertTrue(result) - js.delete_consumer.assert_called_once() - js.delete_stream.assert_called_once() + self.assertTrue(result) + js.delete_consumer.assert_called_once() + js.delete_stream.assert_called_once() async def test_naming_conventions(self): """Test stream, subject, and consumer naming conventions.""" @@ -135,16 +146,117 @@ async def test_naming_conventions(self): self.assertEqual(manager._get_subject(123), "job.123.tasks") self.assertEqual(manager._get_consumer_name(123), "job-123-consumer") - async def test_operations_without_connection_raise_error(self): - """Test that operations without connection raise RuntimeError.""" - manager = TaskQueueManager() - sample_task = self._create_sample_task() - with self.assertRaisesRegex(RuntimeError, "Connection is not open"): - await manager.publish_task(123, sample_task) +class TestRetryOnConnectionError(unittest.IsolatedAsyncioTestCase): + """Test retry_on_connection_error decorator behavior.""" + + def _create_sample_task(self): + """Helper to create a sample PipelineProcessingTask.""" + return PipelineProcessingTask( + id="task-retry-test", + image_id="img-retry", + image_url="https://example.com/retry.jpg", + ) + + async def test_retry_resets_connection_on_error(self): + """On connection error, the decorator should call reset_connection() before retrying.""" + from nats.errors import ConnectionClosedError + + nc = MagicMock() + nc.is_closed = False + js = MagicMock() + js.stream_info = AsyncMock() + js.add_stream = AsyncMock() + js.consumer_info = AsyncMock() + js.add_consumer = AsyncMock() + # First publish fails with connection error, second succeeds + js.publish = AsyncMock(side_effect=[ConnectionClosedError(), MagicMock(seq=1)]) + js.pull_subscribe = AsyncMock() + + with ( + patch( + "ami.ml.orchestration.nats_connection.get_connection", + new_callable=AsyncMock, + return_value=(nc, js), + ), + patch( + "ami.ml.orchestration.nats_connection.reset_connection", + new_callable=AsyncMock, + ) as mock_reset, + ): + manager = TaskQueueManager() + sample_task = self._create_sample_task() + + # Should succeed after retry + with patch("ami.ml.orchestration.nats_queue.asyncio.sleep", new_callable=AsyncMock): + result = await manager.publish_task(456, sample_task) + + self.assertTrue(result) + # reset_connection() should have been called once (after first failure) + mock_reset.assert_called_once() + + async def test_retry_raises_after_max_retries(self): + """After exhausting retries, the last error should be raised.""" + from nats.errors import ConnectionClosedError - with self.assertRaisesRegex(RuntimeError, "Connection is not open"): - await manager.reserve_task(123) + nc = MagicMock() + nc.is_closed = False + js = MagicMock() + js.stream_info = AsyncMock() + js.add_stream = AsyncMock() + js.consumer_info = AsyncMock() + js.add_consumer = AsyncMock() + # All attempts fail + js.publish = AsyncMock(side_effect=ConnectionClosedError()) + + with ( + patch( + "ami.ml.orchestration.nats_connection.get_connection", + new_callable=AsyncMock, + return_value=(nc, js), + ), + patch( + "ami.ml.orchestration.nats_connection.reset_connection", + new_callable=AsyncMock, + ) as mock_reset, + ): + manager = TaskQueueManager() + sample_task = self._create_sample_task() + + with patch("ami.ml.orchestration.nats_queue.asyncio.sleep", new_callable=AsyncMock): + with self.assertRaises(ConnectionClosedError): + await manager.publish_task(456, sample_task) + + # reset_connection() called twice (max_retries=2, so 2 retries means 2 resets) + self.assertEqual(mock_reset.call_count, 2) + + async def test_non_connection_errors_are_not_retried(self): + """Non-connection errors (e.g. ValueError) should propagate immediately without retry.""" + nc = MagicMock() + nc.is_closed = False + js = MagicMock() + js.stream_info = AsyncMock() + js.add_stream = AsyncMock() + js.consumer_info = AsyncMock() + js.add_consumer = AsyncMock() + js.publish = AsyncMock(side_effect=ValueError("bad data")) + + with ( + patch( + "ami.ml.orchestration.nats_connection.get_connection", + new_callable=AsyncMock, + return_value=(nc, js), + ), + patch( + "ami.ml.orchestration.nats_connection.reset_connection", + new_callable=AsyncMock, + ) as mock_reset, + ): + manager = TaskQueueManager() + sample_task = self._create_sample_task() + + with self.assertRaises(ValueError): + await manager.publish_task(456, sample_task) - with self.assertRaisesRegex(RuntimeError, "Connection is not open"): - await manager.delete_stream(123) + # reset_connection() should NOT have been called + mock_reset.assert_not_called()