Skip to content
5 changes: 5 additions & 0 deletions infrastructure/images/api-gateway-mock/resources/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,16 @@
def forward_request(path_params):
app.logger.info("received request with data: %s", request.get_data(as_text=True))

x_correlation_id = request.headers.get("X-Correlation-ID")
forwarded_headers = {k.lower(): v for k, v in request.headers.items()}
forwarded_headers["nhsd-correlation-id"] = x_correlation_id

response = requests.post(
"http://pathology-api:8080/2015-03-31" # NOSONAR python:S5332
"/functions/function/invocations",
json={
"body": request.get_data(as_text=True).replace("\n", "").replace(" ", ""),
"headers": forwarded_headers,
"requestContext": {
"http": {
"path": f"/{path_params}",
Expand Down
47 changes: 30 additions & 17 deletions pathology-api/lambda_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pathology_api.fhir.r4.resources import Bundle, OperationOutcome
from pathology_api.handler import handle_request
from pathology_api.logging import get_logger
from pathology_api.request_context import set_correlation_id

_logger = get_logger(__name__)

Expand Down Expand Up @@ -99,33 +100,45 @@ def handle_exception(exception: Exception) -> Response[str]:
@app.get("/_status")
def status() -> Response[str]:
_logger.debug("Status check endpoint called")
return Response(status_code=200, body="OK", headers={"Content-Type": "text/plain"})
return Response(
status_code=200,
body='{"status": "pass"}',
headers={"Content-Type": "application/json"},
)


_CORRELATION_ID_HEADER = "nhsd-correlation-id"


@app.post("/FHIR/R4/Bundle")
def post_result() -> Response[str]:
_logger.debug("Post result endpoint called.")
correlation_id = app.current_event.headers.get(_CORRELATION_ID_HEADER)

try:
payload = app.current_event.json_body
except JSONDecodeError as e:
raise ValidationError("Invalid payload provided.") from e
if not correlation_id:
raise ValueError(f"Missing required header: {_CORRELATION_ID_HEADER}")
with set_correlation_id(correlation_id):
_logger.debug("Post result endpoint called.")

_logger.debug("Payload received: %s", payload)
try:
payload = app.current_event.json_body
except JSONDecodeError as e:
raise ValidationError("Invalid payload provided.") from e

if payload is None:
raise ValidationError(
"Resources must be provided as a bundle of type 'document'"
)
_logger.debug("Payload received: %s", payload)

bundle = Bundle.model_validate(payload, by_alias=True)
if payload is None:
raise ValidationError(
"Resources must be provided as a bundle of type 'document'"
)

response = handle_request(bundle)
bundle = Bundle.model_validate(payload, by_alias=True)

return _with_default_headers(
status_code=200,
body=response,
)
response = handle_request(bundle)

return _with_default_headers(
status_code=200,
body=response,
)


def handler(data: dict[str, Any], context: LambdaContext) -> dict[str, Any]:
Expand Down
15 changes: 14 additions & 1 deletion pathology-api/src/pathology_api/logging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import logging
from typing import Any, Protocol

from aws_lambda_powertools import Logger

from pathology_api.request_context import get_correlation_id


class _CorrelationIdFilter(logging.Filter):
"""Injects the current correlation ID into every log record."""

def filter(self, record: logging.LogRecord) -> bool:
record.correlation_id = get_correlation_id()
return True


class LogProvider(Protocol):
"""Protocol defining required contract for a logger."""
Expand All @@ -19,4 +30,6 @@ def exception(self, msg: str, *args: Any, **kwargs: Any) -> None: ...

def get_logger(service: str) -> LogProvider:
"""Get a configured logger instance."""
return Logger(service=service, level="DEBUG", serialize_stacktrace=True)
logger = Logger(service=service, level="DEBUG", serialize_stacktrace=True)
logger.addFilter(_CorrelationIdFilter())
return logger
20 changes: 20 additions & 0 deletions pathology-api/src/pathology_api/request_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar

_correlation_id: ContextVar[str] = ContextVar("correlation_id", default="")


@contextmanager
def set_correlation_id(value: str) -> Generator[None, None, None]:
"""Set the correlation ID for the current request context."""
_correlation_id.set(value)
try:
yield None
finally:
_correlation_id.set("")


def get_correlation_id() -> str:
"""Get the correlation ID for the current request context."""
return _correlation_id.get()
9 changes: 9 additions & 0 deletions pathology-api/src/pathology_api/test_request_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from pathology_api.request_context import get_correlation_id, set_correlation_id


class TestSetAndGetCorrelationId:
def test_correlation_id_is_cleared_after_context_exit(self) -> None:
with set_correlation_id("round-trip-test-123"):
assert get_correlation_id() == "round-trip-test-123"

assert get_correlation_id() == ""
82 changes: 70 additions & 12 deletions pathology-api/test_lambda_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ def _create_test_event(
body: str | None = None,
path_params: str | None = None,
request_method: str | None = None,
headers: dict[str, str] | None = None,
) -> dict[str, Any]:
return {
"body": body,
"headers": headers or {},
"requestContext": {
"http": {
"path": f"/{path_params}",
Expand Down Expand Up @@ -58,13 +60,15 @@ def test_create_test_result_success(self) -> None:
body=bundle.model_dump_json(by_alias=True),
path_params="FHIR/R4/Bundle",
request_method="POST",
headers={"nhsd-correlation-id": "test-correlation-id"},
)
context = LambdaContext()

response = handler(event, context)

assert response["statusCode"] == 200
assert response["headers"] == {"Content-Type": "application/fhir+json"}
assert response["headers"]["Content-Type"] == "application/fhir+json"
assert response["headers"]["nhsd-correlation-id"] == "test-correlation-id"

response_body = response["body"]
assert isinstance(response_body, str)
Expand All @@ -76,16 +80,53 @@ def test_create_test_result_success(self) -> None:
# A UUID value so can only check its presence.
assert response_bundle.id is not None

def test_missing_correlation_id_header_returns_500(self) -> None:
bundle = Bundle.create(
type="document",
entry=[
Bundle.Entry(
fullUrl="composition",
resource=Composition.create(
subject=LogicalReference(
PatientIdentifier.from_nhs_number("nhs_number")
)
),
)
],
)
event = self._create_test_event(
body=bundle.model_dump_json(by_alias=True),
path_params="FHIR/R4/Bundle",
request_method="POST",
)
context = LambdaContext()

response = handler(event, context)

assert response["statusCode"] == 500
assert response["headers"] == {"Content-Type": "application/fhir+json"}

returned_issue = self._parse_returned_issue(response["body"])
assert returned_issue["severity"] == "fatal"
assert returned_issue["code"] == "exception"
assert (
returned_issue["diagnostics"]
== "Missing required header: nhsd-correlation-id"
)

def test_create_test_result_no_payload(self) -> None:
event = self._create_test_event(
path_params="FHIR/R4/Bundle", request_method="POST"
path_params="FHIR/R4/Bundle",
request_method="POST",
headers={"nhsd-correlation-id": "test-correlation-id"},
)
context = LambdaContext()

response = handler(event, context)

assert response["statusCode"] == 400
assert response["headers"] == {"Content-Type": "application/fhir+json"}
assert response["headers"]["Content-Type"] == "application/fhir+json"
assert response["headers"]["nhsd-correlation-id"] == "test-correlation-id"

returned_issue = self._parse_returned_issue(response["body"])

Expand All @@ -98,14 +139,18 @@ def test_create_test_result_no_payload(self) -> None:

def test_create_test_result_empty_payload(self) -> None:
event = self._create_test_event(
body="{}", path_params="FHIR/R4/Bundle", request_method="POST"
body="{}",
path_params="FHIR/R4/Bundle",
request_method="POST",
headers={"nhsd-correlation-id": "test-correlation-id"},
)
context = LambdaContext()

response = handler(event, context)

assert response["statusCode"] == 400
assert response["headers"] == {"Content-Type": "application/fhir+json"}
assert response["headers"]["Content-Type"] == "application/fhir+json"
assert response["headers"]["nhsd-correlation-id"] == "test-correlation-id"

returned_issue = self._parse_returned_issue(response["body"])

Expand All @@ -118,14 +163,18 @@ def test_create_test_result_empty_payload(self) -> None:

def test_create_test_result_invalid_json(self) -> None:
event = self._create_test_event(
body="invalid json", path_params="FHIR/R4/Bundle", request_method="POST"
body="invalid json",
path_params="FHIR/R4/Bundle",
request_method="POST",
headers={"nhsd-correlation-id": "test-correlation-id"},
)
context = LambdaContext()

response = handler(event, context)

assert response["statusCode"] == 400
assert response["headers"] == {"Content-Type": "application/fhir+json"}
assert response["headers"]["Content-Type"] == "application/fhir+json"
assert response["headers"]["nhsd-correlation-id"] == "test-correlation-id"

returned_issue = self._parse_returned_issue(response["body"])
assert returned_issue["severity"] == "error"
Expand Down Expand Up @@ -169,14 +218,16 @@ def test_create_test_result_processing_error(
body=bundle.model_dump_json(by_alias=True),
path_params="FHIR/R4/Bundle",
request_method="POST",
headers={"nhsd-correlation-id": "test-correlation-id"},
)
context = LambdaContext()

with patch("lambda_handler.handle_request", side_effect=error):
response = handler(event, context)

assert response["statusCode"] == expected_status_code
assert response["headers"] == {"Content-Type": "application/fhir+json"}
assert response["headers"]["Content-Type"] == "application/fhir+json"
assert response["headers"]["nhsd-correlation-id"] == "test-correlation-id"

returned_issue = self._parse_returned_issue(response["body"])
assert returned_issue == expected_issue
Expand Down Expand Up @@ -207,6 +258,7 @@ def test_create_test_result_model_validate_error(
body=bundle.model_dump_json(by_alias=True),
path_params="FHIR/R4/Bundle",
request_method="POST",
headers={"nhsd-correlation-id": "test-correlation-id"},
)
context = LambdaContext()

Expand All @@ -217,19 +269,25 @@ def test_create_test_result_model_validate_error(
response = handler(event, context)

assert response["statusCode"] == 400
assert response["headers"] == {"Content-Type": "application/fhir+json"}
assert response["headers"]["Content-Type"] == "application/fhir+json"
assert response["headers"]["nhsd-correlation-id"] == "test-correlation-id"

returned_issue = self._parse_returned_issue(response["body"])
assert returned_issue["severity"] == "error"
assert returned_issue["code"] == "invalid"
assert returned_issue["diagnostics"] == expected_diagnostic

def test_status_success(self) -> None:
event = self._create_test_event(path_params="_status", request_method="GET")
event = self._create_test_event(
path_params="_status",
request_method="GET",
headers={"nhsd-correlation-id": "test-correlation-id"},
)
context = LambdaContext()

response = handler(event, context)

assert response["statusCode"] == 200
assert response["body"] == "OK"
assert response["headers"] == {"Content-Type": "text/plain"}
assert response["body"] == '{"status": "pass"}'
assert response["headers"]["Content-Type"] == "application/json"
assert response["headers"]["nhsd-correlation-id"] == "test-correlation-id"
Loading
Loading