Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b60eab0
merge
carlos-irreverentlabs Jan 16, 2026
644927f
Merge remote-tracking branch 'upstream/main'
carlosgjs Jan 22, 2026
218f7aa
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 3, 2026
867201d
Update ML job counts in async case
carlosgjs Feb 6, 2026
90da389
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 10, 2026
cdf57ea
Update date picker version and tweak layout logic (#1105)
annavik Feb 6, 2026
6837ad6
fix: Properly handle async job state with celery tasks (#1114)
carlosgjs Feb 7, 2026
f1cd62d
PSv2: Implement queue clean-up upon job completion (#1113)
carlosgjs Feb 7, 2026
74df9ea
fix: PSv2: Workers should not try to fetch tasks from v1 jobs (#1118)
carlosgjs Feb 9, 2026
4a082d3
PSv2 cleanup: use is_complete() and dispatch_mode in job progress han…
mihow Feb 10, 2026
9d560cf
Merge branch 'main' into carlos/trackcounts
carlosgjs Feb 10, 2026
e43536b
track captures and failures
carlosgjs Feb 11, 2026
50df5f6
Update tests, CR feedback, log error images
carlosgjs Feb 11, 2026
3287fe2
CR feedback
carlosgjs Feb 11, 2026
a87b05a
fix type checking
carlosgjs Feb 11, 2026
89bf950
Merge remote-tracking branch 'origin/main' into carlos/trackcounts
mihow Feb 12, 2026
a5ff6f8
refactor: rename _get_progress to _commit_update in TaskStateManager
mihow Feb 12, 2026
337b7fc
fix: unify FAILURE_THRESHOLD and convert TaskProgress to dataclass
mihow Feb 12, 2026
8618d3c
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 13, 2026
4331dee
Merge branch 'main' into carlos/trackcounts
carlosgjs Feb 13, 2026
afee6e7
refactor: rename TaskProgress to JobStateProgress
mihow Feb 12, 2026
65d77cb
docs: update NATS todo and planning docs with session learnings
mihow Feb 13, 2026
8e8cd80
Rename TaskStateManager to AsyncJobStateManager
carlosgjs Feb 13, 2026
34af787
Merge branch 'carlos/trackcounts' of github.com:uw-ssec/antenna into …
carlosgjs Feb 13, 2026
afc4472
Track results counts in the job itself vs Redis
carlosgjs Feb 13, 2026
b6c3c6a
small simplification
carlosgjs Feb 13, 2026
b15024f
Reset counts to 0 on reset
carlosgjs Feb 13, 2026
b2e4a72
chore: remove local planning docs from PR branch
mihow Feb 17, 2026
a15ebda
docs: clarify three-layer job state architecture in docstrings
mihow Feb 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions ami/jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def python_slugify(value: str) -> str:


class JobProgressSummary(pydantic.BaseModel):
"""Summary of all stages of a job"""
"""Top-level status and progress for a job, shown in the UI."""

status: JobState = JobState.CREATED
progress: float = 0
Expand All @@ -132,7 +132,17 @@ class JobProgressStageDetail(ConfigurableStage, JobProgressSummary):


class JobProgress(pydantic.BaseModel):
"""The full progress of a job and its stages."""
"""
The user-facing progress of a job, stored as JSONB on the Job model.

This is what the UI displays and what external APIs read. Contains named
stages ("process", "results") with per-stage params (progress percentage,
detections/classifications/captures counts, failed count).

For async (NATS) jobs, updated by _update_job_progress() in ami/jobs/tasks.py
which copies snapshots from the internal Redis-backed AsyncJobStateManager.
For sync jobs, updated directly in MLJob.process_images().
"""

summary: JobProgressSummary
stages: list[JobProgressStageDetail]
Expand Down Expand Up @@ -222,6 +232,10 @@ def reset(self, status: JobState = JobState.CREATED):
for stage in self.stages:
stage.progress = 0
stage.status = status
# Reset numeric param values to 0
for param in stage.params:
if isinstance(param.value, (int, float)):
param.value = 0

def is_complete(self) -> bool:
"""
Expand Down Expand Up @@ -561,7 +575,8 @@ def process_images(cls, job, images):

job.logger.info(f"All tasks completed for job {job.pk}")

FAILURE_THRESHOLD = 0.5
from ami.jobs.tasks import FAILURE_THRESHOLD

if image_count and (percent_successful < FAILURE_THRESHOLD):
job.progress.update_stage("process", status=JobState.FAILURE)
job.save()
Expand Down
131 changes: 117 additions & 14 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,25 @@
import logging
import time
from collections.abc import Callable
from typing import TYPE_CHECKING

from asgiref.sync import async_to_sync
from celery.signals import task_failure, task_postrun, task_prerun
from django.db import transaction

from ami.ml.orchestration.async_job_state import AsyncJobStateManager
from ami.ml.orchestration.nats_queue import TaskQueueManager
from ami.ml.orchestration.task_state import TaskStateManager
from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse
from ami.tasks import default_soft_time_limit, default_time_limit
from config import celery_app

if TYPE_CHECKING:
from ami.jobs.models import JobState

logger = logging.getLogger(__name__)
# Minimum success rate. Jobs with fewer than this fraction of images
# processed successfully are marked as failed. Also used in MLJob.process_images().
FAILURE_THRESHOLD = 0.5


@celery_app.task(bind=True, soft_time_limit=default_soft_time_limit, time_limit=default_time_limit)
Expand Down Expand Up @@ -59,23 +66,27 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
result_data: Dictionary containing the pipeline result
reply_subject: NATS reply subject for acknowledgment
"""
from ami.jobs.models import Job # avoid circular import
from ami.jobs.models import Job, JobState # avoid circular import

_, t = log_time()

# Validate with Pydantic - check for error response first
error_result = None
if "error" in result_data:
error_result = PipelineResultsError(**result_data)
processed_image_ids = {str(error_result.image_id)} if error_result.image_id else set()
logger.error(f"Pipeline returned error for job {job_id}, image {error_result.image_id}: {error_result.error}")
failed_image_ids = processed_image_ids # Same as processed for errors
pipeline_result = None
else:
pipeline_result = PipelineResultsResponse(**result_data)
processed_image_ids = {str(img.id) for img in pipeline_result.source_images}
failed_image_ids = set() # No failures for successful results

state_manager = TaskStateManager(job_id)
state_manager = AsyncJobStateManager(job_id)

progress_info = state_manager.update_state(processed_image_ids, stage="process", request_id=self.request.id)
progress_info = state_manager.update_state(
processed_image_ids, stage="process", request_id=self.request.id, failed_image_ids=failed_image_ids
)
if not progress_info:
logger.warning(
f"Another task is already processing results for job {job_id}. "
Expand All @@ -84,16 +95,31 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
raise self.retry(countdown=5, max_retries=10)

try:
_update_job_progress(job_id, "process", progress_info.percentage)
complete_state = JobState.SUCCESS
if progress_info.total > 0 and (progress_info.failed / progress_info.total) > FAILURE_THRESHOLD:
complete_state = JobState.FAILURE
_update_job_progress(
job_id,
"process",
progress_info.percentage,
complete_state=complete_state,
processed=progress_info.processed,
remaining=progress_info.remaining,
failed=progress_info.failed,
)

_, t = t(f"TIME: Updated job {job_id} progress in PROCESS stage progress to {progress_info.percentage*100}%")
job = Job.objects.get(pk=job_id)
job.logger.info(f"Processing pipeline result for job {job_id}, reply_subject: {reply_subject}")
job.logger.info(
f" Job {job_id} progress: {progress_info.processed}/{progress_info.total} images processed "
f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {len(processed_image_ids)} just "
"processed"
f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {progress_info.failed} failed, "
f"{len(processed_image_ids)} just processed"
)
if error_result:
job.logger.error(
f"Pipeline returned error for job {job_id}, image {error_result.image_id}: {error_result.error}"
)
except Job.DoesNotExist:
# don't raise and ack so that we don't retry since the job doesn't exists
logger.error(f"Job {job_id} not found")
Expand All @@ -102,6 +128,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub

try:
# Save to database (this is the slow operation)
detections_count, classifications_count, captures_count = 0, 0, 0
if pipeline_result:
# should never happen since otherwise we could not be processing results here
assert job.pipeline is not None, "Job pipeline is None"
Expand All @@ -112,18 +139,41 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
f"Saved pipeline results to database with {len(pipeline_result.detections)} detections"
f", percentage: {progress_info.percentage*100}%"
)
# Calculate detection and classification counts from this result
detections_count = len(pipeline_result.detections)
classifications_count = sum(len(detection.classifications) for detection in pipeline_result.detections)
captures_count = len(pipeline_result.source_images)

_ack_task_via_nats(reply_subject, job.logger)
# Update job stage with calculated progress
progress_info = state_manager.update_state(processed_image_ids, stage="results", request_id=self.request.id)

progress_info = state_manager.update_state(
processed_image_ids,
stage="results",
request_id=self.request.id,
)

if not progress_info:
logger.warning(
f"Another task is already processing results for job {job_id}. "
f"Retrying task {self.request.id} in 5 seconds..."
)
raise self.retry(countdown=5, max_retries=10)
_update_job_progress(job_id, "results", progress_info.percentage)

# update complete state based on latest progress info after saving results
complete_state = JobState.SUCCESS
if progress_info.total > 0 and (progress_info.failed / progress_info.total) > FAILURE_THRESHOLD:
complete_state = JobState.FAILURE

_update_job_progress(
job_id,
"results",
progress_info.percentage,
complete_state=complete_state,
detections=detections_count,
classifications=classifications_count,
captures=captures_count,
)

except Exception as e:
job.logger.error(
Expand All @@ -149,19 +199,72 @@ async def ack_task():
# Don't fail the task if ACK fails - data is already saved


def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> None:
def _get_current_counts_from_job_progress(job, stage: str) -> tuple[int, int, int]:
"""
Get current detections, classifications, and captures counts from job progress.

Args:
job: The Job instance
stage: The stage name to read counts from

Returns:
Tuple of (detections, classifications, captures) counts, defaulting to 0 if not found
"""
try:
stage_obj = job.progress.get_stage(stage)

# Initialize defaults
detections = 0
classifications = 0
captures = 0

# Search through the params list for our count values
for param in stage_obj.params:
if param.key == "detections":
detections = param.value or 0
elif param.key == "classifications":
classifications = param.value or 0
elif param.key == "captures":
captures = param.value or 0

return detections, classifications, captures
except (ValueError, AttributeError):
# Stage doesn't exist or doesn't have these attributes yet
return 0, 0, 0


def _update_job_progress(
job_id: int, stage: str, progress_percentage: float, complete_state: "JobState", **state_params
) -> None:
from ami.jobs.models import Job, JobState # avoid circular import

with transaction.atomic():
job = Job.objects.select_for_update().get(pk=job_id)

# For results stage, accumulate detections/classifications/captures counts
if stage == "results":
current_detections, current_classifications, current_captures = _get_current_counts_from_job_progress(
job, stage
)

# Add new counts to existing counts
new_detections = state_params.get("detections", 0)
new_classifications = state_params.get("classifications", 0)
new_captures = state_params.get("captures", 0)

state_params["detections"] = current_detections + new_detections
state_params["classifications"] = current_classifications + new_classifications
state_params["captures"] = current_captures + new_captures

job.progress.update_stage(
stage,
status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED,
status=complete_state if progress_percentage >= 1.0 else JobState.STARTED,
progress=progress_percentage,
**state_params,
)
if job.progress.is_complete():
job.status = JobState.SUCCESS
job.progress.summary.status = JobState.SUCCESS
job.status = complete_state
job.progress.summary.status = complete_state
job.finished_at = datetime.datetime.now() # Use naive datetime in local time
job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%")
job.save()
Expand Down
14 changes: 7 additions & 7 deletions ami/jobs/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ami.jobs.tasks import process_nats_pipeline_result
from ami.main.models import Detection, Project, SourceImage, SourceImageCollection
from ami.ml.models import Pipeline
from ami.ml.orchestration.task_state import TaskStateManager, _lock_key
from ami.ml.orchestration.async_job_state import AsyncJobStateManager, _lock_key
from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse, SourceImageResponse
from ami.users.models import User

Expand Down Expand Up @@ -64,7 +64,7 @@ def setUp(self):

# Initialize state manager
self.image_ids = [str(img.pk) for img in self.images]
self.state_manager = TaskStateManager(self.job.pk)
self.state_manager = AsyncJobStateManager(self.job.pk)
self.state_manager.initialize_job(self.image_ids)

def tearDown(self):
Expand All @@ -90,7 +90,7 @@ def _assert_progress_updated(
self, job_id: int, expected_processed: int, expected_total: int, stage: str = "process"
):
"""Assert TaskStateManager state is correct."""
manager = TaskStateManager(job_id)
manager = AsyncJobStateManager(job_id)
progress = manager.get_progress(stage)
self.assertIsNotNone(progress, f"Progress not found for stage '{stage}'")
self.assertEqual(progress.processed, expected_processed)
Expand Down Expand Up @@ -157,7 +157,7 @@ def test_process_nats_pipeline_result_error_no_image_id(self, mock_manager_class

# Assert: Progress was NOT updated (empty set of processed images)
# Since no image_id was provided, processed_image_ids = set()
manager = TaskStateManager(self.job.pk)
manager = AsyncJobStateManager(self.job.pk)
progress = manager.get_progress("process")
self.assertEqual(progress.processed, 0) # No images marked as processed

Expand Down Expand Up @@ -208,7 +208,7 @@ def test_process_nats_pipeline_result_mixed_results(self, mock_manager_class):
)

# Assert: All 3 images marked as processed in TaskStateManager
manager = TaskStateManager(self.job.pk)
manager = AsyncJobStateManager(self.job.pk)
process_progress = manager.get_progress("process")
self.assertIsNotNone(process_progress)
self.assertEqual(process_progress.processed, 3)
Expand Down Expand Up @@ -266,7 +266,7 @@ def test_process_nats_pipeline_result_error_concurrent_locking(self, mock_manage
)

# Assert: Progress was NOT updated (lock not acquired)
manager = TaskStateManager(self.job.pk)
manager = AsyncJobStateManager(self.job.pk)
progress = manager.get_progress("process")
self.assertEqual(progress.processed, 0)

Expand Down Expand Up @@ -342,7 +342,7 @@ def setUp(self):
)

# Initialize state manager
state_manager = TaskStateManager(self.job.pk)
state_manager = AsyncJobStateManager(self.job.pk)
state_manager.initialize_job([str(self.image.pk)])

def tearDown(self):
Expand Down
Loading