diff --git a/.agents/planning/async-job-status-handling.md b/.agents/planning/async-job-status-handling.md new file mode 100644 index 000000000..1a6024a87 --- /dev/null +++ b/.agents/planning/async-job-status-handling.md @@ -0,0 +1,408 @@ +# Plan: Fix Async Pipeline Job Status Handling + +**Date:** 2026-01-30 +**Status:** Ready for implementation + +## Problem Summary + +When `async_pipeline_workers` feature flag is enabled: +1. `queue_images_to_nats()` queues work and returns immediately (lines 400-408 in `ami/jobs/models.py`) +2. The Celery `run_job` task completes without exception +3. The `task_postrun` signal handler (`ami/jobs/tasks.py:175-192`) calls `job.update_status("SUCCESS")` +4. This prematurely marks the job as SUCCESS before async workers actually finish processing + +## Solution Overview + +Use a **progress-based approach** to determine job completion: +1. Add a generic `is_complete()` method to `JobProgress` that works for any job type +2. In `task_postrun`, only allow SUCCESS if all stages are complete +3. Allow FAILURE, REVOKED, and other terminal states to pass through immediately +4. Authoritative completion for async jobs stays in `_update_job_progress` + +--- + +## Background: Job Types and Stages + +Jobs in this app have different types, each with different stages: + +| Job Type | Stages | +|----------|--------| +| MLJob | delay (optional), collect, process, results | +| DataStorageSyncJob | data_storage_sync | +| SourceImageCollectionPopulateJob | populate_captures_collection | +| DataExportJob | exporting_data, uploading_snapshot | +| PostProcessingJob | post_processing | + +The `is_complete()` method must be generic and work for ALL job types by checking if all stages have finished. + +### Celery Signal States + +The `task_postrun` signal fires after every task with these states: +- **SUCCESS**: Task completed without exception +- **FAILURE**: Task raised exception (also triggers `task_failure`) +- **RETRY**: Task requested retry +- **REVOKED**: Task cancelled + +**Handling strategy:** +| State | Behavior | Rationale | +|-------|----------|-----------| +| SUCCESS | Guard - only set if `is_complete()` | Prevents premature success | +| FAILURE | Allow immediately | Job failed, user needs to know | +| REVOKED | Allow immediately | Job cancelled | +| RETRY | Allow immediately | Transient state | + +--- + +## Implementation Steps + +### Step 1: Add `is_complete()` method to JobProgress + +**File:** `ami/jobs/models.py` (in the `JobProgress` class, after `reset()` method around line 198) + +```python +def is_complete(self) -> bool: + """ + Check if all stages have finished processing. + + A job is considered complete when ALL of its stages have: + - progress >= 1.0 (fully processed) + - status in a final state (SUCCESS, FAILURE, or REVOKED) + + This method works for any job type regardless of which stages it has. + It's used by the Celery task_postrun signal to determine whether to + set the job's final SUCCESS status, or defer to async progress handlers. + + Related: Job.update_progress() (lines 924-947) calculates the aggregate + progress percentage across all stages for display purposes. This method + is a binary check for completion that considers both progress AND status. + + Returns: + True if all stages are complete, False otherwise. + Returns False if job has no stages (shouldn't happen in practice). + """ + if not self.stages: + return False + return all( + stage.progress >= 1.0 and stage.status in JobState.final_states() + for stage in self.stages + ) +``` + +### Step 2: Modify `task_postrun` signal handler + +**File:** `ami/jobs/tasks.py` (lines 175-192) + +Only guard SUCCESS state - let all other states pass through: + +```python +@task_postrun.connect(sender=run_job) +def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs): + from ami.jobs.models import Job, JobState + + job_id = task.request.kwargs["job_id"] + if job_id is None: + logger.error(f"Job id is None for task {task_id}") + return + try: + job = Job.objects.get(pk=job_id) + except Job.DoesNotExist: + try: + job = Job.objects.get(task_id=task_id) + except Job.DoesNotExist: + logger.error(f"No job found for task {task_id} or job_id {job_id}") + return + + # Guard only SUCCESS state - let FAILURE, REVOKED, RETRY pass through immediately + # SUCCESS should only be set when all stages are actually complete + # This prevents premature SUCCESS when async workers are still processing + if state == JobState.SUCCESS and not job.progress.is_complete(): + job.logger.info( + f"Job {job.pk} task completed but stages not finished - " + "deferring SUCCESS status to progress handler" + ) + return + + job.update_status(state) +``` + +### Step 3: Add cleanup to `_update_job_progress` + +**File:** `ami/jobs/tasks.py` (lines 151-166) + +When job completes, add cleanup for Redis and NATS resources: + +```python +def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> None: + """ + Update job progress for a specific stage from async pipeline workers. + + This function is called by process_nats_pipeline_result when async workers + report progress. It updates the job's progress model and, when the job + completes, sets the final SUCCESS status. + + For async jobs, this is the authoritative place where SUCCESS status + is set - the Celery task_postrun signal defers to this function. + + Args: + job_id: The job primary key + stage: The processing stage (e.g., "process" or "results" for ML jobs) + progress_percentage: Progress as a float from 0.0 to 1.0 + """ + from ami.jobs.models import Job, JobState # avoid circular import + + with transaction.atomic(): + job = Job.objects.select_for_update().get(pk=job_id) + job.progress.update_stage( + stage, + status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, + progress=progress_percentage, + ) + + # Check if all stages are now complete + if job.progress.is_complete(): + job.status = JobState.SUCCESS + job.progress.summary.status = JobState.SUCCESS + job.finished_at = datetime.datetime.now() # Use naive datetime in local time + job.logger.info(f"Job {job_id} completed successfully - all stages finished") + + # Clean up job-specific resources (Redis state, NATS stream/consumer) + _cleanup_async_job_resources(job_id, job.logger) + + job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%") + job.save() + + +def _cleanup_async_job_resources(job_id: int, job_logger: logging.Logger) -> None: + """ + Clean up all async processing resources for a completed job. + + This function is called when an async job completes (all stages finished). + It cleans up: + + 1. Redis state (via TaskStateManager.cleanup): + - job:{job_id}:pending_images:process - tracks remaining images in process stage + - job:{job_id}:pending_images:results - tracks remaining images in results stage + - job:{job_id}:pending_images_total - total image count for progress calculation + + 2. NATS JetStream resources (via TaskQueueManager.cleanup_job_resources): + - Stream: job_{job_id} - the message stream that holds pending tasks + - Consumer: job-{job_id}-consumer - the durable consumer that tracks delivery + + Why cleanup is needed: + - Redis keys have a 7-day TTL but should be cleaned immediately when job completes + - NATS streams/consumers have 24-hour retention but consume server resources + - Cleaning up immediately prevents resource accumulation from many jobs + + Cleanup failures are logged but don't fail the job - data is already saved. + + Args: + job_id: The job primary key + job_logger: Logger instance for this job (writes to job's log file) + """ + # Cleanup Redis state tracking + state_manager = TaskStateManager(job_id) + state_manager.cleanup() + job_logger.info(f"Cleaned up Redis state for job {job_id}") + + # Cleanup NATS resources (stream and consumer) + try: + async def cleanup_nats(): + async with TaskQueueManager() as manager: + return await manager.cleanup_job_resources(job_id) + + success = async_to_sync(cleanup_nats)() + if success: + job_logger.info(f"Cleaned up NATS resources for job {job_id}") + else: + job_logger.warning(f"Failed to clean up NATS resources for job {job_id}") + except Exception as e: + job_logger.error(f"Error cleaning up NATS resources for job {job_id}: {e}") + # Don't fail the job if cleanup fails - job data is already saved +``` + +### Step 4: Verify existing error handling + +**File:** `ami/jobs/models.py` (lines 402-408) + +**No changes needed** - Existing code already handles FAILURE correctly: +```python +if not queued: + job.logger.error("Aborting job %s because images could not be queued to NATS", job.pk) + job.progress.update_stage("collect", status=JobState.FAILURE) + job.update_status(JobState.FAILURE) + job.finished_at = datetime.datetime.now() + job.save() + return +``` + +--- + +## Files to Modify + +| File | Lines | Change | +|------|-------|--------| +| `ami/jobs/models.py` | ~198 (in JobProgress class) | Add generic `is_complete()` method | +| `ami/jobs/models.py` | ~700 (Job class docstring) | Add future improvements note | +| `ami/jobs/tasks.py` | 175-192 | Guard SUCCESS with `is_complete()` check | +| `ami/jobs/tasks.py` | 151-166 | Use `is_complete()`, add cleanup with docstring | +| `ami/jobs/tasks.py` | imports | Add import for `TaskStateManager` | + +**Related existing code (not modified):** +- `Job.update_progress()` at lines 924-947 - calculates aggregate progress from stages + +--- + +## Why This Approach Works + +| Job Type | Stages Complete When Task Ends? | `task_postrun` Behavior | +|----------|--------------------------------|------------------------| +| Sync ML job | Yes - all stages marked SUCCESS | Sets SUCCESS normally | +| Async ML job | No - stages incomplete | Skips SUCCESS, defers to async handler | +| DataStorageSyncJob | Yes - single stage | Sets SUCCESS normally | +| DataExportJob | Yes - all stages | Sets SUCCESS normally | +| Any job that fails | N/A | FAILURE passes through | +| Cancelled job | N/A | REVOKED passes through | + +--- + +## Race Condition Analysis: `update_progress()` vs `is_complete()` + +### How `update_progress()` works (lines 924-947) + +Called automatically by `save()` (line 958-959), it auto-corrects stage status/progress: + +```python +for stage in self.progress.stages: + if stage.progress > 0 and stage.status == JobState.CREATED: + stage.status = JobState.STARTED # Auto-upgrade CREATED→STARTED + elif stage.status in JobState.final_states() and stage.progress < 1: + stage.progress = 1 # If final, ensure progress=1 + elif stage.progress == 1 and stage.status not in JobState.final_states(): + stage.status = JobState.SUCCESS # Auto-upgrade to SUCCESS if progress=1 +``` + +**Key insight:** Line 939-941 auto-sets SUCCESS only if `progress == 1`. For async jobs, incomplete stages have `progress < 1`, so they won't be auto-upgraded. + +### When is `is_complete()` called? + +1. **In `task_postrun`:** Job loaded fresh from DB (already saved by `job.run()`) +2. **In `_update_job_progress`:** After `update_stage()` but before `save()` + +### Async job flow - no race condition: + +``` +run_job task: +1. job.run() → MLJob.run() +2. collect stage set to SUCCESS, progress=1.0 +3. job.save() called → update_progress() runs → no changes (only collect is 100%) +4. queue_images_to_nats() called, returns +5. job.run() returns successfully + +task_postrun fires: +6. job = Job.objects.get(pk=job_id) # Fresh from DB +7. Job state: collect=SUCCESS(1.0), process=CREATED(0), results=CREATED(0) +8. is_complete() → False (not all stages at 100% final) +9. SUCCESS status deferred ✓ + +Async workers process images: +10. process_nats_pipeline_result called for each result +11. _update_job_progress("process", 0.5) → is_complete() = False +12. _update_job_progress("process", 1.0) → is_complete() = False (results still 0) +13. _update_job_progress("results", 0.5) → is_complete() = False +14. _update_job_progress("results", 1.0) → is_complete() = True ✓ +15. Job marked SUCCESS, cleanup runs +``` + +### Why there's no race: + +1. `is_complete()` requires ALL stages to have `progress >= 1.0` AND `status in final_states()` +2. For async jobs, incomplete stages have `progress < 1.0`, so `update_progress()` won't auto-upgrade them +3. The async handler updates stages separately (process, then results), so completion only triggers when the LAST stage reaches 100% +4. `select_for_update()` in `_update_job_progress` prevents concurrent updates to the same job + +--- + +## Risk Analysis + +**Core work functions that must NOT be affected:** +1. `queue_images_to_nats()` - queues images to NATS for async processing +2. `process_nats_pipeline_result()` - processes results from async workers +3. `pipeline.save_results()` - saves detections/classifications to database +4. `MLJob.process_images()` - synchronous image processing path + +**Changes and their risk:** + +| Change | Risk to Core Work | Analysis | +|--------|------------------|----------| +| Add `is_complete()` to JobProgress | **None** | Pure read-only check, doesn't modify any state | +| Guard SUCCESS in `task_postrun` | **None** | Only affects status display, not actual processing. Returns early without modifying anything. | +| Use `is_complete()` in `_update_job_progress` | **None** | Replaces hardcoded `stage == "results" and progress >= 1.0` check with generic method. Called AFTER `pipeline.save_results()` completes. | +| Add `_cleanup_async_job_resources()` | **Minimal** | Called ONLY after job is marked SUCCESS. Cleanup failures are caught and logged, don't fail the job. Data is already saved at this point. | + +**Worst case scenario:** If `is_complete()` has a bug and always returns False: +- Work still completes normally (queuing, processing, saving all happen) +- Job status would stay STARTED instead of SUCCESS +- UI would show job as incomplete even though work finished +- **Data is safe** - this is a display/status issue, not a data loss risk + +**Existing function to note:** `Job.update_progress()` at lines 924-947 calculates total progress from stages. The new `is_complete()` method is a related concept - checking if all stages are done vs calculating aggregate progress. + +--- + +## Future Improvements (out of scope) + +Per user feedback, consider for follow-up: +- Split job types into different classes with clearer state management +- More robust state machine for job lifecycle +- But don't reinvent state tracking on top of existing bg task tools + +These improvements should be noted in the Job class docstring. + +--- + +## Additional Documentation + +### Add note to Job class docstring + +**File:** `ami/jobs/models.py` (Job class, around line ~700) + +Add to the class docstring: +```python +# Future improvements: +# - Consider splitting job types into subclasses with clearer state management +# - The progress/stages system (see JobProgress, update_progress()) was designed +# for UI display. The is_complete() method bridges this to actual completion logic. +# - Avoid reinventing state tracking on top of existing background task tools +``` + +--- + +## Verification + +1. **Unit test for `is_complete()`:** + - No stages -> False + - One stage at progress=0.5, STARTED -> False + - One stage at progress=1.0, STARTED -> False (not final state) + - One stage at progress=1.0, SUCCESS -> True + - Multiple stages, all SUCCESS -> True + - Multiple stages, one still STARTED -> False + +2. **Unit test for task_postrun behavior:** + - SUCCESS with incomplete stages -> status NOT changed + - SUCCESS with complete stages -> status changed to SUCCESS + - FAILURE -> passes through immediately + - REVOKED -> passes through immediately + +3. **Integration test for async ML job:** + - Create job with async_pipeline_workers enabled + - Queue images to NATS + - Verify job status stays STARTED after run_job completes + - Process all images via process_nats_pipeline_result + - Verify job status becomes SUCCESS after all stages finish + - Verify finished_at is set + - Verify Redis/NATS cleanup occurred + +4. **Regression test for sync jobs:** + - Run sync ML job -> SUCCESS after task completes + - Run DataStorageSyncJob -> SUCCESS after task completes + - Run DataExportJob -> SUCCESS after task completes diff --git a/.envs/.ci/.django b/.envs/.ci/.django index bec17501b..6577ebece 100644 --- a/.envs/.ci/.django +++ b/.envs/.ci/.django @@ -26,3 +26,7 @@ CELERY_BROKER_URL=amqp://rabbituser:rabbitpass@rabbitmq:5672/ CELERY_RESULT_BACKEND=rpc:// # Use RabbitMQ for results backend RABBITMQ_DEFAULT_USER=rabbituser RABBITMQ_DEFAULT_PASS=rabbitpass + +# NATS +# ------------------------------------------------------------------------------ +NATS_URL=nats://nats:4222 diff --git a/.envs/.local/.django b/.envs/.local/.django index 29780e680..8eb5610f7 100644 --- a/.envs/.local/.django +++ b/.envs/.local/.django @@ -12,6 +12,9 @@ DJANGO_SUPERUSER_PASSWORD=localadmin # Redis REDIS_URL=redis://redis:6379/0 +# NATS +NATS_URL=nats://nats:4222 + # Celery / Flower CELERY_FLOWER_USER=QSocnxapfMvzLqJXSsXtnEZqRkBtsmKT CELERY_FLOWER_PASSWORD=BEQgmCtgyrFieKNoGTsux9YIye0I7P5Q7vEgfJD2C4jxmtHDetFaE2jhS7K7rxaf diff --git a/.envs/.production/.django-example b/.envs/.production/.django-example index a54d4ae60..93737d527 100644 --- a/.envs/.production/.django-example +++ b/.envs/.production/.django-example @@ -65,3 +65,7 @@ WEB_CONCURRENCY=4 DEFAULT_PROCESSING_SERVICE_NAME="AMI Data Companion" DEFAULT_PROCESSING_SERVICE_ENDPOINT=https://ml.antenna.insectai.org/ DEFAULT_PIPELINES_ENABLED=global_moths_2024,quebec_vermont_moths_2023,panama_moths_2023,uk_denmark_moths_2023 + +# NATS +# ------------------------------------------------------------------------------ +NATS_URL=nats://nats:4222 diff --git a/README.md b/README.md index 16ddbb07f..4f1a00b2a 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,7 @@ docker compose -f processing_services/example/docker-compose.yml up -d - Django admin: http://localhost:8000/admin/ - OpenAPI / Swagger documentation: http://localhost:8000/api/v2/docs/ - Minio UI: http://minio:9001, Minio service: http://minio:9000 +- NATS dashboard: https://natsdashboard.com/ (Add localhost) NOTE: If one of these services is not working properly, it could be due another process is using the port. You can check for this with `lsof -i :`. diff --git a/ami/jobs/models.py b/ami/jobs/models.py index b94baa9a2..482d01a58 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -322,15 +322,13 @@ def run(cls, job: "Job"): """ Procedure for an ML pipeline as a job. """ + from ami.ml.orchestration.jobs import queue_images_to_nats + job.update_status(JobState.STARTED) job.started_at = datetime.datetime.now() job.finished_at = None job.save() - # Keep track of sub-tasks for saving results, pair with batch number - save_tasks: list[tuple[int, AsyncResult]] = [] - save_tasks_completed: list[tuple[int, AsyncResult]] = [] - if job.delay: update_interval_seconds = 2 last_update = time.time() @@ -365,7 +363,7 @@ def run(cls, job: "Job"): progress=0, ) - images = list( + images: list[SourceImage] = list( # @TODO return generator plus image count # @TODO pass to celery group chain? job.pipeline.collect_images( @@ -389,8 +387,6 @@ def run(cls, job: "Job"): images = images[: job.limit] image_count = len(images) job.progress.add_stage_param("collect", "Limit", image_count) - else: - image_count = source_image_count job.progress.update_stage( "collect", @@ -401,6 +397,24 @@ def run(cls, job: "Job"): # End image collection stage job.save() + if job.project.feature_flags.async_pipeline_workers: + queued = queue_images_to_nats(job, images) + if not queued: + job.logger.error("Aborting job %s because images could not be queued to NATS", job.pk) + job.progress.update_stage("collect", status=JobState.FAILURE) + job.update_status(JobState.FAILURE) + job.finished_at = datetime.datetime.now() + job.save() + return + else: + cls.process_images(job, images) + + @classmethod + def process_images(cls, job, images): + image_count = len(images) + # Keep track of sub-tasks for saving results, pair with batch number + save_tasks: list[tuple[int, AsyncResult]] = [] + save_tasks_completed: list[tuple[int, AsyncResult]] = [] total_captures = 0 total_detections = 0 total_classifications = 0 diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 30c594141..6404512d1 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -1,7 +1,16 @@ +import datetime +import functools import logging +import time +from collections.abc import Callable +from asgiref.sync import async_to_sync from celery.signals import task_failure, task_postrun, task_prerun +from django.db import transaction +from ami.ml.orchestration.nats_queue import TaskQueueManager +from ami.ml.orchestration.task_state import TaskStateManager +from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse from ami.tasks import default_soft_time_limit, default_time_limit from config import celery_app @@ -29,6 +38,135 @@ def run_job(self, job_id: int) -> None: job.logger.info(f"Finished job {job}") +@celery_app.task( + bind=True, + max_retries=0, # don't retry since we already have retry logic in the NATS queue + soft_time_limit=300, # 5 minutes + time_limit=360, # 6 minutes +) +def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_subject: str) -> None: + """ + Process a single pipeline result asynchronously. + + This task: + 1. Deserializes the pipeline result + 2. Saves it to the database + 3. Updates progress by removing processed image IDs from Redis + 4. Acknowledges the task via NATS + + Args: + job_id: The job ID + result_data: Dictionary containing the pipeline result + reply_subject: NATS reply subject for acknowledgment + """ + from ami.jobs.models import Job # avoid circular import + + _, t = log_time() + + # Validate with Pydantic - check for error response first + if "error" in result_data: + error_result = PipelineResultsError(**result_data) + processed_image_ids = {str(error_result.image_id)} if error_result.image_id else set() + logger.error(f"Pipeline returned error for job {job_id}, image {error_result.image_id}: {error_result.error}") + pipeline_result = None + else: + pipeline_result = PipelineResultsResponse(**result_data) + processed_image_ids = {str(img.id) for img in pipeline_result.source_images} + + state_manager = TaskStateManager(job_id) + + progress_info = state_manager.update_state(processed_image_ids, stage="process", request_id=self.request.id) + if not progress_info: + logger.warning( + f"Another task is already processing results for job {job_id}. " + f"Retrying task {self.request.id} in 5 seconds..." + ) + raise self.retry(countdown=5, max_retries=10) + + try: + _update_job_progress(job_id, "process", progress_info.percentage) + + _, t = t(f"TIME: Updated job {job_id} progress in PROCESS stage progress to {progress_info.percentage*100}%") + job = Job.objects.get(pk=job_id) + job.logger.info(f"Processing pipeline result for job {job_id}, reply_subject: {reply_subject}") + job.logger.info( + f" Job {job_id} progress: {progress_info.processed}/{progress_info.total} images processed " + f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {len(processed_image_ids)} just " + "processed" + ) + except Job.DoesNotExist: + # don't raise and ack so that we don't retry since the job doesn't exists + logger.error(f"Job {job_id} not found") + _ack_task_via_nats(reply_subject, logger) + return + + try: + # Save to database (this is the slow operation) + if pipeline_result: + # should never happen since otherwise we could not be processing results here + assert job.pipeline is not None, "Job pipeline is None" + job.pipeline.save_results(results=pipeline_result, job_id=job.pk) + job.logger.info(f"Successfully saved results for job {job_id}") + + _, t = t( + f"Saved pipeline results to database with {len(pipeline_result.detections)} detections" + f", percentage: {progress_info.percentage*100}%" + ) + + _ack_task_via_nats(reply_subject, job.logger) + # Update job stage with calculated progress + progress_info = state_manager.update_state(processed_image_ids, stage="results", request_id=self.request.id) + + if not progress_info: + logger.warning( + f"Another task is already processing results for job {job_id}. " + f"Retrying task {self.request.id} in 5 seconds..." + ) + raise self.retry(countdown=5, max_retries=10) + _update_job_progress(job_id, "results", progress_info.percentage) + + except Exception as e: + job.logger.error( + f"Failed to process pipeline result for job {job_id}: {e}. NATS will redeliver the task message." + ) + + +def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None: + try: + + async def ack_task(): + async with TaskQueueManager() as manager: + return await manager.acknowledge_task(reply_subject) + + ack_success = async_to_sync(ack_task)() + + if ack_success: + job_logger.info(f"Successfully acknowledged task via NATS: {reply_subject}") + else: + job_logger.warning(f"Failed to acknowledge task via NATS: {reply_subject}") + except Exception as ack_error: + job_logger.error(f"Error acknowledging task via NATS: {ack_error}") + # Don't fail the task if ACK fails - data is already saved + + +def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> None: + from ami.jobs.models import Job, JobState # avoid circular import + + with transaction.atomic(): + job = Job.objects.select_for_update().get(pk=job_id) + job.progress.update_stage( + stage, + status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, + progress=progress_percentage, + ) + if stage == "results" and progress_percentage >= 1.0: + job.status = JobState.SUCCESS + job.progress.summary.status = JobState.SUCCESS + job.finished_at = datetime.datetime.now() # Use naive datetime in local time + job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%") + job.save() + + @task_prerun.connect(sender=run_job) def pre_update_job_status(sender, task_id, task, **kwargs): # in the prerun signal, set the job status to PENDING @@ -65,3 +203,28 @@ def update_job_failure(sender, task_id, exception, *args, **kwargs): job.logger.error(f'Job #{job.pk} "{job.name}" failed: {exception}') job.save() + + +def log_time(start: float = 0, msg: str | None = None) -> tuple[float, Callable]: + """ + Small helper to measure time between calls. + + Returns: elapsed time since the last call, and a partial function to measure from the current call + Usage: + + _, tlog = log_time() + # do something + _, tlog = tlog("Did something") # will log the time taken by 'something' + # do something else + t, tlog = tlog("Did something else") # will log the time taken by 'something else', returned as 't' + """ + + end = time.perf_counter() + if start == 0: + dur = 0.0 + else: + dur = end - start + if msg and start > 0: + logger.info(f"{msg}: {dur:.3f}s") + new_start = time.perf_counter() + return dur, functools.partial(log_time, new_start) diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index 4d61a9ea4..8e04d9dd9 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -11,6 +11,7 @@ from ami.jobs.models import Job, JobProgress, JobState, MLJob, SourceImageCollectionPopulateJob from ami.main.models import Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline +from ami.ml.orchestration.jobs import queue_images_to_nats from ami.users.models import User logger = logging.getLogger(__name__) @@ -326,6 +327,15 @@ def test_search_jobs(self): def _task_batch_helper(self, value: Any, expected_status: int): pipeline = self._create_pipeline() job = self._create_ml_job("Job for batch test", pipeline) + images = [ + SourceImage.objects.create( + path=f"image_{i}.jpg", + public_base_url="http://example.com", + project=self.project, + ) + for i in range(8) # more than 5 since we test with batch=5 + ] + queue_images_to_nats(job, images) self.client.force_authenticate(user=self.user) tasks_url = reverse_with_params( @@ -390,10 +400,9 @@ def test_result_endpoint_stub(self): self.assertEqual(resp.status_code, 200) data = resp.json() - self.assertEqual(data["status"], "received") + self.assertEqual(data["status"], "accepted") self.assertEqual(data["job_id"], job.pk) - self.assertEqual(data["results_received"], 1) - self.assertIn("message", data) + self.assertEqual(data["results_queued"], 1) def test_result_endpoint_validation(self): """Test the result endpoint validates request data.""" diff --git a/ami/jobs/views.py b/ami/jobs/views.py index fb94fd60b..a2f087cf5 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -1,6 +1,7 @@ import logging import pydantic +from asgiref.sync import async_to_sync from django.db.models import Q from django.db.models.query import QuerySet from django.forms import IntegerField @@ -15,11 +16,10 @@ from ami.base.permissions import ObjectPermission from ami.base.views import ProjectMixin from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param +from ami.jobs.tasks import process_nats_pipeline_result from ami.main.api.schemas import project_id_doc_param - -# from ami.jobs.tasks import process_pipeline_result # TODO: Uncomment when available in main from ami.main.api.views import DefaultViewSet -from ami.ml.schemas import PipelineProcessingTask, PipelineTaskResult +from ami.ml.schemas import PipelineTaskResult from ami.utils.fields import url_boolean_param from .models import Job, JobState @@ -232,33 +232,34 @@ def tasks(self, request, pk=None): if not job.pipeline: raise ValidationError("This job does not have a pipeline configured") - # TODO: Implement task queue integration - logger.warning(f"Task queue endpoint called for job {job.pk} but the implementation is not yet available.") + # Get tasks from NATS JetStream + from ami.ml.orchestration.nats_queue import TaskQueueManager - dummy_task = PipelineProcessingTask( - id="1", - image_id="1", - image_url="http://example.com/image1", - queue_timestamp=timezone.now().isoformat(), - ) + async def get_tasks(): + tasks = [] + async with TaskQueueManager() as manager: + for _ in range(batch): + task = await manager.reserve_task(job.pk, timeout=0.1) + if task: + tasks.append(task.dict()) + return tasks - # @TODO when this gets fully implemented, use a Serializer or Pydantic schema - # for the full repsponse structure. - return Response({"tasks": [task.dict() for task in [dummy_task] * batch]}) + # Use async_to_sync to properly handle the async call + tasks = async_to_sync(get_tasks)() + + return Response({"tasks": tasks}) @action(detail=True, methods=["post"], name="result") def result(self, request, pk=None): """ - Submit pipeline results for asynchronous processing. + The request body should be a list of results: list[PipelineTaskResult] This endpoint accepts a list of pipeline results and queues them for - background processing. Each result will be validated and saved. - - The request body should be a list of results: list[PipelineTaskResult] + background processing. Each result will be validated, saved to the database, + and acknowledged via NATS in a Celery task. """ job = self.get_object() - job_id = job.pk # Validate request data is a list if isinstance(request.data, list): @@ -267,32 +268,55 @@ def result(self, request, pk=None): results = [request.data] try: - queued_tasks = [] + # Pre-validate all results before enqueuing any tasks + # This prevents partial queueing and duplicate task processing + validated_results = [] for item in results: task_result = PipelineTaskResult(**item) - # Stub: Log that we received the result but don't process it yet - logger.warning( - f"Result endpoint called for job {job_id} (reply_subject: {task_result.reply_subject}) " - "but result processing not yet available." + validated_results.append(task_result) + + # All validation passed, now queue all tasks + queued_tasks = [] + for task_result in validated_results: + reply_subject = task_result.reply_subject + result_data = task_result.result + + # Queue the background task + # Convert Pydantic model to dict for JSON serialization + task = process_nats_pipeline_result.delay( + job_id=job.pk, result_data=result_data.dict(), reply_subject=reply_subject ) - # TODO: Implement result storage and processing queued_tasks.append( { - "reply_subject": task_result.reply_subject, - "status": "pending_implementation", - "message": "Result processing not yet implemented.", + "reply_subject": reply_subject, + "status": "queued", + "task_id": task.id, } ) + + logger.info( + f"Queued pipeline result processing for job {job.pk}, " + f"task_id: {task.id}, reply_subject: {reply_subject}" + ) + + return Response( + { + "status": "accepted", + "job_id": job.pk, + "results_queued": len([t for t in queued_tasks if t["status"] == "queued"]), + "tasks": queued_tasks, + } + ) except pydantic.ValidationError as e: raise ValidationError(f"Invalid result data: {e}") from e - return Response( - { - "status": "received", - "job_id": job_id, - "results_received": len(queued_tasks), - "tasks": queued_tasks, - "message": "Result processing not yet implemented.", - } - ) + except Exception as e: + logger.error(f"Failed to queue pipeline results for job {job.pk}: {e}") + return Response( + { + "status": "error", + "job_id": job.pk, + }, + status=500, + ) diff --git a/ami/main/models.py b/ami/main/models.py index 6f67b3e3c..1946ec3cf 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -218,6 +218,7 @@ class ProjectFeatureFlags(pydantic.BaseModel): default_filters: bool = False # Whether to show default filters form in UI # Feature flag for jobs to reprocess all images in the project, even if already processed reprocess_all_images: bool = False + async_pipeline_workers: bool = False # Whether to use async pipeline workers that pull tasks from a queue def get_default_feature_flags() -> ProjectFeatureFlags: diff --git a/ami/ml/orchestration/__init__.py b/ami/ml/orchestration/__init__.py index d05bbbd82..75c2ec3b5 100644 --- a/ami/ml/orchestration/__init__.py +++ b/ami/ml/orchestration/__init__.py @@ -1 +1,5 @@ -from .processing import * # noqa: F401, F403 +# cgjs: This creates a circular import: +# - ami.jobs.models imports ami.jobs.tasks.run_job +# - ami.jobs.tasks imports ami.ml.orchestration +# -.processing imports ami.jobs.models +# from .processing import * # noqa: F401, F403 diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py new file mode 100644 index 000000000..621d5f089 --- /dev/null +++ b/ami/ml/orchestration/jobs.py @@ -0,0 +1,114 @@ +from asgiref.sync import async_to_sync + +from ami.jobs.models import Job, JobState, logger +from ami.main.models import SourceImage +from ami.ml.orchestration.nats_queue import TaskQueueManager +from ami.ml.orchestration.task_state import TaskStateManager +from ami.ml.schemas import PipelineProcessingTask + + +# TODO CGJS: (Issue #1083) Call this once a job is fully complete (all images processed and saved) +def cleanup_nats_resources(job: "Job") -> bool: + """ + Clean up NATS JetStream resources (stream and consumer) for a completed job. + + Args: + job: The Job instance + Returns: + bool: True if cleanup was successful, False otherwise + """ + + async def cleanup(): + async with TaskQueueManager() as manager: + return await manager.cleanup_job_resources(job.pk) + + return async_to_sync(cleanup)() + + +def queue_images_to_nats(job: "Job", images: list[SourceImage]): + """ + Queue all images for a job to a NATS JetStream stream for the job. + + Args: + job: The Job instance + images: List of SourceImage instances to queue + + Returns: + bool: True if all images were successfully queued, False otherwise + """ + job.logger.info(f"Queuing {len(images)} images to NATS stream for job '{job.pk}'") + + # Prepare all messages outside of async context to avoid Django ORM issues + tasks: list[tuple[int, PipelineProcessingTask]] = [] + image_ids = [] + skipped_count = 0 + for image in images: + image_id = str(image.pk) + image_url = image.url() if hasattr(image, "url") and image.url() else "" + if not image_url: + job.logger.warning(f"Image {image.pk} has no URL, skipping queuing to NATS for job '{job.pk}'") + skipped_count += 1 + continue + image_ids.append(image_id) + task = PipelineProcessingTask( + id=image_id, + image_id=image_id, + image_url=image_url, + ) + tasks.append((image.pk, task)) + + # Store all image IDs in Redis for progress tracking + state_manager = TaskStateManager(job.pk) + state_manager.initialize_job(image_ids) + job.logger.info(f"Initialized task state tracking for {len(image_ids)} images") + + async def queue_all_images(): + successful_queues = 0 + failed_queues = 0 + + async with TaskQueueManager() as manager: + for image_pk, task in tasks: + try: + logger.info(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}") + success = await manager.publish_task( + job_id=job.pk, + data=task, + ) + except Exception as e: + logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}") + success = False + + if success: + successful_queues += 1 + else: + failed_queues += 1 + + return successful_queues, failed_queues + + if tasks: + successful_queues, failed_queues = async_to_sync(queue_all_images)() + # Add skipped images to failed count + failed_queues += skipped_count + else: + # If no tasks but there are skipped images, mark as failed + if skipped_count > 0: + job.progress.update_stage("process", status=JobState.FAILURE, progress=1.0) + job.progress.update_stage("results", status=JobState.FAILURE, progress=1.0) + else: + job.progress.update_stage("process", status=JobState.SUCCESS, progress=1.0) + job.progress.update_stage("results", status=JobState.SUCCESS, progress=1.0) + job.save() + successful_queues, failed_queues = 0, skipped_count + + # Log results (back in sync context) + if successful_queues > 0: + job.logger.info(f"Successfully queued {successful_queues}/{len(images)} images to stream for job '{job.pk}'") + + if failed_queues > 0: + job.logger.warning( + f"Failed to queue {failed_queues}/{len(images)} images to stream for job '{job.pk}' (including " + f"{skipped_count} skipped images)" + ) + return False + + return True diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py new file mode 100644 index 000000000..fa7188627 --- /dev/null +++ b/ami/ml/orchestration/nats_queue.py @@ -0,0 +1,300 @@ +""" +NATS JetStream utility for task queue management in the antenna project. + +This module provides a TaskQueueManager that uses NATS JetStream for distributed +task queuing with acknowledgment support via reply subjects. This allows workers +to pull tasks over HTTP and acknowledge them later without maintaining a persistent +connection to NATS. + +Other queue systems were considered, such as RabbitMQ and Beanstalkd. However, they don't +support the visibility timeout semantics we want or a disconnected mode of pulling and ACKing tasks. +""" + +import json +import logging + +import nats +from django.conf import settings +from nats.js import JetStreamContext +from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy + +from ami.ml.schemas import PipelineProcessingTask + +logger = logging.getLogger(__name__) + + +async def get_connection(nats_url: str): + nc = await nats.connect(nats_url) + js = nc.jetstream() + return nc, js + + +TASK_TTR = 300 # Default Time-To-Run (visibility timeout) in seconds + + +class TaskQueueManager: + """ + Manager for NATS JetStream task queue operations. + + Use as an async context manager: + async with TaskQueueManager() as manager: + await manager.publish_task('job123', {'data': 'value'}) + task = await manager.reserve_task('job123') + await manager.acknowledge_task(task['reply_subject']) + """ + + def __init__(self, nats_url: str | None = None): + self.nats_url = nats_url or getattr(settings, "NATS_URL", "nats://nats:4222") + self.nc: nats.NATS | None = None + self.js: JetStreamContext | None = None + + async def __aenter__(self): + """Create connection on enter.""" + self.nc, self.js = await get_connection(self.nats_url) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.js: + self.js = None + if self.nc and not self.nc.is_closed: + await self.nc.close() + self.nc = None + + return False + + def _get_stream_name(self, job_id: int) -> str: + """Get stream name from job_id.""" + return f"job_{job_id}" + + def _get_subject(self, job_id: int) -> str: + """Get subject name from job_id.""" + return f"job.{job_id}.tasks" + + def _get_consumer_name(self, job_id: int) -> str: + """Get consumer name from job_id.""" + return f"job-{job_id}-consumer" + + async def _ensure_stream(self, job_id: int): + """Ensure stream exists for the given job.""" + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + stream_name = self._get_stream_name(job_id) + subject = self._get_subject(job_id) + + try: + await self.js.stream_info(stream_name) + logger.debug(f"Stream {stream_name} already exists") + except Exception as e: + logger.warning(f"Stream {stream_name} does not exist: {e}") + # Stream doesn't exist, create it + await self.js.add_stream( + name=stream_name, + subjects=[subject], + max_age=86400, # 24 hours retention + ) + logger.info(f"Created stream {stream_name}") + + async def _ensure_consumer(self, job_id: int): + """Ensure consumer exists for the given job.""" + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + stream_name = self._get_stream_name(job_id) + consumer_name = self._get_consumer_name(job_id) + subject = self._get_subject(job_id) + + try: + info = await self.js.consumer_info(stream_name, consumer_name) + logger.debug(f"Consumer {consumer_name} already exists: {info}") + except Exception: + # Consumer doesn't exist, create it + await self.js.add_consumer( + stream=stream_name, + config=ConsumerConfig( + durable_name=consumer_name, + ack_policy=AckPolicy.EXPLICIT, + ack_wait=TASK_TTR, # Visibility timeout (TTR) + max_deliver=5, # Max retry attempts + deliver_policy=DeliverPolicy.ALL, + max_ack_pending=100, # Max unacked messages + filter_subject=subject, + ), + ) + logger.info(f"Created consumer {consumer_name}") + + async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool: + """ + Publish a task to it's job queue. + + Args: + job_id: The job ID (integer primary key) + data: PipelineProcessingTask object to be published + + Returns: + bool: True if successful, False otherwise + """ + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + try: + # Ensure stream and consumer exist + await self._ensure_stream(job_id) + await self._ensure_consumer(job_id) + + subject = self._get_subject(job_id) + # Convert Pydantic model to JSON + task_data = json.dumps(data.dict()) + + # Publish to JetStream + ack = await self.js.publish(subject, task_data.encode()) + + logger.info(f"Published task to stream for job '{job_id}', sequence {ack.seq}") + return True + + except Exception as e: + logger.error(f"Failed to publish task to stream for job '{job_id}': {e}") + return False + + async def reserve_task(self, job_id: int, timeout: float | None = None) -> PipelineProcessingTask | None: + """ + Reserve a task from the specified stream. + + Args: + job_id: The job ID (integer primary key) to pull tasks from + timeout: Timeout in seconds for reservation (default: 5 seconds) + + Returns: + PipelineProcessingTask with reply_subject set for acknowledgment, or None if no task available + """ + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + if timeout is None: + timeout = 5 + + try: + # Ensure stream and consumer exist + await self._ensure_stream(job_id) + await self._ensure_consumer(job_id) + + consumer_name = self._get_consumer_name(job_id) + subject = self._get_subject(job_id) + + # Create ephemeral subscription for this pull + psub = await self.js.pull_subscribe(subject, consumer_name) + + try: + # Fetch a single message + msgs = await psub.fetch(1, timeout=timeout) + + if msgs: + msg = msgs[0] + task_data = json.loads(msg.data.decode()) + metadata = msg.metadata + + # Parse the task data into PipelineProcessingTask + task = PipelineProcessingTask(**task_data) + # Set the reply_subject for acknowledgment + task.reply_subject = msg.reply + + logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") + return task + + except nats.errors.TimeoutError: + # No messages available + logger.debug(f"No tasks available in stream for job '{job_id}'") + return None + finally: + # Always unsubscribe + await psub.unsubscribe() + + except Exception as e: + logger.error(f"Failed to reserve task from stream for job '{job_id}': {e}") + return None + + async def acknowledge_task(self, reply_subject: str) -> bool: + """ + Acknowledge (delete) a completed task using its reply subject. + + Args: + reply_subject: The reply subject from reserve_task + + Returns: + bool: True if successful + """ + if self.nc is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + try: + await self.nc.publish(reply_subject, b"+ACK") + logger.debug(f"Acknowledged task with reply subject {reply_subject}") + return True + except Exception as e: + logger.error(f"Failed to acknowledge task: {e}") + return False + + async def delete_consumer(self, job_id: int) -> bool: + """ + Delete the consumer for a job. + + Args: + job_id: The job ID (integer primary key) + + Returns: + bool: True if successful, False otherwise + """ + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + try: + stream_name = self._get_stream_name(job_id) + consumer_name = self._get_consumer_name(job_id) + + await self.js.delete_consumer(stream_name, consumer_name) + logger.info(f"Deleted consumer {consumer_name} for job '{job_id}'") + return True + except Exception as e: + logger.error(f"Failed to delete consumer for job '{job_id}': {e}") + return False + + async def delete_stream(self, job_id: int) -> bool: + """ + Delete the stream for a job. + + Args: + job_id: The job ID (integer primary key) + + Returns: + bool: True if successful, False otherwise + """ + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + try: + stream_name = self._get_stream_name(job_id) + + await self.js.delete_stream(stream_name) + logger.info(f"Deleted stream {stream_name} for job '{job_id}'") + return True + except Exception as e: + logger.error(f"Failed to delete stream for job '{job_id}': {e}") + return False + + async def cleanup_job_resources(self, job_id: int) -> bool: + """ + Clean up all NATS resources (consumer and stream) for a job. + + This should be called when a job completes or is cancelled. + + Args: + job_id: The job ID (integer primary key) + + Returns: + bool: True if successful, False otherwise + """ + # Delete consumer first, then stream + consumer_deleted = await self.delete_consumer(job_id) + stream_deleted = await self.delete_stream(job_id) + + return consumer_deleted and stream_deleted diff --git a/ami/ml/orchestration/task_state.py b/ami/ml/orchestration/task_state.py new file mode 100644 index 000000000..483275453 --- /dev/null +++ b/ami/ml/orchestration/task_state.py @@ -0,0 +1,125 @@ +""" +Task state management for job progress tracking using Redis. +""" + +import logging +from collections import namedtuple + +from django.core.cache import cache + +logger = logging.getLogger(__name__) + + +# Define a namedtuple for a TaskProgress with the image counts +TaskProgress = namedtuple("TaskProgress", ["remaining", "total", "processed", "percentage"]) + + +class TaskStateManager: + """ + Manages job progress tracking state in Redis. + + Tracks pending images for jobs to calculate progress percentages + as workers process images asynchronously. + """ + + TIMEOUT = 86400 * 7 # 7 days in seconds + STAGES = ["process", "results"] + + def __init__(self, job_id: int): + """ + Initialize the task state manager for a specific job. + + Args: + job_id: The job primary key + """ + self.job_id = job_id + self._pending_key = f"job:{job_id}:pending_images" + self._total_key = f"job:{job_id}:pending_images_total" + + def initialize_job(self, image_ids: list[str]) -> None: + """ + Initialize job tracking with a list of image IDs to process. + + Args: + image_ids: List of image IDs that need to be processed + """ + for stage in self.STAGES: + cache.set(self._get_pending_key(stage), image_ids, timeout=self.TIMEOUT) + + cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT) + + def _get_pending_key(self, stage: str) -> str: + return f"{self._pending_key}:{stage}" + + def update_state( + self, + processed_image_ids: set[str], + stage: str, + request_id: str, + ) -> None | TaskProgress: + """ + Update the task state with newly processed images. + + Args: + processed_image_ids: Set of image IDs that have just been processed + """ + # Create a unique lock key for this job + lock_key = f"job:{self.job_id}:process_results_lock" + lock_timeout = 360 # 6 minutes (matches task time_limit) + lock_acquired = cache.add(lock_key, request_id, timeout=lock_timeout) + if not lock_acquired: + return None + + try: + # Update progress tracking in Redis + progress_info = self._get_progress(processed_image_ids, stage) + return progress_info + finally: + # Always release the lock when done + current_lock_value = cache.get(lock_key) + # Only delete if we still own the lock (prevents race condition) + if current_lock_value == request_id: + cache.delete(lock_key) + logger.debug(f"Released lock for job {self.job_id}, task {request_id}") + + def _get_progress(self, processed_image_ids: set[str], stage: str) -> TaskProgress | None: + """ + Get current progress information for the job. + + Returns: + TaskProgress namedtuple with fields: + - remaining: Number of images still pending (or None if not tracked) + - total: Total number of images (or None if not tracked) + - processed: Number of images processed (or None if not tracked) + - percentage: Progress as float 0.0-1.0 (or None if not tracked) + """ + pending_images = cache.get(self._get_pending_key(stage)) + total_images = cache.get(self._total_key) + if pending_images is None or total_images is None: + return None + remaining_images = [img_id for img_id in pending_images if img_id not in processed_image_ids] + assert len(pending_images) >= len(remaining_images) + cache.set(self._get_pending_key(stage), remaining_images, timeout=self.TIMEOUT) + + remaining = len(remaining_images) + processed = total_images - remaining + percentage = float(processed) / total_images if total_images > 0 else 1.0 + logger.info( + f"Pending images from Redis for job {self.job_id} {stage}: " + f"{remaining}/{total_images}: {percentage*100}%" + ) + + return TaskProgress( + remaining=remaining, + total=total_images, + processed=processed, + percentage=percentage, + ) + + def cleanup(self) -> None: + """ + Delete all Redis keys associated with this job. + """ + for stage in self.STAGES: + cache.delete(self._get_pending_key(stage)) + cache.delete(self._total_key) diff --git a/ami/ml/orchestration/test_nats_queue.py b/ami/ml/orchestration/test_nats_queue.py new file mode 100644 index 000000000..0cd2c3bef --- /dev/null +++ b/ami/ml/orchestration/test_nats_queue.py @@ -0,0 +1,150 @@ +"""Unit tests for TaskQueueManager.""" + +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from ami.ml.orchestration.nats_queue import TaskQueueManager +from ami.ml.schemas import PipelineProcessingTask + + +class TestTaskQueueManager(unittest.IsolatedAsyncioTestCase): + """Test suite for TaskQueueManager.""" + + def _create_sample_task(self): + """Helper to create a sample PipelineProcessingTask.""" + return PipelineProcessingTask( + id="task-123", + image_id="img-456", + image_url="https://example.com/image.jpg", + ) + + def _create_mock_nats_connection(self): + """Helper to create mock NATS connection and JetStream context.""" + nc = MagicMock() + nc.is_closed = False + nc.close = AsyncMock() + + js = MagicMock() + js.stream_info = AsyncMock() + js.add_stream = AsyncMock() + js.add_consumer = AsyncMock() + js.consumer_info = AsyncMock() + js.publish = AsyncMock(return_value=MagicMock(seq=1)) + js.pull_subscribe = AsyncMock() + js.delete_consumer = AsyncMock() + js.delete_stream = AsyncMock() + + return nc, js + + async def test_context_manager_lifecycle(self): + """Test that context manager properly opens and closes connections.""" + nc, js = self._create_mock_nats_connection() + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager("nats://test:4222") as manager: + self.assertIsNotNone(manager.nc) + self.assertIsNotNone(manager.js) + + nc.close.assert_called_once() + + async def test_publish_task_creates_stream_and_consumer(self): + """Test that publish_task ensures stream and consumer exist.""" + nc, js = self._create_mock_nats_connection() + sample_task = self._create_sample_task() + js.stream_info.side_effect = Exception("Not found") + js.consumer_info.side_effect = Exception("Not found") + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + await manager.publish_task(456, sample_task) + + js.add_stream.assert_called_once() + self.assertIn("job_456", str(js.add_stream.call_args)) + js.add_consumer.assert_called_once() + + async def test_reserve_task_success(self): + """Test successful task reservation.""" + nc, js = self._create_mock_nats_connection() + sample_task = self._create_sample_task() + + # Mock message with task data + mock_msg = MagicMock() + mock_msg.data = sample_task.json().encode() + mock_msg.reply = "reply.subject.123" + mock_msg.metadata = MagicMock(sequence=MagicMock(stream=1)) + + mock_psub = MagicMock() + mock_psub.fetch = AsyncMock(return_value=[mock_msg]) + mock_psub.unsubscribe = AsyncMock() + js.pull_subscribe = AsyncMock(return_value=mock_psub) + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + task = await manager.reserve_task(123) + + self.assertIsNotNone(task) + self.assertEqual(task.id, sample_task.id) + self.assertEqual(task.reply_subject, "reply.subject.123") + mock_psub.unsubscribe.assert_called_once() + + async def test_reserve_task_no_messages(self): + """Test reserve_task when no messages are available.""" + nc, js = self._create_mock_nats_connection() + + mock_psub = MagicMock() + mock_psub.fetch = AsyncMock(return_value=[]) + mock_psub.unsubscribe = AsyncMock() + js.pull_subscribe = AsyncMock(return_value=mock_psub) + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + task = await manager.reserve_task(123) + + self.assertIsNone(task) + mock_psub.unsubscribe.assert_called_once() + + async def test_acknowledge_task_success(self): + """Test successful task acknowledgment.""" + nc, js = self._create_mock_nats_connection() + nc.publish = AsyncMock() + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + result = await manager.acknowledge_task("reply.subject.123") + + self.assertTrue(result) + nc.publish.assert_called_once_with("reply.subject.123", b"+ACK") + + async def test_cleanup_job_resources(self): + """Test cleanup of job resources (consumer and stream).""" + nc, js = self._create_mock_nats_connection() + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + result = await manager.cleanup_job_resources(123) + + self.assertTrue(result) + js.delete_consumer.assert_called_once() + js.delete_stream.assert_called_once() + + async def test_naming_conventions(self): + """Test stream, subject, and consumer naming conventions.""" + manager = TaskQueueManager() + + self.assertEqual(manager._get_stream_name(123), "job_123") + self.assertEqual(manager._get_subject(123), "job.123.tasks") + self.assertEqual(manager._get_consumer_name(123), "job-123-consumer") + + async def test_operations_without_connection_raise_error(self): + """Test that operations without connection raise RuntimeError.""" + manager = TaskQueueManager() + sample_task = self._create_sample_task() + + with self.assertRaisesRegex(RuntimeError, "Connection is not open"): + await manager.publish_task(123, sample_task) + + with self.assertRaisesRegex(RuntimeError, "Connection is not open"): + await manager.reserve_task(123) + + with self.assertRaisesRegex(RuntimeError, "Connection is not open"): + await manager.delete_stream(123) diff --git a/ami/ml/schemas.py b/ami/ml/schemas.py index a70a89308..0de43497c 100644 --- a/ami/ml/schemas.py +++ b/ami/ml/schemas.py @@ -241,6 +241,13 @@ class PipelineResultsResponse(pydantic.BaseModel): errors: list | str | None = None +class PipelineResultsError(pydantic.BaseModel): + """Error result when pipeline processing fails for an image.""" + + error: str + image_id: str | None = None + + class PipelineProcessingTask(pydantic.BaseModel): """ A task representing a single image or detection to be processed in an async pipeline. @@ -249,7 +256,6 @@ class PipelineProcessingTask(pydantic.BaseModel): id: str image_id: str image_url: str - queue_timestamp: str reply_subject: str | None = None # The NATS subject to send the result to # TODO: Do we need these? # detections: list[DetectionRequest] | None = None @@ -259,10 +265,12 @@ class PipelineProcessingTask(pydantic.BaseModel): class PipelineTaskResult(pydantic.BaseModel): """ The result from processing a single PipelineProcessingTask. + + Note: this schema is called `AntennaTaskResult` in the ADC worker processing service. """ reply_subject: str # The reply_subject from the PipelineProcessingTask - result: PipelineResultsResponse + result: PipelineResultsResponse | PipelineResultsError class PipelineStageParam(pydantic.BaseModel): diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 14e4374f2..20e0368fe 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -855,3 +855,127 @@ def test_small_size_filter_assigns_not_identifiable(self): not_identifiable_taxon, f"Occurrence {occurrence.pk} should have its determination set to 'Not identifiable'.", ) + + +class TestTaskStateManager(TestCase): + """Test TaskStateManager for job progress tracking.""" + + def setUp(self): + """Set up test fixtures.""" + from django.core.cache import cache + + from ami.ml.orchestration.task_state import TaskStateManager + + cache.clear() + self.job_id = 123 + self.manager = TaskStateManager(self.job_id) + self.image_ids = ["img1", "img2", "img3", "img4", "img5"] + + def _init_and_verify(self, image_ids): + """Helper to initialize job and verify initial state.""" + self.manager.initialize_job(image_ids) + progress = self.manager._get_progress(set(), "process") + assert progress is not None + self.assertEqual(progress.total, len(image_ids)) + self.assertEqual(progress.remaining, len(image_ids)) + self.assertEqual(progress.processed, 0) + self.assertEqual(progress.percentage, 0.0) + return progress + + def test_initialize_job(self): + """Test job initialization sets up tracking for all stages.""" + self._init_and_verify(self.image_ids) + + # Verify both stages are initialized + for stage in self.manager.STAGES: + progress = self.manager._get_progress(set(), stage) + assert progress is not None + self.assertEqual(progress.total, len(self.image_ids)) + + def test_progress_tracking(self): + """Test progress updates correctly as images are processed.""" + self._init_and_verify(self.image_ids) + + # Process 2 images + progress = self.manager._get_progress({"img1", "img2"}, "process") + assert progress is not None + self.assertEqual(progress.remaining, 3) + self.assertEqual(progress.processed, 2) + self.assertEqual(progress.percentage, 0.4) + + # Process 2 more images + progress = self.manager._get_progress({"img3", "img4"}, "process") + assert progress is not None + self.assertEqual(progress.remaining, 1) + self.assertEqual(progress.processed, 4) + self.assertEqual(progress.percentage, 0.8) + + # Process last image + progress = self.manager._get_progress({"img5"}, "process") + assert progress is not None + self.assertEqual(progress.remaining, 0) + self.assertEqual(progress.processed, 5) + self.assertEqual(progress.percentage, 1.0) + + def test_update_state_with_locking(self): + """Test update_state acquires lock, updates progress, and releases lock.""" + from django.core.cache import cache + + self._init_and_verify(self.image_ids) + + # First update should succeed + progress = self.manager.update_state({"img1", "img2"}, "process", "task1") + assert progress is not None + self.assertEqual(progress.processed, 2) + + # Simulate concurrent update by holding the lock + lock_key = f"job:{self.job_id}:process_results_lock" + cache.set(lock_key, "other_task", timeout=60) + + # Update should fail (lock held by another task) + progress = self.manager.update_state({"img3"}, "process", "task1") + self.assertIsNone(progress) + + # Release the lock and retry + cache.delete(lock_key) + progress = self.manager.update_state({"img3"}, "process", "task1") + assert progress is not None + self.assertEqual(progress.processed, 3) + + def test_stages_independent(self): + """Test that different stages track progress independently.""" + self._init_and_verify(self.image_ids) + + # Update process stage + self.manager._get_progress({"img1", "img2"}, "process") + progress_process = self.manager._get_progress(set(), "process") + assert progress_process is not None + self.assertEqual(progress_process.remaining, 3) + + # Results stage should still have all images pending + progress_results = self.manager._get_progress(set(), "results") + assert progress_results is not None + self.assertEqual(progress_results.remaining, 5) + + def test_empty_job(self): + """Test handling of job with no images.""" + self.manager.initialize_job([]) + progress = self.manager._get_progress(set(), "process") + assert progress is not None + self.assertEqual(progress.total, 0) + self.assertEqual(progress.percentage, 1.0) # Empty job is 100% complete + + def test_cleanup(self): + """Test cleanup removes all tracking keys.""" + self._init_and_verify(self.image_ids) + + # Verify keys exist + progress = self.manager._get_progress(set(), "process") + self.assertIsNotNone(progress) + + # Cleanup + self.manager.cleanup() + + # Verify keys are gone + progress = self.manager._get_progress(set(), "process") + self.assertIsNone(progress) diff --git a/ami/utils/requests.py b/ami/utils/requests.py index c4396b725..e4de57c0f 100644 --- a/ami/utils/requests.py +++ b/ami/utils/requests.py @@ -2,6 +2,8 @@ import requests from django.forms import BooleanField, FloatField +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import OpenApiParameter from requests.adapters import HTTPAdapter from rest_framework.request import Request from urllib3.util import Retry @@ -142,3 +144,30 @@ def get_default_classification_threshold(project: "Project | None" = None, reque return project.default_filters_score_threshold else: return default_threshold + + +project_id_doc_param = OpenApiParameter( + name="project_id", + description="Filter by project ID", + required=False, + type=int, +) + +ids_only_param = OpenApiParameter( + name="ids_only", + description="Return only job IDs instead of full job objects", + required=False, + type=OpenApiTypes.BOOL, +) +incomplete_only_param = OpenApiParameter( + name="incomplete_only", + description="Filter to only show incomplete jobs (excludes SUCCESS, FAILURE, REVOKED)", + required=False, + type=OpenApiTypes.BOOL, +) +batch_param = OpenApiParameter( + name="batch", + description="Number of tasks to pull in the batch", + required=False, + type=OpenApiTypes.INT, +) diff --git a/config/settings/base.py b/config/settings/base.py index 8385c38f3..dad65ce21 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -263,6 +263,10 @@ } REDIS_URL = env("REDIS_URL", default=None) +# NATS +# ------------------------------------------------------------------------------ +NATS_URL = env("NATS_URL", default="nats://localhost:4222") # type: ignore[no-untyped-call] + # ADMIN # ------------------------------------------------------------------------------ # Django Admin URL. diff --git a/docker-compose.ci.yml b/docker-compose.ci.yml index 8e93b684d..57f6fbc9f 100644 --- a/docker-compose.ci.yml +++ b/docker-compose.ci.yml @@ -22,6 +22,7 @@ services: - minio-init - ml_backend - rabbitmq + - nats env_file: - ./.envs/.ci/.django - ./.envs/.ci/.postgres @@ -39,6 +40,17 @@ services: redis: image: redis:6 + nats: + image: nats:2.10-alpine + container_name: ami_ci_nats + hostname: nats + command: ["-js", "-m", "8222"] # Enable JetStream and monitoring + healthcheck: + test: ["CMD", "wget", "--spider", "-q", "http://localhost:8222/healthz"] + interval: 10s + timeout: 5s + retries: 3 + celeryworker: <<: *django depends_on: @@ -58,7 +70,7 @@ services: env_file: - ./.envs/.ci/.django healthcheck: - test: [ "CMD", "mc", "ready", "local" ] + test: ["CMD", "mc", "ready", "local"] interval: 5s timeout: 5s retries: 5 diff --git a/docker-compose.staging.yml b/docker-compose.staging.yml index 13525c044..684e50e67 100644 --- a/docker-compose.staging.yml +++ b/docker-compose.staging.yml @@ -5,7 +5,6 @@ # 1. The database is a service in the Docker Compose configuration rather than external as in production. # 2. Redis is a service in the Docker Compose configuration rather than external as in production. # 3. Port 5001 is exposed for the Django application. -version: "3" volumes: ami_local_postgres_data: {} @@ -21,6 +20,7 @@ services: depends_on: - postgres - redis + # - nats env_file: - ./.envs/.production/.django - ./.envs/.local/.postgres @@ -29,6 +29,7 @@ services: ports: - "5001:5000" command: /start + restart: always postgres: build: @@ -42,9 +43,11 @@ services: - ./data/db/snapshots:/backups env_file: - ./.envs/.local/.postgres + restart: always redis: image: redis:6 + restart: always celeryworker: <<: *django @@ -62,3 +65,18 @@ services: ports: - "5550:5555" command: /start-flower + + nats: + image: nats:2.10-alpine + container_name: ami_local_nats + hostname: nats + ports: + - "4222:4222" # Client port + - "8222:8222" # HTTP monitoring port + command: ["-js", "-m", "8222"] # Enable JetStream and monitoring + healthcheck: + test: ["CMD", "wget", "--spider", "-q", "http://localhost:8222/healthz"] + interval: 10s + timeout: 5s + retries: 3 + restart: always diff --git a/docker-compose.yml b/docker-compose.yml index e2ad3a100..703ecea0d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -21,6 +21,7 @@ services: depends_on: - postgres - redis + - nats - minio-init - ml_backend - rabbitmq @@ -75,7 +76,12 @@ services: volumes: - ./.git:/app/.git:ro - ./ui:/app - entrypoint: ["sh", "-c", "yarn install && yarn start --debug --host 0.0.0.0 --port 4000"] + entrypoint: + [ + "sh", + "-c", + "yarn install && yarn start --debug --host 0.0.0.0 --port 4000", + ] environment: - API_PROXY_TARGET=http://django:8000 - CHOKIDAR_USEPOLLING=true @@ -84,6 +90,20 @@ services: image: redis:6 container_name: ami_local_redis + nats: + image: nats:2.10-alpine + container_name: ami_local_nats + hostname: nats + ports: + - "4222:4222" # Client port + - "8222:8222" # HTTP monitoring port + command: ["-js", "-m", "8222"] # Enable JetStream and monitoring + healthcheck: + test: ["CMD", "wget", "--spider", "-q", "http://localhost:8222/healthz"] + interval: 10s + timeout: 5s + retries: 3 + celeryworker: <<: *django image: ami_local_celeryworker diff --git a/requirements/base.txt b/requirements/base.txt index 624b4a8e4..3b208e9df 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -8,6 +8,7 @@ celery==5.4.0 # pyup: < 6.0 # https://github.com/celery/celery django-celery-beat==2.5.0 # https://github.com/celery/django-celery-beat flower==2.0.1 # https://github.com/mher/flower kombu==5.4.2 +nats-py==2.10.0 # https://github.com/nats-io/nats.py uvicorn[standard]==0.22.0 # https://github.com/encode/uvicorn rich==13.5.0 markdown==3.4.4 @@ -53,6 +54,7 @@ drf-nested-routers==0.94.1 Werkzeug[watchdog]==2.3.6 # https://github.com/pallets/werkzeug ipdb==0.13.13 # https://github.com/gotcha/ipdb psycopg[binary]==3.1.9 # https://github.com/psycopg/psycopg +# psycopg==3.1.9 # https://github.com/psycopg/psycopg # the non-binary version is needed for some platforms watchfiles==0.19.0 # https://github.com/samuelcolvin/watchfiles # Testing