From e2c288570ae4e3549d116821f5097727828fea52 Mon Sep 17 00:00:00 2001 From: Peter Van Bouwel Date: Tue, 31 Mar 2026 08:03:03 +0200 Subject: [PATCH 1/2] fix: do not exceed lifetime of access token in cache --- openeo_driver/util/auth.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/openeo_driver/util/auth.py b/openeo_driver/util/auth.py index 7377e2f0..972dbe90 100644 --- a/openeo_driver/util/auth.py +++ b/openeo_driver/util/auth.py @@ -6,7 +6,12 @@ from typing import Mapping, NamedTuple, Optional, Union import requests -from openeo.rest.auth.oidc import OidcClientCredentialsAuthenticator, OidcClientInfo, OidcProviderInfo +from openeo.rest.auth.oidc import ( + OidcClientCredentialsAuthenticator, + OidcClientInfo, + OidcProviderInfo, + AccessTokenResult, +) from openeo.util import str_truncate _log = logging.getLogger(__name__) @@ -116,18 +121,25 @@ def setup_credentials(self, credentials: ClientCredentials) -> None: client_info=client_info, requests_session=self._session ) - def _get_access_token(self) -> str: + def _get_access_token(self) -> AccessTokenResult: """Get an access token using the configured authenticator.""" if not self._authenticator: raise RuntimeError("No authentication set up") _log.debug(f"{self.__class__.__name__} getting access token") - tokens = self._authenticator.get_tokens() - return tokens.access_token + access_token_response = self._authenticator.get_tokens() + return access_token_response def get_access_token(self) -> str: """Get an access token using the configured authenticator.""" if time.time() > self._cache.expires_at: - access_token = self._get_access_token() - # TODO: get expiry from access token itself? - self._cache = _AccessTokenCache(access_token, time.time() + self._default_ttl) + access_token_response = self._get_access_token() + access_token = access_token_response.access_token + self._cache = _AccessTokenCache(access_token, self._get_access_token_expiry_time(access_token_response)) return self._cache.access_token + + def _get_access_token_expiry_time(self, access_token_response: AccessTokenResult) -> float: + if access_token_response.expires_in is None: + return time.time() + self._default_ttl + else: + # Expire the cache entry before the entry actually expires + return time.time() + access_token_response.expires_in * 0.90 From cb5ffb6ef0ac64f0cc9de4deeb9be623fa4a2370 Mon Sep 17 00:00:00 2001 From: Peter Van Bouwel Date: Tue, 31 Mar 2026 08:04:26 +0200 Subject: [PATCH 2/2] test: cache testing to verify adherence to access token expires in field --- tests/util/test_auth.py | 71 +++++++++++++++++++++++++++++++++++------ 1 file changed, 61 insertions(+), 10 deletions(-) diff --git a/tests/util/test_auth.py b/tests/util/test_auth.py index b953d11e..e24ace6e 100644 --- a/tests/util/test_auth.py +++ b/tests/util/test_auth.py @@ -1,7 +1,10 @@ import logging import re +import time +from typing import Optional import pytest +import time_machine from openeo.rest.auth.testing import OidcMock from openeo_driver.util.auth import ClientCredentials, ClientCredentialsAccessTokenHelper @@ -67,27 +70,75 @@ def credentials(self) -> ClientCredentials: return ClientCredentials(oidc_issuer="https://oidc.test", client_id="client123", client_secret="s3cr3t") @pytest.fixture - def oidc_mock(self, requests_mock, credentials) -> OidcMock: + def access_token_expires_in(self) -> Optional[int]: + """By default we let access tokens of the mock expire in 1 hour""" + return 3600 + + @pytest.fixture + def local_cache_ttl(self) -> int: + """By default we let the local cache expire in 30 minutes""" + return 1800 + + @pytest.fixture + def oidc_mock(self, requests_mock, credentials, access_token_expires_in) -> OidcMock: oidc_mock = OidcMock( requests_mock=requests_mock, oidc_issuer=credentials.oidc_issuer, expected_grant_type="client_credentials", expected_client_id=credentials.client_id, expected_fields={"client_secret": credentials.client_secret, "scope": "openid"}, + access_token_expires_in=access_token_expires_in, ) return oidc_mock - def test_basic(self, credentials, oidc_mock: OidcMock): - helper = ClientCredentialsAccessTokenHelper(credentials=credentials) + def test_basic(self, credentials, oidc_mock: OidcMock, local_cache_ttl): + helper = ClientCredentialsAccessTokenHelper(credentials=credentials, default_ttl=local_cache_ttl) assert helper.get_access_token() == oidc_mock.state["access_token"] - def test_caching(self, credentials, oidc_mock: OidcMock): - helper = ClientCredentialsAccessTokenHelper(credentials=credentials) - assert oidc_mock.mocks["token_endpoint"].call_count == 0 - assert helper.get_access_token() == oidc_mock.state["access_token"] - assert oidc_mock.mocks["token_endpoint"].call_count == 1 - assert helper.get_access_token() == oidc_mock.state["access_token"] - assert oidc_mock.mocks["token_endpoint"].call_count == 1 + @pytest.mark.parametrize( + ["desc", "local_cache_ttl", "access_token_expires_in", "no_cache_at_30m", "no_cache_at_50m"], + [ + ("Long caching", 3600, 3600, False, False), + ("No/very short caching", 0, 0, True, True), + ("Local cache expires after 30m but before 50m, no server expiry", 40 * 60, None, False, True), + ("Server cache shortest and cause expiry after 30m but before 50m", 3600, 40 * 60, False, True), + ("Local cache expires after 30m but access token does not", 40 * 60, 7200, False, False), + ], + ) + def test_caching( + self, + credentials, + oidc_mock: OidcMock, + local_cache_ttl, + access_token_expires_in, + no_cache_at_30m, + no_cache_at_50m, + desc: str, + ): + """ + Test caching by requesting an access token at start time and at the 30 and 50 minute mark. + """ + now = time.time() + helper = ClientCredentialsAccessTokenHelper(credentials=credentials, default_ttl=local_cache_ttl) + + expected_chache_misses = 0 + with time_machine.travel(now): + assert oidc_mock.mocks["token_endpoint"].call_count == expected_chache_misses + assert helper.get_access_token() == oidc_mock.state["access_token"] + expected_chache_misses += 1 # First request is always a miss + assert oidc_mock.mocks["token_endpoint"].call_count == expected_chache_misses + + with time_machine.travel(now + 30 * 60): + assert helper.get_access_token() == oidc_mock.state["access_token"] + if no_cache_at_30m: + expected_chache_misses += 1 + assert oidc_mock.mocks["token_endpoint"].call_count == expected_chache_misses + + with time_machine.travel(now + 50 * 60): + assert helper.get_access_token() == oidc_mock.state["access_token"] + if no_cache_at_50m: + expected_chache_misses += 1 + assert oidc_mock.mocks["token_endpoint"].call_count == expected_chache_misses @pytest.mark.skip(reason="Logging was removed for eu-cdse/openeo-cdse-infra#476") def test_secret_logging(self, credentials, oidc_mock: OidcMock, caplog):