diff --git a/infrastructure/images/api-gateway-mock/resources/server.py b/infrastructure/images/api-gateway-mock/resources/server.py index 6208cad9..cf59a7a1 100644 --- a/infrastructure/images/api-gateway-mock/resources/server.py +++ b/infrastructure/images/api-gateway-mock/resources/server.py @@ -37,11 +37,16 @@ def forward_request(path_params): app.logger.info("received request with data: %s", request.get_data(as_text=True)) + x_correlation_id = request.headers.get("X-Correlation-ID") + forwarded_headers = {k.lower(): v for k, v in request.headers.items()} + forwarded_headers["nhsd-correlation-id"] = x_correlation_id + response = requests.post( "http://pathology-api:8080/2015-03-31" # NOSONAR python:S5332 "/functions/function/invocations", json={ "body": request.get_data(as_text=True).replace("\n", "").replace(" ", ""), + "headers": forwarded_headers, "requestContext": { "http": { "path": f"/{path_params}", diff --git a/pathology-api/lambda_handler.py b/pathology-api/lambda_handler.py index a769da89..57bdad03 100644 --- a/pathology-api/lambda_handler.py +++ b/pathology-api/lambda_handler.py @@ -13,6 +13,7 @@ from pathology_api.fhir.r4.resources import Bundle, OperationOutcome from pathology_api.handler import handle_request from pathology_api.logging import get_logger +from pathology_api.request_context import set_correlation_id _logger = get_logger(__name__) @@ -99,33 +100,48 @@ def handle_exception(exception: Exception) -> Response[str]: @app.get("/_status") def status() -> Response[str]: _logger.debug("Status check endpoint called") - return Response(status_code=200, body="OK", headers={"Content-Type": "text/plain"}) + return Response( + status_code=200, + body='{"status": "pass"}', + headers={"Content-Type": "application/json"}, + ) + + +_CORRELATION_ID_HEADER = "nhsd-correlation-id" @app.post("/FHIR/R4/Bundle") def post_result() -> Response[str]: - _logger.debug("Post result endpoint called.") + correlation_id = app.current_event.headers.get(_CORRELATION_ID_HEADER) + + if not correlation_id: + _logger.warning( + "no correlation id. Current event headers: %s", app.current_event.headers + ) + raise ValueError(f"Missing required header: {_CORRELATION_ID_HEADER}") + with set_correlation_id(correlation_id): + _logger.debug("Post result endpoint called.") - try: - payload = app.current_event.json_body - except JSONDecodeError as e: - raise ValidationError("Invalid payload provided.") from e + try: + payload = app.current_event.json_body + except JSONDecodeError as e: + raise ValidationError("Invalid payload provided.") from e - _logger.debug("Payload received: %s", payload) + _logger.debug("Payload received: %s", payload) - if payload is None: - raise ValidationError( - "Resources must be provided as a bundle of type 'document'" - ) + if payload is None: + raise ValidationError( + "Resources must be provided as a bundle of type 'document'" + ) - bundle = Bundle.model_validate(payload, by_alias=True) + bundle = Bundle.model_validate(payload, by_alias=True) - response = handle_request(bundle) + response = handle_request(bundle) - return _with_default_headers( - status_code=200, - body=response, - ) + return _with_default_headers( + status_code=200, + body=response, + ) def handler(data: dict[str, Any], context: LambdaContext) -> dict[str, Any]: diff --git a/pathology-api/src/pathology_api/logging.py b/pathology-api/src/pathology_api/logging.py index d094698c..fc59087e 100644 --- a/pathology-api/src/pathology_api/logging.py +++ b/pathology-api/src/pathology_api/logging.py @@ -1,7 +1,18 @@ +import logging from typing import Any, Protocol from aws_lambda_powertools import Logger +from pathology_api.request_context import get_correlation_id + + +class _CorrelationIdFilter(logging.Filter): + """Injects the current correlation ID into every log record.""" + + def filter(self, record: logging.LogRecord) -> bool: + record.correlation_id = get_correlation_id() + return True + class LogProvider(Protocol): """Protocol defining required contract for a logger.""" @@ -19,4 +30,6 @@ def exception(self, msg: str, *args: Any, **kwargs: Any) -> None: ... def get_logger(service: str) -> LogProvider: """Get a configured logger instance.""" - return Logger(service=service, level="DEBUG", serialize_stacktrace=True) + logger = Logger(service=service, level="DEBUG", serialize_stacktrace=True) + logger.addFilter(_CorrelationIdFilter()) + return logger diff --git a/pathology-api/src/pathology_api/request_context.py b/pathology-api/src/pathology_api/request_context.py new file mode 100644 index 00000000..df177fea --- /dev/null +++ b/pathology-api/src/pathology_api/request_context.py @@ -0,0 +1,20 @@ +from collections.abc import Generator +from contextlib import contextmanager +from contextvars import ContextVar + +_correlation_id: ContextVar[str] = ContextVar("correlation_id", default="") + + +@contextmanager +def set_correlation_id(value: str) -> Generator[None, None, None]: + """Set the correlation ID for the current request context.""" + _correlation_id.set(value) + try: + yield None + finally: + _correlation_id.set("") + + +def get_correlation_id() -> str: + """Get the correlation ID for the current request context.""" + return _correlation_id.get() diff --git a/pathology-api/src/pathology_api/test_request_context.py b/pathology-api/src/pathology_api/test_request_context.py new file mode 100644 index 00000000..d29ac646 --- /dev/null +++ b/pathology-api/src/pathology_api/test_request_context.py @@ -0,0 +1,9 @@ +from pathology_api.request_context import get_correlation_id, set_correlation_id + + +class TestSetAndGetCorrelationId: + def test_correlation_id_is_cleared_after_context_exit(self) -> None: + with set_correlation_id("round-trip-test-123"): + assert get_correlation_id() == "round-trip-test-123" + + assert get_correlation_id() == "" diff --git a/pathology-api/test_lambda_handler.py b/pathology-api/test_lambda_handler.py index 7f867aea..edaf5354 100644 --- a/pathology-api/test_lambda_handler.py +++ b/pathology-api/test_lambda_handler.py @@ -16,9 +16,11 @@ def _create_test_event( body: str | None = None, path_params: str | None = None, request_method: str | None = None, + headers: dict[str, str] | None = None, ) -> dict[str, Any]: return { "body": body, + "headers": headers or {}, "requestContext": { "http": { "path": f"/{path_params}", @@ -58,13 +60,15 @@ def test_create_test_result_success(self) -> None: body=bundle.model_dump_json(by_alias=True), path_params="FHIR/R4/Bundle", request_method="POST", + headers={"nhsd-correlation-id": "test-correlation-id"}, ) context = LambdaContext() response = handler(event, context) assert response["statusCode"] == 200 - assert response["headers"] == {"Content-Type": "application/fhir+json"} + assert response["headers"]["Content-Type"] == "application/fhir+json" + assert response["headers"]["nhsd-correlation-id"] == "test-correlation-id" response_body = response["body"] assert isinstance(response_body, str) @@ -76,16 +80,53 @@ def test_create_test_result_success(self) -> None: # A UUID value so can only check its presence. assert response_bundle.id is not None + def test_missing_correlation_id_header_returns_500(self) -> None: + bundle = Bundle.create( + type="document", + entry=[ + Bundle.Entry( + fullUrl="composition", + resource=Composition.create( + subject=LogicalReference( + PatientIdentifier.from_nhs_number("nhs_number") + ) + ), + ) + ], + ) + event = self._create_test_event( + body=bundle.model_dump_json(by_alias=True), + path_params="FHIR/R4/Bundle", + request_method="POST", + ) + context = LambdaContext() + + response = handler(event, context) + + assert response["statusCode"] == 500 + assert response["headers"] == {"Content-Type": "application/fhir+json"} + + returned_issue = self._parse_returned_issue(response["body"]) + assert returned_issue["severity"] == "fatal" + assert returned_issue["code"] == "exception" + assert ( + returned_issue["diagnostics"] + == "Missing required header: nhsd-correlation-id" + ) + def test_create_test_result_no_payload(self) -> None: event = self._create_test_event( - path_params="FHIR/R4/Bundle", request_method="POST" + path_params="FHIR/R4/Bundle", + request_method="POST", + headers={"nhsd-correlation-id": "test-correlation-id"}, ) context = LambdaContext() response = handler(event, context) assert response["statusCode"] == 400 - assert response["headers"] == {"Content-Type": "application/fhir+json"} + assert response["headers"]["Content-Type"] == "application/fhir+json" + assert response["headers"]["nhsd-correlation-id"] == "test-correlation-id" returned_issue = self._parse_returned_issue(response["body"]) @@ -98,14 +139,18 @@ def test_create_test_result_no_payload(self) -> None: def test_create_test_result_empty_payload(self) -> None: event = self._create_test_event( - body="{}", path_params="FHIR/R4/Bundle", request_method="POST" + body="{}", + path_params="FHIR/R4/Bundle", + request_method="POST", + headers={"nhsd-correlation-id": "test-correlation-id"}, ) context = LambdaContext() response = handler(event, context) assert response["statusCode"] == 400 - assert response["headers"] == {"Content-Type": "application/fhir+json"} + assert response["headers"]["Content-Type"] == "application/fhir+json" + assert response["headers"]["nhsd-correlation-id"] == "test-correlation-id" returned_issue = self._parse_returned_issue(response["body"]) @@ -118,14 +163,18 @@ def test_create_test_result_empty_payload(self) -> None: def test_create_test_result_invalid_json(self) -> None: event = self._create_test_event( - body="invalid json", path_params="FHIR/R4/Bundle", request_method="POST" + body="invalid json", + path_params="FHIR/R4/Bundle", + request_method="POST", + headers={"nhsd-correlation-id": "test-correlation-id"}, ) context = LambdaContext() response = handler(event, context) assert response["statusCode"] == 400 - assert response["headers"] == {"Content-Type": "application/fhir+json"} + assert response["headers"]["Content-Type"] == "application/fhir+json" + assert response["headers"]["nhsd-correlation-id"] == "test-correlation-id" returned_issue = self._parse_returned_issue(response["body"]) assert returned_issue["severity"] == "error" @@ -169,6 +218,7 @@ def test_create_test_result_processing_error( body=bundle.model_dump_json(by_alias=True), path_params="FHIR/R4/Bundle", request_method="POST", + headers={"nhsd-correlation-id": "test-correlation-id"}, ) context = LambdaContext() @@ -176,7 +226,8 @@ def test_create_test_result_processing_error( response = handler(event, context) assert response["statusCode"] == expected_status_code - assert response["headers"] == {"Content-Type": "application/fhir+json"} + assert response["headers"]["Content-Type"] == "application/fhir+json" + assert response["headers"]["nhsd-correlation-id"] == "test-correlation-id" returned_issue = self._parse_returned_issue(response["body"]) assert returned_issue == expected_issue @@ -207,6 +258,7 @@ def test_create_test_result_model_validate_error( body=bundle.model_dump_json(by_alias=True), path_params="FHIR/R4/Bundle", request_method="POST", + headers={"nhsd-correlation-id": "test-correlation-id"}, ) context = LambdaContext() @@ -217,7 +269,8 @@ def test_create_test_result_model_validate_error( response = handler(event, context) assert response["statusCode"] == 400 - assert response["headers"] == {"Content-Type": "application/fhir+json"} + assert response["headers"]["Content-Type"] == "application/fhir+json" + assert response["headers"]["nhsd-correlation-id"] == "test-correlation-id" returned_issue = self._parse_returned_issue(response["body"]) assert returned_issue["severity"] == "error" @@ -225,11 +278,16 @@ def test_create_test_result_model_validate_error( assert returned_issue["diagnostics"] == expected_diagnostic def test_status_success(self) -> None: - event = self._create_test_event(path_params="_status", request_method="GET") + event = self._create_test_event( + path_params="_status", + request_method="GET", + headers={"nhsd-correlation-id": "test-correlation-id"}, + ) context = LambdaContext() response = handler(event, context) assert response["statusCode"] == 200 - assert response["body"] == "OK" - assert response["headers"] == {"Content-Type": "text/plain"} + assert response["body"] == '{"status": "pass"}' + assert response["headers"]["Content-Type"] == "application/json" + assert response["headers"]["nhsd-correlation-id"] == "test-correlation-id" diff --git a/pathology-api/tests/conftest.py b/pathology-api/tests/conftest.py index 191c21d6..e1a21b8f 100644 --- a/pathology-api/tests/conftest.py +++ b/pathology-api/tests/conftest.py @@ -17,7 +17,11 @@ class Client(Protocol): """Protocol defining the interface for HTTP clients.""" def send( - self, data: str, path: str, request_method: _RequestMethod + self, + data: str, + path: str, + request_method: _RequestMethod, + headers: dict[str, str] | None = None, ) -> requests.Response: """ Send a request to the APIs with some given parameters. @@ -31,7 +35,10 @@ def send( ... def send_without_payload( - self, path: str, request_method: _RequestMethod + self, + path: str, + request_method: _RequestMethod, + headers: dict[str, str] | None = None, ) -> requests.Response: """ Send a request to the APIs without a payload. @@ -47,24 +54,47 @@ def send_without_payload( class LocalClient: """HTTP client that sends requests to the Lambda via the RIE (no auth headers).""" - def __init__(self, lambda_url: str, timeout: timedelta = timedelta(seconds=1)): + def __init__( + self, + lambda_url: str, + headers: dict[str, str] | None = None, + timeout: timedelta = timedelta(seconds=1), + ): self._lambda_url = lambda_url + self._default_headers = {"Content-Type": "application/fhir+json"} | ( + headers or {} + ) self._timeout = timeout.total_seconds() def send( - self, data: str, path: str, request_method: _RequestMethod + self, + data: str, + path: str, + request_method: _RequestMethod, + headers: dict[str, str] | None = None, ) -> requests.Response: return self._send( - data=data, path=path, include_payload=True, request_method=request_method + data=data, + path=path, + include_payload=True, + request_method=request_method, + headers=headers, ) def send_without_payload( - self, path: str, request_method: _RequestMethod + self, + path: str, + request_method: _RequestMethod, + headers: dict[str, str] | None = None, ) -> requests.Response: return self._send( - data=None, path=path, include_payload=False, request_method=request_method + data=None, + path=path, + include_payload=False, + request_method=request_method, + headers=headers, ) def _send( @@ -73,20 +103,24 @@ def _send( path: str, include_payload: bool, request_method: _RequestMethod, + headers: dict[str, str] | None = None, ) -> requests.Response: url = f"{self._lambda_url}/{path}" + merged_headers = self._default_headers | (headers or {}) match request_method: case "POST": return requests.post( url, data=data if include_payload else None, timeout=self._timeout, + headers=merged_headers, ) case "GET": return requests.get( url, data=data if include_payload else None, timeout=self._timeout, + headers=merged_headers, ) @@ -211,7 +245,9 @@ def client(request: pytest.FixtureRequest, base_url: str) -> Client: env = request.config.getoption("--env") if env == "local": - return LocalClient(lambda_url=base_url) + return LocalClient( + lambda_url=base_url, + ) elif env == "remote": return _create_remote_client(request) else: diff --git a/pathology-api/tests/integration/test_endpoints.py b/pathology-api/tests/integration/test_endpoints.py index c123dc7c..eda094f4 100644 --- a/pathology-api/tests/integration/test_endpoints.py +++ b/pathology-api/tests/integration/test_endpoints.py @@ -31,10 +31,15 @@ def test_bundle_returns_200(self, client: Client) -> None: data=bundle.model_dump_json(by_alias=True), path="FHIR/R4/Bundle", request_method="POST", + headers={"X-Correlation-ID": "test-correlation-id-555666777"}, ) assert response.status_code == 200 assert response.headers["Content-Type"] == "application/fhir+json" + assert response.headers["X-Correlation-ID"] == "test-correlation-id-555666777" + assert response.headers["nhsd-correlation-id"].startswith( + ".test-correlation-id-555666777" + ) response_data = response.json() response_bundle = Bundle.model_validate(response_data, by_alias=True) @@ -256,15 +261,29 @@ class TestStatusEndpoint: @pytest.mark.status_auth_headers def test_status_returns_200(self, client: Client) -> None: - response = client.send_without_payload(request_method="GET", path="_status") + response = client.send_without_payload( + request_method="GET", + path="_status", + headers={"X-Correlation-ID": "test-correlation-id-111222333"}, + ) assert response.status_code == 200 assert response.headers["Content-Type"] == "application/json" + assert response.headers["nhsd-correlation-id"].startswith( + ".test-correlation-id-111222333.rrt-" + ) + + import json + import logging + + logger = logging.getLogger(__name__) + logger.warning( + "/_status response JSON: %s", json.dumps(response.json(), indent=2) + ) parsed = StatusResponse.model_validate(response.json()) assert parsed.status == "pass" assert parsed.checks.healthcheck.responseCode == 200 - assert parsed.checks.healthcheck.outcome == "OK" class StatusLinks(BaseModel): @@ -275,7 +294,7 @@ class HealthCheck(BaseModel): status: Literal["pass", "fail"] timeout: Literal["true", "false"] responseCode: int - outcome: str + outcome: dict[Any, Any] links: StatusLinks