Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 64 additions & 16 deletions tests/utilities/api_clients.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import json
import logging
from typing import Any, Callable

import requests
from pydantic import BaseModel
from requests import Response

from nrlf.core.constants import Categories, PointerTypes
from nrlf.core.model import ConnectionMetadata

logger = logging.getLogger(__name__)


class ClientConfig(BaseModel):
base_url: str
Expand Down Expand Up @@ -34,6 +39,33 @@ def add_pointer_type(self, pointer_type: PointerTypes):
return self


def retry_if(status_codes: list[int]) -> Callable[..., Any]:
"""
Decorator to retry a function call if it returns certain errors
"""

def wrapped_func(func: Callable[..., Response]) -> Callable[..., Response]:
def wrapper(*args: Any, **kwargs: Any) -> Any:
attempt_responses: list[Response] = []
for attempt in range(3):
response = func(*args, **kwargs)
if not response.status_code or response.status_code not in status_codes:
return response
attempt_responses.append(response)
logger.warning(
f"Attempt {attempt + 1} failed with status code {response.status_code}"
)

logger.error(f"All attempts failed with responses: {attempt_responses}")
raise RuntimeError(
f"Function failed after retries with responses: {attempt_responses}"
)

return wrapper

return wrapped_func


class ConsumerTestClient:

def __init__(self, config: ClientConfig):
Expand All @@ -60,29 +92,32 @@ def __init__(self, config: ClientConfig):

self.request_headers.update(self.config.custom_headers)

def read(self, doc_ref_id: str):
@retry_if([502])
def read(self, doc_ref_id: str) -> Response:
return requests.get(
f"{self.api_url}/DocumentReference/{doc_ref_id}",
headers=self.request_headers,
cert=self.config.client_cert,
)

def count(self, params: dict[str, str]):
@retry_if([502])
def count(self, params: dict[str, str]) -> Response:
return requests.get(
f"{self.api_url}/DocumentReference/_count",
params=params,
headers=self.request_headers,
cert=self.config.client_cert,
)

@retry_if([502])
def search(
self,
nhs_number: str | None = None,
custodian: str | None = None,
pointer_type: PointerTypes | None = None,
category: Categories | None = None,
extra_params: dict[str, str] | None = None,
):
) -> Response:
params = {**(extra_params or {})}

if nhs_number:
Expand Down Expand Up @@ -114,14 +149,15 @@ def search(
cert=self.config.client_cert,
)

@retry_if([502])
def search_post(
self,
nhs_number: str | None = None,
custodian: str | None = None,
pointer_type: PointerTypes | None = None,
category: Categories | None = None,
extra_fields: dict[str, str] | None = None,
):
) -> Response:
body = {**(extra_fields or {})}

if nhs_number:
Expand Down Expand Up @@ -156,7 +192,8 @@ def search_post(
cert=self.config.client_cert,
)

def read_capability_statement(self):
@retry_if([502])
def read_capability_statement(self) -> Response:
return requests.get(
f"{self.api_url}/metadata",
headers=self.request_headers,
Expand Down Expand Up @@ -189,74 +226,83 @@ def __init__(self, config: ClientConfig):

self.request_headers.update(self.config.custom_headers)

def create(self, doc_ref):
@retry_if([502])
def create(self, doc_ref) -> Response:
return requests.post(
f"{self.api_url}/DocumentReference",
json=doc_ref,
headers=self.request_headers,
cert=self.config.client_cert,
)

def create_text(self, doc_ref):
@retry_if([502])
def create_text(self, doc_ref) -> Response:
return requests.post(
f"{self.api_url}/DocumentReference",
data=doc_ref,
headers=self.request_headers,
cert=self.config.client_cert,
)

def upsert(self, doc_ref):
@retry_if([502])
def upsert(self, doc_ref) -> Response:
return requests.put(
f"{self.api_url}/DocumentReference",
json=doc_ref,
headers=self.request_headers,
cert=self.config.client_cert,
)

def upsert_text(self, doc_ref):
@retry_if([502])
def upsert_text(self, doc_ref) -> Response:
return requests.put(
f"{self.api_url}/DocumentReference",
data=doc_ref,
headers=self.request_headers,
cert=self.config.client_cert,
)

def update(self, doc_ref, doc_ref_id: str):
@retry_if([502])
def update(self, doc_ref, doc_ref_id: str) -> Response:
return requests.put(
f"{self.api_url}/DocumentReference/{doc_ref_id}",
json=doc_ref,
headers=self.request_headers,
cert=self.config.client_cert,
)

def update_text(self, doc_ref, doc_ref_id: str):
@retry_if([502])
def update_text(self, doc_ref, doc_ref_id: str) -> Response:
return requests.put(
f"{self.api_url}/DocumentReference/{doc_ref_id}",
data=doc_ref,
headers=self.request_headers,
cert=self.config.client_cert,
)

def delete(self, doc_ref_id: str):
@retry_if([502])
def delete(self, doc_ref_id: str) -> Response:
return requests.delete(
f"{self.api_url}/DocumentReference/{doc_ref_id}",
headers=self.request_headers,
cert=self.config.client_cert,
)

def read(self, doc_ref_id: str):
@retry_if([502])
def read(self, doc_ref_id: str) -> Response:
return requests.get(
f"{self.api_url}/DocumentReference/{doc_ref_id}",
headers=self.request_headers,
cert=self.config.client_cert,
)

@retry_if([502])
def search(
self,
nhs_number: str | None = None,
pointer_type: PointerTypes | None = None,
extra_params: dict[str, str] | None = None,
):
) -> Response:
params = {**(extra_params or {})}

if nhs_number:
Expand All @@ -277,12 +323,13 @@ def search(
cert=self.config.client_cert,
)

@retry_if([502])
def search_post(
self,
nhs_number: str | None = None,
pointer_type: PointerTypes | None = None,
extra_fields: dict[str, str] | None = None,
):
) -> Response:
body = {**(extra_fields or {})}

if nhs_number:
Expand All @@ -306,7 +353,8 @@ def search_post(
cert=self.config.client_cert,
)

def read_capability_statement(self):
@retry_if([502])
def read_capability_statement(self) -> Response:
return requests.get(
f"{self.api_url}/metadata",
headers=self.request_headers,
Expand Down
Loading