diff --git a/tests/utilities/api_clients.py b/tests/utilities/api_clients.py index 762b1c1b4..ea9aa2da5 100644 --- a/tests/utilities/api_clients.py +++ b/tests/utilities/api_clients.py @@ -1,7 +1,9 @@ import json +from typing import Any, Callable import requests from pydantic import BaseModel +from requests import Response from nrlf.core.constants import Categories, PointerTypes from nrlf.core.model import ConnectionMetadata @@ -34,6 +36,35 @@ def add_pointer_type(self, pointer_type: PointerTypes): return self +def retry_if(status_codes: list[int]) -> Callable[..., Any]: + """ + Decorator to retry a function call if it returns certain errors + """ + + def wrapped_func(func: Callable[..., Response]) -> Callable[..., Response]: + def wrapper(*args: Any, **kwargs: Any) -> Any: + attempt_responses: list[Response] = [] + for attempt in range(2): + response = func(*args, **kwargs) + if not response.status_code or response.status_code not in status_codes: + return response + attempt_responses.append(response) + print( # noqa: T201 + f"Retrying due to {response.status_code} error in attempt {attempt + 1}..." + ) + + print( # noqa: T201 + f"All attempts failed with responses: {attempt_responses}" + ) + raise Exception( + f"Function failed after retries with responses: {attempt_responses}" + ) + + return wrapper + + return wrapped_func + + class ConsumerTestClient: def __init__(self, config: ClientConfig): @@ -60,14 +91,16 @@ def __init__(self, config: ClientConfig): self.request_headers.update(self.config.custom_headers) - def read(self, doc_ref_id: str): + @retry_if([502]) + def read(self, doc_ref_id: str) -> Response: return requests.get( f"{self.api_url}/DocumentReference/{doc_ref_id}", headers=self.request_headers, cert=self.config.client_cert, ) - def count(self, params: dict[str, str]): + @retry_if([502]) + def count(self, params: dict[str, str]) -> Response: return requests.get( f"{self.api_url}/DocumentReference/_count", params=params, @@ -75,6 +108,7 @@ def count(self, params: dict[str, str]): cert=self.config.client_cert, ) + @retry_if([502]) def search( self, nhs_number: str | None = None, @@ -82,7 +116,7 @@ def search( pointer_type: PointerTypes | None = None, category: Categories | None = None, extra_params: dict[str, str] | None = None, - ): + ) -> Response: params = {**(extra_params or {})} if nhs_number: @@ -114,6 +148,7 @@ def search( cert=self.config.client_cert, ) + @retry_if([502]) def search_post( self, nhs_number: str | None = None, @@ -121,7 +156,7 @@ def search_post( pointer_type: PointerTypes | None = None, category: Categories | None = None, extra_fields: dict[str, str] | None = None, - ): + ) -> Response: body = {**(extra_fields or {})} if nhs_number: @@ -156,7 +191,8 @@ def search_post( cert=self.config.client_cert, ) - def read_capability_statement(self): + @retry_if([502]) + def read_capability_statement(self) -> Response: return requests.get( f"{self.api_url}/metadata", headers=self.request_headers, @@ -189,7 +225,8 @@ def __init__(self, config: ClientConfig): self.request_headers.update(self.config.custom_headers) - def create(self, doc_ref): + @retry_if([502]) + def create(self, doc_ref) -> Response: return requests.post( f"{self.api_url}/DocumentReference", json=doc_ref, @@ -197,7 +234,8 @@ def create(self, doc_ref): cert=self.config.client_cert, ) - def create_text(self, doc_ref): + @retry_if([502]) + def create_text(self, doc_ref) -> Response: return requests.post( f"{self.api_url}/DocumentReference", data=doc_ref, @@ -205,7 +243,8 @@ def create_text(self, doc_ref): cert=self.config.client_cert, ) - def upsert(self, doc_ref): + @retry_if([502]) + def upsert(self, doc_ref) -> Response: return requests.put( f"{self.api_url}/DocumentReference", json=doc_ref, @@ -213,7 +252,8 @@ def upsert(self, doc_ref): cert=self.config.client_cert, ) - def upsert_text(self, doc_ref): + @retry_if([502]) + def upsert_text(self, doc_ref) -> Response: return requests.put( f"{self.api_url}/DocumentReference", data=doc_ref, @@ -221,7 +261,8 @@ def upsert_text(self, doc_ref): cert=self.config.client_cert, ) - def update(self, doc_ref, doc_ref_id: str): + @retry_if([502]) + def update(self, doc_ref, doc_ref_id: str) -> Response: return requests.put( f"{self.api_url}/DocumentReference/{doc_ref_id}", json=doc_ref, @@ -229,7 +270,8 @@ def update(self, doc_ref, doc_ref_id: str): cert=self.config.client_cert, ) - def update_text(self, doc_ref, doc_ref_id: str): + @retry_if([502]) + def update_text(self, doc_ref, doc_ref_id: str) -> Response: return requests.put( f"{self.api_url}/DocumentReference/{doc_ref_id}", data=doc_ref, @@ -237,26 +279,29 @@ def update_text(self, doc_ref, doc_ref_id: str): cert=self.config.client_cert, ) - def delete(self, doc_ref_id: str): + @retry_if([502]) + def delete(self, doc_ref_id: str) -> Response: return requests.delete( f"{self.api_url}/DocumentReference/{doc_ref_id}", headers=self.request_headers, cert=self.config.client_cert, ) - def read(self, doc_ref_id: str): + @retry_if([502]) + def read(self, doc_ref_id: str) -> Response: return requests.get( f"{self.api_url}/DocumentReference/{doc_ref_id}", headers=self.request_headers, cert=self.config.client_cert, ) + @retry_if([502]) def search( self, nhs_number: str | None = None, pointer_type: PointerTypes | None = None, extra_params: dict[str, str] | None = None, - ): + ) -> Response: params = {**(extra_params or {})} if nhs_number: @@ -277,12 +322,13 @@ def search( cert=self.config.client_cert, ) + @retry_if([502]) def search_post( self, nhs_number: str | None = None, pointer_type: PointerTypes | None = None, extra_fields: dict[str, str] | None = None, - ): + ) -> Response: body = {**(extra_fields or {})} if nhs_number: @@ -306,7 +352,8 @@ def search_post( cert=self.config.client_cert, ) - def read_capability_statement(self): + @retry_if([502]) + def read_capability_statement(self) -> Response: return requests.get( f"{self.api_url}/metadata", headers=self.request_headers,