diff --git a/ami/ml/models/pipeline.py b/ami/ml/models/pipeline.py index a7470ec6a..6486e26d9 100644 --- a/ami/ml/models/pipeline.py +++ b/ami/ml/models/pipeline.py @@ -51,7 +51,7 @@ SourceImageResponse, ) from ami.ml.tasks import celery_app, create_detection_images -from ami.utils.requests import create_session +from ami.utils.requests import create_session, extract_error_message_from_response logger = logging.getLogger(__name__) @@ -242,10 +242,10 @@ def process_images( session = create_session() resp = session.post(endpoint_url, json=request_data.dict()) if not resp.ok: - try: - msg = resp.json()["detail"] - except (ValueError, KeyError): - msg = str(resp.content) + summary = request_data.summary() + error_msg = extract_error_message_from_response(resp) + msg = f"Failed to process {summary}: {error_msg}" + if job: job.logger.error(msg) else: @@ -1060,17 +1060,18 @@ def choose_processing_service_for_pipeline( f"{[processing_service.name for processing_service in processing_services]}" ) - # check the status of all processing services - timeout = 5 * 60.0 # 5 minutes - lowest_latency = timeout + # check the status of all processing services and pick the one with the lowest latency + lowest_latency = float("inf") processing_services_online = False for processing_service in processing_services: - status_response = processing_service.get_status() # @TODO pass timeout to get_status() - if status_response.server_live: + if processing_service.last_checked_live: processing_services_online = True - if status_response.latency < lowest_latency: - lowest_latency = status_response.latency + if ( + processing_service.last_checked_latency + and processing_service.last_checked_latency < lowest_latency + ): + lowest_latency = processing_service.last_checked_latency # pick the processing service that has lowest latency processing_service_lowest_latency = processing_service diff --git a/ami/ml/models/processing_service.py b/ami/ml/models/processing_service.py index e350d34a8..4711c5e73 100644 --- a/ami/ml/models/processing_service.py +++ b/ami/ml/models/processing_service.py @@ -18,6 +18,7 @@ ProcessingServiceInfoResponse, ProcessingServiceStatusResponse, ) +from ami.utils.requests import create_session logger = logging.getLogger(__name__) @@ -137,10 +138,18 @@ def create_pipelines( algorithms_created=algorithms_created, ) - def get_status(self, timeout=6): + def get_status(self, timeout=90) -> ProcessingServiceStatusResponse: """ Check the status of the processing service. This is a simple health check that pings the /readyz endpoint of the service. + + Uses urllib3 Retry with exponential backoff to handle cold starts and transient failures. + The timeout is set to 90s per attempt to accommodate serverless cold starts, especially for + services that need to load multiple models into memory. With automatic retries, transient + connection errors are handled gracefully. + + Args: + timeout: Request timeout in seconds per attempt (default: 90s for serverless cold starts) """ ready_check_url = urljoin(self.endpoint_url, "readyz") start_time = time.time() @@ -151,11 +160,17 @@ def get_status(self, timeout=6): self.last_checked = timestamp resp = None + # Create session with retry logic for connection errors and timeouts + session = create_session( + retries=3, + backoff_factor=2, # 0s, 2s, 4s delays between retries + status_forcelist=(500, 502, 503, 504), + ) + try: - resp = requests.get(ready_check_url, timeout=timeout) + resp = session.get(ready_check_url, timeout=timeout) resp.raise_for_status() self.last_checked_live = True - latency = time.time() - start_time except requests.exceptions.RequestException as e: error = f"Error connecting to {ready_check_url}: {e}" logger.error(error) @@ -176,6 +191,7 @@ def get_status(self, timeout=6): # but the intention is to show which ones are loaded into memory and ready to use. # @TODO: this may be overkill, but it is displayed in the UI now. try: + assert resp is not None pipelines_online: list[str] = resp.json().get("status", []) except (ValueError, KeyError) as e: error = f"Error parsing pipeline statuses from {ready_check_url}: {e}" diff --git a/ami/ml/schemas.py b/ami/ml/schemas.py index c555ef943..49e5efd8f 100644 --- a/ami/ml/schemas.py +++ b/ami/ml/schemas.py @@ -176,6 +176,25 @@ class PipelineRequest(pydantic.BaseModel): detections: list[DetectionRequest] | None = None config: PipelineRequestConfigParameters | dict | None = None + def summary(self) -> str: + """ + Return a human-friendly summary string of the request key details. + (number of images, pipeline name, number of detections, etc.) + + e.g. "pipeline request with 10 images and 25 detections to 'panama_moths_2023'" + + Returns: + str: A summary string. + """ + + num_images = len(self.source_images) + num_detections = len(self.detections) if self.detections else 0 + return ( + f"pipeline request with {num_images} image{'s' if num_images != 1 else ''} " + f"and {num_detections} detection{'s' if num_detections != 1 else ''} " + f"to pipeline '{self.pipeline}'" + ) + class PipelineResultsResponse(pydantic.BaseModel): # pipeline: PipelineChoice diff --git a/ami/ml/tests.py b/ami/ml/tests.py index f88bfbbc0..14e4374f2 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -117,6 +117,40 @@ def test_processing_service_pipeline_registration(self): class TestPipelineWithProcessingService(TestCase): + def test_run_pipeline_with_errors_from_processing_service(self): + """ + Run a real pipeline and verify that if an error occurs for one image, the error is logged in job.logs.stderr. + """ + from ami.jobs.models import Job + + # Setup test project, images, and job + project, deployment = setup_test_project() + captures = create_captures_from_files(deployment, skip_existing=False) + test_images = [image for image, frame in captures] + processing_service_instance = create_processing_service(project) + pipeline = processing_service_instance.pipelines.all().get(slug="constant") + job = Job.objects.create(project=project, name="Test Job Real Pipeline Error Handling", pipeline=pipeline) + + # Simulate an error by passing an invalid image (e.g., missing file or corrupt) + # Here, we manually set the path of one image to a non-existent file + error_image = test_images[0] + error_image.path = "/tmp/nonexistent_image.jpg" + error_image.save() + images = [error_image] + test_images[1:2] # Only two images for brevity + + # Run the pipeline and catch any error + try: + pipeline.process_images(images, job_id=job.pk, project_id=project.pk) + except Exception: + pass # Expected if the backend raises + + job.refresh_from_db() + stderr_logs = job.logs.stderr + # Check that an error message mentioning the failed image is present + assert any( + "Failed to process" in log for log in stderr_logs + ), f"Expected error message in job.logs.stderr, got: {stderr_logs}" + def setUp(self): self.project, self.deployment = setup_test_project() self.captures = create_captures_from_files(self.deployment, skip_existing=False) diff --git a/ami/utils/requests.py b/ami/utils/requests.py index dca9c5c43..aff13209d 100644 --- a/ami/utils/requests.py +++ b/ami/utils/requests.py @@ -41,6 +41,47 @@ def create_session( return session +def extract_error_message_from_response(resp: requests.Response) -> str: + """ + Extract detailed error information from an HTTP response. + + Prioritizes the "detail" field from JSON responses (FastAPI standard), + falls back to other fields, text content, or raw bytes. + + Args: + resp: The HTTP response object + + Returns: + A formatted error message string + """ + error_details = [f"HTTP {resp.status_code}: {resp.reason}"] + + try: + # Try to parse JSON response + resp_json = resp.json() + if isinstance(resp_json, dict): + # Check for the standard "detail" field first + if "detail" in resp_json: + error_details.append(f"Detail: {resp_json['detail']}") + else: + # Fallback: add all fields from the error response + for key, value in resp_json.items(): + error_details.append(f"{key}: {value}") + else: + error_details.append(f"Response: {resp_json}") + except (ValueError, KeyError): + # If JSON parsing fails, try to get text content + try: + content_text = resp.text + if content_text: + error_details.append(f"Response text: {content_text[:500]}") # Limit to first 500 chars + except Exception: + # Last resort: raw content + error_details.append(f"Response content: {resp.content[:500]}") + + return " | ".join(error_details) + + def get_active_classification_threshold(request: Request) -> float: """ Get the active classification threshold from request parameters. diff --git a/ami/utils/tests.py b/ami/utils/tests.py index 36d3c5bb6..73ddbc110 100644 --- a/ami/utils/tests.py +++ b/ami/utils/tests.py @@ -1,5 +1,8 @@ import datetime from unittest import TestCase +from unittest.mock import Mock + +import requests class TestUtils(TestCase): @@ -32,3 +35,33 @@ def test_extract_timestamps(self): self.assertEqual( result, expected_date, f"Failed for {filename}: expected {expected_date}, got {result}" ) + + def test_extract_error_message_from_response(self): + """Test extracting error messages from HTTP responses.""" + from ami.utils.requests import extract_error_message_from_response + + # Test with standard 'detail' field (FastAPI) + mock_response = Mock(spec=requests.Response) + mock_response.status_code = 500 + mock_response.reason = "Internal Server Error" + mock_response.json.return_value = {"detail": "CUDA out of memory"} + result = extract_error_message_from_response(mock_response) + self.assertEqual(result, "HTTP 500: Internal Server Error | Detail: CUDA out of memory") + + # Test fallback to non-standard fields + mock_response.json.return_value = {"error": "Invalid input"} + result = extract_error_message_from_response(mock_response) + self.assertIn("error: Invalid input", result) + + # Test fallback to text when JSON fails + mock_response.json.side_effect = ValueError("No JSON") + mock_response.text = "Service unavailable" + result = extract_error_message_from_response(mock_response) + self.assertIn("Response text: Service unavailable", result) + + # Test fallback to raw bytes when text access fails + mock_response.json.side_effect = ValueError("404 Not Found: Could not fetch image") + mock_response.text = property(lambda self: (_ for _ in ()).throw(Exception("text error"))) + mock_response.content = b"Raw error bytes" + result = extract_error_message_from_response(mock_response) + self.assertIn("Response content: b'Raw error bytes'", result)