diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 8e790cdcb..b4df41a04 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -109,7 +109,7 @@ def python_slugify(value: str) -> str: class JobProgressSummary(pydantic.BaseModel): - """Summary of all stages of a job""" + """Top-level status and progress for a job, shown in the UI.""" status: JobState = JobState.CREATED progress: float = 0 @@ -132,7 +132,17 @@ class JobProgressStageDetail(ConfigurableStage, JobProgressSummary): class JobProgress(pydantic.BaseModel): - """The full progress of a job and its stages.""" + """ + The user-facing progress of a job, stored as JSONB on the Job model. + + This is what the UI displays and what external APIs read. Contains named + stages ("process", "results") with per-stage params (progress percentage, + detections/classifications/captures counts, failed count). + + For async (NATS) jobs, updated by _update_job_progress() in ami/jobs/tasks.py + which copies snapshots from the internal Redis-backed AsyncJobStateManager. + For sync jobs, updated directly in MLJob.process_images(). + """ summary: JobProgressSummary stages: list[JobProgressStageDetail] @@ -222,6 +232,10 @@ def reset(self, status: JobState = JobState.CREATED): for stage in self.stages: stage.progress = 0 stage.status = status + # Reset numeric param values to 0 + for param in stage.params: + if isinstance(param.value, (int, float)): + param.value = 0 def is_complete(self) -> bool: """ @@ -561,7 +575,8 @@ def process_images(cls, job, images): job.logger.info(f"All tasks completed for job {job.pk}") - FAILURE_THRESHOLD = 0.5 + from ami.jobs.tasks import FAILURE_THRESHOLD + if image_count and (percent_successful < FAILURE_THRESHOLD): job.progress.update_stage("process", status=JobState.FAILURE) job.save() diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 3548d0ea5..0abf85dae 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -3,18 +3,25 @@ import logging import time from collections.abc import Callable +from typing import TYPE_CHECKING from asgiref.sync import async_to_sync from celery.signals import task_failure, task_postrun, task_prerun from django.db import transaction +from ami.ml.orchestration.async_job_state import AsyncJobStateManager from ami.ml.orchestration.nats_queue import TaskQueueManager -from ami.ml.orchestration.task_state import TaskStateManager from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse from ami.tasks import default_soft_time_limit, default_time_limit from config import celery_app +if TYPE_CHECKING: + from ami.jobs.models import JobState + logger = logging.getLogger(__name__) +# Minimum success rate. Jobs with fewer than this fraction of images +# processed successfully are marked as failed. Also used in MLJob.process_images(). +FAILURE_THRESHOLD = 0.5 @celery_app.task(bind=True, soft_time_limit=default_soft_time_limit, time_limit=default_time_limit) @@ -59,23 +66,27 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub result_data: Dictionary containing the pipeline result reply_subject: NATS reply subject for acknowledgment """ - from ami.jobs.models import Job # avoid circular import + from ami.jobs.models import Job, JobState # avoid circular import _, t = log_time() # Validate with Pydantic - check for error response first + error_result = None if "error" in result_data: error_result = PipelineResultsError(**result_data) processed_image_ids = {str(error_result.image_id)} if error_result.image_id else set() - logger.error(f"Pipeline returned error for job {job_id}, image {error_result.image_id}: {error_result.error}") + failed_image_ids = processed_image_ids # Same as processed for errors pipeline_result = None else: pipeline_result = PipelineResultsResponse(**result_data) processed_image_ids = {str(img.id) for img in pipeline_result.source_images} + failed_image_ids = set() # No failures for successful results - state_manager = TaskStateManager(job_id) + state_manager = AsyncJobStateManager(job_id) - progress_info = state_manager.update_state(processed_image_ids, stage="process", request_id=self.request.id) + progress_info = state_manager.update_state( + processed_image_ids, stage="process", request_id=self.request.id, failed_image_ids=failed_image_ids + ) if not progress_info: logger.warning( f"Another task is already processing results for job {job_id}. " @@ -84,16 +95,31 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub raise self.retry(countdown=5, max_retries=10) try: - _update_job_progress(job_id, "process", progress_info.percentage) + complete_state = JobState.SUCCESS + if progress_info.total > 0 and (progress_info.failed / progress_info.total) > FAILURE_THRESHOLD: + complete_state = JobState.FAILURE + _update_job_progress( + job_id, + "process", + progress_info.percentage, + complete_state=complete_state, + processed=progress_info.processed, + remaining=progress_info.remaining, + failed=progress_info.failed, + ) _, t = t(f"TIME: Updated job {job_id} progress in PROCESS stage progress to {progress_info.percentage*100}%") job = Job.objects.get(pk=job_id) job.logger.info(f"Processing pipeline result for job {job_id}, reply_subject: {reply_subject}") job.logger.info( f" Job {job_id} progress: {progress_info.processed}/{progress_info.total} images processed " - f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {len(processed_image_ids)} just " - "processed" + f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {progress_info.failed} failed, " + f"{len(processed_image_ids)} just processed" ) + if error_result: + job.logger.error( + f"Pipeline returned error for job {job_id}, image {error_result.image_id}: {error_result.error}" + ) except Job.DoesNotExist: # don't raise and ack so that we don't retry since the job doesn't exists logger.error(f"Job {job_id} not found") @@ -102,6 +128,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub try: # Save to database (this is the slow operation) + detections_count, classifications_count, captures_count = 0, 0, 0 if pipeline_result: # should never happen since otherwise we could not be processing results here assert job.pipeline is not None, "Job pipeline is None" @@ -112,10 +139,19 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub f"Saved pipeline results to database with {len(pipeline_result.detections)} detections" f", percentage: {progress_info.percentage*100}%" ) + # Calculate detection and classification counts from this result + detections_count = len(pipeline_result.detections) + classifications_count = sum(len(detection.classifications) for detection in pipeline_result.detections) + captures_count = len(pipeline_result.source_images) _ack_task_via_nats(reply_subject, job.logger) # Update job stage with calculated progress - progress_info = state_manager.update_state(processed_image_ids, stage="results", request_id=self.request.id) + + progress_info = state_manager.update_state( + processed_image_ids, + stage="results", + request_id=self.request.id, + ) if not progress_info: logger.warning( @@ -123,7 +159,21 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub f"Retrying task {self.request.id} in 5 seconds..." ) raise self.retry(countdown=5, max_retries=10) - _update_job_progress(job_id, "results", progress_info.percentage) + + # update complete state based on latest progress info after saving results + complete_state = JobState.SUCCESS + if progress_info.total > 0 and (progress_info.failed / progress_info.total) > FAILURE_THRESHOLD: + complete_state = JobState.FAILURE + + _update_job_progress( + job_id, + "results", + progress_info.percentage, + complete_state=complete_state, + detections=detections_count, + classifications=classifications_count, + captures=captures_count, + ) except Exception as e: job.logger.error( @@ -149,19 +199,72 @@ async def ack_task(): # Don't fail the task if ACK fails - data is already saved -def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> None: +def _get_current_counts_from_job_progress(job, stage: str) -> tuple[int, int, int]: + """ + Get current detections, classifications, and captures counts from job progress. + + Args: + job: The Job instance + stage: The stage name to read counts from + + Returns: + Tuple of (detections, classifications, captures) counts, defaulting to 0 if not found + """ + try: + stage_obj = job.progress.get_stage(stage) + + # Initialize defaults + detections = 0 + classifications = 0 + captures = 0 + + # Search through the params list for our count values + for param in stage_obj.params: + if param.key == "detections": + detections = param.value or 0 + elif param.key == "classifications": + classifications = param.value or 0 + elif param.key == "captures": + captures = param.value or 0 + + return detections, classifications, captures + except (ValueError, AttributeError): + # Stage doesn't exist or doesn't have these attributes yet + return 0, 0, 0 + + +def _update_job_progress( + job_id: int, stage: str, progress_percentage: float, complete_state: "JobState", **state_params +) -> None: from ami.jobs.models import Job, JobState # avoid circular import with transaction.atomic(): job = Job.objects.select_for_update().get(pk=job_id) + + # For results stage, accumulate detections/classifications/captures counts + if stage == "results": + current_detections, current_classifications, current_captures = _get_current_counts_from_job_progress( + job, stage + ) + + # Add new counts to existing counts + new_detections = state_params.get("detections", 0) + new_classifications = state_params.get("classifications", 0) + new_captures = state_params.get("captures", 0) + + state_params["detections"] = current_detections + new_detections + state_params["classifications"] = current_classifications + new_classifications + state_params["captures"] = current_captures + new_captures + job.progress.update_stage( stage, - status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, + status=complete_state if progress_percentage >= 1.0 else JobState.STARTED, progress=progress_percentage, + **state_params, ) if job.progress.is_complete(): - job.status = JobState.SUCCESS - job.progress.summary.status = JobState.SUCCESS + job.status = complete_state + job.progress.summary.status = complete_state job.finished_at = datetime.datetime.now() # Use naive datetime in local time job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%") job.save() diff --git a/ami/jobs/test_tasks.py b/ami/jobs/test_tasks.py index fc291c8ba..b37940cdd 100644 --- a/ami/jobs/test_tasks.py +++ b/ami/jobs/test_tasks.py @@ -17,7 +17,7 @@ from ami.jobs.tasks import process_nats_pipeline_result from ami.main.models import Detection, Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline -from ami.ml.orchestration.task_state import TaskStateManager, _lock_key +from ami.ml.orchestration.async_job_state import AsyncJobStateManager, _lock_key from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse, SourceImageResponse from ami.users.models import User @@ -64,7 +64,7 @@ def setUp(self): # Initialize state manager self.image_ids = [str(img.pk) for img in self.images] - self.state_manager = TaskStateManager(self.job.pk) + self.state_manager = AsyncJobStateManager(self.job.pk) self.state_manager.initialize_job(self.image_ids) def tearDown(self): @@ -90,7 +90,7 @@ def _assert_progress_updated( self, job_id: int, expected_processed: int, expected_total: int, stage: str = "process" ): """Assert TaskStateManager state is correct.""" - manager = TaskStateManager(job_id) + manager = AsyncJobStateManager(job_id) progress = manager.get_progress(stage) self.assertIsNotNone(progress, f"Progress not found for stage '{stage}'") self.assertEqual(progress.processed, expected_processed) @@ -157,7 +157,7 @@ def test_process_nats_pipeline_result_error_no_image_id(self, mock_manager_class # Assert: Progress was NOT updated (empty set of processed images) # Since no image_id was provided, processed_image_ids = set() - manager = TaskStateManager(self.job.pk) + manager = AsyncJobStateManager(self.job.pk) progress = manager.get_progress("process") self.assertEqual(progress.processed, 0) # No images marked as processed @@ -208,7 +208,7 @@ def test_process_nats_pipeline_result_mixed_results(self, mock_manager_class): ) # Assert: All 3 images marked as processed in TaskStateManager - manager = TaskStateManager(self.job.pk) + manager = AsyncJobStateManager(self.job.pk) process_progress = manager.get_progress("process") self.assertIsNotNone(process_progress) self.assertEqual(process_progress.processed, 3) @@ -266,7 +266,7 @@ def test_process_nats_pipeline_result_error_concurrent_locking(self, mock_manage ) # Assert: Progress was NOT updated (lock not acquired) - manager = TaskStateManager(self.job.pk) + manager = AsyncJobStateManager(self.job.pk) progress = manager.get_progress("process") self.assertEqual(progress.processed, 0) @@ -342,7 +342,7 @@ def setUp(self): ) # Initialize state manager - state_manager = TaskStateManager(self.job.pk) + state_manager = AsyncJobStateManager(self.job.pk) state_manager.initialize_job([str(self.image.pk)]) def tearDown(self): diff --git a/ami/ml/orchestration/async_job_state.py b/ami/ml/orchestration/async_job_state.py new file mode 100644 index 000000000..5a300c12a --- /dev/null +++ b/ami/ml/orchestration/async_job_state.py @@ -0,0 +1,212 @@ +""" +Internal progress tracking for async (NATS) job processing, backed by Redis. + +Multiple Celery workers process image batches concurrently and report progress +here using Redis for atomic updates with locking. This module is purely internal +— nothing outside the worker pipeline reads from it directly. + +How this relates to the Job model (ami/jobs/models.py): + + The **Job model** is the primary, external-facing record. It is what users see + in the UI, what external APIs interact with (listing open jobs, fetching tasks + to process), and what persists as history in the CMS. It has two relevant fields: + + - **Job.status** (JobState enum) — lifecycle state (CREATED → STARTED → SUCCESS/FAILURE) + - **Job.progress** (JobProgress JSONB) — detailed stage progress with params + like detections, classifications, captures counts + + This Redis layer exists only because concurrent NATS workers need atomic + counters that PostgreSQL row locks would serialize too aggressively. After each + batch, _update_job_progress() in ami/jobs/tasks.py copies the Redis snapshot + into the Job model, which is the source of truth for everything external. + +Flow: NATS result → AsyncJobStateManager.update_state() (Redis, internal) + → _update_job_progress() (writes to Job model) → UI / API reads Job +""" + +import logging +from dataclasses import dataclass + +from django.core.cache import cache + +logger = logging.getLogger(__name__) + + +@dataclass +class JobStateProgress: + """ + Progress snapshot for a job stage, read from Redis. + + All counts refer to source images (captures), not detections or occurrences. + Currently specific to ML pipeline jobs — if other job types are made available + for external processing, the unit of work ("source image") and failure semantics + may need to be generalized. + """ + + remaining: int = 0 # source images not yet processed in this stage + total: int = 0 # total source images in the job + processed: int = 0 # source images completed (success + failed) + percentage: float = 0.0 # processed / total + failed: int = 0 # source images that returned an error from the processing service + + +def _lock_key(job_id: int) -> str: + return f"job:{job_id}:process_results_lock" + + +class AsyncJobStateManager: + """ + Manages real-time job progress in Redis for concurrent NATS workers. + + Each job has per-stage pending image lists and a shared failed image set. + Workers acquire a Redis lock before mutating state, ensuring atomic updates + even when multiple Celery tasks process batches in parallel. + + The results are ephemeral — _update_job_progress() in ami/jobs/tasks.py + copies each snapshot into the persistent Job.progress JSONB field. + """ + + TIMEOUT = 86400 * 7 # 7 days in seconds + STAGES = ["process", "results"] + + def __init__(self, job_id: int): + """ + Initialize the task state manager for a specific job. + + Args: + job_id: The job primary key + """ + self.job_id = job_id + self._pending_key = f"job:{job_id}:pending_images" + self._total_key = f"job:{job_id}:pending_images_total" + self._failed_key = f"job:{job_id}:failed_images" + + def initialize_job(self, image_ids: list[str]) -> None: + """ + Initialize job tracking with a list of image IDs to process. + + Args: + image_ids: List of image IDs that need to be processed + """ + for stage in self.STAGES: + cache.set(self._get_pending_key(stage), image_ids, timeout=self.TIMEOUT) + + # Initialize failed images set for process stage only + cache.set(self._failed_key, set(), timeout=self.TIMEOUT) + + cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT) + + def _get_pending_key(self, stage: str) -> str: + return f"{self._pending_key}:{stage}" + + def update_state( + self, + processed_image_ids: set[str], + stage: str, + request_id: str, + failed_image_ids: set[str] | None = None, + ) -> None | JobStateProgress: + """ + Update the task state with newly processed images. + + Args: + processed_image_ids: Set of image IDs that have just been processed + stage: The processing stage ("process" or "results") + request_id: Unique identifier for this processing request + detections_count: Number of detections to add to cumulative count + classifications_count: Number of classifications to add to cumulative count + captures_count: Number of captures to add to cumulative count + failed_image_ids: Set of image IDs that failed processing (optional) + """ + # Create a unique lock key for this job + lock_key = _lock_key(self.job_id) + lock_timeout = 360 # 6 minutes (matches task time_limit) + lock_acquired = cache.add(lock_key, request_id, timeout=lock_timeout) + if not lock_acquired: + return None + + try: + # Update progress tracking in Redis + progress_info = self._commit_update(processed_image_ids, stage, failed_image_ids) + return progress_info + finally: + # Always release the lock when done + current_lock_value = cache.get(lock_key) + # Only delete if we still own the lock (prevents race condition) + if current_lock_value == request_id: + cache.delete(lock_key) + logger.debug(f"Released lock for job {self.job_id}, task {request_id}") + + def get_progress(self, stage: str) -> JobStateProgress | None: + """Read-only progress snapshot for the given stage. Does not acquire a lock or mutate state.""" + pending_images = cache.get(self._get_pending_key(stage)) + total_images = cache.get(self._total_key) + if pending_images is None or total_images is None: + return None + remaining = len(pending_images) + processed = total_images - remaining + percentage = float(processed) / total_images if total_images > 0 else 1.0 + failed_set = cache.get(self._failed_key) or set() + return JobStateProgress( + remaining=remaining, + total=total_images, + processed=processed, + percentage=percentage, + failed=len(failed_set), + ) + + def _commit_update( + self, + processed_image_ids: set[str], + stage: str, + failed_image_ids: set[str] | None = None, + ) -> JobStateProgress | None: + """ + Update pending images and return progress. Must be called under lock. + + Removes processed_image_ids from the pending set and persists the update. + """ + pending_images = cache.get(self._get_pending_key(stage)) + total_images = cache.get(self._total_key) + if pending_images is None or total_images is None: + return None + remaining_images = [img_id for img_id in pending_images if img_id not in processed_image_ids] + assert len(pending_images) >= len(remaining_images) + cache.set(self._get_pending_key(stage), remaining_images, timeout=self.TIMEOUT) + + remaining = len(remaining_images) + processed = total_images - remaining + percentage = float(processed) / total_images if total_images > 0 else 1.0 + + # Update failed images set if provided + if failed_image_ids: + existing_failed = cache.get(self._failed_key) or set() + updated_failed = existing_failed | failed_image_ids # Union to prevent duplicates + cache.set(self._failed_key, updated_failed, timeout=self.TIMEOUT) + failed_set = updated_failed + else: + failed_set = cache.get(self._failed_key) or set() + + failed_count = len(failed_set) + + logger.info( + f"Pending images from Redis for job {self.job_id} {stage}: " + f"{remaining}/{total_images}: {percentage*100}%" + ) + + return JobStateProgress( + remaining=remaining, + total=total_images, + processed=processed, + percentage=percentage, + failed=failed_count, + ) + + def cleanup(self) -> None: + """ + Delete all Redis keys associated with this job. + """ + for stage in self.STAGES: + cache.delete(self._get_pending_key(stage)) + cache.delete(self._failed_key) + cache.delete(self._total_key) diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index 9b4a577e9..ce54ecd1c 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -4,8 +4,8 @@ from ami.jobs.models import Job, JobState from ami.main.models import SourceImage +from ami.ml.orchestration.async_job_state import AsyncJobStateManager from ami.ml.orchestration.nats_queue import TaskQueueManager -from ami.ml.orchestration.task_state import TaskStateManager from ami.ml.schemas import PipelineProcessingTask logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ def cleanup_async_job_resources(job: "Job") -> bool: # Cleanup Redis state try: - state_manager = TaskStateManager(job.pk) + state_manager = AsyncJobStateManager(job.pk) state_manager.cleanup() job.logger.info(f"Cleaned up Redis state for job {job.pk}") redis_success = True @@ -88,7 +88,7 @@ def queue_images_to_nats(job: "Job", images: list[SourceImage]): tasks.append((image.pk, task)) # Store all image IDs in Redis for progress tracking - state_manager = TaskStateManager(job.pk) + state_manager = AsyncJobStateManager(job.pk) state_manager.initialize_job(image_ids) job.logger.info(f"Initialized task state tracking for {len(image_ids)} images") diff --git a/ami/ml/orchestration/task_state.py b/ami/ml/orchestration/task_state.py deleted file mode 100644 index b05760e68..000000000 --- a/ami/ml/orchestration/task_state.py +++ /dev/null @@ -1,135 +0,0 @@ -""" -Task state management for job progress tracking using Redis. -""" - -import logging -from collections import namedtuple - -from django.core.cache import cache - -logger = logging.getLogger(__name__) - - -# Define a namedtuple for a TaskProgress with the image counts -TaskProgress = namedtuple("TaskProgress", ["remaining", "total", "processed", "percentage"]) - - -def _lock_key(job_id: int) -> str: - return f"job:{job_id}:process_results_lock" - - -class TaskStateManager: - """ - Manages job progress tracking state in Redis. - - Tracks pending images for jobs to calculate progress percentages - as workers process images asynchronously. - """ - - TIMEOUT = 86400 * 7 # 7 days in seconds - STAGES = ["process", "results"] - - def __init__(self, job_id: int): - """ - Initialize the task state manager for a specific job. - - Args: - job_id: The job primary key - """ - self.job_id = job_id - self._pending_key = f"job:{job_id}:pending_images" - self._total_key = f"job:{job_id}:pending_images_total" - - def initialize_job(self, image_ids: list[str]) -> None: - """ - Initialize job tracking with a list of image IDs to process. - - Args: - image_ids: List of image IDs that need to be processed - """ - for stage in self.STAGES: - cache.set(self._get_pending_key(stage), image_ids, timeout=self.TIMEOUT) - - cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT) - - def _get_pending_key(self, stage: str) -> str: - return f"{self._pending_key}:{stage}" - - def update_state( - self, - processed_image_ids: set[str], - stage: str, - request_id: str, - ) -> None | TaskProgress: - """ - Update the task state with newly processed images. - - Args: - processed_image_ids: Set of image IDs that have just been processed - """ - # Create a unique lock key for this job - lock_key = _lock_key(self.job_id) - lock_timeout = 360 # 6 minutes (matches task time_limit) - lock_acquired = cache.add(lock_key, request_id, timeout=lock_timeout) - if not lock_acquired: - return None - - try: - # Update progress tracking in Redis - progress_info = self._get_progress(processed_image_ids, stage) - return progress_info - finally: - # Always release the lock when done - current_lock_value = cache.get(lock_key) - # Only delete if we still own the lock (prevents race condition) - if current_lock_value == request_id: - cache.delete(lock_key) - logger.debug(f"Released lock for job {self.job_id}, task {request_id}") - - def get_progress(self, stage: str) -> TaskProgress | None: - """Read-only progress snapshot for the given stage. Does not acquire a lock or mutate state.""" - pending_images = cache.get(self._get_pending_key(stage)) - total_images = cache.get(self._total_key) - if pending_images is None or total_images is None: - return None - remaining = len(pending_images) - processed = total_images - remaining - percentage = float(processed) / total_images if total_images > 0 else 1.0 - return TaskProgress(remaining=remaining, total=total_images, processed=processed, percentage=percentage) - - def _get_progress(self, processed_image_ids: set[str], stage: str) -> TaskProgress | None: - """ - Update pending images and return progress. Must be called under lock. - - Removes processed_image_ids from the pending set and persists the update. - """ - pending_images = cache.get(self._get_pending_key(stage)) - total_images = cache.get(self._total_key) - if pending_images is None or total_images is None: - return None - remaining_images = [img_id for img_id in pending_images if img_id not in processed_image_ids] - assert len(pending_images) >= len(remaining_images) - cache.set(self._get_pending_key(stage), remaining_images, timeout=self.TIMEOUT) - - remaining = len(remaining_images) - processed = total_images - remaining - percentage = float(processed) / total_images if total_images > 0 else 1.0 - logger.info( - f"Pending images from Redis for job {self.job_id} {stage}: " - f"{remaining}/{total_images}: {percentage*100}%" - ) - - return TaskProgress( - remaining=remaining, - total=total_images, - processed=processed, - percentage=percentage, - ) - - def cleanup(self) -> None: - """ - Delete all Redis keys associated with this job. - """ - for stage in self.STAGES: - cache.delete(self._get_pending_key(stage)) - cache.delete(self._total_key) diff --git a/ami/ml/orchestration/tests/test_cleanup.py b/ami/ml/orchestration/tests/test_cleanup.py index ef8382d3d..ccdfa2c49 100644 --- a/ami/ml/orchestration/tests/test_cleanup.py +++ b/ami/ml/orchestration/tests/test_cleanup.py @@ -9,9 +9,9 @@ from ami.jobs.tasks import _update_job_progress, update_job_failure, update_job_status from ami.main.models import Project, ProjectFeatureFlags, SourceImage, SourceImageCollection from ami.ml.models import Pipeline +from ami.ml.orchestration.async_job_state import AsyncJobStateManager from ami.ml.orchestration.jobs import queue_images_to_nats from ami.ml.orchestration.nats_queue import TaskQueueManager -from ami.ml.orchestration.task_state import TaskStateManager class TestCleanupAsyncJobResources(TestCase): @@ -59,7 +59,7 @@ def _verify_resources_created(self, job_id: int): job_id: The job ID to check """ # Verify Redis keys exist - state_manager = TaskStateManager(job_id) + state_manager = AsyncJobStateManager(job_id) for stage in state_manager.STAGES: pending_key = state_manager._get_pending_key(stage) self.assertIsNotNone(cache.get(pending_key), f"Redis key {pending_key} should exist") @@ -125,7 +125,7 @@ def _verify_resources_cleaned(self, job_id: int): job_id: The job ID to check """ # Verify Redis keys are deleted - state_manager = TaskStateManager(job_id) + state_manager = AsyncJobStateManager(job_id) for stage in state_manager.STAGES: pending_key = state_manager._get_pending_key(stage) self.assertIsNone(cache.get(pending_key), f"Redis key {pending_key} should be deleted") @@ -164,9 +164,9 @@ def test_cleanup_on_job_completion(self): job = self._create_job_with_queued_images() # Simulate job completion: complete all stages (collect, process, then results) - _update_job_progress(job.pk, stage="collect", progress_percentage=1.0) - _update_job_progress(job.pk, stage="process", progress_percentage=1.0) - _update_job_progress(job.pk, stage="results", progress_percentage=1.0) + _update_job_progress(job.pk, stage="collect", progress_percentage=1.0, complete_state=JobState.SUCCESS) + _update_job_progress(job.pk, stage="process", progress_percentage=1.0, complete_state=JobState.SUCCESS) + _update_job_progress(job.pk, stage="results", progress_percentage=1.0, complete_state=JobState.SUCCESS) # Verify cleanup happened self._verify_resources_cleaned(job.pk) diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 20e0368fe..2f269a6f9 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -864,22 +864,23 @@ def setUp(self): """Set up test fixtures.""" from django.core.cache import cache - from ami.ml.orchestration.task_state import TaskStateManager + from ami.ml.orchestration.async_job_state import AsyncJobStateManager cache.clear() self.job_id = 123 - self.manager = TaskStateManager(self.job_id) + self.manager = AsyncJobStateManager(self.job_id) self.image_ids = ["img1", "img2", "img3", "img4", "img5"] def _init_and_verify(self, image_ids): """Helper to initialize job and verify initial state.""" self.manager.initialize_job(image_ids) - progress = self.manager._get_progress(set(), "process") + progress = self.manager._commit_update(set(), "process") assert progress is not None self.assertEqual(progress.total, len(image_ids)) self.assertEqual(progress.remaining, len(image_ids)) self.assertEqual(progress.processed, 0) self.assertEqual(progress.percentage, 0.0) + self.assertEqual(progress.failed, 0) return progress def test_initialize_job(self): @@ -888,30 +889,31 @@ def test_initialize_job(self): # Verify both stages are initialized for stage in self.manager.STAGES: - progress = self.manager._get_progress(set(), stage) + progress = self.manager._commit_update(set(), stage) assert progress is not None self.assertEqual(progress.total, len(self.image_ids)) + self.assertEqual(progress.failed, 0) def test_progress_tracking(self): """Test progress updates correctly as images are processed.""" self._init_and_verify(self.image_ids) # Process 2 images - progress = self.manager._get_progress({"img1", "img2"}, "process") + progress = self.manager._commit_update({"img1", "img2"}, "process") assert progress is not None self.assertEqual(progress.remaining, 3) self.assertEqual(progress.processed, 2) self.assertEqual(progress.percentage, 0.4) # Process 2 more images - progress = self.manager._get_progress({"img3", "img4"}, "process") + progress = self.manager._commit_update({"img3", "img4"}, "process") assert progress is not None self.assertEqual(progress.remaining, 1) self.assertEqual(progress.processed, 4) self.assertEqual(progress.percentage, 0.8) # Process last image - progress = self.manager._get_progress({"img5"}, "process") + progress = self.manager._commit_update({"img5"}, "process") assert progress is not None self.assertEqual(progress.remaining, 0) self.assertEqual(progress.processed, 5) @@ -947,20 +949,20 @@ def test_stages_independent(self): self._init_and_verify(self.image_ids) # Update process stage - self.manager._get_progress({"img1", "img2"}, "process") - progress_process = self.manager._get_progress(set(), "process") + self.manager._commit_update({"img1", "img2"}, "process") + progress_process = self.manager._commit_update(set(), "process") assert progress_process is not None self.assertEqual(progress_process.remaining, 3) # Results stage should still have all images pending - progress_results = self.manager._get_progress(set(), "results") + progress_results = self.manager._commit_update(set(), "results") assert progress_results is not None self.assertEqual(progress_results.remaining, 5) def test_empty_job(self): """Test handling of job with no images.""" self.manager.initialize_job([]) - progress = self.manager._get_progress(set(), "process") + progress = self.manager._commit_update(set(), "process") assert progress is not None self.assertEqual(progress.total, 0) self.assertEqual(progress.percentage, 1.0) # Empty job is 100% complete @@ -970,12 +972,65 @@ def test_cleanup(self): self._init_and_verify(self.image_ids) # Verify keys exist - progress = self.manager._get_progress(set(), "process") + progress = self.manager._commit_update(set(), "process") self.assertIsNotNone(progress) # Cleanup self.manager.cleanup() # Verify keys are gone - progress = self.manager._get_progress(set(), "process") + progress = self.manager._commit_update(set(), "process") self.assertIsNone(progress) + + def test_failed_image_tracking(self): + """Test basic failed image tracking with no double-counting on retries.""" + self._init_and_verify(self.image_ids) + + # Mark 2 images as failed in process stage + progress = self.manager._commit_update({"img1", "img2"}, "process", failed_image_ids={"img1", "img2"}) + assert progress is not None + self.assertEqual(progress.failed, 2) + + # Retry same 2 images (fail again) - should not double-count + progress = self.manager._commit_update(set(), "process", failed_image_ids={"img1", "img2"}) + assert progress is not None + self.assertEqual(progress.failed, 2) + + # Fail a different image + progress = self.manager._commit_update(set(), "process", failed_image_ids={"img3"}) + assert progress is not None + self.assertEqual(progress.failed, 3) + + def test_failed_and_processed_mixed(self): + """Test mixed successful and failed processing in same batch.""" + self._init_and_verify(self.image_ids) + + # Process 2 successfully, 2 fail, 1 remains pending + progress = self.manager._commit_update( + {"img1", "img2", "img3", "img4"}, "process", failed_image_ids={"img3", "img4"} + ) + assert progress is not None + self.assertEqual(progress.processed, 4) + self.assertEqual(progress.failed, 2) + self.assertEqual(progress.remaining, 1) + self.assertEqual(progress.percentage, 0.8) + + def test_cleanup_removes_failed_set(self): + """Test that cleanup removes failed image set.""" + from django.core.cache import cache + + self._init_and_verify(self.image_ids) + + # Add failed images + self.manager._commit_update({"img1", "img2"}, "process", failed_image_ids={"img1", "img2"}) + + # Verify failed set exists + failed_set = cache.get(self.manager._failed_key) + self.assertEqual(len(failed_set), 2) + + # Cleanup + self.manager.cleanup() + + # Verify failed set is gone + failed_set = cache.get(self.manager._failed_key) + self.assertIsNone(failed_set)