From b60eab0ff435e32b0264b182b4022864cdce71f0 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 16 Jan 2026 11:25:40 -0800 Subject: [PATCH 01/14] merge --- requirements/base.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/base.txt b/requirements/base.txt index dd9de69d5..ed40ea5f7 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -52,6 +52,7 @@ django-anymail[sendgrid]==10.0 # https://github.com/anymail/django-anymail Werkzeug[watchdog]==2.3.6 # https://github.com/pallets/werkzeug ipdb==0.13.13 # https://github.com/gotcha/ipdb psycopg[binary]==3.1.9 # https://github.com/psycopg/psycopg +# psycopg==3.1.9 # https://github.com/psycopg/psycopg # the non-binary version is needed for some platforms watchfiles==0.19.0 # https://github.com/samuelcolvin/watchfiles # Testing From 02aa6fa3de39a1a5ecaf7b03cd4a7ba2bf88a48a Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 4 Feb 2026 14:52:51 -0800 Subject: [PATCH 02/14] Accept and log `processing_service_name` parameter from workers --- ami/jobs/schemas.py | 7 +++++++ ami/jobs/views.py | 29 +++++++++++++++++++++++++++-- ami/utils/requests.py | 29 ----------------------------- 3 files changed, 34 insertions(+), 31 deletions(-) diff --git a/ami/jobs/schemas.py b/ami/jobs/schemas.py index 0e1ea4ac7..71726eb8c 100644 --- a/ami/jobs/schemas.py +++ b/ami/jobs/schemas.py @@ -20,3 +20,10 @@ required=False, type=int, ) + +processing_service_name_param = OpenApiParameter( + name="processing_service_name", + description="Inform the name of the calling processing service", + required=False, + type=str, +) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index a2f087cf5..9606582d3 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -15,7 +15,7 @@ from ami.base.permissions import ObjectPermission from ami.base.views import ProjectMixin -from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param +from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param, processing_service_name_param from ami.jobs.tasks import process_nats_pipeline_result from ami.main.api.schemas import project_id_doc_param from ami.main.api.views import DefaultViewSet @@ -203,13 +203,21 @@ def get_queryset(self) -> QuerySet: project_id_doc_param, ids_only_param, incomplete_only_param, + processing_service_name_param, ] ) def list(self, request, *args, **kwargs): + # Get optional processing_service_name parameter + processing_service_name = request.query_params.get("processing_service_name", None) + if processing_service_name: + logger.info(f"Jobs list requested by processing service: {processing_service_name}") + else: + logger.debug("Jobs list requested without processing service name") + return super().list(request, *args, **kwargs) @extend_schema( - parameters=[batch_param], + parameters=[batch_param, processing_service_name_param], responses={200: dict}, ) @action(detail=True, methods=["get"], name="tasks") @@ -228,6 +236,13 @@ def tasks(self, request, pk=None): except Exception as e: raise ValidationError({"batch": str(e)}) from e + # Get optional processing_service_name parameter + processing_service_name = request.query_params.get("processing_service_name", None) + if processing_service_name: + job.logger.info(f"Job {job.pk} tasks ({batch}) requested by processing service: {processing_service_name}") + else: + job.logger.warning(f"Job {job.pk} tasks ({batch}) requested without processing service name") + # Validate that the job has a pipeline if not job.pipeline: raise ValidationError("This job does not have a pipeline configured") @@ -249,6 +264,9 @@ async def get_tasks(): return Response({"tasks": tasks}) + @extend_schema( + parameters=[processing_service_name_param], + ) @action(detail=True, methods=["post"], name="result") def result(self, request, pk=None): """ @@ -261,6 +279,13 @@ def result(self, request, pk=None): job = self.get_object() + # Get optional processing_service_name parameter + processing_service_name = request.query_params.get("processing_service_name", None) + if processing_service_name: + job.logger.info(f"Job {job.pk} result received from processing service: {processing_service_name}") + else: + job.logger.warning(f"Job {job.pk} result received without processing service name") + # Validate request data is a list if isinstance(request.data, list): results = request.data diff --git a/ami/utils/requests.py b/ami/utils/requests.py index e4de57c0f..c4396b725 100644 --- a/ami/utils/requests.py +++ b/ami/utils/requests.py @@ -2,8 +2,6 @@ import requests from django.forms import BooleanField, FloatField -from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import OpenApiParameter from requests.adapters import HTTPAdapter from rest_framework.request import Request from urllib3.util import Retry @@ -144,30 +142,3 @@ def get_default_classification_threshold(project: "Project | None" = None, reque return project.default_filters_score_threshold else: return default_threshold - - -project_id_doc_param = OpenApiParameter( - name="project_id", - description="Filter by project ID", - required=False, - type=int, -) - -ids_only_param = OpenApiParameter( - name="ids_only", - description="Return only job IDs instead of full job objects", - required=False, - type=OpenApiTypes.BOOL, -) -incomplete_only_param = OpenApiParameter( - name="incomplete_only", - description="Filter to only show incomplete jobs (excludes SUCCESS, FAILURE, REVOKED)", - required=False, - type=OpenApiTypes.BOOL, -) -batch_param = OpenApiParameter( - name="batch", - description="Number of tasks to pull in the batch", - required=False, - type=OpenApiTypes.INT, -) From 90d729f36dc72e39a921710664cbc902a34f8f99 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 4 Feb 2026 17:13:26 -0800 Subject: [PATCH 03/14] refactor --- ami/jobs/views.py | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 9606582d3..ccf57cfbb 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -207,12 +207,8 @@ def get_queryset(self) -> QuerySet: ] ) def list(self, request, *args, **kwargs): - # Get optional processing_service_name parameter - processing_service_name = request.query_params.get("processing_service_name", None) - if processing_service_name: - logger.info(f"Jobs list requested by processing service: {processing_service_name}") - else: - logger.debug("Jobs list requested without processing service name") + # Get optional processing_service_name parameter and log it + _ = _log_processing_service_name(request, "list requested", logger) return super().list(request, *args, **kwargs) @@ -236,12 +232,8 @@ def tasks(self, request, pk=None): except Exception as e: raise ValidationError({"batch": str(e)}) from e - # Get optional processing_service_name parameter - processing_service_name = request.query_params.get("processing_service_name", None) - if processing_service_name: - job.logger.info(f"Job {job.pk} tasks ({batch}) requested by processing service: {processing_service_name}") - else: - job.logger.warning(f"Job {job.pk} tasks ({batch}) requested without processing service name") + # Get optional processing_service_name parameter and log it + _ = _log_processing_service_name(request, f"tasks ({batch}) requested for job {job.pk}", job.logger) # Validate that the job has a pipeline if not job.pipeline: @@ -279,12 +271,8 @@ def result(self, request, pk=None): job = self.get_object() - # Get optional processing_service_name parameter - processing_service_name = request.query_params.get("processing_service_name", None) - if processing_service_name: - job.logger.info(f"Job {job.pk} result received from processing service: {processing_service_name}") - else: - job.logger.warning(f"Job {job.pk} result received without processing service name") + # Get optional processing_service_name parameter and log it + _ = _log_processing_service_name(request, f"result received for job {job.pk}", job.logger) # Validate request data is a list if isinstance(request.data, list): @@ -345,3 +333,22 @@ def result(self, request, pk=None): }, status=500, ) + + +def _log_processing_service_name(request, context: str, logger: logging.Logger) -> str: + """ + Log the processing_service_name from query parameters. + + Args: + request: The HTTP request object + context: A string describing the operation (e.g., "tasks requested", "result received") + logger: A logging.Logger instance to use for logging + """ + processing_service_name = request.query_params.get("processing_service_name", None) + + if processing_service_name: + logger.info(f"Jobs {context} by processing service: {processing_service_name}") + else: + logger.debug(f"Jobs {context} without processing service name") + + return processing_service_name From 335229decca1b84384f98f0659f87c16b9bc3987 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 4 Feb 2026 17:14:52 -0800 Subject: [PATCH 04/14] Clean up --- ami/jobs/views.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index ccf57cfbb..00b024669 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -207,7 +207,6 @@ def get_queryset(self) -> QuerySet: ] ) def list(self, request, *args, **kwargs): - # Get optional processing_service_name parameter and log it _ = _log_processing_service_name(request, "list requested", logger) return super().list(request, *args, **kwargs) @@ -232,7 +231,6 @@ def tasks(self, request, pk=None): except Exception as e: raise ValidationError({"batch": str(e)}) from e - # Get optional processing_service_name parameter and log it _ = _log_processing_service_name(request, f"tasks ({batch}) requested for job {job.pk}", job.logger) # Validate that the job has a pipeline @@ -271,7 +269,6 @@ def result(self, request, pk=None): job = self.get_object() - # Get optional processing_service_name parameter and log it _ = _log_processing_service_name(request, f"result received for job {job.pk}", job.logger) # Validate request data is a list From 5b53380db51de66d1a1c633088b3f3d816ad55a7 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 6 Feb 2026 15:17:33 -0800 Subject: [PATCH 05/14] Address CR feedback --- ami/jobs/schemas.py | 2 +- ami/jobs/tests.py | 46 +++++++++++++++++++++++++++++++++++++++++++++ ami/jobs/views.py | 4 +++- 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/ami/jobs/schemas.py b/ami/jobs/schemas.py index 71726eb8c..2225c3799 100644 --- a/ami/jobs/schemas.py +++ b/ami/jobs/schemas.py @@ -23,7 +23,7 @@ processing_service_name_param = OpenApiParameter( name="processing_service_name", - description="Inform the name of the calling processing service", + description="Name of the calling processing service", required=False, type=str, ) diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index 8e04d9dd9..6590fda1d 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -423,3 +423,49 @@ def test_result_endpoint_validation(self): resp = self.client.post(result_url, invalid_data, format="json") self.assertEqual(resp.status_code, 400) self.assertIn("result", resp.json()[0].lower()) + + def test_processing_service_name_parameter(self): + """Test that processing_service_name parameter is accepted on job endpoints.""" + self.client.force_authenticate(user=self.user) + test_service_name = "Test Service" + + # Test list endpoint + list_url = reverse_with_params( + "api:job-list", params={"project_id": self.project.pk, "processing_service_name": test_service_name} + ) + resp = self.client.get(list_url) + self.assertEqual(resp.status_code, 200) + + # Test tasks endpoint (requires job with pipeline) + pipeline = self._create_pipeline() + job = self._create_ml_job("Job for service name test", pipeline) + + tasks_url = reverse_with_params( + "api:job-tasks", + args=[job.pk], + params={"project_id": self.project.pk, "batch": 1, "processing_service_name": test_service_name}, + ) + resp = self.client.get(tasks_url) + self.assertEqual(resp.status_code, 200) + + # Test result endpoint + result_url = reverse_with_params( + "api:job-result", + args=[job.pk], + params={"project_id": self.project.pk, "processing_service_name": test_service_name}, + ) + result_data = [ + { + "reply_subject": "test.reply.1", + "result": { + "pipeline": "test-pipeline", + "algorithms": {}, + "total_time": 1.5, + "source_images": [], + "detections": [], + "errors": None, + }, + } + ] + resp = self.client.post(result_url, result_data, format="json") + self.assertEqual(resp.status_code, 200) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 00b024669..cc57ecb3c 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -332,7 +332,7 @@ def result(self, request, pk=None): ) -def _log_processing_service_name(request, context: str, logger: logging.Logger) -> str: +def _log_processing_service_name(request, context: str, logger: logging.Logger) -> str | None: """ Log the processing_service_name from query parameters. @@ -340,6 +340,8 @@ def _log_processing_service_name(request, context: str, logger: logging.Logger) request: The HTTP request object context: A string describing the operation (e.g., "tasks requested", "result received") logger: A logging.Logger instance to use for logging + Returns: + The processing_service_name if provided, otherwise None """ processing_service_name = request.query_params.get("processing_service_name", None) From ba747474e5ee487fc4772df139c6c5f0f2a3d2db Mon Sep 17 00:00:00 2001 From: carlosgjs Date: Fri, 6 Feb 2026 18:27:26 -0800 Subject: [PATCH 06/14] fix: Properly handle async job state with celery tasks (#1114) * merge * fix: Properly handle async job state with celery tasks * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Delete implemented plan --------- Co-authored-by: Carlos Garcia Jurado Suarez Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .agents/planning/async-job-status-handling.md | 408 ------------------ ami/jobs/models.py | 24 ++ ami/jobs/tasks.py | 11 +- ami/jobs/tests.py | 87 ++++ 4 files changed, 121 insertions(+), 409 deletions(-) delete mode 100644 .agents/planning/async-job-status-handling.md diff --git a/.agents/planning/async-job-status-handling.md b/.agents/planning/async-job-status-handling.md deleted file mode 100644 index 1a6024a87..000000000 --- a/.agents/planning/async-job-status-handling.md +++ /dev/null @@ -1,408 +0,0 @@ -# Plan: Fix Async Pipeline Job Status Handling - -**Date:** 2026-01-30 -**Status:** Ready for implementation - -## Problem Summary - -When `async_pipeline_workers` feature flag is enabled: -1. `queue_images_to_nats()` queues work and returns immediately (lines 400-408 in `ami/jobs/models.py`) -2. The Celery `run_job` task completes without exception -3. The `task_postrun` signal handler (`ami/jobs/tasks.py:175-192`) calls `job.update_status("SUCCESS")` -4. This prematurely marks the job as SUCCESS before async workers actually finish processing - -## Solution Overview - -Use a **progress-based approach** to determine job completion: -1. Add a generic `is_complete()` method to `JobProgress` that works for any job type -2. In `task_postrun`, only allow SUCCESS if all stages are complete -3. Allow FAILURE, REVOKED, and other terminal states to pass through immediately -4. Authoritative completion for async jobs stays in `_update_job_progress` - ---- - -## Background: Job Types and Stages - -Jobs in this app have different types, each with different stages: - -| Job Type | Stages | -|----------|--------| -| MLJob | delay (optional), collect, process, results | -| DataStorageSyncJob | data_storage_sync | -| SourceImageCollectionPopulateJob | populate_captures_collection | -| DataExportJob | exporting_data, uploading_snapshot | -| PostProcessingJob | post_processing | - -The `is_complete()` method must be generic and work for ALL job types by checking if all stages have finished. - -### Celery Signal States - -The `task_postrun` signal fires after every task with these states: -- **SUCCESS**: Task completed without exception -- **FAILURE**: Task raised exception (also triggers `task_failure`) -- **RETRY**: Task requested retry -- **REVOKED**: Task cancelled - -**Handling strategy:** -| State | Behavior | Rationale | -|-------|----------|-----------| -| SUCCESS | Guard - only set if `is_complete()` | Prevents premature success | -| FAILURE | Allow immediately | Job failed, user needs to know | -| REVOKED | Allow immediately | Job cancelled | -| RETRY | Allow immediately | Transient state | - ---- - -## Implementation Steps - -### Step 1: Add `is_complete()` method to JobProgress - -**File:** `ami/jobs/models.py` (in the `JobProgress` class, after `reset()` method around line 198) - -```python -def is_complete(self) -> bool: - """ - Check if all stages have finished processing. - - A job is considered complete when ALL of its stages have: - - progress >= 1.0 (fully processed) - - status in a final state (SUCCESS, FAILURE, or REVOKED) - - This method works for any job type regardless of which stages it has. - It's used by the Celery task_postrun signal to determine whether to - set the job's final SUCCESS status, or defer to async progress handlers. - - Related: Job.update_progress() (lines 924-947) calculates the aggregate - progress percentage across all stages for display purposes. This method - is a binary check for completion that considers both progress AND status. - - Returns: - True if all stages are complete, False otherwise. - Returns False if job has no stages (shouldn't happen in practice). - """ - if not self.stages: - return False - return all( - stage.progress >= 1.0 and stage.status in JobState.final_states() - for stage in self.stages - ) -``` - -### Step 2: Modify `task_postrun` signal handler - -**File:** `ami/jobs/tasks.py` (lines 175-192) - -Only guard SUCCESS state - let all other states pass through: - -```python -@task_postrun.connect(sender=run_job) -def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs): - from ami.jobs.models import Job, JobState - - job_id = task.request.kwargs["job_id"] - if job_id is None: - logger.error(f"Job id is None for task {task_id}") - return - try: - job = Job.objects.get(pk=job_id) - except Job.DoesNotExist: - try: - job = Job.objects.get(task_id=task_id) - except Job.DoesNotExist: - logger.error(f"No job found for task {task_id} or job_id {job_id}") - return - - # Guard only SUCCESS state - let FAILURE, REVOKED, RETRY pass through immediately - # SUCCESS should only be set when all stages are actually complete - # This prevents premature SUCCESS when async workers are still processing - if state == JobState.SUCCESS and not job.progress.is_complete(): - job.logger.info( - f"Job {job.pk} task completed but stages not finished - " - "deferring SUCCESS status to progress handler" - ) - return - - job.update_status(state) -``` - -### Step 3: Add cleanup to `_update_job_progress` - -**File:** `ami/jobs/tasks.py` (lines 151-166) - -When job completes, add cleanup for Redis and NATS resources: - -```python -def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> None: - """ - Update job progress for a specific stage from async pipeline workers. - - This function is called by process_nats_pipeline_result when async workers - report progress. It updates the job's progress model and, when the job - completes, sets the final SUCCESS status. - - For async jobs, this is the authoritative place where SUCCESS status - is set - the Celery task_postrun signal defers to this function. - - Args: - job_id: The job primary key - stage: The processing stage (e.g., "process" or "results" for ML jobs) - progress_percentage: Progress as a float from 0.0 to 1.0 - """ - from ami.jobs.models import Job, JobState # avoid circular import - - with transaction.atomic(): - job = Job.objects.select_for_update().get(pk=job_id) - job.progress.update_stage( - stage, - status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, - progress=progress_percentage, - ) - - # Check if all stages are now complete - if job.progress.is_complete(): - job.status = JobState.SUCCESS - job.progress.summary.status = JobState.SUCCESS - job.finished_at = datetime.datetime.now() # Use naive datetime in local time - job.logger.info(f"Job {job_id} completed successfully - all stages finished") - - # Clean up job-specific resources (Redis state, NATS stream/consumer) - _cleanup_async_job_resources(job_id, job.logger) - - job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%") - job.save() - - -def _cleanup_async_job_resources(job_id: int, job_logger: logging.Logger) -> None: - """ - Clean up all async processing resources for a completed job. - - This function is called when an async job completes (all stages finished). - It cleans up: - - 1. Redis state (via TaskStateManager.cleanup): - - job:{job_id}:pending_images:process - tracks remaining images in process stage - - job:{job_id}:pending_images:results - tracks remaining images in results stage - - job:{job_id}:pending_images_total - total image count for progress calculation - - 2. NATS JetStream resources (via TaskQueueManager.cleanup_job_resources): - - Stream: job_{job_id} - the message stream that holds pending tasks - - Consumer: job-{job_id}-consumer - the durable consumer that tracks delivery - - Why cleanup is needed: - - Redis keys have a 7-day TTL but should be cleaned immediately when job completes - - NATS streams/consumers have 24-hour retention but consume server resources - - Cleaning up immediately prevents resource accumulation from many jobs - - Cleanup failures are logged but don't fail the job - data is already saved. - - Args: - job_id: The job primary key - job_logger: Logger instance for this job (writes to job's log file) - """ - # Cleanup Redis state tracking - state_manager = TaskStateManager(job_id) - state_manager.cleanup() - job_logger.info(f"Cleaned up Redis state for job {job_id}") - - # Cleanup NATS resources (stream and consumer) - try: - async def cleanup_nats(): - async with TaskQueueManager() as manager: - return await manager.cleanup_job_resources(job_id) - - success = async_to_sync(cleanup_nats)() - if success: - job_logger.info(f"Cleaned up NATS resources for job {job_id}") - else: - job_logger.warning(f"Failed to clean up NATS resources for job {job_id}") - except Exception as e: - job_logger.error(f"Error cleaning up NATS resources for job {job_id}: {e}") - # Don't fail the job if cleanup fails - job data is already saved -``` - -### Step 4: Verify existing error handling - -**File:** `ami/jobs/models.py` (lines 402-408) - -**No changes needed** - Existing code already handles FAILURE correctly: -```python -if not queued: - job.logger.error("Aborting job %s because images could not be queued to NATS", job.pk) - job.progress.update_stage("collect", status=JobState.FAILURE) - job.update_status(JobState.FAILURE) - job.finished_at = datetime.datetime.now() - job.save() - return -``` - ---- - -## Files to Modify - -| File | Lines | Change | -|------|-------|--------| -| `ami/jobs/models.py` | ~198 (in JobProgress class) | Add generic `is_complete()` method | -| `ami/jobs/models.py` | ~700 (Job class docstring) | Add future improvements note | -| `ami/jobs/tasks.py` | 175-192 | Guard SUCCESS with `is_complete()` check | -| `ami/jobs/tasks.py` | 151-166 | Use `is_complete()`, add cleanup with docstring | -| `ami/jobs/tasks.py` | imports | Add import for `TaskStateManager` | - -**Related existing code (not modified):** -- `Job.update_progress()` at lines 924-947 - calculates aggregate progress from stages - ---- - -## Why This Approach Works - -| Job Type | Stages Complete When Task Ends? | `task_postrun` Behavior | -|----------|--------------------------------|------------------------| -| Sync ML job | Yes - all stages marked SUCCESS | Sets SUCCESS normally | -| Async ML job | No - stages incomplete | Skips SUCCESS, defers to async handler | -| DataStorageSyncJob | Yes - single stage | Sets SUCCESS normally | -| DataExportJob | Yes - all stages | Sets SUCCESS normally | -| Any job that fails | N/A | FAILURE passes through | -| Cancelled job | N/A | REVOKED passes through | - ---- - -## Race Condition Analysis: `update_progress()` vs `is_complete()` - -### How `update_progress()` works (lines 924-947) - -Called automatically by `save()` (line 958-959), it auto-corrects stage status/progress: - -```python -for stage in self.progress.stages: - if stage.progress > 0 and stage.status == JobState.CREATED: - stage.status = JobState.STARTED # Auto-upgrade CREATED→STARTED - elif stage.status in JobState.final_states() and stage.progress < 1: - stage.progress = 1 # If final, ensure progress=1 - elif stage.progress == 1 and stage.status not in JobState.final_states(): - stage.status = JobState.SUCCESS # Auto-upgrade to SUCCESS if progress=1 -``` - -**Key insight:** Line 939-941 auto-sets SUCCESS only if `progress == 1`. For async jobs, incomplete stages have `progress < 1`, so they won't be auto-upgraded. - -### When is `is_complete()` called? - -1. **In `task_postrun`:** Job loaded fresh from DB (already saved by `job.run()`) -2. **In `_update_job_progress`:** After `update_stage()` but before `save()` - -### Async job flow - no race condition: - -``` -run_job task: -1. job.run() → MLJob.run() -2. collect stage set to SUCCESS, progress=1.0 -3. job.save() called → update_progress() runs → no changes (only collect is 100%) -4. queue_images_to_nats() called, returns -5. job.run() returns successfully - -task_postrun fires: -6. job = Job.objects.get(pk=job_id) # Fresh from DB -7. Job state: collect=SUCCESS(1.0), process=CREATED(0), results=CREATED(0) -8. is_complete() → False (not all stages at 100% final) -9. SUCCESS status deferred ✓ - -Async workers process images: -10. process_nats_pipeline_result called for each result -11. _update_job_progress("process", 0.5) → is_complete() = False -12. _update_job_progress("process", 1.0) → is_complete() = False (results still 0) -13. _update_job_progress("results", 0.5) → is_complete() = False -14. _update_job_progress("results", 1.0) → is_complete() = True ✓ -15. Job marked SUCCESS, cleanup runs -``` - -### Why there's no race: - -1. `is_complete()` requires ALL stages to have `progress >= 1.0` AND `status in final_states()` -2. For async jobs, incomplete stages have `progress < 1.0`, so `update_progress()` won't auto-upgrade them -3. The async handler updates stages separately (process, then results), so completion only triggers when the LAST stage reaches 100% -4. `select_for_update()` in `_update_job_progress` prevents concurrent updates to the same job - ---- - -## Risk Analysis - -**Core work functions that must NOT be affected:** -1. `queue_images_to_nats()` - queues images to NATS for async processing -2. `process_nats_pipeline_result()` - processes results from async workers -3. `pipeline.save_results()` - saves detections/classifications to database -4. `MLJob.process_images()` - synchronous image processing path - -**Changes and their risk:** - -| Change | Risk to Core Work | Analysis | -|--------|------------------|----------| -| Add `is_complete()` to JobProgress | **None** | Pure read-only check, doesn't modify any state | -| Guard SUCCESS in `task_postrun` | **None** | Only affects status display, not actual processing. Returns early without modifying anything. | -| Use `is_complete()` in `_update_job_progress` | **None** | Replaces hardcoded `stage == "results" and progress >= 1.0` check with generic method. Called AFTER `pipeline.save_results()` completes. | -| Add `_cleanup_async_job_resources()` | **Minimal** | Called ONLY after job is marked SUCCESS. Cleanup failures are caught and logged, don't fail the job. Data is already saved at this point. | - -**Worst case scenario:** If `is_complete()` has a bug and always returns False: -- Work still completes normally (queuing, processing, saving all happen) -- Job status would stay STARTED instead of SUCCESS -- UI would show job as incomplete even though work finished -- **Data is safe** - this is a display/status issue, not a data loss risk - -**Existing function to note:** `Job.update_progress()` at lines 924-947 calculates total progress from stages. The new `is_complete()` method is a related concept - checking if all stages are done vs calculating aggregate progress. - ---- - -## Future Improvements (out of scope) - -Per user feedback, consider for follow-up: -- Split job types into different classes with clearer state management -- More robust state machine for job lifecycle -- But don't reinvent state tracking on top of existing bg task tools - -These improvements should be noted in the Job class docstring. - ---- - -## Additional Documentation - -### Add note to Job class docstring - -**File:** `ami/jobs/models.py` (Job class, around line ~700) - -Add to the class docstring: -```python -# Future improvements: -# - Consider splitting job types into subclasses with clearer state management -# - The progress/stages system (see JobProgress, update_progress()) was designed -# for UI display. The is_complete() method bridges this to actual completion logic. -# - Avoid reinventing state tracking on top of existing background task tools -``` - ---- - -## Verification - -1. **Unit test for `is_complete()`:** - - No stages -> False - - One stage at progress=0.5, STARTED -> False - - One stage at progress=1.0, STARTED -> False (not final state) - - One stage at progress=1.0, SUCCESS -> True - - Multiple stages, all SUCCESS -> True - - Multiple stages, one still STARTED -> False - -2. **Unit test for task_postrun behavior:** - - SUCCESS with incomplete stages -> status NOT changed - - SUCCESS with complete stages -> status changed to SUCCESS - - FAILURE -> passes through immediately - - REVOKED -> passes through immediately - -3. **Integration test for async ML job:** - - Create job with async_pipeline_workers enabled - - Queue images to NATS - - Verify job status stays STARTED after run_job completes - - Process all images via process_nats_pipeline_result - - Verify job status becomes SUCCESS after all stages finish - - Verify finished_at is set - - Verify Redis/NATS cleanup occurred - -4. **Regression test for sync jobs:** - - Run sync ML job -> SUCCESS after task completes - - Run DataStorageSyncJob -> SUCCESS after task completes - - Run DataExportJob -> SUCCESS after task completes diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 482d01a58..9f8aa197f 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -197,6 +197,30 @@ def reset(self, status: JobState = JobState.CREATED): stage.progress = 0 stage.status = status + def is_complete(self) -> bool: + """ + Check if all stages have finished processing. + + A job is considered complete when ALL of its stages have: + - progress >= 1.0 (fully processed) + - status in a final state (SUCCESS, FAILURE, or REVOKED) + + This method works for any job type regardless of which stages it has. + It's used by the Celery task_postrun signal to determine whether to + set the job's final SUCCESS status, or defer to async progress handlers. + + Related: Job.update_progress() calculates the aggregate + progress percentage across all stages for display purposes. This method + is a binary check for completion that considers both progress AND status. + + Returns: + True if all stages are complete, False otherwise. + Returns False if job has no stages (shouldn't happen in practice). + """ + if not self.stages: + return False + return all(stage.progress >= 1.0 and stage.status in JobState.final_states() for stage in self.stages) + class Config: use_enum_values = True as_dict = True diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 6404512d1..4f70fd89a 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -175,7 +175,7 @@ def pre_update_job_status(sender, task_id, task, **kwargs): @task_postrun.connect(sender=run_job) def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs): - from ami.jobs.models import Job + from ami.jobs.models import Job, JobState job_id = task.request.kwargs["job_id"] if job_id is None: @@ -190,6 +190,15 @@ def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs): logger.error(f"No job found for task {task_id} or job_id {job_id}") return + # Guard only SUCCESS state - let FAILURE, REVOKED, RETRY pass through immediately + # SUCCESS should only be set when all stages are actually complete + # This prevents premature SUCCESS when async workers are still processing + if state == JobState.SUCCESS and not job.progress.is_complete(): + job.logger.info( + f"Job {job.pk} task completed but stages not finished - " "deferring SUCCESS status to progress handler" + ) + return + job.update_status(state) diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index 6590fda1d..4bfd1c97d 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -61,6 +61,93 @@ def test_create_job_with_delay(self): self.assertEqual(job.progress.stages[0].progress, 1) self.assertEqual(job.progress.stages[0].status, JobState.SUCCESS) + def test_job_status_guard_prevents_premature_success(self): + """ + Test that update_job_status guards against setting SUCCESS + when job stages are not complete. + + This tests the fix for race conditions where Celery task completes + but async workers are still processing stages. + """ + from unittest.mock import Mock + + from ami.jobs.tasks import update_job_status + + # Create job with multiple stages + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job with incomplete stages", + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + ) + + # Add stages that are NOT complete + job.progress.add_stage("detection") + job.progress.update_stage("detection", progress=0.5, status=JobState.STARTED) + job.progress.add_stage("classification") + job.progress.update_stage("classification", progress=0.0, status=JobState.CREATED) + job.save() + + # Verify stages are incomplete + self.assertFalse(job.progress.is_complete()) + + # Mock task object + mock_task = Mock() + mock_task.request.kwargs = {"job_id": job.pk} + initial_status = job.status + + # Attempt to set SUCCESS while stages are incomplete + update_job_status( + sender=mock_task, + task_id="test-task-id", + task=mock_task, + state=JobState.SUCCESS.value, # Pass string value, not enum + retval=None, + ) + + # Verify job status was NOT updated to SUCCESS (should remain CREATED) + job.refresh_from_db() + self.assertEqual(job.status, initial_status) + self.assertNotEqual(job.status, JobState.SUCCESS.value) + + def test_job_status_allows_failure_states_immediately(self): + """ + Test that FAILURE and REVOKED states bypass the completion guard + and are set immediately regardless of stage completion. + """ + from unittest.mock import Mock + + from ami.jobs.tasks import update_job_status + + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job for failure states", + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + ) + + # Add incomplete stage + job.progress.add_stage("detection") + job.progress.update_stage("detection", progress=0.3, status=JobState.STARTED) + job.save() + + mock_task = Mock() + mock_task.request.kwargs = {"job_id": job.pk} + + # Test FAILURE state passes through even with incomplete stages + update_job_status( + sender=mock_task, + task_id="test-task-id", + task=mock_task, + state=JobState.FAILURE.value, # Pass string value, not enum + retval=None, + ) + + job.refresh_from_db() + self.assertEqual(job.status, JobState.FAILURE.value) + class TestJobView(APITestCase): """ From 595f4c97f33c29429fbfe27133bc0b7c0415b359 Mon Sep 17 00:00:00 2001 From: carlosgjs Date: Fri, 6 Feb 2026 19:54:58 -0800 Subject: [PATCH 07/14] PSv2: Implement queue clean-up upon job completion (#1113) * merge * feat: PSv2 - Queue/redis clean-up upon job completion * fix: catch specific exception * chore: move tests to a subdir --------- Co-authored-by: Carlos Garcia Jurado Suarez Co-authored-by: Michael Bunsen --- ami/jobs/tasks.py | 30 +++ ami/ml/orchestration/jobs.py | 44 +++- ami/ml/orchestration/tests/__init__.py | 0 ami/ml/orchestration/tests/test_cleanup.py | 212 ++++++++++++++++++ .../{ => tests}/test_nats_queue.py | 0 5 files changed, 279 insertions(+), 7 deletions(-) create mode 100644 ami/ml/orchestration/tests/__init__.py create mode 100644 ami/ml/orchestration/tests/test_cleanup.py rename ami/ml/orchestration/{ => tests}/test_nats_queue.py (100%) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 4f70fd89a..17083fb84 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -166,6 +166,29 @@ def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%") job.save() + # Clean up async resources for completed jobs that use NATS/Redis + # Only ML jobs with async_pipeline_workers enabled use these resources + if stage == "results" and progress_percentage >= 1.0: + job = Job.objects.get(pk=job_id) # Re-fetch outside transaction + _cleanup_job_if_needed(job) + + +def _cleanup_job_if_needed(job) -> None: + """ + Clean up async resources (NATS/Redis) if this job type uses them. + + Only ML jobs with async_pipeline_workers enabled use NATS/Redis resources. + This function is safe to call for any job - it checks if cleanup is needed. + + Args: + job: The Job instance + """ + if job.job_type_key == "ml" and job.project and job.project.feature_flags.async_pipeline_workers: + # import here to avoid circular imports + from ami.ml.orchestration.jobs import cleanup_async_job_resources + + cleanup_async_job_resources(job) + @task_prerun.connect(sender=run_job) def pre_update_job_status(sender, task_id, task, **kwargs): @@ -201,6 +224,10 @@ def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs): job.update_status(state) + # Clean up async resources for revoked jobs + if state == JobState.REVOKED: + _cleanup_job_if_needed(job) + @task_failure.connect(sender=run_job, retry=False) def update_job_failure(sender, task_id, exception, *args, **kwargs): @@ -213,6 +240,9 @@ def update_job_failure(sender, task_id, exception, *args, **kwargs): job.save() + # Clean up async resources for failed jobs + _cleanup_job_if_needed(job) + def log_time(start: float = 0, msg: str | None = None) -> tuple[float, Callable]: """ diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index 621d5f089..9b4a577e9 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -1,28 +1,58 @@ +import logging + from asgiref.sync import async_to_sync -from ami.jobs.models import Job, JobState, logger +from ami.jobs.models import Job, JobState from ami.main.models import SourceImage from ami.ml.orchestration.nats_queue import TaskQueueManager from ami.ml.orchestration.task_state import TaskStateManager from ami.ml.schemas import PipelineProcessingTask +logger = logging.getLogger(__name__) + -# TODO CGJS: (Issue #1083) Call this once a job is fully complete (all images processed and saved) -def cleanup_nats_resources(job: "Job") -> bool: +def cleanup_async_job_resources(job: "Job") -> bool: """ - Clean up NATS JetStream resources (stream and consumer) for a completed job. + Clean up NATS JetStream and Redis resources for a completed job. + + This function cleans up: + 1. Redis state (via TaskStateManager.cleanup): + 2. NATS JetStream resources (via TaskQueueManager.cleanup_job_resources): + + Cleanup failures are logged but don't fail the job - data is already saved. Args: job: The Job instance Returns: - bool: True if cleanup was successful, False otherwise + bool: True if both cleanups succeeded, False otherwise """ - + redis_success = False + nats_success = False + + # Cleanup Redis state + try: + state_manager = TaskStateManager(job.pk) + state_manager.cleanup() + job.logger.info(f"Cleaned up Redis state for job {job.pk}") + redis_success = True + except Exception as e: + job.logger.error(f"Error cleaning up Redis state for job {job.pk}: {e}") + + # Cleanup NATS resources async def cleanup(): async with TaskQueueManager() as manager: return await manager.cleanup_job_resources(job.pk) - return async_to_sync(cleanup)() + try: + nats_success = async_to_sync(cleanup)() + if nats_success: + job.logger.info(f"Cleaned up NATS resources for job {job.pk}") + else: + job.logger.warning(f"Failed to clean up NATS resources for job {job.pk}") + except Exception as e: + job.logger.error(f"Error cleaning up NATS resources for job {job.pk}: {e}") + + return redis_success and nats_success def queue_images_to_nats(job: "Job", images: list[SourceImage]): diff --git a/ami/ml/orchestration/tests/__init__.py b/ami/ml/orchestration/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ami/ml/orchestration/tests/test_cleanup.py b/ami/ml/orchestration/tests/test_cleanup.py new file mode 100644 index 000000000..f7b686d44 --- /dev/null +++ b/ami/ml/orchestration/tests/test_cleanup.py @@ -0,0 +1,212 @@ +"""Integration tests for async job resource cleanup (NATS and Redis).""" + +from asgiref.sync import async_to_sync +from django.core.cache import cache +from django.test import TestCase +from nats.js.errors import NotFoundError + +from ami.jobs.models import Job, JobState, MLJob +from ami.jobs.tasks import _update_job_progress, update_job_failure, update_job_status +from ami.main.models import Project, ProjectFeatureFlags, SourceImage, SourceImageCollection +from ami.ml.models import Pipeline +from ami.ml.orchestration.jobs import queue_images_to_nats +from ami.ml.orchestration.nats_queue import TaskQueueManager +from ami.ml.orchestration.task_state import TaskStateManager + + +class TestCleanupAsyncJobResources(TestCase): + """Test cleanup of NATS and Redis resources for async ML jobs.""" + + def setUp(self): + """Set up test fixtures with async_pipeline_workers enabled.""" + # Create project with async_pipeline_workers feature flag enabled + self.project = Project.objects.create( + name="Test Cleanup Project", + feature_flags=ProjectFeatureFlags(async_pipeline_workers=True), + ) + + # Create pipeline + self.pipeline = Pipeline.objects.create( + name="Test Cleanup Pipeline", + slug="test-cleanup-pipeline", + description="Pipeline for cleanup tests", + ) + self.pipeline.projects.add(self.project) + + # Create source image collection with images + self.collection = SourceImageCollection.objects.create( + name="Test Cleanup Collection", + project=self.project, + ) + + # Create test images + self.images = [ + SourceImage.objects.create( + path=f"test_image_{i}.jpg", + public_base_url="https://example.com", + project=self.project, + ) + for i in range(3) + ] + for image in self.images: + self.collection.images.add(image) + + def _verify_resources_created(self, job_id: int): + """ + Verify that both Redis and NATS resources were created. + + Args: + job_id: The job ID to check + """ + # Verify Redis keys exist + state_manager = TaskStateManager(job_id) + for stage in state_manager.STAGES: + pending_key = state_manager._get_pending_key(stage) + self.assertIsNotNone(cache.get(pending_key), f"Redis key {pending_key} should exist") + total_key = state_manager._total_key + self.assertIsNotNone(cache.get(total_key), f"Redis key {total_key} should exist") + + # Verify NATS stream and consumer exist + async def check_nats_resources(): + async with TaskQueueManager() as manager: + stream_name = manager._get_stream_name(job_id) + consumer_name = manager._get_consumer_name(job_id) + + # Try to get stream info - should succeed if created + stream_exists = True + try: + await manager.js.stream_info(stream_name) + except NotFoundError: + stream_exists = False + + # Try to get consumer info - should succeed if created + consumer_exists = True + try: + await manager.js.consumer_info(stream_name, consumer_name) + except NotFoundError: + consumer_exists = False + + return stream_exists, consumer_exists + + stream_exists, consumer_exists = async_to_sync(check_nats_resources)() + + self.assertTrue(stream_exists, f"NATS stream for job {job_id} should exist") + self.assertTrue(consumer_exists, f"NATS consumer for job {job_id} should exist") + + def _create_job_with_queued_images(self) -> Job: + """ + Helper to create an ML job and queue images to NATS/Redis. + + Returns: + Job instance with images queued to NATS and state initialized in Redis + """ + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test Cleanup Job", + pipeline=self.pipeline, + source_image_collection=self.collection, + ) + + # Queue images to NATS (also initializes Redis state) + queue_images_to_nats(job, self.images) + + # Verify resources were actually created + self._verify_resources_created(job.pk) + + return job + + def _verify_resources_cleaned(self, job_id: int): + """ + Verify that both Redis and NATS resources are cleaned up. + + Args: + job_id: The job ID to check + """ + # Verify Redis keys are deleted + state_manager = TaskStateManager(job_id) + for stage in state_manager.STAGES: + pending_key = state_manager._get_pending_key(stage) + self.assertIsNone(cache.get(pending_key), f"Redis key {pending_key} should be deleted") + total_key = state_manager._total_key + self.assertIsNone(cache.get(total_key), f"Redis key {total_key} should be deleted") + + # Verify NATS stream and consumer are deleted + async def check_nats_resources(): + async with TaskQueueManager() as manager: + stream_name = manager._get_stream_name(job_id) + consumer_name = manager._get_consumer_name(job_id) + + # Try to get stream info - should fail if deleted + stream_exists = True + try: + await manager.js.stream_info(stream_name) + except NotFoundError: + stream_exists = False + + # Try to get consumer info - should fail if deleted + consumer_exists = True + try: + await manager.js.consumer_info(stream_name, consumer_name) + except NotFoundError: + consumer_exists = False + + return stream_exists, consumer_exists + + stream_exists, consumer_exists = async_to_sync(check_nats_resources)() + + self.assertFalse(stream_exists, f"NATS stream for job {job_id} should be deleted") + self.assertFalse(consumer_exists, f"NATS consumer for job {job_id} should be deleted") + + def test_cleanup_on_job_completion(self): + """Test that resources are cleaned up when job completes successfully.""" + job = self._create_job_with_queued_images() + + # Simulate job completion by updating progress to 100% in results stage + _update_job_progress(job.pk, stage="results", progress_percentage=1.0) + + # Verify cleanup happened + self._verify_resources_cleaned(job.pk) + + def test_cleanup_on_job_failure(self): + """Test that resources are cleaned up when job fails.""" + job = self._create_job_with_queued_images() + + # Set task_id so the failure handler can find the job + job.task_id = "test-task-failure-123" + job.save() + + # Simulate job failure by calling the failure signal handler + update_job_failure( + sender=None, + task_id=job.task_id, + exception=Exception("Test failure"), + ) + + # Verify cleanup happened + self._verify_resources_cleaned(job.pk) + + def test_cleanup_on_job_revoked(self): + """Test that resources are cleaned up when job is revoked/cancelled.""" + job = self._create_job_with_queued_images() + + # Create a mock task request object for the signal handler + class MockRequest: + def __init__(self): + self.kwargs = {"job_id": job.pk} + + class MockTask: + def __init__(self, job_id): + self.request = MockRequest() + self.request.kwargs["job_id"] = job_id + + # Simulate job revocation by calling the postrun signal handler with REVOKED state + update_job_status( + sender=None, + task_id="test-task-revoked-456", + task=MockTask(job.pk), + state=JobState.REVOKED, + ) + + # Verify cleanup happened + self._verify_resources_cleaned(job.pk) diff --git a/ami/ml/orchestration/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py similarity index 100% rename from ami/ml/orchestration/test_nats_queue.py rename to ami/ml/orchestration/tests/test_nats_queue.py From 0d8c6494d92a4196f67708f0d0bc9f01012e6c2c Mon Sep 17 00:00:00 2001 From: carlosgjs Date: Mon, 9 Feb 2026 15:32:04 -0800 Subject: [PATCH 08/14] fix: PSv2: Workers should not try to fetch tasks from v1 jobs (#1118) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces the dispatch_mode field on the Job model to track how each job dispatches its workload. This allows API clients (including the AMI worker) to filter jobs by dispatch mode — for example, fetching only async_api jobs so workers don't pull synchronous or internal jobs. JobDispatchMode enum (ami/jobs/models.py): internal — work handled entirely within the platform (Celery worker, no external calls). Default for all jobs. sync_api — worker calls an external processing service API synchronously and waits for each response. async_api — worker publishes items to NATS for external processing service workers to pick up independently. Database and Model Changes: Added dispatch_mode CharField with TextChoices, defaulting to internal, with the migration in ami/jobs/migrations/0019_job_dispatch_mode.py. ML jobs set dispatch_mode = async_api when the project's async_pipeline_workers feature flag is enabled. ML jobs set dispatch_mode = sync_api on the synchronous processing path (previously unset). API and Filtering: dispatch_mode is exposed (read-only) in job list and detail serializers. Filterable via query parameter: ?dispatch_mode=async_api The /tasks endpoint now returns 400 for non-async_api jobs, since only those have NATS tasks to fetch. Architecture doc: docs/claude/job-dispatch-modes.md documents the three modes, naming decisions, and per-job-type mapping. --------- Co-authored-by: Carlos Garcia Jurado Suarez Co-authored-by: Michael Bunsen Co-authored-by: Claude --- .dockerignore | 10 ++ ami/jobs/migrations/0019_job_dispatch_mode.py | 22 +++ ami/jobs/models.py | 36 +++++ ami/jobs/serializers.py | 2 + ami/jobs/tests.py | 132 +++++++++++++++++- ami/jobs/views.py | 6 +- docs/claude/job-dispatch-modes.md | 69 +++++++++ 7 files changed, 272 insertions(+), 5 deletions(-) create mode 100644 ami/jobs/migrations/0019_job_dispatch_mode.py create mode 100644 docs/claude/job-dispatch-modes.md diff --git a/.dockerignore b/.dockerignore index f0d718f04..95d6ac666 100644 --- a/.dockerignore +++ b/.dockerignore @@ -27,6 +27,10 @@ __pycache__/ *.egg *.whl +# IPython / Jupyter +.ipython/ +.jupyter/ + # Django / runtime artifacts *.log @@ -55,6 +59,12 @@ yarn-error.log .vscode/ *.iml +# Development / testing +.pytest_cache/ +.coverage +.tox/ +.cache/ + # OS cruft .DS_Store Thumbs.db diff --git a/ami/jobs/migrations/0019_job_dispatch_mode.py b/ami/jobs/migrations/0019_job_dispatch_mode.py new file mode 100644 index 000000000..1134c0144 --- /dev/null +++ b/ami/jobs/migrations/0019_job_dispatch_mode.py @@ -0,0 +1,22 @@ +# Generated by Django 4.2.10 on 2026-02-04 20:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("jobs", "0018_alter_job_job_type_key"), + ] + + operations = [ + migrations.AddField( + model_name="job", + name="dispatch_mode", + field=models.CharField( + choices=[("internal", "Internal"), ("sync_api", "Sync API"), ("async_api", "Async API")], + default="internal", + help_text="How the job dispatches its workload: internal, sync_api, or async_api.", + max_length=32, + ), + ), + ] diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 9f8aa197f..8e790cdcb 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -24,6 +24,32 @@ logger = logging.getLogger(__name__) +class JobDispatchMode(models.TextChoices): + """ + How a job dispatches its workload. + + Jobs are configured and launched by users in the UI, then dispatched to + Celery workers. This enum describes what the worker does with the work: + + - INTERNAL: All work happens within the platform (Celery worker handles it directly). + - SYNC_API: Worker calls an external processing service API and waits for each response. + - ASYNC_API: Worker queues items to a message broker (NATS) for external processing + service workers to pick up and process independently. + """ + + # Work is handled entirely within the platform, no external service calls. + # e.g. DataStorageSyncJob, DataExportJob, SourceImageCollectionPopulateJob + INTERNAL = "internal", "Internal" + + # Worker loops over items, sends each to an external processing service + # endpoint synchronously, and waits for the response before continuing. + SYNC_API = "sync_api", "Sync API" + + # Worker publishes all items to a message broker (NATS). External processing + # service workers consume and process them independently, reporting results back. + ASYNC_API = "async_api", "Async API" + + class JobState(str, OrderedEnum): """ These come from Celery, except for CREATED, which is a custom state. @@ -422,6 +448,8 @@ def run(cls, job: "Job"): job.save() if job.project.feature_flags.async_pipeline_workers: + job.dispatch_mode = JobDispatchMode.ASYNC_API + job.save(update_fields=["dispatch_mode"]) queued = queue_images_to_nats(job, images) if not queued: job.logger.error("Aborting job %s because images could not be queued to NATS", job.pk) @@ -431,6 +459,8 @@ def run(cls, job: "Job"): job.save() return else: + job.dispatch_mode = JobDispatchMode.SYNC_API + job.save(update_fields=["dispatch_mode"]) cls.process_images(job, images) @classmethod @@ -822,6 +852,12 @@ class Job(BaseModel): blank=True, related_name="jobs", ) + dispatch_mode = models.CharField( + max_length=32, + choices=JobDispatchMode.choices, + default=JobDispatchMode.INTERNAL, + help_text="How the job dispatches its workload: internal, sync_api, or async_api.", + ) def __str__(self) -> str: return f'#{self.pk} "{self.name}" ({self.status})' diff --git a/ami/jobs/serializers.py b/ami/jobs/serializers.py index 254242046..7a3471003 100644 --- a/ami/jobs/serializers.py +++ b/ami/jobs/serializers.py @@ -127,6 +127,7 @@ class Meta: "job_type", "job_type_key", "data_export", + "dispatch_mode", # "duration", # "duration_label", # "progress_label", @@ -141,6 +142,7 @@ class Meta: "started_at", "finished_at", "duration", + "dispatch_mode", ] diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index 4bfd1c97d..4d4edb9fa 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -8,7 +8,7 @@ from rest_framework.test import APIRequestFactory, APITestCase from ami.base.serializers import reverse_with_params -from ami.jobs.models import Job, JobProgress, JobState, MLJob, SourceImageCollectionPopulateJob +from ami.jobs.models import Job, JobDispatchMode, JobProgress, JobState, MLJob, SourceImageCollectionPopulateJob from ami.main.models import Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline from ami.ml.orchestration.jobs import queue_images_to_nats @@ -414,6 +414,8 @@ def test_search_jobs(self): def _task_batch_helper(self, value: Any, expected_status: int): pipeline = self._create_pipeline() job = self._create_ml_job("Job for batch test", pipeline) + job.dispatch_mode = JobDispatchMode.ASYNC_API + job.save(update_fields=["dispatch_mode"]) images = [ SourceImage.objects.create( path=f"image_{i}.jpg", @@ -447,12 +449,19 @@ def test_tasks_endpoint_with_invalid_batch(self): def test_tasks_endpoint_without_pipeline(self): """Test the tasks endpoint returns error when job has no pipeline.""" - # Use the existing job which doesn't have a pipeline - job_data = self._create_job("Job without pipeline", start_now=False) + # Create a job without a pipeline but with async_api dispatch mode + # so the dispatch_mode guard passes and the pipeline check is reached + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Job without pipeline", + source_image_collection=self.source_image_collection, + dispatch_mode=JobDispatchMode.ASYNC_API, + ) self.client.force_authenticate(user=self.user) tasks_url = reverse_with_params( - "api:job-tasks", args=[job_data["id"]], params={"project_id": self.project.pk, "batch": 1} + "api:job-tasks", args=[job.pk], params={"project_id": self.project.pk, "batch": 1} ) resp = self.client.get(tasks_url) @@ -556,3 +565,118 @@ def test_processing_service_name_parameter(self): ] resp = self.client.post(result_url, result_data, format="json") self.assertEqual(resp.status_code, 200) + + +class TestJobDispatchModeFiltering(APITestCase): + """Test job filtering by dispatch_mode.""" + + def setUp(self): + self.user = User.objects.create_user( # type: ignore + email="testuser-backend@insectai.org", + is_staff=True, + is_active=True, + is_superuser=True, + ) + self.project = Project.objects.create(name="Test Backend Project") + + # Create pipeline for ML jobs + self.pipeline = Pipeline.objects.create( + name="Test ML Pipeline", + slug="test-ml-pipeline", + description="Test ML pipeline for dispatch_mode filtering", + ) + self.pipeline.projects.add(self.project) + + # Create source image collection for jobs + self.source_image_collection = SourceImageCollection.objects.create( + name="Test Collection", + project=self.project, + ) + + # Give the user necessary permissions + assign_perm(Project.Permissions.VIEW_PROJECT, self.user, self.project) + + def test_dispatch_mode_filtering(self): + """Test that jobs can be filtered by dispatch_mode parameter.""" + # Create two ML jobs with different dispatch modes + sync_job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Sync API Job", + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + dispatch_mode=JobDispatchMode.SYNC_API, + ) + + async_job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Async API Job", + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + dispatch_mode=JobDispatchMode.ASYNC_API, + ) + + # Create a job with default dispatch_mode (should be "internal") + internal_job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Internal Job", + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + ) + + self.client.force_authenticate(user=self.user) + jobs_list_url = reverse_with_params("api:job-list", params={"project_id": self.project.pk}) + + # Test filtering by sync_api dispatch_mode + resp = self.client.get(jobs_list_url, {"dispatch_mode": JobDispatchMode.SYNC_API}) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(data["count"], 1) + self.assertEqual(data["results"][0]["id"], sync_job.pk) + self.assertEqual(data["results"][0]["dispatch_mode"], JobDispatchMode.SYNC_API) + + # Test filtering by async_api dispatch_mode + resp = self.client.get(jobs_list_url, {"dispatch_mode": JobDispatchMode.ASYNC_API}) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(data["count"], 1) + self.assertEqual(data["results"][0]["id"], async_job.pk) + self.assertEqual(data["results"][0]["dispatch_mode"], JobDispatchMode.ASYNC_API) + + # Test filtering by invalid dispatch_mode (should return 400 due to choices validation) + resp = self.client.get(jobs_list_url, {"dispatch_mode": "non_existent_mode"}) + self.assertEqual(resp.status_code, 400) + + # Test without dispatch_mode filter (should return all jobs) + resp = self.client.get(jobs_list_url) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(data["count"], 3) # All three jobs + + # Verify the job IDs returned include all jobs + returned_ids = {job["id"] for job in data["results"]} + expected_ids = {sync_job.pk, async_job.pk, internal_job.pk} + self.assertEqual(returned_ids, expected_ids) + + def test_tasks_endpoint_rejects_non_async_jobs(self): + """Test that /tasks endpoint returns 400 for non-async_api jobs.""" + from ami.base.serializers import reverse_with_params + + sync_job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Sync Job for tasks test", + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + dispatch_mode=JobDispatchMode.SYNC_API, + ) + + self.client.force_authenticate(user=self.user) + tasks_url = reverse_with_params( + "api:job-tasks", args=[sync_job.pk], params={"project_id": self.project.pk, "batch": 1} + ) + resp = self.client.get(tasks_url) + self.assertEqual(resp.status_code, 400) + self.assertIn("async_api", resp.json()[0].lower()) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index cc57ecb3c..eb3ab258c 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -22,7 +22,7 @@ from ami.ml.schemas import PipelineTaskResult from ami.utils.fields import url_boolean_param -from .models import Job, JobState +from .models import Job, JobDispatchMode, JobState from .serializers import JobListSerializer, JobSerializer, MinimalJobSerializer logger = logging.getLogger(__name__) @@ -43,6 +43,7 @@ class Meta: "source_image_single", "pipeline", "job_type_key", + "dispatch_mode", ] @@ -232,6 +233,9 @@ def tasks(self, request, pk=None): raise ValidationError({"batch": str(e)}) from e _ = _log_processing_service_name(request, f"tasks ({batch}) requested for job {job.pk}", job.logger) + # Only async_api jobs have tasks fetchable from NATS + if job.dispatch_mode != JobDispatchMode.ASYNC_API: + raise ValidationError("Only async_api jobs have fetchable tasks") # Validate that the job has a pipeline if not job.pipeline: diff --git a/docs/claude/job-dispatch-modes.md b/docs/claude/job-dispatch-modes.md new file mode 100644 index 000000000..26664be0b --- /dev/null +++ b/docs/claude/job-dispatch-modes.md @@ -0,0 +1,69 @@ +# Job Dispatch Modes + +## Overview + +The `Job` model is a user-facing CMS feature. Users configure jobs in the UI (selecting images, a processing pipeline, etc.), click start, and watch progress. The actual work is dispatched to background workers. The `dispatch_mode` field on `Job` describes *how* that work gets dispatched. + +## The Three Dispatch Modes + +### `internal` + +All work happens within the platform itself. A Celery worker picks up the job and handles it directly — no external service calls. + +**Job types using this mode:** +- `DataStorageSyncJob` — syncs files from S3 storage +- `SourceImageCollectionPopulateJob` — queries DB, populates a capture collection +- `DataExportJob` — generates export files +- `PostProcessingJob` — runs post-processing tasks + +### `sync_api` + +The Celery worker calls an external processing service API synchronously. It loops over items (e.g. batches of images), sends each batch to the processing service endpoint, waits for the response, saves results, and moves on. + +**Job types using this mode:** +- `MLJob` (default path) + +### `async_api` + +The Celery worker publishes all items to a message broker (NATS). External processing service workers consume items independently and report results back. The job monitors progress and completes when all items are processed. + +**Job types using this mode:** +- `MLJob` (when `project.feature_flags.async_pipeline_workers` is enabled) + +## Architecture Context + +``` +User (UI) + │ + ▼ +Job (Django model) ─── dispatch_mode: internal | sync_api | async_api + │ + ▼ +Celery Worker + │ + ├── internal ──────► Work done directly (DB queries, file ops, exports) + │ + ├── sync_api ──────► HTTP calls to Processing Service API (request/response loop) + │ + └── async_api ─────► Publish to NATS ──► External Processing Service workers + │ + ▼ + Results reported back +``` + +## Naming Decisions + +- **Why not `backend`?** Collides with Celery's "result backend" concept and the "ML backend" term used for processing services throughout the codebase. +- **Why not `task_backend`?** "Task backend" is specifically a Celery concept (where task results are stored). +- **Why not `local`?** Ambiguous with local development environments. +- **Why `internal`?** Clean contrast with the two external API modes. "Internal" means the work stays within the platform; `sync_api` and `async_api` both involve external processing services. +- **Why `dispatch_mode`?** The field describes *how* the Celery worker dispatches work to processing services, not how the job itself executes (all jobs execute via Celery). "Dispatch" is more precise than "execution" which is ambiguous. + +## Code Locations + +- Enum: `ami/jobs/models.py` — `JobDispatchMode` +- Field: `ami/jobs/models.py` — `Job.dispatch_mode` +- Serializer: `ami/jobs/serializers.py` — exposed in `JobListSerializer` and read-only +- API filter: `ami/jobs/views.py` — filterable via `?dispatch_mode=sync_api` +- Migration: `ami/jobs/migrations/0019_job_dispatch_mode.py` +- Tests: `ami/jobs/tests.py` — `TestJobDispatchModeFiltering` From 751573e3f53c17f1971c5cff0cd4f0b5e8fa31e6 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 9 Feb 2026 16:15:57 -0800 Subject: [PATCH 09/14] PSv2 cleanup: use is_complete() and dispatch_mode in job progress handler (#1125) * refactor: use is_complete() and dispatch_mode in job progress handler Replace hardcoded `stage == "results"` check with `job.progress.is_complete()` which verifies ALL stages are done, making it work for any job type. Replace feature flag check in cleanup with `dispatch_mode == ASYNC_API` which is immutable for the job's lifetime and more correct than re-reading a mutable flag that could change between job creation and completion. Co-Authored-By: Claude * test: update cleanup tests for is_complete() and dispatch_mode checks Set dispatch_mode=ASYNC_API on test jobs to match the new cleanup guard. Complete all stages (collect, process, results) in the completion test since is_complete() correctly requires all stages to be done. Co-Authored-By: Claude --------- Co-authored-by: Claude --- ami/jobs/tasks.py | 13 +++++++------ ami/ml/orchestration/tests/test_cleanup.py | 7 +++++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 17083fb84..3548d0ea5 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -159,7 +159,7 @@ def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, progress=progress_percentage, ) - if stage == "results" and progress_percentage >= 1.0: + if job.progress.is_complete(): job.status = JobState.SUCCESS job.progress.summary.status = JobState.SUCCESS job.finished_at = datetime.datetime.now() # Use naive datetime in local time @@ -167,23 +167,24 @@ def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> job.save() # Clean up async resources for completed jobs that use NATS/Redis - # Only ML jobs with async_pipeline_workers enabled use these resources - if stage == "results" and progress_percentage >= 1.0: + if job.progress.is_complete(): job = Job.objects.get(pk=job_id) # Re-fetch outside transaction _cleanup_job_if_needed(job) def _cleanup_job_if_needed(job) -> None: """ - Clean up async resources (NATS/Redis) if this job type uses them. + Clean up async resources (NATS/Redis) if this job uses them. - Only ML jobs with async_pipeline_workers enabled use NATS/Redis resources. + Only jobs with ASYNC_API dispatch mode use NATS/Redis resources. This function is safe to call for any job - it checks if cleanup is needed. Args: job: The Job instance """ - if job.job_type_key == "ml" and job.project and job.project.feature_flags.async_pipeline_workers: + from ami.jobs.models import JobDispatchMode + + if job.dispatch_mode == JobDispatchMode.ASYNC_API: # import here to avoid circular imports from ami.ml.orchestration.jobs import cleanup_async_job_resources diff --git a/ami/ml/orchestration/tests/test_cleanup.py b/ami/ml/orchestration/tests/test_cleanup.py index f7b686d44..ef8382d3d 100644 --- a/ami/ml/orchestration/tests/test_cleanup.py +++ b/ami/ml/orchestration/tests/test_cleanup.py @@ -5,7 +5,7 @@ from django.test import TestCase from nats.js.errors import NotFoundError -from ami.jobs.models import Job, JobState, MLJob +from ami.jobs.models import Job, JobDispatchMode, JobState, MLJob from ami.jobs.tasks import _update_job_progress, update_job_failure, update_job_status from ami.main.models import Project, ProjectFeatureFlags, SourceImage, SourceImageCollection from ami.ml.models import Pipeline @@ -106,6 +106,7 @@ def _create_job_with_queued_images(self) -> Job: name="Test Cleanup Job", pipeline=self.pipeline, source_image_collection=self.collection, + dispatch_mode=JobDispatchMode.ASYNC_API, ) # Queue images to NATS (also initializes Redis state) @@ -162,7 +163,9 @@ def test_cleanup_on_job_completion(self): """Test that resources are cleaned up when job completes successfully.""" job = self._create_job_with_queued_images() - # Simulate job completion by updating progress to 100% in results stage + # Simulate job completion: complete all stages (collect, process, then results) + _update_job_progress(job.pk, stage="collect", progress_percentage=1.0) + _update_job_progress(job.pk, stage="process", progress_percentage=1.0) _update_job_progress(job.pk, stage="results", progress_percentage=1.0) # Verify cleanup happened From 7a2847725479b6d22b973dfe3fe202527656e357 Mon Sep 17 00:00:00 2001 From: carlosgjs Date: Thu, 12 Feb 2026 15:00:06 -0800 Subject: [PATCH 10/14] Tests for async result processing (#1129) * merge * Tests for async result processing * fix formatting * CR feedback * refactor: add public get_progress() to TaskStateManager Add a read-only get_progress(stage) method that returns a progress snapshot without acquiring a lock or mutating state. Use it in test_tasks.py instead of calling the private _get_progress() directly. Co-Authored-By: Claude * docs: clarify Celery .delay() vs .apply_async() calling convention Three reviewers were confused by how mock.call_args works here. .delay(**kw) passes ((), kw) as two positional args to apply_async, which is different from apply_async(kwargs=kw). Co-Authored-By: Claude --------- Co-authored-by: Carlos Garcia Jurado Suarez Co-authored-by: Michael Bunsen Co-authored-by: Claude --- .vscode/launch.json | 16 ++ ami/jobs/test_tasks.py | 406 +++++++++++++++++++++++++++++ ami/ml/orchestration/task_state.py | 28 +- scripts/debug_tests.sh | 8 + 4 files changed, 449 insertions(+), 9 deletions(-) create mode 100644 ami/jobs/test_tasks.py create mode 100755 scripts/debug_tests.sh diff --git a/.vscode/launch.json b/.vscode/launch.json index 558ab2678..024962ed7 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -40,6 +40,22 @@ ], "justMyCode": true }, + { + "name": "Attach: Tests", + "type": "debugpy", + "request": "attach", + "connect": { + "host": "localhost", + "port": 5680 + }, + "pathMappings": [ + { + "localRoot": "${workspaceFolder}", + "remoteRoot": "/app" + } + ], + "justMyCode": true + }, { "name": "Current File", "type": "debugpy", diff --git a/ami/jobs/test_tasks.py b/ami/jobs/test_tasks.py new file mode 100644 index 000000000..fc291c8ba --- /dev/null +++ b/ami/jobs/test_tasks.py @@ -0,0 +1,406 @@ +""" +E2E tests for ami.jobs.tasks, focusing on error handling in process_nats_pipeline_result. + +This test suite verifies the critical error handling path when PipelineResultsError +is received instead of successful pipeline results. +""" + +import logging +from unittest.mock import AsyncMock, MagicMock, patch + +from django.core.cache import cache +from django.test import TestCase +from rest_framework.test import APITestCase + +from ami.base.serializers import reverse_with_params +from ami.jobs.models import Job, JobDispatchMode, JobState, MLJob +from ami.jobs.tasks import process_nats_pipeline_result +from ami.main.models import Detection, Project, SourceImage, SourceImageCollection +from ami.ml.models import Pipeline +from ami.ml.orchestration.task_state import TaskStateManager, _lock_key +from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse, SourceImageResponse +from ami.users.models import User + +logger = logging.getLogger(__name__) + + +class TestProcessNatsPipelineResultError(TestCase): + """E2E tests for process_nats_pipeline_result with error handling.""" + + def setUp(self): + """Setup test fixtures.""" + cache.clear() # Critical: clear Redis between tests + + self.project = Project.objects.create(name="Error Test Project") + self.pipeline = Pipeline.objects.create( + name="Test Pipeline", + slug="test-pipeline", + ) + self.pipeline.projects.add(self.project) + + self.collection = SourceImageCollection.objects.create( + name="Test Collection", + project=self.project, + ) + + self.job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test Error Handling Job", + pipeline=self.pipeline, + source_image_collection=self.collection, + dispatch_mode=JobDispatchMode.ASYNC_API, + ) + + # Create test images + self.images = [ + SourceImage.objects.create( + path=f"test_image_{i}.jpg", + public_base_url="http://example.com", + project=self.project, + ) + for i in range(3) + ] + + # Initialize state manager + self.image_ids = [str(img.pk) for img in self.images] + self.state_manager = TaskStateManager(self.job.pk) + self.state_manager.initialize_job(self.image_ids) + + def tearDown(self): + """Clean up after tests.""" + cache.clear() + + def _setup_mock_nats(self, mock_manager_class): + """Helper to setup mock NATS manager.""" + mock_manager = AsyncMock() + mock_manager.acknowledge_task = AsyncMock(return_value=True) + mock_manager_class.return_value.__aenter__.return_value = mock_manager + mock_manager_class.return_value.__aexit__.return_value = AsyncMock() + return mock_manager + + def _create_error_result(self, image_id: str | None = None, error_msg: str = "Processing failed") -> dict: + res_err = PipelineResultsError( + error=error_msg, + image_id=image_id, + ) + return res_err.dict() + + def _assert_progress_updated( + self, job_id: int, expected_processed: int, expected_total: int, stage: str = "process" + ): + """Assert TaskStateManager state is correct.""" + manager = TaskStateManager(job_id) + progress = manager.get_progress(stage) + self.assertIsNotNone(progress, f"Progress not found for stage '{stage}'") + self.assertEqual(progress.processed, expected_processed) + self.assertEqual(progress.total, expected_total) + + @patch("ami.jobs.tasks.TaskQueueManager") + def test_process_nats_pipeline_result_with_error(self, mock_manager_class): + """ + Test that PipelineResultsError is properly handled without saving to DB. + """ + mock_manager = self._setup_mock_nats(mock_manager_class) + + # Create error result data for first image + error_data = self._create_error_result( + image_id=str(self.images[0].pk), error_msg="Failed to process image: invalid format" + ) + reply_subject = "tasks.reply.test123" + + # Verify no detections exist before + initial_detection_count = Detection.objects.count() + + # Execute task using .apply() for synchronous testing + # This properly handles the bind=True decorator + process_nats_pipeline_result.apply( + kwargs={"job_id": self.job.pk, "result_data": error_data, "reply_subject": reply_subject} + ) + + # Assert: Progress was updated (1 of 3 images processed) + self._assert_progress_updated(self.job.pk, expected_processed=1, expected_total=3, stage="process") + self._assert_progress_updated(self.job.pk, expected_processed=1, expected_total=3, stage="results") + + # Assert: Job progress increased + self.job.refresh_from_db() + process_stage = next((s for s in self.job.progress.stages if s.key == "process"), None) + self.assertIsNotNone(process_stage) + self.assertGreater(process_stage.progress, 0) + self.assertLess(process_stage.progress, 1.0) # Not complete yet + + # Assert: Job status is still STARTED (not SUCCESS with incomplete stages) + self.assertNotEqual(self.job.status, JobState.SUCCESS.value) + + # Assert: NO detections were saved to database + self.assertEqual(Detection.objects.count(), initial_detection_count) + + mock_manager.acknowledge_task.assert_called_once_with(reply_subject) + + @patch("ami.jobs.tasks.TaskQueueManager") + def test_process_nats_pipeline_result_error_no_image_id(self, mock_manager_class): + """ + Test error handling when image_id is None. + + This tests the fallback: processed_image_ids = set() when no image_id. + """ + mock_manager = self._setup_mock_nats(mock_manager_class) + + # Create error result without image_id + error_data = self._create_error_result(error_msg="General pipeline failure", image_id=None) + reply_subject = "tasks.reply.test456" + + # Execute task using .apply() + process_nats_pipeline_result.apply( + kwargs={"job_id": self.job.pk, "result_data": error_data, "reply_subject": reply_subject} + ) + + # Assert: Progress was NOT updated (empty set of processed images) + # Since no image_id was provided, processed_image_ids = set() + manager = TaskStateManager(self.job.pk) + progress = manager.get_progress("process") + self.assertEqual(progress.processed, 0) # No images marked as processed + + mock_manager.acknowledge_task.assert_called_once_with(reply_subject) + + # Assert: No detections saved for this job's images + detections_for_job = Detection.objects.filter(source_image__in=self.images) + self.assertEqual(detections_for_job.count(), 0) + + @patch("ami.jobs.tasks.TaskQueueManager") + def test_process_nats_pipeline_result_mixed_results(self, mock_manager_class): + """ + Test realistic scenario with some images succeeding and others failing. + + Simulates processing batch where: + - Image 1: Error (PipelineResultsError) + - Image 2: Success with detections + - Image 3: Error (PipelineResultsError) + """ + mock_manager = self._setup_mock_nats(mock_manager_class) + + # For this test, we just want to verify progress tracking works with mixed results + # We'll skip checking final job completion status since that depends on all stages + + # Process error for image 1 + error_data_1 = self._create_error_result(image_id=str(self.images[0].pk), error_msg="Image 1 failed") + process_nats_pipeline_result.apply( + kwargs={"job_id": self.job.pk, "result_data": error_data_1, "reply_subject": "reply.1"} + ) + + # Process success for image 2 (simplified - just tracking progress without actual detections) + success_data_2 = PipelineResultsResponse( + pipeline="test-pipeline", + algorithms={}, + total_time=1.5, + source_images=[SourceImageResponse(id=str(self.images[1].pk), url="http://example.com/test_image_1.jpg")], + detections=[], + errors=None, + ).dict() + process_nats_pipeline_result.apply( + kwargs={"job_id": self.job.pk, "result_data": success_data_2, "reply_subject": "reply.2"} + ) + + # Process error for image 3 + error_data_3 = self._create_error_result(image_id=str(self.images[2].pk), error_msg="Image 3 failed") + process_nats_pipeline_result.apply( + kwargs={"job_id": self.job.pk, "result_data": error_data_3, "reply_subject": "reply.3"} + ) + + # Assert: All 3 images marked as processed in TaskStateManager + manager = TaskStateManager(self.job.pk) + process_progress = manager.get_progress("process") + self.assertIsNotNone(process_progress) + self.assertEqual(process_progress.processed, 3) + self.assertEqual(process_progress.total, 3) + self.assertEqual(process_progress.percentage, 1.0) + + results_progress = manager.get_progress("results") + self.assertIsNotNone(results_progress) + self.assertEqual(results_progress.processed, 3) + self.assertEqual(results_progress.total, 3) + self.assertEqual(results_progress.percentage, 1.0) + + # Assert: Job progress stages updated + self.job.refresh_from_db() + process_stage = next((s for s in self.job.progress.stages if s.key == "process"), None) + results_stage = next((s for s in self.job.progress.stages if s.key == "results"), None) + + self.assertIsNotNone(process_stage, "Process stage not found in job progress") + self.assertIsNotNone(results_stage, "Results stage not found in job progress") + + # Both should be at 100% + self.assertEqual(process_stage.progress, 1.0) + self.assertEqual(results_stage.progress, 1.0) + + # Assert: All tasks acknowledged + self.assertEqual(mock_manager.acknowledge_task.call_count, 3) + + @patch("ami.jobs.tasks.TaskQueueManager") + def test_process_nats_pipeline_result_error_concurrent_locking(self, mock_manager_class): + """ + Test that error results respect locking mechanism. + + Verifies race condition handling when multiple workers + process error results simultaneously. + """ + # Simulate lock held by another task + lock_key = _lock_key(self.job.pk) + cache.set(lock_key, "other-task-id", timeout=60) + + # Create error result + error_data = self._create_error_result(image_id=str(self.images[0].pk)) + reply_subject = "tasks.reply.test789" + + # Task should raise retry exception when lock not acquired + # The task internally calls self.retry() which raises a Retry exception + from celery.exceptions import Retry + + with self.assertRaises(Retry): + process_nats_pipeline_result.apply( + kwargs={ + "job_id": self.job.pk, + "result_data": error_data, + "reply_subject": reply_subject, + } + ) + + # Assert: Progress was NOT updated (lock not acquired) + manager = TaskStateManager(self.job.pk) + progress = manager.get_progress("process") + self.assertEqual(progress.processed, 0) + + @patch("ami.jobs.tasks.TaskQueueManager") + def test_process_nats_pipeline_result_error_job_not_found(self, mock_manager_class): + """ + Test graceful handling when job is deleted before error processed. + + From tasks.py lines 97-101, should log error and acknowledge without raising. + """ + mock_manager = self._setup_mock_nats(mock_manager_class) + + # Create error result + error_data = self._create_error_result(image_id=str(self.images[0].pk)) + reply_subject = "tasks.reply.test999" + + # Delete the job + deleted_job_id = self.job.pk + self.job.delete() + + # Should NOT raise exception - task should handle gracefully + process_nats_pipeline_result.apply( + kwargs={ + "job_id": deleted_job_id, + "result_data": error_data, + "reply_subject": reply_subject, + } + ) + + # Assert: Task was acknowledged despite missing job + mock_manager.acknowledge_task.assert_called_once_with(reply_subject) + + +class TestResultEndpointWithError(APITestCase): + """Integration test for the result API endpoint with error results.""" + + def setUp(self): + """Setup test fixtures.""" + cache.clear() + + self.user = User.objects.create_user( # type: ignore + email="testuser-error@insectai.org", + is_staff=True, + is_active=True, + is_superuser=True, + ) + + self.project = Project.objects.create(name="Error API Test Project") + self.pipeline = Pipeline.objects.create( + name="Test Pipeline for Errors", + slug="test-error-pipeline", + ) + self.pipeline.projects.add(self.project) + + self.collection = SourceImageCollection.objects.create( + name="Test Collection", + project=self.project, + ) + + self.job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="API Test Error Job", + pipeline=self.pipeline, + source_image_collection=self.collection, + dispatch_mode=JobDispatchMode.ASYNC_API, + ) + + self.image = SourceImage.objects.create( + path="test_error_image.jpg", + public_base_url="http://example.com", + project=self.project, + ) + + # Initialize state manager + state_manager = TaskStateManager(self.job.pk) + state_manager.initialize_job([str(self.image.pk)]) + + def tearDown(self): + """Clean up after tests.""" + cache.clear() + + @patch("ami.jobs.tasks.process_nats_pipeline_result.apply_async") + def test_result_endpoint_with_error_result(self, mock_apply_async): + """ + E2E test through the API endpoint that queues the task. + + Tests the full flow: API -> Celery task -> Error handling + """ + # Configure mock to return a proper task-like object with serializable id + + mock_result = MagicMock() + mock_result.id = "test-task-id-123" + mock_apply_async.return_value = mock_result + + self.client.force_authenticate(user=self.user) + result_url = reverse_with_params("api:job-result", args=[self.job.pk], params={"project_id": self.project.pk}) + + # Create error result data + result_data = [ + { + "reply_subject": "test.reply.error.1", + "result": { + "error": "Image processing timeout", + "image_id": str(self.image.pk), + }, + } + ] + + # POST error result to API + resp = self.client.post(result_url, result_data, format="json") + + # Assert: API accepted the error result + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(data["status"], "accepted") + self.assertEqual(data["job_id"], self.job.pk) + self.assertEqual(data["results_queued"], 1) + self.assertEqual(len(data["tasks"]), 1) + self.assertEqual(data["tasks"][0]["task_id"], "test-task-id-123") + self.assertEqual(data["tasks"][0]["status"], "queued") + + # Assert: Celery task was queued + mock_apply_async.assert_called_once() + + # Verify the task was called with correct arguments. + # NOTE on Celery calling convention: + # .delay(k1=v1, k2=v2) calls .apply_async((), {k1: v1, k2: v2}) + # i.e. two *positional* args to apply_async: an empty args tuple and a kwargs dict. + # This is NOT the same as apply_async(kwargs={...}) which uses a keyword argument. + # So mock.call_args[0] == ((), {task kwargs}) — a 2-element tuple. + call_args = mock_apply_async.call_args[0] + self.assertEqual(len(call_args), 2, "apply_async should be called with (args, kwargs)") + task_kwargs = call_args[1] # Second positional arg is the kwargs dict + self.assertEqual(task_kwargs["job_id"], self.job.pk) + self.assertEqual(task_kwargs["reply_subject"], "test.reply.error.1") + self.assertIn("error", task_kwargs["result_data"]) diff --git a/ami/ml/orchestration/task_state.py b/ami/ml/orchestration/task_state.py index 483275453..b05760e68 100644 --- a/ami/ml/orchestration/task_state.py +++ b/ami/ml/orchestration/task_state.py @@ -14,6 +14,10 @@ TaskProgress = namedtuple("TaskProgress", ["remaining", "total", "processed", "percentage"]) +def _lock_key(job_id: int) -> str: + return f"job:{job_id}:process_results_lock" + + class TaskStateManager: """ Manages job progress tracking state in Redis. @@ -64,7 +68,7 @@ def update_state( processed_image_ids: Set of image IDs that have just been processed """ # Create a unique lock key for this job - lock_key = f"job:{self.job_id}:process_results_lock" + lock_key = _lock_key(self.job_id) lock_timeout = 360 # 6 minutes (matches task time_limit) lock_acquired = cache.add(lock_key, request_id, timeout=lock_timeout) if not lock_acquired: @@ -82,16 +86,22 @@ def update_state( cache.delete(lock_key) logger.debug(f"Released lock for job {self.job_id}, task {request_id}") + def get_progress(self, stage: str) -> TaskProgress | None: + """Read-only progress snapshot for the given stage. Does not acquire a lock or mutate state.""" + pending_images = cache.get(self._get_pending_key(stage)) + total_images = cache.get(self._total_key) + if pending_images is None or total_images is None: + return None + remaining = len(pending_images) + processed = total_images - remaining + percentage = float(processed) / total_images if total_images > 0 else 1.0 + return TaskProgress(remaining=remaining, total=total_images, processed=processed, percentage=percentage) + def _get_progress(self, processed_image_ids: set[str], stage: str) -> TaskProgress | None: """ - Get current progress information for the job. - - Returns: - TaskProgress namedtuple with fields: - - remaining: Number of images still pending (or None if not tracked) - - total: Total number of images (or None if not tracked) - - processed: Number of images processed (or None if not tracked) - - percentage: Progress as float 0.0-1.0 (or None if not tracked) + Update pending images and return progress. Must be called under lock. + + Removes processed_image_ids from the pending set and persists the update. """ pending_images = cache.get(self._get_pending_key(stage)) total_images = cache.get(self._total_key) diff --git a/scripts/debug_tests.sh b/scripts/debug_tests.sh new file mode 100755 index 000000000..858c5729f --- /dev/null +++ b/scripts/debug_tests.sh @@ -0,0 +1,8 @@ +#!/bin/bash +# small helper to launch the test with debugpy enabled, allowing you to attach a VS Code debugger to the test run +# See the "Debug Tests" launch configuration in .vscode/launch.json + +# e.g. scripts/debug_tests.sh ami.jobs.test_tasks --keepdb +docker compose run --rm -p 5680:5680 django \ + python -m debugpy --listen 0.0.0.0:5680 --wait-for-client \ + manage.py test "$@" From 2554950b691cc007591ae0a340fe6c67608c8139 Mon Sep 17 00:00:00 2001 From: carlosgjs Date: Mon, 16 Feb 2026 21:20:11 -0800 Subject: [PATCH 11/14] PSv2: Track and display image count progress and state (#1121) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * merge * Update ML job counts in async case * Update date picker version and tweak layout logic (#1105) * fix: update date picker version and tweak layout logic * feat: set start month based on selected date * fix: Properly handle async job state with celery tasks (#1114) * merge * fix: Properly handle async job state with celery tasks * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Delete implemented plan --------- Co-authored-by: Carlos Garcia Jurado Suarez Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * PSv2: Implement queue clean-up upon job completion (#1113) * merge * feat: PSv2 - Queue/redis clean-up upon job completion * fix: catch specific exception * chore: move tests to a subdir --------- Co-authored-by: Carlos Garcia Jurado Suarez Co-authored-by: Michael Bunsen * fix: PSv2: Workers should not try to fetch tasks from v1 jobs (#1118) Introduces the dispatch_mode field on the Job model to track how each job dispatches its workload. This allows API clients (including the AMI worker) to filter jobs by dispatch mode — for example, fetching only async_api jobs so workers don't pull synchronous or internal jobs. JobDispatchMode enum (ami/jobs/models.py): internal — work handled entirely within the platform (Celery worker, no external calls). Default for all jobs. sync_api — worker calls an external processing service API synchronously and waits for each response. async_api — worker publishes items to NATS for external processing service workers to pick up independently. Database and Model Changes: Added dispatch_mode CharField with TextChoices, defaulting to internal, with the migration in ami/jobs/migrations/0019_job_dispatch_mode.py. ML jobs set dispatch_mode = async_api when the project's async_pipeline_workers feature flag is enabled. ML jobs set dispatch_mode = sync_api on the synchronous processing path (previously unset). API and Filtering: dispatch_mode is exposed (read-only) in job list and detail serializers. Filterable via query parameter: ?dispatch_mode=async_api The /tasks endpoint now returns 400 for non-async_api jobs, since only those have NATS tasks to fetch. Architecture doc: docs/claude/job-dispatch-modes.md documents the three modes, naming decisions, and per-job-type mapping. --------- Co-authored-by: Carlos Garcia Jurado Suarez Co-authored-by: Michael Bunsen Co-authored-by: Claude * PSv2 cleanup: use is_complete() and dispatch_mode in job progress handler (#1125) * refactor: use is_complete() and dispatch_mode in job progress handler Replace hardcoded `stage == "results"` check with `job.progress.is_complete()` which verifies ALL stages are done, making it work for any job type. Replace feature flag check in cleanup with `dispatch_mode == ASYNC_API` which is immutable for the job's lifetime and more correct than re-reading a mutable flag that could change between job creation and completion. Co-Authored-By: Claude * test: update cleanup tests for is_complete() and dispatch_mode checks Set dispatch_mode=ASYNC_API on test jobs to match the new cleanup guard. Complete all stages (collect, process, results) in the completion test since is_complete() correctly requires all stages to be done. Co-Authored-By: Claude --------- Co-authored-by: Claude * track captures and failures * Update tests, CR feedback, log error images * CR feedback * fix type checking * refactor: rename _get_progress to _commit_update in TaskStateManager Clarify naming to distinguish mutating vs read-only methods: - _commit_update(): private, writes mutations to Redis, returns progress - get_progress(): public, read-only snapshot (added in #1129) - update_state(): public API, acquires lock, calls _commit_update() Co-Authored-By: Claude * fix: unify FAILURE_THRESHOLD and convert TaskProgress to dataclass - Single FAILURE_THRESHOLD constant in tasks.py, imported by models.py - Fix async path to use `> FAILURE_THRESHOLD` (was `>=`) to match the sync path's boundary behavior at exactly 50% - Convert TaskProgress from namedtuple to dataclass with defaults, so new fields don't break existing callers Co-Authored-By: Claude * refactor: rename TaskProgress to JobStateProgress Clarify that this dataclass tracks job-level progress in Redis, not individual task/image progress. Aligns with the naming of JobProgress (the Django/Pydantic model equivalent). Co-Authored-By: Claude * docs: update NATS todo and planning docs with session learnings Mark connection handling as done (PR #1130), add worktree/remote mapping and docker testing notes for future sessions. Co-Authored-By: Claude * Rename TaskStateManager to AsyncJobStateManager * Track results counts in the job itself vs Redis * small simplification * Reset counts to 0 on reset * chore: remove local planning docs from PR branch Co-Authored-By: Claude * docs: clarify three-layer job state architecture in docstrings Explain the relationship between AsyncJobStateManager (Redis), JobProgress (JSONB), and JobState (enum). Clarify that all counts in JobStateProgress refer to source images (captures). Co-Authored-By: Claude --------- Co-authored-by: Carlos Garcia Jurado Suarez Co-authored-by: Anna Viklund Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Michael Bunsen Co-authored-by: Claude --- ami/jobs/models.py | 21 +- ami/jobs/tasks.py | 131 +++++++++++-- ami/jobs/test_tasks.py | 14 +- ami/ml/orchestration/async_job_state.py | 212 +++++++++++++++++++++ ami/ml/orchestration/jobs.py | 6 +- ami/ml/orchestration/task_state.py | 135 ------------- ami/ml/orchestration/tests/test_cleanup.py | 12 +- ami/ml/tests.py | 81 ++++++-- 8 files changed, 431 insertions(+), 181 deletions(-) create mode 100644 ami/ml/orchestration/async_job_state.py delete mode 100644 ami/ml/orchestration/task_state.py diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 8e790cdcb..b4df41a04 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -109,7 +109,7 @@ def python_slugify(value: str) -> str: class JobProgressSummary(pydantic.BaseModel): - """Summary of all stages of a job""" + """Top-level status and progress for a job, shown in the UI.""" status: JobState = JobState.CREATED progress: float = 0 @@ -132,7 +132,17 @@ class JobProgressStageDetail(ConfigurableStage, JobProgressSummary): class JobProgress(pydantic.BaseModel): - """The full progress of a job and its stages.""" + """ + The user-facing progress of a job, stored as JSONB on the Job model. + + This is what the UI displays and what external APIs read. Contains named + stages ("process", "results") with per-stage params (progress percentage, + detections/classifications/captures counts, failed count). + + For async (NATS) jobs, updated by _update_job_progress() in ami/jobs/tasks.py + which copies snapshots from the internal Redis-backed AsyncJobStateManager. + For sync jobs, updated directly in MLJob.process_images(). + """ summary: JobProgressSummary stages: list[JobProgressStageDetail] @@ -222,6 +232,10 @@ def reset(self, status: JobState = JobState.CREATED): for stage in self.stages: stage.progress = 0 stage.status = status + # Reset numeric param values to 0 + for param in stage.params: + if isinstance(param.value, (int, float)): + param.value = 0 def is_complete(self) -> bool: """ @@ -561,7 +575,8 @@ def process_images(cls, job, images): job.logger.info(f"All tasks completed for job {job.pk}") - FAILURE_THRESHOLD = 0.5 + from ami.jobs.tasks import FAILURE_THRESHOLD + if image_count and (percent_successful < FAILURE_THRESHOLD): job.progress.update_stage("process", status=JobState.FAILURE) job.save() diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 3548d0ea5..0abf85dae 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -3,18 +3,25 @@ import logging import time from collections.abc import Callable +from typing import TYPE_CHECKING from asgiref.sync import async_to_sync from celery.signals import task_failure, task_postrun, task_prerun from django.db import transaction +from ami.ml.orchestration.async_job_state import AsyncJobStateManager from ami.ml.orchestration.nats_queue import TaskQueueManager -from ami.ml.orchestration.task_state import TaskStateManager from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse from ami.tasks import default_soft_time_limit, default_time_limit from config import celery_app +if TYPE_CHECKING: + from ami.jobs.models import JobState + logger = logging.getLogger(__name__) +# Minimum success rate. Jobs with fewer than this fraction of images +# processed successfully are marked as failed. Also used in MLJob.process_images(). +FAILURE_THRESHOLD = 0.5 @celery_app.task(bind=True, soft_time_limit=default_soft_time_limit, time_limit=default_time_limit) @@ -59,23 +66,27 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub result_data: Dictionary containing the pipeline result reply_subject: NATS reply subject for acknowledgment """ - from ami.jobs.models import Job # avoid circular import + from ami.jobs.models import Job, JobState # avoid circular import _, t = log_time() # Validate with Pydantic - check for error response first + error_result = None if "error" in result_data: error_result = PipelineResultsError(**result_data) processed_image_ids = {str(error_result.image_id)} if error_result.image_id else set() - logger.error(f"Pipeline returned error for job {job_id}, image {error_result.image_id}: {error_result.error}") + failed_image_ids = processed_image_ids # Same as processed for errors pipeline_result = None else: pipeline_result = PipelineResultsResponse(**result_data) processed_image_ids = {str(img.id) for img in pipeline_result.source_images} + failed_image_ids = set() # No failures for successful results - state_manager = TaskStateManager(job_id) + state_manager = AsyncJobStateManager(job_id) - progress_info = state_manager.update_state(processed_image_ids, stage="process", request_id=self.request.id) + progress_info = state_manager.update_state( + processed_image_ids, stage="process", request_id=self.request.id, failed_image_ids=failed_image_ids + ) if not progress_info: logger.warning( f"Another task is already processing results for job {job_id}. " @@ -84,16 +95,31 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub raise self.retry(countdown=5, max_retries=10) try: - _update_job_progress(job_id, "process", progress_info.percentage) + complete_state = JobState.SUCCESS + if progress_info.total > 0 and (progress_info.failed / progress_info.total) > FAILURE_THRESHOLD: + complete_state = JobState.FAILURE + _update_job_progress( + job_id, + "process", + progress_info.percentage, + complete_state=complete_state, + processed=progress_info.processed, + remaining=progress_info.remaining, + failed=progress_info.failed, + ) _, t = t(f"TIME: Updated job {job_id} progress in PROCESS stage progress to {progress_info.percentage*100}%") job = Job.objects.get(pk=job_id) job.logger.info(f"Processing pipeline result for job {job_id}, reply_subject: {reply_subject}") job.logger.info( f" Job {job_id} progress: {progress_info.processed}/{progress_info.total} images processed " - f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {len(processed_image_ids)} just " - "processed" + f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {progress_info.failed} failed, " + f"{len(processed_image_ids)} just processed" ) + if error_result: + job.logger.error( + f"Pipeline returned error for job {job_id}, image {error_result.image_id}: {error_result.error}" + ) except Job.DoesNotExist: # don't raise and ack so that we don't retry since the job doesn't exists logger.error(f"Job {job_id} not found") @@ -102,6 +128,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub try: # Save to database (this is the slow operation) + detections_count, classifications_count, captures_count = 0, 0, 0 if pipeline_result: # should never happen since otherwise we could not be processing results here assert job.pipeline is not None, "Job pipeline is None" @@ -112,10 +139,19 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub f"Saved pipeline results to database with {len(pipeline_result.detections)} detections" f", percentage: {progress_info.percentage*100}%" ) + # Calculate detection and classification counts from this result + detections_count = len(pipeline_result.detections) + classifications_count = sum(len(detection.classifications) for detection in pipeline_result.detections) + captures_count = len(pipeline_result.source_images) _ack_task_via_nats(reply_subject, job.logger) # Update job stage with calculated progress - progress_info = state_manager.update_state(processed_image_ids, stage="results", request_id=self.request.id) + + progress_info = state_manager.update_state( + processed_image_ids, + stage="results", + request_id=self.request.id, + ) if not progress_info: logger.warning( @@ -123,7 +159,21 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub f"Retrying task {self.request.id} in 5 seconds..." ) raise self.retry(countdown=5, max_retries=10) - _update_job_progress(job_id, "results", progress_info.percentage) + + # update complete state based on latest progress info after saving results + complete_state = JobState.SUCCESS + if progress_info.total > 0 and (progress_info.failed / progress_info.total) > FAILURE_THRESHOLD: + complete_state = JobState.FAILURE + + _update_job_progress( + job_id, + "results", + progress_info.percentage, + complete_state=complete_state, + detections=detections_count, + classifications=classifications_count, + captures=captures_count, + ) except Exception as e: job.logger.error( @@ -149,19 +199,72 @@ async def ack_task(): # Don't fail the task if ACK fails - data is already saved -def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> None: +def _get_current_counts_from_job_progress(job, stage: str) -> tuple[int, int, int]: + """ + Get current detections, classifications, and captures counts from job progress. + + Args: + job: The Job instance + stage: The stage name to read counts from + + Returns: + Tuple of (detections, classifications, captures) counts, defaulting to 0 if not found + """ + try: + stage_obj = job.progress.get_stage(stage) + + # Initialize defaults + detections = 0 + classifications = 0 + captures = 0 + + # Search through the params list for our count values + for param in stage_obj.params: + if param.key == "detections": + detections = param.value or 0 + elif param.key == "classifications": + classifications = param.value or 0 + elif param.key == "captures": + captures = param.value or 0 + + return detections, classifications, captures + except (ValueError, AttributeError): + # Stage doesn't exist or doesn't have these attributes yet + return 0, 0, 0 + + +def _update_job_progress( + job_id: int, stage: str, progress_percentage: float, complete_state: "JobState", **state_params +) -> None: from ami.jobs.models import Job, JobState # avoid circular import with transaction.atomic(): job = Job.objects.select_for_update().get(pk=job_id) + + # For results stage, accumulate detections/classifications/captures counts + if stage == "results": + current_detections, current_classifications, current_captures = _get_current_counts_from_job_progress( + job, stage + ) + + # Add new counts to existing counts + new_detections = state_params.get("detections", 0) + new_classifications = state_params.get("classifications", 0) + new_captures = state_params.get("captures", 0) + + state_params["detections"] = current_detections + new_detections + state_params["classifications"] = current_classifications + new_classifications + state_params["captures"] = current_captures + new_captures + job.progress.update_stage( stage, - status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, + status=complete_state if progress_percentage >= 1.0 else JobState.STARTED, progress=progress_percentage, + **state_params, ) if job.progress.is_complete(): - job.status = JobState.SUCCESS - job.progress.summary.status = JobState.SUCCESS + job.status = complete_state + job.progress.summary.status = complete_state job.finished_at = datetime.datetime.now() # Use naive datetime in local time job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%") job.save() diff --git a/ami/jobs/test_tasks.py b/ami/jobs/test_tasks.py index fc291c8ba..b37940cdd 100644 --- a/ami/jobs/test_tasks.py +++ b/ami/jobs/test_tasks.py @@ -17,7 +17,7 @@ from ami.jobs.tasks import process_nats_pipeline_result from ami.main.models import Detection, Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline -from ami.ml.orchestration.task_state import TaskStateManager, _lock_key +from ami.ml.orchestration.async_job_state import AsyncJobStateManager, _lock_key from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse, SourceImageResponse from ami.users.models import User @@ -64,7 +64,7 @@ def setUp(self): # Initialize state manager self.image_ids = [str(img.pk) for img in self.images] - self.state_manager = TaskStateManager(self.job.pk) + self.state_manager = AsyncJobStateManager(self.job.pk) self.state_manager.initialize_job(self.image_ids) def tearDown(self): @@ -90,7 +90,7 @@ def _assert_progress_updated( self, job_id: int, expected_processed: int, expected_total: int, stage: str = "process" ): """Assert TaskStateManager state is correct.""" - manager = TaskStateManager(job_id) + manager = AsyncJobStateManager(job_id) progress = manager.get_progress(stage) self.assertIsNotNone(progress, f"Progress not found for stage '{stage}'") self.assertEqual(progress.processed, expected_processed) @@ -157,7 +157,7 @@ def test_process_nats_pipeline_result_error_no_image_id(self, mock_manager_class # Assert: Progress was NOT updated (empty set of processed images) # Since no image_id was provided, processed_image_ids = set() - manager = TaskStateManager(self.job.pk) + manager = AsyncJobStateManager(self.job.pk) progress = manager.get_progress("process") self.assertEqual(progress.processed, 0) # No images marked as processed @@ -208,7 +208,7 @@ def test_process_nats_pipeline_result_mixed_results(self, mock_manager_class): ) # Assert: All 3 images marked as processed in TaskStateManager - manager = TaskStateManager(self.job.pk) + manager = AsyncJobStateManager(self.job.pk) process_progress = manager.get_progress("process") self.assertIsNotNone(process_progress) self.assertEqual(process_progress.processed, 3) @@ -266,7 +266,7 @@ def test_process_nats_pipeline_result_error_concurrent_locking(self, mock_manage ) # Assert: Progress was NOT updated (lock not acquired) - manager = TaskStateManager(self.job.pk) + manager = AsyncJobStateManager(self.job.pk) progress = manager.get_progress("process") self.assertEqual(progress.processed, 0) @@ -342,7 +342,7 @@ def setUp(self): ) # Initialize state manager - state_manager = TaskStateManager(self.job.pk) + state_manager = AsyncJobStateManager(self.job.pk) state_manager.initialize_job([str(self.image.pk)]) def tearDown(self): diff --git a/ami/ml/orchestration/async_job_state.py b/ami/ml/orchestration/async_job_state.py new file mode 100644 index 000000000..5a300c12a --- /dev/null +++ b/ami/ml/orchestration/async_job_state.py @@ -0,0 +1,212 @@ +""" +Internal progress tracking for async (NATS) job processing, backed by Redis. + +Multiple Celery workers process image batches concurrently and report progress +here using Redis for atomic updates with locking. This module is purely internal +— nothing outside the worker pipeline reads from it directly. + +How this relates to the Job model (ami/jobs/models.py): + + The **Job model** is the primary, external-facing record. It is what users see + in the UI, what external APIs interact with (listing open jobs, fetching tasks + to process), and what persists as history in the CMS. It has two relevant fields: + + - **Job.status** (JobState enum) — lifecycle state (CREATED → STARTED → SUCCESS/FAILURE) + - **Job.progress** (JobProgress JSONB) — detailed stage progress with params + like detections, classifications, captures counts + + This Redis layer exists only because concurrent NATS workers need atomic + counters that PostgreSQL row locks would serialize too aggressively. After each + batch, _update_job_progress() in ami/jobs/tasks.py copies the Redis snapshot + into the Job model, which is the source of truth for everything external. + +Flow: NATS result → AsyncJobStateManager.update_state() (Redis, internal) + → _update_job_progress() (writes to Job model) → UI / API reads Job +""" + +import logging +from dataclasses import dataclass + +from django.core.cache import cache + +logger = logging.getLogger(__name__) + + +@dataclass +class JobStateProgress: + """ + Progress snapshot for a job stage, read from Redis. + + All counts refer to source images (captures), not detections or occurrences. + Currently specific to ML pipeline jobs — if other job types are made available + for external processing, the unit of work ("source image") and failure semantics + may need to be generalized. + """ + + remaining: int = 0 # source images not yet processed in this stage + total: int = 0 # total source images in the job + processed: int = 0 # source images completed (success + failed) + percentage: float = 0.0 # processed / total + failed: int = 0 # source images that returned an error from the processing service + + +def _lock_key(job_id: int) -> str: + return f"job:{job_id}:process_results_lock" + + +class AsyncJobStateManager: + """ + Manages real-time job progress in Redis for concurrent NATS workers. + + Each job has per-stage pending image lists and a shared failed image set. + Workers acquire a Redis lock before mutating state, ensuring atomic updates + even when multiple Celery tasks process batches in parallel. + + The results are ephemeral — _update_job_progress() in ami/jobs/tasks.py + copies each snapshot into the persistent Job.progress JSONB field. + """ + + TIMEOUT = 86400 * 7 # 7 days in seconds + STAGES = ["process", "results"] + + def __init__(self, job_id: int): + """ + Initialize the task state manager for a specific job. + + Args: + job_id: The job primary key + """ + self.job_id = job_id + self._pending_key = f"job:{job_id}:pending_images" + self._total_key = f"job:{job_id}:pending_images_total" + self._failed_key = f"job:{job_id}:failed_images" + + def initialize_job(self, image_ids: list[str]) -> None: + """ + Initialize job tracking with a list of image IDs to process. + + Args: + image_ids: List of image IDs that need to be processed + """ + for stage in self.STAGES: + cache.set(self._get_pending_key(stage), image_ids, timeout=self.TIMEOUT) + + # Initialize failed images set for process stage only + cache.set(self._failed_key, set(), timeout=self.TIMEOUT) + + cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT) + + def _get_pending_key(self, stage: str) -> str: + return f"{self._pending_key}:{stage}" + + def update_state( + self, + processed_image_ids: set[str], + stage: str, + request_id: str, + failed_image_ids: set[str] | None = None, + ) -> None | JobStateProgress: + """ + Update the task state with newly processed images. + + Args: + processed_image_ids: Set of image IDs that have just been processed + stage: The processing stage ("process" or "results") + request_id: Unique identifier for this processing request + detections_count: Number of detections to add to cumulative count + classifications_count: Number of classifications to add to cumulative count + captures_count: Number of captures to add to cumulative count + failed_image_ids: Set of image IDs that failed processing (optional) + """ + # Create a unique lock key for this job + lock_key = _lock_key(self.job_id) + lock_timeout = 360 # 6 minutes (matches task time_limit) + lock_acquired = cache.add(lock_key, request_id, timeout=lock_timeout) + if not lock_acquired: + return None + + try: + # Update progress tracking in Redis + progress_info = self._commit_update(processed_image_ids, stage, failed_image_ids) + return progress_info + finally: + # Always release the lock when done + current_lock_value = cache.get(lock_key) + # Only delete if we still own the lock (prevents race condition) + if current_lock_value == request_id: + cache.delete(lock_key) + logger.debug(f"Released lock for job {self.job_id}, task {request_id}") + + def get_progress(self, stage: str) -> JobStateProgress | None: + """Read-only progress snapshot for the given stage. Does not acquire a lock or mutate state.""" + pending_images = cache.get(self._get_pending_key(stage)) + total_images = cache.get(self._total_key) + if pending_images is None or total_images is None: + return None + remaining = len(pending_images) + processed = total_images - remaining + percentage = float(processed) / total_images if total_images > 0 else 1.0 + failed_set = cache.get(self._failed_key) or set() + return JobStateProgress( + remaining=remaining, + total=total_images, + processed=processed, + percentage=percentage, + failed=len(failed_set), + ) + + def _commit_update( + self, + processed_image_ids: set[str], + stage: str, + failed_image_ids: set[str] | None = None, + ) -> JobStateProgress | None: + """ + Update pending images and return progress. Must be called under lock. + + Removes processed_image_ids from the pending set and persists the update. + """ + pending_images = cache.get(self._get_pending_key(stage)) + total_images = cache.get(self._total_key) + if pending_images is None or total_images is None: + return None + remaining_images = [img_id for img_id in pending_images if img_id not in processed_image_ids] + assert len(pending_images) >= len(remaining_images) + cache.set(self._get_pending_key(stage), remaining_images, timeout=self.TIMEOUT) + + remaining = len(remaining_images) + processed = total_images - remaining + percentage = float(processed) / total_images if total_images > 0 else 1.0 + + # Update failed images set if provided + if failed_image_ids: + existing_failed = cache.get(self._failed_key) or set() + updated_failed = existing_failed | failed_image_ids # Union to prevent duplicates + cache.set(self._failed_key, updated_failed, timeout=self.TIMEOUT) + failed_set = updated_failed + else: + failed_set = cache.get(self._failed_key) or set() + + failed_count = len(failed_set) + + logger.info( + f"Pending images from Redis for job {self.job_id} {stage}: " + f"{remaining}/{total_images}: {percentage*100}%" + ) + + return JobStateProgress( + remaining=remaining, + total=total_images, + processed=processed, + percentage=percentage, + failed=failed_count, + ) + + def cleanup(self) -> None: + """ + Delete all Redis keys associated with this job. + """ + for stage in self.STAGES: + cache.delete(self._get_pending_key(stage)) + cache.delete(self._failed_key) + cache.delete(self._total_key) diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index 9b4a577e9..ce54ecd1c 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -4,8 +4,8 @@ from ami.jobs.models import Job, JobState from ami.main.models import SourceImage +from ami.ml.orchestration.async_job_state import AsyncJobStateManager from ami.ml.orchestration.nats_queue import TaskQueueManager -from ami.ml.orchestration.task_state import TaskStateManager from ami.ml.schemas import PipelineProcessingTask logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ def cleanup_async_job_resources(job: "Job") -> bool: # Cleanup Redis state try: - state_manager = TaskStateManager(job.pk) + state_manager = AsyncJobStateManager(job.pk) state_manager.cleanup() job.logger.info(f"Cleaned up Redis state for job {job.pk}") redis_success = True @@ -88,7 +88,7 @@ def queue_images_to_nats(job: "Job", images: list[SourceImage]): tasks.append((image.pk, task)) # Store all image IDs in Redis for progress tracking - state_manager = TaskStateManager(job.pk) + state_manager = AsyncJobStateManager(job.pk) state_manager.initialize_job(image_ids) job.logger.info(f"Initialized task state tracking for {len(image_ids)} images") diff --git a/ami/ml/orchestration/task_state.py b/ami/ml/orchestration/task_state.py deleted file mode 100644 index b05760e68..000000000 --- a/ami/ml/orchestration/task_state.py +++ /dev/null @@ -1,135 +0,0 @@ -""" -Task state management for job progress tracking using Redis. -""" - -import logging -from collections import namedtuple - -from django.core.cache import cache - -logger = logging.getLogger(__name__) - - -# Define a namedtuple for a TaskProgress with the image counts -TaskProgress = namedtuple("TaskProgress", ["remaining", "total", "processed", "percentage"]) - - -def _lock_key(job_id: int) -> str: - return f"job:{job_id}:process_results_lock" - - -class TaskStateManager: - """ - Manages job progress tracking state in Redis. - - Tracks pending images for jobs to calculate progress percentages - as workers process images asynchronously. - """ - - TIMEOUT = 86400 * 7 # 7 days in seconds - STAGES = ["process", "results"] - - def __init__(self, job_id: int): - """ - Initialize the task state manager for a specific job. - - Args: - job_id: The job primary key - """ - self.job_id = job_id - self._pending_key = f"job:{job_id}:pending_images" - self._total_key = f"job:{job_id}:pending_images_total" - - def initialize_job(self, image_ids: list[str]) -> None: - """ - Initialize job tracking with a list of image IDs to process. - - Args: - image_ids: List of image IDs that need to be processed - """ - for stage in self.STAGES: - cache.set(self._get_pending_key(stage), image_ids, timeout=self.TIMEOUT) - - cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT) - - def _get_pending_key(self, stage: str) -> str: - return f"{self._pending_key}:{stage}" - - def update_state( - self, - processed_image_ids: set[str], - stage: str, - request_id: str, - ) -> None | TaskProgress: - """ - Update the task state with newly processed images. - - Args: - processed_image_ids: Set of image IDs that have just been processed - """ - # Create a unique lock key for this job - lock_key = _lock_key(self.job_id) - lock_timeout = 360 # 6 minutes (matches task time_limit) - lock_acquired = cache.add(lock_key, request_id, timeout=lock_timeout) - if not lock_acquired: - return None - - try: - # Update progress tracking in Redis - progress_info = self._get_progress(processed_image_ids, stage) - return progress_info - finally: - # Always release the lock when done - current_lock_value = cache.get(lock_key) - # Only delete if we still own the lock (prevents race condition) - if current_lock_value == request_id: - cache.delete(lock_key) - logger.debug(f"Released lock for job {self.job_id}, task {request_id}") - - def get_progress(self, stage: str) -> TaskProgress | None: - """Read-only progress snapshot for the given stage. Does not acquire a lock or mutate state.""" - pending_images = cache.get(self._get_pending_key(stage)) - total_images = cache.get(self._total_key) - if pending_images is None or total_images is None: - return None - remaining = len(pending_images) - processed = total_images - remaining - percentage = float(processed) / total_images if total_images > 0 else 1.0 - return TaskProgress(remaining=remaining, total=total_images, processed=processed, percentage=percentage) - - def _get_progress(self, processed_image_ids: set[str], stage: str) -> TaskProgress | None: - """ - Update pending images and return progress. Must be called under lock. - - Removes processed_image_ids from the pending set and persists the update. - """ - pending_images = cache.get(self._get_pending_key(stage)) - total_images = cache.get(self._total_key) - if pending_images is None or total_images is None: - return None - remaining_images = [img_id for img_id in pending_images if img_id not in processed_image_ids] - assert len(pending_images) >= len(remaining_images) - cache.set(self._get_pending_key(stage), remaining_images, timeout=self.TIMEOUT) - - remaining = len(remaining_images) - processed = total_images - remaining - percentage = float(processed) / total_images if total_images > 0 else 1.0 - logger.info( - f"Pending images from Redis for job {self.job_id} {stage}: " - f"{remaining}/{total_images}: {percentage*100}%" - ) - - return TaskProgress( - remaining=remaining, - total=total_images, - processed=processed, - percentage=percentage, - ) - - def cleanup(self) -> None: - """ - Delete all Redis keys associated with this job. - """ - for stage in self.STAGES: - cache.delete(self._get_pending_key(stage)) - cache.delete(self._total_key) diff --git a/ami/ml/orchestration/tests/test_cleanup.py b/ami/ml/orchestration/tests/test_cleanup.py index ef8382d3d..ccdfa2c49 100644 --- a/ami/ml/orchestration/tests/test_cleanup.py +++ b/ami/ml/orchestration/tests/test_cleanup.py @@ -9,9 +9,9 @@ from ami.jobs.tasks import _update_job_progress, update_job_failure, update_job_status from ami.main.models import Project, ProjectFeatureFlags, SourceImage, SourceImageCollection from ami.ml.models import Pipeline +from ami.ml.orchestration.async_job_state import AsyncJobStateManager from ami.ml.orchestration.jobs import queue_images_to_nats from ami.ml.orchestration.nats_queue import TaskQueueManager -from ami.ml.orchestration.task_state import TaskStateManager class TestCleanupAsyncJobResources(TestCase): @@ -59,7 +59,7 @@ def _verify_resources_created(self, job_id: int): job_id: The job ID to check """ # Verify Redis keys exist - state_manager = TaskStateManager(job_id) + state_manager = AsyncJobStateManager(job_id) for stage in state_manager.STAGES: pending_key = state_manager._get_pending_key(stage) self.assertIsNotNone(cache.get(pending_key), f"Redis key {pending_key} should exist") @@ -125,7 +125,7 @@ def _verify_resources_cleaned(self, job_id: int): job_id: The job ID to check """ # Verify Redis keys are deleted - state_manager = TaskStateManager(job_id) + state_manager = AsyncJobStateManager(job_id) for stage in state_manager.STAGES: pending_key = state_manager._get_pending_key(stage) self.assertIsNone(cache.get(pending_key), f"Redis key {pending_key} should be deleted") @@ -164,9 +164,9 @@ def test_cleanup_on_job_completion(self): job = self._create_job_with_queued_images() # Simulate job completion: complete all stages (collect, process, then results) - _update_job_progress(job.pk, stage="collect", progress_percentage=1.0) - _update_job_progress(job.pk, stage="process", progress_percentage=1.0) - _update_job_progress(job.pk, stage="results", progress_percentage=1.0) + _update_job_progress(job.pk, stage="collect", progress_percentage=1.0, complete_state=JobState.SUCCESS) + _update_job_progress(job.pk, stage="process", progress_percentage=1.0, complete_state=JobState.SUCCESS) + _update_job_progress(job.pk, stage="results", progress_percentage=1.0, complete_state=JobState.SUCCESS) # Verify cleanup happened self._verify_resources_cleaned(job.pk) diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 20e0368fe..2f269a6f9 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -864,22 +864,23 @@ def setUp(self): """Set up test fixtures.""" from django.core.cache import cache - from ami.ml.orchestration.task_state import TaskStateManager + from ami.ml.orchestration.async_job_state import AsyncJobStateManager cache.clear() self.job_id = 123 - self.manager = TaskStateManager(self.job_id) + self.manager = AsyncJobStateManager(self.job_id) self.image_ids = ["img1", "img2", "img3", "img4", "img5"] def _init_and_verify(self, image_ids): """Helper to initialize job and verify initial state.""" self.manager.initialize_job(image_ids) - progress = self.manager._get_progress(set(), "process") + progress = self.manager._commit_update(set(), "process") assert progress is not None self.assertEqual(progress.total, len(image_ids)) self.assertEqual(progress.remaining, len(image_ids)) self.assertEqual(progress.processed, 0) self.assertEqual(progress.percentage, 0.0) + self.assertEqual(progress.failed, 0) return progress def test_initialize_job(self): @@ -888,30 +889,31 @@ def test_initialize_job(self): # Verify both stages are initialized for stage in self.manager.STAGES: - progress = self.manager._get_progress(set(), stage) + progress = self.manager._commit_update(set(), stage) assert progress is not None self.assertEqual(progress.total, len(self.image_ids)) + self.assertEqual(progress.failed, 0) def test_progress_tracking(self): """Test progress updates correctly as images are processed.""" self._init_and_verify(self.image_ids) # Process 2 images - progress = self.manager._get_progress({"img1", "img2"}, "process") + progress = self.manager._commit_update({"img1", "img2"}, "process") assert progress is not None self.assertEqual(progress.remaining, 3) self.assertEqual(progress.processed, 2) self.assertEqual(progress.percentage, 0.4) # Process 2 more images - progress = self.manager._get_progress({"img3", "img4"}, "process") + progress = self.manager._commit_update({"img3", "img4"}, "process") assert progress is not None self.assertEqual(progress.remaining, 1) self.assertEqual(progress.processed, 4) self.assertEqual(progress.percentage, 0.8) # Process last image - progress = self.manager._get_progress({"img5"}, "process") + progress = self.manager._commit_update({"img5"}, "process") assert progress is not None self.assertEqual(progress.remaining, 0) self.assertEqual(progress.processed, 5) @@ -947,20 +949,20 @@ def test_stages_independent(self): self._init_and_verify(self.image_ids) # Update process stage - self.manager._get_progress({"img1", "img2"}, "process") - progress_process = self.manager._get_progress(set(), "process") + self.manager._commit_update({"img1", "img2"}, "process") + progress_process = self.manager._commit_update(set(), "process") assert progress_process is not None self.assertEqual(progress_process.remaining, 3) # Results stage should still have all images pending - progress_results = self.manager._get_progress(set(), "results") + progress_results = self.manager._commit_update(set(), "results") assert progress_results is not None self.assertEqual(progress_results.remaining, 5) def test_empty_job(self): """Test handling of job with no images.""" self.manager.initialize_job([]) - progress = self.manager._get_progress(set(), "process") + progress = self.manager._commit_update(set(), "process") assert progress is not None self.assertEqual(progress.total, 0) self.assertEqual(progress.percentage, 1.0) # Empty job is 100% complete @@ -970,12 +972,65 @@ def test_cleanup(self): self._init_and_verify(self.image_ids) # Verify keys exist - progress = self.manager._get_progress(set(), "process") + progress = self.manager._commit_update(set(), "process") self.assertIsNotNone(progress) # Cleanup self.manager.cleanup() # Verify keys are gone - progress = self.manager._get_progress(set(), "process") + progress = self.manager._commit_update(set(), "process") self.assertIsNone(progress) + + def test_failed_image_tracking(self): + """Test basic failed image tracking with no double-counting on retries.""" + self._init_and_verify(self.image_ids) + + # Mark 2 images as failed in process stage + progress = self.manager._commit_update({"img1", "img2"}, "process", failed_image_ids={"img1", "img2"}) + assert progress is not None + self.assertEqual(progress.failed, 2) + + # Retry same 2 images (fail again) - should not double-count + progress = self.manager._commit_update(set(), "process", failed_image_ids={"img1", "img2"}) + assert progress is not None + self.assertEqual(progress.failed, 2) + + # Fail a different image + progress = self.manager._commit_update(set(), "process", failed_image_ids={"img3"}) + assert progress is not None + self.assertEqual(progress.failed, 3) + + def test_failed_and_processed_mixed(self): + """Test mixed successful and failed processing in same batch.""" + self._init_and_verify(self.image_ids) + + # Process 2 successfully, 2 fail, 1 remains pending + progress = self.manager._commit_update( + {"img1", "img2", "img3", "img4"}, "process", failed_image_ids={"img3", "img4"} + ) + assert progress is not None + self.assertEqual(progress.processed, 4) + self.assertEqual(progress.failed, 2) + self.assertEqual(progress.remaining, 1) + self.assertEqual(progress.percentage, 0.8) + + def test_cleanup_removes_failed_set(self): + """Test that cleanup removes failed image set.""" + from django.core.cache import cache + + self._init_and_verify(self.image_ids) + + # Add failed images + self.manager._commit_update({"img1", "img2"}, "process", failed_image_ids={"img1", "img2"}) + + # Verify failed set exists + failed_set = cache.get(self.manager._failed_key) + self.assertEqual(len(failed_set), 2) + + # Cleanup + self.manager.cleanup() + + # Verify failed set is gone + failed_set = cache.get(self.manager._failed_key) + self.assertIsNone(failed_set) From 9ae607c86861ba197c3a9e1ab9bf5a9d8119fe95 Mon Sep 17 00:00:00 2001 From: carlosgjs Date: Mon, 16 Feb 2026 21:34:28 -0800 Subject: [PATCH 12/14] PSV2: API endpoint for external processing services to register pipelines (#1076) * RFC: V2 endpoint to register pipeliens * merge * Allow null enpoint_url for processing services * Add tests * Add processing_service_name * Tests for pipeline registration * Add default (None) for endpoint_url Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * Better assert Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * CR feedback * Simplify registration payload * CR feedback * fix test * chore: remove old comment * chore: update display name * feat: filter with db query instead of python, update docstring * docs: add plan for migrating to an existing DRF pattern * Refactor pipeline registration to nested DRF route (#10) * feat: add ProjectNestedPermission for nested project routes Reusable permission class for nested routes under /projects/{pk}/. Allows read access to any user, write access to project owners, superusers, and ProjectManagers. Designed for pipelines, tags, taxa lists, and similar nested resources. Co-Authored-By: Claude * refactor: move pipeline registration to nested DRF route Replace the pipelines action on ProjectViewSet with a proper nested ViewSet at /api/v2/projects/{project_pk}/pipelines/. Adds GET (list) and POST (register) using standard DRF patterns with SchemaField for pydantic validation, transaction.atomic() for DB ops, and idempotent re-registration. - Add PipelineRegistrationSerializer in ami/ml/serializers.py - Add ProjectPipelineViewSet in ami/ml/views.py - Register nested route in config/api_router.py - Remove old pipelines action + unused imports from ProjectViewSet Co-Authored-By: Claude * test: update pipeline registration tests for nested route - Payload uses flat {processing_service_name, pipelines} format - Success returns 201 instead of 200 - Re-registration is now idempotent (no longer returns 400) - Add test_list_pipelines for GET endpoint Co-Authored-By: Claude * feat: add guardian permissions for ProjectPipelineConfig Add CREATE/UPDATE/DELETE_PROJECT_PIPELINE_CONFIG to Project.Permissions and Meta.permissions. Assign create permission to MLDataManager and all three to ProjectManager, enabling granular access control for pipeline registration instead of the coarse ProjectManager.has_role() check. Co-Authored-By: Claude * refactor: replace ProjectNestedPermission with ProjectPipelineConfigPermission Replace the generic ProjectNestedPermission with ProjectPipelineConfigPermission following the UserMembershipPermission pattern. The new class extends ObjectPermission and creates a temporary ProjectPipelineConfig instance to leverage BaseModel.check_permission(), which handles draft project visibility and guardian permission checks. Update ProjectPipelineViewSet to use ProjectMixin with require_project=True instead of manual kwargs lookups. Co-Authored-By: Claude * test: add permission tests for project pipelines endpoint Update setUp to use create_roles_for_project and guardian permissions instead of is_staff=True. Add tests for draft project access (403 for non-members), unauthenticated writes (401/403), and public project reads (200 for non-members). Co-Authored-By: Claude * test: verify list pipelines response contains project pipelines Co-Authored-By: Claude --------- Co-authored-by: Claude * chore: remove planning doc from PR branch Co-Authored-By: Claude --------- Co-authored-by: Carlos Garcia Jurado Suarez Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Michael Bunsen Co-authored-by: Claude --- ami/base/permissions.py | 26 ++++ ami/main/models.py | 15 +- ami/main/tests.py | 137 ++++++++++++++++++ ...rocessing_service_endpoint_url_nullable.py | 17 +++ ami/ml/models/processing_service.py | 21 ++- ami/ml/schemas.py | 11 +- ami/ml/serializers.py | 6 + ami/ml/tasks.py | 7 +- ami/ml/tests.py | 45 ++++++ ami/ml/views.py | 79 +++++++++- ami/users/roles.py | 4 + config/api_router.py | 5 + 12 files changed, 364 insertions(+), 9 deletions(-) create mode 100644 ami/ml/migrations/0026_make_processing_service_endpoint_url_nullable.py diff --git a/ami/base/permissions.py b/ami/base/permissions.py index bb143be23..3c4b6c053 100644 --- a/ami/base/permissions.py +++ b/ami/base/permissions.py @@ -89,6 +89,32 @@ def has_object_permission(self, request, view, obj: BaseModel): return obj.check_permission(request.user, view.action) +class ProjectPipelineConfigPermission(ObjectPermission): + """ + Permission for the nested project pipelines route (/projects/{pk}/pipelines/). + + Extends ObjectPermission to handle list/create actions where no object exists yet. + Creates a temporary ProjectPipelineConfig instance to leverage BaseModel.check_permission(), + which handles draft project visibility and guardian permission checks automatically. + + Follows the same pattern as UserMembershipPermission. + """ + + def has_permission(self, request, view): + from ami.ml.models.project_pipeline_config import ProjectPipelineConfig + + if view.action in ("list", "create"): + project = view.get_active_project() + if not project: + return False + + config = ProjectPipelineConfig(project=project) + action = "retrieve" if view.action == "list" else "create" + return config.check_permission(request.user, action) + + return super().has_permission(request, view) + + class UserMembershipPermission(ObjectPermission): """ Custom permission for UserProjectMembershipViewSet. diff --git a/ami/main/models.py b/ami/main/models.py index 1946ec3cf..60a7dc0f4 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -336,7 +336,7 @@ def check_custom_permission(self, user, action: str) -> bool: Charts is treated as a read-only operation, so it follows the same permission logic as 'retrieve'. """ - from ami.users.roles import BasicMember + from ami.users.roles import BasicMember, ProjectManager if action == "charts": # Same permission logic as retrieve action @@ -345,6 +345,10 @@ def check_custom_permission(self, user, action: str) -> bool: return BasicMember.has_role(user, self) or user == self.owner or user.is_superuser return True + if action == "pipelines": + # Pipeline registration requires project management permissions + return ProjectManager.has_role(user, self) or user == self.owner or user.is_superuser + # Fall back to default permission checking for other actions return super().check_custom_permission(user, action) @@ -422,6 +426,11 @@ class Permissions: UPDATE_DATA_EXPORT = "update_dataexport" DELETE_DATA_EXPORT = "delete_dataexport" + # Pipeline configuration permissions + CREATE_PROJECT_PIPELINE_CONFIG = "create_projectpipelineconfig" + UPDATE_PROJECT_PIPELINE_CONFIG = "update_projectpipelineconfig" + DELETE_PROJECT_PIPELINE_CONFIG = "delete_projectpipelineconfig" + # Other permissions VIEW_PRIVATE_DATA = "view_private_data" DELETE_OCCURRENCES = "delete_occurrences" @@ -485,6 +494,10 @@ class Meta: ("create_dataexport", "Can create a data export"), ("update_dataexport", "Can update a data export"), ("delete_dataexport", "Can delete a data export"), + # Pipeline configuration permissions + ("create_projectpipelineconfig", "Can register pipelines for the project"), + ("update_projectpipelineconfig", "Can update pipeline configurations"), + ("delete_projectpipelineconfig", "Can remove pipelines from the project"), # Other permissions ("view_private_data", "Can view private data"), ] diff --git a/ami/main/tests.py b/ami/main/tests.py index a6be324f4..ca6fb53cc 100644 --- a/ami/main/tests.py +++ b/ami/main/tests.py @@ -3443,3 +3443,140 @@ def test_taxon_detail_visible_when_excluded_from_list(self): detail_url = f"/api/v2/taxa/{excluded_taxon.id}/?project_id={self.project.pk}" res = self.client.get(detail_url) self.assertEqual(res.status_code, status.HTTP_200_OK) + + +class TestProjectPipelinesAPI(APITestCase): + """Test the project pipelines API endpoint.""" + + def setUp(self): + from ami.users.roles import ProjectManager, create_roles_for_project + + self.user = User.objects.create_user(email="test@example.com") # type: ignore + self.other_user = User.objects.create_user(email="other@example.com") # type: ignore + + # Create projects with explicit ownership + self.project = Project.objects.create(name="Test Project", owner=self.user, create_defaults=True) + self.other_project = Project.objects.create(name="Other Project", owner=self.other_user, create_defaults=True) + + # Create role groups and assign permissions + create_roles_for_project(self.project) + create_roles_for_project(self.other_project) + ProjectManager.assign_user(self.user, self.project) + + def _get_pipelines_url(self, project_id): + """Get the pipelines API URL for a project.""" + return f"/api/v2/projects/{project_id}/pipelines/" + + def _get_test_payload(self, service_name: str): + """Get a minimal test payload for pipeline registration.""" + return { + "processing_service_name": service_name, + "pipelines": [], + } + + def test_create_new_service_success(self): + """Test creating a new processing service if it doesn't exist.""" + url = self._get_pipelines_url(self.project.pk) + payload = self._get_test_payload("NewService") + + self.client.force_authenticate(user=self.user) + response = self.client.post(url, payload, format="json") + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + # Verify service was created and associated + service = ProcessingService.objects.get(name="NewService") + self.assertIn(self.project, service.projects.all()) + + def test_reregistration_is_idempotent(self): + """Test that re-registering a service already associated with the project succeeds.""" + # Create and associate service + service = ProcessingService.objects.create(name="ExistingService") + service.projects.add(self.project) + + url = self._get_pipelines_url(self.project.pk) + payload = self._get_test_payload("ExistingService") + + self.client.force_authenticate(user=self.user) + response = self.client.post(url, payload, format="json") + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + def test_associate_existing_service_success(self): + """Test associating existing service with project when not yet associated.""" + # Create service but don't associate with project + service = ProcessingService.objects.create(name="UnassociatedService") + + url = self._get_pipelines_url(self.project.pk) + payload = self._get_test_payload("UnassociatedService") + + self.client.force_authenticate(user=self.user) + response = self.client.post(url, payload, format="json") + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertIn(self.project, service.projects.all()) + + def test_unauthorized_project_access_returns_403(self): + """Test 403 when user doesn't have write access to project.""" + url = self._get_pipelines_url(self.other_project.pk) + payload = self._get_test_payload("UnauthorizedService") + + self.client.force_authenticate(user=self.user) + response = self.client.post(url, payload, format="json") + + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_invalid_payload_returns_400(self): + """Test 400 when payload is invalid.""" + url = self._get_pipelines_url(self.project.pk) + invalid_payload = {"invalid": "data"} + + self.client.force_authenticate(user=self.user) + response = self.client.post(url, invalid_payload, format="json") + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_list_pipelines(self): + """Test listing pipelines for a project returns the project's enabled pipelines.""" + url = self._get_pipelines_url(self.project.pk) + self.client.force_authenticate(user=self.user) + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + results = response.json()["results"] + self.assertGreater(len(results), 0) + + # All returned pipelines should belong to this project + project_pipeline_names = set( + Pipeline.objects.filter(projects=self.project, project_pipeline_configs__enabled=True) + .values_list("name", flat=True) + .distinct() + ) + response_names = {p["name"] for p in results} + self.assertEqual(response_names, project_pipeline_names) + + def test_list_pipelines_draft_project_non_member(self): + """Non-members cannot list pipelines on draft projects.""" + self.project.draft = True + self.project.save() + + non_member = User.objects.create_user(email="nonmember@example.com") # type: ignore + url = self._get_pipelines_url(self.project.pk) + self.client.force_authenticate(user=non_member) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_unauthenticated_write_returns_401(self): + """Unauthenticated users cannot register pipelines.""" + url = self._get_pipelines_url(self.project.pk) + payload = self._get_test_payload("AnonService") + response = self.client.post(url, payload, format="json") + self.assertIn(response.status_code, [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN]) + + def test_list_pipelines_public_project_non_member(self): + """Non-members can list pipelines on public projects.""" + non_member = User.objects.create_user(email="reader@example.com") # type: ignore + url = self._get_pipelines_url(self.project.pk) + self.client.force_authenticate(user=non_member) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/ami/ml/migrations/0026_make_processing_service_endpoint_url_nullable.py b/ami/ml/migrations/0026_make_processing_service_endpoint_url_nullable.py new file mode 100644 index 000000000..af7905b0d --- /dev/null +++ b/ami/ml/migrations/0026_make_processing_service_endpoint_url_nullable.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.10 on 2026-01-16 17:36 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ml", "0025_alter_algorithm_task_type"), + ] + + operations = [ + migrations.AlterField( + model_name="processingservice", + name="endpoint_url", + field=models.CharField(blank=True, max_length=1024, null=True), + ), + ] diff --git a/ami/ml/models/processing_service.py b/ami/ml/models/processing_service.py index 4711c5e73..ec7516d39 100644 --- a/ami/ml/models/processing_service.py +++ b/ami/ml/models/processing_service.py @@ -39,7 +39,7 @@ class ProcessingService(BaseModel): name = models.CharField(max_length=255) description = models.TextField(blank=True) projects = models.ManyToManyField("main.Project", related_name="processing_services", blank=True) - endpoint_url = models.CharField(max_length=1024) + endpoint_url = models.CharField(max_length=1024, null=True, blank=True) pipelines = models.ManyToManyField("ml.Pipeline", related_name="processing_services", blank=True) last_checked = models.DateTimeField(null=True) last_checked_live = models.BooleanField(null=True) @@ -48,7 +48,8 @@ class ProcessingService(BaseModel): objects = ProcessingServiceManager() def __str__(self): - return f'#{self.pk} "{self.name}" at {self.endpoint_url}' + endpoint_display = self.endpoint_url or "async" + return f'#{self.pk} "{self.name}" ({endpoint_display})' class Meta: verbose_name = "Processing Service" @@ -151,6 +152,19 @@ def get_status(self, timeout=90) -> ProcessingServiceStatusResponse: Args: timeout: Request timeout in seconds per attempt (default: 90s for serverless cold starts) """ + # If no endpoint URL is configured, return a no-op response + if self.endpoint_url is None: + return ProcessingServiceStatusResponse( + timestamp=datetime.datetime.now(), + request_successful=False, + server_live=None, + pipelines_online=[], + pipeline_configs=[], + endpoint_url=self.endpoint_url, + error="No endpoint URL configured - service operates in pull mode", + latency=0.0, + ) + ready_check_url = urljoin(self.endpoint_url, "readyz") start_time = time.time() error = None @@ -215,6 +229,9 @@ def get_pipeline_configs(self, timeout=6): Get the pipeline configurations from the processing service. This can be a long response as it includes the full category map for each algorithm. """ + if self.endpoint_url is None: + return [] + info_url = urljoin(self.endpoint_url, "info") resp = requests.get(info_url, timeout=timeout) resp.raise_for_status() diff --git a/ami/ml/schemas.py b/ami/ml/schemas.py index 0de43497c..f63e6e1a1 100644 --- a/ami/ml/schemas.py +++ b/ami/ml/schemas.py @@ -331,7 +331,7 @@ class ProcessingServiceStatusResponse(pydantic.BaseModel): error: str | None = None server_live: bool | None = None pipelines_online: list[str] = [] - endpoint_url: str + endpoint_url: str | None = None latency: float @@ -342,3 +342,12 @@ class PipelineRegistrationResponse(pydantic.BaseModel): pipelines: list[PipelineConfigResponse] = [] pipelines_created: list[str] = [] algorithms_created: list[str] = [] + + +class AsyncPipelineRegistrationRequest(pydantic.BaseModel): + """ + Request to register pipelines from an async processing service + """ + + processing_service_name: str + pipelines: list[PipelineConfigResponse] = [] diff --git a/ami/ml/serializers.py b/ami/ml/serializers.py index 8ebe50e0a..3217c519b 100644 --- a/ami/ml/serializers.py +++ b/ami/ml/serializers.py @@ -8,6 +8,7 @@ from .models.pipeline import Pipeline, PipelineStage from .models.processing_service import ProcessingService from .models.project_pipeline_config import ProjectPipelineConfig +from .schemas import PipelineConfigResponse class AlgorithmCategoryMapSerializer(DefaultSerializer): @@ -164,3 +165,8 @@ def create(self, validated_data): instance.projects.add(project) return instance + + +class PipelineRegistrationSerializer(serializers.Serializer): + processing_service_name = serializers.CharField() + pipelines = SchemaField(schema=list[PipelineConfigResponse], default=[]) diff --git a/ami/ml/tasks.py b/ami/ml/tasks.py index 23188db4c..68e9603bd 100644 --- a/ami/ml/tasks.py +++ b/ami/ml/tasks.py @@ -98,15 +98,16 @@ def remove_duplicate_classifications(project_id: int | None = None, dry_run: boo @celery_app.task(soft_time_limit=10, time_limit=20) def check_processing_services_online(): """ - Check the status of all processing services and update last checked. + Check the status of all v1 synchronous processing services and update the last_seen field. + We will update last_seen for asynchronous services when we receive a request from them. @TODO make this async to check all services in parallel """ from ami.ml.models import ProcessingService - logger.info("Checking if processing services are online.") + logger.info("Checking which synchronous processing services are online.") - services = ProcessingService.objects.all() + services = ProcessingService.objects.exclude(endpoint_url__isnull=True).exclude(endpoint_url__exact="").all() for service in services: logger.info(f"Checking service {service}") diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 2f269a6f9..6d029492b 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -115,6 +115,51 @@ def test_processing_service_pipeline_registration(self): self.assertEqual(pipelines_queryset.count(), len(response["pipelines"])) + def test_create_processing_service_without_endpoint_url(self): + """Test creating a ProcessingService without endpoint_url (pull mode)""" + processing_services_create_url = reverse_with_params("api:processingservice-list") + self.client.force_authenticate(user=self.user) + processing_service_data = { + "project": self.project.pk, + "name": "Pull Mode Service", + "description": "Service without endpoint", + } + resp = self.client.post(processing_services_create_url, processing_service_data) + self.client.force_authenticate(user=None) + + self.assertEqual(resp.status_code, 201) + data = resp.json() + + # Check that endpoint_url is null + self.assertIsNone(data["instance"]["endpoint_url"]) + + # Check that status indicates no endpoint configured + self.assertFalse(data["status"]["request_successful"]) + self.assertIn("No endpoint URL configured", data["status"]["error"]) + self.assertIsNone(data["status"]["endpoint_url"]) + + def test_get_status_with_null_endpoint_url(self): + """Test get_status method when endpoint_url is None""" + service = ProcessingService.objects.create(name="Pull Mode Service", endpoint_url=None) + service.projects.add(self.project) + + status = service.get_status() + + self.assertFalse(status.request_successful) + self.assertIsNone(status.server_live) + self.assertIsNone(status.endpoint_url) + self.assertIsNotNone(status.error) + self.assertIn("No endpoint URL configured", (status.error or "")) + self.assertEqual(status.pipelines_online, []) + + def test_get_pipeline_configs_with_null_endpoint_url(self): + """Test get_pipeline_configs method when endpoint_url is None""" + service = ProcessingService.objects.create(name="Pull Mode Service", endpoint_url=None) + + configs = service.get_pipeline_configs() + + self.assertEqual(configs, []) + class TestPipelineWithProcessingService(TestCase): def test_run_pipeline_with_errors_from_processing_service(self): diff --git a/ami/ml/views.py b/ami/ml/views.py index 4c0699b70..edcf0517c 100644 --- a/ami/ml/views.py +++ b/ami/ml/views.py @@ -1,19 +1,22 @@ import logging +from django.db import transaction from django.db.models import Prefetch from django.db.models.query import QuerySet from django.utils.text import slugify from drf_spectacular.utils import extend_schema from rest_framework import exceptions as api_exceptions -from rest_framework import status +from rest_framework import mixins, status, viewsets from rest_framework.decorators import action from rest_framework.request import Request from rest_framework.response import Response +from ami.base.permissions import ProjectPipelineConfigPermission from ami.base.views import ProjectMixin from ami.main.api.schemas import project_id_doc_param from ami.main.api.views import DefaultViewSet -from ami.main.models import SourceImage +from ami.main.models import Project, SourceImage +from ami.ml.schemas import PipelineRegistrationResponse from .models.algorithm import Algorithm, AlgorithmCategoryMap from .models.pipeline import Pipeline @@ -22,6 +25,7 @@ from .serializers import ( AlgorithmCategoryMapSerializer, AlgorithmSerializer, + PipelineRegistrationSerializer, PipelineSerializer, ProcessingServiceSerializer, ) @@ -188,3 +192,74 @@ def register_pipelines(self, request: Request, pk=None) -> Response: response = processing_service.create_pipelines() processing_service.save() return Response(response.dict()) + + +class ProjectPipelineViewSet(ProjectMixin, mixins.ListModelMixin, mixins.CreateModelMixin, viewsets.GenericViewSet): + """Pipelines for a specific project. GET lists, POST registers.""" + + queryset = Pipeline.objects.none() + serializer_class = PipelineSerializer + permission_classes = [ProjectPipelineConfigPermission] + require_project = True + + def get_queryset(self) -> QuerySet: + project = self.get_active_project() + return ( + Pipeline.objects.filter(projects=project, project_pipeline_configs__enabled=True) + .prefetch_related( + "algorithms", + Prefetch( + "processing_services", + queryset=ProcessingService.objects.filter(projects=project), + ), + Prefetch( + "project_pipeline_configs", + queryset=ProjectPipelineConfig.objects.filter(project=project), + ), + ) + .distinct() + ) + + def get_serializer_class(self): + if self.action == "create": + return PipelineRegistrationSerializer + return PipelineSerializer + + @extend_schema( + operation_id="projects_pipelines_list", + summary="List pipelines for a project", + responses={200: PipelineSerializer(many=True)}, + tags=["projects"], + ) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + + @extend_schema( + operation_id="projects_pipelines_create", + summary="Register pipelines for a project", + description=( + "Receive pipeline registrations for a project. This endpoint is called by the " + "V2 ML processing services to register available pipelines for a project." + ), + request=PipelineRegistrationSerializer, + responses={201: PipelineRegistrationResponse}, + tags=["projects"], + ) + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + project = self.get_active_project() + + with transaction.atomic(): + processing_service, _ = ProcessingService.objects.get_or_create( + name=serializer.validated_data["processing_service_name"], + defaults={"endpoint_url": None}, + ) + processing_service.projects.add(project) + + response = processing_service.create_pipelines( + pipeline_configs=serializer.validated_data["pipelines"], + projects=Project.objects.filter(pk=project.pk), + ) + + return Response(response.dict(), status=status.HTTP_201_CREATED) diff --git a/ami/users/roles.py b/ami/users/roles.py index 69862229c..79cad22ae 100644 --- a/ami/users/roles.py +++ b/ami/users/roles.py @@ -144,6 +144,7 @@ class MLDataManager(Role): Project.Permissions.RUN_DATA_EXPORT_JOB, Project.Permissions.DELETE_JOB, Project.Permissions.DELETE_OCCURRENCES, + Project.Permissions.CREATE_PROJECT_PIPELINE_CONFIG, } @@ -190,6 +191,9 @@ class ProjectManager(Role): Project.Permissions.CREATE_USER_PROJECT_MEMBERSHIP, Project.Permissions.UPDATE_USER_PROJECT_MEMBERSHIP, Project.Permissions.DELETE_USER_PROJECT_MEMBERSHIP, + Project.Permissions.CREATE_PROJECT_PIPELINE_CONFIG, + Project.Permissions.UPDATE_PROJECT_PIPELINE_CONFIG, + Project.Permissions.DELETE_PROJECT_PIPELINE_CONFIG, } ) diff --git a/config/api_router.py b/config/api_router.py index 13d4026e5..52541cde0 100644 --- a/config/api_router.py +++ b/config/api_router.py @@ -23,6 +23,11 @@ UserProjectMembershipViewSet, basename="project-members", ) +projects_router.register( + r"pipelines", + ml_views.ProjectPipelineViewSet, + basename="project-pipelines", +) router.register(r"deployments/devices", views.DeviceViewSet) router.register(r"deployments/sites", views.SiteViewSet) From cae6ff3d8a42de08356d6e12aac627048f679afd Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 18 Feb 2026 10:08:06 -0800 Subject: [PATCH 13/14] update tests --- ami/jobs/tests.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index 4d4edb9fa..d041ef0ac 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -238,7 +238,7 @@ def _create_pipeline(self, name: str = "Test Pipeline", slug: str = "test-pipeli self.pipeline = pipeline return pipeline - def _create_ml_job(self, name: str, pipeline: Pipeline) -> Job: + def _create_ml_job(self, name: str, pipeline: Pipeline, **kwargs) -> Job: """Helper to create an ML job with a pipeline.""" return Job.objects.create( job_type_key=MLJob.key, @@ -246,6 +246,7 @@ def _create_ml_job(self, name: str, pipeline: Pipeline) -> Job: name=name, pipeline=pipeline, source_image_collection=self.source_image_collection, + **kwargs, ) def test_create_job(self): @@ -413,8 +414,7 @@ def test_search_jobs(self): def _task_batch_helper(self, value: Any, expected_status: int): pipeline = self._create_pipeline() - job = self._create_ml_job("Job for batch test", pipeline) - job.dispatch_mode = JobDispatchMode.ASYNC_API + job = self._create_ml_job("Job for batch test", pipeline, dispatch_mode=JobDispatchMode.ASYNC_API) job.save(update_fields=["dispatch_mode"]) images = [ SourceImage.objects.create( @@ -534,7 +534,7 @@ def test_processing_service_name_parameter(self): # Test tasks endpoint (requires job with pipeline) pipeline = self._create_pipeline() - job = self._create_ml_job("Job for service name test", pipeline) + job = self._create_ml_job("Job for service name test", pipeline, dispatch_mode=JobDispatchMode.ASYNC_API) tasks_url = reverse_with_params( "api:job-tasks", From 9cb886b4d513e2c9374894176783f13059fa8e03 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 18 Feb 2026 10:21:10 -0800 Subject: [PATCH 14/14] clean up --- ami/jobs/tests.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index d041ef0ac..278438448 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -415,7 +415,6 @@ def test_search_jobs(self): def _task_batch_helper(self, value: Any, expected_status: int): pipeline = self._create_pipeline() job = self._create_ml_job("Job for batch test", pipeline, dispatch_mode=JobDispatchMode.ASYNC_API) - job.save(update_fields=["dispatch_mode"]) images = [ SourceImage.objects.create( path=f"image_{i}.jpg",