diff --git a/openeo_driver/util/auth.py b/openeo_driver/util/auth.py index 7377e2f0..97c514f2 100644 --- a/openeo_driver/util/auth.py +++ b/openeo_driver/util/auth.py @@ -6,7 +6,12 @@ from typing import Mapping, NamedTuple, Optional, Union import requests -from openeo.rest.auth.oidc import OidcClientCredentialsAuthenticator, OidcClientInfo, OidcProviderInfo +from openeo.rest.auth.oidc import ( + OidcClientCredentialsAuthenticator, + OidcClientInfo, + OidcProviderInfo, + AccessTokenResult, +) from openeo.util import str_truncate _log = logging.getLogger(__name__) @@ -77,7 +82,7 @@ class ClientCredentialsAccessTokenHelper: - call `get_access_token()` to get an access token where necessary """ - __slots__ = ("_authenticator", "_session", "_cache", "_default_ttl") + __slots__ = ("_authenticator", "_session", "_cache", "_default_ttl", "_expiration_threshold") def __init__( self, @@ -85,11 +90,13 @@ def __init__( credentials: Optional[ClientCredentials] = None, session: Optional[requests.Session] = None, default_ttl: float = 5 * 60, + expiration_threshold: float = 300, ): self._session = session self._authenticator: Optional[OidcClientCredentialsAuthenticator] = None self._cache = _AccessTokenCache("", 0) self._default_ttl = default_ttl + self._expiration_threshold = expiration_threshold if credentials: self.setup_credentials(credentials) @@ -116,18 +123,25 @@ def setup_credentials(self, credentials: ClientCredentials) -> None: client_info=client_info, requests_session=self._session ) - def _get_access_token(self) -> str: + def _get_access_token(self) -> AccessTokenResult: """Get an access token using the configured authenticator.""" if not self._authenticator: raise RuntimeError("No authentication set up") _log.debug(f"{self.__class__.__name__} getting access token") - tokens = self._authenticator.get_tokens() - return tokens.access_token + access_token_response = self._authenticator.get_tokens() + return access_token_response def get_access_token(self) -> str: """Get an access token using the configured authenticator.""" if time.time() > self._cache.expires_at: - access_token = self._get_access_token() - # TODO: get expiry from access token itself? - self._cache = _AccessTokenCache(access_token, time.time() + self._default_ttl) + access_token_response = self._get_access_token() + access_token = access_token_response.access_token + self._cache = _AccessTokenCache(access_token, self._get_access_token_expiry_time(access_token_response)) return self._cache.access_token + + def _get_access_token_expiry_time(self, access_token_response: AccessTokenResult) -> float: + if access_token_response.expires_in is None: + return time.time() + self._default_ttl + else: + # Expire the cache entry before the entry actually expires + return time.time() + (access_token_response.expires_in - self._expiration_threshold) diff --git a/setup.py b/setup.py index 25e94a8f..280fe588 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ "flask>=2.0.0", "werkzeug>=3.0.3", # https://github.com/Open-EO/openeo-python-driver/issues/243 "requests>=2.28.0", - "openeo>=0.46.0.a2.dev", + "openeo>=0.49.0.a4.dev", "openeo_processes==0.0.4", # 0.0.4 is special build/release, also see https://github.com/Open-EO/openeo-python-driver/issues/152 "gunicorn>=20.0.1", "numpy>=1.22.0", diff --git a/tests/util/test_auth.py b/tests/util/test_auth.py index b953d11e..38c99385 100644 --- a/tests/util/test_auth.py +++ b/tests/util/test_auth.py @@ -1,7 +1,10 @@ import logging import re +import time +from typing import Optional import pytest +import time_machine from openeo.rest.auth.testing import OidcMock from openeo_driver.util.auth import ClientCredentials, ClientCredentialsAccessTokenHelper @@ -67,27 +70,76 @@ def credentials(self) -> ClientCredentials: return ClientCredentials(oidc_issuer="https://oidc.test", client_id="client123", client_secret="s3cr3t") @pytest.fixture - def oidc_mock(self, requests_mock, credentials) -> OidcMock: + def access_token_expires_in(self) -> Optional[int]: + """By default we let access tokens of the mock expire in 1 hour""" + return 3600 + + @pytest.fixture + def local_cache_ttl(self) -> int: + """By default we let the local cache expire in 30 minutes""" + return 1800 + + @pytest.fixture + def oidc_mock(self, requests_mock, credentials, access_token_expires_in) -> OidcMock: oidc_mock = OidcMock( requests_mock=requests_mock, oidc_issuer=credentials.oidc_issuer, expected_grant_type="client_credentials", expected_client_id=credentials.client_id, expected_fields={"client_secret": credentials.client_secret, "scope": "openid"}, + access_token_expires_in=access_token_expires_in, ) return oidc_mock - def test_basic(self, credentials, oidc_mock: OidcMock): - helper = ClientCredentialsAccessTokenHelper(credentials=credentials) + def test_basic(self, credentials, oidc_mock: OidcMock, local_cache_ttl): + helper = ClientCredentialsAccessTokenHelper(credentials=credentials, default_ttl=local_cache_ttl) assert helper.get_access_token() == oidc_mock.state["access_token"] - def test_caching(self, credentials, oidc_mock: OidcMock): - helper = ClientCredentialsAccessTokenHelper(credentials=credentials) - assert oidc_mock.mocks["token_endpoint"].call_count == 0 - assert helper.get_access_token() == oidc_mock.state["access_token"] - assert oidc_mock.mocks["token_endpoint"].call_count == 1 - assert helper.get_access_token() == oidc_mock.state["access_token"] - assert oidc_mock.mocks["token_endpoint"].call_count == 1 + @pytest.mark.parametrize( + ["desc", "local_cache_ttl", "access_token_expires_in", "no_cache_at_30m", "no_cache_at_50m"], + [ + ("Long caching", 3600, 3600, False, False), + ("No/very short caching", 0, 0, True, True), + ("Local cache expires after 30m but before 50m, no server expiry", 40 * 60, None, False, True), + ("Server cache shortest and cause expiry after 30m but before 50m", 3600, 40 * 60, False, True), + ("Local cache expires after 30m but access token does not", 40 * 60, 7200, False, False), + ("Server cache shortest and shorter then default expiry", 3600, 10, True, True), + ], + ) + def test_caching( + self, + credentials, + oidc_mock: OidcMock, + local_cache_ttl, + access_token_expires_in, + no_cache_at_30m, + no_cache_at_50m, + desc: str, + ): + """ + Test caching by requesting an access token at start time and at the 30 and 50 minute mark. + """ + now = time.time() + helper = ClientCredentialsAccessTokenHelper(credentials=credentials, default_ttl=local_cache_ttl) + + expected_chache_misses = 0 + with time_machine.travel(now): + assert oidc_mock.mocks["token_endpoint"].call_count == expected_chache_misses + assert helper.get_access_token() == oidc_mock.state["access_token"] + expected_chache_misses += 1 # First request is always a miss + assert oidc_mock.mocks["token_endpoint"].call_count == expected_chache_misses + + with time_machine.travel(now + 30 * 60): + assert helper.get_access_token() == oidc_mock.state["access_token"] + if no_cache_at_30m: + expected_chache_misses += 1 + assert oidc_mock.mocks["token_endpoint"].call_count == expected_chache_misses + + with time_machine.travel(now + 50 * 60): + assert helper.get_access_token() == oidc_mock.state["access_token"] + if no_cache_at_50m: + expected_chache_misses += 1 + assert oidc_mock.mocks["token_endpoint"].call_count == expected_chache_misses @pytest.mark.skip(reason="Logging was removed for eu-cdse/openeo-cdse-infra#476") def test_secret_logging(self, credentials, oidc_mock: OidcMock, caplog):