diff --git a/providers/openfeature-provider-ofrep/src/openfeature/contrib/provider/ofrep/__init__.py b/providers/openfeature-provider-ofrep/src/openfeature/contrib/provider/ofrep/__init__.py index 69372e12..9993ed66 100644 --- a/providers/openfeature-provider-ofrep/src/openfeature/contrib/provider/ofrep/__init__.py +++ b/providers/openfeature-provider-ofrep/src/openfeature/contrib/provider/ofrep/__init__.py @@ -1,12 +1,13 @@ +import asyncio import re from collections.abc import Mapping, Sequence from datetime import datetime, timedelta, timezone from email.utils import parsedate_to_datetime +from json import JSONDecodeError from typing import Any, Callable, NoReturn, Optional, Union from urllib.parse import urljoin -import requests -from requests.exceptions import JSONDecodeError +import httpx from openfeature.evaluation_context import EvaluationContext from openfeature.exception import ( @@ -55,7 +56,23 @@ def __init__( self.headers_factory = headers_factory self.timeout = timeout self.retry_after: Optional[datetime] = None - self.session = requests.Session() + + self.client = httpx.Client() + self.client_async = httpx.AsyncClient() + self._client_async_is_entered = False + + def initialize(self, evaluation_context: EvaluationContext) -> None: + self.client.__enter__() + + def shutdown(self) -> None: + self.client.__exit__(None, None, None) + + try: + # TODO(someday): support non asyncio runtimes here + asyncio.get_running_loop().create_task(self.client_async.__aexit__(None, None, None)) + self._client_async_is_entered = False + except Exception: + pass def get_metadata(self) -> Metadata: return Metadata(name="OpenFeature Remote Evaluation Protocol Provider") @@ -73,6 +90,16 @@ def resolve_boolean_details( FlagType.BOOLEAN, flag_key, default_value, evaluation_context ) + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return await self._resolve_async( + FlagType.BOOLEAN, flag_key, default_value, evaluation_context + ) + def resolve_string_details( self, flag_key: str, @@ -83,6 +110,16 @@ def resolve_string_details( FlagType.STRING, flag_key, default_value, evaluation_context ) + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + return await self._resolve_async( + FlagType.STRING, flag_key, default_value, evaluation_context + ) + def resolve_integer_details( self, flag_key: str, @@ -93,6 +130,16 @@ def resolve_integer_details( FlagType.INTEGER, flag_key, default_value, evaluation_context ) + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + return await self._resolve_async( + FlagType.INTEGER, flag_key, default_value, evaluation_context + ) + def resolve_float_details( self, flag_key: str, @@ -103,6 +150,17 @@ def resolve_float_details( FlagType.FLOAT, flag_key, default_value, evaluation_context ) + + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + return await self._resolve_async( + FlagType.FLOAT, flag_key, default_value, evaluation_context + ) + def resolve_object_details( self, flag_key: str, @@ -115,6 +173,16 @@ def resolve_object_details( FlagType.OBJECT, flag_key, default_value, evaluation_context ) + async def resolve_object_details_async( + self, + flag_key: str, + default_value: Union[Sequence[FlagValueType], Mapping[str, FlagValueType]], + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[Sequence[FlagValueType] | Mapping[str, FlagValueType]]: + return await self._resolve_async( + FlagType.OBJECT, flag_key, default_value, evaluation_context + ) + def _get_ofrep_api_url(self, api_version: str = "v1") -> str: ofrep_base_url = ( self.base_url if self.base_url.endswith("/") else f"{self.base_url}/" @@ -146,7 +214,7 @@ def _resolve( self.retry_after = None try: - response = self.session.post( + response = self.client.post( urljoin(self._get_ofrep_api_url(), f"evaluate/flags/{flag_key}"), json=_build_request_data(evaluation_context), timeout=self.timeout, @@ -154,7 +222,7 @@ def _resolve( ) response.raise_for_status() - except requests.RequestException as e: + except httpx.HTTPError as e: self._handle_error(e) try: @@ -171,11 +239,66 @@ def _resolve( flag_metadata=data.get("metadata", {}), ) - def _handle_error(self, exception: requests.RequestException) -> NoReturn: - response = exception.response - if response is None: + async def _resolve_async( + self, + flag_type: FlagType, + flag_key: str, + default_value: Union[ + bool, + str, + int, + float, + dict, + list, + Sequence[FlagValueType], + Mapping[str, FlagValueType], + ], + evaluation_context: Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[Any]: + if not self._client_async_is_entered: + await self.client_async.__aenter__() + self._client_async_is_entered = True + + now = datetime.now(timezone.utc) + if self.retry_after and now <= self.retry_after: + raise GeneralError( + f"OFREP evaluation paused due to TooManyRequests until {self.retry_after}" + ) + elif self.retry_after: + self.retry_after = None + + try: + response = await self.client_async.post( + urljoin(self._get_ofrep_api_url(), f"evaluate/flags/{flag_key}"), + json=_build_request_data(evaluation_context), + timeout=self.timeout, + headers=self.headers_factory() if self.headers_factory else None, + ) + response.raise_for_status() + + except httpx.HTTPError as e: + self._handle_error(e) + + try: + data = response.json() + except JSONDecodeError as e: + raise ParseError(str(e)) from e + + _typecheck_flag_value(data["value"], flag_type) + + return FlagResolutionDetails( + value=data["value"], + reason=Reason[data["reason"]], + variant=data["variant"], + flag_metadata=data.get("metadata", {}), + ) + + def _handle_error(self, exception: httpx.HTTPError) -> NoReturn: + if not isinstance(exception, httpx.HTTPStatusError): raise GeneralError(str(exception)) from exception + response = exception.response + if response.status_code == 429: retry_after = response.headers.get("Retry-After") self.retry_after = _parse_retry_after(retry_after) @@ -205,6 +328,10 @@ def _handle_error(self, exception: requests.RequestException) -> NoReturn: raise OpenFeatureError(error_code, error_details) from exception + def __del__(self): + # Ensure clients get cleaned up + self.shutdown() + def _build_request_data( evaluation_context: Optional[EvaluationContext],