From ff3ed4680293d46a427687daf172b7521f0f1f88 Mon Sep 17 00:00:00 2001 From: Matthias Dellweg Date: Tue, 5 Nov 2024 12:10:18 +0100 Subject: [PATCH 1/2] Rework authentication The auth provider now is supposed to provide the credentials for certain authentication mechanisms and not return requests specific classes. This should help to define the interface without lock in to requests. --- CHANGES/pulp-glue/+auth_recontract.removal | 1 + pulp-glue/pulp_glue/common/authentication.py | 211 ++++++++------ pulp-glue/pulp_glue/common/context.py | 32 ++- pulp-glue/pulp_glue/common/openapi.py | 275 ++++++++++--------- pulp-glue/tests/conftest.py | 23 +- pulp-glue/tests/test_auth_provider.py | 93 ++++--- pulp-glue/tests/test_authentication.py | 30 -- pulp-glue/tests/test_openapi.py | 274 +++++++++++++++++- pulp_cli/__init__.py | 4 +- pulp_cli/config.py | 6 +- pulp_cli/generic.py | 223 ++++++++------- 11 files changed, 761 insertions(+), 411 deletions(-) create mode 100644 CHANGES/pulp-glue/+auth_recontract.removal delete mode 100644 pulp-glue/tests/test_authentication.py diff --git a/CHANGES/pulp-glue/+auth_recontract.removal b/CHANGES/pulp-glue/+auth_recontract.removal new file mode 100644 index 000000000..dfc83f18d --- /dev/null +++ b/CHANGES/pulp-glue/+auth_recontract.removal @@ -0,0 +1 @@ +Breaking change: Reworked the contract around the `AuthProvider` to allow authentication to be coded independently of the underlying library. diff --git a/pulp-glue/pulp_glue/common/authentication.py b/pulp-glue/pulp_glue/common/authentication.py index a7998bf44..991aa5af9 100644 --- a/pulp-glue/pulp_glue/common/authentication.py +++ b/pulp-glue/pulp_glue/common/authentication.py @@ -1,97 +1,140 @@ import typing as t -from datetime import datetime, timedelta -import requests - -class OAuth2ClientCredentialsAuth(requests.auth.AuthBase): - """ - This implements the OAuth2 ClientCredentials Grant authentication flow. - https://datatracker.ietf.org/doc/html/rfc6749#section-4.4 +class AuthProviderBase: """ + Base class for auth providers. - def __init__( - self, - client_id: str, - client_secret: str, - token_url: str, - scopes: list[str] | None = None, - verify_ssl: str | bool | None = None, - ): - self._token_server_auth = requests.auth.HTTPBasicAuth(client_id, client_secret) - self._token_url = token_url - self._scopes = scopes - self._verify_ssl = verify_ssl + This abstract base class will analyze the authentication proposals of the openapi specs. + Different authentication schemes can be implemented in subclasses. + """ - self._access_token: str | None = None - self._expire_at: datetime | None = None + def can_complete_http_basic(self) -> bool: + return False + + def can_complete_mutualTLS(self) -> bool: + return False + + def can_complete_oauth2_client_credentials(self, scopes: list[str]) -> bool: + return False + + def can_complete_scheme(self, scheme: dict[str, t.Any], scopes: list[str]) -> bool: + if scheme["type"] == "http": + if scheme["scheme"] == "basic": + return self.can_complete_http_basic() + elif scheme["type"] == "mutualTLS": + return self.can_complete_mutualTLS() + elif scheme["type"] == "oauth2": + for flow_name, flow in scheme["flows"].items(): + if flow_name == "clientCredentials" and self.can_complete_oauth2_client_credentials( + flow["scopes"] + ): + return True + return False + + def can_complete( + self, proposal: dict[str, list[str]], security_schemes: dict[str, dict[str, t.Any]] + ) -> bool: + for name, scopes in proposal.items(): + scheme = security_schemes.get(name) + if scheme is None or not self.can_complete_scheme(scheme, scopes): + return False + # This covers the case where `[]` allows for no auth at all. + return True + + async def auth_success_hook( + self, proposal: dict[str, list[str]], security_schemes: dict[str, dict[str, t.Any]] + ) -> None: + pass + + async def auth_failure_hook( + self, proposal: dict[str, list[str]], security_schemes: dict[str, dict[str, t.Any]] + ) -> None: + pass + + async def http_basic_credentials(self) -> tuple[bytes, bytes]: + raise NotImplementedError() + + async def oauth2_client_credentials(self) -> tuple[bytes, bytes]: + raise NotImplementedError() + + def tls_credentials(self) -> tuple[str, str | None]: + raise NotImplementedError() + + +class BasicAuthProvider(AuthProviderBase): + """ + AuthProvider providing basic auth with fixed `username`, `password`. + """ - def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: - if self._expire_at is None or self._expire_at < datetime.now(): - self._retrieve_token() + def __init__(self, username: t.AnyStr, password: t.AnyStr): + super().__init__() + self.username: bytes = username.encode("latin1") if isinstance(username, str) else username + self.password: bytes = password.encode("latin1") if isinstance(password, str) else password - assert self._access_token is not None + def can_complete_http_basic(self) -> bool: + return True - request.headers["Authorization"] = f"Bearer {self._access_token}" + async def http_basic_credentials(self) -> tuple[bytes, bytes]: + return self.username, self.password - # Call to untyped function "register_hook" in typed context - request.register_hook("response", self._handle401) # type: ignore[no-untyped-call] - return request +class GlueAuthProvider(AuthProviderBase): + """ + AuthProvider allowing to be used with prepared credentials. + """ - def _handle401( + def __init__( self, - response: requests.Response, - **kwargs: t.Any, - ) -> requests.Response: - if response.status_code != 401: - return response - - # If we get this far, probably the token is not valid anymore. - - # Try to reach for a new token once. - self._retrieve_token() - - assert self._access_token is not None - - # Consume content and release the original connection - # to allow our new request to reuse the same one. - response.content - response.close() - prepared_new_request = response.request.copy() - - prepared_new_request.headers["Authorization"] = f"Bearer {self._access_token}" - - # Avoid to enter into an infinity loop. - # Call to untyped function "deregister_hook" in typed context - prepared_new_request.deregister_hook( # type: ignore[no-untyped-call] - "response", self._handle401 - ) - - # "Response" has no attribute "connection" - new_response: requests.Response = response.connection.send(prepared_new_request, **kwargs) - new_response.history.append(response) - new_response.request = prepared_new_request - - return new_response - - def _retrieve_token(self) -> None: - data = { - "grant_type": "client_credentials", - } - - if self._scopes: - data["scope"] = " ".join(self._scopes) - - response: requests.Response = requests.post( - self._token_url, - data=data, - auth=self._token_server_auth, - verify=self._verify_ssl, - ) - - response.raise_for_status() - - token = response.json() - self._expire_at = datetime.now() + timedelta(seconds=token["expires_in"]) - self._access_token = token["access_token"] + *, + username: t.AnyStr | None = None, + password: t.AnyStr | None = None, + client_id: t.AnyStr | None = None, + client_secret: t.AnyStr | None = None, + cert: str | None = None, + key: str | None = None, + ): + super().__init__() + self.username: bytes | None = None + self.password: bytes | None = None + self.client_id: bytes | None = None + self.client_secret: bytes | None = None + self.cert: str | None = cert + self.key: str | None = key + + if username is not None: + assert password is not None + self.username = username.encode("latin1") if isinstance(username, str) else username + self.password = password.encode("latin1") if isinstance(password, str) else password + if client_id is not None: + assert client_secret is not None + self.client_id = client_id.encode("latin1") if isinstance(client_id, str) else client_id + self.client_secret = ( + client_secret.encode("latin1") if isinstance(client_secret, str) else client_secret + ) + + if cert is None and key is not None: + raise RuntimeError("Key can only be used together with a cert.") + + def can_complete_http_basic(self) -> bool: + return self.username is not None + + def can_complete_oauth2_client_credentials(self, scopes: list[str]) -> bool: + return self.client_id is not None + + def can_complete_mutualTLS(self) -> bool: + return self.cert is not None + + async def http_basic_credentials(self) -> tuple[bytes, bytes]: + assert self.username is not None + assert self.password is not None + return self.username, self.password + + async def oauth2_client_credentials(self) -> tuple[bytes, bytes]: + assert self.client_id is not None + assert self.client_secret is not None + return self.client_id, self.client_secret + + def tls_credentials(self) -> tuple[str, str | None]: + assert self.cert is not None + return (self.cert, self.key) diff --git a/pulp-glue/pulp_glue/common/context.py b/pulp-glue/pulp_glue/common/context.py index fff028eff..18b73904b 100644 --- a/pulp-glue/pulp_glue/common/context.py +++ b/pulp-glue/pulp_glue/common/context.py @@ -9,6 +9,7 @@ from packaging.specifiers import SpecifierSet +from pulp_glue.common.authentication import GlueAuthProvider from pulp_glue.common.exceptions import ( NotImplementedFake, OpenAPIError, @@ -19,7 +20,7 @@ UnsafeCallError, ) from pulp_glue.common.i18n import get_translation -from pulp_glue.common.openapi import BasicAuthProvider, OpenAPI +from pulp_glue.common.openapi import OpenAPI if sys.version_info >= (3, 11): import tomllib @@ -202,6 +203,20 @@ def patch_upstream_pulp_replicate_request_body(api: OpenAPI) -> None: operation.pop("requestBody", None) +@api_quirk(PluginRequirement("core", specifier="<3.85")) +def patch_security_scheme_mutual_tls(api: OpenAPI) -> None: + # Trick to allow tls cert auth on older Pulp. + if (components := api.api_spec.get("components")) is not None: + if (security_schemes := components.get("securitySchemes")) is not None: + # Only if it is going to be idempotent... + if "gluePatchTLS" not in security_schemes: + security_schemes["gluePatchTLS"] = {"type": "mutualTLS"} + for method, path in api.operations.values(): + operation = api.api_spec["paths"][path][method] + if "security" in operation: + operation["security"].append({"gluePatchTLS": []}) + + class PulpContext: """ Abstract class for the global PulpContext object. @@ -335,8 +350,13 @@ def from_config(cls, config: dict[str, t.Any]) -> "t.Self": api_kwargs: dict[str, t.Any] = { "base_url": config["base_url"], } - if "username" in config: - api_kwargs["auth_provider"] = BasicAuthProvider(config["username"], config["password"]) + api_kwargs["auth_provider"] = GlueAuthProvider( + **{ + k: v + for k, v in config.items() + if k in {"username", "password", "client_id", "client_secret", "cert", "key"} + } + ) if "headers" in config: api_kwargs["headers"] = dict( (header.split(":", maxsplit=1) for header in config["headers"]) @@ -385,7 +405,9 @@ def api(self) -> OpenAPI: # Deprecated for 'auth'. if not password: password = self.prompt("password", hide_input=True) - self._api_kwargs["auth_provider"] = BasicAuthProvider(username, password) + self._api_kwargs["auth_provider"] = GlueAuthProvider( + username=username, password=password + ) warnings.warn( "Using 'username' and 'password' with 'PulpContext' is deprecated. " "Use an auth provider with the 'auth_provider' argument instead.", @@ -399,10 +421,10 @@ def api(self) -> OpenAPI: ) except OpenAPIError as e: raise PulpException(str(e)) + self._patch_api_spec() # Rerun scheduled version checks for plugin_requirement in self._needed_plugins: self.needs_plugin(plugin_requirement) - self._patch_api_spec() return self._api @property diff --git a/pulp-glue/pulp_glue/common/openapi.py b/pulp-glue/pulp_glue/common/openapi.py index ce7e32e99..2fe58debb 100644 --- a/pulp-glue/pulp_glue/common/openapi.py +++ b/pulp-glue/pulp_glue/common/openapi.py @@ -1,10 +1,14 @@ +import asyncio import json import logging import os +import ssl import typing as t import warnings -from collections import defaultdict +from base64 import b64encode from dataclasses import dataclass +from datetime import datetime, timedelta +from functools import cached_property from io import BufferedReader from urllib.parse import urlencode, urljoin @@ -13,6 +17,7 @@ from multidict import CIMultiDict, CIMultiDictProxy, MutableMultiMapping from pulp_glue.common import __version__ +from pulp_glue.common.authentication import AuthProviderBase from pulp_glue.common.exceptions import ( OpenAPIError, PulpAuthenticationFailed, @@ -38,7 +43,7 @@ class _Request: operation_id: str method: str url: str - headers: MutableMultiMapping[str] | CIMultiDictProxy[str] | t.MutableMapping[str, str] + headers: MutableMultiMapping[str] | CIMultiDict[str] | t.MutableMapping[str, str] params: dict[str, str] | None = None data: dict[str, t.Any] | str | None = None files: dict[str, tuple[str, UploadType, str]] | None = None @@ -52,99 +57,6 @@ class _Response: body: bytes -class AuthProviderBase: - """ - Base class for auth providers. - - This abstract base class will analyze the authentication proposals of the openapi specs. - Different authentication schemes should be implemented by subclasses. - Returned auth objects need to be compatible with `requests.auth.AuthBase`. - """ - - def basic_auth(self, scopes: list[str]) -> requests.auth.AuthBase | None: - """Implement this to provide means of http basic auth.""" - return None - - def http_auth( - self, security_scheme: dict[str, t.Any], scopes: list[str] - ) -> requests.auth.AuthBase | None: - """Select a suitable http auth scheme or return None.""" - # https://www.iana.org/assignments/http-authschemes/http-authschemes.xhtml - if security_scheme["scheme"] == "basic": - result = self.basic_auth(scopes) - if result: - return result - return None - - def oauth2_client_credentials_auth( - self, flow: t.Any, scopes: list[str] - ) -> requests.auth.AuthBase | None: - """Implement this to provide other authentication methods.""" - return None - - def oauth2_auth( - self, security_scheme: dict[str, t.Any], scopes: list[str] - ) -> requests.auth.AuthBase | None: - """Select a suitable oauth2 flow or return None.""" - # Check flows by preference. - if "clientCredentials" in security_scheme["flows"]: - flow = security_scheme["flows"]["clientCredentials"] - # Select this flow only if it claims to provide all the necessary scopes. - # This will allow subsequent auth proposals to be considered. - if set(scopes) - set(flow["scopes"]): - return None - - result = self.oauth2_client_credentials_auth(flow, scopes) - if result: - return result - return None - - def __call__( - self, - security: list[dict[str, list[str]]], - security_schemes: dict[str, dict[str, t.Any]], - ) -> requests.auth.AuthBase | None: - # Reorder the proposals by their type to prioritize properly. - # Select only single mechanism proposals on the way. - proposed_schemes: dict[str, dict[str, list[str]]] = defaultdict(dict) - for proposal in security: - if len(proposal) == 0: - # Empty proposal: No authentication needed. Shortcut return. - return None - if len(proposal) == 1: - name, scopes = list(proposal.items())[0] - proposed_schemes[security_schemes[name]["type"]][name] = scopes - # Ignore all proposals with more than one required auth mechanism. - - # Check for auth schemes by preference. - if "oauth2" in proposed_schemes: - for name, scopes in proposed_schemes["oauth2"].items(): - result = self.oauth2_auth(security_schemes[name], scopes) - if result: - return result - - # if we get here, either no-oauth2, OR we couldn't find creds - if "http" in proposed_schemes: - for name, scopes in proposed_schemes["http"].items(): - result = self.http_auth(security_schemes[name], scopes) - if result: - return result - - raise OpenAPIError(_("No suitable auth scheme found.")) - - -class BasicAuthProvider(AuthProviderBase): - """ - Implementation for AuthProviderBase providing basic auth with fixed `username`, `password`. - """ - - def __init__(self, username: str, password: str): - self.auth = requests.auth.HTTPBasicAuth(username, password) - - def basic_auth(self, scopes: list[str]) -> requests.auth.AuthBase | None: - return self.auth - - class OpenAPI: """ The abstraction Layer to interact with a server providing an openapi v3 specification. @@ -154,7 +66,7 @@ class OpenAPI: served api. doc_path: Path of the json api doc schema relative to the `base_url`. headers: Dictionary of additional request headers. - auth_provider: Object that returns requests auth objects according to the api spec. + auth_provider: Object that can be questioned for credentials according to the api spec. cert: Client certificate used for auth. key: Matching key for `cert` if not already included. verify_ssl: Whether to check server TLS certificates agains a CA. @@ -210,9 +122,8 @@ def __init__( self._dry_run: bool = dry_run self._headers = CIMultiDict(headers or {}) self._verify_ssl = verify_ssl + self._auth_provider = auth_provider - self._cert = cert - self._key = key self._headers.update( { @@ -225,6 +136,10 @@ def __init__( self._setup_session() + self._oauth2_lock = asyncio.Lock() + self._oauth2_token: str | None = None + self._oauth2_expires: datetime = datetime.now() + self.load_api(refresh_cache=refresh_cache) def _setup_session(self) -> None: @@ -237,22 +152,19 @@ def _setup_session(self) -> None: # Don't redirect, because carrying auth accross redirects is unsafe. self._session.max_redirects = 0 self._session.headers.update(self._headers) - if self._auth_provider: - if self._cert or self._key: - raise OpenAPIError(_("Cannot use both 'auth' and 'cert'.")) - else: - if self._cert and self._key: - self._session.cert = (self._cert, self._key) - elif self._cert: - self._session.cert = self._cert - elif self._key: - raise OpenAPIError(_("Cert is required if key is set.")) session_settings = self._session.merge_environment_settings( self._base_url, {}, None, self._verify_ssl, None ) self._session.verify = session_settings["verify"] self._session.proxies = session_settings["proxies"] + if self._auth_provider is not None and self._auth_provider.can_complete_mutualTLS(): + cert, key = self._auth_provider.tls_credentials() + if key is not None: + self._session.cert = (cert, key) + else: + self._session.cert = cert + @property def base_url(self) -> str: return self._base_url @@ -261,6 +173,20 @@ def base_url(self) -> str: def cid(self) -> str | None: return self._headers.get("Correlation-Id") + @cached_property + def ssl_context(self) -> t.Union[ssl.SSLContext, bool]: + _ssl_context: t.Union[ssl.SSLContext, bool] + if self._verify_ssl is False: + _ssl_context = False + else: + if isinstance(self._verify_ssl, str): + _ssl_context = ssl.create_default_context(cafile=self._verify_ssl) + else: + _ssl_context = ssl.create_default_context() + if self._auth_provider is not None and self._auth_provider.can_complete_mutualTLS(): + _ssl_context.load_cert_chain(*self._auth_provider.tls_credentials()) + return _ssl_context + def load_api(self, refresh_cache: bool = False) -> None: # TODO: Find a way to invalidate caches on upstream change xdg_cache_home: str = os.environ.get("XDG_CACHE_HOME") or "~/.cache" @@ -272,6 +198,7 @@ def load_api(self, refresh_cache: bool = False) -> None: ) try: if refresh_cache: + # Fake that we did not find the cache. raise OSError() with open(apidoc_cache, "rb") as f: data: bytes = f.read() @@ -554,23 +481,100 @@ def _log_request(self, request: _Request) -> None: for key, (name, _dummy, content_type) in request.files.items(): self._debug_callback(3, f"{key} <- {name} [{content_type}]") + def _select_proposal( + self, + request: _Request, + ) -> dict[str, list[str]] | None: + proposal = None + if ( + request.security + and "Authorization" not in request.headers + and "Authorization" not in self._session.headers + and self._auth_provider is not None + ): + security_schemes: dict[str, dict[str, t.Any]] = self.api_spec["components"][ + "securitySchemes" + ] + try: + proposal = next( + ( + p + for p in request.security + if self._auth_provider.can_complete(p, security_schemes) + ) + ) + except StopIteration: + raise OpenAPIError(_("No suitable auth scheme found.")) + return proposal + + async def _authenticate_request( + self, + request: _Request, + proposal: dict[str, list[str]], + ) -> bool: + assert self._auth_provider is not None + + may_retry = False + security_schemes = self.api_spec["components"]["securitySchemes"] + for scheme_name, scopes in proposal.items(): + scheme = security_schemes[scheme_name] + if scheme["type"] == "http": + if scheme["scheme"] == "basic": + username, password = await self._auth_provider.http_basic_credentials() + secret = b64encode(username + b":" + password) + # TODO Should we add, amend or replace the existing auth header? + request.headers["Authorization"] = f"Basic {secret.decode()}" + else: + raise NotImplementedError("Auth scheme: http " + scheme["scheme"]) + elif scheme["type"] == "oauth2": + flow = scheme["flows"].get("clientCredentials") + if flow is None: + raise NotImplementedError("OAuth2: Only client credential flow is available.") + # Allow retry if the token was taken from cache. + may_retry = not await self._fetch_oauth2_token(flow) + request.headers["Authorization"] = f"Bearer {self._oauth2_token}" + elif scheme["type"] == "mutualTLS": + # At this point, we assume the cert has already been loaded into the sslcontext. + pass + else: + raise NotImplementedError("Auth type: " + scheme["type"]) + return may_retry + + async def _fetch_oauth2_token(self, flow: dict[str, t.Any]) -> bool: + assert self._auth_provider is not None + + new_token = False + async with self._oauth2_lock: + now = datetime.now() + if self._oauth2_token is None or self._oauth2_expires < now: + # Get or refresh token. + client_id, client_secret = await self._auth_provider.oauth2_client_credentials() + secret = b64encode(client_id + b":" + client_secret) + data: dict[str, t.Any] = {"grant_type": "client_credentials"} + scopes = flow.get("scopes") + if scopes: + data["scopes"] = " ".join(scopes) + request = _Request( + operation_id="", + method="post", + url=flow["tokenUrl"], + headers={"Authorization": f"Basic {secret.decode()}"}, + data=data, + ) + response = self._send_request(request) + if response.status_code < 200 or response.status_code >= 300: + raise OpenAPIError("Failed to fetch OAuth2 token") + result = json.loads(response.body) + self._oauth2_token = result["access_token"] + self._oauth2_expires = now + timedelta(seconds=result["expires_in"]) + new_token = True + return new_token + def _send_request( self, request: _Request, ) -> _Response: # This function uses requests to translate the _Request into a _Response. - if request.security and self._auth_provider: - if "Authorization" in self._session.headers: - # Bad idea, but you wanted it that way. - auth = None - else: - auth = self._auth_provider( - request.security, self.api_spec["components"]["securitySchemes"] - ) - else: - # No auth required? Don't provide it. - # No auth_provider available? Hope for the best (should do the trick for cert auth). - auth = None try: r = self._session.request( request.method, @@ -579,7 +583,6 @@ def _send_request( headers=request.headers, data=request.data, files=request.files, - auth=auth, ) response = _Response(status_code=r.status_code, headers=r.headers, body=r.content) except requests.TooManyRedirects as e: @@ -676,8 +679,9 @@ def call( headers = self._extract_params("header", path_spec, method_spec, parameters) + rel_url = path for name, value in self._extract_params("path", path_spec, method_spec, parameters).items(): - path = path.replace("{" + name + "}", value) + rel_url = path.replace("{" + name + "}", value) query_params = self._extract_params("query", path_spec, method_spec, parameters) @@ -687,7 +691,7 @@ def call( names=", ".join(parameters.keys()), operation_id=operation_id ) ) - url = urljoin(self._base_url, path) + url = urljoin(self._base_url, rel_url) request = self._render_request( path_spec, @@ -703,7 +707,32 @@ def call( if self._dry_run and request.method.upper() not in SAFE_METHODS: raise UnsafeCallError(_("Call aborted due to safe mode")) + may_retry = False + if proposal := self._select_proposal(request): + assert len(proposal) == 1, "More complex security proposals are not implemented." + may_retry = asyncio.run(self._authenticate_request(request, proposal)) + response = self._send_request(request) + if proposal is not None: + assert self._auth_provider is not None + if may_retry and response.status_code == 401: + self._oauth2_token = None + asyncio.run(self._authenticate_request(request, proposal)) + response = self._send_request(request) + + if response.status_code >= 200 and response.status_code < 300: + asyncio.run( + self._auth_provider.auth_success_hook( + proposal, self.api_spec["components"]["securitySchemes"] + ) + ) + elif response.status_code == 401: + asyncio.run( + self._auth_provider.auth_failure_hook( + proposal, self.api_spec["components"]["securitySchemes"] + ) + ) + self._log_response(response) return self._parse_response(method_spec, response) diff --git a/pulp-glue/tests/conftest.py b/pulp-glue/tests/conftest.py index bc8d56678..06a44e75a 100644 --- a/pulp-glue/tests/conftest.py +++ b/pulp-glue/tests/conftest.py @@ -1,11 +1,11 @@ import json -import os import typing as t import pytest +from pulp_glue.common.authentication import GlueAuthProvider from pulp_glue.common.context import PulpContext -from pulp_glue.common.openapi import BasicAuthProvider, OpenAPI +from pulp_glue.common.openapi import OpenAPI FAKE_OPENAPI_SPEC = json.dumps( { @@ -23,8 +23,6 @@ def pulp_ctx( if not any((mark.name == "live" for mark in request.node.iter_markers())): pytest.fail("This fixture can only be used in live (integration) tests.") - if os.environ.get("PULP_OAUTH2", "").lower() == "true": - pytest.skip("Pulp-glue in isolation does not support OAuth2 atm.") verbose = request.config.getoption("verbose") settings = pulp_cli_settings["cli"].copy() settings["debug_callback"] = lambda i, s: i <= verbose and print(s) @@ -58,18 +56,15 @@ def fake_pulp_ctx( if not any((mark.name == "live" for mark in request.node.iter_markers())): pytest.fail("This fixture can only be used in live (integration) tests.") - if os.environ.get("PULP_OAUTH2", "").lower() == "true": - pytest.skip("Pulp-glue in isolation does not support OAuth2 atm.") verbose = request.config.getoption("verbose") settings = pulp_cli_settings["cli"] - if "username" in settings: - username = settings.get("username") - assert isinstance(username, str) - password = settings.get("password") - assert isinstance(password, str) - auth_provider = BasicAuthProvider(username, password) - else: - auth_provider = None + auth_provider = GlueAuthProvider( + **{ + k: v + for k, v in settings.items() + if k in {"username", "password", "client_id", "client_secret", "cert", "key"} + } + ) return PulpContext( api_kwargs={ "base_url": settings["base_url"], diff --git a/pulp-glue/tests/test_auth_provider.py b/pulp-glue/tests/test_auth_provider.py index f8e2f2441..0bcde6da0 100644 --- a/pulp-glue/tests/test_auth_provider.py +++ b/pulp-glue/tests/test_auth_provider.py @@ -1,10 +1,9 @@ +import asyncio import typing as t import pytest -from requests.auth import AuthBase -from pulp_glue.common.exceptions import OpenAPIError -from pulp_glue.common.openapi import AuthProviderBase +from pulp_glue.common.authentication import AuthProviderBase, BasicAuthProvider, GlueAuthProvider pytestmark = pytest.mark.glue @@ -51,56 +50,66 @@ }, }, }, + "E": {"type": "mutualTLS"}, } -class MockBasicAuth(AuthBase): - pass +class TestBasicAuthProvider: + @pytest.fixture(scope="class") + def provider(self) -> AuthProviderBase: + return BasicAuthProvider(username="user1", password="password1") + def test_can_complete_basic(self, provider: AuthProviderBase) -> None: + assert provider.can_complete_http_basic() -class MockOAuth2CCAuth(AuthBase): - pass + def test_provides_username_and_password(self, provider: AuthProviderBase) -> None: + assert asyncio.run(provider.http_basic_credentials()) == ( + b"user1", + b"password1", + ) + def test_cannot_complete_mutualTLS(self, provider: AuthProviderBase) -> None: + assert not provider.can_complete_mutualTLS() -def test_auth_provider_select_mechanism(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(AuthProviderBase, "basic_auth", lambda *args: MockBasicAuth()) - monkeypatch.setattr( - AuthProviderBase, - "oauth2_client_credentials_auth", - lambda *args: MockOAuth2CCAuth(), - ) - provider = AuthProviderBase() + def test_can_complete_basic_proposal(self, provider: AuthProviderBase) -> None: + assert provider.can_complete({"B": []}, security_schemes=SECURITY_SCHEMES) - # Error if no auth scheme is available. - with pytest.raises(OpenAPIError): - provider([], SECURITY_SCHEMES) + def test_cannot_complete_bearer_proposal(self, provider: AuthProviderBase) -> None: + assert not provider.can_complete({"A": []}, security_schemes=SECURITY_SCHEMES) - # Error if a nonexisting mechanism is proposed. - with pytest.raises(KeyError): - provider([{"foo": []}], SECURITY_SCHEMES) + def test_cannot_complete_combined_proposal(self, provider: AuthProviderBase) -> None: + assert not provider.can_complete({"A": [], "B": []}, security_schemes=SECURITY_SCHEMES) - # Succeed without mechanism for an empty proposal. - assert provider([{}], SECURITY_SCHEMES) is None - # Try select a not implemented auth. - with pytest.raises(OpenAPIError): - provider([{"A": []}], SECURITY_SCHEMES) +class TestGlueAuthProvider: + def test_empty_provider_cannot_complete(self) -> None: + provider = GlueAuthProvider() + assert provider.can_complete_http_basic() is False + assert provider.can_complete_oauth2_client_credentials([]) is False + assert provider.can_complete_mutualTLS() is False - # Ignore proposals with multiple mechanisms. - with pytest.raises(OpenAPIError): - provider([{"B": [], "C": []}], SECURITY_SCHEMES) + def test_username_needs_password(self) -> None: + with pytest.raises(AssertionError): + GlueAuthProvider(username="user1") - # Select Basic auth alone and from multiple. - assert isinstance(provider([{"B": []}], SECURITY_SCHEMES), MockBasicAuth) - assert isinstance(provider([{"A": []}, {"B": []}], SECURITY_SCHEMES), MockBasicAuth) + def test_can_complete_basic_auth_and_provide_credentials(self) -> None: + provider = GlueAuthProvider(username="user1", password="secret1") + assert provider.can_complete_http_basic() is True + assert asyncio.run(provider.http_basic_credentials()) == (b"user1", b"secret1") - # Select oauth2 client credentials alone and over basic auth if scopes match. - assert isinstance(provider([{"D": []}], SECURITY_SCHEMES), MockOAuth2CCAuth) - assert isinstance(provider([{"B": []}, {"D": []}], SECURITY_SCHEMES), MockOAuth2CCAuth) - assert isinstance( - provider([{"B": []}, {"D": ["read:pets"]}], SECURITY_SCHEMES), MockOAuth2CCAuth - ) - # Fall back to basic if scope does not match. - assert isinstance( - provider([{"B": []}, {"D": ["read:cattle"]}], SECURITY_SCHEMES), MockBasicAuth - ) + def test_client_id_needs_client_secret(self) -> None: + with pytest.raises(AssertionError): + GlueAuthProvider(client_id="client1") + + def test_can_complete_oauth2_client_credentials_and_provide_them(self) -> None: + provider = GlueAuthProvider(client_id="client1", client_secret="secret1") + assert provider.can_complete_oauth2_client_credentials([]) is True + assert asyncio.run(provider.oauth2_client_credentials()) == ( + b"client1", + b"secret1", + ) + + def test_can_complete_mutualTLS_and_provide_cert(self) -> None: + provider = GlueAuthProvider(cert="FAKECERTIFICATE") + assert provider.can_complete_mutualTLS() is True + assert provider.tls_credentials() == ("FAKECERTIFICATE", None) diff --git a/pulp-glue/tests/test_authentication.py b/pulp-glue/tests/test_authentication.py deleted file mode 100644 index 108f86927..000000000 --- a/pulp-glue/tests/test_authentication.py +++ /dev/null @@ -1,30 +0,0 @@ -import typing as t - -import pytest - -from pulp_glue.common.authentication import OAuth2ClientCredentialsAuth - -pytestmark = pytest.mark.glue - - -def test_sending_no_scope_when_empty(monkeypatch: pytest.MonkeyPatch) -> None: - class OAuth2MockResponse: - def raise_for_status(self) -> None: - return None - - def json(self) -> dict[str, t.Any]: - return {"expires_in": 1, "access_token": "aaa"} - - def _requests_post_mocked( - url: str, data: dict[str, t.Any], **kwargs: t.Any - ) -> OAuth2MockResponse: - assert "scope" not in data - return OAuth2MockResponse() - - monkeypatch.setattr("requests.post", _requests_post_mocked) - - OAuth2ClientCredentialsAuth(token_url="", client_id="", client_secret="")._retrieve_token() - - OAuth2ClientCredentialsAuth( - token_url="", client_id="", client_secret="", scopes=[] - )._retrieve_token() diff --git a/pulp-glue/tests/test_openapi.py b/pulp-glue/tests/test_openapi.py index afa871c68..a646a5bf2 100644 --- a/pulp-glue/tests/test_openapi.py +++ b/pulp-glue/tests/test_openapi.py @@ -1,19 +1,69 @@ +import asyncio import datetime import json import logging import pytest +from multidict import CIMultiDict +from pulp_glue.common.authentication import AuthProviderBase, BasicAuthProvider, GlueAuthProvider from pulp_glue.common.openapi import OpenAPI, _Request, _Response pytestmark = pytest.mark.glue +SECURITY_SCHEMES = { + "A": {"type": "http", "scheme": "bearer"}, + "B": {"type": "http", "scheme": "basic"}, + "C": { + "type": "oauth2", + "flows": { + "implicit": { + "authorizationUrl": "https://example.com/api/oauth/dialog", + "scopes": { + "write:pets": "modify pets in your account", + "read:pets": "read your pets", + }, + }, + "authorizationCode": { + "authorizationUrl": "https://example.com/api/oauth/dialog", + "tokenUrl": "https://example.com/api/oauth/token", + "scopes": { + "write:pets": "modify pets in your account", + "read:pets": "read your pets", + }, + }, + }, + }, + "D": { + "type": "oauth2", + "flows": { + "implicit": { + "authorizationUrl": "https://example.com/api/oauth/dialog", + "scopes": { + "write:pets": "modify pets in your account", + "read:pets": "read your pets", + }, + }, + "clientCredentials": { + "tokenUrl": "https://example.com/api/oauth/token", + "scopes": { + "write:pets": "modify pets in your account", + "read:pets": "read your pets", + }, + }, + }, + }, + "E": {"type": "mutualTLS"}, +} TEST_SCHEMA = json.dumps( { "openapi": "3.0.3", "paths": { "test/": { - "get": {"operationId": "get_test_id", "responses": {200: {}}}, + "get": { + "operationId": "get_test_id", + "responses": {200: {}}, + }, "post": { "operationId": "post_test_id", "requestBody": { @@ -25,9 +75,11 @@ }, }, "responses": {200: {}}, + "security": [{"B": []}], }, } }, + "security": [{}], "components": { "schemas": { "testBody": { @@ -35,14 +87,28 @@ "properties": {"text": {"type": "string"}}, "required": ["text"], } - } + }, + "securitySchemes": SECURITY_SCHEMES, }, } ).encode() def mock_send_request(request: _Request) -> _Response: - return _Response(status_code=200, headers={}, body=b"{}") + if request.url.endswith("oauth/token"): + assert request.method.lower() == "post" + # $ echo -n "client1:secret1" | base64 + assert request.headers["Authorization"] == "Basic Y2xpZW50MTpzZWNyZXQx" + assert isinstance(request.data, dict) + assert request.data["grant_type"] == "client_credentials" + assert set(request.data["scopes"].split(" ")) == {"write:pets", "read:pets"} + return _Response( + status_code=200, + headers={}, + body=json.dumps({"access_token": "DEADBEEF", "expires_in": 600}).encode(), + ) + else: + return _Response(status_code=200, headers={}, body=b"{}") @pytest.fixture @@ -54,13 +120,73 @@ def mock_openapi(monkeypatch: pytest.MonkeyPatch) -> OpenAPI: return openapi +@pytest.fixture +def basic_auth_provider( + monkeypatch: pytest.MonkeyPatch, + mock_openapi: OpenAPI, +) -> AuthProviderBase: + auth_provider = BasicAuthProvider(username="user1", password="password1") + monkeypatch.setattr(mock_openapi, "_auth_provider", auth_provider) + return auth_provider + + +@pytest.fixture +def oauth2_cc_auth_provider( + monkeypatch: pytest.MonkeyPatch, + mock_openapi: OpenAPI, +) -> AuthProviderBase: + auth_provider = GlueAuthProvider(client_id="client1", client_secret="secret1") + monkeypatch.setattr(mock_openapi, "_auth_provider", auth_provider) + return auth_provider + + +@pytest.fixture +def tls_auth_provider( + monkeypatch: pytest.MonkeyPatch, + mock_openapi: OpenAPI, +) -> AuthProviderBase: + auth_provider = GlueAuthProvider(cert="asdf") + monkeypatch.setattr(mock_openapi, "_auth_provider", auth_provider) + return auth_provider + + +class TestRenderRequest: + def test_request_has_no_auth( + self, + mock_openapi: OpenAPI, + basic_auth_provider: AuthProviderBase, + ) -> None: + method, path = mock_openapi.operations["get_test_id"] + path_spec = mock_openapi.api_spec["paths"][path] + request = mock_openapi._render_request(path_spec, method, "test/", {}, {}, None) + assert request.security == [{}] + + def test_request_has_security( + self, + mock_openapi: OpenAPI, + basic_auth_provider: AuthProviderBase, + ) -> None: + method, path = mock_openapi.operations["post_test_id"] + path_spec = mock_openapi.api_spec["paths"][path] + request = mock_openapi._render_request( + path_spec, method, "test/", {}, {}, {"text": "TRACE"} + ) + assert request.security == [{"B": []}] + + class TestParseResponse: - def test_returns_dict_for_no_content(self, mock_openapi: OpenAPI) -> None: + def test_returns_dict_for_no_content( + self, + mock_openapi: OpenAPI, + ) -> None: response = _Response(204, {}, b"") result = mock_openapi._parse_response({}, response) assert result == {} - def test_decodes_json(self, mock_openapi: OpenAPI) -> None: + def test_decodes_json( + self, + mock_openapi: OpenAPI, + ) -> None: response = _Response(200, {"content-type": "application/json"}, b'{"a": 1, "b": "Hallo!"}') result = mock_openapi._parse_response( {"responses": {"200": {"content": {"application/json": {}}}}}, response @@ -161,3 +287,141 @@ def test_post_operation_to_debug( ("pulp_glue.openapi", logging.DEBUG + 3, "Response: 200"), ("pulp_glue.openapi", logging.DEBUG + 1, "b'{}'"), ] + + +class TestSelectProposal: + def test_none_if_no_provider( + self, + mock_openapi: OpenAPI, + ) -> None: + request = _Request( + "", + "GET", + "http://example.org", + CIMultiDict(), + security=[{"A": []}, {"B": []}], + ) + assert mock_openapi._select_proposal(request) is None + + def test_none_if_header_provided( + self, + mock_openapi: OpenAPI, + basic_auth_provider: AuthProviderBase, + ) -> None: + request = _Request( + "", + "GET", + "http://example.org", + CIMultiDict({"Authorization": "Weird Auth"}), + security=[{"A": []}, {"B": []}], + ) + assert mock_openapi._select_proposal(request) is None + + def test_B_with_basic_auth( + self, + mock_openapi: OpenAPI, + basic_auth_provider: AuthProviderBase, + ) -> None: + request = _Request( + "", + "GET", + "http://example.org", + CIMultiDict(), + security=[{"A": []}, {"B": []}], + ) + assert mock_openapi._select_proposal(request) == {"B": []} + + def test_oauth_with_client_credentials( + self, + mock_openapi: OpenAPI, + oauth2_cc_auth_provider: AuthProviderBase, + ) -> None: + request = _Request( + "", + "GET", + "http://example.org", + CIMultiDict(), + security=[ + {"A": []}, + {"B": []}, + {"C": []}, + {"D": []}, + {"E": []}, + ], + ) + assert mock_openapi._select_proposal(request) == {"D": []} + + def test_mutual_tls_with_cert( + self, + mock_openapi: OpenAPI, + tls_auth_provider: AuthProviderBase, + ) -> None: + request = _Request( + "", + "GET", + "http://example.org", + CIMultiDict(), + security=[ + {"A": []}, + {"B": []}, + {"C": []}, + {"D": []}, + {"E": []}, + ], + ) + assert mock_openapi._select_proposal(request) == {"E": []} + + +class TestAuthenticate: + def test_basic_auth( + self, + mock_openapi: OpenAPI, + basic_auth_provider: AuthProviderBase, + ) -> None: + request = _Request("", "GET", "http://example.org", CIMultiDict()) + assert asyncio.run(mock_openapi._authenticate_request(request, {"B": []})) is False + # $ echo -n "user1:password1" | base64 + assert request.headers.get("Authorization") == "Basic dXNlcjE6cGFzc3dvcmQx" + + def test_tls( + self, + mock_openapi: OpenAPI, + tls_auth_provider: AuthProviderBase, + ) -> None: + request = _Request("", "GET", "http://example.org", CIMultiDict()) + assert asyncio.run(mock_openapi._authenticate_request(request, {"E": []})) is False + # No header, and the certificate is handled by the ssl_context. + assert "Authorization" not in request.headers + + def test_oauth2_client_credentials( + self, + mock_openapi: OpenAPI, + oauth2_cc_auth_provider: AuthProviderBase, + ) -> None: + request = _Request("", "GET", "http://example.org", CIMultiDict()) + assert asyncio.run(mock_openapi._authenticate_request(request, {"D": ["scope1"]})) is False + assert request.headers.get("Authorization") == "Bearer DEADBEEF" + + def test_oauth2_client_credentials_reuses_token( + self, + mock_openapi: OpenAPI, + oauth2_cc_auth_provider: AuthProviderBase, + ) -> None: + mock_openapi._oauth2_token = "BABACAFE" + mock_openapi._oauth2_expires = datetime.datetime.now() + datetime.timedelta(seconds=500) + + request = _Request("", "GET", "http://example.org", CIMultiDict()) + assert asyncio.run(mock_openapi._authenticate_request(request, {"D": ["scope1"]})) is True + assert request.headers.get("Authorization") == "Bearer BABACAFE" + + def test_oauth2_client_credentials_refreshes_outdated_token( + self, + mock_openapi: OpenAPI, + oauth2_cc_auth_provider: AuthProviderBase, + ) -> None: + mock_openapi._oauth2_token = "BABACAFE" + mock_openapi._oauth2_expires = datetime.datetime.now() - datetime.timedelta(seconds=500) + + request = _Request("", "GET", "http://example.org", CIMultiDict()) + assert asyncio.run(mock_openapi._authenticate_request(request, {"D": ["scope1"]})) is False + assert request.headers.get("Authorization") == "Bearer DEADBEEF" diff --git a/pulp_cli/__init__.py b/pulp_cli/__init__.py index 5c7340f0b..af011ad6d 100644 --- a/pulp_cli/__init__.py +++ b/pulp_cli/__init__.py @@ -212,8 +212,6 @@ def main( api_kwargs = dict( base_url=base_url, headers=dict((header.split(":", maxsplit=1) for header in headers)), - cert=cert, - key=key, verify_ssl=verify_ssl, refresh_cache=refresh_api, dry_run=dry_run, @@ -229,6 +227,8 @@ def main( timeout=timeout, username=username, password=password, + cert=cert, + key=key, oauth2_client_id=client_id, oauth2_client_secret=client_secret, ) diff --git a/pulp_cli/config.py b/pulp_cli/config.py index b378ddbdf..767e732d4 100644 --- a/pulp_cli/config.py +++ b/pulp_cli/config.py @@ -88,10 +88,10 @@ def headers_callback( click.option("--password", default=None, help=_("Password on pulp server")), click.option("--client-id", default=None, help=_("OAuth2 client ID")), click.option("--client-secret", default=None, help=_("OAuth2 client secret")), - click.option("--cert", default="", help=_("Path to client certificate")), + click.option("--cert", default=None, help=_("Path to client certificate")), click.option( "--key", - default="", + default=None, help=_("Path to client private key. Not required if client cert contains this."), ), click.option("--verify-ssl/--no-verify-ssl", default=True, help=_("Verify SSL connection")), @@ -186,7 +186,7 @@ def validate_config(config: dict[str, t.Any], strict: bool = False) -> None: missing_settings = ( set(SETTINGS) - set(config.keys()) - - {"plugins", "username", "password", "client_id", "client_secret"} + - {"plugins", "username", "password", "client_id", "client_secret", "cert", "key"} ) if missing_settings: errors.append(_("Missing settings: '{}'.").format("','".join(missing_settings))) diff --git a/pulp_cli/generic.py b/pulp_cli/generic.py index 8e1369fcf..30820046a 100644 --- a/pulp_cli/generic.py +++ b/pulp_cli/generic.py @@ -1,3 +1,4 @@ +import asyncio import datetime import json import re @@ -7,11 +8,10 @@ from functools import lru_cache, wraps import click -import requests import schema as s import yaml -from pulp_glue.common.authentication import OAuth2ClientCredentialsAuth +from pulp_glue.common.authentication import AuthProviderBase from pulp_glue.common.context import ( DATETIME_FORMATS, DEFAULT_LIMIT, @@ -31,7 +31,6 @@ ) from pulp_glue.common.exceptions import PulpException, PulpNoWait from pulp_glue.common.i18n import get_translation -from pulp_glue.common.openapi import AuthProviderBase try: from pygments import highlight @@ -154,6 +153,8 @@ def __init__( domain: str = "default", username: str | None = None, password: str | None = None, + cert: str | None = None, + key: str | None = None, oauth2_client_id: str | None = None, oauth2_client_secret: str | None = None, ) -> None: @@ -161,8 +162,9 @@ def __init__( self.password = password self.oauth2_client_id = oauth2_client_id self.oauth2_client_secret = oauth2_client_secret - if not api_kwargs.get("cert"): - api_kwargs["auth_provider"] = PulpCLIAuthProvider(pulp_ctx=self) + self.cert = cert + self.key = key + api_kwargs["auth_provider"] = PulpCLIAuthProvider(pulp_ctx=self) verify_ssl: bool | None = api_kwargs.pop("verify_ssl", None) super().__init__( @@ -194,115 +196,130 @@ def output_result(self, result: t.Any) -> None: click.echo(formatter(result)) -if SECRET_STORAGE: +class PulpCLIAuthProvider(AuthProviderBase): + """ + The auth provider using cli promts to ask for missing passwords. + """ - class SecretStorageBasicAuth(requests.auth.AuthBase): - def __init__(self, pulp_ctx: PulpCLIContext): - self.pulp_ctx = pulp_ctx - assert self.pulp_ctx.username is not None - self.password: str | None = None + def __init__(self, pulp_ctx: PulpCLIContext): + super().__init__() + self.pulp_ctx = pulp_ctx + self._http_basic: tuple[bytes, bytes] | None = None + self._password_in_secretstorage: bool | None = None + self._oauth2_client_credentials: tuple[bytes, bytes] | None = None + + def can_complete_http_basic(self) -> bool: + return self.pulp_ctx.username is not None + + def can_complete_oauth2_client_credentials(self, scopes: list[str]) -> bool: + return self.pulp_ctx.oauth2_client_id is not None - self.attr: dict[str, str] = { + def can_complete_mutualTLS(self) -> bool: + return self.pulp_ctx.cert is not None + + def _fetch_password(self) -> bytes: + if self.pulp_ctx.password is not None: + password = self.pulp_ctx.password.encode("latin1") + elif SECRET_STORAGE: + assert self.pulp_ctx.username is not None + secret_attr: dict[str, str] = { "service": "pulp-cli", "base_url": self.pulp_ctx.api.base_url, "api_path": self.pulp_ctx.api_path, "username": self.pulp_ctx.username, } + with closing(secretstorage.dbus_init()) as connection: + collection = secretstorage.get_default_collection(connection) + item = next(collection.search_items(secret_attr), None) + if item: + password = item.get_secret() + self._password_in_secretstorage = True + else: + password = click.prompt("Password", hide_input=True).encode("latin1") + self._password_in_secretstorage = False + else: + password = click.prompt("Password", hide_input=True).encode("latin1") + return password - def response_hook(self, response: requests.Response, **kwargs: t.Any) -> requests.Response: - # Example adapted from: - # https://docs.python-requests.org/en/latest/_modules/requests/auth/#HTTPDigestAuth - if 200 <= response.status_code < 300 and not self.password_in_manager: - if click.confirm(_("Add password to password manager?")): - assert isinstance(self.password, str) - - with closing(secretstorage.dbus_init()) as connection: - collection = secretstorage.get_default_collection(connection) - collection.create_item( - "Pulp CLI", self.attr, self.password.encode(), replace=True - ) - elif response.status_code == 401 and self.password_in_manager: - if click.confirm(_("Remove failed password from password manager?")): - with closing(secretstorage.dbus_init()) as connection: - collection = secretstorage.get_default_collection(connection) - item = next(collection.search_items(self.attr), None) - if item is not None: - item.delete() - self.password = None - return response - - def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: + async def http_basic_credentials(self) -> tuple[bytes, bytes]: + if self._http_basic is None: assert self.pulp_ctx.username is not None - if self.password is None: - with closing(secretstorage.dbus_init()) as connection: - collection = secretstorage.get_default_collection(connection) - item = next(collection.search_items(self.attr), None) - if item: - self.password = item.get_secret().decode() - self.password_in_manager = True - else: - self.password = str(click.prompt("Password", hide_input=True)) - self.password_in_manager = False - request.register_hook("response", self.response_hook) # type: ignore [no-untyped-call] - return requests.auth.HTTPBasicAuth( # type: ignore [no-any-return] - self.pulp_ctx.username, self.password - )(request) + password = await asyncio.get_running_loop().run_in_executor(None, self._fetch_password) + self._http_basic = self.pulp_ctx.username.encode("latin1"), password + return self._http_basic + + def _save_password_to_storage(self) -> None: + if click.confirm(_("Add password to password manager?")): + with closing(secretstorage.dbus_init()) as connection: + assert self.pulp_ctx.username is not None + assert self._http_basic is not None + + secret_attr: dict[str, str] = { + "service": "pulp-cli", + "base_url": self.pulp_ctx.api.base_url, + "api_path": self.pulp_ctx.api_path, + "username": self.pulp_ctx.username, + } + password = self._http_basic[1] + collection = secretstorage.get_default_collection(connection) + collection.create_item("Pulp CLI", secret_attr, password, replace=True) + + async def auth_success_hook( + self, proposal: dict[str, list[str]], security_schemes: dict[str, dict[str, t.Any]] + ) -> None: + if SECRET_STORAGE and self._password_in_secretstorage is False: + await asyncio.get_running_loop().run_in_executor(None, self._save_password_to_storage) + self._password_in_secretstorage = None + + def _remove_password_from_storage(self) -> None: + if click.confirm(_("Remove failed password from password manager?")): + with closing(secretstorage.dbus_init()) as connection: + assert self.pulp_ctx.username is not None + + secret_attr: dict[str, str] = { + "service": "pulp-cli", + "base_url": self.pulp_ctx.api.base_url, + "api_path": self.pulp_ctx.api_path, + "username": self.pulp_ctx.username, + } + collection = secretstorage.get_default_collection(connection) + item = next(collection.search_items(secret_attr), None) + if item is not None: + item.delete() + self.password = None + + async def auth_failure_hook( + self, proposal: dict[str, list[str]], security_schemes: dict[str, dict[str, t.Any]] + ) -> None: + if SECRET_STORAGE and self._password_in_secretstorage is True: + await asyncio.get_running_loop().run_in_executor( + None, self._remove_password_from_storage + ) + self._password_in_secretstorage = None + self._http_basic = None + self._oauth2_client_credentials = None + + def tls_credentials(self) -> tuple[str, str | None]: + assert self.pulp_ctx.cert is not None -class PulpCLIAuthProvider(AuthProviderBase): - def __init__(self, pulp_ctx: PulpCLIContext): - self.pulp_ctx = pulp_ctx - self._memoized: dict[str, requests.auth.AuthBase | None] = {} - - def basic_auth(self, scopes: list[str]) -> requests.auth.AuthBase | None: - if "BASIC_AUTH" not in self._memoized: - if self.pulp_ctx.username is None: - # No username -> No basic auth. - self._memoized["BASIC_AUTH"] = None - elif self.pulp_ctx.password is None: - # TODO give the user a chance to opt out. - if SECRET_STORAGE: - # We could just try to fetch the password here, - # but we want to get a grip on the response_hook. - self._memoized["BASIC_AUTH"] = SecretStorageBasicAuth(self.pulp_ctx) - else: - self._memoized["BASIC_AUTH"] = requests.auth.HTTPBasicAuth( - self.pulp_ctx.username, click.prompt("Password", hide_input=True) - ) - else: - self._memoized["BASIC_AUTH"] = requests.auth.HTTPBasicAuth( - self.pulp_ctx.username, self.pulp_ctx.password - ) - return self._memoized["BASIC_AUTH"] - - def oauth2_client_credentials_auth( - self, flow: t.Any, scopes: list[str] - ) -> requests.auth.AuthBase | None: - token_url = flow["tokenUrl"] - key = "OAUTH2_CLIENT_CREDENTIALS;" + token_url + ";" + ":".join(scopes) - if key not in self._memoized: - if self.pulp_ctx.oauth2_client_id is None: - # No client_id -> No oauth2 client credentials. - self._memoized[key] = None - elif self.pulp_ctx.oauth2_client_secret is None: - self._memoized[key] = OAuth2ClientCredentialsAuth( - client_id=self.pulp_ctx.oauth2_client_id, - client_secret=click.prompt("Client Secret"), - token_url=flow["tokenUrl"], - # Try to request all possible scopes. - scopes=flow["scopes"], - verify_ssl=self.pulp_ctx.verify_ssl, - ) - else: - self._memoized[key] = OAuth2ClientCredentialsAuth( - client_id=self.pulp_ctx.oauth2_client_id, - client_secret=self.pulp_ctx.oauth2_client_secret, - token_url=flow["tokenUrl"], - # Try to request all possible scopes. - scopes=flow["scopes"], - verify_ssl=self.pulp_ctx.verify_ssl, - ) - return self._memoized[key] + return self.pulp_ctx.cert, self.pulp_ctx.key + + def _fetch_client_secret(self) -> str: + return self.pulp_ctx.oauth2_client_secret or click.prompt("Client Secret", hide_input=True) + + async def oauth2_client_credentials(self) -> tuple[bytes, bytes]: + if self._oauth2_client_credentials is None: + assert self.pulp_ctx.oauth2_client_id is not None + + client_secret = await asyncio.get_running_loop().run_in_executor( + None, self._fetch_client_secret + ) + self._oauth2_client_credentials = ( + self.pulp_ctx.oauth2_client_id.encode("latin1"), + client_secret.encode("latin1"), + ) + return self._oauth2_client_credentials ############################################################################## From 15a3c46f144a8eee10fa40d365e198e0c5ee6a62 Mon Sep 17 00:00:00 2001 From: Matthias Dellweg Date: Tue, 5 Nov 2024 12:10:18 +0100 Subject: [PATCH 2/2] WIP: Async support in pulp-glue Replaces requests with aiohttp and changes the api. --- CHANGES/pulp-glue/+aiohttp.feature | 1 + CHANGES/pulp-glue/+aiohttp.removal | 2 + lint_requirements.txt | 2 +- lower_bounds_constraints.lock | 2 + pulp-glue/pulp_glue/common/openapi.py | 168 ++++++++++++++------------ pulp-glue/pyproject.toml | 3 +- pulp-glue/tests/test_auth_provider.py | 10 +- pulp-glue/tests/test_openapi.py | 2 +- pulpcore/cli/core/task.py | 26 ++-- 9 files changed, 118 insertions(+), 98 deletions(-) create mode 100644 CHANGES/pulp-glue/+aiohttp.feature create mode 100644 CHANGES/pulp-glue/+aiohttp.removal diff --git a/CHANGES/pulp-glue/+aiohttp.feature b/CHANGES/pulp-glue/+aiohttp.feature new file mode 100644 index 000000000..70dac14dc --- /dev/null +++ b/CHANGES/pulp-glue/+aiohttp.feature @@ -0,0 +1 @@ +WIP: Added async api to Pulp glue. diff --git a/CHANGES/pulp-glue/+aiohttp.removal b/CHANGES/pulp-glue/+aiohttp.removal new file mode 100644 index 000000000..4d5165bf5 --- /dev/null +++ b/CHANGES/pulp-glue/+aiohttp.removal @@ -0,0 +1,2 @@ +Replaced requests with aiohttp. +Breaking change: Reworked the contract around the `AuthProvider` to allow authentication to be coded independently of the underlying library. diff --git a/lint_requirements.txt b/lint_requirements.txt index 58f450cdd..796ec00fa 100644 --- a/lint_requirements.txt +++ b/lint_requirements.txt @@ -4,9 +4,9 @@ mypy~=1.19.1 shellcheck-py~=0.11.0.1 # Type annotation stubs +types-aiofiles types-pygments types-PyYAML -types-requests types-setuptools types-toml diff --git a/lower_bounds_constraints.lock b/lower_bounds_constraints.lock index a2e1857e1..3aad3a3ed 100644 --- a/lower_bounds_constraints.lock +++ b/lower_bounds_constraints.lock @@ -1,3 +1,5 @@ +aiofiles==25.1.0 +aiohttp==3.12.0 click==8.0.0 packaging==22.0 PyYAML==5.3 diff --git a/pulp-glue/pulp_glue/common/openapi.py b/pulp-glue/pulp_glue/common/openapi.py index 2fe58debb..217035a78 100644 --- a/pulp-glue/pulp_glue/common/openapi.py +++ b/pulp-glue/pulp_glue/common/openapi.py @@ -12,8 +12,9 @@ from io import BufferedReader from urllib.parse import urlencode, urljoin -import requests -import urllib3 +import aiofiles +import aiofiles.os +import aiohttp from multidict import CIMultiDict, CIMultiDictProxy, MutableMultiMapping from pulp_glue.common import __version__ @@ -134,37 +135,12 @@ def __init__( if cid: self._headers["Correlation-Id"] = cid - self._setup_session() - self._oauth2_lock = asyncio.Lock() self._oauth2_token: str | None = None self._oauth2_expires: datetime = datetime.now() self.load_api(refresh_cache=refresh_cache) - def _setup_session(self) -> None: - # This is specific requests library. - - if self._verify_ssl is False: - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - - self._session: requests.Session = requests.session() - # Don't redirect, because carrying auth accross redirects is unsafe. - self._session.max_redirects = 0 - self._session.headers.update(self._headers) - session_settings = self._session.merge_environment_settings( - self._base_url, {}, None, self._verify_ssl, None - ) - self._session.verify = session_settings["verify"] - self._session.proxies = session_settings["proxies"] - - if self._auth_provider is not None and self._auth_provider.can_complete_mutualTLS(): - cert, key = self._auth_provider.tls_credentials() - if key is not None: - self._session.cert = (cert, key) - else: - self._session.cert = cert - @property def base_url(self) -> str: return self._base_url @@ -188,7 +164,10 @@ def ssl_context(self) -> t.Union[ssl.SSLContext, bool]: return _ssl_context def load_api(self, refresh_cache: bool = False) -> None: - # TODO: Find a way to invalidate caches on upstream change + asyncio.run(self._load_api(refresh_cache=refresh_cache)) + + async def _load_api(self, refresh_cache: bool = False) -> None: + # TODO: Find a way to invalidate caches on upstream change. xdg_cache_home: str = os.environ.get("XDG_CACHE_HOME") or "~/.cache" apidoc_cache: str = os.path.join( os.path.expanduser(xdg_cache_home), @@ -200,17 +179,17 @@ def load_api(self, refresh_cache: bool = False) -> None: if refresh_cache: # Fake that we did not find the cache. raise OSError() - with open(apidoc_cache, "rb") as f: - data: bytes = f.read() + async with aiofiles.open(apidoc_cache, mode="rb") as f: + data: bytes = await f.read() self._parse_api(data) except Exception: - # Try again with a freshly downloaded version - data = self._download_api() + # Try again with a freshly downloaded version. + data = await self._download_api() self._parse_api(data) - # Write to cache as it seems to be valid - os.makedirs(os.path.dirname(apidoc_cache), exist_ok=True) - with open(apidoc_cache, "bw") as f: - f.write(data) + # Write to cache as it seems to be valid. + await aiofiles.os.makedirs(os.path.dirname(apidoc_cache), exist_ok=True) + async with aiofiles.open(apidoc_cache, mode="bw") as f: + await f.write(data) def _parse_api(self, data: bytes) -> None: self.api_spec: dict[str, t.Any] = json.loads(data) @@ -225,15 +204,18 @@ def _parse_api(self, data: bytes) -> None: if method in {"get", "put", "post", "delete", "options", "head", "patch", "trace"} } - def _download_api(self) -> bytes: - try: - response: requests.Response = self._session.get(urljoin(self._base_url, self._doc_path)) - except requests.RequestException as e: - raise OpenAPIError(str(e)) - response.raise_for_status() - if "Correlation-Id" in response.headers: - self._set_correlation_id(response.headers["Correlation-Id"]) - return response.content + async def _download_api(self) -> bytes: + response = await self._send_request( + _Request( + operation_id="", + method="get", + url=urljoin(self._base_url, self._doc_path), + headers=self._headers, + ) + ) + if response.status_code != 200: + raise OpenAPIError(_("Failed to find api docs.")) + return response.body def _set_correlation_id(self, correlation_id: str) -> None: if "Correlation-Id" in self._headers: @@ -245,8 +227,6 @@ def _set_correlation_id(self, correlation_id: str) -> None: ) else: self._headers["Correlation-Id"] = correlation_id - # Do it for requests too... - self._session.headers["Correlation-Id"] = correlation_id def param_spec( self, operation_id: str, param_type: str, required: bool = False @@ -463,7 +443,7 @@ def _render_request( security=security, ) - def _log_request(self, request: _Request) -> None: + async def _log_request(self, request: _Request) -> None: if request.params: qs = urlencode(request.params) self._debug_callback(1, f"{request.operation_id} : {request.method} {request.url}?{qs}") @@ -489,7 +469,6 @@ def _select_proposal( if ( request.security and "Authorization" not in request.headers - and "Authorization" not in self._session.headers and self._auth_provider is not None ): security_schemes: dict[str, dict[str, t.Any]] = self.api_spec["components"][ @@ -561,7 +540,7 @@ async def _fetch_oauth2_token(self, flow: dict[str, t.Any]) -> bool: headers={"Authorization": f"Basic {secret.decode()}"}, data=data, ) - response = self._send_request(request) + response = await self._send_request(request) if response.status_code < 200 or response.status_code >= 300: raise OpenAPIError("Failed to fetch OAuth2 token") result = json.loads(response.body) @@ -570,38 +549,55 @@ async def _fetch_oauth2_token(self, flow: dict[str, t.Any]) -> bool: new_token = True return new_token - def _send_request( + async def _send_request( self, request: _Request, ) -> _Response: - # This function uses requests to translate the _Request into a _Response. + # This function uses aiohttp to translate the _Request into a _Response. + data: aiohttp.FormData | dict[str, t.Any] | str | None + if request.files: + assert isinstance(request.data, dict) + # Maybe assert on the content type header. + data = aiohttp.FormData(default_to_multipart=True) + for key, value in request.data.items(): + data.add_field(key, encode_param(value)) + for key, (name, value, content_type) in request.files.items(): + data.add_field(key, value, filename=name, content_type=content_type) + else: + data = request.data try: - r = self._session.request( - request.method, - request.url, - params=request.params, - headers=request.headers, - data=request.data, - files=request.files, - ) - response = _Response(status_code=r.status_code, headers=r.headers, body=r.content) - except requests.TooManyRedirects as e: - assert e.response is not None + async with aiohttp.ClientSession() as session: + async with session.request( + request.method, + request.url, + params=request.params, + headers=request.headers, + data=data, + ssl=self.ssl_context, + max_redirects=0, + ) as r: + response_body = await r.read() + response = _Response( + status_code=r.status, headers=r.headers, body=response_body + ) + except aiohttp.TooManyRedirects as e: + # We could handle that in the middleware... + assert e.history[-1] is not None raise OpenAPIError( _( "Received redirect to '{new_url} from {old_url}'." " Please check your configuration." ).format( - new_url=e.response.headers["location"], + new_url=e.history[-1].headers["location"], old_url=request.url, ) ) - except requests.RequestException as e: + except aiohttp.ClientResponseError as e: raise OpenAPIError(str(e)) return response - def _log_response(self, response: _Response) -> None: + async def _log_response(self, response: _Response) -> None: self._debug_callback( 1, _("Response: {status_code}").format(status_code=response.status_code) ) @@ -648,6 +644,22 @@ def call( parameters: dict[str, t.Any] | None = None, body: dict[str, t.Any] | None = None, validate_body: bool = True, + ) -> t.Any: + return asyncio.run( + self.async_call( + operation_id=operation_id, + parameters=parameters, + body=body, + validate_body=validate_body, + ) + ) + + async def async_call( + self, + operation_id: str, + parameters: dict[str, t.Any] | None = None, + body: dict[str, t.Any] | None = None, + validate_body: bool = True, ) -> t.Any: """ Make a call to the server. @@ -702,7 +714,7 @@ def call( body, validate_body=validate_body, ) - self._log_request(request) + await self._log_request(request) if self._dry_run and request.method.upper() not in SAFE_METHODS: raise UnsafeCallError(_("Call aborted due to safe mode")) @@ -710,29 +722,25 @@ def call( may_retry = False if proposal := self._select_proposal(request): assert len(proposal) == 1, "More complex security proposals are not implemented." - may_retry = asyncio.run(self._authenticate_request(request, proposal)) + may_retry = await self._authenticate_request(request, proposal) - response = self._send_request(request) + response = await self._send_request(request) if proposal is not None: assert self._auth_provider is not None if may_retry and response.status_code == 401: self._oauth2_token = None - asyncio.run(self._authenticate_request(request, proposal)) - response = self._send_request(request) + await self._authenticate_request(request, proposal) + response = await self._send_request(request) if response.status_code >= 200 and response.status_code < 300: - asyncio.run( - self._auth_provider.auth_success_hook( - proposal, self.api_spec["components"]["securitySchemes"] - ) + await self._auth_provider.auth_success_hook( + proposal, self.api_spec["components"]["securitySchemes"] ) elif response.status_code == 401: - asyncio.run( - self._auth_provider.auth_failure_hook( - proposal, self.api_spec["components"]["securitySchemes"] - ) + await self._auth_provider.auth_failure_hook( + proposal, self.api_spec["components"]["securitySchemes"] ) - self._log_response(response) + await self._log_response(response) return self._parse_response(method_spec, response) diff --git a/pulp-glue/pyproject.toml b/pulp-glue/pyproject.toml index 414089f56..906256e6b 100644 --- a/pulp-glue/pyproject.toml +++ b/pulp-glue/pyproject.toml @@ -23,9 +23,10 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ + "aiofiles>=25.1.0,<25.2", + "aiohttp>=3.12.0,<3.14", "multidict>=6.0.5,<6.8", "packaging>=22.0,<=26.0", # CalVer - "requests>=2.24.0,<2.33", "tomli>=2.0.0,<2.1;python_version<'3.11'", ] diff --git a/pulp-glue/tests/test_auth_provider.py b/pulp-glue/tests/test_auth_provider.py index 0bcde6da0..07eed191d 100644 --- a/pulp-glue/tests/test_auth_provider.py +++ b/pulp-glue/tests/test_auth_provider.py @@ -63,10 +63,7 @@ def test_can_complete_basic(self, provider: AuthProviderBase) -> None: assert provider.can_complete_http_basic() def test_provides_username_and_password(self, provider: AuthProviderBase) -> None: - assert asyncio.run(provider.http_basic_credentials()) == ( - b"user1", - b"password1", - ) + assert asyncio.run(provider.http_basic_credentials()) == (b"user1", b"password1") def test_cannot_complete_mutualTLS(self, provider: AuthProviderBase) -> None: assert not provider.can_complete_mutualTLS() @@ -104,10 +101,7 @@ def test_client_id_needs_client_secret(self) -> None: def test_can_complete_oauth2_client_credentials_and_provide_them(self) -> None: provider = GlueAuthProvider(client_id="client1", client_secret="secret1") assert provider.can_complete_oauth2_client_credentials([]) is True - assert asyncio.run(provider.oauth2_client_credentials()) == ( - b"client1", - b"secret1", - ) + assert asyncio.run(provider.oauth2_client_credentials()) == (b"client1", b"secret1") def test_can_complete_mutualTLS_and_provide_cert(self) -> None: provider = GlueAuthProvider(cert="FAKECERTIFICATE") diff --git a/pulp-glue/tests/test_openapi.py b/pulp-glue/tests/test_openapi.py index a646a5bf2..47f29de5f 100644 --- a/pulp-glue/tests/test_openapi.py +++ b/pulp-glue/tests/test_openapi.py @@ -94,7 +94,7 @@ ).encode() -def mock_send_request(request: _Request) -> _Response: +async def mock_send_request(request: _Request) -> _Response: if request.url.endswith("oauth/token"): assert request.method.lower() == "post" # $ echo -n "client1:secret1" | base64 diff --git a/pulpcore/cli/core/task.py b/pulpcore/cli/core/task.py index 6ece44767..98151e748 100644 --- a/pulpcore/cli/core/task.py +++ b/pulpcore/cli/core/task.py @@ -1,13 +1,17 @@ +import asyncio import re from contextlib import suppress from datetime import datetime from pathlib import Path +import aiofiles +import aiohttp import click from pulp_glue.common.context import ( DATETIME_FORMATS, PluginRequirement, + PulpContext, PulpEntityContext, ) from pulp_glue.common.exceptions import PulpException @@ -175,6 +179,20 @@ def cancel( task_ctx.cancel(task_ctx.pulp_href) +async def _download_artifacts( + pulp_ctx: PulpContext, urls: dict[str, str], profile_artifact_dir: Path +) -> None: + async with aiohttp.ClientSession() as session: + for name, url in urls.items(): + profile_artifact_path = profile_artifact_dir / name + click.echo(_("Downloading {path}").format(path=profile_artifact_path)) + async with session.get(url, ssl=pulp_ctx.api.ssl_context) as response: + assert response.status == 200 + async with aiofiles.open(profile_artifact_path, "wb") as fp: + async for chunk in response.content.iter_chunked(1024): + await fp.write(chunk) + + @task.command() @href_option @uuid_option @@ -197,13 +215,7 @@ def profile_artifact_urls( uuid = uuid_match.group("uuid") profile_artifact_dir = Path(".") / f"task_profile-{task_name}-{uuid}" profile_artifact_dir.mkdir(exist_ok=True) - with pulp_ctx.api._session as session: - for name, url in urls.items(): - profile_artifact_path = profile_artifact_dir / name - click.echo(_("Downloading {path}").format(path=profile_artifact_path)) - response = session.get(url) - response.raise_for_status() - profile_artifact_path.write_bytes(response.content) + asyncio.run(_download_artifacts(pulp_ctx, urls, profile_artifact_dir)) @task.command()