From b60eab0ff435e32b0264b182b4022864cdce71f0 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 16 Jan 2026 11:25:40 -0800 Subject: [PATCH 01/16] merge --- requirements/base.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/base.txt b/requirements/base.txt index dd9de69d5..ed40ea5f7 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -52,6 +52,7 @@ django-anymail[sendgrid]==10.0 # https://github.com/anymail/django-anymail Werkzeug[watchdog]==2.3.6 # https://github.com/pallets/werkzeug ipdb==0.13.13 # https://github.com/gotcha/ipdb psycopg[binary]==3.1.9 # https://github.com/psycopg/psycopg +# psycopg==3.1.9 # https://github.com/psycopg/psycopg # the non-binary version is needed for some platforms watchfiles==0.19.0 # https://github.com/samuelcolvin/watchfiles # Testing From 842e9b3890fff3218f25898a8babfeb5390e7a74 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 11 Feb 2026 09:55:20 -0800 Subject: [PATCH 02/16] PSv2: Use connection pooling and retries for NATS --- ami/jobs/tasks.py | 4 +- ami/jobs/views.py | 10 +- ami/ml/orchestration/jobs.py | 36 +-- ami/ml/orchestration/nats_connection_pool.py | 112 +++++++ ami/ml/orchestration/nats_queue.py | 323 ++++++++++++------- 5 files changed, 338 insertions(+), 147 deletions(-) create mode 100644 ami/ml/orchestration/nats_connection_pool.py diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 3548d0ea5..94a6a821a 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -135,8 +135,8 @@ def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None: try: async def ack_task(): - async with TaskQueueManager() as manager: - return await manager.acknowledge_task(reply_subject) + manager = TaskQueueManager() + return await manager.acknowledge_task(reply_subject) ack_success = async_to_sync(ack_task)() diff --git a/ami/jobs/views.py b/ami/jobs/views.py index dd8da01b2..f45361007 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -242,11 +242,11 @@ def tasks(self, request, pk=None): async def get_tasks(): tasks = [] - async with TaskQueueManager() as manager: - for _ in range(batch): - task = await manager.reserve_task(job.pk, timeout=0.1) - if task: - tasks.append(task.dict()) + manager = TaskQueueManager() + for _ in range(batch): + task = await manager.reserve_task(job.pk, timeout=0.1) + if task: + tasks.append(task.dict()) return tasks # Use async_to_sync to properly handle the async call diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index 9b4a577e9..b4b9171b7 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -40,8 +40,8 @@ def cleanup_async_job_resources(job: "Job") -> bool: # Cleanup NATS resources async def cleanup(): - async with TaskQueueManager() as manager: - return await manager.cleanup_job_resources(job.pk) + manager = TaskQueueManager() + return await manager.cleanup_job_resources(job.pk) try: nats_success = async_to_sync(cleanup)() @@ -96,22 +96,22 @@ async def queue_all_images(): successful_queues = 0 failed_queues = 0 - async with TaskQueueManager() as manager: - for image_pk, task in tasks: - try: - logger.info(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}") - success = await manager.publish_task( - job_id=job.pk, - data=task, - ) - except Exception as e: - logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}") - success = False - - if success: - successful_queues += 1 - else: - failed_queues += 1 + manager = TaskQueueManager() + for image_pk, task in tasks: + try: + logger.info(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}") + success = await manager.publish_task( + job_id=job.pk, + data=task, + ) + except Exception as e: + logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}") + success = False + + if success: + successful_queues += 1 + else: + failed_queues += 1 return successful_queues, failed_queues diff --git a/ami/ml/orchestration/nats_connection_pool.py b/ami/ml/orchestration/nats_connection_pool.py new file mode 100644 index 000000000..5769fda19 --- /dev/null +++ b/ami/ml/orchestration/nats_connection_pool.py @@ -0,0 +1,112 @@ +""" +NATS connection pool for both Celery workers and Django processes. + +Maintains a persistent NATS connection per process to avoid +the overhead of creating/closing connections for every operation. + +The connection pool is lazily initialized on first use and shared +across all operations in the same process. +""" + +import asyncio +import logging +from typing import TYPE_CHECKING + +import nats +from django.conf import settings +from nats.js import JetStreamContext + +if TYPE_CHECKING: + from nats.aio.client import Client as NATSClient + +logger = logging.getLogger(__name__) + + +class ConnectionPool: + """ + Manages a single NATS connection per process (Celery worker or Django web worker). + + This is safe because: + - Each process gets its own isolated connection + - NATS connections are async-safe (can be used by multiple coroutines) + - Works in both Celery prefork and Django WSGI/ASGI contexts + """ + + def __init__(self): + self._nc: "NATSClient | None" = None + self._js: JetStreamContext | None = None + self._nats_url: str | None = None + self._lock = asyncio.Lock() + + async def get_connection(self) -> tuple["NATSClient", JetStreamContext]: + """ + Get or create the worker's NATS connection. Checks connection health and recreates if stale. + + Returns: + Tuple of (NATS connection, JetStream context) + Raises: + RuntimeError: If connection cannot be established + """ + # Fast path: connection exists, is open, and is connected + if self._nc is not None and not self._nc.is_closed and self._nc.is_connected: + return self._nc, self._js # type: ignore + + # Connection is stale or doesn't exist + if self._nc is not None: + logger.warning("NATS connection is closed or disconnected, will reconnect") + self._nc = None + self._js = None + + # Slow path: need to create/recreate connection + async with self._lock: + # Double-check after acquiring lock + if self._nc is not None and not self._nc.is_closed and self._nc.is_connected: + return self._nc, self._js # type: ignore + + # Get NATS URL from settings + if self._nats_url is None: + self._nats_url = getattr(settings, "NATS_URL", "nats://nats:4222") + + try: + logger.info(f"Creating NATS connection to {self._nats_url}") + self._nc = await nats.connect(self._nats_url) + self._js = self._nc.jetstream() + logger.info(f"Successfully connected to NATS at {self._nats_url}") + return self._nc, self._js + except Exception as e: + logger.error(f"Failed to connect to NATS: {e}") + raise RuntimeError(f"Could not establish NATS connection: {e}") from e + + async def close(self): + """Close the NATS connection if it exists.""" + if self._nc is not None and not self._nc.is_closed: + logger.info("Closing NATS connection") + await self._nc.close() + self._nc = None + self._js = None + + def reset(self): + """ + Reset the connection pool (mark connection as stale). + + This should be called when a connection error is detected. + The next call to get_connection() will create a fresh connection. + """ + logger.warning("Resetting NATS connection pool due to connection error") + self._nc = None + self._js = None + + +# Global pool instance - one per process (Celery worker or Django process) +_connection_pool: ConnectionPool | None = None + + +def get_pool() -> ConnectionPool: + """ + Get the process-local connection pool. + """ + global _connection_pool + if _connection_pool is None: + _connection_pool = ConnectionPool() + logger.debug("Lazily initialized NATS connection pool") + return _connection_pool diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index fa7188627..4647892e5 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -10,57 +10,111 @@ support the visibility timeout semantics we want or a disconnected mode of pulling and ACKing tasks. """ +import asyncio +import functools import json import logging +from collections.abc import Callable +from typing import TYPE_CHECKING, TypeVar -import nats -from django.conf import settings +from nats import errors as nats_errors from nats.js import JetStreamContext from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy from ami.ml.schemas import PipelineProcessingTask -logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from collections.abc import Awaitable + from nats.aio.client import Client as NATSClient -async def get_connection(nats_url: str): - nc = await nats.connect(nats_url) - js = nc.jetstream() - return nc, js +logger = logging.getLogger(__name__) TASK_TTR = 300 # Default Time-To-Run (visibility timeout) in seconds +T = TypeVar("T") -class TaskQueueManager: + +def retry_on_connection_error(max_retries: int = 2, backoff_seconds: float = 0.5): """ - Manager for NATS JetStream task queue operations. + Decorator that retries NATS operations on connection errors. + + When a connection error is detected: + 1. Resets the connection pool (clears stale connection) + 2. Waits with exponential backoff + 3. Retries the operation (which will get a fresh connection) + + Args: + max_retries: Maximum number of retry attempts (default: 2) + backoff_seconds: Initial backoff time in seconds (default: 0.5) - Use as an async context manager: - async with TaskQueueManager() as manager: - await manager.publish_task('job123', {'data': 'value'}) - task = await manager.reserve_task('job123') - await manager.acknowledge_task(task['reply_subject']) + Returns: + Decorated async function with retry logic """ - def __init__(self, nats_url: str | None = None): - self.nats_url = nats_url or getattr(settings, "NATS_URL", "nats://nats:4222") - self.nc: nats.NATS | None = None - self.js: JetStreamContext | None = None + def decorator(func: Callable[..., "Awaitable[T]"]) -> Callable[..., "Awaitable[T]"]: + @functools.wraps(func) + async def wrapper(self, *args, **kwargs) -> T: + last_error = None + + for attempt in range(max_retries + 1): + try: + return await func(self, *args, **kwargs) + except ( + nats_errors.ConnectionClosedError, + nats_errors.NoServersError, + nats_errors.TimeoutError, + nats_errors.ConnectionReconnectingError, + OSError, # Network errors + ) as e: + last_error = e + + # Don't retry on last attempt + if attempt == max_retries: + logger.error( + f"{func.__name__} failed after {max_retries + 1} attempts: {e}", + exc_info=True, + ) + break + + # Reset the connection pool so next attempt gets a fresh connection + from ami.ml.orchestration.nats_connection_pool import get_pool + + pool = get_pool() + pool.reset() + + # Exponential backoff + wait_time = backoff_seconds * (2**attempt) + logger.warning( + f"{func.__name__} failed (attempt {attempt + 1}/{max_retries + 1}): {e}. " + f"Retrying in {wait_time}s..." + ) + await asyncio.sleep(wait_time) + + # If we exhausted retries, raise the last error + raise last_error # type: ignore + + return wrapper + + return decorator - async def __aenter__(self): - """Create connection on enter.""" - self.nc, self.js = await get_connection(self.nats_url) - return self - async def __aexit__(self, exc_type, exc_val, exc_tb): - if self.js: - self.js = None - if self.nc and not self.nc.is_closed: - await self.nc.close() - self.nc = None +class TaskQueueManager: + """ + Manager for NATS JetStream task queue operations. - return False + Always uses the process-local connection pool for efficiency. + Note: The connection pool is shared across all instances in the same process, + so there's no overhead creating multiple TaskQueueManager instances. + """ + + async def _get_connection(self) -> tuple["NATSClient", JetStreamContext]: + """Get connection from the process-local pool.""" + from ami.ml.orchestration.nats_connection_pool import get_pool + + pool = get_pool() + return await pool.get_connection() def _get_stream_name(self, job_id: int) -> str: """Get stream name from job_id.""" @@ -76,19 +130,18 @@ def _get_consumer_name(self, job_id: int) -> str: async def _ensure_stream(self, job_id: int): """Ensure stream exists for the given job.""" - if self.js is None: - raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + _, js = await self._get_connection() stream_name = self._get_stream_name(job_id) subject = self._get_subject(job_id) try: - await self.js.stream_info(stream_name) + await js.stream_info(stream_name) logger.debug(f"Stream {stream_name} already exists") except Exception as e: logger.warning(f"Stream {stream_name} does not exist: {e}") # Stream doesn't exist, create it - await self.js.add_stream( + await js.add_stream( name=stream_name, subjects=[subject], max_age=86400, # 24 hours retention @@ -97,19 +150,18 @@ async def _ensure_stream(self, job_id: int): async def _ensure_consumer(self, job_id: int): """Ensure consumer exists for the given job.""" - if self.js is None: - raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + _, js = await self._get_connection() stream_name = self._get_stream_name(job_id) consumer_name = self._get_consumer_name(job_id) subject = self._get_subject(job_id) try: - info = await self.js.consumer_info(stream_name, consumer_name) + info = await js.consumer_info(stream_name, consumer_name) logger.debug(f"Consumer {consumer_name} already exists: {info}") except Exception: # Consumer doesn't exist, create it - await self.js.add_consumer( + await js.add_consumer( stream=stream_name, config=ConsumerConfig( durable_name=consumer_name, @@ -123,43 +175,48 @@ async def _ensure_consumer(self, job_id: int): ) logger.info(f"Created consumer {consumer_name}") + @retry_on_connection_error(max_retries=2, backoff_seconds=0.5) async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool: """ Publish a task to it's job queue. + Automatically retries on connection errors with exponential backoff. + Args: job_id: The job ID (integer primary key) data: PipelineProcessingTask object to be published Returns: bool: True if successful, False otherwise + + Raises: + Connection errors are retried by decorator, other errors are raised """ - if self.js is None: - raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + _, js = await self._get_connection() - try: - # Ensure stream and consumer exist - await self._ensure_stream(job_id) - await self._ensure_consumer(job_id) + # Ensure stream and consumer exist + await self._ensure_stream(job_id) + await self._ensure_consumer(job_id) - subject = self._get_subject(job_id) - # Convert Pydantic model to JSON - task_data = json.dumps(data.dict()) + subject = self._get_subject(job_id) + # Convert Pydantic model to JSON + task_data = json.dumps(data.dict()) - # Publish to JetStream - ack = await self.js.publish(subject, task_data.encode()) + # Publish to JetStream + # Note: JetStream publish() waits for PubAck, so it's implicitly flushed + ack = await js.publish(subject, task_data.encode()) - logger.info(f"Published task to stream for job '{job_id}', sequence {ack.seq}") - return True - - except Exception as e: - logger.error(f"Failed to publish task to stream for job '{job_id}': {e}") - return False + logger.info(f"Published task to stream for job '{job_id}', sequence {ack.seq}") + return True + @retry_on_connection_error(max_retries=2, backoff_seconds=0.5) async def reserve_task(self, job_id: int, timeout: float | None = None) -> PipelineProcessingTask | None: """ Reserve a task from the specified stream. + Automatically retries on connection errors with exponential backoff. + Note: TimeoutError from fetch() (no messages) is NOT retried - only connection errors. + Args: job_id: The job ID (integer primary key) to pull tasks from timeout: Timeout in seconds for reservation (default: 5 seconds) @@ -167,134 +224,156 @@ async def reserve_task(self, job_id: int, timeout: float | None = None) -> Pipel Returns: PipelineProcessingTask with reply_subject set for acknowledgment, or None if no task available """ - if self.js is None: - raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + _, js = await self._get_connection() if timeout is None: timeout = 5 - try: - # Ensure stream and consumer exist - await self._ensure_stream(job_id) - await self._ensure_consumer(job_id) + # Ensure stream and consumer exist (let connection errors escape for retry) + await self._ensure_stream(job_id) + await self._ensure_consumer(job_id) - consumer_name = self._get_consumer_name(job_id) - subject = self._get_subject(job_id) - - # Create ephemeral subscription for this pull - psub = await self.js.pull_subscribe(subject, consumer_name) + consumer_name = self._get_consumer_name(job_id) + subject = self._get_subject(job_id) - try: - # Fetch a single message - msgs = await psub.fetch(1, timeout=timeout) + # Create ephemeral subscription for this pull + psub = await js.pull_subscribe(subject, consumer_name) - if msgs: - msg = msgs[0] - task_data = json.loads(msg.data.decode()) - metadata = msg.metadata + try: + # Fetch a single message + msgs = await psub.fetch(1, timeout=timeout) - # Parse the task data into PipelineProcessingTask - task = PipelineProcessingTask(**task_data) - # Set the reply_subject for acknowledgment - task.reply_subject = msg.reply + if msgs: + msg = msgs[0] + task_data = json.loads(msg.data.decode()) + metadata = msg.metadata - logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") - return task + # Parse the task data into PipelineProcessingTask + task = PipelineProcessingTask(**task_data) + # Set the reply_subject for acknowledgment + task.reply_subject = msg.reply - except nats.errors.TimeoutError: - # No messages available - logger.debug(f"No tasks available in stream for job '{job_id}'") - return None - finally: - # Always unsubscribe - await psub.unsubscribe() + logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") + return task - except Exception as e: - logger.error(f"Failed to reserve task from stream for job '{job_id}': {e}") + except nats_errors.TimeoutError: + # No messages available (expected behavior) + logger.debug(f"No tasks available in stream for job '{job_id}'") return None + finally: + # Always unsubscribe + await psub.unsubscribe() + @retry_on_connection_error(max_retries=2, backoff_seconds=0.5) async def acknowledge_task(self, reply_subject: str) -> bool: """ Acknowledge (delete) a completed task using its reply subject. + Automatically retries on connection errors with exponential backoff. + Uses a lock to serialize ACK operations to prevent concurrent access issues. + Args: reply_subject: The reply subject from reserve_task Returns: bool: True if successful + + Raises: + Connection errors are retried by decorator, other errors are logged """ - if self.nc is None: - raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + nc, _ = await self._get_connection() + + # Don't catch connection errors - let retry decorator handle them + await nc.publish(reply_subject, b"+ACK") + # CRITICAL: Flush to ensure ACK is sent immediately + # Without flush, ACKs may be buffered and not sent to NATS server try: - await self.nc.publish(reply_subject, b"+ACK") - logger.debug(f"Acknowledged task with reply subject {reply_subject}") - return True - except Exception as e: - logger.error(f"Failed to acknowledge task: {e}") - return False + await nc.flush(timeout=2) + except asyncio.TimeoutError as e: + # Flush timeout likely means connection is stale - re-raise to trigger retry + logger.warning(f"Flush timeout for ACK {reply_subject}, connection may be stale: {e}") + raise nats_errors.TimeoutError("Flush timeout") from e + + logger.debug(f"Acknowledged task with reply subject {reply_subject}") + return True + @retry_on_connection_error(max_retries=2, backoff_seconds=0.5) async def delete_consumer(self, job_id: int) -> bool: """ Delete the consumer for a job. + Automatically retries on connection errors with exponential backoff. + Args: job_id: The job ID (integer primary key) Returns: - bool: True if successful, False otherwise + bool: True if successful + + Raises: + Connection errors are retried by decorator, other errors are raised """ - if self.js is None: - raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + _, js = await self._get_connection() - try: - stream_name = self._get_stream_name(job_id) - consumer_name = self._get_consumer_name(job_id) + stream_name = self._get_stream_name(job_id) + consumer_name = self._get_consumer_name(job_id) - await self.js.delete_consumer(stream_name, consumer_name) - logger.info(f"Deleted consumer {consumer_name} for job '{job_id}'") - return True - except Exception as e: - logger.error(f"Failed to delete consumer for job '{job_id}': {e}") - return False + await js.delete_consumer(stream_name, consumer_name) + logger.info(f"Deleted consumer {consumer_name} for job '{job_id}'") + return True + @retry_on_connection_error(max_retries=2, backoff_seconds=0.5) async def delete_stream(self, job_id: int) -> bool: """ Delete the stream for a job. + Automatically retries on connection errors with exponential backoff. + Args: job_id: The job ID (integer primary key) Returns: - bool: True if successful, False otherwise + bool: True if successful + + Raises: + Connection errors are retried by decorator, other errors are raised """ - if self.js is None: - raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + _, js = await self._get_connection() - try: - stream_name = self._get_stream_name(job_id) + stream_name = self._get_stream_name(job_id) - await self.js.delete_stream(stream_name) - logger.info(f"Deleted stream {stream_name} for job '{job_id}'") - return True - except Exception as e: - logger.error(f"Failed to delete stream for job '{job_id}': {e}") - return False + await js.delete_stream(stream_name) + logger.info(f"Deleted stream {stream_name} for job'{job_id}'") + return True async def cleanup_job_resources(self, job_id: int) -> bool: """ Clean up all NATS resources (consumer and stream) for a job. This should be called when a job completes or is cancelled. + Best-effort cleanup - logs errors but doesn't fail if cleanup fails. Args: job_id: The job ID (integer primary key) Returns: - bool: True if successful, False otherwise + bool: True if both cleanup operations succeeded, False otherwise """ - # Delete consumer first, then stream - consumer_deleted = await self.delete_consumer(job_id) - stream_deleted = await self.delete_stream(job_id) + consumer_deleted = False + stream_deleted = False + + # Delete consumer first, then stream (best-effort) + try: + await self.delete_consumer(job_id) + consumer_deleted = True + except Exception as e: + logger.warning(f"Failed to delete consumer for job {job_id} after retries: {e}") + + try: + await self.delete_stream(job_id) + stream_deleted = True + except Exception as e: + logger.warning(f"Failed to delete stream for job {job_id} after retries: {e}") return consumer_deleted and stream_deleted From 227a8dbefd560e6fa49baed9674a314e54f333a5 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 11 Feb 2026 10:10:40 -0800 Subject: [PATCH 03/16] Refactor and fix nats tests --- ami/ml/orchestration/tests/test_nats_queue.py | 131 ++++++++---------- 1 file changed, 56 insertions(+), 75 deletions(-) diff --git a/ami/ml/orchestration/tests/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py index 0cd2c3bef..def90ba68 100644 --- a/ami/ml/orchestration/tests/test_nats_queue.py +++ b/ami/ml/orchestration/tests/test_nats_queue.py @@ -1,6 +1,7 @@ """Unit tests for TaskQueueManager.""" import unittest +from contextlib import contextmanager from unittest.mock import AsyncMock, MagicMock, patch from ami.ml.orchestration.nats_queue import TaskQueueManager @@ -18,8 +19,13 @@ def _create_sample_task(self): image_url="https://example.com/image.jpg", ) - def _create_mock_nats_connection(self): - """Helper to create mock NATS connection and JetStream context.""" + @contextmanager + def _mock_nats_setup(self): + """Helper to create and mock NATS connection with connection pool. + + Yields: + tuple: (nc, js, mock_pool) - NATS connection, JetStream context, and mock pool + """ nc = MagicMock() nc.is_closed = False nc.close = AsyncMock() @@ -34,37 +40,29 @@ def _create_mock_nats_connection(self): js.delete_consumer = AsyncMock() js.delete_stream = AsyncMock() - return nc, js - - async def test_context_manager_lifecycle(self): - """Test that context manager properly opens and closes connections.""" - nc, js = self._create_mock_nats_connection() - - with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): - async with TaskQueueManager("nats://test:4222") as manager: - self.assertIsNotNone(manager.nc) - self.assertIsNotNone(manager.js) - - nc.close.assert_called_once() + with patch("ami.ml.orchestration.nats_connection_pool.get_pool") as mock_get_pool: + mock_pool = MagicMock() + mock_pool.get_connection = AsyncMock(return_value=(nc, js)) + mock_get_pool.return_value = mock_pool + yield nc, js, mock_pool async def test_publish_task_creates_stream_and_consumer(self): """Test that publish_task ensures stream and consumer exist.""" - nc, js = self._create_mock_nats_connection() sample_task = self._create_sample_task() - js.stream_info.side_effect = Exception("Not found") - js.consumer_info.side_effect = Exception("Not found") - with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): - async with TaskQueueManager() as manager: - await manager.publish_task(456, sample_task) + with self._mock_nats_setup() as (_, js, _): + js.stream_info.side_effect = Exception("Not found") + js.consumer_info.side_effect = Exception("Not found") - js.add_stream.assert_called_once() - self.assertIn("job_456", str(js.add_stream.call_args)) - js.add_consumer.assert_called_once() + manager = TaskQueueManager() + await manager.publish_task(456, sample_task) + + js.add_stream.assert_called_once() + self.assertIn("job_456", str(js.add_stream.call_args)) + js.add_consumer.assert_called_once() async def test_reserve_task_success(self): """Test successful task reservation.""" - nc, js = self._create_mock_nats_connection() sample_task = self._create_sample_task() # Mock message with task data @@ -73,59 +71,56 @@ async def test_reserve_task_success(self): mock_msg.reply = "reply.subject.123" mock_msg.metadata = MagicMock(sequence=MagicMock(stream=1)) - mock_psub = MagicMock() - mock_psub.fetch = AsyncMock(return_value=[mock_msg]) - mock_psub.unsubscribe = AsyncMock() - js.pull_subscribe = AsyncMock(return_value=mock_psub) + with self._mock_nats_setup() as (_, js, _): + mock_psub = MagicMock() + mock_psub.fetch = AsyncMock(return_value=[mock_msg]) + mock_psub.unsubscribe = AsyncMock() + js.pull_subscribe = AsyncMock(return_value=mock_psub) - with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): - async with TaskQueueManager() as manager: - task = await manager.reserve_task(123) + manager = TaskQueueManager() + task = await manager.reserve_task(123) - self.assertIsNotNone(task) - self.assertEqual(task.id, sample_task.id) - self.assertEqual(task.reply_subject, "reply.subject.123") - mock_psub.unsubscribe.assert_called_once() + self.assertIsNotNone(task) + self.assertEqual(task.id, sample_task.id) + self.assertEqual(task.reply_subject, "reply.subject.123") + mock_psub.unsubscribe.assert_called_once() async def test_reserve_task_no_messages(self): """Test reserve_task when no messages are available.""" - nc, js = self._create_mock_nats_connection() - - mock_psub = MagicMock() - mock_psub.fetch = AsyncMock(return_value=[]) - mock_psub.unsubscribe = AsyncMock() - js.pull_subscribe = AsyncMock(return_value=mock_psub) + with self._mock_nats_setup() as (_, js, _): + mock_psub = MagicMock() + mock_psub.fetch = AsyncMock(return_value=[]) + mock_psub.unsubscribe = AsyncMock() + js.pull_subscribe = AsyncMock(return_value=mock_psub) - with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): - async with TaskQueueManager() as manager: - task = await manager.reserve_task(123) + manager = TaskQueueManager() + task = await manager.reserve_task(123) - self.assertIsNone(task) - mock_psub.unsubscribe.assert_called_once() + self.assertIsNone(task) + mock_psub.unsubscribe.assert_called_once() async def test_acknowledge_task_success(self): """Test successful task acknowledgment.""" - nc, js = self._create_mock_nats_connection() - nc.publish = AsyncMock() + with self._mock_nats_setup() as (nc, _, _): + nc.publish = AsyncMock() + nc.flush = AsyncMock() - with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): - async with TaskQueueManager() as manager: - result = await manager.acknowledge_task("reply.subject.123") + manager = TaskQueueManager() + result = await manager.acknowledge_task("reply.subject.123") - self.assertTrue(result) - nc.publish.assert_called_once_with("reply.subject.123", b"+ACK") + self.assertTrue(result) + nc.publish.assert_called_once_with("reply.subject.123", b"+ACK") + nc.flush.assert_called_once() async def test_cleanup_job_resources(self): """Test cleanup of job resources (consumer and stream).""" - nc, js = self._create_mock_nats_connection() - - with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): - async with TaskQueueManager() as manager: - result = await manager.cleanup_job_resources(123) + with self._mock_nats_setup() as (_, js, _): + manager = TaskQueueManager() + result = await manager.cleanup_job_resources(123) - self.assertTrue(result) - js.delete_consumer.assert_called_once() - js.delete_stream.assert_called_once() + self.assertTrue(result) + js.delete_consumer.assert_called_once() + js.delete_stream.assert_called_once() async def test_naming_conventions(self): """Test stream, subject, and consumer naming conventions.""" @@ -134,17 +129,3 @@ async def test_naming_conventions(self): self.assertEqual(manager._get_stream_name(123), "job_123") self.assertEqual(manager._get_subject(123), "job.123.tasks") self.assertEqual(manager._get_consumer_name(123), "job-123-consumer") - - async def test_operations_without_connection_raise_error(self): - """Test that operations without connection raise RuntimeError.""" - manager = TaskQueueManager() - sample_task = self._create_sample_task() - - with self.assertRaisesRegex(RuntimeError, "Connection is not open"): - await manager.publish_task(123, sample_task) - - with self.assertRaisesRegex(RuntimeError, "Connection is not open"): - await manager.reserve_task(123) - - with self.assertRaisesRegex(RuntimeError, "Connection is not open"): - await manager.delete_stream(123) From 2acf62016746571a77413abd49029a146dc7399e Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 11 Feb 2026 10:19:42 -0800 Subject: [PATCH 04/16] Tighten formatting --- ami/ml/orchestration/nats_queue.py | 30 +----------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 4647892e5..7687264c1 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -38,17 +38,13 @@ def retry_on_connection_error(max_retries: int = 2, backoff_seconds: float = 0.5): """ - Decorator that retries NATS operations on connection errors. - - When a connection error is detected: + Decorator that retries NATS operations on connection errors. When a connection error is detected: 1. Resets the connection pool (clears stale connection) 2. Waits with exponential backoff 3. Retries the operation (which will get a fresh connection) - Args: max_retries: Maximum number of retry attempts (default: 2) backoff_seconds: Initial backoff time in seconds (default: 0.5) - Returns: Decorated async function with retry logic """ @@ -103,7 +99,6 @@ async def wrapper(self, *args, **kwargs) -> T: class TaskQueueManager: """ Manager for NATS JetStream task queue operations. - Always uses the process-local connection pool for efficiency. Note: The connection pool is shared across all instances in the same process, so there's no overhead creating multiple TaskQueueManager instances. @@ -179,16 +174,12 @@ async def _ensure_consumer(self, job_id: int): async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool: """ Publish a task to it's job queue. - Automatically retries on connection errors with exponential backoff. - Args: job_id: The job ID (integer primary key) data: PipelineProcessingTask object to be published - Returns: bool: True if successful, False otherwise - Raises: Connection errors are retried by decorator, other errors are raised """ @@ -213,14 +204,11 @@ async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool: async def reserve_task(self, job_id: int, timeout: float | None = None) -> PipelineProcessingTask | None: """ Reserve a task from the specified stream. - Automatically retries on connection errors with exponential backoff. Note: TimeoutError from fetch() (no messages) is NOT retried - only connection errors. - Args: job_id: The job ID (integer primary key) to pull tasks from timeout: Timeout in seconds for reservation (default: 5 seconds) - Returns: PipelineProcessingTask with reply_subject set for acknowledgment, or None if no task available """ @@ -268,16 +256,11 @@ async def reserve_task(self, job_id: int, timeout: float | None = None) -> Pipel async def acknowledge_task(self, reply_subject: str) -> bool: """ Acknowledge (delete) a completed task using its reply subject. - Automatically retries on connection errors with exponential backoff. - Uses a lock to serialize ACK operations to prevent concurrent access issues. - Args: reply_subject: The reply subject from reserve_task - Returns: bool: True if successful - Raises: Connection errors are retried by decorator, other errors are logged """ @@ -302,15 +285,11 @@ async def acknowledge_task(self, reply_subject: str) -> bool: async def delete_consumer(self, job_id: int) -> bool: """ Delete the consumer for a job. - Automatically retries on connection errors with exponential backoff. - Args: job_id: The job ID (integer primary key) - Returns: bool: True if successful - Raises: Connection errors are retried by decorator, other errors are raised """ @@ -327,15 +306,11 @@ async def delete_consumer(self, job_id: int) -> bool: async def delete_stream(self, job_id: int) -> bool: """ Delete the stream for a job. - Automatically retries on connection errors with exponential backoff. - Args: job_id: The job ID (integer primary key) - Returns: bool: True if successful - Raises: Connection errors are retried by decorator, other errors are raised """ @@ -350,13 +325,10 @@ async def delete_stream(self, job_id: int) -> bool: async def cleanup_job_resources(self, job_id: int) -> bool: """ Clean up all NATS resources (consumer and stream) for a job. - This should be called when a job completes or is cancelled. Best-effort cleanup - logs errors but doesn't fail if cleanup fails. - Args: job_id: The job ID (integer primary key) - Returns: bool: True if both cleanup operations succeeded, False otherwise """ From 0632ce0b8d9af682a857ee116ce1014d1bc8d9b9 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 11 Feb 2026 10:20:55 -0800 Subject: [PATCH 05/16] format --- ami/ml/orchestration/nats_queue.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 7687264c1..8edccd772 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -65,7 +65,6 @@ async def wrapper(self, *args, **kwargs) -> T: OSError, # Network errors ) as e: last_error = e - # Don't retry on last attempt if attempt == max_retries: logger.error( @@ -73,13 +72,11 @@ async def wrapper(self, *args, **kwargs) -> T: exc_info=True, ) break - # Reset the connection pool so next attempt gets a fresh connection from ami.ml.orchestration.nats_connection_pool import get_pool pool = get_pool() pool.reset() - # Exponential backoff wait_time = backoff_seconds * (2**attempt) logger.warning( @@ -87,7 +84,6 @@ async def wrapper(self, *args, **kwargs) -> T: f"Retrying in {wait_time}s..." ) await asyncio.sleep(wait_time) - # If we exhausted retries, raise the last error raise last_error # type: ignore From c5f81060426e170708e5f8758658f6190455b438 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 11 Feb 2026 11:58:09 -0800 Subject: [PATCH 06/16] CR feedback --- ami/ml/orchestration/nats_connection_pool.py | 83 +++++++++++++++----- ami/ml/orchestration/nats_queue.py | 25 ++++-- ami/ml/orchestration/tests/test_cleanup.py | 82 ++++++++++--------- 3 files changed, 123 insertions(+), 67 deletions(-) diff --git a/ami/ml/orchestration/nats_connection_pool.py b/ami/ml/orchestration/nats_connection_pool.py index 5769fda19..a8352d156 100644 --- a/ami/ml/orchestration/nats_connection_pool.py +++ b/ami/ml/orchestration/nats_connection_pool.py @@ -1,16 +1,18 @@ """ NATS connection pool for both Celery workers and Django processes. -Maintains a persistent NATS connection per process to avoid +Maintains a persistent NATS connection per event loop to avoid the overhead of creating/closing connections for every operation. -The connection pool is lazily initialized on first use and shared -across all operations in the same process. +The connection pool is lazily initialized on first use and keyed by event loop +to prevent "attached to a different loop" errors when using async_to_sync(). """ import asyncio import logging +import threading from typing import TYPE_CHECKING +from weakref import WeakKeyDictionary import nats from django.conf import settings @@ -24,23 +26,30 @@ class ConnectionPool: """ - Manages a single NATS connection per process (Celery worker or Django web worker). + Manages a single NATS connection per event loop. This is safe because: - - Each process gets its own isolated connection - - NATS connections are async-safe (can be used by multiple coroutines) - - Works in both Celery prefork and Django WSGI/ASGI contexts + - asyncio.Lock and NATS Client are bound to the event loop they were created on + - Each event loop gets its own isolated connection and lock + - Works correctly with async_to_sync() which creates per-thread event loops + - Prevents "attached to a different loop" errors in Celery tasks and Django views """ def __init__(self): self._nc: "NATSClient | None" = None self._js: JetStreamContext | None = None self._nats_url: str | None = None - self._lock = asyncio.Lock() + self._lock: asyncio.Lock | None = None # Lazy-initialized when needed + + def _ensure_lock(self) -> asyncio.Lock: + """Lazily create lock bound to current event loop.""" + if self._lock is None: + self._lock = asyncio.Lock() + return self._lock async def get_connection(self) -> tuple["NATSClient", JetStreamContext]: """ - Get or create the worker's NATS connection. Checks connection health and recreates if stale. + Get or create the event loop's NATS connection. Checks connection health and recreates if stale. Returns: Tuple of (NATS connection, JetStream context) @@ -58,7 +67,8 @@ async def get_connection(self) -> tuple["NATSClient", JetStreamContext]: self._js = None # Slow path: need to create/recreate connection - async with self._lock: + lock = self._ensure_lock() + async with lock: # Double-check after acquiring lock if self._nc is not None and not self._nc.is_closed and self._nc.is_connected: return self._nc, self._js # type: ignore @@ -85,28 +95,59 @@ async def close(self): self._nc = None self._js = None - def reset(self): + async def reset(self): """ - Reset the connection pool (mark connection as stale). + Async version of reset that properly closes the connection before clearing references. - This should be called when a connection error is detected. + This should be called when a connection error is detected from an async context. The next call to get_connection() will create a fresh connection. """ logger.warning("Resetting NATS connection pool due to connection error") + if self._nc is not None: + try: + # Attempt to close the connection gracefully + if not self._nc.is_closed: + await self._nc.close() + logger.debug("Successfully closed existing NATS connection during reset") + except Exception as e: + # Swallow errors - connection may already be broken + logger.debug(f"Error closing connection during reset (expected): {e}") self._nc = None self._js = None + self._lock = None # Clear lock so new one is created for fresh connection -# Global pool instance - one per process (Celery worker or Django process) -_connection_pool: ConnectionPool | None = None +# Event-loop-keyed pools: one ConnectionPool per event loop +# WeakKeyDictionary automatically cleans up when event loops are garbage collected +_pools: WeakKeyDictionary[asyncio.AbstractEventLoop, ConnectionPool] = WeakKeyDictionary() +_pools_lock = threading.Lock() def get_pool() -> ConnectionPool: """ - Get the process-local connection pool. + Get or create the connection pool for the current event loop. + + Each event loop gets its own ConnectionPool to prevent "attached to a different loop" errors. + This is critical when using async_to_sync() in Celery tasks or Django views, as each call + may run on a different event loop. + + Returns: + ConnectionPool bound to the current event loop + + Raises: + RuntimeError: If called outside of an async context (no running event loop) """ - global _connection_pool - if _connection_pool is None: - _connection_pool = ConnectionPool() - logger.debug("Lazily initialized NATS connection pool") - return _connection_pool + try: + loop = asyncio.get_running_loop() + except RuntimeError: + raise RuntimeError( + "get_pool() must be called from an async context with a running event loop. " + "If calling from sync code, use async_to_sync() to wrap the async function." + ) + + # Thread-safe lookup/creation of pool for this event loop + with _pools_lock: + if loop not in _pools: + _pools[loop] = ConnectionPool() + logger.debug(f"Created NATS connection pool for event loop {id(loop)}") + return _pools[loop] diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 8edccd772..fddd4ec9f 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -39,9 +39,13 @@ def retry_on_connection_error(max_retries: int = 2, backoff_seconds: float = 0.5): """ Decorator that retries NATS operations on connection errors. When a connection error is detected: - 1. Resets the connection pool (clears stale connection) + 1. Resets the event-loop-local connection pool (clears stale connection and lock) 2. Waits with exponential backoff - 3. Retries the operation (which will get a fresh connection) + 3. Retries the operation (which will get a fresh connection from the same event loop) + + This works correctly with async_to_sync() because the pool is keyed by event loop, + ensuring each retry uses the connection bound to the current loop. + Args: max_retries: Maximum number of retry attempts (default: 2) backoff_seconds: Initial backoff time in seconds (default: 0.5) @@ -53,6 +57,7 @@ def decorator(func: Callable[..., "Awaitable[T]"]) -> Callable[..., "Awaitable[T @functools.wraps(func) async def wrapper(self, *args, **kwargs) -> T: last_error = None + assert max_retries >= 0, "max_retries must be non-negative" for attempt in range(max_retries + 1): try: @@ -76,7 +81,7 @@ async def wrapper(self, *args, **kwargs) -> T: from ami.ml.orchestration.nats_connection_pool import get_pool pool = get_pool() - pool.reset() + await pool.reset() # Exponential backoff wait_time = backoff_seconds * (2**attempt) logger.warning( @@ -84,7 +89,8 @@ async def wrapper(self, *args, **kwargs) -> T: f"Retrying in {wait_time}s..." ) await asyncio.sleep(wait_time) - # If we exhausted retries, raise the last error + # If we exhausted retries, raise the last error, guaranteed to be not None here + assert last_error is not None, "last_error should not be None if we exhausted retries" raise last_error # type: ignore return wrapper @@ -95,13 +101,16 @@ async def wrapper(self, *args, **kwargs) -> T: class TaskQueueManager: """ Manager for NATS JetStream task queue operations. - Always uses the process-local connection pool for efficiency. - Note: The connection pool is shared across all instances in the same process, - so there's no overhead creating multiple TaskQueueManager instances. + Always uses the event-loop-local connection pool for efficiency. + + Note: The connection pool is keyed by event loop, so each event loop gets its own + connection. This prevents "attached to a different loop" errors when using async_to_sync() + in Celery tasks or Django views. There's no overhead creating multiple TaskQueueManager + instances as they all share the same event-loop-keyed pool. """ async def _get_connection(self) -> tuple["NATSClient", JetStreamContext]: - """Get connection from the process-local pool.""" + """Get connection from the event-loop-local pool.""" from ami.ml.orchestration.nats_connection_pool import get_pool pool = get_pool() diff --git a/ami/ml/orchestration/tests/test_cleanup.py b/ami/ml/orchestration/tests/test_cleanup.py index ef8382d3d..a5fc6fcb7 100644 --- a/ami/ml/orchestration/tests/test_cleanup.py +++ b/ami/ml/orchestration/tests/test_cleanup.py @@ -68,25 +68,28 @@ def _verify_resources_created(self, job_id: int): # Verify NATS stream and consumer exist async def check_nats_resources(): - async with TaskQueueManager() as manager: - stream_name = manager._get_stream_name(job_id) - consumer_name = manager._get_consumer_name(job_id) - - # Try to get stream info - should succeed if created - stream_exists = True - try: - await manager.js.stream_info(stream_name) - except NotFoundError: - stream_exists = False - - # Try to get consumer info - should succeed if created - consumer_exists = True - try: - await manager.js.consumer_info(stream_name, consumer_name) - except NotFoundError: - consumer_exists = False - - return stream_exists, consumer_exists + manager = TaskQueueManager() + stream_name = manager._get_stream_name(job_id) + consumer_name = manager._get_consumer_name(job_id) + + # Get JetStream context + _, js = await manager._get_connection() + + # Try to get stream info - should succeed if created + stream_exists = True + try: + await js.stream_info(stream_name) + except NotFoundError: + stream_exists = False + + # Try to get consumer info - should succeed if created + consumer_exists = True + try: + await js.consumer_info(stream_name, consumer_name) + except NotFoundError: + consumer_exists = False + + return stream_exists, consumer_exists stream_exists, consumer_exists = async_to_sync(check_nats_resources)() @@ -134,25 +137,28 @@ def _verify_resources_cleaned(self, job_id: int): # Verify NATS stream and consumer are deleted async def check_nats_resources(): - async with TaskQueueManager() as manager: - stream_name = manager._get_stream_name(job_id) - consumer_name = manager._get_consumer_name(job_id) - - # Try to get stream info - should fail if deleted - stream_exists = True - try: - await manager.js.stream_info(stream_name) - except NotFoundError: - stream_exists = False - - # Try to get consumer info - should fail if deleted - consumer_exists = True - try: - await manager.js.consumer_info(stream_name, consumer_name) - except NotFoundError: - consumer_exists = False - - return stream_exists, consumer_exists + manager = TaskQueueManager() + stream_name = manager._get_stream_name(job_id) + consumer_name = manager._get_consumer_name(job_id) + + # Get JetStream context + _, js = await manager._get_connection() + + # Try to get stream info - should fail if deleted + stream_exists = True + try: + await js.stream_info(stream_name) + except NotFoundError: + stream_exists = False + + # Try to get consumer info - should fail if deleted + consumer_exists = True + try: + await js.consumer_info(stream_name, consumer_name) + except NotFoundError: + consumer_exists = False + + return stream_exists, consumer_exists stream_exists, consumer_exists = async_to_sync(check_nats_resources)() From 8805dbe0b160a4315479c51a189ff21894d52bae Mon Sep 17 00:00:00 2001 From: carlosgjs Date: Wed, 11 Feb 2026 12:01:21 -0800 Subject: [PATCH 07/16] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ami/ml/orchestration/nats_queue.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index fddd4ec9f..974b12b8a 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -178,13 +178,13 @@ async def _ensure_consumer(self, job_id: int): @retry_on_connection_error(max_retries=2, backoff_seconds=0.5) async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool: """ - Publish a task to it's job queue. + Publish a task to its job queue. Automatically retries on connection errors with exponential backoff. Args: job_id: The job ID (integer primary key) data: PipelineProcessingTask object to be published Returns: - bool: True if successful, False otherwise + bool: True if successful Raises: Connection errors are retried by decorator, other errors are raised """ @@ -267,7 +267,7 @@ async def acknowledge_task(self, reply_subject: str) -> bool: Returns: bool: True if successful Raises: - Connection errors are retried by decorator, other errors are logged + Connection errors are retried by decorator, other errors are raised """ nc, _ = await self._get_connection() @@ -324,7 +324,7 @@ async def delete_stream(self, job_id: int) -> bool: stream_name = self._get_stream_name(job_id) await js.delete_stream(stream_name) - logger.info(f"Deleted stream {stream_name} for job'{job_id}'") + logger.info(f"Deleted stream {stream_name} for job '{job_id}'") return True async def cleanup_job_resources(self, job_id: int) -> bool: From c384199f20b234652752446ac6acb7a92fcbbb4d Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 12 Feb 2026 15:55:01 -0800 Subject: [PATCH 08/16] =?UTF-8?q?refactor:=20simplify=20NATS=20connection?= =?UTF-8?q?=20handling=20=E2=80=94=20keep=20retry=20decorator,=20drop=20po?= =?UTF-8?q?ol?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the event-loop-keyed WeakKeyDictionary connection pool with a straightforward async context manager on TaskQueueManager. Each async_to_sync() call now scopes one connection for its block of operations (e.g. queue_all_images reuses one connection for all publishes, _ack_task_via_nats gets one for the single ACK). The retry decorator is preserved — on connection error it closes the stale connection so the next _get_connection() creates a fresh one. Also adds reconnected_cb/disconnected_cb logging callbacks to nats.connect() and narrows bare except clauses to NotFoundError. Co-Authored-By: Claude --- ami/jobs/tasks.py | 4 +- ami/jobs/views.py | 10 +- ami/ml/orchestration/jobs.py | 36 ++--- ami/ml/orchestration/nats_connection_pool.py | 153 ------------------ ami/ml/orchestration/nats_queue.py | 95 +++++++---- ami/ml/orchestration/tests/test_cleanup.py | 80 ++++----- ami/ml/orchestration/tests/test_nats_queue.py | 57 +++---- 7 files changed, 162 insertions(+), 273 deletions(-) delete mode 100644 ami/ml/orchestration/nats_connection_pool.py diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 94a6a821a..3548d0ea5 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -135,8 +135,8 @@ def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None: try: async def ack_task(): - manager = TaskQueueManager() - return await manager.acknowledge_task(reply_subject) + async with TaskQueueManager() as manager: + return await manager.acknowledge_task(reply_subject) ack_success = async_to_sync(ack_task)() diff --git a/ami/jobs/views.py b/ami/jobs/views.py index f45361007..dd8da01b2 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -242,11 +242,11 @@ def tasks(self, request, pk=None): async def get_tasks(): tasks = [] - manager = TaskQueueManager() - for _ in range(batch): - task = await manager.reserve_task(job.pk, timeout=0.1) - if task: - tasks.append(task.dict()) + async with TaskQueueManager() as manager: + for _ in range(batch): + task = await manager.reserve_task(job.pk, timeout=0.1) + if task: + tasks.append(task.dict()) return tasks # Use async_to_sync to properly handle the async call diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index b4b9171b7..9b4a577e9 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -40,8 +40,8 @@ def cleanup_async_job_resources(job: "Job") -> bool: # Cleanup NATS resources async def cleanup(): - manager = TaskQueueManager() - return await manager.cleanup_job_resources(job.pk) + async with TaskQueueManager() as manager: + return await manager.cleanup_job_resources(job.pk) try: nats_success = async_to_sync(cleanup)() @@ -96,22 +96,22 @@ async def queue_all_images(): successful_queues = 0 failed_queues = 0 - manager = TaskQueueManager() - for image_pk, task in tasks: - try: - logger.info(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}") - success = await manager.publish_task( - job_id=job.pk, - data=task, - ) - except Exception as e: - logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}") - success = False - - if success: - successful_queues += 1 - else: - failed_queues += 1 + async with TaskQueueManager() as manager: + for image_pk, task in tasks: + try: + logger.info(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}") + success = await manager.publish_task( + job_id=job.pk, + data=task, + ) + except Exception as e: + logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}") + success = False + + if success: + successful_queues += 1 + else: + failed_queues += 1 return successful_queues, failed_queues diff --git a/ami/ml/orchestration/nats_connection_pool.py b/ami/ml/orchestration/nats_connection_pool.py deleted file mode 100644 index a8352d156..000000000 --- a/ami/ml/orchestration/nats_connection_pool.py +++ /dev/null @@ -1,153 +0,0 @@ -""" -NATS connection pool for both Celery workers and Django processes. - -Maintains a persistent NATS connection per event loop to avoid -the overhead of creating/closing connections for every operation. - -The connection pool is lazily initialized on first use and keyed by event loop -to prevent "attached to a different loop" errors when using async_to_sync(). -""" - -import asyncio -import logging -import threading -from typing import TYPE_CHECKING -from weakref import WeakKeyDictionary - -import nats -from django.conf import settings -from nats.js import JetStreamContext - -if TYPE_CHECKING: - from nats.aio.client import Client as NATSClient - -logger = logging.getLogger(__name__) - - -class ConnectionPool: - """ - Manages a single NATS connection per event loop. - - This is safe because: - - asyncio.Lock and NATS Client are bound to the event loop they were created on - - Each event loop gets its own isolated connection and lock - - Works correctly with async_to_sync() which creates per-thread event loops - - Prevents "attached to a different loop" errors in Celery tasks and Django views - """ - - def __init__(self): - self._nc: "NATSClient | None" = None - self._js: JetStreamContext | None = None - self._nats_url: str | None = None - self._lock: asyncio.Lock | None = None # Lazy-initialized when needed - - def _ensure_lock(self) -> asyncio.Lock: - """Lazily create lock bound to current event loop.""" - if self._lock is None: - self._lock = asyncio.Lock() - return self._lock - - async def get_connection(self) -> tuple["NATSClient", JetStreamContext]: - """ - Get or create the event loop's NATS connection. Checks connection health and recreates if stale. - - Returns: - Tuple of (NATS connection, JetStream context) - Raises: - RuntimeError: If connection cannot be established - """ - # Fast path: connection exists, is open, and is connected - if self._nc is not None and not self._nc.is_closed and self._nc.is_connected: - return self._nc, self._js # type: ignore - - # Connection is stale or doesn't exist - if self._nc is not None: - logger.warning("NATS connection is closed or disconnected, will reconnect") - self._nc = None - self._js = None - - # Slow path: need to create/recreate connection - lock = self._ensure_lock() - async with lock: - # Double-check after acquiring lock - if self._nc is not None and not self._nc.is_closed and self._nc.is_connected: - return self._nc, self._js # type: ignore - - # Get NATS URL from settings - if self._nats_url is None: - self._nats_url = getattr(settings, "NATS_URL", "nats://nats:4222") - - try: - logger.info(f"Creating NATS connection to {self._nats_url}") - self._nc = await nats.connect(self._nats_url) - self._js = self._nc.jetstream() - logger.info(f"Successfully connected to NATS at {self._nats_url}") - return self._nc, self._js - except Exception as e: - logger.error(f"Failed to connect to NATS: {e}") - raise RuntimeError(f"Could not establish NATS connection: {e}") from e - - async def close(self): - """Close the NATS connection if it exists.""" - if self._nc is not None and not self._nc.is_closed: - logger.info("Closing NATS connection") - await self._nc.close() - self._nc = None - self._js = None - - async def reset(self): - """ - Async version of reset that properly closes the connection before clearing references. - - This should be called when a connection error is detected from an async context. - The next call to get_connection() will create a fresh connection. - """ - logger.warning("Resetting NATS connection pool due to connection error") - if self._nc is not None: - try: - # Attempt to close the connection gracefully - if not self._nc.is_closed: - await self._nc.close() - logger.debug("Successfully closed existing NATS connection during reset") - except Exception as e: - # Swallow errors - connection may already be broken - logger.debug(f"Error closing connection during reset (expected): {e}") - self._nc = None - self._js = None - self._lock = None # Clear lock so new one is created for fresh connection - - -# Event-loop-keyed pools: one ConnectionPool per event loop -# WeakKeyDictionary automatically cleans up when event loops are garbage collected -_pools: WeakKeyDictionary[asyncio.AbstractEventLoop, ConnectionPool] = WeakKeyDictionary() -_pools_lock = threading.Lock() - - -def get_pool() -> ConnectionPool: - """ - Get or create the connection pool for the current event loop. - - Each event loop gets its own ConnectionPool to prevent "attached to a different loop" errors. - This is critical when using async_to_sync() in Celery tasks or Django views, as each call - may run on a different event loop. - - Returns: - ConnectionPool bound to the current event loop - - Raises: - RuntimeError: If called outside of an async context (no running event loop) - """ - try: - loop = asyncio.get_running_loop() - except RuntimeError: - raise RuntimeError( - "get_pool() must be called from an async context with a running event loop. " - "If calling from sync code, use async_to_sync() to wrap the async function." - ) - - # Thread-safe lookup/creation of pool for this event loop - with _pools_lock: - if loop not in _pools: - _pools[loop] = ConnectionPool() - logger.debug(f"Created NATS connection pool for event loop {id(loop)}") - return _pools[loop] diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 974b12b8a..582de0d55 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -17,9 +17,12 @@ from collections.abc import Callable from typing import TYPE_CHECKING, TypeVar +import nats +from django.conf import settings from nats import errors as nats_errors from nats.js import JetStreamContext from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy +from nats.js.errors import NotFoundError from ami.ml.schemas import PipelineProcessingTask @@ -38,13 +41,9 @@ def retry_on_connection_error(max_retries: int = 2, backoff_seconds: float = 0.5): """ - Decorator that retries NATS operations on connection errors. When a connection error is detected: - 1. Resets the event-loop-local connection pool (clears stale connection and lock) - 2. Waits with exponential backoff - 3. Retries the operation (which will get a fresh connection from the same event loop) - - This works correctly with async_to_sync() because the pool is keyed by event loop, - ensuring each retry uses the connection bound to the current loop. + Decorator that retries NATS operations on connection errors. When a connection error + is detected, the manager's connection is cleared so the next call to _get_connection() + creates a fresh one. Args: max_retries: Maximum number of retry attempts (default: 2) @@ -55,7 +54,7 @@ def retry_on_connection_error(max_retries: int = 2, backoff_seconds: float = 0.5 def decorator(func: Callable[..., "Awaitable[T]"]) -> Callable[..., "Awaitable[T]"]: @functools.wraps(func) - async def wrapper(self, *args, **kwargs) -> T: + async def wrapper(self: "TaskQueueManager", *args, **kwargs) -> T: last_error = None assert max_retries >= 0, "max_retries must be non-negative" @@ -77,11 +76,8 @@ async def wrapper(self, *args, **kwargs) -> T: exc_info=True, ) break - # Reset the connection pool so next attempt gets a fresh connection - from ami.ml.orchestration.nats_connection_pool import get_pool - - pool = get_pool() - await pool.reset() + # Close stale connection so next attempt creates a fresh one + await self._close_connection() # Exponential backoff wait_time = backoff_seconds * (2**attempt) logger.warning( @@ -101,20 +97,66 @@ async def wrapper(self, *args, **kwargs) -> T: class TaskQueueManager: """ Manager for NATS JetStream task queue operations. - Always uses the event-loop-local connection pool for efficiency. - Note: The connection pool is keyed by event loop, so each event loop gets its own - connection. This prevents "attached to a different loop" errors when using async_to_sync() - in Celery tasks or Django views. There's no overhead creating multiple TaskQueueManager - instances as they all share the same event-loop-keyed pool. + Use as an async context manager to scope a single connection to a block of operations: + + async with TaskQueueManager() as manager: + await manager.publish_task(job_id, task) + + The connection is created on entry and closed on exit. Within the block, the retry + decorator handles transient connection errors by clearing and recreating the connection. """ - async def _get_connection(self) -> tuple["NATSClient", JetStreamContext]: - """Get connection from the event-loop-local pool.""" - from ami.ml.orchestration.nats_connection_pool import get_pool + def __init__(self): + self._nc: "NATSClient | None" = None + self._js: JetStreamContext | None = None - pool = get_pool() - return await pool.get_connection() + async def __aenter__(self) -> "TaskQueueManager": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self._close_connection() + + async def _get_connection(self) -> tuple["NATSClient", JetStreamContext]: + """Get or create a NATS connection.""" + # Fast path: connection exists and is healthy + if self._nc is not None and not self._nc.is_closed and self._nc.is_connected: + return self._nc, self._js # type: ignore + + # Connection is stale, clear it + if self._nc is not None: + logger.warning("NATS connection is closed or disconnected, will reconnect") + await self._close_connection() + + nats_url = getattr(settings, "NATS_URL", "nats://nats:4222") + logger.info(f"Connecting to NATS at {nats_url}") + self._nc = await nats.connect( + nats_url, + reconnected_cb=self._on_reconnected, + disconnected_cb=self._on_disconnected, + ) + self._js = self._nc.jetstream() + logger.info(f"Connected to NATS at {nats_url}") + return self._nc, self._js + + async def _close_connection(self) -> None: + """Close the NATS connection if open.""" + if self._nc is not None: + try: + if not self._nc.is_closed: + await self._nc.close() + except Exception as e: + logger.debug(f"Error closing NATS connection (expected during error recovery): {e}") + self._nc = None + self._js = None + + @staticmethod + async def _on_reconnected(): + logger.info("NATS client reconnected") + + @staticmethod + async def _on_disconnected(): + logger.warning("NATS client disconnected") def _get_stream_name(self, job_id: int) -> str: """Get stream name from job_id.""" @@ -138,9 +180,7 @@ async def _ensure_stream(self, job_id: int): try: await js.stream_info(stream_name) logger.debug(f"Stream {stream_name} already exists") - except Exception as e: - logger.warning(f"Stream {stream_name} does not exist: {e}") - # Stream doesn't exist, create it + except NotFoundError: await js.add_stream( name=stream_name, subjects=[subject], @@ -159,8 +199,7 @@ async def _ensure_consumer(self, job_id: int): try: info = await js.consumer_info(stream_name, consumer_name) logger.debug(f"Consumer {consumer_name} already exists: {info}") - except Exception: - # Consumer doesn't exist, create it + except NotFoundError: await js.add_consumer( stream=stream_name, config=ConsumerConfig( diff --git a/ami/ml/orchestration/tests/test_cleanup.py b/ami/ml/orchestration/tests/test_cleanup.py index a5fc6fcb7..e201b988f 100644 --- a/ami/ml/orchestration/tests/test_cleanup.py +++ b/ami/ml/orchestration/tests/test_cleanup.py @@ -68,28 +68,28 @@ def _verify_resources_created(self, job_id: int): # Verify NATS stream and consumer exist async def check_nats_resources(): - manager = TaskQueueManager() - stream_name = manager._get_stream_name(job_id) - consumer_name = manager._get_consumer_name(job_id) + async with TaskQueueManager() as manager: + stream_name = manager._get_stream_name(job_id) + consumer_name = manager._get_consumer_name(job_id) - # Get JetStream context - _, js = await manager._get_connection() + # Get JetStream context + _, js = await manager._get_connection() - # Try to get stream info - should succeed if created - stream_exists = True - try: - await js.stream_info(stream_name) - except NotFoundError: - stream_exists = False + # Try to get stream info - should succeed if created + stream_exists = True + try: + await js.stream_info(stream_name) + except NotFoundError: + stream_exists = False - # Try to get consumer info - should succeed if created - consumer_exists = True - try: - await js.consumer_info(stream_name, consumer_name) - except NotFoundError: - consumer_exists = False + # Try to get consumer info - should succeed if created + consumer_exists = True + try: + await js.consumer_info(stream_name, consumer_name) + except NotFoundError: + consumer_exists = False - return stream_exists, consumer_exists + return stream_exists, consumer_exists stream_exists, consumer_exists = async_to_sync(check_nats_resources)() @@ -137,28 +137,28 @@ def _verify_resources_cleaned(self, job_id: int): # Verify NATS stream and consumer are deleted async def check_nats_resources(): - manager = TaskQueueManager() - stream_name = manager._get_stream_name(job_id) - consumer_name = manager._get_consumer_name(job_id) - - # Get JetStream context - _, js = await manager._get_connection() - - # Try to get stream info - should fail if deleted - stream_exists = True - try: - await js.stream_info(stream_name) - except NotFoundError: - stream_exists = False - - # Try to get consumer info - should fail if deleted - consumer_exists = True - try: - await js.consumer_info(stream_name, consumer_name) - except NotFoundError: - consumer_exists = False - - return stream_exists, consumer_exists + async with TaskQueueManager() as manager: + stream_name = manager._get_stream_name(job_id) + consumer_name = manager._get_consumer_name(job_id) + + # Get JetStream context + _, js = await manager._get_connection() + + # Try to get stream info - should fail if deleted + stream_exists = True + try: + await js.stream_info(stream_name) + except NotFoundError: + stream_exists = False + + # Try to get consumer info - should fail if deleted + consumer_exists = True + try: + await js.consumer_info(stream_name, consumer_name) + except NotFoundError: + consumer_exists = False + + return stream_exists, consumer_exists stream_exists, consumer_exists = async_to_sync(check_nats_resources)() diff --git a/ami/ml/orchestration/tests/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py index def90ba68..f36e08b3b 100644 --- a/ami/ml/orchestration/tests/test_nats_queue.py +++ b/ami/ml/orchestration/tests/test_nats_queue.py @@ -1,7 +1,7 @@ """Unit tests for TaskQueueManager.""" import unittest -from contextlib import contextmanager +from contextlib import asynccontextmanager from unittest.mock import AsyncMock, MagicMock, patch from ami.ml.orchestration.nats_queue import TaskQueueManager @@ -19,15 +19,16 @@ def _create_sample_task(self): image_url="https://example.com/image.jpg", ) - @contextmanager - def _mock_nats_setup(self): - """Helper to create and mock NATS connection with connection pool. + @asynccontextmanager + async def _mock_nats_setup(self): + """Helper to create and mock NATS connection. Yields: - tuple: (nc, js, mock_pool) - NATS connection, JetStream context, and mock pool + tuple: (nc, js) - mock NATS client and JetStream context """ nc = MagicMock() nc.is_closed = False + nc.is_connected = True nc.close = AsyncMock() js = MagicMock() @@ -40,22 +41,24 @@ def _mock_nats_setup(self): js.delete_consumer = AsyncMock() js.delete_stream = AsyncMock() - with patch("ami.ml.orchestration.nats_connection_pool.get_pool") as mock_get_pool: - mock_pool = MagicMock() - mock_pool.get_connection = AsyncMock(return_value=(nc, js)) - mock_get_pool.return_value = mock_pool - yield nc, js, mock_pool + mock_nc = AsyncMock(return_value=nc) + nc.jetstream.return_value = js + + with patch("ami.ml.orchestration.nats_queue.nats.connect", mock_nc): + yield nc, js async def test_publish_task_creates_stream_and_consumer(self): """Test that publish_task ensures stream and consumer exist.""" + from nats.js.errors import NotFoundError + sample_task = self._create_sample_task() - with self._mock_nats_setup() as (_, js, _): - js.stream_info.side_effect = Exception("Not found") - js.consumer_info.side_effect = Exception("Not found") + async with self._mock_nats_setup() as (_, js): + js.stream_info.side_effect = NotFoundError + js.consumer_info.side_effect = NotFoundError - manager = TaskQueueManager() - await manager.publish_task(456, sample_task) + async with TaskQueueManager() as manager: + await manager.publish_task(456, sample_task) js.add_stream.assert_called_once() self.assertIn("job_456", str(js.add_stream.call_args)) @@ -71,14 +74,14 @@ async def test_reserve_task_success(self): mock_msg.reply = "reply.subject.123" mock_msg.metadata = MagicMock(sequence=MagicMock(stream=1)) - with self._mock_nats_setup() as (_, js, _): + async with self._mock_nats_setup() as (_, js): mock_psub = MagicMock() mock_psub.fetch = AsyncMock(return_value=[mock_msg]) mock_psub.unsubscribe = AsyncMock() js.pull_subscribe = AsyncMock(return_value=mock_psub) - manager = TaskQueueManager() - task = await manager.reserve_task(123) + async with TaskQueueManager() as manager: + task = await manager.reserve_task(123) self.assertIsNotNone(task) self.assertEqual(task.id, sample_task.id) @@ -87,26 +90,26 @@ async def test_reserve_task_success(self): async def test_reserve_task_no_messages(self): """Test reserve_task when no messages are available.""" - with self._mock_nats_setup() as (_, js, _): + async with self._mock_nats_setup() as (_, js): mock_psub = MagicMock() mock_psub.fetch = AsyncMock(return_value=[]) mock_psub.unsubscribe = AsyncMock() js.pull_subscribe = AsyncMock(return_value=mock_psub) - manager = TaskQueueManager() - task = await manager.reserve_task(123) + async with TaskQueueManager() as manager: + task = await manager.reserve_task(123) self.assertIsNone(task) mock_psub.unsubscribe.assert_called_once() async def test_acknowledge_task_success(self): """Test successful task acknowledgment.""" - with self._mock_nats_setup() as (nc, _, _): + async with self._mock_nats_setup() as (nc, _): nc.publish = AsyncMock() nc.flush = AsyncMock() - manager = TaskQueueManager() - result = await manager.acknowledge_task("reply.subject.123") + async with TaskQueueManager() as manager: + result = await manager.acknowledge_task("reply.subject.123") self.assertTrue(result) nc.publish.assert_called_once_with("reply.subject.123", b"+ACK") @@ -114,9 +117,9 @@ async def test_acknowledge_task_success(self): async def test_cleanup_job_resources(self): """Test cleanup of job resources (consumer and stream).""" - with self._mock_nats_setup() as (_, js, _): - manager = TaskQueueManager() - result = await manager.cleanup_job_resources(123) + async with self._mock_nats_setup() as (_, js): + async with TaskQueueManager() as manager: + result = await manager.cleanup_job_resources(123) self.assertTrue(result) js.delete_consumer.assert_called_once() From cf42506416a18c8a11f8f201b4a6abdaa1350fdf Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 13 Feb 2026 12:02:41 -0800 Subject: [PATCH 09/16] =?UTF-8?q?revert:=20restore=20NATS=20connection=20p?= =?UTF-8?q?ool=20=E2=80=94=20avoid=20per-operation=20connection=20churn?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reverts c384199f which replaced the event-loop-keyed connection pool with a plain async context manager. The context manager approach opened and closed a TCP connection per async block, causing ~1500 connections per 1000-image job (250-500 for task fetches, 1000 for ACKs). The connection pool keeps one persistent connection per event loop and reuses it across all TaskQueueManager operations. Co-Authored-By: Claude --- ami/jobs/tasks.py | 4 +- ami/jobs/views.py | 10 +- ami/ml/orchestration/jobs.py | 36 ++--- ami/ml/orchestration/nats_connection_pool.py | 153 ++++++++++++++++++ ami/ml/orchestration/nats_queue.py | 95 ++++------- ami/ml/orchestration/tests/test_cleanup.py | 80 ++++----- ami/ml/orchestration/tests/test_nats_queue.py | 57 ++++--- 7 files changed, 273 insertions(+), 162 deletions(-) create mode 100644 ami/ml/orchestration/nats_connection_pool.py diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 3548d0ea5..94a6a821a 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -135,8 +135,8 @@ def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None: try: async def ack_task(): - async with TaskQueueManager() as manager: - return await manager.acknowledge_task(reply_subject) + manager = TaskQueueManager() + return await manager.acknowledge_task(reply_subject) ack_success = async_to_sync(ack_task)() diff --git a/ami/jobs/views.py b/ami/jobs/views.py index dd8da01b2..f45361007 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -242,11 +242,11 @@ def tasks(self, request, pk=None): async def get_tasks(): tasks = [] - async with TaskQueueManager() as manager: - for _ in range(batch): - task = await manager.reserve_task(job.pk, timeout=0.1) - if task: - tasks.append(task.dict()) + manager = TaskQueueManager() + for _ in range(batch): + task = await manager.reserve_task(job.pk, timeout=0.1) + if task: + tasks.append(task.dict()) return tasks # Use async_to_sync to properly handle the async call diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index 9b4a577e9..b4b9171b7 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -40,8 +40,8 @@ def cleanup_async_job_resources(job: "Job") -> bool: # Cleanup NATS resources async def cleanup(): - async with TaskQueueManager() as manager: - return await manager.cleanup_job_resources(job.pk) + manager = TaskQueueManager() + return await manager.cleanup_job_resources(job.pk) try: nats_success = async_to_sync(cleanup)() @@ -96,22 +96,22 @@ async def queue_all_images(): successful_queues = 0 failed_queues = 0 - async with TaskQueueManager() as manager: - for image_pk, task in tasks: - try: - logger.info(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}") - success = await manager.publish_task( - job_id=job.pk, - data=task, - ) - except Exception as e: - logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}") - success = False - - if success: - successful_queues += 1 - else: - failed_queues += 1 + manager = TaskQueueManager() + for image_pk, task in tasks: + try: + logger.info(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}") + success = await manager.publish_task( + job_id=job.pk, + data=task, + ) + except Exception as e: + logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}") + success = False + + if success: + successful_queues += 1 + else: + failed_queues += 1 return successful_queues, failed_queues diff --git a/ami/ml/orchestration/nats_connection_pool.py b/ami/ml/orchestration/nats_connection_pool.py new file mode 100644 index 000000000..a8352d156 --- /dev/null +++ b/ami/ml/orchestration/nats_connection_pool.py @@ -0,0 +1,153 @@ +""" +NATS connection pool for both Celery workers and Django processes. + +Maintains a persistent NATS connection per event loop to avoid +the overhead of creating/closing connections for every operation. + +The connection pool is lazily initialized on first use and keyed by event loop +to prevent "attached to a different loop" errors when using async_to_sync(). +""" + +import asyncio +import logging +import threading +from typing import TYPE_CHECKING +from weakref import WeakKeyDictionary + +import nats +from django.conf import settings +from nats.js import JetStreamContext + +if TYPE_CHECKING: + from nats.aio.client import Client as NATSClient + +logger = logging.getLogger(__name__) + + +class ConnectionPool: + """ + Manages a single NATS connection per event loop. + + This is safe because: + - asyncio.Lock and NATS Client are bound to the event loop they were created on + - Each event loop gets its own isolated connection and lock + - Works correctly with async_to_sync() which creates per-thread event loops + - Prevents "attached to a different loop" errors in Celery tasks and Django views + """ + + def __init__(self): + self._nc: "NATSClient | None" = None + self._js: JetStreamContext | None = None + self._nats_url: str | None = None + self._lock: asyncio.Lock | None = None # Lazy-initialized when needed + + def _ensure_lock(self) -> asyncio.Lock: + """Lazily create lock bound to current event loop.""" + if self._lock is None: + self._lock = asyncio.Lock() + return self._lock + + async def get_connection(self) -> tuple["NATSClient", JetStreamContext]: + """ + Get or create the event loop's NATS connection. Checks connection health and recreates if stale. + + Returns: + Tuple of (NATS connection, JetStream context) + Raises: + RuntimeError: If connection cannot be established + """ + # Fast path: connection exists, is open, and is connected + if self._nc is not None and not self._nc.is_closed and self._nc.is_connected: + return self._nc, self._js # type: ignore + + # Connection is stale or doesn't exist + if self._nc is not None: + logger.warning("NATS connection is closed or disconnected, will reconnect") + self._nc = None + self._js = None + + # Slow path: need to create/recreate connection + lock = self._ensure_lock() + async with lock: + # Double-check after acquiring lock + if self._nc is not None and not self._nc.is_closed and self._nc.is_connected: + return self._nc, self._js # type: ignore + + # Get NATS URL from settings + if self._nats_url is None: + self._nats_url = getattr(settings, "NATS_URL", "nats://nats:4222") + + try: + logger.info(f"Creating NATS connection to {self._nats_url}") + self._nc = await nats.connect(self._nats_url) + self._js = self._nc.jetstream() + logger.info(f"Successfully connected to NATS at {self._nats_url}") + return self._nc, self._js + except Exception as e: + logger.error(f"Failed to connect to NATS: {e}") + raise RuntimeError(f"Could not establish NATS connection: {e}") from e + + async def close(self): + """Close the NATS connection if it exists.""" + if self._nc is not None and not self._nc.is_closed: + logger.info("Closing NATS connection") + await self._nc.close() + self._nc = None + self._js = None + + async def reset(self): + """ + Async version of reset that properly closes the connection before clearing references. + + This should be called when a connection error is detected from an async context. + The next call to get_connection() will create a fresh connection. + """ + logger.warning("Resetting NATS connection pool due to connection error") + if self._nc is not None: + try: + # Attempt to close the connection gracefully + if not self._nc.is_closed: + await self._nc.close() + logger.debug("Successfully closed existing NATS connection during reset") + except Exception as e: + # Swallow errors - connection may already be broken + logger.debug(f"Error closing connection during reset (expected): {e}") + self._nc = None + self._js = None + self._lock = None # Clear lock so new one is created for fresh connection + + +# Event-loop-keyed pools: one ConnectionPool per event loop +# WeakKeyDictionary automatically cleans up when event loops are garbage collected +_pools: WeakKeyDictionary[asyncio.AbstractEventLoop, ConnectionPool] = WeakKeyDictionary() +_pools_lock = threading.Lock() + + +def get_pool() -> ConnectionPool: + """ + Get or create the connection pool for the current event loop. + + Each event loop gets its own ConnectionPool to prevent "attached to a different loop" errors. + This is critical when using async_to_sync() in Celery tasks or Django views, as each call + may run on a different event loop. + + Returns: + ConnectionPool bound to the current event loop + + Raises: + RuntimeError: If called outside of an async context (no running event loop) + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + raise RuntimeError( + "get_pool() must be called from an async context with a running event loop. " + "If calling from sync code, use async_to_sync() to wrap the async function." + ) + + # Thread-safe lookup/creation of pool for this event loop + with _pools_lock: + if loop not in _pools: + _pools[loop] = ConnectionPool() + logger.debug(f"Created NATS connection pool for event loop {id(loop)}") + return _pools[loop] diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 582de0d55..974b12b8a 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -17,12 +17,9 @@ from collections.abc import Callable from typing import TYPE_CHECKING, TypeVar -import nats -from django.conf import settings from nats import errors as nats_errors from nats.js import JetStreamContext from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy -from nats.js.errors import NotFoundError from ami.ml.schemas import PipelineProcessingTask @@ -41,9 +38,13 @@ def retry_on_connection_error(max_retries: int = 2, backoff_seconds: float = 0.5): """ - Decorator that retries NATS operations on connection errors. When a connection error - is detected, the manager's connection is cleared so the next call to _get_connection() - creates a fresh one. + Decorator that retries NATS operations on connection errors. When a connection error is detected: + 1. Resets the event-loop-local connection pool (clears stale connection and lock) + 2. Waits with exponential backoff + 3. Retries the operation (which will get a fresh connection from the same event loop) + + This works correctly with async_to_sync() because the pool is keyed by event loop, + ensuring each retry uses the connection bound to the current loop. Args: max_retries: Maximum number of retry attempts (default: 2) @@ -54,7 +55,7 @@ def retry_on_connection_error(max_retries: int = 2, backoff_seconds: float = 0.5 def decorator(func: Callable[..., "Awaitable[T]"]) -> Callable[..., "Awaitable[T]"]: @functools.wraps(func) - async def wrapper(self: "TaskQueueManager", *args, **kwargs) -> T: + async def wrapper(self, *args, **kwargs) -> T: last_error = None assert max_retries >= 0, "max_retries must be non-negative" @@ -76,8 +77,11 @@ async def wrapper(self: "TaskQueueManager", *args, **kwargs) -> T: exc_info=True, ) break - # Close stale connection so next attempt creates a fresh one - await self._close_connection() + # Reset the connection pool so next attempt gets a fresh connection + from ami.ml.orchestration.nats_connection_pool import get_pool + + pool = get_pool() + await pool.reset() # Exponential backoff wait_time = backoff_seconds * (2**attempt) logger.warning( @@ -97,66 +101,20 @@ async def wrapper(self: "TaskQueueManager", *args, **kwargs) -> T: class TaskQueueManager: """ Manager for NATS JetStream task queue operations. + Always uses the event-loop-local connection pool for efficiency. - Use as an async context manager to scope a single connection to a block of operations: - - async with TaskQueueManager() as manager: - await manager.publish_task(job_id, task) - - The connection is created on entry and closed on exit. Within the block, the retry - decorator handles transient connection errors by clearing and recreating the connection. + Note: The connection pool is keyed by event loop, so each event loop gets its own + connection. This prevents "attached to a different loop" errors when using async_to_sync() + in Celery tasks or Django views. There's no overhead creating multiple TaskQueueManager + instances as they all share the same event-loop-keyed pool. """ - def __init__(self): - self._nc: "NATSClient | None" = None - self._js: JetStreamContext | None = None - - async def __aenter__(self) -> "TaskQueueManager": - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: - await self._close_connection() - async def _get_connection(self) -> tuple["NATSClient", JetStreamContext]: - """Get or create a NATS connection.""" - # Fast path: connection exists and is healthy - if self._nc is not None and not self._nc.is_closed and self._nc.is_connected: - return self._nc, self._js # type: ignore - - # Connection is stale, clear it - if self._nc is not None: - logger.warning("NATS connection is closed or disconnected, will reconnect") - await self._close_connection() - - nats_url = getattr(settings, "NATS_URL", "nats://nats:4222") - logger.info(f"Connecting to NATS at {nats_url}") - self._nc = await nats.connect( - nats_url, - reconnected_cb=self._on_reconnected, - disconnected_cb=self._on_disconnected, - ) - self._js = self._nc.jetstream() - logger.info(f"Connected to NATS at {nats_url}") - return self._nc, self._js - - async def _close_connection(self) -> None: - """Close the NATS connection if open.""" - if self._nc is not None: - try: - if not self._nc.is_closed: - await self._nc.close() - except Exception as e: - logger.debug(f"Error closing NATS connection (expected during error recovery): {e}") - self._nc = None - self._js = None - - @staticmethod - async def _on_reconnected(): - logger.info("NATS client reconnected") - - @staticmethod - async def _on_disconnected(): - logger.warning("NATS client disconnected") + """Get connection from the event-loop-local pool.""" + from ami.ml.orchestration.nats_connection_pool import get_pool + + pool = get_pool() + return await pool.get_connection() def _get_stream_name(self, job_id: int) -> str: """Get stream name from job_id.""" @@ -180,7 +138,9 @@ async def _ensure_stream(self, job_id: int): try: await js.stream_info(stream_name) logger.debug(f"Stream {stream_name} already exists") - except NotFoundError: + except Exception as e: + logger.warning(f"Stream {stream_name} does not exist: {e}") + # Stream doesn't exist, create it await js.add_stream( name=stream_name, subjects=[subject], @@ -199,7 +159,8 @@ async def _ensure_consumer(self, job_id: int): try: info = await js.consumer_info(stream_name, consumer_name) logger.debug(f"Consumer {consumer_name} already exists: {info}") - except NotFoundError: + except Exception: + # Consumer doesn't exist, create it await js.add_consumer( stream=stream_name, config=ConsumerConfig( diff --git a/ami/ml/orchestration/tests/test_cleanup.py b/ami/ml/orchestration/tests/test_cleanup.py index e201b988f..a5fc6fcb7 100644 --- a/ami/ml/orchestration/tests/test_cleanup.py +++ b/ami/ml/orchestration/tests/test_cleanup.py @@ -68,28 +68,28 @@ def _verify_resources_created(self, job_id: int): # Verify NATS stream and consumer exist async def check_nats_resources(): - async with TaskQueueManager() as manager: - stream_name = manager._get_stream_name(job_id) - consumer_name = manager._get_consumer_name(job_id) + manager = TaskQueueManager() + stream_name = manager._get_stream_name(job_id) + consumer_name = manager._get_consumer_name(job_id) - # Get JetStream context - _, js = await manager._get_connection() + # Get JetStream context + _, js = await manager._get_connection() - # Try to get stream info - should succeed if created - stream_exists = True - try: - await js.stream_info(stream_name) - except NotFoundError: - stream_exists = False + # Try to get stream info - should succeed if created + stream_exists = True + try: + await js.stream_info(stream_name) + except NotFoundError: + stream_exists = False - # Try to get consumer info - should succeed if created - consumer_exists = True - try: - await js.consumer_info(stream_name, consumer_name) - except NotFoundError: - consumer_exists = False + # Try to get consumer info - should succeed if created + consumer_exists = True + try: + await js.consumer_info(stream_name, consumer_name) + except NotFoundError: + consumer_exists = False - return stream_exists, consumer_exists + return stream_exists, consumer_exists stream_exists, consumer_exists = async_to_sync(check_nats_resources)() @@ -137,28 +137,28 @@ def _verify_resources_cleaned(self, job_id: int): # Verify NATS stream and consumer are deleted async def check_nats_resources(): - async with TaskQueueManager() as manager: - stream_name = manager._get_stream_name(job_id) - consumer_name = manager._get_consumer_name(job_id) - - # Get JetStream context - _, js = await manager._get_connection() - - # Try to get stream info - should fail if deleted - stream_exists = True - try: - await js.stream_info(stream_name) - except NotFoundError: - stream_exists = False - - # Try to get consumer info - should fail if deleted - consumer_exists = True - try: - await js.consumer_info(stream_name, consumer_name) - except NotFoundError: - consumer_exists = False - - return stream_exists, consumer_exists + manager = TaskQueueManager() + stream_name = manager._get_stream_name(job_id) + consumer_name = manager._get_consumer_name(job_id) + + # Get JetStream context + _, js = await manager._get_connection() + + # Try to get stream info - should fail if deleted + stream_exists = True + try: + await js.stream_info(stream_name) + except NotFoundError: + stream_exists = False + + # Try to get consumer info - should fail if deleted + consumer_exists = True + try: + await js.consumer_info(stream_name, consumer_name) + except NotFoundError: + consumer_exists = False + + return stream_exists, consumer_exists stream_exists, consumer_exists = async_to_sync(check_nats_resources)() diff --git a/ami/ml/orchestration/tests/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py index f36e08b3b..def90ba68 100644 --- a/ami/ml/orchestration/tests/test_nats_queue.py +++ b/ami/ml/orchestration/tests/test_nats_queue.py @@ -1,7 +1,7 @@ """Unit tests for TaskQueueManager.""" import unittest -from contextlib import asynccontextmanager +from contextlib import contextmanager from unittest.mock import AsyncMock, MagicMock, patch from ami.ml.orchestration.nats_queue import TaskQueueManager @@ -19,16 +19,15 @@ def _create_sample_task(self): image_url="https://example.com/image.jpg", ) - @asynccontextmanager - async def _mock_nats_setup(self): - """Helper to create and mock NATS connection. + @contextmanager + def _mock_nats_setup(self): + """Helper to create and mock NATS connection with connection pool. Yields: - tuple: (nc, js) - mock NATS client and JetStream context + tuple: (nc, js, mock_pool) - NATS connection, JetStream context, and mock pool """ nc = MagicMock() nc.is_closed = False - nc.is_connected = True nc.close = AsyncMock() js = MagicMock() @@ -41,24 +40,22 @@ async def _mock_nats_setup(self): js.delete_consumer = AsyncMock() js.delete_stream = AsyncMock() - mock_nc = AsyncMock(return_value=nc) - nc.jetstream.return_value = js - - with patch("ami.ml.orchestration.nats_queue.nats.connect", mock_nc): - yield nc, js + with patch("ami.ml.orchestration.nats_connection_pool.get_pool") as mock_get_pool: + mock_pool = MagicMock() + mock_pool.get_connection = AsyncMock(return_value=(nc, js)) + mock_get_pool.return_value = mock_pool + yield nc, js, mock_pool async def test_publish_task_creates_stream_and_consumer(self): """Test that publish_task ensures stream and consumer exist.""" - from nats.js.errors import NotFoundError - sample_task = self._create_sample_task() - async with self._mock_nats_setup() as (_, js): - js.stream_info.side_effect = NotFoundError - js.consumer_info.side_effect = NotFoundError + with self._mock_nats_setup() as (_, js, _): + js.stream_info.side_effect = Exception("Not found") + js.consumer_info.side_effect = Exception("Not found") - async with TaskQueueManager() as manager: - await manager.publish_task(456, sample_task) + manager = TaskQueueManager() + await manager.publish_task(456, sample_task) js.add_stream.assert_called_once() self.assertIn("job_456", str(js.add_stream.call_args)) @@ -74,14 +71,14 @@ async def test_reserve_task_success(self): mock_msg.reply = "reply.subject.123" mock_msg.metadata = MagicMock(sequence=MagicMock(stream=1)) - async with self._mock_nats_setup() as (_, js): + with self._mock_nats_setup() as (_, js, _): mock_psub = MagicMock() mock_psub.fetch = AsyncMock(return_value=[mock_msg]) mock_psub.unsubscribe = AsyncMock() js.pull_subscribe = AsyncMock(return_value=mock_psub) - async with TaskQueueManager() as manager: - task = await manager.reserve_task(123) + manager = TaskQueueManager() + task = await manager.reserve_task(123) self.assertIsNotNone(task) self.assertEqual(task.id, sample_task.id) @@ -90,26 +87,26 @@ async def test_reserve_task_success(self): async def test_reserve_task_no_messages(self): """Test reserve_task when no messages are available.""" - async with self._mock_nats_setup() as (_, js): + with self._mock_nats_setup() as (_, js, _): mock_psub = MagicMock() mock_psub.fetch = AsyncMock(return_value=[]) mock_psub.unsubscribe = AsyncMock() js.pull_subscribe = AsyncMock(return_value=mock_psub) - async with TaskQueueManager() as manager: - task = await manager.reserve_task(123) + manager = TaskQueueManager() + task = await manager.reserve_task(123) self.assertIsNone(task) mock_psub.unsubscribe.assert_called_once() async def test_acknowledge_task_success(self): """Test successful task acknowledgment.""" - async with self._mock_nats_setup() as (nc, _): + with self._mock_nats_setup() as (nc, _, _): nc.publish = AsyncMock() nc.flush = AsyncMock() - async with TaskQueueManager() as manager: - result = await manager.acknowledge_task("reply.subject.123") + manager = TaskQueueManager() + result = await manager.acknowledge_task("reply.subject.123") self.assertTrue(result) nc.publish.assert_called_once_with("reply.subject.123", b"+ACK") @@ -117,9 +114,9 @@ async def test_acknowledge_task_success(self): async def test_cleanup_job_resources(self): """Test cleanup of job resources (consumer and stream).""" - async with self._mock_nats_setup() as (_, js): - async with TaskQueueManager() as manager: - result = await manager.cleanup_job_resources(123) + with self._mock_nats_setup() as (_, js, _): + manager = TaskQueueManager() + result = await manager.cleanup_job_resources(123) self.assertTrue(result) js.delete_consumer.assert_called_once() From dc798eac72681501acd9b8b4454349aa6344a2cf Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 13 Feb 2026 12:05:25 -0800 Subject: [PATCH 10/16] refactor: add switchable NATS connection strategies Extract connection pool to a pluggable design with two strategies: - "pool" (default): persistent connection reuse per event loop - "per_operation": fresh TCP connection each time, for debugging Controlled by NATS_CONNECTION_STRATEGY Django setting. Both strategies implement the same interface (get_connection, reset, close) so TaskQueueManager is agnostic to which one is active. Changes: - Rename nats_connection_pool.py to nats_connection.py - Rename get_pool() to get_provider() - Use settings.NATS_URL directly instead of getattr with divergent defaults - Narrow except clauses in _ensure_stream/_ensure_consumer to NotFoundError - Add _js guard to fast path, add strategy logging - Enhanced module and class docstrings Co-Authored-By: Claude --- ami/ml/orchestration/nats_connection.py | 250 +++++++++++++ ami/ml/orchestration/nats_connection_pool.py | 153 -------- ami/ml/orchestration/nats_queue.py | 39 +- .../tests/test_nats_connection.py | 336 ++++++++++++++++++ ami/ml/orchestration/tests/test_nats_queue.py | 116 +++++- config/settings/base.py | 2 + 6 files changed, 721 insertions(+), 175 deletions(-) create mode 100644 ami/ml/orchestration/nats_connection.py delete mode 100644 ami/ml/orchestration/nats_connection_pool.py create mode 100644 ami/ml/orchestration/tests/test_nats_connection.py diff --git a/ami/ml/orchestration/nats_connection.py b/ami/ml/orchestration/nats_connection.py new file mode 100644 index 000000000..7edc29cc7 --- /dev/null +++ b/ami/ml/orchestration/nats_connection.py @@ -0,0 +1,250 @@ +""" +NATS connection management for both Celery workers and Django processes. + +Provides two connection strategies, selectable via NATS_CONNECTION_STRATEGY setting: + + "pool" (default): + Maintains a persistent NATS connection per event loop and reuses it across + all TaskQueueManager operations. A 1000-image job generates ~1500+ NATS + operations (1 for queuing, 250-500 for task fetches, 1000 for ACKs). The + pool keeps one connection alive per event loop and reuses it for all of them. + + "per_operation": + Creates a fresh TCP connection for every get_connection() call. Simple but + expensive — the same 1000-image job opens ~1500 TCP connections. Use this + only for debugging connection issues or when pooling causes problems. + +Why keyed by event loop (pool strategy): + Django views and Celery tasks use async_to_sync(), which creates a new event + loop per thread. asyncio.Lock and nats.Client are bound to the loop they were + created on, so sharing them across loops causes "attached to a different loop" + errors. Keying by loop ensures isolation. WeakKeyDictionary auto-cleans when + loops are garbage collected, so short-lived loops don't leak. + +Connection lifecycle (pool strategy): + - Created lazily on first use within an event loop + - Reused for all subsequent operations on that loop + - On connection error: retry decorator calls pool.reset() to close the stale + connection; next operation creates a fresh one (see retry_on_connection_error + in nats_queue.py) + - Cleaned up automatically when the event loop is garbage collected + +Both strategies implement the same interface (get_connection, reset, close) so +TaskQueueManager is agnostic to which one is active. +""" + +import asyncio +import logging +import threading +import typing +from typing import TYPE_CHECKING +from weakref import WeakKeyDictionary + +import nats +from django.conf import settings +from nats.js import JetStreamContext + +if TYPE_CHECKING: + from nats.aio.client import Client as NATSClient + +logger = logging.getLogger(__name__) + + +class ConnectionProvider(typing.Protocol): + """Interface that all NATS connection strategies must implement.""" + + async def get_connection(self) -> tuple["NATSClient", JetStreamContext]: + ... + + async def reset(self) -> None: + ... + + async def close(self) -> None: + ... + + +class ConnectionPool: + """ + Manages a single persistent NATS connection per event loop. + + This is safe because: + - asyncio.Lock and NATS Client are bound to the event loop they were created on + - Each event loop gets its own isolated connection and lock + - Works correctly with async_to_sync() which creates per-thread event loops + - Prevents "attached to a different loop" errors in Celery tasks and Django views + + Instantiating TaskQueueManager() is cheap — multiple instances share the same + underlying connection via this pool. + """ + + def __init__(self): + self._nc: "NATSClient | None" = None + self._js: JetStreamContext | None = None + self._lock: asyncio.Lock | None = None # Lazy-initialized when needed + + def _ensure_lock(self) -> asyncio.Lock: + """Lazily create lock bound to current event loop.""" + if self._lock is None: + self._lock = asyncio.Lock() + return self._lock + + async def get_connection(self) -> tuple["NATSClient", JetStreamContext]: + """ + Get or create the event loop's NATS connection. Checks connection health + and recreates if stale. + + Returns: + Tuple of (NATS connection, JetStream context) + Raises: + RuntimeError: If connection cannot be established + """ + # Fast path (no lock needed): connection exists, is open, and is connected. + # This is the hot path — most calls hit this and return immediately. + if self._nc is not None and self._js is not None and not self._nc.is_closed and self._nc.is_connected: + return self._nc, self._js + + # Connection is stale or doesn't exist — clear references before reconnecting + if self._nc is not None: + logger.warning("NATS connection is closed or disconnected, will reconnect") + self._nc = None + self._js = None + + # Slow path: acquire lock to prevent concurrent reconnection attempts + lock = self._ensure_lock() + async with lock: + # Double-check after acquiring lock (another coroutine may have reconnected) + if self._nc is not None and self._js is not None and not self._nc.is_closed and self._nc.is_connected: + return self._nc, self._js + + nats_url = settings.NATS_URL + try: + logger.info(f"Creating NATS connection to {nats_url}") + self._nc = await nats.connect(nats_url) + self._js = self._nc.jetstream() + logger.info(f"Successfully connected to NATS at {nats_url}") + return self._nc, self._js + except Exception as e: + logger.error(f"Failed to connect to NATS: {e}") + raise RuntimeError(f"Could not establish NATS connection: {e}") from e + + async def close(self): + """Close the NATS connection if it exists.""" + if self._nc is not None and not self._nc.is_closed: + logger.info("Closing NATS connection") + await self._nc.close() + self._nc = None + self._js = None + + async def reset(self): + """ + Close the current connection and clear all state so the next call to + get_connection() creates a fresh one. + + Called by retry_on_connection_error when an operation hits a connection + error (e.g. network blip, NATS restart). The lock is also cleared so it + gets recreated bound to the current event loop. + """ + logger.warning("Resetting NATS connection pool due to connection error") + if self._nc is not None: + try: + if not self._nc.is_closed: + await self._nc.close() + logger.debug("Successfully closed existing NATS connection during reset") + except Exception as e: + # Swallow errors - connection may already be broken + logger.debug(f"Error closing connection during reset (expected): {e}") + self._nc = None + self._js = None + self._lock = None # Clear lock so new one is created for fresh connection + + +class PerOperationConnection: + """ + Creates a fresh NATS connection on every get_connection() call. + + Each call closes the previous connection (if any) and opens a new one. + This avoids any shared state but is expensive: ~1500 TCP round-trips per + 1000-image job vs. 1 with the pool strategy. + + Use for debugging connection lifecycle issues or when the pool causes + problems (e.g. event loop mismatch edge cases). + """ + + def __init__(self): + self._nc: "NATSClient | None" = None + self._js: JetStreamContext | None = None + + async def get_connection(self) -> tuple["NATSClient", JetStreamContext]: + """Create a fresh NATS connection, closing any previous one.""" + # Close previous connection if it exists + await self.close() + + nats_url = settings.NATS_URL + try: + logger.debug(f"Creating transient NATS connection to {nats_url}") + self._nc = await nats.connect(nats_url) + self._js = self._nc.jetstream() + return self._nc, self._js + except Exception as e: + logger.error(f"Failed to connect to NATS: {e}") + raise RuntimeError(f"Could not establish NATS connection: {e}") from e + + async def close(self): + """Close the current connection if it exists.""" + if self._nc is not None and not self._nc.is_closed: + try: + await self._nc.close() + except Exception as e: + logger.debug(f"Error closing transient connection (expected): {e}") + self._nc = None + self._js = None + + async def reset(self): + """Close and clear — next get_connection() creates a fresh one.""" + await self.close() + + +# Event-loop-keyed providers: one per event loop. +# WeakKeyDictionary automatically cleans up when event loops are garbage collected. +_providers: WeakKeyDictionary[asyncio.AbstractEventLoop, ConnectionProvider] = WeakKeyDictionary() +_providers_lock = threading.Lock() + + +def _create_provider() -> ConnectionProvider: + """Create a connection provider based on the NATS_CONNECTION_STRATEGY setting.""" + strategy = settings.NATS_CONNECTION_STRATEGY + if strategy == "per_operation": + logger.info("Using NATS connection strategy: per_operation") + return PerOperationConnection() + logger.info("Using NATS connection strategy: pool (persistent)") + return ConnectionPool() + + +def get_provider() -> ConnectionProvider: + """ + Get or create the connection provider for the current event loop. + + Each event loop gets its own provider to prevent "attached to a different loop" + errors. The provider type is determined by the NATS_CONNECTION_STRATEGY setting: + - "pool" (default): persistent connection reuse via ConnectionPool + - "per_operation": fresh connection each time via PerOperationConnection + + Returns: + ConnectionProvider bound to the current event loop + + Raises: + RuntimeError: If called outside of an async context (no running event loop) + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + raise RuntimeError( + "get_provider() must be called from an async context with a running event loop. " + "If calling from sync code, use async_to_sync() to wrap the async function." + ) + + with _providers_lock: + if loop not in _providers: + _providers[loop] = _create_provider() + logger.debug(f"Created NATS connection provider for event loop {id(loop)}") + return _providers[loop] diff --git a/ami/ml/orchestration/nats_connection_pool.py b/ami/ml/orchestration/nats_connection_pool.py deleted file mode 100644 index a8352d156..000000000 --- a/ami/ml/orchestration/nats_connection_pool.py +++ /dev/null @@ -1,153 +0,0 @@ -""" -NATS connection pool for both Celery workers and Django processes. - -Maintains a persistent NATS connection per event loop to avoid -the overhead of creating/closing connections for every operation. - -The connection pool is lazily initialized on first use and keyed by event loop -to prevent "attached to a different loop" errors when using async_to_sync(). -""" - -import asyncio -import logging -import threading -from typing import TYPE_CHECKING -from weakref import WeakKeyDictionary - -import nats -from django.conf import settings -from nats.js import JetStreamContext - -if TYPE_CHECKING: - from nats.aio.client import Client as NATSClient - -logger = logging.getLogger(__name__) - - -class ConnectionPool: - """ - Manages a single NATS connection per event loop. - - This is safe because: - - asyncio.Lock and NATS Client are bound to the event loop they were created on - - Each event loop gets its own isolated connection and lock - - Works correctly with async_to_sync() which creates per-thread event loops - - Prevents "attached to a different loop" errors in Celery tasks and Django views - """ - - def __init__(self): - self._nc: "NATSClient | None" = None - self._js: JetStreamContext | None = None - self._nats_url: str | None = None - self._lock: asyncio.Lock | None = None # Lazy-initialized when needed - - def _ensure_lock(self) -> asyncio.Lock: - """Lazily create lock bound to current event loop.""" - if self._lock is None: - self._lock = asyncio.Lock() - return self._lock - - async def get_connection(self) -> tuple["NATSClient", JetStreamContext]: - """ - Get or create the event loop's NATS connection. Checks connection health and recreates if stale. - - Returns: - Tuple of (NATS connection, JetStream context) - Raises: - RuntimeError: If connection cannot be established - """ - # Fast path: connection exists, is open, and is connected - if self._nc is not None and not self._nc.is_closed and self._nc.is_connected: - return self._nc, self._js # type: ignore - - # Connection is stale or doesn't exist - if self._nc is not None: - logger.warning("NATS connection is closed or disconnected, will reconnect") - self._nc = None - self._js = None - - # Slow path: need to create/recreate connection - lock = self._ensure_lock() - async with lock: - # Double-check after acquiring lock - if self._nc is not None and not self._nc.is_closed and self._nc.is_connected: - return self._nc, self._js # type: ignore - - # Get NATS URL from settings - if self._nats_url is None: - self._nats_url = getattr(settings, "NATS_URL", "nats://nats:4222") - - try: - logger.info(f"Creating NATS connection to {self._nats_url}") - self._nc = await nats.connect(self._nats_url) - self._js = self._nc.jetstream() - logger.info(f"Successfully connected to NATS at {self._nats_url}") - return self._nc, self._js - except Exception as e: - logger.error(f"Failed to connect to NATS: {e}") - raise RuntimeError(f"Could not establish NATS connection: {e}") from e - - async def close(self): - """Close the NATS connection if it exists.""" - if self._nc is not None and not self._nc.is_closed: - logger.info("Closing NATS connection") - await self._nc.close() - self._nc = None - self._js = None - - async def reset(self): - """ - Async version of reset that properly closes the connection before clearing references. - - This should be called when a connection error is detected from an async context. - The next call to get_connection() will create a fresh connection. - """ - logger.warning("Resetting NATS connection pool due to connection error") - if self._nc is not None: - try: - # Attempt to close the connection gracefully - if not self._nc.is_closed: - await self._nc.close() - logger.debug("Successfully closed existing NATS connection during reset") - except Exception as e: - # Swallow errors - connection may already be broken - logger.debug(f"Error closing connection during reset (expected): {e}") - self._nc = None - self._js = None - self._lock = None # Clear lock so new one is created for fresh connection - - -# Event-loop-keyed pools: one ConnectionPool per event loop -# WeakKeyDictionary automatically cleans up when event loops are garbage collected -_pools: WeakKeyDictionary[asyncio.AbstractEventLoop, ConnectionPool] = WeakKeyDictionary() -_pools_lock = threading.Lock() - - -def get_pool() -> ConnectionPool: - """ - Get or create the connection pool for the current event loop. - - Each event loop gets its own ConnectionPool to prevent "attached to a different loop" errors. - This is critical when using async_to_sync() in Celery tasks or Django views, as each call - may run on a different event loop. - - Returns: - ConnectionPool bound to the current event loop - - Raises: - RuntimeError: If called outside of an async context (no running event loop) - """ - try: - loop = asyncio.get_running_loop() - except RuntimeError: - raise RuntimeError( - "get_pool() must be called from an async context with a running event loop. " - "If calling from sync code, use async_to_sync() to wrap the async function." - ) - - # Thread-safe lookup/creation of pool for this event loop - with _pools_lock: - if loop not in _pools: - _pools[loop] = ConnectionPool() - logger.debug(f"Created NATS connection pool for event loop {id(loop)}") - return _pools[loop] diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 974b12b8a..29f17f08f 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -20,6 +20,7 @@ from nats import errors as nats_errors from nats.js import JetStreamContext from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy +from nats.js.errors import NotFoundError from ami.ml.schemas import PipelineProcessingTask @@ -77,11 +78,11 @@ async def wrapper(self, *args, **kwargs) -> T: exc_info=True, ) break - # Reset the connection pool so next attempt gets a fresh connection - from ami.ml.orchestration.nats_connection_pool import get_pool + # Reset the connection provider so next attempt gets a fresh connection + from ami.ml.orchestration.nats_connection import get_provider - pool = get_pool() - await pool.reset() + provider = get_provider() + await provider.reset() # Exponential backoff wait_time = backoff_seconds * (2**attempt) logger.warning( @@ -101,20 +102,28 @@ async def wrapper(self, *args, **kwargs) -> T: class TaskQueueManager: """ Manager for NATS JetStream task queue operations. - Always uses the event-loop-local connection pool for efficiency. - Note: The connection pool is keyed by event loop, so each event loop gets its own - connection. This prevents "attached to a different loop" errors when using async_to_sync() - in Celery tasks or Django views. There's no overhead creating multiple TaskQueueManager - instances as they all share the same event-loop-keyed pool. + This class is a stateless wrapper — it holds no connection state itself. + All connections come from the event-loop-keyed provider in + nats_connection.py, so instantiating TaskQueueManager() is cheap + and multiple instances share the same underlying connection. + + Usage pattern (no context manager needed): + manager = TaskQueueManager() + await manager.publish_task(job_id, task) + await manager.acknowledge_task(reply_subject) + + Error handling: + The @retry_on_connection_error decorator on each method handles transient + connection failures by resetting the provider and retrying with backoff. """ async def _get_connection(self) -> tuple["NATSClient", JetStreamContext]: - """Get connection from the event-loop-local pool.""" - from ami.ml.orchestration.nats_connection_pool import get_pool + """Get connection from the event-loop-local provider.""" + from ami.ml.orchestration.nats_connection import get_provider - pool = get_pool() - return await pool.get_connection() + provider = get_provider() + return await provider.get_connection() def _get_stream_name(self, job_id: int) -> str: """Get stream name from job_id.""" @@ -138,7 +147,7 @@ async def _ensure_stream(self, job_id: int): try: await js.stream_info(stream_name) logger.debug(f"Stream {stream_name} already exists") - except Exception as e: + except NotFoundError as e: logger.warning(f"Stream {stream_name} does not exist: {e}") # Stream doesn't exist, create it await js.add_stream( @@ -159,7 +168,7 @@ async def _ensure_consumer(self, job_id: int): try: info = await js.consumer_info(stream_name, consumer_name) logger.debug(f"Consumer {consumer_name} already exists: {info}") - except Exception: + except NotFoundError: # Consumer doesn't exist, create it await js.add_consumer( stream=stream_name, diff --git a/ami/ml/orchestration/tests/test_nats_connection.py b/ami/ml/orchestration/tests/test_nats_connection.py new file mode 100644 index 000000000..250a7a980 --- /dev/null +++ b/ami/ml/orchestration/tests/test_nats_connection.py @@ -0,0 +1,336 @@ +"""Unit tests for nats_connection module.""" + +import asyncio +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from ami.ml.orchestration.nats_connection import ConnectionPool, PerOperationConnection, _create_provider + + +class TestCreateProvider(unittest.TestCase): + """Test _create_provider() returns the correct strategy based on settings.""" + + @patch("ami.ml.orchestration.nats_connection.settings") + def test_default_strategy_returns_connection_pool(self, mock_settings): + """Test that default (unspecified) strategy returns ConnectionPool.""" + mock_settings.NATS_CONNECTION_STRATEGY = "pool" + provider = _create_provider() + self.assertIsInstance(provider, ConnectionPool) + + @patch("ami.ml.orchestration.nats_connection.settings") + def test_per_operation_strategy_returns_per_operation_connection(self, mock_settings): + """Test that per_operation strategy returns PerOperationConnection.""" + mock_settings.NATS_CONNECTION_STRATEGY = "per_operation" + provider = _create_provider() + self.assertIsInstance(provider, PerOperationConnection) + + @patch("ami.ml.orchestration.nats_connection.settings") + def test_unknown_strategy_falls_back_to_pool(self, mock_settings): + """Test that unknown strategy falls back to ConnectionPool.""" + mock_settings.NATS_CONNECTION_STRATEGY = "unknown_value" + provider = _create_provider() + self.assertIsInstance(provider, ConnectionPool) + + @patch("ami.ml.orchestration.nats_connection.settings") + def test_empty_string_strategy_falls_back_to_pool(self, mock_settings): + """Test that empty string strategy falls back to ConnectionPool.""" + mock_settings.NATS_CONNECTION_STRATEGY = "" + provider = _create_provider() + self.assertIsInstance(provider, ConnectionPool) + + +class TestConnectionPoolBehavior(unittest.IsolatedAsyncioTestCase): + """Test ConnectionPool lifecycle and connection reuse.""" + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_get_connection_creates_connection(self, mock_settings, mock_nats): + """Test that get_connection creates a connection on first call.""" + mock_settings.NATS_URL = "nats://test:4222" + + mock_nc = MagicMock() + mock_nc.is_closed = False + mock_nc.is_connected = True + mock_nc.close = AsyncMock() + mock_js = MagicMock() + mock_nc.jetstream.return_value = mock_js + + mock_nats.connect = AsyncMock(return_value=mock_nc) + + pool = ConnectionPool() + nc, js = await pool.get_connection() + + self.assertIs(nc, mock_nc) + self.assertIs(js, mock_js) + mock_nats.connect.assert_called_once_with("nats://test:4222") + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_get_connection_reuses_existing_connection(self, mock_settings, mock_nats): + """Test that get_connection reuses connection on subsequent calls.""" + mock_settings.NATS_URL = "nats://test:4222" + + mock_nc = MagicMock() + mock_nc.is_closed = False + mock_nc.is_connected = True + mock_nc.close = AsyncMock() + mock_js = MagicMock() + mock_nc.jetstream.return_value = mock_js + + mock_nats.connect = AsyncMock(return_value=mock_nc) + + pool = ConnectionPool() + + # First call + nc1, js1 = await pool.get_connection() + # Second call + nc2, js2 = await pool.get_connection() + + # Should only connect once + self.assertEqual(mock_nats.connect.call_count, 1) + self.assertIs(nc1, nc2) + self.assertIs(js1, js2) + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_get_connection_reconnects_if_closed(self, mock_settings, mock_nats): + """Test that get_connection reconnects if the connection is closed.""" + mock_settings.NATS_URL = "nats://test:4222" + + mock_nc1 = MagicMock() + mock_nc1.is_closed = True + mock_nc1.is_connected = False + mock_nc1.close = AsyncMock() + + mock_nc2 = MagicMock() + mock_nc2.is_closed = False + mock_nc2.is_connected = True + mock_nc2.close = AsyncMock() + mock_js2 = MagicMock() + mock_nc2.jetstream.return_value = mock_js2 + + mock_nats.connect = AsyncMock(side_effect=[mock_nc2]) + + pool = ConnectionPool() + pool._nc = mock_nc1 + pool._js = MagicMock() + + # This should detect the connection is closed and reconnect + nc, js = await pool.get_connection() + + self.assertIs(nc, mock_nc2) + self.assertIs(js, mock_js2) + mock_nats.connect.assert_called_once() + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_get_connection_raises_on_connection_error(self, mock_settings, mock_nats): + """Test that get_connection raises RuntimeError on connection failure.""" + mock_settings.NATS_URL = "nats://test:4222" + + mock_nats.connect = AsyncMock(side_effect=ConnectionError("Connection failed")) + + pool = ConnectionPool() + + with self.assertRaises(RuntimeError) as context: + await pool.get_connection() + + self.assertIn("Could not establish NATS connection", str(context.exception)) + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_close_closes_connection(self, mock_settings, mock_nats): + """Test that close() closes the connection.""" + mock_settings.NATS_URL = "nats://test:4222" + + mock_nc = MagicMock() + mock_nc.is_closed = False + mock_nc.is_connected = True + mock_nc.close = AsyncMock() + mock_js = MagicMock() + mock_nc.jetstream.return_value = mock_js + + mock_nats.connect = AsyncMock(return_value=mock_nc) + + pool = ConnectionPool() + await pool.get_connection() + + await pool.close() + + mock_nc.close.assert_called_once() + self.assertIsNone(pool._nc) + self.assertIsNone(pool._js) + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_reset_closes_and_clears_state(self, mock_settings, mock_nats): + """Test that reset() closes connection and clears all state.""" + mock_settings.NATS_URL = "nats://test:4222" + + mock_nc = MagicMock() + mock_nc.is_closed = False + mock_nc.is_connected = True + mock_nc.close = AsyncMock() + mock_js = MagicMock() + mock_nc.jetstream.return_value = mock_js + + mock_nats.connect = AsyncMock(return_value=mock_nc) + + pool = ConnectionPool() + await pool.get_connection() + + # Set the lock so we can verify it gets cleared + pool._lock = asyncio.Lock() + + await pool.reset() + + mock_nc.close.assert_called_once() + self.assertIsNone(pool._nc) + self.assertIsNone(pool._js) + self.assertIsNone(pool._lock) + + +class TestPerOperationConnectionBehavior(unittest.IsolatedAsyncioTestCase): + """Test PerOperationConnection lifecycle behavior.""" + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_get_connection_creates_connection(self, mock_settings, mock_nats): + """Test that get_connection creates a fresh connection.""" + mock_settings.NATS_URL = "nats://test:4222" + + mock_nc = MagicMock() + mock_nc.is_closed = False + mock_nc.close = AsyncMock() + mock_js = MagicMock() + mock_nc.jetstream.return_value = mock_js + + mock_nats.connect = AsyncMock(return_value=mock_nc) + + conn = PerOperationConnection() + nc, js = await conn.get_connection() + + self.assertIs(nc, mock_nc) + self.assertIs(js, mock_js) + mock_nats.connect.assert_called_once_with("nats://test:4222") + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_get_connection_closes_previous(self, mock_settings, mock_nats): + """Test that each get_connection() call closes the previous connection.""" + mock_settings.NATS_URL = "nats://test:4222" + + # Create two mock connections (MagicMock because jetstream() is sync) + mock_nc1 = MagicMock() + mock_nc1.is_closed = False + mock_nc1.close = AsyncMock() + mock_nc1.jetstream.return_value = MagicMock() + + mock_nc2 = MagicMock() + mock_nc2.is_closed = False + mock_nc2.close = AsyncMock() + mock_nc2.jetstream.return_value = MagicMock() + + mock_nats.connect = AsyncMock(side_effect=[mock_nc1, mock_nc2]) + + conn = PerOperationConnection() + + # First call + nc1, _ = await conn.get_connection() + self.assertIs(nc1, mock_nc1) + + # Second call should close the first connection + nc2, _ = await conn.get_connection() + self.assertIs(nc2, mock_nc2) + + # Verify first connection was closed + mock_nc1.close.assert_called_once() + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_get_connection_handles_close_errors(self, mock_settings, mock_nats): + """Test that get_connection handles errors when closing previous connection.""" + mock_settings.NATS_URL = "nats://test:4222" + + # First connection throws error on close + mock_nc1 = MagicMock() + mock_nc1.is_closed = False + mock_nc1.close = AsyncMock(side_effect=RuntimeError("Close error")) + mock_nc1.jetstream.return_value = MagicMock() + + # Second connection succeeds + mock_nc2 = MagicMock() + mock_nc2.is_closed = False + mock_nc2.close = AsyncMock() + mock_nc2.jetstream.return_value = MagicMock() + + mock_nats.connect = AsyncMock(side_effect=[mock_nc1, mock_nc2]) + + conn = PerOperationConnection() + + # First call + await conn.get_connection() + + # Second call should not raise even though closing first connection fails + nc2, _ = await conn.get_connection() + self.assertIs(nc2, mock_nc2) + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_reset_closes_connection(self, mock_settings, mock_nats): + """Test that reset() closes the current connection.""" + mock_settings.NATS_URL = "nats://test:4222" + + mock_nc = MagicMock() + mock_nc.is_closed = False + mock_nc.close = AsyncMock() + mock_nc.jetstream.return_value = MagicMock() + + mock_nats.connect = AsyncMock(return_value=mock_nc) + + conn = PerOperationConnection() + await conn.get_connection() + + await conn.reset() + + mock_nc.close.assert_called_once() + # After reset, internal state should be cleared + self.assertIsNone(conn._nc) + self.assertIsNone(conn._js) + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_close_closes_connection(self, mock_settings, mock_nats): + """Test that close() closes the connection.""" + mock_settings.NATS_URL = "nats://test:4222" + + mock_nc = MagicMock() + mock_nc.is_closed = False + mock_nc.close = AsyncMock() + mock_nc.jetstream.return_value = MagicMock() + + mock_nats.connect = AsyncMock(return_value=mock_nc) + + conn = PerOperationConnection() + await conn.get_connection() + + await conn.close() + + mock_nc.close.assert_called_once() + self.assertIsNone(conn._nc) + self.assertIsNone(conn._js) + + @patch("ami.ml.orchestration.nats_connection.nats") + @patch("ami.ml.orchestration.nats_connection.settings") + async def test_get_connection_raises_on_connection_error(self, mock_settings, mock_nats): + """Test that get_connection raises RuntimeError on connection failure.""" + mock_settings.NATS_URL = "nats://test:4222" + + mock_nats.connect = AsyncMock(side_effect=ConnectionError("Connection failed")) + + conn = PerOperationConnection() + + with self.assertRaises(RuntimeError) as context: + await conn.get_connection() + + self.assertIn("Could not establish NATS connection", str(context.exception)) diff --git a/ami/ml/orchestration/tests/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py index def90ba68..7367d7b42 100644 --- a/ami/ml/orchestration/tests/test_nats_queue.py +++ b/ami/ml/orchestration/tests/test_nats_queue.py @@ -4,6 +4,8 @@ from contextlib import contextmanager from unittest.mock import AsyncMock, MagicMock, patch +from nats.js.errors import NotFoundError + from ami.ml.orchestration.nats_queue import TaskQueueManager from ami.ml.schemas import PipelineProcessingTask @@ -40,19 +42,19 @@ def _mock_nats_setup(self): js.delete_consumer = AsyncMock() js.delete_stream = AsyncMock() - with patch("ami.ml.orchestration.nats_connection_pool.get_pool") as mock_get_pool: - mock_pool = MagicMock() - mock_pool.get_connection = AsyncMock(return_value=(nc, js)) - mock_get_pool.return_value = mock_pool - yield nc, js, mock_pool + with patch("ami.ml.orchestration.nats_connection.get_provider") as mock_get_provider: + mock_provider = MagicMock() + mock_provider.get_connection = AsyncMock(return_value=(nc, js)) + mock_get_provider.return_value = mock_provider + yield nc, js, mock_provider async def test_publish_task_creates_stream_and_consumer(self): """Test that publish_task ensures stream and consumer exist.""" sample_task = self._create_sample_task() with self._mock_nats_setup() as (_, js, _): - js.stream_info.side_effect = Exception("Not found") - js.consumer_info.side_effect = Exception("Not found") + js.stream_info.side_effect = NotFoundError + js.consumer_info.side_effect = NotFoundError manager = TaskQueueManager() await manager.publish_task(456, sample_task) @@ -129,3 +131,103 @@ async def test_naming_conventions(self): self.assertEqual(manager._get_stream_name(123), "job_123") self.assertEqual(manager._get_subject(123), "job.123.tasks") self.assertEqual(manager._get_consumer_name(123), "job-123-consumer") + + +class TestRetryOnConnectionError(unittest.IsolatedAsyncioTestCase): + """Test retry_on_connection_error decorator behavior.""" + + def _create_sample_task(self): + """Helper to create a sample PipelineProcessingTask.""" + return PipelineProcessingTask( + id="task-retry-test", + image_id="img-retry", + image_url="https://example.com/retry.jpg", + ) + + async def test_retry_resets_provider_on_connection_error(self): + """On connection error, the decorator should call provider.reset() before retrying.""" + from nats.errors import ConnectionClosedError + + nc = MagicMock() + nc.is_closed = False + js = MagicMock() + js.stream_info = AsyncMock() + js.add_stream = AsyncMock() + js.consumer_info = AsyncMock() + js.add_consumer = AsyncMock() + # First publish fails with connection error, second succeeds + js.publish = AsyncMock(side_effect=[ConnectionClosedError(), MagicMock(seq=1)]) + js.pull_subscribe = AsyncMock() + + with patch("ami.ml.orchestration.nats_connection.get_provider") as mock_get_provider: + mock_provider = MagicMock() + mock_provider.get_connection = AsyncMock(return_value=(nc, js)) + mock_provider.reset = AsyncMock() + mock_get_provider.return_value = mock_provider + + manager = TaskQueueManager() + sample_task = self._create_sample_task() + + # Should succeed after retry + with patch("ami.ml.orchestration.nats_queue.asyncio.sleep", new_callable=AsyncMock): + result = await manager.publish_task(456, sample_task) + + self.assertTrue(result) + # provider.reset() should have been called once (after first failure) + mock_provider.reset.assert_called_once() + + async def test_retry_raises_after_max_retries(self): + """After exhausting retries, the last error should be raised.""" + from nats.errors import ConnectionClosedError + + nc = MagicMock() + nc.is_closed = False + js = MagicMock() + js.stream_info = AsyncMock() + js.add_stream = AsyncMock() + js.consumer_info = AsyncMock() + js.add_consumer = AsyncMock() + # All attempts fail + js.publish = AsyncMock(side_effect=ConnectionClosedError()) + + with patch("ami.ml.orchestration.nats_connection.get_provider") as mock_get_provider: + mock_provider = MagicMock() + mock_provider.get_connection = AsyncMock(return_value=(nc, js)) + mock_provider.reset = AsyncMock() + mock_get_provider.return_value = mock_provider + + manager = TaskQueueManager() + sample_task = self._create_sample_task() + + with patch("ami.ml.orchestration.nats_queue.asyncio.sleep", new_callable=AsyncMock): + with self.assertRaises(ConnectionClosedError): + await manager.publish_task(456, sample_task) + + # reset() called twice (max_retries=2, so 2 retries means 2 resets) + self.assertEqual(mock_provider.reset.call_count, 2) + + async def test_non_connection_errors_are_not_retried(self): + """Non-connection errors (e.g. ValueError) should propagate immediately without retry.""" + nc = MagicMock() + nc.is_closed = False + js = MagicMock() + js.stream_info = AsyncMock() + js.add_stream = AsyncMock() + js.consumer_info = AsyncMock() + js.add_consumer = AsyncMock() + js.publish = AsyncMock(side_effect=ValueError("bad data")) + + with patch("ami.ml.orchestration.nats_connection.get_provider") as mock_get_provider: + mock_provider = MagicMock() + mock_provider.get_connection = AsyncMock(return_value=(nc, js)) + mock_provider.reset = AsyncMock() + mock_get_provider.return_value = mock_provider + + manager = TaskQueueManager() + sample_task = self._create_sample_task() + + with self.assertRaises(ValueError): + await manager.publish_task(456, sample_task) + + # reset() should NOT have been called + mock_provider.reset.assert_not_called() diff --git a/config/settings/base.py b/config/settings/base.py index dad65ce21..31f34d748 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -266,6 +266,8 @@ # NATS # ------------------------------------------------------------------------------ NATS_URL = env("NATS_URL", default="nats://localhost:4222") # type: ignore[no-untyped-call] +# Connection strategy: "pool" (persistent, default) or "per_operation" (fresh connection each time) +NATS_CONNECTION_STRATEGY = env("NATS_CONNECTION_STRATEGY", default="pool") # type: ignore[no-untyped-call] # ADMIN # ------------------------------------------------------------------------------ From 4d66c07e3ed6cd73a9033777f9fae848acac02d5 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 13 Feb 2026 14:29:31 -0800 Subject: [PATCH 11/16] =?UTF-8?q?refactor:=20simplify=20NATS=20connection?= =?UTF-8?q?=20module=20=E2=80=94=20pool-only,=20archive=20original?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove the switchable strategy pattern (Protocol + factory + Django setting) and expose the connection pool directly via module-level get_connection() and reset_connection() functions. The PerOperationConnection is archived as ContextManagerConnection for reference. Remove NATS_CONNECTION_STRATEGY setting. Co-Authored-By: Claude --- ami/ml/orchestration/nats_connection.py | 160 +++++++--------- ami/ml/orchestration/nats_queue.py | 27 ++- .../tests/test_nats_connection.py | 173 ++++-------------- ami/ml/orchestration/tests/test_nats_queue.py | 75 ++++---- config/settings/base.py | 2 - 5 files changed, 166 insertions(+), 271 deletions(-) diff --git a/ami/ml/orchestration/nats_connection.py b/ami/ml/orchestration/nats_connection.py index 7edc29cc7..b2afc7699 100644 --- a/ami/ml/orchestration/nats_connection.py +++ b/ami/ml/orchestration/nats_connection.py @@ -1,42 +1,34 @@ """ NATS connection management for both Celery workers and Django processes. -Provides two connection strategies, selectable via NATS_CONNECTION_STRATEGY setting: +Uses a persistent connection pool (one NATS connection per event loop) for all +TaskQueueManager operations. A 1000-image job generates ~1500+ NATS operations +(1 for queuing, 250-500 for task fetches, 1000 for ACKs). The pool keeps one +connection alive per event loop and reuses it for all of them. - "pool" (default): - Maintains a persistent NATS connection per event loop and reuses it across - all TaskQueueManager operations. A 1000-image job generates ~1500+ NATS - operations (1 for queuing, 250-500 for task fetches, 1000 for ACKs). The - pool keeps one connection alive per event loop and reuses it for all of them. - - "per_operation": - Creates a fresh TCP connection for every get_connection() call. Simple but - expensive — the same 1000-image job opens ~1500 TCP connections. Use this - only for debugging connection issues or when pooling causes problems. - -Why keyed by event loop (pool strategy): +Why keyed by event loop: Django views and Celery tasks use async_to_sync(), which creates a new event loop per thread. asyncio.Lock and nats.Client are bound to the loop they were created on, so sharing them across loops causes "attached to a different loop" errors. Keying by loop ensures isolation. WeakKeyDictionary auto-cleans when loops are garbage collected, so short-lived loops don't leak. -Connection lifecycle (pool strategy): +Connection lifecycle: - Created lazily on first use within an event loop - Reused for all subsequent operations on that loop - - On connection error: retry decorator calls pool.reset() to close the stale - connection; next operation creates a fresh one (see retry_on_connection_error - in nats_queue.py) + - On connection error: retry decorator calls reset_connection() to close the + stale connection; next operation creates a fresh one (see + retry_on_connection_error in nats_queue.py) - Cleaned up automatically when the event loop is garbage collected -Both strategies implement the same interface (get_connection, reset, close) so -TaskQueueManager is agnostic to which one is active. +Archived alternative: + ContextManagerConnection preserves the original pre-pool implementation + (one connection per `async with` block) as a drop-in fallback. """ import asyncio import logging import threading -import typing from typing import TYPE_CHECKING from weakref import WeakKeyDictionary @@ -50,19 +42,6 @@ logger = logging.getLogger(__name__) -class ConnectionProvider(typing.Protocol): - """Interface that all NATS connection strategies must implement.""" - - async def get_connection(self) -> tuple["NATSClient", JetStreamContext]: - ... - - async def reset(self) -> None: - ... - - async def close(self) -> None: - ... - - class ConnectionPool: """ Manages a single persistent NATS connection per event loop. @@ -158,93 +137,86 @@ async def reset(self): self._lock = None # Clear lock so new one is created for fresh connection -class PerOperationConnection: +class ContextManagerConnection: """ - Creates a fresh NATS connection on every get_connection() call. + Archived pre-pool implementation: one NATS connection per `async with` block. - Each call closes the previous connection (if any) and opens a new one. - This avoids any shared state but is expensive: ~1500 TCP round-trips per - 1000-image job vs. 1 with the pool strategy. + This was the original approach before the connection pool was added. It creates + a fresh connection on get_connection() and expects the caller to close it when + done. There is no connection reuse and no retry logic at this layer. - Use for debugging connection lifecycle issues or when the pool causes - problems (e.g. event loop mismatch edge cases). - """ + Trade-offs vs ConnectionPool: + - Simpler: no shared state, no locking, no event-loop keying + - Expensive: ~1500 TCP connections per 1000-image job vs 1 with the pool + - No automatic reconnection — caller must handle connection failures - def __init__(self): - self._nc: "NATSClient | None" = None - self._js: JetStreamContext | None = None + Kept as a drop-in fallback. To switch, change the class used in + _create_pool() below from ConnectionPool to ContextManagerConnection. + """ async def get_connection(self) -> tuple["NATSClient", JetStreamContext]: - """Create a fresh NATS connection, closing any previous one.""" - # Close previous connection if it exists - await self.close() - + """Create a fresh NATS connection.""" nats_url = settings.NATS_URL try: - logger.debug(f"Creating transient NATS connection to {nats_url}") - self._nc = await nats.connect(nats_url) - self._js = self._nc.jetstream() - return self._nc, self._js + logger.debug(f"Creating per-operation NATS connection to {nats_url}") + nc = await nats.connect(nats_url) + js = nc.jetstream() + return nc, js except Exception as e: logger.error(f"Failed to connect to NATS: {e}") raise RuntimeError(f"Could not establish NATS connection: {e}") from e async def close(self): - """Close the current connection if it exists.""" - if self._nc is not None and not self._nc.is_closed: - try: - await self._nc.close() - except Exception as e: - logger.debug(f"Error closing transient connection (expected): {e}") - self._nc = None - self._js = None + """No-op — connections are not tracked.""" + pass async def reset(self): - """Close and clear — next get_connection() creates a fresh one.""" - await self.close() + """No-op — connections are not tracked.""" + pass -# Event-loop-keyed providers: one per event loop. +# Event-loop-keyed pools: one ConnectionPool per event loop. # WeakKeyDictionary automatically cleans up when event loops are garbage collected. -_providers: WeakKeyDictionary[asyncio.AbstractEventLoop, ConnectionProvider] = WeakKeyDictionary() -_providers_lock = threading.Lock() +_pools: WeakKeyDictionary[asyncio.AbstractEventLoop, ConnectionPool] = WeakKeyDictionary() +_pools_lock = threading.Lock() -def _create_provider() -> ConnectionProvider: - """Create a connection provider based on the NATS_CONNECTION_STRATEGY setting.""" - strategy = settings.NATS_CONNECTION_STRATEGY - if strategy == "per_operation": - logger.info("Using NATS connection strategy: per_operation") - return PerOperationConnection() - logger.info("Using NATS connection strategy: pool (persistent)") - return ConnectionPool() +def _get_pool() -> ConnectionPool: + """Get or create the ConnectionPool for the current event loop.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + raise RuntimeError( + "get_connection() must be called from an async context with a running event loop. " + "If calling from sync code, use async_to_sync() to wrap the async function." + ) + with _pools_lock: + if loop not in _pools: + _pools[loop] = ConnectionPool() + logger.debug(f"Created NATS connection pool for event loop {id(loop)}") + return _pools[loop] -def get_provider() -> ConnectionProvider: - """ - Get or create the connection provider for the current event loop. - Each event loop gets its own provider to prevent "attached to a different loop" - errors. The provider type is determined by the NATS_CONNECTION_STRATEGY setting: - - "pool" (default): persistent connection reuse via ConnectionPool - - "per_operation": fresh connection each time via PerOperationConnection +async def get_connection() -> tuple["NATSClient", JetStreamContext]: + """ + Get or create a NATS connection for the current event loop. Returns: - ConnectionProvider bound to the current event loop - + Tuple of (NATS connection, JetStream context) Raises: RuntimeError: If called outside of an async context (no running event loop) """ - try: - loop = asyncio.get_running_loop() - except RuntimeError: - raise RuntimeError( - "get_provider() must be called from an async context with a running event loop. " - "If calling from sync code, use async_to_sync() to wrap the async function." - ) + pool = _get_pool() + return await pool.get_connection() - with _providers_lock: - if loop not in _providers: - _providers[loop] = _create_provider() - logger.debug(f"Created NATS connection provider for event loop {id(loop)}") - return _providers[loop] + +async def reset_connection() -> None: + """ + Reset the NATS connection for the current event loop. + + Closes the current connection and clears all state so the next call to + get_connection() creates a fresh one. + """ + pool = _get_pool() + await pool.reset() diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 29f17f08f..d481d29a6 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -47,6 +47,16 @@ def retry_on_connection_error(max_retries: int = 2, backoff_seconds: float = 0.5 This works correctly with async_to_sync() because the pool is keyed by event loop, ensuring each retry uses the connection bound to the current loop. + Retried error types: + - ConnectionClosedError: server closed the connection + - NoServersError: cannot reach any NATS server + - TimeoutError: NATS operation timed out + - ConnectionReconnectingError: client is mid-reconnect + - StaleConnectionError: client detected a stale connection + - OSError: network-level failures — includes ConnectionRefusedError, + ConnectionResetError, BrokenPipeError, socket.timeout, and other + OS-level socket/DNS errors + Args: max_retries: Maximum number of retry attempts (default: 2) backoff_seconds: Initial backoff time in seconds (default: 0.5) @@ -68,7 +78,8 @@ async def wrapper(self, *args, **kwargs) -> T: nats_errors.NoServersError, nats_errors.TimeoutError, nats_errors.ConnectionReconnectingError, - OSError, # Network errors + nats_errors.StaleConnectionError, + OSError, # ConnectionRefusedError, ConnectionResetError, BrokenPipeError, etc. ) as e: last_error = e # Don't retry on last attempt @@ -78,11 +89,10 @@ async def wrapper(self, *args, **kwargs) -> T: exc_info=True, ) break - # Reset the connection provider so next attempt gets a fresh connection - from ami.ml.orchestration.nats_connection import get_provider + # Reset the connection so next attempt gets a fresh one + from ami.ml.orchestration.nats_connection import reset_connection - provider = get_provider() - await provider.reset() + await reset_connection() # Exponential backoff wait_time = backoff_seconds * (2**attempt) logger.warning( @@ -119,11 +129,10 @@ class TaskQueueManager: """ async def _get_connection(self) -> tuple["NATSClient", JetStreamContext]: - """Get connection from the event-loop-local provider.""" - from ami.ml.orchestration.nats_connection import get_provider + """Get connection from the event-loop-local pool.""" + from ami.ml.orchestration.nats_connection import get_connection - provider = get_provider() - return await provider.get_connection() + return await get_connection() def _get_stream_name(self, job_id: int) -> str: """Get stream name from job_id.""" diff --git a/ami/ml/orchestration/tests/test_nats_connection.py b/ami/ml/orchestration/tests/test_nats_connection.py index 250a7a980..286995977 100644 --- a/ami/ml/orchestration/tests/test_nats_connection.py +++ b/ami/ml/orchestration/tests/test_nats_connection.py @@ -4,39 +4,7 @@ import unittest from unittest.mock import AsyncMock, MagicMock, patch -from ami.ml.orchestration.nats_connection import ConnectionPool, PerOperationConnection, _create_provider - - -class TestCreateProvider(unittest.TestCase): - """Test _create_provider() returns the correct strategy based on settings.""" - - @patch("ami.ml.orchestration.nats_connection.settings") - def test_default_strategy_returns_connection_pool(self, mock_settings): - """Test that default (unspecified) strategy returns ConnectionPool.""" - mock_settings.NATS_CONNECTION_STRATEGY = "pool" - provider = _create_provider() - self.assertIsInstance(provider, ConnectionPool) - - @patch("ami.ml.orchestration.nats_connection.settings") - def test_per_operation_strategy_returns_per_operation_connection(self, mock_settings): - """Test that per_operation strategy returns PerOperationConnection.""" - mock_settings.NATS_CONNECTION_STRATEGY = "per_operation" - provider = _create_provider() - self.assertIsInstance(provider, PerOperationConnection) - - @patch("ami.ml.orchestration.nats_connection.settings") - def test_unknown_strategy_falls_back_to_pool(self, mock_settings): - """Test that unknown strategy falls back to ConnectionPool.""" - mock_settings.NATS_CONNECTION_STRATEGY = "unknown_value" - provider = _create_provider() - self.assertIsInstance(provider, ConnectionPool) - - @patch("ami.ml.orchestration.nats_connection.settings") - def test_empty_string_strategy_falls_back_to_pool(self, mock_settings): - """Test that empty string strategy falls back to ConnectionPool.""" - mock_settings.NATS_CONNECTION_STRATEGY = "" - provider = _create_provider() - self.assertIsInstance(provider, ConnectionPool) +from ami.ml.orchestration.nats_connection import ConnectionPool class TestConnectionPoolBehavior(unittest.IsolatedAsyncioTestCase): @@ -190,147 +158,82 @@ async def test_reset_closes_and_clears_state(self, mock_settings, mock_nats): self.assertIsNone(pool._lock) -class TestPerOperationConnectionBehavior(unittest.IsolatedAsyncioTestCase): - """Test PerOperationConnection lifecycle behavior.""" +class TestModuleLevelFunctions(unittest.IsolatedAsyncioTestCase): + """Test module-level get_connection() and reset_connection() functions.""" @patch("ami.ml.orchestration.nats_connection.nats") @patch("ami.ml.orchestration.nats_connection.settings") - async def test_get_connection_creates_connection(self, mock_settings, mock_nats): - """Test that get_connection creates a fresh connection.""" - mock_settings.NATS_URL = "nats://test:4222" + async def test_get_connection_returns_connection(self, mock_settings, mock_nats): + """Test that module-level get_connection() returns a NATS connection.""" + from ami.ml.orchestration.nats_connection import _pools, get_connection + mock_settings.NATS_URL = "nats://test:4222" mock_nc = MagicMock() mock_nc.is_closed = False + mock_nc.is_connected = True mock_nc.close = AsyncMock() mock_js = MagicMock() mock_nc.jetstream.return_value = mock_js - mock_nats.connect = AsyncMock(return_value=mock_nc) - conn = PerOperationConnection() - nc, js = await conn.get_connection() + # Clear pools to avoid leaking state between tests + _pools.clear() + + nc, js = await get_connection() self.assertIs(nc, mock_nc) self.assertIs(js, mock_js) - mock_nats.connect.assert_called_once_with("nats://test:4222") - - @patch("ami.ml.orchestration.nats_connection.nats") - @patch("ami.ml.orchestration.nats_connection.settings") - async def test_get_connection_closes_previous(self, mock_settings, mock_nats): - """Test that each get_connection() call closes the previous connection.""" - mock_settings.NATS_URL = "nats://test:4222" - - # Create two mock connections (MagicMock because jetstream() is sync) - mock_nc1 = MagicMock() - mock_nc1.is_closed = False - mock_nc1.close = AsyncMock() - mock_nc1.jetstream.return_value = MagicMock() - - mock_nc2 = MagicMock() - mock_nc2.is_closed = False - mock_nc2.close = AsyncMock() - mock_nc2.jetstream.return_value = MagicMock() - - mock_nats.connect = AsyncMock(side_effect=[mock_nc1, mock_nc2]) - - conn = PerOperationConnection() - - # First call - nc1, _ = await conn.get_connection() - self.assertIs(nc1, mock_nc1) - - # Second call should close the first connection - nc2, _ = await conn.get_connection() - self.assertIs(nc2, mock_nc2) - - # Verify first connection was closed - mock_nc1.close.assert_called_once() @patch("ami.ml.orchestration.nats_connection.nats") @patch("ami.ml.orchestration.nats_connection.settings") - async def test_get_connection_handles_close_errors(self, mock_settings, mock_nats): - """Test that get_connection handles errors when closing previous connection.""" - mock_settings.NATS_URL = "nats://test:4222" - - # First connection throws error on close - mock_nc1 = MagicMock() - mock_nc1.is_closed = False - mock_nc1.close = AsyncMock(side_effect=RuntimeError("Close error")) - mock_nc1.jetstream.return_value = MagicMock() - - # Second connection succeeds - mock_nc2 = MagicMock() - mock_nc2.is_closed = False - mock_nc2.close = AsyncMock() - mock_nc2.jetstream.return_value = MagicMock() - - mock_nats.connect = AsyncMock(side_effect=[mock_nc1, mock_nc2]) - - conn = PerOperationConnection() - - # First call - await conn.get_connection() + async def test_get_connection_reuses_pool_for_same_loop(self, mock_settings, mock_nats): + """Test that repeated calls on the same loop reuse the same pool.""" + from ami.ml.orchestration.nats_connection import _pools, get_connection - # Second call should not raise even though closing first connection fails - nc2, _ = await conn.get_connection() - self.assertIs(nc2, mock_nc2) - - @patch("ami.ml.orchestration.nats_connection.nats") - @patch("ami.ml.orchestration.nats_connection.settings") - async def test_reset_closes_connection(self, mock_settings, mock_nats): - """Test that reset() closes the current connection.""" mock_settings.NATS_URL = "nats://test:4222" - mock_nc = MagicMock() mock_nc.is_closed = False + mock_nc.is_connected = True mock_nc.close = AsyncMock() - mock_nc.jetstream.return_value = MagicMock() - + mock_js = MagicMock() + mock_nc.jetstream.return_value = mock_js mock_nats.connect = AsyncMock(return_value=mock_nc) - conn = PerOperationConnection() - await conn.get_connection() + _pools.clear() - await conn.reset() + await get_connection() + await get_connection() - mock_nc.close.assert_called_once() - # After reset, internal state should be cleared - self.assertIsNone(conn._nc) - self.assertIsNone(conn._js) + # Only one TCP connection should have been created + mock_nats.connect.assert_called_once() @patch("ami.ml.orchestration.nats_connection.nats") @patch("ami.ml.orchestration.nats_connection.settings") - async def test_close_closes_connection(self, mock_settings, mock_nats): - """Test that close() closes the connection.""" - mock_settings.NATS_URL = "nats://test:4222" + async def test_reset_connection_clears_pool(self, mock_settings, mock_nats): + """Test that reset_connection() resets the pool for the current loop.""" + from ami.ml.orchestration.nats_connection import _pools, get_connection, reset_connection + mock_settings.NATS_URL = "nats://test:4222" mock_nc = MagicMock() mock_nc.is_closed = False + mock_nc.is_connected = True mock_nc.close = AsyncMock() - mock_nc.jetstream.return_value = MagicMock() - + mock_js = MagicMock() + mock_nc.jetstream.return_value = mock_js mock_nats.connect = AsyncMock(return_value=mock_nc) - conn = PerOperationConnection() - await conn.get_connection() + _pools.clear() - await conn.close() + await get_connection() + await reset_connection() mock_nc.close.assert_called_once() - self.assertIsNone(conn._nc) - self.assertIsNone(conn._js) - @patch("ami.ml.orchestration.nats_connection.nats") - @patch("ami.ml.orchestration.nats_connection.settings") - async def test_get_connection_raises_on_connection_error(self, mock_settings, mock_nats): - """Test that get_connection raises RuntimeError on connection failure.""" - mock_settings.NATS_URL = "nats://test:4222" - - mock_nats.connect = AsyncMock(side_effect=ConnectionError("Connection failed")) - - conn = PerOperationConnection() + def test_get_connection_raises_without_event_loop(self): + """Test that _get_pool raises RuntimeError outside async context.""" + from ami.ml.orchestration.nats_connection import _get_pool with self.assertRaises(RuntimeError) as context: - await conn.get_connection() + _get_pool() - self.assertIn("Could not establish NATS connection", str(context.exception)) + self.assertIn("must be called from an async context", str(context.exception)) diff --git a/ami/ml/orchestration/tests/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py index 7367d7b42..7ad7c281d 100644 --- a/ami/ml/orchestration/tests/test_nats_queue.py +++ b/ami/ml/orchestration/tests/test_nats_queue.py @@ -42,11 +42,9 @@ def _mock_nats_setup(self): js.delete_consumer = AsyncMock() js.delete_stream = AsyncMock() - with patch("ami.ml.orchestration.nats_connection.get_provider") as mock_get_provider: - mock_provider = MagicMock() - mock_provider.get_connection = AsyncMock(return_value=(nc, js)) - mock_get_provider.return_value = mock_provider - yield nc, js, mock_provider + with patch("ami.ml.orchestration.nats_connection.get_connection", new_callable=AsyncMock) as mock_get_conn: + mock_get_conn.return_value = (nc, js) + yield nc, js, mock_get_conn async def test_publish_task_creates_stream_and_consumer(self): """Test that publish_task ensures stream and consumer exist.""" @@ -144,8 +142,8 @@ def _create_sample_task(self): image_url="https://example.com/retry.jpg", ) - async def test_retry_resets_provider_on_connection_error(self): - """On connection error, the decorator should call provider.reset() before retrying.""" + async def test_retry_resets_connection_on_error(self): + """On connection error, the decorator should call reset_connection() before retrying.""" from nats.errors import ConnectionClosedError nc = MagicMock() @@ -159,12 +157,17 @@ async def test_retry_resets_provider_on_connection_error(self): js.publish = AsyncMock(side_effect=[ConnectionClosedError(), MagicMock(seq=1)]) js.pull_subscribe = AsyncMock() - with patch("ami.ml.orchestration.nats_connection.get_provider") as mock_get_provider: - mock_provider = MagicMock() - mock_provider.get_connection = AsyncMock(return_value=(nc, js)) - mock_provider.reset = AsyncMock() - mock_get_provider.return_value = mock_provider - + with ( + patch( + "ami.ml.orchestration.nats_connection.get_connection", + new_callable=AsyncMock, + return_value=(nc, js), + ), + patch( + "ami.ml.orchestration.nats_connection.reset_connection", + new_callable=AsyncMock, + ) as mock_reset, + ): manager = TaskQueueManager() sample_task = self._create_sample_task() @@ -173,8 +176,8 @@ async def test_retry_resets_provider_on_connection_error(self): result = await manager.publish_task(456, sample_task) self.assertTrue(result) - # provider.reset() should have been called once (after first failure) - mock_provider.reset.assert_called_once() + # reset_connection() should have been called once (after first failure) + mock_reset.assert_called_once() async def test_retry_raises_after_max_retries(self): """After exhausting retries, the last error should be raised.""" @@ -190,12 +193,17 @@ async def test_retry_raises_after_max_retries(self): # All attempts fail js.publish = AsyncMock(side_effect=ConnectionClosedError()) - with patch("ami.ml.orchestration.nats_connection.get_provider") as mock_get_provider: - mock_provider = MagicMock() - mock_provider.get_connection = AsyncMock(return_value=(nc, js)) - mock_provider.reset = AsyncMock() - mock_get_provider.return_value = mock_provider - + with ( + patch( + "ami.ml.orchestration.nats_connection.get_connection", + new_callable=AsyncMock, + return_value=(nc, js), + ), + patch( + "ami.ml.orchestration.nats_connection.reset_connection", + new_callable=AsyncMock, + ) as mock_reset, + ): manager = TaskQueueManager() sample_task = self._create_sample_task() @@ -203,8 +211,8 @@ async def test_retry_raises_after_max_retries(self): with self.assertRaises(ConnectionClosedError): await manager.publish_task(456, sample_task) - # reset() called twice (max_retries=2, so 2 retries means 2 resets) - self.assertEqual(mock_provider.reset.call_count, 2) + # reset_connection() called twice (max_retries=2, so 2 retries means 2 resets) + self.assertEqual(mock_reset.call_count, 2) async def test_non_connection_errors_are_not_retried(self): """Non-connection errors (e.g. ValueError) should propagate immediately without retry.""" @@ -217,17 +225,22 @@ async def test_non_connection_errors_are_not_retried(self): js.add_consumer = AsyncMock() js.publish = AsyncMock(side_effect=ValueError("bad data")) - with patch("ami.ml.orchestration.nats_connection.get_provider") as mock_get_provider: - mock_provider = MagicMock() - mock_provider.get_connection = AsyncMock(return_value=(nc, js)) - mock_provider.reset = AsyncMock() - mock_get_provider.return_value = mock_provider - + with ( + patch( + "ami.ml.orchestration.nats_connection.get_connection", + new_callable=AsyncMock, + return_value=(nc, js), + ), + patch( + "ami.ml.orchestration.nats_connection.reset_connection", + new_callable=AsyncMock, + ) as mock_reset, + ): manager = TaskQueueManager() sample_task = self._create_sample_task() with self.assertRaises(ValueError): await manager.publish_task(456, sample_task) - # reset() should NOT have been called - mock_provider.reset.assert_not_called() + # reset_connection() should NOT have been called + mock_reset.assert_not_called() diff --git a/config/settings/base.py b/config/settings/base.py index 31f34d748..dad65ce21 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -266,8 +266,6 @@ # NATS # ------------------------------------------------------------------------------ NATS_URL = env("NATS_URL", default="nats://localhost:4222") # type: ignore[no-untyped-call] -# Connection strategy: "pool" (persistent, default) or "per_operation" (fresh connection each time) -NATS_CONNECTION_STRATEGY = env("NATS_CONNECTION_STRATEGY", default="pool") # type: ignore[no-untyped-call] # ADMIN # ------------------------------------------------------------------------------ From 41bbeb32b92abd01abe7588fe9dd0011d4b09f08 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 13 Feb 2026 15:13:02 -0800 Subject: [PATCH 12/16] docs: clarify where connection pool provides reuse vs. single-use The docstring previously implied the pool reused connections across all async_to_sync() calls. In practice, each async_to_sync() creates a new event loop, so reuse only happens within a single boundary. Updated to be explicit about where the pool helps (bulk publishes, batch reserves) and where it doesn't (single-operation calls like ACKs). Co-Authored-By: Claude --- ami/ml/orchestration/nats_connection.py | 39 ++++++++++++++----------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/ami/ml/orchestration/nats_connection.py b/ami/ml/orchestration/nats_connection.py index b2afc7699..83ac9b3cb 100644 --- a/ami/ml/orchestration/nats_connection.py +++ b/ami/ml/orchestration/nats_connection.py @@ -1,25 +1,30 @@ """ NATS connection management for both Celery workers and Django processes. -Uses a persistent connection pool (one NATS connection per event loop) for all -TaskQueueManager operations. A 1000-image job generates ~1500+ NATS operations -(1 for queuing, 250-500 for task fetches, 1000 for ACKs). The pool keeps one -connection alive per event loop and reuses it for all of them. +Provides a ConnectionPool keyed by event loop. The pool reuses a single NATS +connection for all async operations *within* one async_to_sync() boundary. +It does NOT provide reuse across separate async_to_sync() calls — each call +creates a new event loop, so a new connection is established. + +Where the pool helps: + The main beneficiary is queue_images_to_nats() in jobs.py, which wraps + 1000+ publish_task() awaits in a single async_to_sync() call. All of those + awaits share one event loop and therefore one NATS connection. Without the + pool, each publish would open its own TCP connection (~1500 per job). + Similarly, JobViewSet.tasks() batches multiple reserve_task() calls in one + async_to_sync() boundary. + +Where it doesn't help: + Single-operation boundaries like _ack_task_via_nats() (one ACK per call) + get no reuse — the pool is effectively single-use there. The overhead is + negligible (one dict lookup), and the retry_on_connection_error decorator + provides resilience regardless. Why keyed by event loop: - Django views and Celery tasks use async_to_sync(), which creates a new event - loop per thread. asyncio.Lock and nats.Client are bound to the loop they were - created on, so sharing them across loops causes "attached to a different loop" - errors. Keying by loop ensures isolation. WeakKeyDictionary auto-cleans when - loops are garbage collected, so short-lived loops don't leak. - -Connection lifecycle: - - Created lazily on first use within an event loop - - Reused for all subsequent operations on that loop - - On connection error: retry decorator calls reset_connection() to close the - stale connection; next operation creates a fresh one (see - retry_on_connection_error in nats_queue.py) - - Cleaned up automatically when the event loop is garbage collected + asyncio.Lock and nats.Client are bound to the loop they were created on. + Sharing them across loops causes "attached to a different loop" errors. + Keying by loop ensures isolation. WeakKeyDictionary auto-cleans when loops + are garbage collected, so short-lived loops don't leak. Archived alternative: ContextManagerConnection preserves the original pre-pool implementation From ead53d1c8408c146bd32f9e3934162141599fabe Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 13 Feb 2026 15:15:04 -0800 Subject: [PATCH 13/16] fix: use `from None` to suppress noisy exception chain in _get_pool Co-Authored-By: Claude --- ami/ml/orchestration/nats_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ami/ml/orchestration/nats_connection.py b/ami/ml/orchestration/nats_connection.py index 83ac9b3cb..30a78dc69 100644 --- a/ami/ml/orchestration/nats_connection.py +++ b/ami/ml/orchestration/nats_connection.py @@ -194,7 +194,7 @@ def _get_pool() -> ConnectionPool: raise RuntimeError( "get_connection() must be called from an async context with a running event loop. " "If calling from sync code, use async_to_sync() to wrap the async function." - ) + ) from None with _pools_lock: if loop not in _pools: From 9737301bbf277adbceb62ecfa67f3f1e313f26dd Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 13 Feb 2026 15:15:13 -0800 Subject: [PATCH 14/16] docs: update AGENTS.md test commands to use docker-compose.ci.yml Co-Authored-By: Claude --- .agents/AGENTS.md | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/.agents/AGENTS.md b/.agents/AGENTS.md index 8acceeb4d..11aac41fe 100644 --- a/.agents/AGENTS.md +++ b/.agents/AGENTS.md @@ -107,29 +107,34 @@ docker compose restart django celeryworker ### Backend (Django) -Run tests: +Run tests (use `docker-compose.ci.yml` to avoid conflicts with the local dev stack): +```bash +docker compose -f docker-compose.ci.yml run --rm django python manage.py test +``` + +Run a specific test module: ```bash -docker compose run --rm django python manage.py test +docker compose -f docker-compose.ci.yml run --rm django python manage.py test ami.ml.orchestration.tests.test_nats_connection ``` Run specific test pattern: ```bash -docker compose run --rm django python manage.py test -k pattern +docker compose -f docker-compose.ci.yml run --rm django python manage.py test -k pattern ``` Run tests with debugger on failure: ```bash -docker compose run --rm django python manage.py test -k pattern --failfast --pdb +docker compose -f docker-compose.ci.yml run --rm django python manage.py test -k pattern --failfast --pdb ``` Speed up test development (reuse database): ```bash -docker compose run --rm django python manage.py test --keepdb +docker compose -f docker-compose.ci.yml run --rm django python manage.py test --keepdb ``` Run pytest (alternative test runner): ```bash -docker compose run --rm django pytest --ds=config.settings.test --reuse-db +docker compose -f docker-compose.ci.yml run --rm django pytest --ds=config.settings.test --reuse-db ``` Django shell: @@ -654,13 +659,13 @@ images = SourceImage.objects.annotate(det_count=Count('detections')) ```bash # Run specific test class -docker compose run --rm django python manage.py test ami.main.tests.test_models.ProjectTestCase +docker compose -f docker-compose.ci.yml run --rm django python manage.py test ami.main.tests.test_models.ProjectTestCase # Run specific test method -docker compose run --rm django python manage.py test ami.main.tests.test_models.ProjectTestCase.test_project_creation +docker compose -f docker-compose.ci.yml run --rm django python manage.py test ami.main.tests.test_models.ProjectTestCase.test_project_creation # Run with pattern matching -docker compose run --rm django python manage.py test -k test_detection +docker compose -f docker-compose.ci.yml run --rm django python manage.py test -k test_detection ``` ### Pre-commit Hooks From fa0f84b892674bab63801b6fb015a7f78bd106c7 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 13 Feb 2026 17:57:47 -0800 Subject: [PATCH 15/16] fix: correct mock setup in NATS task tests to match plain instantiation The _setup_mock_nats helper was configuring TaskQueueManager as an async context manager (__aenter__/__aexit__), but _ack_task_via_nats uses plain instantiation. The await on a non-awaitable MagicMock failed silently in the except clause, causing acknowledge_task assertions to always fail. Co-Authored-By: Claude --- ami/jobs/test_tasks.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ami/jobs/test_tasks.py b/ami/jobs/test_tasks.py index fc291c8ba..03c797d9c 100644 --- a/ami/jobs/test_tasks.py +++ b/ami/jobs/test_tasks.py @@ -73,10 +73,8 @@ def tearDown(self): def _setup_mock_nats(self, mock_manager_class): """Helper to setup mock NATS manager.""" - mock_manager = AsyncMock() + mock_manager = mock_manager_class.return_value mock_manager.acknowledge_task = AsyncMock(return_value=True) - mock_manager_class.return_value.__aenter__.return_value = mock_manager - mock_manager_class.return_value.__aexit__.return_value = AsyncMock() return mock_manager def _create_error_result(self, image_id: str | None = None, error_msg: str = "Processing failed") -> dict: From c7b201480536c7ca49cc863496e2cd518443bd86 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 16 Feb 2026 17:01:20 -0800 Subject: [PATCH 16/16] fix: address PR review feedback for NATS connection module - Use logger.exception instead of logger.error in jobs.py for stack traces - Add explicit return None in reserve_task for empty message list - Wrap psub.unsubscribe() in try/except to prevent exception masking - Add test for TimeoutError path in reserve_task - Expand module docstring with call paths, concurrency model, and design rationale for is_reconnecting and lock clearing decisions Co-Authored-By: Claude --- ami/ml/orchestration/jobs.py | 2 +- ami/ml/orchestration/nats_connection.py | 94 ++++++++++++++----- ami/ml/orchestration/nats_queue.py | 9 +- ami/ml/orchestration/tests/test_nats_queue.py | 16 ++++ 4 files changed, 95 insertions(+), 26 deletions(-) diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index b4b9171b7..0f40e83bc 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -105,7 +105,7 @@ async def queue_all_images(): data=task, ) except Exception as e: - logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}") + logger.exception(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}") success = False if success: diff --git a/ami/ml/orchestration/nats_connection.py b/ami/ml/orchestration/nats_connection.py index 30a78dc69..2b63443ae 100644 --- a/ami/ml/orchestration/nats_connection.py +++ b/ami/ml/orchestration/nats_connection.py @@ -6,29 +6,77 @@ It does NOT provide reuse across separate async_to_sync() calls — each call creates a new event loop, so a new connection is established. -Where the pool helps: - The main beneficiary is queue_images_to_nats() in jobs.py, which wraps - 1000+ publish_task() awaits in a single async_to_sync() call. All of those - awaits share one event loop and therefore one NATS connection. Without the - pool, each publish would open its own TCP connection (~1500 per job). - Similarly, JobViewSet.tasks() batches multiple reserve_task() calls in one - async_to_sync() boundary. - -Where it doesn't help: - Single-operation boundaries like _ack_task_via_nats() (one ACK per call) - get no reuse — the pool is effectively single-use there. The overhead is - negligible (one dict lookup), and the retry_on_connection_error decorator - provides resilience regardless. - -Why keyed by event loop: - asyncio.Lock and nats.Client are bound to the loop they were created on. - Sharing them across loops causes "attached to a different loop" errors. - Keying by loop ensures isolation. WeakKeyDictionary auto-cleans when loops - are garbage collected, so short-lived loops don't leak. - -Archived alternative: - ContextManagerConnection preserves the original pre-pool implementation - (one connection per `async with` block) as a drop-in fallback. +Call paths and connection reuse +------------------------------- + +1. Queue images (high-value reuse): + POST /api/v2/jobs/{id}/run/ → Celery run_job → MLJob.run() + → queue_images_to_nats() wraps 1000+ sequential publish_task() awaits in + a single async_to_sync() call. All share one event loop → one connection. + Without the pool each publish would open its own TCP connection. + +2. Reserve tasks (moderate reuse): + GET /api/v2/jobs/{id}/tasks/?batch=N → JobViewSet.tasks() + → async_to_sync() wraps N sequential reserve_task() calls. Typical N=5-10. + +3. Acknowledge (single-use, no reuse): + POST /api/v2/jobs/{id}/result/ → Celery process_nats_pipeline_result + → _ack_task_via_nats() wraps a single acknowledge_task() in its own + async_to_sync() call. Each ACK gets its own event loop → new connection. + Pool overhead is negligible (one dict lookup). retry_on_connection_error + provides resilience. + +4. Cleanup (modest reuse): + Job completion → cleanup_async_job_resources() + → async_to_sync() wraps delete_consumer + delete_stream (2 ops, 1 conn). + +Concurrency model +----------------- + +All call paths use sequential for-loops, never asyncio.gather(). Within a +single async_to_sync() boundary there is only one coroutine running at a time. +This means: + +- The asyncio.Lock in ConnectionPool is defensive but never actually contends. +- reset() is only called from retry_on_connection_error between sequential + retries. No other coroutine races with it. +- Fast-path checks and state mutations outside the lock (lines ~87-94) are + safe because cooperative scheduling guarantees no preemption between + synchronous Python statements. +- Clearing self._lock in reset() is intentional — it ensures the replacement + lock is created bound to the current event loop state. + +Why NOT check is_reconnecting +----------------------------- + +Connections live inside short-lived async_to_sync() event loops. If the NATS +client enters RECONNECTING state, the event loop will typically be destroyed +before reconnection completes. Clearing the client and creating a fresh +connection is correct for this lifecycle. The retry_on_connection_error +decorator provides the real resilience layer, not nats.py's built-in +reconnection (which is designed for long-lived event loops). + +Why keyed by event loop +----------------------- + +asyncio.Lock and nats.Client are bound to the loop they were created on. +Sharing them across loops causes "attached to a different loop" errors. +Keying by loop ensures isolation. WeakKeyDictionary auto-cleans when loops +are garbage collected, so short-lived loops don't leak. + +Thread safety +------------- + +_pools_lock (threading.Lock) serializes access to the global _pools dict. +Multiple Celery worker threads can create pools concurrently — the lock +prevents races. Within a single event loop, the asyncio.Lock serializes +connection creation (though in practice it never contends, see above). + +Archived alternative +-------------------- + +ContextManagerConnection preserves the original pre-pool implementation +(one connection per `async with` block) as a drop-in fallback. """ import asyncio diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index d481d29a6..3fe1be355 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -267,13 +267,18 @@ async def reserve_task(self, job_id: int, timeout: float | None = None) -> Pipel logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") return task + return None + except nats_errors.TimeoutError: # No messages available (expected behavior) logger.debug(f"No tasks available in stream for job '{job_id}'") return None finally: - # Always unsubscribe - await psub.unsubscribe() + # Unsubscribe in its own try/except so it never masks exceptions from above + try: + await psub.unsubscribe() + except Exception as e: + logger.warning(f"Failed to unsubscribe pull subscription for job '{job_id}': {e}") @retry_on_connection_error(max_retries=2, backoff_seconds=0.5) async def acknowledge_task(self, reply_subject: str) -> bool: diff --git a/ami/ml/orchestration/tests/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py index 7ad7c281d..2c3f446fb 100644 --- a/ami/ml/orchestration/tests/test_nats_queue.py +++ b/ami/ml/orchestration/tests/test_nats_queue.py @@ -99,6 +99,22 @@ async def test_reserve_task_no_messages(self): self.assertIsNone(task) mock_psub.unsubscribe.assert_called_once() + async def test_reserve_task_timeout(self): + """Test reserve_task when fetch raises TimeoutError (no messages available).""" + from nats import errors as nats_errors + + with self._mock_nats_setup() as (_, js, _): + mock_psub = MagicMock() + mock_psub.fetch = AsyncMock(side_effect=nats_errors.TimeoutError) + mock_psub.unsubscribe = AsyncMock() + js.pull_subscribe = AsyncMock(return_value=mock_psub) + + manager = TaskQueueManager() + task = await manager.reserve_task(123) + + self.assertIsNone(task) + mock_psub.unsubscribe.assert_called_once() + async def test_acknowledge_task_success(self): """Test successful task acknowledgment.""" with self._mock_nats_setup() as (nc, _, _):