diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 17083fb84..3548d0ea5 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -159,7 +159,7 @@ def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED, progress=progress_percentage, ) - if stage == "results" and progress_percentage >= 1.0: + if job.progress.is_complete(): job.status = JobState.SUCCESS job.progress.summary.status = JobState.SUCCESS job.finished_at = datetime.datetime.now() # Use naive datetime in local time @@ -167,23 +167,24 @@ def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> job.save() # Clean up async resources for completed jobs that use NATS/Redis - # Only ML jobs with async_pipeline_workers enabled use these resources - if stage == "results" and progress_percentage >= 1.0: + if job.progress.is_complete(): job = Job.objects.get(pk=job_id) # Re-fetch outside transaction _cleanup_job_if_needed(job) def _cleanup_job_if_needed(job) -> None: """ - Clean up async resources (NATS/Redis) if this job type uses them. + Clean up async resources (NATS/Redis) if this job uses them. - Only ML jobs with async_pipeline_workers enabled use NATS/Redis resources. + Only jobs with ASYNC_API dispatch mode use NATS/Redis resources. This function is safe to call for any job - it checks if cleanup is needed. Args: job: The Job instance """ - if job.job_type_key == "ml" and job.project and job.project.feature_flags.async_pipeline_workers: + from ami.jobs.models import JobDispatchMode + + if job.dispatch_mode == JobDispatchMode.ASYNC_API: # import here to avoid circular imports from ami.ml.orchestration.jobs import cleanup_async_job_resources diff --git a/ami/ml/orchestration/tests/test_cleanup.py b/ami/ml/orchestration/tests/test_cleanup.py index f7b686d44..ef8382d3d 100644 --- a/ami/ml/orchestration/tests/test_cleanup.py +++ b/ami/ml/orchestration/tests/test_cleanup.py @@ -5,7 +5,7 @@ from django.test import TestCase from nats.js.errors import NotFoundError -from ami.jobs.models import Job, JobState, MLJob +from ami.jobs.models import Job, JobDispatchMode, JobState, MLJob from ami.jobs.tasks import _update_job_progress, update_job_failure, update_job_status from ami.main.models import Project, ProjectFeatureFlags, SourceImage, SourceImageCollection from ami.ml.models import Pipeline @@ -106,6 +106,7 @@ def _create_job_with_queued_images(self) -> Job: name="Test Cleanup Job", pipeline=self.pipeline, source_image_collection=self.collection, + dispatch_mode=JobDispatchMode.ASYNC_API, ) # Queue images to NATS (also initializes Redis state) @@ -162,7 +163,9 @@ def test_cleanup_on_job_completion(self): """Test that resources are cleaned up when job completes successfully.""" job = self._create_job_with_queued_images() - # Simulate job completion by updating progress to 100% in results stage + # Simulate job completion: complete all stages (collect, process, then results) + _update_job_progress(job.pk, stage="collect", progress_percentage=1.0) + _update_job_progress(job.pk, stage="process", progress_percentage=1.0) _update_job_progress(job.pk, stage="results", progress_percentage=1.0) # Verify cleanup happened