Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b60eab0
merge
carlos-irreverentlabs Jan 16, 2026
644927f
Merge remote-tracking branch 'upstream/main'
carlosgjs Jan 22, 2026
218f7aa
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 3, 2026
02aa6fa
Accept and log `processing_service_name` parameter from workers
carlosgjs Feb 4, 2026
90d729f
refactor
carlosgjs Feb 5, 2026
335229d
Clean up
carlosgjs Feb 5, 2026
117057f
Merge branch 'main' into carlosg/servname2
carlosgjs Feb 6, 2026
5b53380
Address CR feedback
carlosgjs Feb 6, 2026
90da389
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 10, 2026
8618d3c
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 13, 2026
bd1be5f
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 17, 2026
ba74747
fix: Properly handle async job state with celery tasks (#1114)
carlosgjs Feb 7, 2026
595f4c9
PSv2: Implement queue clean-up upon job completion (#1113)
carlosgjs Feb 7, 2026
0d8c649
fix: PSv2: Workers should not try to fetch tasks from v1 jobs (#1118)
carlosgjs Feb 9, 2026
751573e
PSv2 cleanup: use is_complete() and dispatch_mode in job progress han…
mihow Feb 10, 2026
7a28477
Tests for async result processing (#1129)
carlosgjs Feb 12, 2026
2554950
PSv2: Track and display image count progress and state (#1121)
carlosgjs Feb 17, 2026
9ae607c
PSV2: API endpoint for external processing services to register pipel…
carlosgjs Feb 17, 2026
3f4729b
Merge branch 'main' into carlosg/servname2
carlosgjs Feb 18, 2026
cae6ff3
update tests
carlosgjs Feb 18, 2026
9cb886b
clean up
carlosgjs Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ami/jobs/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,10 @@
required=False,
type=int,
)

processing_service_name_param = OpenApiParameter(
name="processing_service_name",
description="Name of the calling processing service",
required=False,
type=str,
)
53 changes: 49 additions & 4 deletions ami/jobs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,15 @@ def _create_pipeline(self, name: str = "Test Pipeline", slug: str = "test-pipeli
self.pipeline = pipeline
return pipeline

def _create_ml_job(self, name: str, pipeline: Pipeline) -> Job:
def _create_ml_job(self, name: str, pipeline: Pipeline, **kwargs) -> Job:
"""Helper to create an ML job with a pipeline."""
return Job.objects.create(
job_type_key=MLJob.key,
project=self.project,
name=name,
pipeline=pipeline,
source_image_collection=self.source_image_collection,
**kwargs,
)

def test_create_job(self):
Expand Down Expand Up @@ -413,9 +414,7 @@ def test_search_jobs(self):

def _task_batch_helper(self, value: Any, expected_status: int):
pipeline = self._create_pipeline()
job = self._create_ml_job("Job for batch test", pipeline)
job.dispatch_mode = JobDispatchMode.ASYNC_API
job.save(update_fields=["dispatch_mode"])
job = self._create_ml_job("Job for batch test", pipeline, dispatch_mode=JobDispatchMode.ASYNC_API)
images = [
SourceImage.objects.create(
path=f"image_{i}.jpg",
Expand Down Expand Up @@ -520,6 +519,52 @@ def test_result_endpoint_validation(self):
self.assertEqual(resp.status_code, 400)
self.assertIn("result", resp.json()[0].lower())

def test_processing_service_name_parameter(self):
"""Test that processing_service_name parameter is accepted on job endpoints."""
self.client.force_authenticate(user=self.user)
test_service_name = "Test Service"

# Test list endpoint
list_url = reverse_with_params(
"api:job-list", params={"project_id": self.project.pk, "processing_service_name": test_service_name}
)
resp = self.client.get(list_url)
self.assertEqual(resp.status_code, 200)

# Test tasks endpoint (requires job with pipeline)
pipeline = self._create_pipeline()
job = self._create_ml_job("Job for service name test", pipeline, dispatch_mode=JobDispatchMode.ASYNC_API)

tasks_url = reverse_with_params(
"api:job-tasks",
args=[job.pk],
params={"project_id": self.project.pk, "batch": 1, "processing_service_name": test_service_name},
)
resp = self.client.get(tasks_url)
self.assertEqual(resp.status_code, 200)

# Test result endpoint
result_url = reverse_with_params(
"api:job-result",
args=[job.pk],
params={"project_id": self.project.pk, "processing_service_name": test_service_name},
)
result_data = [
{
"reply_subject": "test.reply.1",
"result": {
"pipeline": "test-pipeline",
"algorithms": {},
"total_time": 1.5,
"source_images": [],
"detections": [],
"errors": None,
},
}
]
resp = self.client.post(result_url, result_data, format="json")
self.assertEqual(resp.status_code, 200)


class TestJobDispatchModeFiltering(APITestCase):
"""Test job filtering by dispatch_mode."""
Expand Down
34 changes: 32 additions & 2 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ami.base.permissions import ObjectPermission
from ami.base.views import ProjectMixin
from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param
from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param, processing_service_name_param
from ami.jobs.tasks import process_nats_pipeline_result
from ami.main.api.schemas import project_id_doc_param
from ami.main.api.views import DefaultViewSet
Expand Down Expand Up @@ -204,13 +204,16 @@ def get_queryset(self) -> QuerySet:
project_id_doc_param,
ids_only_param,
incomplete_only_param,
processing_service_name_param,
]
)
def list(self, request, *args, **kwargs):
_ = _log_processing_service_name(request, "list requested", logger)

return super().list(request, *args, **kwargs)

@extend_schema(
parameters=[batch_param],
parameters=[batch_param, processing_service_name_param],
responses={200: dict},
)
@action(detail=True, methods=["get"], name="tasks")
Expand All @@ -229,6 +232,7 @@ def tasks(self, request, pk=None):
except Exception as e:
raise ValidationError({"batch": str(e)}) from e

_ = _log_processing_service_name(request, f"tasks ({batch}) requested for job {job.pk}", job.logger)
# Only async_api jobs have tasks fetchable from NATS
if job.dispatch_mode != JobDispatchMode.ASYNC_API:
raise ValidationError("Only async_api jobs have fetchable tasks")
Expand All @@ -254,6 +258,9 @@ async def get_tasks():

return Response({"tasks": tasks})

@extend_schema(
parameters=[processing_service_name_param],
)
@action(detail=True, methods=["post"], name="result")
def result(self, request, pk=None):
"""
Expand All @@ -266,6 +273,8 @@ def result(self, request, pk=None):

job = self.get_object()

_ = _log_processing_service_name(request, f"result received for job {job.pk}", job.logger)

# Validate request data is a list
if isinstance(request.data, list):
results = request.data
Expand Down Expand Up @@ -325,3 +334,24 @@ def result(self, request, pk=None):
},
status=500,
)


def _log_processing_service_name(request, context: str, logger: logging.Logger) -> str | None:
"""
Log the processing_service_name from query parameters.

Args:
request: The HTTP request object
context: A string describing the operation (e.g., "tasks requested", "result received")
logger: A logging.Logger instance to use for logging
Returns:
The processing_service_name if provided, otherwise None
"""
processing_service_name = request.query_params.get("processing_service_name", None)

if processing_service_name:
logger.info(f"Jobs {context} by processing service: {processing_service_name}")
else:
logger.debug(f"Jobs {context} without processing service name")
Comment on lines +350 to +355
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Sanitize user-supplied processing_service_name before embedding in log messages.

processing_service_name is an unvalidated, user-controlled query parameter. Embedding it verbatim in log messages enables log injection: a compromised worker can supply a value containing \n followed by fabricated log records (e.g., "svc\nCRITICAL:ami.jobs:Fake security event"), which can poison structured log aggregators, SIEM pipelines, or flat log files. While callers must be authenticated, any valid processing service can exploit this.

Additionally, f-strings force eager string evaluation regardless of the effective log level; use %s lazy formatting instead.

🛡️ Proposed fix – sanitize input and use lazy log formatting
-    if processing_service_name:
-        logger.info(f"Jobs {context} by processing service: {processing_service_name}")
-    else:
-        logger.debug(f"Jobs {context} without processing service name")
+    if processing_service_name:
+        # Sanitize to prevent log injection via newlines / carriage returns
+        sanitized_name = processing_service_name.replace("\n", "\\n").replace("\r", "\\r")
+        logger.info("Jobs %s by processing service: %s", context, sanitized_name)
+    else:
+        logger.debug("Jobs %s without processing service name", context)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ami/jobs/views.py` around lines 350 - 355, The log embeds the untrusted query
param processing_service_name via f-strings (logger.info/logger.debug) which
allows log injection and forces eager formatting; instead sanitize
processing_service_name to remove control/newline characters (e.g., strip or
regex-remove \r, \n and other non-printables) and then pass it to the logger
using lazy formatting (logger.info("Jobs %s by processing service: %s", context,
safe_processing_service_name) and logger.debug("Jobs %s without processing
service name", context)) so replace the f-strings around processing_service_name
with sanitized value and percent-style/argument logging in the
processing_service_name handling code.


return processing_service_name
29 changes: 0 additions & 29 deletions ami/utils/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import requests
from django.forms import BooleanField, FloatField
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter
from requests.adapters import HTTPAdapter
from rest_framework.request import Request
from urllib3.util import Retry
Expand Down Expand Up @@ -144,30 +142,3 @@ def get_default_classification_threshold(project: "Project | None" = None, reque
return project.default_filters_score_threshold
else:
return default_threshold


project_id_doc_param = OpenApiParameter(
name="project_id",
description="Filter by project ID",
required=False,
type=int,
)

ids_only_param = OpenApiParameter(
name="ids_only",
description="Return only job IDs instead of full job objects",
required=False,
type=OpenApiTypes.BOOL,
)
incomplete_only_param = OpenApiParameter(
name="incomplete_only",
description="Filter to only show incomplete jobs (excludes SUCCESS, FAILURE, REVOKED)",
required=False,
type=OpenApiTypes.BOOL,
)
batch_param = OpenApiParameter(
name="batch",
description="Number of tasks to pull in the batch",
required=False,
type=OpenApiTypes.INT,
)