diff --git a/ami/jobs/schemas.py b/ami/jobs/schemas.py index 0e1ea4ac7..2225c3799 100644 --- a/ami/jobs/schemas.py +++ b/ami/jobs/schemas.py @@ -20,3 +20,10 @@ required=False, type=int, ) + +processing_service_name_param = OpenApiParameter( + name="processing_service_name", + description="Name of the calling processing service", + required=False, + type=str, +) diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index 7902faeb1..278438448 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -238,7 +238,7 @@ def _create_pipeline(self, name: str = "Test Pipeline", slug: str = "test-pipeli self.pipeline = pipeline return pipeline - def _create_ml_job(self, name: str, pipeline: Pipeline) -> Job: + def _create_ml_job(self, name: str, pipeline: Pipeline, **kwargs) -> Job: """Helper to create an ML job with a pipeline.""" return Job.objects.create( job_type_key=MLJob.key, @@ -246,6 +246,7 @@ def _create_ml_job(self, name: str, pipeline: Pipeline) -> Job: name=name, pipeline=pipeline, source_image_collection=self.source_image_collection, + **kwargs, ) def test_create_job(self): @@ -413,9 +414,7 @@ def test_search_jobs(self): def _task_batch_helper(self, value: Any, expected_status: int): pipeline = self._create_pipeline() - job = self._create_ml_job("Job for batch test", pipeline) - job.dispatch_mode = JobDispatchMode.ASYNC_API - job.save(update_fields=["dispatch_mode"]) + job = self._create_ml_job("Job for batch test", pipeline, dispatch_mode=JobDispatchMode.ASYNC_API) images = [ SourceImage.objects.create( path=f"image_{i}.jpg", @@ -520,6 +519,52 @@ def test_result_endpoint_validation(self): self.assertEqual(resp.status_code, 400) self.assertIn("result", resp.json()[0].lower()) + def test_processing_service_name_parameter(self): + """Test that processing_service_name parameter is accepted on job endpoints.""" + self.client.force_authenticate(user=self.user) + test_service_name = "Test Service" + + # Test list endpoint + list_url = reverse_with_params( + "api:job-list", params={"project_id": self.project.pk, "processing_service_name": test_service_name} + ) + resp = self.client.get(list_url) + self.assertEqual(resp.status_code, 200) + + # Test tasks endpoint (requires job with pipeline) + pipeline = self._create_pipeline() + job = self._create_ml_job("Job for service name test", pipeline, dispatch_mode=JobDispatchMode.ASYNC_API) + + tasks_url = reverse_with_params( + "api:job-tasks", + args=[job.pk], + params={"project_id": self.project.pk, "batch": 1, "processing_service_name": test_service_name}, + ) + resp = self.client.get(tasks_url) + self.assertEqual(resp.status_code, 200) + + # Test result endpoint + result_url = reverse_with_params( + "api:job-result", + args=[job.pk], + params={"project_id": self.project.pk, "processing_service_name": test_service_name}, + ) + result_data = [ + { + "reply_subject": "test.reply.1", + "result": { + "pipeline": "test-pipeline", + "algorithms": {}, + "total_time": 1.5, + "source_images": [], + "detections": [], + "errors": None, + }, + } + ] + resp = self.client.post(result_url, result_data, format="json") + self.assertEqual(resp.status_code, 200) + class TestJobDispatchModeFiltering(APITestCase): """Test job filtering by dispatch_mode.""" diff --git a/ami/jobs/views.py b/ami/jobs/views.py index dd8da01b2..eb3ab258c 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -15,7 +15,7 @@ from ami.base.permissions import ObjectPermission from ami.base.views import ProjectMixin -from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param +from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param, processing_service_name_param from ami.jobs.tasks import process_nats_pipeline_result from ami.main.api.schemas import project_id_doc_param from ami.main.api.views import DefaultViewSet @@ -204,13 +204,16 @@ def get_queryset(self) -> QuerySet: project_id_doc_param, ids_only_param, incomplete_only_param, + processing_service_name_param, ] ) def list(self, request, *args, **kwargs): + _ = _log_processing_service_name(request, "list requested", logger) + return super().list(request, *args, **kwargs) @extend_schema( - parameters=[batch_param], + parameters=[batch_param, processing_service_name_param], responses={200: dict}, ) @action(detail=True, methods=["get"], name="tasks") @@ -229,6 +232,7 @@ def tasks(self, request, pk=None): except Exception as e: raise ValidationError({"batch": str(e)}) from e + _ = _log_processing_service_name(request, f"tasks ({batch}) requested for job {job.pk}", job.logger) # Only async_api jobs have tasks fetchable from NATS if job.dispatch_mode != JobDispatchMode.ASYNC_API: raise ValidationError("Only async_api jobs have fetchable tasks") @@ -254,6 +258,9 @@ async def get_tasks(): return Response({"tasks": tasks}) + @extend_schema( + parameters=[processing_service_name_param], + ) @action(detail=True, methods=["post"], name="result") def result(self, request, pk=None): """ @@ -266,6 +273,8 @@ def result(self, request, pk=None): job = self.get_object() + _ = _log_processing_service_name(request, f"result received for job {job.pk}", job.logger) + # Validate request data is a list if isinstance(request.data, list): results = request.data @@ -325,3 +334,24 @@ def result(self, request, pk=None): }, status=500, ) + + +def _log_processing_service_name(request, context: str, logger: logging.Logger) -> str | None: + """ + Log the processing_service_name from query parameters. + + Args: + request: The HTTP request object + context: A string describing the operation (e.g., "tasks requested", "result received") + logger: A logging.Logger instance to use for logging + Returns: + The processing_service_name if provided, otherwise None + """ + processing_service_name = request.query_params.get("processing_service_name", None) + + if processing_service_name: + logger.info(f"Jobs {context} by processing service: {processing_service_name}") + else: + logger.debug(f"Jobs {context} without processing service name") + + return processing_service_name diff --git a/ami/utils/requests.py b/ami/utils/requests.py index e4de57c0f..c4396b725 100644 --- a/ami/utils/requests.py +++ b/ami/utils/requests.py @@ -2,8 +2,6 @@ import requests from django.forms import BooleanField, FloatField -from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import OpenApiParameter from requests.adapters import HTTPAdapter from rest_framework.request import Request from urllib3.util import Retry @@ -144,30 +142,3 @@ def get_default_classification_threshold(project: "Project | None" = None, reque return project.default_filters_score_threshold else: return default_threshold - - -project_id_doc_param = OpenApiParameter( - name="project_id", - description="Filter by project ID", - required=False, - type=int, -) - -ids_only_param = OpenApiParameter( - name="ids_only", - description="Return only job IDs instead of full job objects", - required=False, - type=OpenApiTypes.BOOL, -) -incomplete_only_param = OpenApiParameter( - name="incomplete_only", - description="Filter to only show incomplete jobs (excludes SUCCESS, FAILURE, REVOKED)", - required=False, - type=OpenApiTypes.BOOL, -) -batch_param = OpenApiParameter( - name="batch", - description="Number of tasks to pull in the batch", - required=False, - type=OpenApiTypes.INT, -)