From b60eab0ff435e32b0264b182b4022864cdce71f0 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 16 Jan 2026 11:25:40 -0800 Subject: [PATCH 01/25] merge --- requirements/base.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/base.txt b/requirements/base.txt index dd9de69d5..ed40ea5f7 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -52,6 +52,7 @@ django-anymail[sendgrid]==10.0 # https://github.com/anymail/django-anymail Werkzeug[watchdog]==2.3.6 # https://github.com/pallets/werkzeug ipdb==0.13.13 # https://github.com/gotcha/ipdb psycopg[binary]==3.1.9 # https://github.com/psycopg/psycopg +# psycopg==3.1.9 # https://github.com/psycopg/psycopg # the non-binary version is needed for some platforms watchfiles==0.19.0 # https://github.com/samuelcolvin/watchfiles # Testing From 867201d92d8c0a1705021153bddc8355ae508b77 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 6 Feb 2026 15:58:29 -0800 Subject: [PATCH 02/25] Update ML job counts in async case --- ami/jobs/tasks.py | 35 +++++++-- ami/ml/orchestration/task_state.py | 37 +++++++++- ami/ml/tests.py | 112 +++++++++++++++++++++++++++++ 3 files changed, 177 insertions(+), 7 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 6404512d1..78d518d91 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -84,7 +84,13 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub raise self.retry(countdown=5, max_retries=10) try: - _update_job_progress(job_id, "process", progress_info.percentage) + _update_job_progress( + job_id, + "process", + progress_info.percentage, + processed=progress_info.processed, + remaining=progress_info.remaining, + ) _, t = t(f"TIME: Updated job {job_id} progress in PROCESS stage progress to {progress_info.percentage*100}%") job = Job.objects.get(pk=job_id) @@ -115,7 +121,20 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub _ack_task_via_nats(reply_subject, job.logger) # Update job stage with calculated progress - progress_info = state_manager.update_state(processed_image_ids, stage="results", request_id=self.request.id) + + # Calculate detection and classification counts from this result + detections_count = len(pipeline_result.detections) if pipeline_result else 0 + classifications_count = ( + sum(len(detection.classifications) for detection in pipeline_result.detections) if pipeline_result else 0 + ) + + progress_info = state_manager.update_state( + processed_image_ids, + stage="results", + request_id=self.request.id, + detections_count=detections_count, + classifications_count=classifications_count, + ) if not progress_info: logger.warning( @@ -123,7 +142,14 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub f"Retrying task {self.request.id} in 5 seconds..." ) raise self.retry(countdown=5, max_retries=10) - _update_job_progress(job_id, "results", progress_info.percentage) + + _update_job_progress( + job_id, + "results", + progress_info.percentage, + detections=progress_info.detections, + classifications=progress_info.classifications, + ) except Exception as e: job.logger.error( @@ -149,7 +175,7 @@ async def ack_task(): # Don't fail the task if ACK fails - data is already saved -def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> None: +def _update_job_progress(job_id: int, stage: str, progress_percentage: float, **state_params) -> None: from ami.jobs.models import Job, JobState # avoid circular import with transaction.atomic(): @@ -158,6 +184,7 @@ def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> stage, status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, progress=progress_percentage, + **state_params, ) if stage == "results" and progress_percentage >= 1.0: job.status = JobState.SUCCESS diff --git a/ami/ml/orchestration/task_state.py b/ami/ml/orchestration/task_state.py index 483275453..27c4bec07 100644 --- a/ami/ml/orchestration/task_state.py +++ b/ami/ml/orchestration/task_state.py @@ -11,7 +11,9 @@ # Define a namedtuple for a TaskProgress with the image counts -TaskProgress = namedtuple("TaskProgress", ["remaining", "total", "processed", "percentage"]) +TaskProgress = namedtuple( + "TaskProgress", ["remaining", "total", "processed", "percentage", "detections", "classifications"] +) class TaskStateManager: @@ -35,6 +37,8 @@ def __init__(self, job_id: int): self.job_id = job_id self._pending_key = f"job:{job_id}:pending_images" self._total_key = f"job:{job_id}:pending_images_total" + self._detections_key = f"job:{job_id}:total_detections" + self._classifications_key = f"job:{job_id}:total_classifications" def initialize_job(self, image_ids: list[str]) -> None: """ @@ -48,6 +52,10 @@ def initialize_job(self, image_ids: list[str]) -> None: cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT) + # Initialize detection and classification counters + cache.set(self._detections_key, 0, timeout=self.TIMEOUT) + cache.set(self._classifications_key, 0, timeout=self.TIMEOUT) + def _get_pending_key(self, stage: str) -> str: return f"{self._pending_key}:{stage}" @@ -56,12 +64,18 @@ def update_state( processed_image_ids: set[str], stage: str, request_id: str, + detections_count: int = 0, + classifications_count: int = 0, ) -> None | TaskProgress: """ Update the task state with newly processed images. Args: processed_image_ids: Set of image IDs that have just been processed + stage: The processing stage ("process" or "results") + request_id: Unique identifier for this processing request + detections_count: Number of detections to add to cumulative count + classifications_count: Number of classifications to add to cumulative count """ # Create a unique lock key for this job lock_key = f"job:{self.job_id}:process_results_lock" @@ -72,7 +86,7 @@ def update_state( try: # Update progress tracking in Redis - progress_info = self._get_progress(processed_image_ids, stage) + progress_info = self._get_progress(processed_image_ids, stage, detections_count, classifications_count) return progress_info finally: # Always release the lock when done @@ -82,7 +96,9 @@ def update_state( cache.delete(lock_key) logger.debug(f"Released lock for job {self.job_id}, task {request_id}") - def _get_progress(self, processed_image_ids: set[str], stage: str) -> TaskProgress | None: + def _get_progress( + self, processed_image_ids: set[str], stage: str, detections_count: int = 0, classifications_count: int = 0 + ) -> TaskProgress | None: """ Get current progress information for the job. @@ -104,6 +120,17 @@ def _get_progress(self, processed_image_ids: set[str], stage: str) -> TaskProgre remaining = len(remaining_images) processed = total_images - remaining percentage = float(processed) / total_images if total_images > 0 else 1.0 + + # Update cumulative detection and classification counts + current_detections = cache.get(self._detections_key, 0) + current_classifications = cache.get(self._classifications_key, 0) + + new_detections = current_detections + detections_count + new_classifications = current_classifications + classifications_count + + cache.set(self._detections_key, new_detections, timeout=self.TIMEOUT) + cache.set(self._classifications_key, new_classifications, timeout=self.TIMEOUT) + logger.info( f"Pending images from Redis for job {self.job_id} {stage}: " f"{remaining}/{total_images}: {percentage*100}%" @@ -114,6 +141,8 @@ def _get_progress(self, processed_image_ids: set[str], stage: str) -> TaskProgre total=total_images, processed=processed, percentage=percentage, + detections=new_detections, + classifications=new_classifications, ) def cleanup(self) -> None: @@ -123,3 +152,5 @@ def cleanup(self) -> None: for stage in self.STAGES: cache.delete(self._get_pending_key(stage)) cache.delete(self._total_key) + cache.delete(self._detections_key) + cache.delete(self._classifications_key) diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 20e0368fe..538106c19 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -880,6 +880,8 @@ def _init_and_verify(self, image_ids): self.assertEqual(progress.remaining, len(image_ids)) self.assertEqual(progress.processed, 0) self.assertEqual(progress.percentage, 0.0) + self.assertEqual(progress.detections, 0) + self.assertEqual(progress.classifications, 0) return progress def test_initialize_job(self): @@ -891,6 +893,8 @@ def test_initialize_job(self): progress = self.manager._get_progress(set(), stage) assert progress is not None self.assertEqual(progress.total, len(self.image_ids)) + self.assertEqual(progress.detections, 0) + self.assertEqual(progress.classifications, 0) def test_progress_tracking(self): """Test progress updates correctly as images are processed.""" @@ -902,6 +906,8 @@ def test_progress_tracking(self): self.assertEqual(progress.remaining, 3) self.assertEqual(progress.processed, 2) self.assertEqual(progress.percentage, 0.4) + self.assertEqual(progress.detections, 0) # No counts added yet + self.assertEqual(progress.classifications, 0) # Process 2 more images progress = self.manager._get_progress({"img3", "img4"}, "process") @@ -927,6 +933,8 @@ def test_update_state_with_locking(self): progress = self.manager.update_state({"img1", "img2"}, "process", "task1") assert progress is not None self.assertEqual(progress.processed, 2) + self.assertEqual(progress.detections, 0) + self.assertEqual(progress.classifications, 0) # Simulate concurrent update by holding the lock lock_key = f"job:{self.job_id}:process_results_lock" @@ -964,6 +972,8 @@ def test_empty_job(self): assert progress is not None self.assertEqual(progress.total, 0) self.assertEqual(progress.percentage, 1.0) # Empty job is 100% complete + self.assertEqual(progress.detections, 0) + self.assertEqual(progress.classifications, 0) def test_cleanup(self): """Test cleanup removes all tracking keys.""" @@ -979,3 +989,105 @@ def test_cleanup(self): # Verify keys are gone progress = self.manager._get_progress(set(), "process") self.assertIsNone(progress) + + def test_cumulative_detection_counting(self): + """Test that detection counts accumulate correctly across updates.""" + self._init_and_verify(self.image_ids) + + # Process first batch with some detections + progress = self.manager._get_progress({"img1", "img2"}, "process", detections_count=3) + assert progress is not None + self.assertEqual(progress.detections, 3) + self.assertEqual(progress.classifications, 0) + + # Process second batch with more detections + progress = self.manager._get_progress({"img3"}, "process", detections_count=2) + assert progress is not None + self.assertEqual(progress.detections, 5) # Should be cumulative + self.assertEqual(progress.classifications, 0) + + # Process with both detections and classifications + progress = self.manager._get_progress({"img4"}, "results", detections_count=1, classifications_count=4) + assert progress is not None + self.assertEqual(progress.detections, 6) # Should accumulate + self.assertEqual(progress.classifications, 4) + + def test_cumulative_classification_counting(self): + """Test that classification counts accumulate correctly across updates.""" + self._init_and_verify(self.image_ids) + + # Process first batch with some classifications + progress = self.manager._get_progress({"img1"}, "results", classifications_count=5) + assert progress is not None + self.assertEqual(progress.detections, 0) + self.assertEqual(progress.classifications, 5) + + # Process second batch with more classifications + progress = self.manager._get_progress({"img2", "img3"}, "results", classifications_count=8) + assert progress is not None + self.assertEqual(progress.detections, 0) + self.assertEqual(progress.classifications, 13) # Should be cumulative + + def test_update_state_with_counts(self): + """Test update_state method properly handles detection and classification counts.""" + self._init_and_verify(self.image_ids) + + # Update with counts + progress = self.manager.update_state( + {"img1", "img2"}, "process", "task1", detections_count=4, classifications_count=8 + ) + assert progress is not None + self.assertEqual(progress.processed, 2) + self.assertEqual(progress.detections, 4) + self.assertEqual(progress.classifications, 8) + + # Update with more counts + progress = self.manager.update_state({"img3"}, "results", "task2", detections_count=2, classifications_count=6) + assert progress is not None + self.assertEqual(progress.detections, 6) # Should accumulate + self.assertEqual(progress.classifications, 14) # Should accumulate + + def test_counts_persist_across_stages(self): + """Test that detection and classification counts persist across different stages.""" + self._init_and_verify(self.image_ids) + + # Add counts during process stage + progress_process = self.manager._get_progress({"img1"}, "process", detections_count=3) + assert progress_process is not None + self.assertEqual(progress_process.detections, 3) + + # Verify counts are available in results stage + progress_results = self.manager._get_progress(set(), "results") + assert progress_results is not None + self.assertEqual(progress_results.detections, 3) # Should persist + self.assertEqual(progress_results.classifications, 0) + + # Add more counts in results stage + progress_results = self.manager._get_progress({"img2"}, "results", detections_count=1, classifications_count=5) + assert progress_results is not None + self.assertEqual(progress_results.detections, 4) # Should accumulate + self.assertEqual(progress_results.classifications, 5) + + def test_cleanup_removes_count_keys(self): + """Test that cleanup removes detection and classification count keys.""" + from django.core.cache import cache + + self._init_and_verify(self.image_ids) + + # Add some counts + self.manager._get_progress({"img1"}, "process", detections_count=5, classifications_count=10) + + # Verify count keys exist + detections = cache.get(self.manager._detections_key) + classifications = cache.get(self.manager._classifications_key) + self.assertEqual(detections, 5) + self.assertEqual(classifications, 10) + + # Cleanup + self.manager.cleanup() + + # Verify count keys are gone + detections = cache.get(self.manager._detections_key) + classifications = cache.get(self.manager._classifications_key) + self.assertIsNone(detections) + self.assertIsNone(classifications) From cdf57ea91f3d215df4f682f7e570bb7e8a5580e4 Mon Sep 17 00:00:00 2001 From: Anna Viklund Date: Fri, 6 Feb 2026 17:09:31 +0100 Subject: [PATCH 03/25] Update date picker version and tweak layout logic (#1105) * fix: update date picker version and tweak layout logic * feat: set start month based on selected date --- ui/package.json | 2 +- .../page-footer/page-footer.module.scss | 2 +- .../components/select/date-picker.tsx | 22 +++++- ui/yarn.lock | 67 +++++++++++++++---- 4 files changed, 74 insertions(+), 19 deletions(-) diff --git a/ui/package.json b/ui/package.json index 004b743a3..721d2b5d1 100644 --- a/ui/package.json +++ b/ui/package.json @@ -29,7 +29,7 @@ "leaflet": "^1.9.3", "lodash": "^4.17.21", "lucide-react": "^0.454.0", - "nova-ui-kit": "^1.1.32", + "nova-ui-kit": "^1.1.33", "plotly.js": "^2.25.2", "react": "^18.2.0", "react-dom": "^18.2.0", diff --git a/ui/src/design-system/components/page-footer/page-footer.module.scss b/ui/src/design-system/components/page-footer/page-footer.module.scss index 10bcafc28..d89ec8416 100644 --- a/ui/src/design-system/components/page-footer/page-footer.module.scss +++ b/ui/src/design-system/components/page-footer/page-footer.module.scss @@ -8,7 +8,7 @@ width: 100%; background-color: $color-generic-white; border-top: 1px solid $color-neutral-100; - z-index: 2; + z-index: 50; } .content { diff --git a/ui/src/design-system/components/select/date-picker.tsx b/ui/src/design-system/components/select/date-picker.tsx index e863dfffe..82ff58928 100644 --- a/ui/src/design-system/components/select/date-picker.tsx +++ b/ui/src/design-system/components/select/date-picker.tsx @@ -1,7 +1,7 @@ import { format } from 'date-fns' import { AlertCircleIcon, Calendar as CalendarIcon } from 'lucide-react' import { Button, Calendar, Popover } from 'nova-ui-kit' -import { useState } from 'react' +import { useEffect, useMemo, useState } from 'react' const dateToLabel = (date: Date) => { try { @@ -21,7 +21,15 @@ export const DatePicker = ({ value?: string }) => { const [open, setOpen] = useState(false) - const selected = value ? new Date(value) : undefined + const selected = useMemo(() => (value ? new Date(value) : undefined), [value]) + const [month, setMonth] = useState(selected) + + /* Reset start month on date picker close */ + useEffect(() => { + if (!open) { + setMonth(selected) + } + }, [open, selected]) const triggerLabel = (() => { if (!value) { @@ -50,10 +58,18 @@ export const DatePicker = ({ - + { if (date) { onValueChange(dateToLabel(date)) diff --git a/ui/yarn.lock b/ui/yarn.lock index 5ff92cdbf..292f08a76 100644 --- a/ui/yarn.lock +++ b/ui/yarn.lock @@ -668,6 +668,13 @@ __metadata: languageName: node linkType: hard +"@date-fns/tz@npm:^1.4.1": + version: 1.4.1 + resolution: "@date-fns/tz@npm:1.4.1" + checksum: 9033fdc4682fe3d4d147625ce04fa88a8792653594e2de8d5a438c8f3bfc0990ee28fe773f91cac6810b06d818b5b281ae0608752ba8337257d0279ded3f019a + languageName: node + linkType: hard + "@esbuild/android-arm64@npm:0.18.20": version: 0.18.20 resolution: "@esbuild/android-arm64@npm:0.18.20" @@ -3330,7 +3337,7 @@ __metadata: languageName: node linkType: hard -"@radix-ui/react-slot@npm:1.1.0, @radix-ui/react-slot@npm:^1.1.0": +"@radix-ui/react-slot@npm:1.1.0": version: 1.1.0 resolution: "@radix-ui/react-slot@npm:1.1.0" dependencies: @@ -3390,6 +3397,21 @@ __metadata: languageName: node linkType: hard +"@radix-ui/react-slot@npm:^1.2.4": + version: 1.2.4 + resolution: "@radix-ui/react-slot@npm:1.2.4" + dependencies: + "@radix-ui/react-compose-refs": "npm:1.1.2" + peerDependencies: + "@types/react": "*" + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + "@types/react": + optional: true + checksum: 8b719bb934f1ae5ac0e37214783085c17c2f1080217caf514c1c6cc3d9ca56c7e19d25470b26da79aa6e605ab36589edaade149b76f5fc0666f1063e2fc0a0dc + languageName: node + linkType: hard + "@radix-ui/react-switch@npm:^1.2.6": version: 1.2.6 resolution: "@radix-ui/react-switch@npm:1.2.6" @@ -5096,7 +5118,7 @@ __metadata: leaflet: "npm:^1.9.3" lodash: "npm:^4.17.21" lucide-react: "npm:^0.454.0" - nova-ui-kit: "npm:^1.1.32" + nova-ui-kit: "npm:^1.1.33" plotly.js: "npm:^2.25.2" postcss: "npm:^8.4.47" prettier: "npm:2.8.4" @@ -6434,6 +6456,13 @@ __metadata: languageName: node linkType: hard +"date-fns-jalali@npm:^4.1.0-0": + version: 4.1.0-0 + resolution: "date-fns-jalali@npm:4.1.0-0" + checksum: f9ad98d9f7e8e5abe0d070dc806b0c8baded2b1208626c42e92cbd2605b5171f5714d6b79b20cc2666267d821699244c9d0b5e93274106cf57d6232da77596ed + languageName: node + linkType: hard + "date-fns@npm:^3.6.0": version: 3.6.0 resolution: "date-fns@npm:3.6.0" @@ -6441,6 +6470,13 @@ __metadata: languageName: node linkType: hard +"date-fns@npm:^4.1.0": + version: 4.1.0 + resolution: "date-fns@npm:4.1.0" + checksum: b79ff32830e6b7faa009590af6ae0fb8c3fd9ffad46d930548fbb5acf473773b4712ae887e156ba91a7b3dc30591ce0f517d69fd83bd9c38650fdc03b4e0bac8 + languageName: node + linkType: hard + "debug@npm:2": version: 2.6.9 resolution: "debug@npm:2.6.9" @@ -10473,16 +10509,16 @@ __metadata: languageName: node linkType: hard -"nova-ui-kit@npm:^1.1.32": - version: 1.1.32 - resolution: "nova-ui-kit@npm:1.1.32" +"nova-ui-kit@npm:^1.1.33": + version: 1.1.33 + resolution: "nova-ui-kit@npm:1.1.33" dependencies: "@radix-ui/react-checkbox": "npm:^1.1.4" "@radix-ui/react-collapsible": "npm:^1.1.1" "@radix-ui/react-popover": "npm:^1.1.2" "@radix-ui/react-select": "npm:^2.1.2" "@radix-ui/react-slider": "npm:^1.2.1" - "@radix-ui/react-slot": "npm:^1.1.0" + "@radix-ui/react-slot": "npm:^1.2.4" "@radix-ui/react-switch": "npm:^1.2.6" "@radix-ui/react-tooltip": "npm:^1.1.6" class-variance-authority: "npm:^0.7.0" @@ -10491,11 +10527,11 @@ __metadata: date-fns: "npm:^3.6.0" lucide-react: "npm:^0.452.0" react: "npm:^18.3.1" - react-day-picker: "npm:^8.10.1" + react-day-picker: "npm:^9.13.0" react-dom: "npm:^18.3.1" tailwind-merge: "npm:^2.5.4" tailwindcss-animate: "npm:^1.0.7" - checksum: 8c356c554db489c1a0ba5aee2e232be5bc387ccadd2e0c624169089640afa4f78c92309edb214e7d580ca9a5b9bb611294d7d6278856e3fd73b04c9dbbea4a9b + checksum: a29d22052a4f83a6b997e6bb66f4226625b7dbd05faf4f03824af8b353cf75e7a25ffbe06aafe91b8d3cf6e51b5febba1f013cce3c60faae669278c7f4ccbae2 languageName: node linkType: hard @@ -11255,13 +11291,16 @@ __metadata: languageName: node linkType: hard -"react-day-picker@npm:^8.10.1": - version: 8.10.1 - resolution: "react-day-picker@npm:8.10.1" +"react-day-picker@npm:^9.13.0": + version: 9.13.0 + resolution: "react-day-picker@npm:9.13.0" + dependencies: + "@date-fns/tz": "npm:^1.4.1" + date-fns: "npm:^4.1.0" + date-fns-jalali: "npm:^4.1.0-0" peerDependencies: - date-fns: ^2.28.0 || ^3.0.0 - react: ^16.8.0 || ^17.0.0 || ^18.0.0 - checksum: a0ff28c4b61b3882e6a825b19e5679e2fdf3256cf1be8eb0a0c028949815c1ae5a6561474c2c19d231c010c8e0e0b654d3a322610881e0655abca05a2e03d9df + react: ">=16.8.0" + checksum: e176309a24697f6552c80c7fde257f3dc06506a4e5ec22b11b55bf136bd452ab98d2ca9c2a17632092d67d23d08b3b35daf77d99e95a937ad290d3322abc745c languageName: node linkType: hard From 6837ad630f6011da8683f34b93bffd2b786cf4fe Mon Sep 17 00:00:00 2001 From: carlosgjs Date: Fri, 6 Feb 2026 18:27:26 -0800 Subject: [PATCH 04/25] fix: Properly handle async job state with celery tasks (#1114) * merge * fix: Properly handle async job state with celery tasks * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Delete implemented plan --------- Co-authored-by: Carlos Garcia Jurado Suarez Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .agents/planning/async-job-status-handling.md | 408 ------------------ ami/jobs/models.py | 24 ++ ami/jobs/tasks.py | 11 +- ami/jobs/tests.py | 87 ++++ 4 files changed, 121 insertions(+), 409 deletions(-) delete mode 100644 .agents/planning/async-job-status-handling.md diff --git a/.agents/planning/async-job-status-handling.md b/.agents/planning/async-job-status-handling.md deleted file mode 100644 index 1a6024a87..000000000 --- a/.agents/planning/async-job-status-handling.md +++ /dev/null @@ -1,408 +0,0 @@ -# Plan: Fix Async Pipeline Job Status Handling - -**Date:** 2026-01-30 -**Status:** Ready for implementation - -## Problem Summary - -When `async_pipeline_workers` feature flag is enabled: -1. `queue_images_to_nats()` queues work and returns immediately (lines 400-408 in `ami/jobs/models.py`) -2. The Celery `run_job` task completes without exception -3. The `task_postrun` signal handler (`ami/jobs/tasks.py:175-192`) calls `job.update_status("SUCCESS")` -4. This prematurely marks the job as SUCCESS before async workers actually finish processing - -## Solution Overview - -Use a **progress-based approach** to determine job completion: -1. Add a generic `is_complete()` method to `JobProgress` that works for any job type -2. In `task_postrun`, only allow SUCCESS if all stages are complete -3. Allow FAILURE, REVOKED, and other terminal states to pass through immediately -4. Authoritative completion for async jobs stays in `_update_job_progress` - ---- - -## Background: Job Types and Stages - -Jobs in this app have different types, each with different stages: - -| Job Type | Stages | -|----------|--------| -| MLJob | delay (optional), collect, process, results | -| DataStorageSyncJob | data_storage_sync | -| SourceImageCollectionPopulateJob | populate_captures_collection | -| DataExportJob | exporting_data, uploading_snapshot | -| PostProcessingJob | post_processing | - -The `is_complete()` method must be generic and work for ALL job types by checking if all stages have finished. - -### Celery Signal States - -The `task_postrun` signal fires after every task with these states: -- **SUCCESS**: Task completed without exception -- **FAILURE**: Task raised exception (also triggers `task_failure`) -- **RETRY**: Task requested retry -- **REVOKED**: Task cancelled - -**Handling strategy:** -| State | Behavior | Rationale | -|-------|----------|-----------| -| SUCCESS | Guard - only set if `is_complete()` | Prevents premature success | -| FAILURE | Allow immediately | Job failed, user needs to know | -| REVOKED | Allow immediately | Job cancelled | -| RETRY | Allow immediately | Transient state | - ---- - -## Implementation Steps - -### Step 1: Add `is_complete()` method to JobProgress - -**File:** `ami/jobs/models.py` (in the `JobProgress` class, after `reset()` method around line 198) - -```python -def is_complete(self) -> bool: - """ - Check if all stages have finished processing. - - A job is considered complete when ALL of its stages have: - - progress >= 1.0 (fully processed) - - status in a final state (SUCCESS, FAILURE, or REVOKED) - - This method works for any job type regardless of which stages it has. - It's used by the Celery task_postrun signal to determine whether to - set the job's final SUCCESS status, or defer to async progress handlers. - - Related: Job.update_progress() (lines 924-947) calculates the aggregate - progress percentage across all stages for display purposes. This method - is a binary check for completion that considers both progress AND status. - - Returns: - True if all stages are complete, False otherwise. - Returns False if job has no stages (shouldn't happen in practice). - """ - if not self.stages: - return False - return all( - stage.progress >= 1.0 and stage.status in JobState.final_states() - for stage in self.stages - ) -``` - -### Step 2: Modify `task_postrun` signal handler - -**File:** `ami/jobs/tasks.py` (lines 175-192) - -Only guard SUCCESS state - let all other states pass through: - -```python -@task_postrun.connect(sender=run_job) -def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs): - from ami.jobs.models import Job, JobState - - job_id = task.request.kwargs["job_id"] - if job_id is None: - logger.error(f"Job id is None for task {task_id}") - return - try: - job = Job.objects.get(pk=job_id) - except Job.DoesNotExist: - try: - job = Job.objects.get(task_id=task_id) - except Job.DoesNotExist: - logger.error(f"No job found for task {task_id} or job_id {job_id}") - return - - # Guard only SUCCESS state - let FAILURE, REVOKED, RETRY pass through immediately - # SUCCESS should only be set when all stages are actually complete - # This prevents premature SUCCESS when async workers are still processing - if state == JobState.SUCCESS and not job.progress.is_complete(): - job.logger.info( - f"Job {job.pk} task completed but stages not finished - " - "deferring SUCCESS status to progress handler" - ) - return - - job.update_status(state) -``` - -### Step 3: Add cleanup to `_update_job_progress` - -**File:** `ami/jobs/tasks.py` (lines 151-166) - -When job completes, add cleanup for Redis and NATS resources: - -```python -def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> None: - """ - Update job progress for a specific stage from async pipeline workers. - - This function is called by process_nats_pipeline_result when async workers - report progress. It updates the job's progress model and, when the job - completes, sets the final SUCCESS status. - - For async jobs, this is the authoritative place where SUCCESS status - is set - the Celery task_postrun signal defers to this function. - - Args: - job_id: The job primary key - stage: The processing stage (e.g., "process" or "results" for ML jobs) - progress_percentage: Progress as a float from 0.0 to 1.0 - """ - from ami.jobs.models import Job, JobState # avoid circular import - - with transaction.atomic(): - job = Job.objects.select_for_update().get(pk=job_id) - job.progress.update_stage( - stage, - status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, - progress=progress_percentage, - ) - - # Check if all stages are now complete - if job.progress.is_complete(): - job.status = JobState.SUCCESS - job.progress.summary.status = JobState.SUCCESS - job.finished_at = datetime.datetime.now() # Use naive datetime in local time - job.logger.info(f"Job {job_id} completed successfully - all stages finished") - - # Clean up job-specific resources (Redis state, NATS stream/consumer) - _cleanup_async_job_resources(job_id, job.logger) - - job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%") - job.save() - - -def _cleanup_async_job_resources(job_id: int, job_logger: logging.Logger) -> None: - """ - Clean up all async processing resources for a completed job. - - This function is called when an async job completes (all stages finished). - It cleans up: - - 1. Redis state (via TaskStateManager.cleanup): - - job:{job_id}:pending_images:process - tracks remaining images in process stage - - job:{job_id}:pending_images:results - tracks remaining images in results stage - - job:{job_id}:pending_images_total - total image count for progress calculation - - 2. NATS JetStream resources (via TaskQueueManager.cleanup_job_resources): - - Stream: job_{job_id} - the message stream that holds pending tasks - - Consumer: job-{job_id}-consumer - the durable consumer that tracks delivery - - Why cleanup is needed: - - Redis keys have a 7-day TTL but should be cleaned immediately when job completes - - NATS streams/consumers have 24-hour retention but consume server resources - - Cleaning up immediately prevents resource accumulation from many jobs - - Cleanup failures are logged but don't fail the job - data is already saved. - - Args: - job_id: The job primary key - job_logger: Logger instance for this job (writes to job's log file) - """ - # Cleanup Redis state tracking - state_manager = TaskStateManager(job_id) - state_manager.cleanup() - job_logger.info(f"Cleaned up Redis state for job {job_id}") - - # Cleanup NATS resources (stream and consumer) - try: - async def cleanup_nats(): - async with TaskQueueManager() as manager: - return await manager.cleanup_job_resources(job_id) - - success = async_to_sync(cleanup_nats)() - if success: - job_logger.info(f"Cleaned up NATS resources for job {job_id}") - else: - job_logger.warning(f"Failed to clean up NATS resources for job {job_id}") - except Exception as e: - job_logger.error(f"Error cleaning up NATS resources for job {job_id}: {e}") - # Don't fail the job if cleanup fails - job data is already saved -``` - -### Step 4: Verify existing error handling - -**File:** `ami/jobs/models.py` (lines 402-408) - -**No changes needed** - Existing code already handles FAILURE correctly: -```python -if not queued: - job.logger.error("Aborting job %s because images could not be queued to NATS", job.pk) - job.progress.update_stage("collect", status=JobState.FAILURE) - job.update_status(JobState.FAILURE) - job.finished_at = datetime.datetime.now() - job.save() - return -``` - ---- - -## Files to Modify - -| File | Lines | Change | -|------|-------|--------| -| `ami/jobs/models.py` | ~198 (in JobProgress class) | Add generic `is_complete()` method | -| `ami/jobs/models.py` | ~700 (Job class docstring) | Add future improvements note | -| `ami/jobs/tasks.py` | 175-192 | Guard SUCCESS with `is_complete()` check | -| `ami/jobs/tasks.py` | 151-166 | Use `is_complete()`, add cleanup with docstring | -| `ami/jobs/tasks.py` | imports | Add import for `TaskStateManager` | - -**Related existing code (not modified):** -- `Job.update_progress()` at lines 924-947 - calculates aggregate progress from stages - ---- - -## Why This Approach Works - -| Job Type | Stages Complete When Task Ends? | `task_postrun` Behavior | -|----------|--------------------------------|------------------------| -| Sync ML job | Yes - all stages marked SUCCESS | Sets SUCCESS normally | -| Async ML job | No - stages incomplete | Skips SUCCESS, defers to async handler | -| DataStorageSyncJob | Yes - single stage | Sets SUCCESS normally | -| DataExportJob | Yes - all stages | Sets SUCCESS normally | -| Any job that fails | N/A | FAILURE passes through | -| Cancelled job | N/A | REVOKED passes through | - ---- - -## Race Condition Analysis: `update_progress()` vs `is_complete()` - -### How `update_progress()` works (lines 924-947) - -Called automatically by `save()` (line 958-959), it auto-corrects stage status/progress: - -```python -for stage in self.progress.stages: - if stage.progress > 0 and stage.status == JobState.CREATED: - stage.status = JobState.STARTED # Auto-upgrade CREATED→STARTED - elif stage.status in JobState.final_states() and stage.progress < 1: - stage.progress = 1 # If final, ensure progress=1 - elif stage.progress == 1 and stage.status not in JobState.final_states(): - stage.status = JobState.SUCCESS # Auto-upgrade to SUCCESS if progress=1 -``` - -**Key insight:** Line 939-941 auto-sets SUCCESS only if `progress == 1`. For async jobs, incomplete stages have `progress < 1`, so they won't be auto-upgraded. - -### When is `is_complete()` called? - -1. **In `task_postrun`:** Job loaded fresh from DB (already saved by `job.run()`) -2. **In `_update_job_progress`:** After `update_stage()` but before `save()` - -### Async job flow - no race condition: - -``` -run_job task: -1. job.run() → MLJob.run() -2. collect stage set to SUCCESS, progress=1.0 -3. job.save() called → update_progress() runs → no changes (only collect is 100%) -4. queue_images_to_nats() called, returns -5. job.run() returns successfully - -task_postrun fires: -6. job = Job.objects.get(pk=job_id) # Fresh from DB -7. Job state: collect=SUCCESS(1.0), process=CREATED(0), results=CREATED(0) -8. is_complete() → False (not all stages at 100% final) -9. SUCCESS status deferred ✓ - -Async workers process images: -10. process_nats_pipeline_result called for each result -11. _update_job_progress("process", 0.5) → is_complete() = False -12. _update_job_progress("process", 1.0) → is_complete() = False (results still 0) -13. _update_job_progress("results", 0.5) → is_complete() = False -14. _update_job_progress("results", 1.0) → is_complete() = True ✓ -15. Job marked SUCCESS, cleanup runs -``` - -### Why there's no race: - -1. `is_complete()` requires ALL stages to have `progress >= 1.0` AND `status in final_states()` -2. For async jobs, incomplete stages have `progress < 1.0`, so `update_progress()` won't auto-upgrade them -3. The async handler updates stages separately (process, then results), so completion only triggers when the LAST stage reaches 100% -4. `select_for_update()` in `_update_job_progress` prevents concurrent updates to the same job - ---- - -## Risk Analysis - -**Core work functions that must NOT be affected:** -1. `queue_images_to_nats()` - queues images to NATS for async processing -2. `process_nats_pipeline_result()` - processes results from async workers -3. `pipeline.save_results()` - saves detections/classifications to database -4. `MLJob.process_images()` - synchronous image processing path - -**Changes and their risk:** - -| Change | Risk to Core Work | Analysis | -|--------|------------------|----------| -| Add `is_complete()` to JobProgress | **None** | Pure read-only check, doesn't modify any state | -| Guard SUCCESS in `task_postrun` | **None** | Only affects status display, not actual processing. Returns early without modifying anything. | -| Use `is_complete()` in `_update_job_progress` | **None** | Replaces hardcoded `stage == "results" and progress >= 1.0` check with generic method. Called AFTER `pipeline.save_results()` completes. | -| Add `_cleanup_async_job_resources()` | **Minimal** | Called ONLY after job is marked SUCCESS. Cleanup failures are caught and logged, don't fail the job. Data is already saved at this point. | - -**Worst case scenario:** If `is_complete()` has a bug and always returns False: -- Work still completes normally (queuing, processing, saving all happen) -- Job status would stay STARTED instead of SUCCESS -- UI would show job as incomplete even though work finished -- **Data is safe** - this is a display/status issue, not a data loss risk - -**Existing function to note:** `Job.update_progress()` at lines 924-947 calculates total progress from stages. The new `is_complete()` method is a related concept - checking if all stages are done vs calculating aggregate progress. - ---- - -## Future Improvements (out of scope) - -Per user feedback, consider for follow-up: -- Split job types into different classes with clearer state management -- More robust state machine for job lifecycle -- But don't reinvent state tracking on top of existing bg task tools - -These improvements should be noted in the Job class docstring. - ---- - -## Additional Documentation - -### Add note to Job class docstring - -**File:** `ami/jobs/models.py` (Job class, around line ~700) - -Add to the class docstring: -```python -# Future improvements: -# - Consider splitting job types into subclasses with clearer state management -# - The progress/stages system (see JobProgress, update_progress()) was designed -# for UI display. The is_complete() method bridges this to actual completion logic. -# - Avoid reinventing state tracking on top of existing background task tools -``` - ---- - -## Verification - -1. **Unit test for `is_complete()`:** - - No stages -> False - - One stage at progress=0.5, STARTED -> False - - One stage at progress=1.0, STARTED -> False (not final state) - - One stage at progress=1.0, SUCCESS -> True - - Multiple stages, all SUCCESS -> True - - Multiple stages, one still STARTED -> False - -2. **Unit test for task_postrun behavior:** - - SUCCESS with incomplete stages -> status NOT changed - - SUCCESS with complete stages -> status changed to SUCCESS - - FAILURE -> passes through immediately - - REVOKED -> passes through immediately - -3. **Integration test for async ML job:** - - Create job with async_pipeline_workers enabled - - Queue images to NATS - - Verify job status stays STARTED after run_job completes - - Process all images via process_nats_pipeline_result - - Verify job status becomes SUCCESS after all stages finish - - Verify finished_at is set - - Verify Redis/NATS cleanup occurred - -4. **Regression test for sync jobs:** - - Run sync ML job -> SUCCESS after task completes - - Run DataStorageSyncJob -> SUCCESS after task completes - - Run DataExportJob -> SUCCESS after task completes diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 482d01a58..9f8aa197f 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -197,6 +197,30 @@ def reset(self, status: JobState = JobState.CREATED): stage.progress = 0 stage.status = status + def is_complete(self) -> bool: + """ + Check if all stages have finished processing. + + A job is considered complete when ALL of its stages have: + - progress >= 1.0 (fully processed) + - status in a final state (SUCCESS, FAILURE, or REVOKED) + + This method works for any job type regardless of which stages it has. + It's used by the Celery task_postrun signal to determine whether to + set the job's final SUCCESS status, or defer to async progress handlers. + + Related: Job.update_progress() calculates the aggregate + progress percentage across all stages for display purposes. This method + is a binary check for completion that considers both progress AND status. + + Returns: + True if all stages are complete, False otherwise. + Returns False if job has no stages (shouldn't happen in practice). + """ + if not self.stages: + return False + return all(stage.progress >= 1.0 and stage.status in JobState.final_states() for stage in self.stages) + class Config: use_enum_values = True as_dict = True diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 78d518d91..07716c5fa 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -202,7 +202,7 @@ def pre_update_job_status(sender, task_id, task, **kwargs): @task_postrun.connect(sender=run_job) def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs): - from ami.jobs.models import Job + from ami.jobs.models import Job, JobState job_id = task.request.kwargs["job_id"] if job_id is None: @@ -217,6 +217,15 @@ def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs): logger.error(f"No job found for task {task_id} or job_id {job_id}") return + # Guard only SUCCESS state - let FAILURE, REVOKED, RETRY pass through immediately + # SUCCESS should only be set when all stages are actually complete + # This prevents premature SUCCESS when async workers are still processing + if state == JobState.SUCCESS and not job.progress.is_complete(): + job.logger.info( + f"Job {job.pk} task completed but stages not finished - " "deferring SUCCESS status to progress handler" + ) + return + job.update_status(state) diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index 8e04d9dd9..6017615eb 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -61,6 +61,93 @@ def test_create_job_with_delay(self): self.assertEqual(job.progress.stages[0].progress, 1) self.assertEqual(job.progress.stages[0].status, JobState.SUCCESS) + def test_job_status_guard_prevents_premature_success(self): + """ + Test that update_job_status guards against setting SUCCESS + when job stages are not complete. + + This tests the fix for race conditions where Celery task completes + but async workers are still processing stages. + """ + from unittest.mock import Mock + + from ami.jobs.tasks import update_job_status + + # Create job with multiple stages + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job with incomplete stages", + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + ) + + # Add stages that are NOT complete + job.progress.add_stage("detection") + job.progress.update_stage("detection", progress=0.5, status=JobState.STARTED) + job.progress.add_stage("classification") + job.progress.update_stage("classification", progress=0.0, status=JobState.CREATED) + job.save() + + # Verify stages are incomplete + self.assertFalse(job.progress.is_complete()) + + # Mock task object + mock_task = Mock() + mock_task.request.kwargs = {"job_id": job.pk} + initial_status = job.status + + # Attempt to set SUCCESS while stages are incomplete + update_job_status( + sender=mock_task, + task_id="test-task-id", + task=mock_task, + state=JobState.SUCCESS.value, # Pass string value, not enum + retval=None, + ) + + # Verify job status was NOT updated to SUCCESS (should remain CREATED) + job.refresh_from_db() + self.assertEqual(job.status, initial_status) + self.assertNotEqual(job.status, JobState.SUCCESS.value) + + def test_job_status_allows_failure_states_immediately(self): + """ + Test that FAILURE and REVOKED states bypass the completion guard + and are set immediately regardless of stage completion. + """ + from unittest.mock import Mock + + from ami.jobs.tasks import update_job_status + + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test job for failure states", + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + ) + + # Add incomplete stage + job.progress.add_stage("detection") + job.progress.update_stage("detection", progress=0.3, status=JobState.STARTED) + job.save() + + mock_task = Mock() + mock_task.request.kwargs = {"job_id": job.pk} + + # Test FAILURE state passes through even with incomplete stages + update_job_status( + sender=mock_task, + task_id="test-task-id", + task=mock_task, + state=JobState.FAILURE.value, # Pass string value, not enum + retval=None, + ) + + job.refresh_from_db() + self.assertEqual(job.status, JobState.FAILURE.value) + class TestJobView(APITestCase): """ From f1cd62d2565384ff1107d12c05e38304468bc538 Mon Sep 17 00:00:00 2001 From: carlosgjs Date: Fri, 6 Feb 2026 19:54:58 -0800 Subject: [PATCH 05/25] PSv2: Implement queue clean-up upon job completion (#1113) * merge * feat: PSv2 - Queue/redis clean-up upon job completion * fix: catch specific exception * chore: move tests to a subdir --------- Co-authored-by: Carlos Garcia Jurado Suarez Co-authored-by: Michael Bunsen --- ami/jobs/tasks.py | 30 +++ ami/ml/orchestration/jobs.py | 44 +++- ami/ml/orchestration/tests/__init__.py | 0 ami/ml/orchestration/tests/test_cleanup.py | 212 ++++++++++++++++++ .../{ => tests}/test_nats_queue.py | 0 5 files changed, 279 insertions(+), 7 deletions(-) create mode 100644 ami/ml/orchestration/tests/__init__.py create mode 100644 ami/ml/orchestration/tests/test_cleanup.py rename ami/ml/orchestration/{ => tests}/test_nats_queue.py (100%) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 07716c5fa..44515cd40 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -193,6 +193,29 @@ def _update_job_progress(job_id: int, stage: str, progress_percentage: float, ** job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%") job.save() + # Clean up async resources for completed jobs that use NATS/Redis + # Only ML jobs with async_pipeline_workers enabled use these resources + if stage == "results" and progress_percentage >= 1.0: + job = Job.objects.get(pk=job_id) # Re-fetch outside transaction + _cleanup_job_if_needed(job) + + +def _cleanup_job_if_needed(job) -> None: + """ + Clean up async resources (NATS/Redis) if this job type uses them. + + Only ML jobs with async_pipeline_workers enabled use NATS/Redis resources. + This function is safe to call for any job - it checks if cleanup is needed. + + Args: + job: The Job instance + """ + if job.job_type_key == "ml" and job.project and job.project.feature_flags.async_pipeline_workers: + # import here to avoid circular imports + from ami.ml.orchestration.jobs import cleanup_async_job_resources + + cleanup_async_job_resources(job) + @task_prerun.connect(sender=run_job) def pre_update_job_status(sender, task_id, task, **kwargs): @@ -228,6 +251,10 @@ def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs): job.update_status(state) + # Clean up async resources for revoked jobs + if state == JobState.REVOKED: + _cleanup_job_if_needed(job) + @task_failure.connect(sender=run_job, retry=False) def update_job_failure(sender, task_id, exception, *args, **kwargs): @@ -240,6 +267,9 @@ def update_job_failure(sender, task_id, exception, *args, **kwargs): job.save() + # Clean up async resources for failed jobs + _cleanup_job_if_needed(job) + def log_time(start: float = 0, msg: str | None = None) -> tuple[float, Callable]: """ diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index 621d5f089..9b4a577e9 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -1,28 +1,58 @@ +import logging + from asgiref.sync import async_to_sync -from ami.jobs.models import Job, JobState, logger +from ami.jobs.models import Job, JobState from ami.main.models import SourceImage from ami.ml.orchestration.nats_queue import TaskQueueManager from ami.ml.orchestration.task_state import TaskStateManager from ami.ml.schemas import PipelineProcessingTask +logger = logging.getLogger(__name__) + -# TODO CGJS: (Issue #1083) Call this once a job is fully complete (all images processed and saved) -def cleanup_nats_resources(job: "Job") -> bool: +def cleanup_async_job_resources(job: "Job") -> bool: """ - Clean up NATS JetStream resources (stream and consumer) for a completed job. + Clean up NATS JetStream and Redis resources for a completed job. + + This function cleans up: + 1. Redis state (via TaskStateManager.cleanup): + 2. NATS JetStream resources (via TaskQueueManager.cleanup_job_resources): + + Cleanup failures are logged but don't fail the job - data is already saved. Args: job: The Job instance Returns: - bool: True if cleanup was successful, False otherwise + bool: True if both cleanups succeeded, False otherwise """ - + redis_success = False + nats_success = False + + # Cleanup Redis state + try: + state_manager = TaskStateManager(job.pk) + state_manager.cleanup() + job.logger.info(f"Cleaned up Redis state for job {job.pk}") + redis_success = True + except Exception as e: + job.logger.error(f"Error cleaning up Redis state for job {job.pk}: {e}") + + # Cleanup NATS resources async def cleanup(): async with TaskQueueManager() as manager: return await manager.cleanup_job_resources(job.pk) - return async_to_sync(cleanup)() + try: + nats_success = async_to_sync(cleanup)() + if nats_success: + job.logger.info(f"Cleaned up NATS resources for job {job.pk}") + else: + job.logger.warning(f"Failed to clean up NATS resources for job {job.pk}") + except Exception as e: + job.logger.error(f"Error cleaning up NATS resources for job {job.pk}: {e}") + + return redis_success and nats_success def queue_images_to_nats(job: "Job", images: list[SourceImage]): diff --git a/ami/ml/orchestration/tests/__init__.py b/ami/ml/orchestration/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ami/ml/orchestration/tests/test_cleanup.py b/ami/ml/orchestration/tests/test_cleanup.py new file mode 100644 index 000000000..f7b686d44 --- /dev/null +++ b/ami/ml/orchestration/tests/test_cleanup.py @@ -0,0 +1,212 @@ +"""Integration tests for async job resource cleanup (NATS and Redis).""" + +from asgiref.sync import async_to_sync +from django.core.cache import cache +from django.test import TestCase +from nats.js.errors import NotFoundError + +from ami.jobs.models import Job, JobState, MLJob +from ami.jobs.tasks import _update_job_progress, update_job_failure, update_job_status +from ami.main.models import Project, ProjectFeatureFlags, SourceImage, SourceImageCollection +from ami.ml.models import Pipeline +from ami.ml.orchestration.jobs import queue_images_to_nats +from ami.ml.orchestration.nats_queue import TaskQueueManager +from ami.ml.orchestration.task_state import TaskStateManager + + +class TestCleanupAsyncJobResources(TestCase): + """Test cleanup of NATS and Redis resources for async ML jobs.""" + + def setUp(self): + """Set up test fixtures with async_pipeline_workers enabled.""" + # Create project with async_pipeline_workers feature flag enabled + self.project = Project.objects.create( + name="Test Cleanup Project", + feature_flags=ProjectFeatureFlags(async_pipeline_workers=True), + ) + + # Create pipeline + self.pipeline = Pipeline.objects.create( + name="Test Cleanup Pipeline", + slug="test-cleanup-pipeline", + description="Pipeline for cleanup tests", + ) + self.pipeline.projects.add(self.project) + + # Create source image collection with images + self.collection = SourceImageCollection.objects.create( + name="Test Cleanup Collection", + project=self.project, + ) + + # Create test images + self.images = [ + SourceImage.objects.create( + path=f"test_image_{i}.jpg", + public_base_url="https://example.com", + project=self.project, + ) + for i in range(3) + ] + for image in self.images: + self.collection.images.add(image) + + def _verify_resources_created(self, job_id: int): + """ + Verify that both Redis and NATS resources were created. + + Args: + job_id: The job ID to check + """ + # Verify Redis keys exist + state_manager = TaskStateManager(job_id) + for stage in state_manager.STAGES: + pending_key = state_manager._get_pending_key(stage) + self.assertIsNotNone(cache.get(pending_key), f"Redis key {pending_key} should exist") + total_key = state_manager._total_key + self.assertIsNotNone(cache.get(total_key), f"Redis key {total_key} should exist") + + # Verify NATS stream and consumer exist + async def check_nats_resources(): + async with TaskQueueManager() as manager: + stream_name = manager._get_stream_name(job_id) + consumer_name = manager._get_consumer_name(job_id) + + # Try to get stream info - should succeed if created + stream_exists = True + try: + await manager.js.stream_info(stream_name) + except NotFoundError: + stream_exists = False + + # Try to get consumer info - should succeed if created + consumer_exists = True + try: + await manager.js.consumer_info(stream_name, consumer_name) + except NotFoundError: + consumer_exists = False + + return stream_exists, consumer_exists + + stream_exists, consumer_exists = async_to_sync(check_nats_resources)() + + self.assertTrue(stream_exists, f"NATS stream for job {job_id} should exist") + self.assertTrue(consumer_exists, f"NATS consumer for job {job_id} should exist") + + def _create_job_with_queued_images(self) -> Job: + """ + Helper to create an ML job and queue images to NATS/Redis. + + Returns: + Job instance with images queued to NATS and state initialized in Redis + """ + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Test Cleanup Job", + pipeline=self.pipeline, + source_image_collection=self.collection, + ) + + # Queue images to NATS (also initializes Redis state) + queue_images_to_nats(job, self.images) + + # Verify resources were actually created + self._verify_resources_created(job.pk) + + return job + + def _verify_resources_cleaned(self, job_id: int): + """ + Verify that both Redis and NATS resources are cleaned up. + + Args: + job_id: The job ID to check + """ + # Verify Redis keys are deleted + state_manager = TaskStateManager(job_id) + for stage in state_manager.STAGES: + pending_key = state_manager._get_pending_key(stage) + self.assertIsNone(cache.get(pending_key), f"Redis key {pending_key} should be deleted") + total_key = state_manager._total_key + self.assertIsNone(cache.get(total_key), f"Redis key {total_key} should be deleted") + + # Verify NATS stream and consumer are deleted + async def check_nats_resources(): + async with TaskQueueManager() as manager: + stream_name = manager._get_stream_name(job_id) + consumer_name = manager._get_consumer_name(job_id) + + # Try to get stream info - should fail if deleted + stream_exists = True + try: + await manager.js.stream_info(stream_name) + except NotFoundError: + stream_exists = False + + # Try to get consumer info - should fail if deleted + consumer_exists = True + try: + await manager.js.consumer_info(stream_name, consumer_name) + except NotFoundError: + consumer_exists = False + + return stream_exists, consumer_exists + + stream_exists, consumer_exists = async_to_sync(check_nats_resources)() + + self.assertFalse(stream_exists, f"NATS stream for job {job_id} should be deleted") + self.assertFalse(consumer_exists, f"NATS consumer for job {job_id} should be deleted") + + def test_cleanup_on_job_completion(self): + """Test that resources are cleaned up when job completes successfully.""" + job = self._create_job_with_queued_images() + + # Simulate job completion by updating progress to 100% in results stage + _update_job_progress(job.pk, stage="results", progress_percentage=1.0) + + # Verify cleanup happened + self._verify_resources_cleaned(job.pk) + + def test_cleanup_on_job_failure(self): + """Test that resources are cleaned up when job fails.""" + job = self._create_job_with_queued_images() + + # Set task_id so the failure handler can find the job + job.task_id = "test-task-failure-123" + job.save() + + # Simulate job failure by calling the failure signal handler + update_job_failure( + sender=None, + task_id=job.task_id, + exception=Exception("Test failure"), + ) + + # Verify cleanup happened + self._verify_resources_cleaned(job.pk) + + def test_cleanup_on_job_revoked(self): + """Test that resources are cleaned up when job is revoked/cancelled.""" + job = self._create_job_with_queued_images() + + # Create a mock task request object for the signal handler + class MockRequest: + def __init__(self): + self.kwargs = {"job_id": job.pk} + + class MockTask: + def __init__(self, job_id): + self.request = MockRequest() + self.request.kwargs["job_id"] = job_id + + # Simulate job revocation by calling the postrun signal handler with REVOKED state + update_job_status( + sender=None, + task_id="test-task-revoked-456", + task=MockTask(job.pk), + state=JobState.REVOKED, + ) + + # Verify cleanup happened + self._verify_resources_cleaned(job.pk) diff --git a/ami/ml/orchestration/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py similarity index 100% rename from ami/ml/orchestration/test_nats_queue.py rename to ami/ml/orchestration/tests/test_nats_queue.py From 74df9ea35f0b32feb06badfc473ebbc4382196ca Mon Sep 17 00:00:00 2001 From: carlosgjs Date: Mon, 9 Feb 2026 15:32:04 -0800 Subject: [PATCH 06/25] fix: PSv2: Workers should not try to fetch tasks from v1 jobs (#1118) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces the dispatch_mode field on the Job model to track how each job dispatches its workload. This allows API clients (including the AMI worker) to filter jobs by dispatch mode — for example, fetching only async_api jobs so workers don't pull synchronous or internal jobs. JobDispatchMode enum (ami/jobs/models.py): internal — work handled entirely within the platform (Celery worker, no external calls). Default for all jobs. sync_api — worker calls an external processing service API synchronously and waits for each response. async_api — worker publishes items to NATS for external processing service workers to pick up independently. Database and Model Changes: Added dispatch_mode CharField with TextChoices, defaulting to internal, with the migration in ami/jobs/migrations/0019_job_dispatch_mode.py. ML jobs set dispatch_mode = async_api when the project's async_pipeline_workers feature flag is enabled. ML jobs set dispatch_mode = sync_api on the synchronous processing path (previously unset). API and Filtering: dispatch_mode is exposed (read-only) in job list and detail serializers. Filterable via query parameter: ?dispatch_mode=async_api The /tasks endpoint now returns 400 for non-async_api jobs, since only those have NATS tasks to fetch. Architecture doc: docs/claude/job-dispatch-modes.md documents the three modes, naming decisions, and per-job-type mapping. --------- Co-authored-by: Carlos Garcia Jurado Suarez Co-authored-by: Michael Bunsen Co-authored-by: Claude --- .dockerignore | 10 ++ ami/jobs/migrations/0019_job_dispatch_mode.py | 22 +++ ami/jobs/models.py | 36 +++++ ami/jobs/serializers.py | 2 + ami/jobs/tests.py | 132 +++++++++++++++++- ami/jobs/views.py | 7 +- docs/claude/job-dispatch-modes.md | 69 +++++++++ 7 files changed, 273 insertions(+), 5 deletions(-) create mode 100644 ami/jobs/migrations/0019_job_dispatch_mode.py create mode 100644 docs/claude/job-dispatch-modes.md diff --git a/.dockerignore b/.dockerignore index f0d718f04..95d6ac666 100644 --- a/.dockerignore +++ b/.dockerignore @@ -27,6 +27,10 @@ __pycache__/ *.egg *.whl +# IPython / Jupyter +.ipython/ +.jupyter/ + # Django / runtime artifacts *.log @@ -55,6 +59,12 @@ yarn-error.log .vscode/ *.iml +# Development / testing +.pytest_cache/ +.coverage +.tox/ +.cache/ + # OS cruft .DS_Store Thumbs.db diff --git a/ami/jobs/migrations/0019_job_dispatch_mode.py b/ami/jobs/migrations/0019_job_dispatch_mode.py new file mode 100644 index 000000000..1134c0144 --- /dev/null +++ b/ami/jobs/migrations/0019_job_dispatch_mode.py @@ -0,0 +1,22 @@ +# Generated by Django 4.2.10 on 2026-02-04 20:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("jobs", "0018_alter_job_job_type_key"), + ] + + operations = [ + migrations.AddField( + model_name="job", + name="dispatch_mode", + field=models.CharField( + choices=[("internal", "Internal"), ("sync_api", "Sync API"), ("async_api", "Async API")], + default="internal", + help_text="How the job dispatches its workload: internal, sync_api, or async_api.", + max_length=32, + ), + ), + ] diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 9f8aa197f..8e790cdcb 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -24,6 +24,32 @@ logger = logging.getLogger(__name__) +class JobDispatchMode(models.TextChoices): + """ + How a job dispatches its workload. + + Jobs are configured and launched by users in the UI, then dispatched to + Celery workers. This enum describes what the worker does with the work: + + - INTERNAL: All work happens within the platform (Celery worker handles it directly). + - SYNC_API: Worker calls an external processing service API and waits for each response. + - ASYNC_API: Worker queues items to a message broker (NATS) for external processing + service workers to pick up and process independently. + """ + + # Work is handled entirely within the platform, no external service calls. + # e.g. DataStorageSyncJob, DataExportJob, SourceImageCollectionPopulateJob + INTERNAL = "internal", "Internal" + + # Worker loops over items, sends each to an external processing service + # endpoint synchronously, and waits for the response before continuing. + SYNC_API = "sync_api", "Sync API" + + # Worker publishes all items to a message broker (NATS). External processing + # service workers consume and process them independently, reporting results back. + ASYNC_API = "async_api", "Async API" + + class JobState(str, OrderedEnum): """ These come from Celery, except for CREATED, which is a custom state. @@ -422,6 +448,8 @@ def run(cls, job: "Job"): job.save() if job.project.feature_flags.async_pipeline_workers: + job.dispatch_mode = JobDispatchMode.ASYNC_API + job.save(update_fields=["dispatch_mode"]) queued = queue_images_to_nats(job, images) if not queued: job.logger.error("Aborting job %s because images could not be queued to NATS", job.pk) @@ -431,6 +459,8 @@ def run(cls, job: "Job"): job.save() return else: + job.dispatch_mode = JobDispatchMode.SYNC_API + job.save(update_fields=["dispatch_mode"]) cls.process_images(job, images) @classmethod @@ -822,6 +852,12 @@ class Job(BaseModel): blank=True, related_name="jobs", ) + dispatch_mode = models.CharField( + max_length=32, + choices=JobDispatchMode.choices, + default=JobDispatchMode.INTERNAL, + help_text="How the job dispatches its workload: internal, sync_api, or async_api.", + ) def __str__(self) -> str: return f'#{self.pk} "{self.name}" ({self.status})' diff --git a/ami/jobs/serializers.py b/ami/jobs/serializers.py index 254242046..7a3471003 100644 --- a/ami/jobs/serializers.py +++ b/ami/jobs/serializers.py @@ -127,6 +127,7 @@ class Meta: "job_type", "job_type_key", "data_export", + "dispatch_mode", # "duration", # "duration_label", # "progress_label", @@ -141,6 +142,7 @@ class Meta: "started_at", "finished_at", "duration", + "dispatch_mode", ] diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index 6017615eb..7902faeb1 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -8,7 +8,7 @@ from rest_framework.test import APIRequestFactory, APITestCase from ami.base.serializers import reverse_with_params -from ami.jobs.models import Job, JobProgress, JobState, MLJob, SourceImageCollectionPopulateJob +from ami.jobs.models import Job, JobDispatchMode, JobProgress, JobState, MLJob, SourceImageCollectionPopulateJob from ami.main.models import Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline from ami.ml.orchestration.jobs import queue_images_to_nats @@ -414,6 +414,8 @@ def test_search_jobs(self): def _task_batch_helper(self, value: Any, expected_status: int): pipeline = self._create_pipeline() job = self._create_ml_job("Job for batch test", pipeline) + job.dispatch_mode = JobDispatchMode.ASYNC_API + job.save(update_fields=["dispatch_mode"]) images = [ SourceImage.objects.create( path=f"image_{i}.jpg", @@ -447,12 +449,19 @@ def test_tasks_endpoint_with_invalid_batch(self): def test_tasks_endpoint_without_pipeline(self): """Test the tasks endpoint returns error when job has no pipeline.""" - # Use the existing job which doesn't have a pipeline - job_data = self._create_job("Job without pipeline", start_now=False) + # Create a job without a pipeline but with async_api dispatch mode + # so the dispatch_mode guard passes and the pipeline check is reached + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Job without pipeline", + source_image_collection=self.source_image_collection, + dispatch_mode=JobDispatchMode.ASYNC_API, + ) self.client.force_authenticate(user=self.user) tasks_url = reverse_with_params( - "api:job-tasks", args=[job_data["id"]], params={"project_id": self.project.pk, "batch": 1} + "api:job-tasks", args=[job.pk], params={"project_id": self.project.pk, "batch": 1} ) resp = self.client.get(tasks_url) @@ -510,3 +519,118 @@ def test_result_endpoint_validation(self): resp = self.client.post(result_url, invalid_data, format="json") self.assertEqual(resp.status_code, 400) self.assertIn("result", resp.json()[0].lower()) + + +class TestJobDispatchModeFiltering(APITestCase): + """Test job filtering by dispatch_mode.""" + + def setUp(self): + self.user = User.objects.create_user( # type: ignore + email="testuser-backend@insectai.org", + is_staff=True, + is_active=True, + is_superuser=True, + ) + self.project = Project.objects.create(name="Test Backend Project") + + # Create pipeline for ML jobs + self.pipeline = Pipeline.objects.create( + name="Test ML Pipeline", + slug="test-ml-pipeline", + description="Test ML pipeline for dispatch_mode filtering", + ) + self.pipeline.projects.add(self.project) + + # Create source image collection for jobs + self.source_image_collection = SourceImageCollection.objects.create( + name="Test Collection", + project=self.project, + ) + + # Give the user necessary permissions + assign_perm(Project.Permissions.VIEW_PROJECT, self.user, self.project) + + def test_dispatch_mode_filtering(self): + """Test that jobs can be filtered by dispatch_mode parameter.""" + # Create two ML jobs with different dispatch modes + sync_job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Sync API Job", + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + dispatch_mode=JobDispatchMode.SYNC_API, + ) + + async_job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Async API Job", + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + dispatch_mode=JobDispatchMode.ASYNC_API, + ) + + # Create a job with default dispatch_mode (should be "internal") + internal_job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Internal Job", + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + ) + + self.client.force_authenticate(user=self.user) + jobs_list_url = reverse_with_params("api:job-list", params={"project_id": self.project.pk}) + + # Test filtering by sync_api dispatch_mode + resp = self.client.get(jobs_list_url, {"dispatch_mode": JobDispatchMode.SYNC_API}) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(data["count"], 1) + self.assertEqual(data["results"][0]["id"], sync_job.pk) + self.assertEqual(data["results"][0]["dispatch_mode"], JobDispatchMode.SYNC_API) + + # Test filtering by async_api dispatch_mode + resp = self.client.get(jobs_list_url, {"dispatch_mode": JobDispatchMode.ASYNC_API}) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(data["count"], 1) + self.assertEqual(data["results"][0]["id"], async_job.pk) + self.assertEqual(data["results"][0]["dispatch_mode"], JobDispatchMode.ASYNC_API) + + # Test filtering by invalid dispatch_mode (should return 400 due to choices validation) + resp = self.client.get(jobs_list_url, {"dispatch_mode": "non_existent_mode"}) + self.assertEqual(resp.status_code, 400) + + # Test without dispatch_mode filter (should return all jobs) + resp = self.client.get(jobs_list_url) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(data["count"], 3) # All three jobs + + # Verify the job IDs returned include all jobs + returned_ids = {job["id"] for job in data["results"]} + expected_ids = {sync_job.pk, async_job.pk, internal_job.pk} + self.assertEqual(returned_ids, expected_ids) + + def test_tasks_endpoint_rejects_non_async_jobs(self): + """Test that /tasks endpoint returns 400 for non-async_api jobs.""" + from ami.base.serializers import reverse_with_params + + sync_job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Sync Job for tasks test", + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + dispatch_mode=JobDispatchMode.SYNC_API, + ) + + self.client.force_authenticate(user=self.user) + tasks_url = reverse_with_params( + "api:job-tasks", args=[sync_job.pk], params={"project_id": self.project.pk, "batch": 1} + ) + resp = self.client.get(tasks_url) + self.assertEqual(resp.status_code, 400) + self.assertIn("async_api", resp.json()[0].lower()) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index a2f087cf5..dd8da01b2 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -22,7 +22,7 @@ from ami.ml.schemas import PipelineTaskResult from ami.utils.fields import url_boolean_param -from .models import Job, JobState +from .models import Job, JobDispatchMode, JobState from .serializers import JobListSerializer, JobSerializer, MinimalJobSerializer logger = logging.getLogger(__name__) @@ -43,6 +43,7 @@ class Meta: "source_image_single", "pipeline", "job_type_key", + "dispatch_mode", ] @@ -228,6 +229,10 @@ def tasks(self, request, pk=None): except Exception as e: raise ValidationError({"batch": str(e)}) from e + # Only async_api jobs have tasks fetchable from NATS + if job.dispatch_mode != JobDispatchMode.ASYNC_API: + raise ValidationError("Only async_api jobs have fetchable tasks") + # Validate that the job has a pipeline if not job.pipeline: raise ValidationError("This job does not have a pipeline configured") diff --git a/docs/claude/job-dispatch-modes.md b/docs/claude/job-dispatch-modes.md new file mode 100644 index 000000000..26664be0b --- /dev/null +++ b/docs/claude/job-dispatch-modes.md @@ -0,0 +1,69 @@ +# Job Dispatch Modes + +## Overview + +The `Job` model is a user-facing CMS feature. Users configure jobs in the UI (selecting images, a processing pipeline, etc.), click start, and watch progress. The actual work is dispatched to background workers. The `dispatch_mode` field on `Job` describes *how* that work gets dispatched. + +## The Three Dispatch Modes + +### `internal` + +All work happens within the platform itself. A Celery worker picks up the job and handles it directly — no external service calls. + +**Job types using this mode:** +- `DataStorageSyncJob` — syncs files from S3 storage +- `SourceImageCollectionPopulateJob` — queries DB, populates a capture collection +- `DataExportJob` — generates export files +- `PostProcessingJob` — runs post-processing tasks + +### `sync_api` + +The Celery worker calls an external processing service API synchronously. It loops over items (e.g. batches of images), sends each batch to the processing service endpoint, waits for the response, saves results, and moves on. + +**Job types using this mode:** +- `MLJob` (default path) + +### `async_api` + +The Celery worker publishes all items to a message broker (NATS). External processing service workers consume items independently and report results back. The job monitors progress and completes when all items are processed. + +**Job types using this mode:** +- `MLJob` (when `project.feature_flags.async_pipeline_workers` is enabled) + +## Architecture Context + +``` +User (UI) + │ + ▼ +Job (Django model) ─── dispatch_mode: internal | sync_api | async_api + │ + ▼ +Celery Worker + │ + ├── internal ──────► Work done directly (DB queries, file ops, exports) + │ + ├── sync_api ──────► HTTP calls to Processing Service API (request/response loop) + │ + └── async_api ─────► Publish to NATS ──► External Processing Service workers + │ + ▼ + Results reported back +``` + +## Naming Decisions + +- **Why not `backend`?** Collides with Celery's "result backend" concept and the "ML backend" term used for processing services throughout the codebase. +- **Why not `task_backend`?** "Task backend" is specifically a Celery concept (where task results are stored). +- **Why not `local`?** Ambiguous with local development environments. +- **Why `internal`?** Clean contrast with the two external API modes. "Internal" means the work stays within the platform; `sync_api` and `async_api` both involve external processing services. +- **Why `dispatch_mode`?** The field describes *how* the Celery worker dispatches work to processing services, not how the job itself executes (all jobs execute via Celery). "Dispatch" is more precise than "execution" which is ambiguous. + +## Code Locations + +- Enum: `ami/jobs/models.py` — `JobDispatchMode` +- Field: `ami/jobs/models.py` — `Job.dispatch_mode` +- Serializer: `ami/jobs/serializers.py` — exposed in `JobListSerializer` and read-only +- API filter: `ami/jobs/views.py` — filterable via `?dispatch_mode=sync_api` +- Migration: `ami/jobs/migrations/0019_job_dispatch_mode.py` +- Tests: `ami/jobs/tests.py` — `TestJobDispatchModeFiltering` From 4a082d347fc03910fd62a558a1512520f99e3bcd Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 9 Feb 2026 16:15:57 -0800 Subject: [PATCH 07/25] PSv2 cleanup: use is_complete() and dispatch_mode in job progress handler (#1125) * refactor: use is_complete() and dispatch_mode in job progress handler Replace hardcoded `stage == "results"` check with `job.progress.is_complete()` which verifies ALL stages are done, making it work for any job type. Replace feature flag check in cleanup with `dispatch_mode == ASYNC_API` which is immutable for the job's lifetime and more correct than re-reading a mutable flag that could change between job creation and completion. Co-Authored-By: Claude * test: update cleanup tests for is_complete() and dispatch_mode checks Set dispatch_mode=ASYNC_API on test jobs to match the new cleanup guard. Complete all stages (collect, process, results) in the completion test since is_complete() correctly requires all stages to be done. Co-Authored-By: Claude --------- Co-authored-by: Claude --- ami/jobs/tasks.py | 13 +++++++------ ami/ml/orchestration/tests/test_cleanup.py | 7 +++++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 44515cd40..c928fc12d 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -186,7 +186,7 @@ def _update_job_progress(job_id: int, stage: str, progress_percentage: float, ** progress=progress_percentage, **state_params, ) - if stage == "results" and progress_percentage >= 1.0: + if job.progress.is_complete(): job.status = JobState.SUCCESS job.progress.summary.status = JobState.SUCCESS job.finished_at = datetime.datetime.now() # Use naive datetime in local time @@ -194,23 +194,24 @@ def _update_job_progress(job_id: int, stage: str, progress_percentage: float, ** job.save() # Clean up async resources for completed jobs that use NATS/Redis - # Only ML jobs with async_pipeline_workers enabled use these resources - if stage == "results" and progress_percentage >= 1.0: + if job.progress.is_complete(): job = Job.objects.get(pk=job_id) # Re-fetch outside transaction _cleanup_job_if_needed(job) def _cleanup_job_if_needed(job) -> None: """ - Clean up async resources (NATS/Redis) if this job type uses them. + Clean up async resources (NATS/Redis) if this job uses them. - Only ML jobs with async_pipeline_workers enabled use NATS/Redis resources. + Only jobs with ASYNC_API dispatch mode use NATS/Redis resources. This function is safe to call for any job - it checks if cleanup is needed. Args: job: The Job instance """ - if job.job_type_key == "ml" and job.project and job.project.feature_flags.async_pipeline_workers: + from ami.jobs.models import JobDispatchMode + + if job.dispatch_mode == JobDispatchMode.ASYNC_API: # import here to avoid circular imports from ami.ml.orchestration.jobs import cleanup_async_job_resources diff --git a/ami/ml/orchestration/tests/test_cleanup.py b/ami/ml/orchestration/tests/test_cleanup.py index f7b686d44..ef8382d3d 100644 --- a/ami/ml/orchestration/tests/test_cleanup.py +++ b/ami/ml/orchestration/tests/test_cleanup.py @@ -5,7 +5,7 @@ from django.test import TestCase from nats.js.errors import NotFoundError -from ami.jobs.models import Job, JobState, MLJob +from ami.jobs.models import Job, JobDispatchMode, JobState, MLJob from ami.jobs.tasks import _update_job_progress, update_job_failure, update_job_status from ami.main.models import Project, ProjectFeatureFlags, SourceImage, SourceImageCollection from ami.ml.models import Pipeline @@ -106,6 +106,7 @@ def _create_job_with_queued_images(self) -> Job: name="Test Cleanup Job", pipeline=self.pipeline, source_image_collection=self.collection, + dispatch_mode=JobDispatchMode.ASYNC_API, ) # Queue images to NATS (also initializes Redis state) @@ -162,7 +163,9 @@ def test_cleanup_on_job_completion(self): """Test that resources are cleaned up when job completes successfully.""" job = self._create_job_with_queued_images() - # Simulate job completion by updating progress to 100% in results stage + # Simulate job completion: complete all stages (collect, process, then results) + _update_job_progress(job.pk, stage="collect", progress_percentage=1.0) + _update_job_progress(job.pk, stage="process", progress_percentage=1.0) _update_job_progress(job.pk, stage="results", progress_percentage=1.0) # Verify cleanup happened From e43536bb9243cbd65c3569d82411808f42c2a7ac Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Tue, 10 Feb 2026 16:50:47 -0800 Subject: [PATCH 08/25] track captures and failures --- ami/jobs/tasks.py | 47 +++++++++++++++++-------- ami/ml/orchestration/task_state.py | 49 +++++++++++++++++++++++--- ami/ml/tests.py | 55 ++++++++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 18 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index c928fc12d..f212c7116 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -3,6 +3,7 @@ import logging import time from collections.abc import Callable +from typing import Any from asgiref.sync import async_to_sync from celery.signals import task_failure, task_postrun, task_prerun @@ -59,7 +60,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub result_data: Dictionary containing the pipeline result reply_subject: NATS reply subject for acknowledgment """ - from ami.jobs.models import Job # avoid circular import + from ami.jobs.models import Job, JobState # avoid circular import _, t = log_time() @@ -67,15 +68,19 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub if "error" in result_data: error_result = PipelineResultsError(**result_data) processed_image_ids = {str(error_result.image_id)} if error_result.image_id else set() + failed_image_ids = processed_image_ids # Same as processed for errors logger.error(f"Pipeline returned error for job {job_id}, image {error_result.image_id}: {error_result.error}") pipeline_result = None else: pipeline_result = PipelineResultsResponse(**result_data) processed_image_ids = {str(img.id) for img in pipeline_result.source_images} + failed_image_ids = set() # No failures for successful results state_manager = TaskStateManager(job_id) - progress_info = state_manager.update_state(processed_image_ids, stage="process", request_id=self.request.id) + progress_info = state_manager.update_state( + processed_image_ids, stage="process", request_id=self.request.id, failed_image_ids=failed_image_ids + ) if not progress_info: logger.warning( f"Another task is already processing results for job {job_id}. " @@ -84,12 +89,18 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub raise self.retry(countdown=5, max_retries=10) try: + FAILURE_THRESHOLD = 0.5 + complete_state = JobState.SUCCESS + if (progress_info.failed / progress_info.total) >= FAILURE_THRESHOLD: + complete_state = JobState.FAILURE _update_job_progress( job_id, "process", progress_info.percentage, + complete_state=complete_state, processed=progress_info.processed, remaining=progress_info.remaining, + failed=progress_info.failed, ) _, t = t(f"TIME: Updated job {job_id} progress in PROCESS stage progress to {progress_info.percentage*100}%") @@ -97,8 +108,8 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub job.logger.info(f"Processing pipeline result for job {job_id}, reply_subject: {reply_subject}") job.logger.info( f" Job {job_id} progress: {progress_info.processed}/{progress_info.total} images processed " - f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {len(processed_image_ids)} just " - "processed" + f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {progress_info.failed} failed, " + f"{len(processed_image_ids)} just processed" ) except Job.DoesNotExist: # don't raise and ack so that we don't retry since the job doesn't exists @@ -108,6 +119,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub try: # Save to database (this is the slow operation) + detections_count, classifications_count, captures_count = 0, 0, 0 if pipeline_result: # should never happen since otherwise we could not be processing results here assert job.pipeline is not None, "Job pipeline is None" @@ -118,22 +130,25 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub f"Saved pipeline results to database with {len(pipeline_result.detections)} detections" f", percentage: {progress_info.percentage*100}%" ) + # Calculate detection and classification counts from this result + detections_count = len(pipeline_result.detections) if pipeline_result else 0 + classifications_count = ( + sum(len(detection.classifications) for detection in pipeline_result.detections) + if pipeline_result + else 0 + ) + captures_count = len(pipeline_result.source_images) _ack_task_via_nats(reply_subject, job.logger) # Update job stage with calculated progress - # Calculate detection and classification counts from this result - detections_count = len(pipeline_result.detections) if pipeline_result else 0 - classifications_count = ( - sum(len(detection.classifications) for detection in pipeline_result.detections) if pipeline_result else 0 - ) - progress_info = state_manager.update_state( processed_image_ids, stage="results", request_id=self.request.id, detections_count=detections_count, classifications_count=classifications_count, + captures_count=captures_count, ) if not progress_info: @@ -147,8 +162,10 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub job_id, "results", progress_info.percentage, + complete_state=complete_state, detections=progress_info.detections, classifications=progress_info.classifications, + captures=progress_info.captures, ) except Exception as e: @@ -175,20 +192,22 @@ async def ack_task(): # Don't fail the task if ACK fails - data is already saved -def _update_job_progress(job_id: int, stage: str, progress_percentage: float, **state_params) -> None: +def _update_job_progress( + job_id: int, stage: str, progress_percentage: float, complete_state: Any, **state_params +) -> None: from ami.jobs.models import Job, JobState # avoid circular import with transaction.atomic(): job = Job.objects.select_for_update().get(pk=job_id) job.progress.update_stage( stage, - status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, + status=complete_state if progress_percentage >= 1.0 else JobState.STARTED, progress=progress_percentage, **state_params, ) if job.progress.is_complete(): - job.status = JobState.SUCCESS - job.progress.summary.status = JobState.SUCCESS + job.status = complete_state + job.progress.summary.status = complete_state job.finished_at = datetime.datetime.now() # Use naive datetime in local time job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%") job.save() diff --git a/ami/ml/orchestration/task_state.py b/ami/ml/orchestration/task_state.py index 27c4bec07..ac80040f8 100644 --- a/ami/ml/orchestration/task_state.py +++ b/ami/ml/orchestration/task_state.py @@ -12,7 +12,8 @@ # Define a namedtuple for a TaskProgress with the image counts TaskProgress = namedtuple( - "TaskProgress", ["remaining", "total", "processed", "percentage", "detections", "classifications"] + "TaskProgress", + ["remaining", "total", "processed", "percentage", "detections", "classifications", "captures", "failed"], ) @@ -37,8 +38,10 @@ def __init__(self, job_id: int): self.job_id = job_id self._pending_key = f"job:{job_id}:pending_images" self._total_key = f"job:{job_id}:pending_images_total" + self._failed_key = f"job:{job_id}:failed_images" self._detections_key = f"job:{job_id}:total_detections" self._classifications_key = f"job:{job_id}:total_classifications" + self._captures_key = f"job:{job_id}:total_captures" def initialize_job(self, image_ids: list[str]) -> None: """ @@ -50,11 +53,15 @@ def initialize_job(self, image_ids: list[str]) -> None: for stage in self.STAGES: cache.set(self._get_pending_key(stage), image_ids, timeout=self.TIMEOUT) + # Initialize failed images set for process stage only + cache.set(self._failed_key, set(), timeout=self.TIMEOUT) + cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT) # Initialize detection and classification counters cache.set(self._detections_key, 0, timeout=self.TIMEOUT) cache.set(self._classifications_key, 0, timeout=self.TIMEOUT) + cache.set(self._captures_key, 0, timeout=self.TIMEOUT) def _get_pending_key(self, stage: str) -> str: return f"{self._pending_key}:{stage}" @@ -66,6 +73,8 @@ def update_state( request_id: str, detections_count: int = 0, classifications_count: int = 0, + captures_count: int = 0, + failed_image_ids: set[str] | None = None, ) -> None | TaskProgress: """ Update the task state with newly processed images. @@ -76,6 +85,8 @@ def update_state( request_id: Unique identifier for this processing request detections_count: Number of detections to add to cumulative count classifications_count: Number of classifications to add to cumulative count + captures_count: Number of captures to add to cumulative count + failed_image_ids: Set of image IDs that failed processing (optional) """ # Create a unique lock key for this job lock_key = f"job:{self.job_id}:process_results_lock" @@ -86,7 +97,9 @@ def update_state( try: # Update progress tracking in Redis - progress_info = self._get_progress(processed_image_ids, stage, detections_count, classifications_count) + progress_info = self._get_progress( + processed_image_ids, stage, detections_count, classifications_count, captures_count, failed_image_ids + ) return progress_info finally: # Always release the lock when done @@ -97,7 +110,13 @@ def update_state( logger.debug(f"Released lock for job {self.job_id}, task {request_id}") def _get_progress( - self, processed_image_ids: set[str], stage: str, detections_count: int = 0, classifications_count: int = 0 + self, + processed_image_ids: set[str], + stage: str, + detections_count: int = 0, + classifications_count: int = 0, + captures_count: int = 0, + failed_image_ids: set[str] | None = None, ) -> TaskProgress | None: """ Get current progress information for the job. @@ -108,6 +127,10 @@ def _get_progress( - total: Total number of images (or None if not tracked) - processed: Number of images processed (or None if not tracked) - percentage: Progress as float 0.0-1.0 (or None if not tracked) + - detections: Cumulative count of detections + - classifications: Cumulative count of classifications + - captures: Cumulative count of captures + - failed: Number of unique failed images """ pending_images = cache.get(self._get_pending_key(stage)) total_images = cache.get(self._total_key) @@ -121,15 +144,29 @@ def _get_progress( processed = total_images - remaining percentage = float(processed) / total_images if total_images > 0 else 1.0 - # Update cumulative detection and classification counts + # Update cumulative detection, classification, and capture counts current_detections = cache.get(self._detections_key, 0) current_classifications = cache.get(self._classifications_key, 0) + current_captures = cache.get(self._captures_key, 0) new_detections = current_detections + detections_count new_classifications = current_classifications + classifications_count + new_captures = current_captures + captures_count cache.set(self._detections_key, new_detections, timeout=self.TIMEOUT) cache.set(self._classifications_key, new_classifications, timeout=self.TIMEOUT) + cache.set(self._captures_key, new_captures, timeout=self.TIMEOUT) + + # Update failed images set if provided + if failed_image_ids: + existing_failed = cache.get(self._failed_key) or set() + updated_failed = existing_failed | failed_image_ids # Union to prevent duplicates + cache.set(self._failed_key, updated_failed, timeout=self.TIMEOUT) + failed_set = updated_failed + else: + failed_set = cache.get(self._failed_key) or set() + + failed_count = len(failed_set) logger.info( f"Pending images from Redis for job {self.job_id} {stage}: " @@ -143,6 +180,8 @@ def _get_progress( percentage=percentage, detections=new_detections, classifications=new_classifications, + captures=new_captures, + failed=failed_count, ) def cleanup(self) -> None: @@ -151,6 +190,8 @@ def cleanup(self) -> None: """ for stage in self.STAGES: cache.delete(self._get_pending_key(stage)) + cache.delete(self._failed_key) cache.delete(self._total_key) cache.delete(self._detections_key) cache.delete(self._classifications_key) + cache.delete(self._captures_key) diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 538106c19..9e06536a6 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -882,6 +882,7 @@ def _init_and_verify(self, image_ids): self.assertEqual(progress.percentage, 0.0) self.assertEqual(progress.detections, 0) self.assertEqual(progress.classifications, 0) + self.assertEqual(progress.failed, 0) return progress def test_initialize_job(self): @@ -895,6 +896,7 @@ def test_initialize_job(self): self.assertEqual(progress.total, len(self.image_ids)) self.assertEqual(progress.detections, 0) self.assertEqual(progress.classifications, 0) + self.assertEqual(progress.failed, 0) def test_progress_tracking(self): """Test progress updates correctly as images are processed.""" @@ -1091,3 +1093,56 @@ def test_cleanup_removes_count_keys(self): classifications = cache.get(self.manager._classifications_key) self.assertIsNone(detections) self.assertIsNone(classifications) + + def test_failed_image_tracking(self): + """Test basic failed image tracking with no double-counting on retries.""" + self._init_and_verify(self.image_ids) + + # Mark 2 images as failed in process stage + progress = self.manager._get_progress({"img1", "img2"}, "process", failed_image_ids={"img1", "img2"}) + assert progress is not None + self.assertEqual(progress.failed, 2) + + # Retry same 2 images (fail again) - should not double-count + progress = self.manager._get_progress(set(), "process", failed_image_ids={"img1", "img2"}) + assert progress is not None + self.assertEqual(progress.failed, 2) + + # Fail a different image + progress = self.manager._get_progress(set(), "process", failed_image_ids={"img3"}) + assert progress is not None + self.assertEqual(progress.failed, 3) + + def test_failed_and_processed_mixed(self): + """Test mixed successful and failed processing in same batch.""" + self._init_and_verify(self.image_ids) + + # Process 2 successfully, 2 fail, 1 remains pending + progress = self.manager._get_progress( + {"img1", "img2", "img3", "img4"}, "process", failed_image_ids={"img3", "img4"} + ) + assert progress is not None + self.assertEqual(progress.processed, 4) + self.assertEqual(progress.failed, 2) + self.assertEqual(progress.remaining, 1) + self.assertEqual(progress.percentage, 0.8) + + def test_cleanup_removes_failed_set(self): + """Test that cleanup removes failed image set.""" + from django.core.cache import cache + + self._init_and_verify(self.image_ids) + + # Add failed images + self.manager._get_progress({"img1", "img2"}, "process", failed_image_ids={"img1", "img2"}) + + # Verify failed set exists + failed_set = cache.get(self.manager._failed_key) + self.assertEqual(len(failed_set), 2) + + # Cleanup + self.manager.cleanup() + + # Verify failed set is gone + failed_set = cache.get(self.manager._failed_key) + self.assertIsNone(failed_set) From 50df5f68d5922fe6f6a884559678ea1fe6e36aa7 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 11 Feb 2026 10:53:00 -0800 Subject: [PATCH 09/25] Update tests, CR feedback, log error images --- ami/jobs/tasks.py | 15 ++++-- ami/ml/orchestration/tests/test_cleanup.py | 6 +-- ami/ml/tests.py | 58 ++++++---------------- 3 files changed, 31 insertions(+), 48 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index f212c7116..37047830a 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -16,6 +16,7 @@ from config import celery_app logger = logging.getLogger(__name__) +FAILURE_THRESHOLD = 0.5 # threshold for marking a job as failed based on the percentage of failed images. @celery_app.task(bind=True, soft_time_limit=default_soft_time_limit, time_limit=default_time_limit) @@ -65,11 +66,11 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub _, t = log_time() # Validate with Pydantic - check for error response first + error_result = None if "error" in result_data: error_result = PipelineResultsError(**result_data) processed_image_ids = {str(error_result.image_id)} if error_result.image_id else set() failed_image_ids = processed_image_ids # Same as processed for errors - logger.error(f"Pipeline returned error for job {job_id}, image {error_result.image_id}: {error_result.error}") pipeline_result = None else: pipeline_result = PipelineResultsResponse(**result_data) @@ -89,9 +90,8 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub raise self.retry(countdown=5, max_retries=10) try: - FAILURE_THRESHOLD = 0.5 complete_state = JobState.SUCCESS - if (progress_info.failed / progress_info.total) >= FAILURE_THRESHOLD: + if progress_info.total > 0 and (progress_info.failed / progress_info.total) >= FAILURE_THRESHOLD: complete_state = JobState.FAILURE _update_job_progress( job_id, @@ -111,6 +111,10 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {progress_info.failed} failed, " f"{len(processed_image_ids)} just processed" ) + if error_result: + job.logger.error( + f"Pipeline returned error for job {job_id}, image {error_result.image_id}: {error_result.error}" + ) except Job.DoesNotExist: # don't raise and ack so that we don't retry since the job doesn't exists logger.error(f"Job {job_id} not found") @@ -158,6 +162,11 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub ) raise self.retry(countdown=5, max_retries=10) + # update complete state based on latest progress info after saving results + complete_state = JobState.SUCCESS + if progress_info.total > 0 and (progress_info.failed / progress_info.total) >= FAILURE_THRESHOLD: + complete_state = JobState.FAILURE + _update_job_progress( job_id, "results", diff --git a/ami/ml/orchestration/tests/test_cleanup.py b/ami/ml/orchestration/tests/test_cleanup.py index ef8382d3d..00210a37a 100644 --- a/ami/ml/orchestration/tests/test_cleanup.py +++ b/ami/ml/orchestration/tests/test_cleanup.py @@ -164,9 +164,9 @@ def test_cleanup_on_job_completion(self): job = self._create_job_with_queued_images() # Simulate job completion: complete all stages (collect, process, then results) - _update_job_progress(job.pk, stage="collect", progress_percentage=1.0) - _update_job_progress(job.pk, stage="process", progress_percentage=1.0) - _update_job_progress(job.pk, stage="results", progress_percentage=1.0) + _update_job_progress(job.pk, stage="collect", progress_percentage=1.0, complete_state=JobState.SUCCESS) + _update_job_progress(job.pk, stage="process", progress_percentage=1.0, complete_state=JobState.SUCCESS) + _update_job_progress(job.pk, stage="results", progress_percentage=1.0, complete_state=JobState.SUCCESS) # Verify cleanup happened self._verify_resources_cleaned(job.pk) diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 9e06536a6..4a417b18b 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -992,7 +992,7 @@ def test_cleanup(self): progress = self.manager._get_progress(set(), "process") self.assertIsNone(progress) - def test_cumulative_detection_counting(self): + def test_cumulative_counting(self): """Test that detection counts accumulate correctly across updates.""" self._init_and_verify(self.image_ids) @@ -1001,53 +1001,23 @@ def test_cumulative_detection_counting(self): assert progress is not None self.assertEqual(progress.detections, 3) self.assertEqual(progress.classifications, 0) + self.assertEqual(progress.captures, 0) - # Process second batch with more detections - progress = self.manager._get_progress({"img3"}, "process", detections_count=2) + # Process second batch with more detections and a classification + progress = self.manager._get_progress({"img3"}, "process", detections_count=2, classifications_count=1) assert progress is not None self.assertEqual(progress.detections, 5) # Should be cumulative - self.assertEqual(progress.classifications, 0) - - # Process with both detections and classifications - progress = self.manager._get_progress({"img4"}, "results", detections_count=1, classifications_count=4) - assert progress is not None - self.assertEqual(progress.detections, 6) # Should accumulate - self.assertEqual(progress.classifications, 4) - - def test_cumulative_classification_counting(self): - """Test that classification counts accumulate correctly across updates.""" - self._init_and_verify(self.image_ids) - - # Process first batch with some classifications - progress = self.manager._get_progress({"img1"}, "results", classifications_count=5) - assert progress is not None - self.assertEqual(progress.detections, 0) - self.assertEqual(progress.classifications, 5) + self.assertEqual(progress.classifications, 1) + self.assertEqual(progress.captures, 0) - # Process second batch with more classifications - progress = self.manager._get_progress({"img2", "img3"}, "results", classifications_count=8) - assert progress is not None - self.assertEqual(progress.detections, 0) - self.assertEqual(progress.classifications, 13) # Should be cumulative - - def test_update_state_with_counts(self): - """Test update_state method properly handles detection and classification counts.""" - self._init_and_verify(self.image_ids) - - # Update with counts - progress = self.manager.update_state( - {"img1", "img2"}, "process", "task1", detections_count=4, classifications_count=8 + # Process with detections, classifications, and captures + progress = self.manager._get_progress( + {"img4"}, "results", detections_count=1, classifications_count=4, captures_count=1 ) assert progress is not None - self.assertEqual(progress.processed, 2) - self.assertEqual(progress.detections, 4) - self.assertEqual(progress.classifications, 8) - - # Update with more counts - progress = self.manager.update_state({"img3"}, "results", "task2", detections_count=2, classifications_count=6) - assert progress is not None self.assertEqual(progress.detections, 6) # Should accumulate - self.assertEqual(progress.classifications, 14) # Should accumulate + self.assertEqual(progress.classifications, 5) # Should accumulate + self.assertEqual(progress.captures, 1) def test_counts_persist_across_stages(self): """Test that detection and classification counts persist across different stages.""" @@ -1077,13 +1047,15 @@ def test_cleanup_removes_count_keys(self): self._init_and_verify(self.image_ids) # Add some counts - self.manager._get_progress({"img1"}, "process", detections_count=5, classifications_count=10) + self.manager._get_progress({"img1"}, "process", detections_count=5, classifications_count=10, captures_count=2) # Verify count keys exist detections = cache.get(self.manager._detections_key) classifications = cache.get(self.manager._classifications_key) + captures = cache.get(self.manager._captures_key) self.assertEqual(detections, 5) self.assertEqual(classifications, 10) + self.assertEqual(captures, 2) # Cleanup self.manager.cleanup() @@ -1091,8 +1063,10 @@ def test_cleanup_removes_count_keys(self): # Verify count keys are gone detections = cache.get(self.manager._detections_key) classifications = cache.get(self.manager._classifications_key) + captures = cache.get(self.manager._captures_key) self.assertIsNone(detections) self.assertIsNone(classifications) + self.assertIsNone(captures) def test_failed_image_tracking(self): """Test basic failed image tracking with no double-counting on retries.""" From 3287fe23292117ba9baa23eb65289563f1fd1033 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 11 Feb 2026 12:09:21 -0800 Subject: [PATCH 10/25] CR feedback --- ami/jobs/tasks.py | 15 +++++++-------- ami/ml/tests.py | 1 + 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 37047830a..0b92013d9 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -3,7 +3,7 @@ import logging import time from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING from asgiref.sync import async_to_sync from celery.signals import task_failure, task_postrun, task_prerun @@ -15,6 +15,9 @@ from ami.tasks import default_soft_time_limit, default_time_limit from config import celery_app +if TYPE_CHECKING: + from ami.jobs.models import JobState + logger = logging.getLogger(__name__) FAILURE_THRESHOLD = 0.5 # threshold for marking a job as failed based on the percentage of failed images. @@ -135,12 +138,8 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub f", percentage: {progress_info.percentage*100}%" ) # Calculate detection and classification counts from this result - detections_count = len(pipeline_result.detections) if pipeline_result else 0 - classifications_count = ( - sum(len(detection.classifications) for detection in pipeline_result.detections) - if pipeline_result - else 0 - ) + detections_count = len(pipeline_result.detections) + classifications_count = sum(len(detection.classifications) for detection in pipeline_result.detections) captures_count = len(pipeline_result.source_images) _ack_task_via_nats(reply_subject, job.logger) @@ -202,7 +201,7 @@ async def ack_task(): def _update_job_progress( - job_id: int, stage: str, progress_percentage: float, complete_state: Any, **state_params + job_id: int, stage: str, progress_percentage: float, complete_state: JobState, **state_params ) -> None: from ami.jobs.models import Job, JobState # avoid circular import diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 4a417b18b..f1bf42e4e 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -882,6 +882,7 @@ def _init_and_verify(self, image_ids): self.assertEqual(progress.percentage, 0.0) self.assertEqual(progress.detections, 0) self.assertEqual(progress.classifications, 0) + self.assertEqual(progress.captures, 0) self.assertEqual(progress.failed, 0) return progress From a87b05a19a6be3f7db1dbe9eae7541e5284de654 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 11 Feb 2026 12:15:23 -0800 Subject: [PATCH 11/25] fix type checking --- ami/jobs/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 0b92013d9..e5f84aa55 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -201,7 +201,7 @@ async def ack_task(): def _update_job_progress( - job_id: int, stage: str, progress_percentage: float, complete_state: JobState, **state_params + job_id: int, stage: str, progress_percentage: float, complete_state: "JobState", **state_params ) -> None: from ami.jobs.models import Job, JobState # avoid circular import From a5ff6f8aea9ec1b1d45940b00c22ed04bf18b70b Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 12 Feb 2026 15:19:53 -0800 Subject: [PATCH 12/25] refactor: rename _get_progress to _commit_update in TaskStateManager Clarify naming to distinguish mutating vs read-only methods: - _commit_update(): private, writes mutations to Redis, returns progress - get_progress(): public, read-only snapshot (added in #1129) - update_state(): public API, acquires lock, calls _commit_update() Co-Authored-By: Claude --- ami/ml/orchestration/task_state.py | 4 +-- ami/ml/tests.py | 50 ++++++++++++++++-------------- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/ami/ml/orchestration/task_state.py b/ami/ml/orchestration/task_state.py index de05d1679..a1817ed83 100644 --- a/ami/ml/orchestration/task_state.py +++ b/ami/ml/orchestration/task_state.py @@ -101,7 +101,7 @@ def update_state( try: # Update progress tracking in Redis - progress_info = self._get_progress( + progress_info = self._commit_update( processed_image_ids, stage, detections_count, classifications_count, captures_count, failed_image_ids ) return progress_info @@ -134,7 +134,7 @@ def get_progress(self, stage: str) -> TaskProgress | None: failed=len(failed_set), ) - def _get_progress( + def _commit_update( self, processed_image_ids: set[str], stage: str, diff --git a/ami/ml/tests.py b/ami/ml/tests.py index f1bf42e4e..1c214931f 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -874,7 +874,7 @@ def setUp(self): def _init_and_verify(self, image_ids): """Helper to initialize job and verify initial state.""" self.manager.initialize_job(image_ids) - progress = self.manager._get_progress(set(), "process") + progress = self.manager._commit_update(set(), "process") assert progress is not None self.assertEqual(progress.total, len(image_ids)) self.assertEqual(progress.remaining, len(image_ids)) @@ -892,7 +892,7 @@ def test_initialize_job(self): # Verify both stages are initialized for stage in self.manager.STAGES: - progress = self.manager._get_progress(set(), stage) + progress = self.manager._commit_update(set(), stage) assert progress is not None self.assertEqual(progress.total, len(self.image_ids)) self.assertEqual(progress.detections, 0) @@ -904,7 +904,7 @@ def test_progress_tracking(self): self._init_and_verify(self.image_ids) # Process 2 images - progress = self.manager._get_progress({"img1", "img2"}, "process") + progress = self.manager._commit_update({"img1", "img2"}, "process") assert progress is not None self.assertEqual(progress.remaining, 3) self.assertEqual(progress.processed, 2) @@ -913,14 +913,14 @@ def test_progress_tracking(self): self.assertEqual(progress.classifications, 0) # Process 2 more images - progress = self.manager._get_progress({"img3", "img4"}, "process") + progress = self.manager._commit_update({"img3", "img4"}, "process") assert progress is not None self.assertEqual(progress.remaining, 1) self.assertEqual(progress.processed, 4) self.assertEqual(progress.percentage, 0.8) # Process last image - progress = self.manager._get_progress({"img5"}, "process") + progress = self.manager._commit_update({"img5"}, "process") assert progress is not None self.assertEqual(progress.remaining, 0) self.assertEqual(progress.processed, 5) @@ -958,20 +958,20 @@ def test_stages_independent(self): self._init_and_verify(self.image_ids) # Update process stage - self.manager._get_progress({"img1", "img2"}, "process") - progress_process = self.manager._get_progress(set(), "process") + self.manager._commit_update({"img1", "img2"}, "process") + progress_process = self.manager._commit_update(set(), "process") assert progress_process is not None self.assertEqual(progress_process.remaining, 3) # Results stage should still have all images pending - progress_results = self.manager._get_progress(set(), "results") + progress_results = self.manager._commit_update(set(), "results") assert progress_results is not None self.assertEqual(progress_results.remaining, 5) def test_empty_job(self): """Test handling of job with no images.""" self.manager.initialize_job([]) - progress = self.manager._get_progress(set(), "process") + progress = self.manager._commit_update(set(), "process") assert progress is not None self.assertEqual(progress.total, 0) self.assertEqual(progress.percentage, 1.0) # Empty job is 100% complete @@ -983,14 +983,14 @@ def test_cleanup(self): self._init_and_verify(self.image_ids) # Verify keys exist - progress = self.manager._get_progress(set(), "process") + progress = self.manager._commit_update(set(), "process") self.assertIsNotNone(progress) # Cleanup self.manager.cleanup() # Verify keys are gone - progress = self.manager._get_progress(set(), "process") + progress = self.manager._commit_update(set(), "process") self.assertIsNone(progress) def test_cumulative_counting(self): @@ -998,21 +998,21 @@ def test_cumulative_counting(self): self._init_and_verify(self.image_ids) # Process first batch with some detections - progress = self.manager._get_progress({"img1", "img2"}, "process", detections_count=3) + progress = self.manager._commit_update({"img1", "img2"}, "process", detections_count=3) assert progress is not None self.assertEqual(progress.detections, 3) self.assertEqual(progress.classifications, 0) self.assertEqual(progress.captures, 0) # Process second batch with more detections and a classification - progress = self.manager._get_progress({"img3"}, "process", detections_count=2, classifications_count=1) + progress = self.manager._commit_update({"img3"}, "process", detections_count=2, classifications_count=1) assert progress is not None self.assertEqual(progress.detections, 5) # Should be cumulative self.assertEqual(progress.classifications, 1) self.assertEqual(progress.captures, 0) # Process with detections, classifications, and captures - progress = self.manager._get_progress( + progress = self.manager._commit_update( {"img4"}, "results", detections_count=1, classifications_count=4, captures_count=1 ) assert progress is not None @@ -1025,18 +1025,20 @@ def test_counts_persist_across_stages(self): self._init_and_verify(self.image_ids) # Add counts during process stage - progress_process = self.manager._get_progress({"img1"}, "process", detections_count=3) + progress_process = self.manager._commit_update({"img1"}, "process", detections_count=3) assert progress_process is not None self.assertEqual(progress_process.detections, 3) # Verify counts are available in results stage - progress_results = self.manager._get_progress(set(), "results") + progress_results = self.manager._commit_update(set(), "results") assert progress_results is not None self.assertEqual(progress_results.detections, 3) # Should persist self.assertEqual(progress_results.classifications, 0) # Add more counts in results stage - progress_results = self.manager._get_progress({"img2"}, "results", detections_count=1, classifications_count=5) + progress_results = self.manager._commit_update( + {"img2"}, "results", detections_count=1, classifications_count=5 + ) assert progress_results is not None self.assertEqual(progress_results.detections, 4) # Should accumulate self.assertEqual(progress_results.classifications, 5) @@ -1048,7 +1050,9 @@ def test_cleanup_removes_count_keys(self): self._init_and_verify(self.image_ids) # Add some counts - self.manager._get_progress({"img1"}, "process", detections_count=5, classifications_count=10, captures_count=2) + self.manager._commit_update( + {"img1"}, "process", detections_count=5, classifications_count=10, captures_count=2 + ) # Verify count keys exist detections = cache.get(self.manager._detections_key) @@ -1074,17 +1078,17 @@ def test_failed_image_tracking(self): self._init_and_verify(self.image_ids) # Mark 2 images as failed in process stage - progress = self.manager._get_progress({"img1", "img2"}, "process", failed_image_ids={"img1", "img2"}) + progress = self.manager._commit_update({"img1", "img2"}, "process", failed_image_ids={"img1", "img2"}) assert progress is not None self.assertEqual(progress.failed, 2) # Retry same 2 images (fail again) - should not double-count - progress = self.manager._get_progress(set(), "process", failed_image_ids={"img1", "img2"}) + progress = self.manager._commit_update(set(), "process", failed_image_ids={"img1", "img2"}) assert progress is not None self.assertEqual(progress.failed, 2) # Fail a different image - progress = self.manager._get_progress(set(), "process", failed_image_ids={"img3"}) + progress = self.manager._commit_update(set(), "process", failed_image_ids={"img3"}) assert progress is not None self.assertEqual(progress.failed, 3) @@ -1093,7 +1097,7 @@ def test_failed_and_processed_mixed(self): self._init_and_verify(self.image_ids) # Process 2 successfully, 2 fail, 1 remains pending - progress = self.manager._get_progress( + progress = self.manager._commit_update( {"img1", "img2", "img3", "img4"}, "process", failed_image_ids={"img3", "img4"} ) assert progress is not None @@ -1109,7 +1113,7 @@ def test_cleanup_removes_failed_set(self): self._init_and_verify(self.image_ids) # Add failed images - self.manager._get_progress({"img1", "img2"}, "process", failed_image_ids={"img1", "img2"}) + self.manager._commit_update({"img1", "img2"}, "process", failed_image_ids={"img1", "img2"}) # Verify failed set exists failed_set = cache.get(self.manager._failed_key) From 337b7fc8f0764df12b537add72df58de0358a85f Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 12 Feb 2026 15:27:06 -0800 Subject: [PATCH 13/25] fix: unify FAILURE_THRESHOLD and convert TaskProgress to dataclass - Single FAILURE_THRESHOLD constant in tasks.py, imported by models.py - Fix async path to use `> FAILURE_THRESHOLD` (was `>=`) to match the sync path's boundary behavior at exactly 50% - Convert TaskProgress from namedtuple to dataclass with defaults, so new fields don't break existing callers Co-Authored-By: Claude --- ami/jobs/models.py | 3 ++- ami/jobs/tasks.py | 8 +++++--- ami/ml/orchestration/task_state.py | 19 +++++++++++++------ 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 8e790cdcb..1397410c8 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -561,7 +561,8 @@ def process_images(cls, job, images): job.logger.info(f"All tasks completed for job {job.pk}") - FAILURE_THRESHOLD = 0.5 + from ami.jobs.tasks import FAILURE_THRESHOLD + if image_count and (percent_successful < FAILURE_THRESHOLD): job.progress.update_stage("process", status=JobState.FAILURE) job.save() diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index e5f84aa55..9563be5c4 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -19,7 +19,9 @@ from ami.jobs.models import JobState logger = logging.getLogger(__name__) -FAILURE_THRESHOLD = 0.5 # threshold for marking a job as failed based on the percentage of failed images. +# Minimum success rate. Jobs with fewer than this fraction of images +# processed successfully are marked as failed. Also used in MLJob.process_images(). +FAILURE_THRESHOLD = 0.5 @celery_app.task(bind=True, soft_time_limit=default_soft_time_limit, time_limit=default_time_limit) @@ -94,7 +96,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub try: complete_state = JobState.SUCCESS - if progress_info.total > 0 and (progress_info.failed / progress_info.total) >= FAILURE_THRESHOLD: + if progress_info.total > 0 and (progress_info.failed / progress_info.total) > FAILURE_THRESHOLD: complete_state = JobState.FAILURE _update_job_progress( job_id, @@ -163,7 +165,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub # update complete state based on latest progress info after saving results complete_state = JobState.SUCCESS - if progress_info.total > 0 and (progress_info.failed / progress_info.total) >= FAILURE_THRESHOLD: + if progress_info.total > 0 and (progress_info.failed / progress_info.total) > FAILURE_THRESHOLD: complete_state = JobState.FAILURE _update_job_progress( diff --git a/ami/ml/orchestration/task_state.py b/ami/ml/orchestration/task_state.py index a1817ed83..98665ae86 100644 --- a/ami/ml/orchestration/task_state.py +++ b/ami/ml/orchestration/task_state.py @@ -3,18 +3,25 @@ """ import logging -from collections import namedtuple +from dataclasses import dataclass from django.core.cache import cache logger = logging.getLogger(__name__) -# Define a namedtuple for a TaskProgress with the image counts -TaskProgress = namedtuple( - "TaskProgress", - ["remaining", "total", "processed", "percentage", "detections", "classifications", "captures", "failed"], -) +@dataclass +class TaskProgress: + """Progress snapshot for a job stage tracked in Redis.""" + + remaining: int = 0 + total: int = 0 + processed: int = 0 + percentage: float = 0.0 + detections: int = 0 + classifications: int = 0 + captures: int = 0 + failed: int = 0 def _lock_key(job_id: int) -> str: From afee6e7ed6f23f63456325e9cf7340ebaa9c5e91 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 12 Feb 2026 15:34:11 -0800 Subject: [PATCH 14/25] refactor: rename TaskProgress to JobStateProgress Clarify that this dataclass tracks job-level progress in Redis, not individual task/image progress. Aligns with the naming of JobProgress (the Django/Pydantic model equivalent). Co-Authored-By: Claude --- ami/ml/orchestration/task_state.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ami/ml/orchestration/task_state.py b/ami/ml/orchestration/task_state.py index 98665ae86..ef829c6f0 100644 --- a/ami/ml/orchestration/task_state.py +++ b/ami/ml/orchestration/task_state.py @@ -11,7 +11,7 @@ @dataclass -class TaskProgress: +class JobStateProgress: """Progress snapshot for a job stage tracked in Redis.""" remaining: int = 0 @@ -86,7 +86,7 @@ def update_state( classifications_count: int = 0, captures_count: int = 0, failed_image_ids: set[str] | None = None, - ) -> None | TaskProgress: + ) -> None | JobStateProgress: """ Update the task state with newly processed images. @@ -120,7 +120,7 @@ def update_state( cache.delete(lock_key) logger.debug(f"Released lock for job {self.job_id}, task {request_id}") - def get_progress(self, stage: str) -> TaskProgress | None: + def get_progress(self, stage: str) -> JobStateProgress | None: """Read-only progress snapshot for the given stage. Does not acquire a lock or mutate state.""" pending_images = cache.get(self._get_pending_key(stage)) total_images = cache.get(self._total_key) @@ -130,7 +130,7 @@ def get_progress(self, stage: str) -> TaskProgress | None: processed = total_images - remaining percentage = float(processed) / total_images if total_images > 0 else 1.0 failed_set = cache.get(self._failed_key) or set() - return TaskProgress( + return JobStateProgress( remaining=remaining, total=total_images, processed=processed, @@ -149,7 +149,7 @@ def _commit_update( classifications_count: int = 0, captures_count: int = 0, failed_image_ids: set[str] | None = None, - ) -> TaskProgress | None: + ) -> JobStateProgress | None: """ Update pending images and return progress. Must be called under lock. @@ -196,7 +196,7 @@ def _commit_update( f"{remaining}/{total_images}: {percentage*100}%" ) - return TaskProgress( + return JobStateProgress( remaining=remaining, total=total_images, processed=processed, From 65d77cb4bb70234bd4092da82ec9c693b9ae90dd Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 12 Feb 2026 16:56:57 -0800 Subject: [PATCH 15/25] docs: update NATS todo and planning docs with session learnings Mark connection handling as done (PR #1130), add worktree/remote mapping and docker testing notes for future sessions. Co-Authored-By: Claude --- docs/claude/nats-todo.md | 159 ++++++++++++++++++ .../planning/pr-trackcounts-next-session.md | 67 ++++++++ 2 files changed, 226 insertions(+) create mode 100644 docs/claude/nats-todo.md create mode 100644 docs/claude/planning/pr-trackcounts-next-session.md diff --git a/docs/claude/nats-todo.md b/docs/claude/nats-todo.md new file mode 100644 index 000000000..127ede36a --- /dev/null +++ b/docs/claude/nats-todo.md @@ -0,0 +1,159 @@ +# NATS Infrastructure TODO + +Tracked improvements for the NATS JetStream setup used by async ML pipeline jobs. + +## Urgent + +### Add `NATS_URL` to worker-2 env file + +- **Status:** DONE (env var added, container reloaded, connection verified) +- **Root cause of job 2226 failure:** worker-2 was missing `NATS_URL` in `.envs/.production/.django`, so it defaulted to `nats://localhost:4222`. Every NATS ack from worker-2 failed with `Connect call failed ('127.0.0.1', 4222)`. +- **Fix applied in code:** Changed default in `config/settings/base.py:268` from `nats://localhost:4222` to `nats://nats:4222` (matches the hostname mapped via `extra_hosts` in all compose files). +- **Still needed on server:** + ```bash + ssh ami-cc "ssh ami-worker-2 'echo NATS_URL=nats://nats:4222 >> ~/ami-platform/.envs/.production/.django'" + ssh ami-cc "ssh ami-worker-2 'cd ~/ami-platform && docker compose -f docker-compose.worker.yml restart celeryworker'" + ``` + +## Error Handling + +### Don't retry permanent errors + +- **File:** `ami/ml/orchestration/nats_queue.py:118` +- **Current:** `max_deliver=5` retries every failed message, including permanent errors (404 image not found, malformed data, etc.) +- **Problem:** NATS has no way to distinguish transient vs permanent failures. If a task fails because the image URL is broken, it will be redelivered 5 times, wasting processing service time. +- **Proposed fix:** The error handling should happen in the celery task (`process_nats_pipeline_result`) and in the processing service, not in NATS redelivery. If the processing service returns an error result, the celery task should ack the NATS message (removing it from the queue) and record the error on the job. NATS redelivery should only cover the case where a consumer crashes mid-processing (no result posted at all). +- **Consider:** Reducing `max_deliver` to 2-3 since the only legitimate redelivery scenario is consumer crash/timeout, not application errors. + +### Detect and surface exhausted messages (dead letters) + +- **Current:** When a message hits `max_deliver`, NATS silently drops it. The job hangs forever with remaining images never processed and no error shown to the user. +- **Problem:** There's no feedback loop. The `process_nats_pipeline_result` celery task only runs when the processing service posts a result. If NATS stops delivering a message (because it hit `max_deliver`), no celery task fires, no log is written, and the job just stalls. +- **Proposed approach — poll consumer state from the celery job:** + The `run_job` celery task currently returns immediately after queuing images. Instead, it could poll the NATS consumer state periodically until the job completes or stalls: + + ```python + # In the run_job task or a separate watchdog task: + async with TaskQueueManager() as manager: + info = await js.consumer_info(stream_name, consumer_name) + delivered = info.num_delivered + ack_floor = info.ack_floor.stream_seq + pending = info.num_pending + ack_pending = info.num_ack_pending + + if pending == 0 and ack_pending == 0 and ack_floor < total_images: + # Messages exhausted max_deliver — they're dead + dead_count = total_images - ack_floor + job.logger.error( + f"{dead_count} tasks exceeded max delivery attempts " + f"and will not be retried. {ack_floor}/{total_images} completed." + ) + job.update_status(JobState.FAILURE) + ``` + + This would surface a clear error in the job logs visible in Django admin. + +- **Alternative — NATS advisory subscription:** + Subscribe to `$JS.EVENT.ADVISORY.CONSUMER.MAX_DELIVERIES.job_{id}.*` and log each dead message individually. More complex but gives per-message visibility. +- **Where to implement:** Either as a polling loop in `run_job` (simplest), or as a separate Celery Beat task that checks all active async jobs. +- **Files:** `ami/jobs/tasks.py` (run_job or new watchdog task), `ami/ml/orchestration/nats_queue.py` (add `get_consumer_info` method) + +## Infrastructure + +### Review NATS compose file on ami-redis-1 + +- **Location:** `docker-compose.yml` on ami-redis-1 +- **Current config is mostly good:** ports exposed (4222, 8222), healthcheck configured, restart=always, JetStream enabled +- **Missing: Persistent volume.** JetStream stores data in `/tmp/nats/jetstream` (container temp dir). Server logs warn: `Temporary storage directory used, data could be lost on system reboot`. Add a volume mount: + ```yaml + nats: + image: nats:2.10-alpine + volumes: + - nats-data:/data/jetstream + command: ["-js", "-m", "8222", "-sd", "/data/jetstream"] + volumes: + nats-data: + ``` +- **Consider:** Adding memory/storage limits to JetStream config (`-js --max_mem_store`, `--max_file_store`) to prevent unbounded growth +- **Consider:** Adding NATS config file instead of CLI flags for more control (auth, logging level, connection limits) + +### Clean up stale streams + +- **Current:** 95 streams exist on the server, all but one are empty (from old jobs) +- **Cleanup runs on job completion** (`cleanup_async_job_resources` in `ami/ml/orchestration/jobs.py`), but only if the job fully completes. Failed/stuck jobs leave orphan streams. +- **Proposed fix:** Add a periodic Celery Beat task to clean up streams older than 24h (matching the `max_age=86400` retention on streams). Or clean up streams for jobs that are in a final state (SUCCESS, FAILURE, REVOKED). + +### Expose NATS monitoring for dashboard access + +- **Port 8222 is already exposed** on ami-redis-1, so `http://192.168.123.176:8222` should work from the VPN +- **For browser dashboard** (https://natsdashboard.com/): Needs the monitoring endpoint reachable from your browser. Use SSH tunnel if not on VPN: + ```bash + ssh -L 8222:localhost:8222 ami-cc -t "ssh -L 8222:localhost:8222 ami-redis-1" + ``` + Then open https://natsdashboard.com/ with server URL `http://localhost:8222` + +## Reliability + +### Connection handling in ack path + +- **Status:** DONE (PR #1130, branch `carlosg/natsconn` on `uw-ssec` remote) +- **What was done:** Added `retry_on_connection_error` decorator with exponential backoff. Replaced connection pool with async context manager pattern — each `async_to_sync()` call scopes one connection. Added `reconnected_cb`/`disconnected_cb` logging callbacks. +- **Commit:** `c384199f` refactor: simplify NATS connection handling — keep retry decorator, drop pool + +### `check_processing_services_online` causing worker instability + +- **Observed on both ami-live and worker-2:** This periodic Beat task hits soft time limit (10s) and hard time limit (20s) on every run, causing ForkPoolWorker processes to be SIGKILL'd. +- **Service #13 "Zero Shot Detector Pipelines"** at `https://ml-zs.dev.insectai.org/` consistently times out. +- **Impact:** Worker pool instability, killed processes may have been mid-task. This could contribute to unreliable task processing. +- **Fix:** Either increase the time limit, skip known-offline services, or handle the timeout more gracefully. + +## Source Code References + +| File | Line | What | +| ------------------------------------ | ---------------------------------- | ---------------------------------------------------------- | +| `config/settings/base.py` | 268 | `NATS_URL` setting (default changed to `nats://nats:4222`) | +| `ami/ml/orchestration/nats_queue.py` | 97-108 | `TaskQueueManager` context manager + connection setup | +| `ami/ml/orchestration/nats_queue.py` | 42-94 | `retry_on_connection_error` decorator | +| `ami/ml/orchestration/nats_queue.py` | 191-215 | Consumer config (`max_deliver`, `ack_wait`) | +| `ami/jobs/tasks.py` | 134-149 | `_ack_task_via_nats()` | +| `ami/ml/orchestration/jobs.py` | 14-55 | `cleanup_async_job_resources()` | +| `ami/ml/tasks.py` | `check_processing_services_online` | Periodic health check (causing SIGKILL) | + +## DevOps + +### Move triage steps to `ami-devops` repo and create a Claude skill + +- Port `docs/claude/debugging/nats-triage.md` to the `ami-devops` repo +- Create a Claude Code skill (slash command) for NATS triage that can run the diagnostic commands interactively +- Skill should accept a job ID and walk through the triage steps: check job logs, inspect NATS stream state, test connectivity from workers + +## Other + +- Processing service failing on batches with different image sizes +- How can we mark an image/task as failed and say don't retry? +- Processing service still needs to batch classifications (like prevous methods) +- Nats jobs appear stuck if there are any task failures: https://antenna.insectai.org/projects/18/jobs/2228 +- If a task crashes, the whole worker seems to reset +- Then no tasks are found remaining for the job in NATS + 2026-02-09 18:23:49 [info ] No jobs found, sleeping for 5 seconds + 2026-02-09 18:23:54 [info ] Checking for jobs for pipeline panama_moths_2023 + 2026-02-09 18:23:55 [info ] Checking for jobs for pipeline panama_moths_2024 + 2026-02-09 18:23:55 [info ] Checking for jobs for pipeline quebec_vermont_moths_2023 + 2026-02-09 18:23:55 [info ] Processing job 2229 with pipeline quebec_vermont_moths_2023 + 2026-02-09 18:23:55 [info ] Worker 0/2 starting iteration for job 2229 + 2026-02-09 18:23:55 [info ] Worker 1/2 starting iteration for job 2229 + 2026-02-09 18:23:59 [info ] Worker 0: No more tasks for job 2229 + 2026-02-09 18:23:59 [info ] Worker 0: Iterator finished + 2026-02-09 18:24:03 [info ] Worker 1: No more tasks for job 2229 + 2026-02-09 18:24:03 [info ] Worker 1: Iterator finished + 2026-02-09 18:24:03 [info ] Done, detections: 0. Detecting time: 0.0, classification time: 0.0, dl time: 0.0, save time: 0.0 + +- Would love some logs like "no task has been picked up in X minutes" or "last seen", etc. +- Skip jobs that hbs no tasks in the initial query + +- test in a SLUM job! yeah! in Victoria? + +- jumps around between jobs - good thing? annoying? what about when there is only one job open? +- time for time estimates + +- bring back vectors asap diff --git a/docs/claude/planning/pr-trackcounts-next-session.md b/docs/claude/planning/pr-trackcounts-next-session.md new file mode 100644 index 000000000..d1c085df2 --- /dev/null +++ b/docs/claude/planning/pr-trackcounts-next-session.md @@ -0,0 +1,67 @@ +# Next Session: carlos/trackcounts PR Review + +## Context + +We're reviewing and cleaning up a PR that adds cumulative count tracking (detections, classifications, captures, failed images) to the async (NATS) job pipeline. The PR is on branch `carlos/trackcounts`. + +## Working with worktrees and remotes + +There are two related branches on different remotes/worktrees: + +| Branch | Remote | Worktree | PR | +|--------|--------|----------|-----| +| `carlos/trackcounts` | `origin` (RolnickLab) | `/home/michael/Projects/AMI/antenna` (main) | trackcounts PR | +| `carlosg/natsconn` | `uw-ssec` (uw-ssec) | `/home/michael/Projects/AMI/antenna-natsconn` | #1130 | + +**Key rules:** +- `carlosg/natsconn` pushes to `uw-ssec`, NOT `origin` +- The worktree at `antenna-natsconn` is detached HEAD — use `git push uw-ssec HEAD:refs/heads/carlosg/natsconn` +- To run tests from the worktree, mount only `ami/` over the main compose: `docker compose run --rm -v /home/michael/Projects/AMI/antenna-natsconn/ami:/app/ami:z django python manage.py test ... --keepdb` +- Container uses Pydantic v1 — use `.dict()` / `.json()`, NOT `.model_dump()` / `.model_dump_json()` + +## Completed commits (on top of main merge) + +1. `a5ff6f8a` — Rename `_get_progress` → `_commit_update` in TaskStateManager +2. `337b7fc8` — Unify `FAILURE_THRESHOLD` constant + convert TaskProgress to dataclass +3. `89111a05` — Rename `TaskProgress` → `JobStateProgress` + +## Remaining work + +### 1. Review: Is `JobStateProgress` the right abstraction? + +**Question:** What is `TaskStateManager` actually tracking, and does `JobStateProgress` reflect that correctly? + +There are two parallel progress tracking systems: +- **Django side:** `Job.progress` → `JobProgress` (Pydantic model, persisted in DB) + - Has multiple stages (collect, process, results) + - Each stage has its own status, progress percentage, and arbitrary params + - `is_complete()` checks all stages +- **Redis side:** `TaskStateManager` → `JobStateProgress` (dataclass, ephemeral in Redis) + - Tracks pending image IDs per stage (process, results) + - Tracks cumulative counts: detections, classifications, captures, failed + - Single flat object — no per-stage breakdown of counts + +The disconnect: Redis tracks **per-stage pending images** (separate pending lists for "process" and "results" stages) but returns **job-wide cumulative counts** (one detections counter, one failed set). So `JobStateProgress` is a hybrid — stage-scoped for image completion, but job-scoped for counts. + +**Should counts be per-stage?** For example, "failed" in the process stage means images that errored during ML inference. But could there be failures in the results stage too (failed to save)? The sync path tracks `request_failed_images` (process failures) separately from `failed_save_tasks` (results failures). The async path currently lumps all failures into one set. + +**Key files to read:** +- `ami/ml/orchestration/task_state.py` — `TaskStateManager` and `JobStateProgress` +- `ami/jobs/tasks.py:62-185` — `process_nats_pipeline_result` (async path, uses TaskStateManager) +- `ami/jobs/models.py:466-582` — `MLJob.process_images` (sync path, tracks counts locally) +- `ami/jobs/models.py:134-248` — `JobProgress`, `JobProgressStageDetail`, `is_complete()` + +### 2. Remove `complete_state` parameter + +Plan is written at: `docs/claude/planning/pr-trackcounts-complete-state-removal.md` + +Summary: Remove `complete_state` from `_update_job_progress`. Jobs always complete as SUCCESS. Failure counts are tracked as stage params but don't affect overall status. "Completed with failures" state deferred to future PR. + +**Files to modify:** +- `ami/jobs/tasks.py:97-179, 205-223` — remove complete_state logic +- `ami/ml/orchestration/tests/test_cleanup.py:164-168` — update test calls +- `ami/jobs/test_tasks.py` — update any affected tests + +### 3. PR assessment doc + +Full review written at: `docs/claude/planning/pr-trackcounts-review.md` From 8e8cd80ef24ab93433630beca73c59c3a2640a70 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 13 Feb 2026 07:58:56 -0800 Subject: [PATCH 16/25] Rename TaskStateManager to AsyncJobStateManager --- ami/jobs/tasks.py | 4 ++-- ami/jobs/test_tasks.py | 14 +++++++------- .../{task_state.py => async_job_state.py} | 2 +- ami/ml/orchestration/jobs.py | 6 +++--- ami/ml/orchestration/tests/test_cleanup.py | 6 +++--- ami/ml/tests.py | 4 ++-- 6 files changed, 18 insertions(+), 18 deletions(-) rename ami/ml/orchestration/{task_state.py => async_job_state.py} (99%) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 9563be5c4..5be262f61 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -9,8 +9,8 @@ from celery.signals import task_failure, task_postrun, task_prerun from django.db import transaction +from ami.ml.orchestration.async_job_state import AsyncJobStateManager from ami.ml.orchestration.nats_queue import TaskQueueManager -from ami.ml.orchestration.task_state import TaskStateManager from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse from ami.tasks import default_soft_time_limit, default_time_limit from config import celery_app @@ -82,7 +82,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub processed_image_ids = {str(img.id) for img in pipeline_result.source_images} failed_image_ids = set() # No failures for successful results - state_manager = TaskStateManager(job_id) + state_manager = AsyncJobStateManager(job_id) progress_info = state_manager.update_state( processed_image_ids, stage="process", request_id=self.request.id, failed_image_ids=failed_image_ids diff --git a/ami/jobs/test_tasks.py b/ami/jobs/test_tasks.py index fc291c8ba..b37940cdd 100644 --- a/ami/jobs/test_tasks.py +++ b/ami/jobs/test_tasks.py @@ -17,7 +17,7 @@ from ami.jobs.tasks import process_nats_pipeline_result from ami.main.models import Detection, Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline -from ami.ml.orchestration.task_state import TaskStateManager, _lock_key +from ami.ml.orchestration.async_job_state import AsyncJobStateManager, _lock_key from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse, SourceImageResponse from ami.users.models import User @@ -64,7 +64,7 @@ def setUp(self): # Initialize state manager self.image_ids = [str(img.pk) for img in self.images] - self.state_manager = TaskStateManager(self.job.pk) + self.state_manager = AsyncJobStateManager(self.job.pk) self.state_manager.initialize_job(self.image_ids) def tearDown(self): @@ -90,7 +90,7 @@ def _assert_progress_updated( self, job_id: int, expected_processed: int, expected_total: int, stage: str = "process" ): """Assert TaskStateManager state is correct.""" - manager = TaskStateManager(job_id) + manager = AsyncJobStateManager(job_id) progress = manager.get_progress(stage) self.assertIsNotNone(progress, f"Progress not found for stage '{stage}'") self.assertEqual(progress.processed, expected_processed) @@ -157,7 +157,7 @@ def test_process_nats_pipeline_result_error_no_image_id(self, mock_manager_class # Assert: Progress was NOT updated (empty set of processed images) # Since no image_id was provided, processed_image_ids = set() - manager = TaskStateManager(self.job.pk) + manager = AsyncJobStateManager(self.job.pk) progress = manager.get_progress("process") self.assertEqual(progress.processed, 0) # No images marked as processed @@ -208,7 +208,7 @@ def test_process_nats_pipeline_result_mixed_results(self, mock_manager_class): ) # Assert: All 3 images marked as processed in TaskStateManager - manager = TaskStateManager(self.job.pk) + manager = AsyncJobStateManager(self.job.pk) process_progress = manager.get_progress("process") self.assertIsNotNone(process_progress) self.assertEqual(process_progress.processed, 3) @@ -266,7 +266,7 @@ def test_process_nats_pipeline_result_error_concurrent_locking(self, mock_manage ) # Assert: Progress was NOT updated (lock not acquired) - manager = TaskStateManager(self.job.pk) + manager = AsyncJobStateManager(self.job.pk) progress = manager.get_progress("process") self.assertEqual(progress.processed, 0) @@ -342,7 +342,7 @@ def setUp(self): ) # Initialize state manager - state_manager = TaskStateManager(self.job.pk) + state_manager = AsyncJobStateManager(self.job.pk) state_manager.initialize_job([str(self.image.pk)]) def tearDown(self): diff --git a/ami/ml/orchestration/task_state.py b/ami/ml/orchestration/async_job_state.py similarity index 99% rename from ami/ml/orchestration/task_state.py rename to ami/ml/orchestration/async_job_state.py index 98665ae86..ef7f54bdd 100644 --- a/ami/ml/orchestration/task_state.py +++ b/ami/ml/orchestration/async_job_state.py @@ -28,7 +28,7 @@ def _lock_key(job_id: int) -> str: return f"job:{job_id}:process_results_lock" -class TaskStateManager: +class AsyncJobStateManager: """ Manages job progress tracking state in Redis. diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index 9b4a577e9..ce54ecd1c 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -4,8 +4,8 @@ from ami.jobs.models import Job, JobState from ami.main.models import SourceImage +from ami.ml.orchestration.async_job_state import AsyncJobStateManager from ami.ml.orchestration.nats_queue import TaskQueueManager -from ami.ml.orchestration.task_state import TaskStateManager from ami.ml.schemas import PipelineProcessingTask logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ def cleanup_async_job_resources(job: "Job") -> bool: # Cleanup Redis state try: - state_manager = TaskStateManager(job.pk) + state_manager = AsyncJobStateManager(job.pk) state_manager.cleanup() job.logger.info(f"Cleaned up Redis state for job {job.pk}") redis_success = True @@ -88,7 +88,7 @@ def queue_images_to_nats(job: "Job", images: list[SourceImage]): tasks.append((image.pk, task)) # Store all image IDs in Redis for progress tracking - state_manager = TaskStateManager(job.pk) + state_manager = AsyncJobStateManager(job.pk) state_manager.initialize_job(image_ids) job.logger.info(f"Initialized task state tracking for {len(image_ids)} images") diff --git a/ami/ml/orchestration/tests/test_cleanup.py b/ami/ml/orchestration/tests/test_cleanup.py index 00210a37a..ccdfa2c49 100644 --- a/ami/ml/orchestration/tests/test_cleanup.py +++ b/ami/ml/orchestration/tests/test_cleanup.py @@ -9,9 +9,9 @@ from ami.jobs.tasks import _update_job_progress, update_job_failure, update_job_status from ami.main.models import Project, ProjectFeatureFlags, SourceImage, SourceImageCollection from ami.ml.models import Pipeline +from ami.ml.orchestration.async_job_state import AsyncJobStateManager from ami.ml.orchestration.jobs import queue_images_to_nats from ami.ml.orchestration.nats_queue import TaskQueueManager -from ami.ml.orchestration.task_state import TaskStateManager class TestCleanupAsyncJobResources(TestCase): @@ -59,7 +59,7 @@ def _verify_resources_created(self, job_id: int): job_id: The job ID to check """ # Verify Redis keys exist - state_manager = TaskStateManager(job_id) + state_manager = AsyncJobStateManager(job_id) for stage in state_manager.STAGES: pending_key = state_manager._get_pending_key(stage) self.assertIsNotNone(cache.get(pending_key), f"Redis key {pending_key} should exist") @@ -125,7 +125,7 @@ def _verify_resources_cleaned(self, job_id: int): job_id: The job ID to check """ # Verify Redis keys are deleted - state_manager = TaskStateManager(job_id) + state_manager = AsyncJobStateManager(job_id) for stage in state_manager.STAGES: pending_key = state_manager._get_pending_key(stage) self.assertIsNone(cache.get(pending_key), f"Redis key {pending_key} should be deleted") diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 1c214931f..ac9167455 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -864,11 +864,11 @@ def setUp(self): """Set up test fixtures.""" from django.core.cache import cache - from ami.ml.orchestration.task_state import TaskStateManager + from ami.ml.orchestration.async_job_state import AsyncJobStateManager cache.clear() self.job_id = 123 - self.manager = TaskStateManager(self.job_id) + self.manager = AsyncJobStateManager(self.job_id) self.image_ids = ["img1", "img2", "img3", "img4", "img5"] def _init_and_verify(self, image_ids): From afc44721f4bcd11de97add71167fcb3644e0f134 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 13 Feb 2026 09:00:47 -0800 Subject: [PATCH 17/25] Track results counts in the job itself vs Redis --- ami/jobs/tasks.py | 59 ++++++++++++++-- ami/ml/orchestration/async_job_state.py | 43 +----------- ami/ml/tests.py | 91 ------------------------- 3 files changed, 54 insertions(+), 139 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 5be262f61..35c09ed3c 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -151,9 +151,6 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub processed_image_ids, stage="results", request_id=self.request.id, - detections_count=detections_count, - classifications_count=classifications_count, - captures_count=captures_count, ) if not progress_info: @@ -173,9 +170,9 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub "results", progress_info.percentage, complete_state=complete_state, - detections=progress_info.detections, - classifications=progress_info.classifications, - captures=progress_info.captures, + detections=detections_count, + classifications=classifications_count, + captures=captures_count, ) except Exception as e: @@ -202,6 +199,40 @@ async def ack_task(): # Don't fail the task if ACK fails - data is already saved +def _get_current_counts_from_job_progress(job, stage: str) -> tuple[int, int, int]: + """ + Get current detections, classifications, and captures counts from job progress. + + Args: + job: The Job instance + stage: The stage name to read counts from + + Returns: + Tuple of (detections, classifications, captures) counts, defaulting to 0 if not found + """ + try: + stage_obj = job.progress.get_stage(stage) + + # Initialize defaults + detections = 0 + classifications = 0 + captures = 0 + + # Search through the params list for our count values + for param in stage_obj.params: + if param.key == "detections": + detections = param.value or 0 + elif param.key == "classifications": + classifications = param.value or 0 + elif param.key == "captures": + captures = param.value or 0 + + return detections, classifications, captures + except (ValueError, AttributeError): + # Stage doesn't exist or doesn't have these attributes yet + return 0, 0, 0 + + def _update_job_progress( job_id: int, stage: str, progress_percentage: float, complete_state: "JobState", **state_params ) -> None: @@ -209,6 +240,22 @@ def _update_job_progress( with transaction.atomic(): job = Job.objects.select_for_update().get(pk=job_id) + + # For results stage, accumulate detections/classifications/captures counts + if stage == "results": + current_detections, current_classifications, current_captures = _get_current_counts_from_job_progress( + job, stage + ) + + # Add new counts to existing counts + new_detections = state_params.get("detections", 0) or 0 + new_classifications = state_params.get("classifications", 0) or 0 + new_captures = state_params.get("captures", 0) or 0 + + state_params["detections"] = current_detections + new_detections + state_params["classifications"] = current_classifications + new_classifications + state_params["captures"] = current_captures + new_captures + job.progress.update_stage( stage, status=complete_state if progress_percentage >= 1.0 else JobState.STARTED, diff --git a/ami/ml/orchestration/async_job_state.py b/ami/ml/orchestration/async_job_state.py index 21de70ec8..fe0f7eb74 100644 --- a/ami/ml/orchestration/async_job_state.py +++ b/ami/ml/orchestration/async_job_state.py @@ -18,9 +18,6 @@ class JobStateProgress: total: int = 0 processed: int = 0 percentage: float = 0.0 - detections: int = 0 - classifications: int = 0 - captures: int = 0 failed: int = 0 @@ -50,9 +47,6 @@ def __init__(self, job_id: int): self._pending_key = f"job:{job_id}:pending_images" self._total_key = f"job:{job_id}:pending_images_total" self._failed_key = f"job:{job_id}:failed_images" - self._detections_key = f"job:{job_id}:total_detections" - self._classifications_key = f"job:{job_id}:total_classifications" - self._captures_key = f"job:{job_id}:total_captures" def initialize_job(self, image_ids: list[str]) -> None: """ @@ -69,11 +63,6 @@ def initialize_job(self, image_ids: list[str]) -> None: cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT) - # Initialize detection and classification counters - cache.set(self._detections_key, 0, timeout=self.TIMEOUT) - cache.set(self._classifications_key, 0, timeout=self.TIMEOUT) - cache.set(self._captures_key, 0, timeout=self.TIMEOUT) - def _get_pending_key(self, stage: str) -> str: return f"{self._pending_key}:{stage}" @@ -82,9 +71,6 @@ def update_state( processed_image_ids: set[str], stage: str, request_id: str, - detections_count: int = 0, - classifications_count: int = 0, - captures_count: int = 0, failed_image_ids: set[str] | None = None, ) -> None | JobStateProgress: """ @@ -108,9 +94,7 @@ def update_state( try: # Update progress tracking in Redis - progress_info = self._commit_update( - processed_image_ids, stage, detections_count, classifications_count, captures_count, failed_image_ids - ) + progress_info = self._commit_update(processed_image_ids, stage, failed_image_ids) return progress_info finally: # Always release the lock when done @@ -135,9 +119,6 @@ def get_progress(self, stage: str) -> JobStateProgress | None: total=total_images, processed=processed, percentage=percentage, - detections=cache.get(self._detections_key, 0), - classifications=cache.get(self._classifications_key, 0), - captures=cache.get(self._captures_key, 0), failed=len(failed_set), ) @@ -145,9 +126,6 @@ def _commit_update( self, processed_image_ids: set[str], stage: str, - detections_count: int = 0, - classifications_count: int = 0, - captures_count: int = 0, failed_image_ids: set[str] | None = None, ) -> JobStateProgress | None: """ @@ -167,19 +145,6 @@ def _commit_update( processed = total_images - remaining percentage = float(processed) / total_images if total_images > 0 else 1.0 - # Update cumulative detection, classification, and capture counts - current_detections = cache.get(self._detections_key, 0) - current_classifications = cache.get(self._classifications_key, 0) - current_captures = cache.get(self._captures_key, 0) - - new_detections = current_detections + detections_count - new_classifications = current_classifications + classifications_count - new_captures = current_captures + captures_count - - cache.set(self._detections_key, new_detections, timeout=self.TIMEOUT) - cache.set(self._classifications_key, new_classifications, timeout=self.TIMEOUT) - cache.set(self._captures_key, new_captures, timeout=self.TIMEOUT) - # Update failed images set if provided if failed_image_ids: existing_failed = cache.get(self._failed_key) or set() @@ -201,9 +166,6 @@ def _commit_update( total=total_images, processed=processed, percentage=percentage, - detections=new_detections, - classifications=new_classifications, - captures=new_captures, failed=failed_count, ) @@ -215,6 +177,3 @@ def cleanup(self) -> None: cache.delete(self._get_pending_key(stage)) cache.delete(self._failed_key) cache.delete(self._total_key) - cache.delete(self._detections_key) - cache.delete(self._classifications_key) - cache.delete(self._captures_key) diff --git a/ami/ml/tests.py b/ami/ml/tests.py index ac9167455..2f269a6f9 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -880,9 +880,6 @@ def _init_and_verify(self, image_ids): self.assertEqual(progress.remaining, len(image_ids)) self.assertEqual(progress.processed, 0) self.assertEqual(progress.percentage, 0.0) - self.assertEqual(progress.detections, 0) - self.assertEqual(progress.classifications, 0) - self.assertEqual(progress.captures, 0) self.assertEqual(progress.failed, 0) return progress @@ -895,8 +892,6 @@ def test_initialize_job(self): progress = self.manager._commit_update(set(), stage) assert progress is not None self.assertEqual(progress.total, len(self.image_ids)) - self.assertEqual(progress.detections, 0) - self.assertEqual(progress.classifications, 0) self.assertEqual(progress.failed, 0) def test_progress_tracking(self): @@ -909,8 +904,6 @@ def test_progress_tracking(self): self.assertEqual(progress.remaining, 3) self.assertEqual(progress.processed, 2) self.assertEqual(progress.percentage, 0.4) - self.assertEqual(progress.detections, 0) # No counts added yet - self.assertEqual(progress.classifications, 0) # Process 2 more images progress = self.manager._commit_update({"img3", "img4"}, "process") @@ -936,8 +929,6 @@ def test_update_state_with_locking(self): progress = self.manager.update_state({"img1", "img2"}, "process", "task1") assert progress is not None self.assertEqual(progress.processed, 2) - self.assertEqual(progress.detections, 0) - self.assertEqual(progress.classifications, 0) # Simulate concurrent update by holding the lock lock_key = f"job:{self.job_id}:process_results_lock" @@ -975,8 +966,6 @@ def test_empty_job(self): assert progress is not None self.assertEqual(progress.total, 0) self.assertEqual(progress.percentage, 1.0) # Empty job is 100% complete - self.assertEqual(progress.detections, 0) - self.assertEqual(progress.classifications, 0) def test_cleanup(self): """Test cleanup removes all tracking keys.""" @@ -993,86 +982,6 @@ def test_cleanup(self): progress = self.manager._commit_update(set(), "process") self.assertIsNone(progress) - def test_cumulative_counting(self): - """Test that detection counts accumulate correctly across updates.""" - self._init_and_verify(self.image_ids) - - # Process first batch with some detections - progress = self.manager._commit_update({"img1", "img2"}, "process", detections_count=3) - assert progress is not None - self.assertEqual(progress.detections, 3) - self.assertEqual(progress.classifications, 0) - self.assertEqual(progress.captures, 0) - - # Process second batch with more detections and a classification - progress = self.manager._commit_update({"img3"}, "process", detections_count=2, classifications_count=1) - assert progress is not None - self.assertEqual(progress.detections, 5) # Should be cumulative - self.assertEqual(progress.classifications, 1) - self.assertEqual(progress.captures, 0) - - # Process with detections, classifications, and captures - progress = self.manager._commit_update( - {"img4"}, "results", detections_count=1, classifications_count=4, captures_count=1 - ) - assert progress is not None - self.assertEqual(progress.detections, 6) # Should accumulate - self.assertEqual(progress.classifications, 5) # Should accumulate - self.assertEqual(progress.captures, 1) - - def test_counts_persist_across_stages(self): - """Test that detection and classification counts persist across different stages.""" - self._init_and_verify(self.image_ids) - - # Add counts during process stage - progress_process = self.manager._commit_update({"img1"}, "process", detections_count=3) - assert progress_process is not None - self.assertEqual(progress_process.detections, 3) - - # Verify counts are available in results stage - progress_results = self.manager._commit_update(set(), "results") - assert progress_results is not None - self.assertEqual(progress_results.detections, 3) # Should persist - self.assertEqual(progress_results.classifications, 0) - - # Add more counts in results stage - progress_results = self.manager._commit_update( - {"img2"}, "results", detections_count=1, classifications_count=5 - ) - assert progress_results is not None - self.assertEqual(progress_results.detections, 4) # Should accumulate - self.assertEqual(progress_results.classifications, 5) - - def test_cleanup_removes_count_keys(self): - """Test that cleanup removes detection and classification count keys.""" - from django.core.cache import cache - - self._init_and_verify(self.image_ids) - - # Add some counts - self.manager._commit_update( - {"img1"}, "process", detections_count=5, classifications_count=10, captures_count=2 - ) - - # Verify count keys exist - detections = cache.get(self.manager._detections_key) - classifications = cache.get(self.manager._classifications_key) - captures = cache.get(self.manager._captures_key) - self.assertEqual(detections, 5) - self.assertEqual(classifications, 10) - self.assertEqual(captures, 2) - - # Cleanup - self.manager.cleanup() - - # Verify count keys are gone - detections = cache.get(self.manager._detections_key) - classifications = cache.get(self.manager._classifications_key) - captures = cache.get(self.manager._captures_key) - self.assertIsNone(detections) - self.assertIsNone(classifications) - self.assertIsNone(captures) - def test_failed_image_tracking(self): """Test basic failed image tracking with no double-counting on retries.""" self._init_and_verify(self.image_ids) From b6c3c6a380aea4f296fe9ba377511bb333218eee Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 13 Feb 2026 09:05:31 -0800 Subject: [PATCH 18/25] small simplification --- ami/jobs/tasks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 35c09ed3c..0abf85dae 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -248,9 +248,9 @@ def _update_job_progress( ) # Add new counts to existing counts - new_detections = state_params.get("detections", 0) or 0 - new_classifications = state_params.get("classifications", 0) or 0 - new_captures = state_params.get("captures", 0) or 0 + new_detections = state_params.get("detections", 0) + new_classifications = state_params.get("classifications", 0) + new_captures = state_params.get("captures", 0) state_params["detections"] = current_detections + new_detections state_params["classifications"] = current_classifications + new_classifications From b15024fd914e101e2c15dfd205c8c1aea3fe38e0 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 13 Feb 2026 11:46:32 -0800 Subject: [PATCH 19/25] Reset counts to 0 on reset --- ami/jobs/models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 1397410c8..2b2f2b4bd 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -222,6 +222,10 @@ def reset(self, status: JobState = JobState.CREATED): for stage in self.stages: stage.progress = 0 stage.status = status + # Reset numeric param values to 0 + for param in stage.params: + if isinstance(param.value, (int, float)): + param.value = 0 def is_complete(self) -> bool: """ From b2e4a728911e8ed8b41f1f3b8d0c275da3599c79 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 16 Feb 2026 20:56:41 -0800 Subject: [PATCH 20/25] chore: remove local planning docs from PR branch Co-Authored-By: Claude --- docs/claude/nats-todo.md | 159 ------------------ .../planning/pr-trackcounts-next-session.md | 67 -------- 2 files changed, 226 deletions(-) delete mode 100644 docs/claude/nats-todo.md delete mode 100644 docs/claude/planning/pr-trackcounts-next-session.md diff --git a/docs/claude/nats-todo.md b/docs/claude/nats-todo.md deleted file mode 100644 index 127ede36a..000000000 --- a/docs/claude/nats-todo.md +++ /dev/null @@ -1,159 +0,0 @@ -# NATS Infrastructure TODO - -Tracked improvements for the NATS JetStream setup used by async ML pipeline jobs. - -## Urgent - -### Add `NATS_URL` to worker-2 env file - -- **Status:** DONE (env var added, container reloaded, connection verified) -- **Root cause of job 2226 failure:** worker-2 was missing `NATS_URL` in `.envs/.production/.django`, so it defaulted to `nats://localhost:4222`. Every NATS ack from worker-2 failed with `Connect call failed ('127.0.0.1', 4222)`. -- **Fix applied in code:** Changed default in `config/settings/base.py:268` from `nats://localhost:4222` to `nats://nats:4222` (matches the hostname mapped via `extra_hosts` in all compose files). -- **Still needed on server:** - ```bash - ssh ami-cc "ssh ami-worker-2 'echo NATS_URL=nats://nats:4222 >> ~/ami-platform/.envs/.production/.django'" - ssh ami-cc "ssh ami-worker-2 'cd ~/ami-platform && docker compose -f docker-compose.worker.yml restart celeryworker'" - ``` - -## Error Handling - -### Don't retry permanent errors - -- **File:** `ami/ml/orchestration/nats_queue.py:118` -- **Current:** `max_deliver=5` retries every failed message, including permanent errors (404 image not found, malformed data, etc.) -- **Problem:** NATS has no way to distinguish transient vs permanent failures. If a task fails because the image URL is broken, it will be redelivered 5 times, wasting processing service time. -- **Proposed fix:** The error handling should happen in the celery task (`process_nats_pipeline_result`) and in the processing service, not in NATS redelivery. If the processing service returns an error result, the celery task should ack the NATS message (removing it from the queue) and record the error on the job. NATS redelivery should only cover the case where a consumer crashes mid-processing (no result posted at all). -- **Consider:** Reducing `max_deliver` to 2-3 since the only legitimate redelivery scenario is consumer crash/timeout, not application errors. - -### Detect and surface exhausted messages (dead letters) - -- **Current:** When a message hits `max_deliver`, NATS silently drops it. The job hangs forever with remaining images never processed and no error shown to the user. -- **Problem:** There's no feedback loop. The `process_nats_pipeline_result` celery task only runs when the processing service posts a result. If NATS stops delivering a message (because it hit `max_deliver`), no celery task fires, no log is written, and the job just stalls. -- **Proposed approach — poll consumer state from the celery job:** - The `run_job` celery task currently returns immediately after queuing images. Instead, it could poll the NATS consumer state periodically until the job completes or stalls: - - ```python - # In the run_job task or a separate watchdog task: - async with TaskQueueManager() as manager: - info = await js.consumer_info(stream_name, consumer_name) - delivered = info.num_delivered - ack_floor = info.ack_floor.stream_seq - pending = info.num_pending - ack_pending = info.num_ack_pending - - if pending == 0 and ack_pending == 0 and ack_floor < total_images: - # Messages exhausted max_deliver — they're dead - dead_count = total_images - ack_floor - job.logger.error( - f"{dead_count} tasks exceeded max delivery attempts " - f"and will not be retried. {ack_floor}/{total_images} completed." - ) - job.update_status(JobState.FAILURE) - ``` - - This would surface a clear error in the job logs visible in Django admin. - -- **Alternative — NATS advisory subscription:** - Subscribe to `$JS.EVENT.ADVISORY.CONSUMER.MAX_DELIVERIES.job_{id}.*` and log each dead message individually. More complex but gives per-message visibility. -- **Where to implement:** Either as a polling loop in `run_job` (simplest), or as a separate Celery Beat task that checks all active async jobs. -- **Files:** `ami/jobs/tasks.py` (run_job or new watchdog task), `ami/ml/orchestration/nats_queue.py` (add `get_consumer_info` method) - -## Infrastructure - -### Review NATS compose file on ami-redis-1 - -- **Location:** `docker-compose.yml` on ami-redis-1 -- **Current config is mostly good:** ports exposed (4222, 8222), healthcheck configured, restart=always, JetStream enabled -- **Missing: Persistent volume.** JetStream stores data in `/tmp/nats/jetstream` (container temp dir). Server logs warn: `Temporary storage directory used, data could be lost on system reboot`. Add a volume mount: - ```yaml - nats: - image: nats:2.10-alpine - volumes: - - nats-data:/data/jetstream - command: ["-js", "-m", "8222", "-sd", "/data/jetstream"] - volumes: - nats-data: - ``` -- **Consider:** Adding memory/storage limits to JetStream config (`-js --max_mem_store`, `--max_file_store`) to prevent unbounded growth -- **Consider:** Adding NATS config file instead of CLI flags for more control (auth, logging level, connection limits) - -### Clean up stale streams - -- **Current:** 95 streams exist on the server, all but one are empty (from old jobs) -- **Cleanup runs on job completion** (`cleanup_async_job_resources` in `ami/ml/orchestration/jobs.py`), but only if the job fully completes. Failed/stuck jobs leave orphan streams. -- **Proposed fix:** Add a periodic Celery Beat task to clean up streams older than 24h (matching the `max_age=86400` retention on streams). Or clean up streams for jobs that are in a final state (SUCCESS, FAILURE, REVOKED). - -### Expose NATS monitoring for dashboard access - -- **Port 8222 is already exposed** on ami-redis-1, so `http://192.168.123.176:8222` should work from the VPN -- **For browser dashboard** (https://natsdashboard.com/): Needs the monitoring endpoint reachable from your browser. Use SSH tunnel if not on VPN: - ```bash - ssh -L 8222:localhost:8222 ami-cc -t "ssh -L 8222:localhost:8222 ami-redis-1" - ``` - Then open https://natsdashboard.com/ with server URL `http://localhost:8222` - -## Reliability - -### Connection handling in ack path - -- **Status:** DONE (PR #1130, branch `carlosg/natsconn` on `uw-ssec` remote) -- **What was done:** Added `retry_on_connection_error` decorator with exponential backoff. Replaced connection pool with async context manager pattern — each `async_to_sync()` call scopes one connection. Added `reconnected_cb`/`disconnected_cb` logging callbacks. -- **Commit:** `c384199f` refactor: simplify NATS connection handling — keep retry decorator, drop pool - -### `check_processing_services_online` causing worker instability - -- **Observed on both ami-live and worker-2:** This periodic Beat task hits soft time limit (10s) and hard time limit (20s) on every run, causing ForkPoolWorker processes to be SIGKILL'd. -- **Service #13 "Zero Shot Detector Pipelines"** at `https://ml-zs.dev.insectai.org/` consistently times out. -- **Impact:** Worker pool instability, killed processes may have been mid-task. This could contribute to unreliable task processing. -- **Fix:** Either increase the time limit, skip known-offline services, or handle the timeout more gracefully. - -## Source Code References - -| File | Line | What | -| ------------------------------------ | ---------------------------------- | ---------------------------------------------------------- | -| `config/settings/base.py` | 268 | `NATS_URL` setting (default changed to `nats://nats:4222`) | -| `ami/ml/orchestration/nats_queue.py` | 97-108 | `TaskQueueManager` context manager + connection setup | -| `ami/ml/orchestration/nats_queue.py` | 42-94 | `retry_on_connection_error` decorator | -| `ami/ml/orchestration/nats_queue.py` | 191-215 | Consumer config (`max_deliver`, `ack_wait`) | -| `ami/jobs/tasks.py` | 134-149 | `_ack_task_via_nats()` | -| `ami/ml/orchestration/jobs.py` | 14-55 | `cleanup_async_job_resources()` | -| `ami/ml/tasks.py` | `check_processing_services_online` | Periodic health check (causing SIGKILL) | - -## DevOps - -### Move triage steps to `ami-devops` repo and create a Claude skill - -- Port `docs/claude/debugging/nats-triage.md` to the `ami-devops` repo -- Create a Claude Code skill (slash command) for NATS triage that can run the diagnostic commands interactively -- Skill should accept a job ID and walk through the triage steps: check job logs, inspect NATS stream state, test connectivity from workers - -## Other - -- Processing service failing on batches with different image sizes -- How can we mark an image/task as failed and say don't retry? -- Processing service still needs to batch classifications (like prevous methods) -- Nats jobs appear stuck if there are any task failures: https://antenna.insectai.org/projects/18/jobs/2228 -- If a task crashes, the whole worker seems to reset -- Then no tasks are found remaining for the job in NATS - 2026-02-09 18:23:49 [info ] No jobs found, sleeping for 5 seconds - 2026-02-09 18:23:54 [info ] Checking for jobs for pipeline panama_moths_2023 - 2026-02-09 18:23:55 [info ] Checking for jobs for pipeline panama_moths_2024 - 2026-02-09 18:23:55 [info ] Checking for jobs for pipeline quebec_vermont_moths_2023 - 2026-02-09 18:23:55 [info ] Processing job 2229 with pipeline quebec_vermont_moths_2023 - 2026-02-09 18:23:55 [info ] Worker 0/2 starting iteration for job 2229 - 2026-02-09 18:23:55 [info ] Worker 1/2 starting iteration for job 2229 - 2026-02-09 18:23:59 [info ] Worker 0: No more tasks for job 2229 - 2026-02-09 18:23:59 [info ] Worker 0: Iterator finished - 2026-02-09 18:24:03 [info ] Worker 1: No more tasks for job 2229 - 2026-02-09 18:24:03 [info ] Worker 1: Iterator finished - 2026-02-09 18:24:03 [info ] Done, detections: 0. Detecting time: 0.0, classification time: 0.0, dl time: 0.0, save time: 0.0 - -- Would love some logs like "no task has been picked up in X minutes" or "last seen", etc. -- Skip jobs that hbs no tasks in the initial query - -- test in a SLUM job! yeah! in Victoria? - -- jumps around between jobs - good thing? annoying? what about when there is only one job open? -- time for time estimates - -- bring back vectors asap diff --git a/docs/claude/planning/pr-trackcounts-next-session.md b/docs/claude/planning/pr-trackcounts-next-session.md deleted file mode 100644 index d1c085df2..000000000 --- a/docs/claude/planning/pr-trackcounts-next-session.md +++ /dev/null @@ -1,67 +0,0 @@ -# Next Session: carlos/trackcounts PR Review - -## Context - -We're reviewing and cleaning up a PR that adds cumulative count tracking (detections, classifications, captures, failed images) to the async (NATS) job pipeline. The PR is on branch `carlos/trackcounts`. - -## Working with worktrees and remotes - -There are two related branches on different remotes/worktrees: - -| Branch | Remote | Worktree | PR | -|--------|--------|----------|-----| -| `carlos/trackcounts` | `origin` (RolnickLab) | `/home/michael/Projects/AMI/antenna` (main) | trackcounts PR | -| `carlosg/natsconn` | `uw-ssec` (uw-ssec) | `/home/michael/Projects/AMI/antenna-natsconn` | #1130 | - -**Key rules:** -- `carlosg/natsconn` pushes to `uw-ssec`, NOT `origin` -- The worktree at `antenna-natsconn` is detached HEAD — use `git push uw-ssec HEAD:refs/heads/carlosg/natsconn` -- To run tests from the worktree, mount only `ami/` over the main compose: `docker compose run --rm -v /home/michael/Projects/AMI/antenna-natsconn/ami:/app/ami:z django python manage.py test ... --keepdb` -- Container uses Pydantic v1 — use `.dict()` / `.json()`, NOT `.model_dump()` / `.model_dump_json()` - -## Completed commits (on top of main merge) - -1. `a5ff6f8a` — Rename `_get_progress` → `_commit_update` in TaskStateManager -2. `337b7fc8` — Unify `FAILURE_THRESHOLD` constant + convert TaskProgress to dataclass -3. `89111a05` — Rename `TaskProgress` → `JobStateProgress` - -## Remaining work - -### 1. Review: Is `JobStateProgress` the right abstraction? - -**Question:** What is `TaskStateManager` actually tracking, and does `JobStateProgress` reflect that correctly? - -There are two parallel progress tracking systems: -- **Django side:** `Job.progress` → `JobProgress` (Pydantic model, persisted in DB) - - Has multiple stages (collect, process, results) - - Each stage has its own status, progress percentage, and arbitrary params - - `is_complete()` checks all stages -- **Redis side:** `TaskStateManager` → `JobStateProgress` (dataclass, ephemeral in Redis) - - Tracks pending image IDs per stage (process, results) - - Tracks cumulative counts: detections, classifications, captures, failed - - Single flat object — no per-stage breakdown of counts - -The disconnect: Redis tracks **per-stage pending images** (separate pending lists for "process" and "results" stages) but returns **job-wide cumulative counts** (one detections counter, one failed set). So `JobStateProgress` is a hybrid — stage-scoped for image completion, but job-scoped for counts. - -**Should counts be per-stage?** For example, "failed" in the process stage means images that errored during ML inference. But could there be failures in the results stage too (failed to save)? The sync path tracks `request_failed_images` (process failures) separately from `failed_save_tasks` (results failures). The async path currently lumps all failures into one set. - -**Key files to read:** -- `ami/ml/orchestration/task_state.py` — `TaskStateManager` and `JobStateProgress` -- `ami/jobs/tasks.py:62-185` — `process_nats_pipeline_result` (async path, uses TaskStateManager) -- `ami/jobs/models.py:466-582` — `MLJob.process_images` (sync path, tracks counts locally) -- `ami/jobs/models.py:134-248` — `JobProgress`, `JobProgressStageDetail`, `is_complete()` - -### 2. Remove `complete_state` parameter - -Plan is written at: `docs/claude/planning/pr-trackcounts-complete-state-removal.md` - -Summary: Remove `complete_state` from `_update_job_progress`. Jobs always complete as SUCCESS. Failure counts are tracked as stage params but don't affect overall status. "Completed with failures" state deferred to future PR. - -**Files to modify:** -- `ami/jobs/tasks.py:97-179, 205-223` — remove complete_state logic -- `ami/ml/orchestration/tests/test_cleanup.py:164-168` — update test calls -- `ami/jobs/test_tasks.py` — update any affected tests - -### 3. PR assessment doc - -Full review written at: `docs/claude/planning/pr-trackcounts-review.md` From a15ebda30fdace9f3aa1589ebb0311777c6c6efa Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 16 Feb 2026 21:01:20 -0800 Subject: [PATCH 21/25] docs: clarify three-layer job state architecture in docstrings Explain the relationship between AsyncJobStateManager (Redis), JobProgress (JSONB), and JobState (enum). Clarify that all counts in JobStateProgress refer to source images (captures). Co-Authored-By: Claude --- ami/jobs/models.py | 14 ++++++- ami/ml/orchestration/async_job_state.py | 53 ++++++++++++++++++++----- 2 files changed, 55 insertions(+), 12 deletions(-) diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 2b2f2b4bd..b4df41a04 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -109,7 +109,7 @@ def python_slugify(value: str) -> str: class JobProgressSummary(pydantic.BaseModel): - """Summary of all stages of a job""" + """Top-level status and progress for a job, shown in the UI.""" status: JobState = JobState.CREATED progress: float = 0 @@ -132,7 +132,17 @@ class JobProgressStageDetail(ConfigurableStage, JobProgressSummary): class JobProgress(pydantic.BaseModel): - """The full progress of a job and its stages.""" + """ + The user-facing progress of a job, stored as JSONB on the Job model. + + This is what the UI displays and what external APIs read. Contains named + stages ("process", "results") with per-stage params (progress percentage, + detections/classifications/captures counts, failed count). + + For async (NATS) jobs, updated by _update_job_progress() in ami/jobs/tasks.py + which copies snapshots from the internal Redis-backed AsyncJobStateManager. + For sync jobs, updated directly in MLJob.process_images(). + """ summary: JobProgressSummary stages: list[JobProgressStageDetail] diff --git a/ami/ml/orchestration/async_job_state.py b/ami/ml/orchestration/async_job_state.py index fe0f7eb74..5a300c12a 100644 --- a/ami/ml/orchestration/async_job_state.py +++ b/ami/ml/orchestration/async_job_state.py @@ -1,5 +1,27 @@ """ -Task state management for job progress tracking using Redis. +Internal progress tracking for async (NATS) job processing, backed by Redis. + +Multiple Celery workers process image batches concurrently and report progress +here using Redis for atomic updates with locking. This module is purely internal +— nothing outside the worker pipeline reads from it directly. + +How this relates to the Job model (ami/jobs/models.py): + + The **Job model** is the primary, external-facing record. It is what users see + in the UI, what external APIs interact with (listing open jobs, fetching tasks + to process), and what persists as history in the CMS. It has two relevant fields: + + - **Job.status** (JobState enum) — lifecycle state (CREATED → STARTED → SUCCESS/FAILURE) + - **Job.progress** (JobProgress JSONB) — detailed stage progress with params + like detections, classifications, captures counts + + This Redis layer exists only because concurrent NATS workers need atomic + counters that PostgreSQL row locks would serialize too aggressively. After each + batch, _update_job_progress() in ami/jobs/tasks.py copies the Redis snapshot + into the Job model, which is the source of truth for everything external. + +Flow: NATS result → AsyncJobStateManager.update_state() (Redis, internal) + → _update_job_progress() (writes to Job model) → UI / API reads Job """ import logging @@ -12,13 +34,20 @@ @dataclass class JobStateProgress: - """Progress snapshot for a job stage tracked in Redis.""" + """ + Progress snapshot for a job stage, read from Redis. - remaining: int = 0 - total: int = 0 - processed: int = 0 - percentage: float = 0.0 - failed: int = 0 + All counts refer to source images (captures), not detections or occurrences. + Currently specific to ML pipeline jobs — if other job types are made available + for external processing, the unit of work ("source image") and failure semantics + may need to be generalized. + """ + + remaining: int = 0 # source images not yet processed in this stage + total: int = 0 # total source images in the job + processed: int = 0 # source images completed (success + failed) + percentage: float = 0.0 # processed / total + failed: int = 0 # source images that returned an error from the processing service def _lock_key(job_id: int) -> str: @@ -27,10 +56,14 @@ def _lock_key(job_id: int) -> str: class AsyncJobStateManager: """ - Manages job progress tracking state in Redis. + Manages real-time job progress in Redis for concurrent NATS workers. + + Each job has per-stage pending image lists and a shared failed image set. + Workers acquire a Redis lock before mutating state, ensuring atomic updates + even when multiple Celery tasks process batches in parallel. - Tracks pending images for jobs to calculate progress percentages - as workers process images asynchronously. + The results are ephemeral — _update_job_progress() in ami/jobs/tasks.py + copies each snapshot into the persistent Job.progress JSONB field. """ TIMEOUT = 86400 * 7 # 7 days in seconds From 14f6a636ddfe75236e298ba9ca5bb7be8fc924b9 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Tue, 17 Feb 2026 10:27:59 -0800 Subject: [PATCH 22/25] Add async ML backend diagrma --- docs/diagrams/async_ml_backend.md | 187 ++++++++++++++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 docs/diagrams/async_ml_backend.md diff --git a/docs/diagrams/async_ml_backend.md b/docs/diagrams/async_ml_backend.md new file mode 100644 index 000000000..d83c6ebc0 --- /dev/null +++ b/docs/diagrams/async_ml_backend.md @@ -0,0 +1,187 @@ +# Async ML Backend Architecture + +This document describes how async ML jobs work in the Antenna system, showing the flow of data between Django, Celery workers, NATS JetStream, and external ML processing workers. + +## System Components + +- **Django**: Web application serving the REST API +- **Postgres**: Database for persistent storage +- **Celery Worker**: Background task processor for job orchestration +- **NATS JetStream**: Distributed task queue with acknowledgment support +- **ML Worker**: External processing service that runs ML models on images +- **Redis**: State management for job progress tracking + +## Async Job Flow + +```mermaid +sequenceDiagram + autonumber + + participant User + participant Django + participant Celery + participant Redis + participant NATS + participant Worker as ML Worker + participant DB as Postgres + + %% Job Creation and Setup + User->>Django: POST /jobs/ (create job) + Django->>DB: Create Job instance + User->>Django: POST /jobs/{id}/run/ + Django->>Celery: enqueue run_job.delay(job_id) + Note over Django,Celery: Celery task queued + + %% Job Execution and Image Queuing + Celery->>Celery: Collect images to process + Note over Celery: Set dispatch_mode = ASYNC_API + + Celery->>Redis: Initialize job state (image IDs) + + loop For each image + Celery->>NATS: publish_task(job_id, PipelineProcessingTask) + Note over NATS: Task contains:
- image_id
- image_url
- reply_subject (for ACK) + end + + Note over Celery: Celery task completes
(images queued, not processed) + + %% Worker Processing Loop + loop Worker polling loop + Worker->>Django: GET /jobs/{id}/tasks?batch=10 + Django->>NATS: reserve_task(job_id) x10 + NATS-->>Django: PipelineProcessingTask[] (with reply_subjects) + Django-->>Worker: {"tasks": [...]} + + %% Process each task + loop For each task + Worker->>Worker: Download image from image_url + Worker->>Worker: Run ML model (detection/classification) + Worker->>Django: POST /jobs/{id}/result/
PipelineTaskResult{
reply_subject,
result: PipelineResultsResponse
} + + %% Result Processing + Django->>Celery: process_nats_pipeline_result.delay(
job_id, result_data, reply_subject) + Django-->>Worker: {"status": "queued", "task_id": "..."} + + %% Celery processes result + Celery->>Redis: Acquire lock for job + Celery->>Redis: Update pending images (remove processed) + Celery->>Redis: Calculate progress percentage + Redis-->>Celery: JobStateProgress + + Celery->>DB: Update job.progress (process stage) + Celery->>DB: Save detections, classifications + Celery->>NATS: acknowledge_task(reply_subject) + Note over NATS: Task removed from queue
(won't be redelivered) + + Celery->>Redis: Update pending images (results stage) + Celery->>DB: Update job.progress (results stage) + Celery->>Redis: Release lock + end + end +``` + +## Key Design Decisions + +### 1. Asynchronous Task Queue (NATS JetStream) + +- **Why NATS?** Supports disconnected pull model - workers don't need persistent connections +- **Visibility Timeout (TTR)**: 300 seconds (5 minutes) - tasks auto-requeue if not ACK'd +- **Max Retries**: 5 attempts before giving up on a task +- **Per-Job Streams**: Each job gets its own stream (`job_{job_id}`) for isolation + +### 2. Redis-Based State Management + +- **Purpose**: Track pending images across distributed workers +- **Atomicity**: Uses distributed locks to prevent race conditions +- **Lock Duration**: 360 seconds (matches Celery task timeout) +- **Cleanup**: Automatic cleanup when job completes + +### 3. Reply Subject for Acknowledgment + +- NATS generates unique `reply_subject` for each reserved task +- Worker receives `reply_subject` in task data +- Worker includes `reply_subject` in result POST +- Celery acknowledges via NATS after successful save + +This pattern enables: + +- Workers don't need direct NATS access +- HTTP-only communication for workers +- Proper task acknowledgment through Django API + +### 4. Error Handling + +**Worker Errors:** + +- Worker posts `PipelineResultsError` instead of `PipelineResultsResponse` +- Error is logged but task is still ACK'd (prevents infinite retries for bad data) +- Failed images tracked separately in Redis + +**Database Errors:** + +- If `save_results()` fails, task is NOT ACK'd +- NATS will redeliver after visibility timeout +- Celery task has no retries (relies on NATS retry mechanism) + +**Job Cancellation:** + +- Celery task terminated immediately +- NATS stream and consumer deleted +- Redis state cleaned up + +## API Endpoints + +### GET /jobs/{id}/tasks + +Worker endpoint to fetch tasks from NATS queue. + +**Query Parameters:** + +- `batch`: Number of tasks to fetch (1-100) + +**Response:** + +```json +{ + "tasks": [ + { + "id": "123", + "image_id": "123", + "image_url": "https://minio:9000/...", + "reply_subject": "$JS.ACK.job_1.job-1-consumer.1.2.3" + } + ] +} +``` + +### POST /jobs/{id}/result + +Worker endpoint to post processing results. + +**Request Body:** + +```json +{ + "reply_subject": "$JS.ACK.job_1.job-1-consumer.1.2.3", + "result": { + // PipelineResultsResponse or PipelineResultsError + } +} +``` + +**Response:** + +```json +{ + "status": "accepted", + "job_id": 1, + "results_queued": 1, + "tasks": [ + { + "reply_subject": "$JS.ACK.job_1.job-1-consumer.1.2.3", + "status": "queued", + "task_id": "celery-task-uuid" + } + ] +} +``` From 1c559ce02c48cb19c10fc486ce2eca1fcde2a246 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Tue, 17 Feb 2026 10:42:23 -0800 Subject: [PATCH 23/25] cleanup --- docs/diagrams/async_ml_backend.md | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/docs/diagrams/async_ml_backend.md b/docs/diagrams/async_ml_backend.md index d83c6ebc0..a6217d388 100644 --- a/docs/diagrams/async_ml_backend.md +++ b/docs/diagrams/async_ml_backend.md @@ -116,6 +116,7 @@ This pattern enables: - Worker posts `PipelineResultsError` instead of `PipelineResultsResponse` - Error is logged but task is still ACK'd (prevents infinite retries for bad data) - Failed images tracked separately in Redis +- If worker crashes and never reports a result or error NATS will redeliver after visibility timeout **Database Errors:** @@ -137,7 +138,7 @@ Worker endpoint to fetch tasks from NATS queue. **Query Parameters:** -- `batch`: Number of tasks to fetch (1-100) +- `batch`: Number of tasks to fetch **Response:** @@ -168,20 +169,3 @@ Worker endpoint to post processing results. } } ``` - -**Response:** - -```json -{ - "status": "accepted", - "job_id": 1, - "results_queued": 1, - "tasks": [ - { - "reply_subject": "$JS.ACK.job_1.job-1-consumer.1.2.3", - "status": "queued", - "task_id": "celery-task-uuid" - } - ] -} -``` From dcd27c038ca0e38ddb8302a17151d6c370c83f25 Mon Sep 17 00:00:00 2001 From: carlosgjs Date: Tue, 17 Feb 2026 10:44:41 -0800 Subject: [PATCH 24/25] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/diagrams/async_ml_backend.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/diagrams/async_ml_backend.md b/docs/diagrams/async_ml_backend.md index a6217d388..3c2cd837b 100644 --- a/docs/diagrams/async_ml_backend.md +++ b/docs/diagrams/async_ml_backend.md @@ -40,7 +40,7 @@ sequenceDiagram loop For each image Celery->>NATS: publish_task(job_id, PipelineProcessingTask) - Note over NATS: Task contains:
- image_id
- image_url
- reply_subject (for ACK) + Note over NATS: Task contains:
- id
- image_id
- image_url

Note: reply_subject is added when the task is reserved via reserve_task(). end Note over Celery: Celery task completes
(images queued, not processed) @@ -60,7 +60,7 @@ sequenceDiagram %% Result Processing Django->>Celery: process_nats_pipeline_result.delay(
job_id, result_data, reply_subject) - Django-->>Worker: {"status": "queued", "task_id": "..."} + Django-->>Worker: {"status": "accepted", "job_id": "...", "results_queued": ..., "tasks": [...]} %% Celery processes result Celery->>Redis: Acquire lock for job @@ -155,7 +155,7 @@ Worker endpoint to fetch tasks from NATS queue. } ``` -### POST /jobs/{id}/result +### POST /api/v2/jobs/{id}/result/ Worker endpoint to post processing results. From 0830b789e2df91696e65d167cb77752a5b018e8d Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Tue, 17 Feb 2026 10:47:05 -0800 Subject: [PATCH 25/25] update path --- docs/diagrams/async_ml_backend.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/diagrams/async_ml_backend.md b/docs/diagrams/async_ml_backend.md index 3c2cd837b..8de0e58c1 100644 --- a/docs/diagrams/async_ml_backend.md +++ b/docs/diagrams/async_ml_backend.md @@ -132,7 +132,7 @@ This pattern enables: ## API Endpoints -### GET /jobs/{id}/tasks +### GET /api/v2/jobs/{id}/tasks Worker endpoint to fetch tasks from NATS queue.