diff --git a/CHANGES/pulp-glue/+aiohttp.feature b/CHANGES/pulp-glue/+aiohttp.feature new file mode 100644 index 000000000..70dac14dc --- /dev/null +++ b/CHANGES/pulp-glue/+aiohttp.feature @@ -0,0 +1 @@ +WIP: Added async api to Pulp glue. diff --git a/CHANGES/pulp-glue/+aiohttp.removal b/CHANGES/pulp-glue/+aiohttp.removal new file mode 100644 index 000000000..4d5165bf5 --- /dev/null +++ b/CHANGES/pulp-glue/+aiohttp.removal @@ -0,0 +1,2 @@ +Replaced requests with aiohttp. +Breaking change: Reworked the contract around the `AuthProvider` to allow authentication to be coded independently of the underlying library. diff --git a/CHANGES/pulp-glue/+auth_recontract.removal b/CHANGES/pulp-glue/+auth_recontract.removal new file mode 100644 index 000000000..dfc83f18d --- /dev/null +++ b/CHANGES/pulp-glue/+auth_recontract.removal @@ -0,0 +1 @@ +Breaking change: Reworked the contract around the `AuthProvider` to allow authentication to be coded independently of the underlying library. diff --git a/lint_requirements.txt b/lint_requirements.txt index 58f450cdd..796ec00fa 100644 --- a/lint_requirements.txt +++ b/lint_requirements.txt @@ -4,9 +4,9 @@ mypy~=1.19.1 shellcheck-py~=0.11.0.1 # Type annotation stubs +types-aiofiles types-pygments types-PyYAML -types-requests types-setuptools types-toml diff --git a/lower_bounds_constraints.lock b/lower_bounds_constraints.lock index a2e1857e1..3aad3a3ed 100644 --- a/lower_bounds_constraints.lock +++ b/lower_bounds_constraints.lock @@ -1,3 +1,5 @@ +aiofiles==25.1.0 +aiohttp==3.12.0 click==8.0.0 packaging==22.0 PyYAML==5.3 diff --git a/pulp-glue/pulp_glue/common/authentication.py b/pulp-glue/pulp_glue/common/authentication.py index a7998bf44..991aa5af9 100644 --- a/pulp-glue/pulp_glue/common/authentication.py +++ b/pulp-glue/pulp_glue/common/authentication.py @@ -1,97 +1,140 @@ import typing as t -from datetime import datetime, timedelta -import requests - -class OAuth2ClientCredentialsAuth(requests.auth.AuthBase): - """ - This implements the OAuth2 ClientCredentials Grant authentication flow. - https://datatracker.ietf.org/doc/html/rfc6749#section-4.4 +class AuthProviderBase: """ + Base class for auth providers. - def __init__( - self, - client_id: str, - client_secret: str, - token_url: str, - scopes: list[str] | None = None, - verify_ssl: str | bool | None = None, - ): - self._token_server_auth = requests.auth.HTTPBasicAuth(client_id, client_secret) - self._token_url = token_url - self._scopes = scopes - self._verify_ssl = verify_ssl + This abstract base class will analyze the authentication proposals of the openapi specs. + Different authentication schemes can be implemented in subclasses. + """ - self._access_token: str | None = None - self._expire_at: datetime | None = None + def can_complete_http_basic(self) -> bool: + return False + + def can_complete_mutualTLS(self) -> bool: + return False + + def can_complete_oauth2_client_credentials(self, scopes: list[str]) -> bool: + return False + + def can_complete_scheme(self, scheme: dict[str, t.Any], scopes: list[str]) -> bool: + if scheme["type"] == "http": + if scheme["scheme"] == "basic": + return self.can_complete_http_basic() + elif scheme["type"] == "mutualTLS": + return self.can_complete_mutualTLS() + elif scheme["type"] == "oauth2": + for flow_name, flow in scheme["flows"].items(): + if flow_name == "clientCredentials" and self.can_complete_oauth2_client_credentials( + flow["scopes"] + ): + return True + return False + + def can_complete( + self, proposal: dict[str, list[str]], security_schemes: dict[str, dict[str, t.Any]] + ) -> bool: + for name, scopes in proposal.items(): + scheme = security_schemes.get(name) + if scheme is None or not self.can_complete_scheme(scheme, scopes): + return False + # This covers the case where `[]` allows for no auth at all. + return True + + async def auth_success_hook( + self, proposal: dict[str, list[str]], security_schemes: dict[str, dict[str, t.Any]] + ) -> None: + pass + + async def auth_failure_hook( + self, proposal: dict[str, list[str]], security_schemes: dict[str, dict[str, t.Any]] + ) -> None: + pass + + async def http_basic_credentials(self) -> tuple[bytes, bytes]: + raise NotImplementedError() + + async def oauth2_client_credentials(self) -> tuple[bytes, bytes]: + raise NotImplementedError() + + def tls_credentials(self) -> tuple[str, str | None]: + raise NotImplementedError() + + +class BasicAuthProvider(AuthProviderBase): + """ + AuthProvider providing basic auth with fixed `username`, `password`. + """ - def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: - if self._expire_at is None or self._expire_at < datetime.now(): - self._retrieve_token() + def __init__(self, username: t.AnyStr, password: t.AnyStr): + super().__init__() + self.username: bytes = username.encode("latin1") if isinstance(username, str) else username + self.password: bytes = password.encode("latin1") if isinstance(password, str) else password - assert self._access_token is not None + def can_complete_http_basic(self) -> bool: + return True - request.headers["Authorization"] = f"Bearer {self._access_token}" + async def http_basic_credentials(self) -> tuple[bytes, bytes]: + return self.username, self.password - # Call to untyped function "register_hook" in typed context - request.register_hook("response", self._handle401) # type: ignore[no-untyped-call] - return request +class GlueAuthProvider(AuthProviderBase): + """ + AuthProvider allowing to be used with prepared credentials. + """ - def _handle401( + def __init__( self, - response: requests.Response, - **kwargs: t.Any, - ) -> requests.Response: - if response.status_code != 401: - return response - - # If we get this far, probably the token is not valid anymore. - - # Try to reach for a new token once. - self._retrieve_token() - - assert self._access_token is not None - - # Consume content and release the original connection - # to allow our new request to reuse the same one. - response.content - response.close() - prepared_new_request = response.request.copy() - - prepared_new_request.headers["Authorization"] = f"Bearer {self._access_token}" - - # Avoid to enter into an infinity loop. - # Call to untyped function "deregister_hook" in typed context - prepared_new_request.deregister_hook( # type: ignore[no-untyped-call] - "response", self._handle401 - ) - - # "Response" has no attribute "connection" - new_response: requests.Response = response.connection.send(prepared_new_request, **kwargs) - new_response.history.append(response) - new_response.request = prepared_new_request - - return new_response - - def _retrieve_token(self) -> None: - data = { - "grant_type": "client_credentials", - } - - if self._scopes: - data["scope"] = " ".join(self._scopes) - - response: requests.Response = requests.post( - self._token_url, - data=data, - auth=self._token_server_auth, - verify=self._verify_ssl, - ) - - response.raise_for_status() - - token = response.json() - self._expire_at = datetime.now() + timedelta(seconds=token["expires_in"]) - self._access_token = token["access_token"] + *, + username: t.AnyStr | None = None, + password: t.AnyStr | None = None, + client_id: t.AnyStr | None = None, + client_secret: t.AnyStr | None = None, + cert: str | None = None, + key: str | None = None, + ): + super().__init__() + self.username: bytes | None = None + self.password: bytes | None = None + self.client_id: bytes | None = None + self.client_secret: bytes | None = None + self.cert: str | None = cert + self.key: str | None = key + + if username is not None: + assert password is not None + self.username = username.encode("latin1") if isinstance(username, str) else username + self.password = password.encode("latin1") if isinstance(password, str) else password + if client_id is not None: + assert client_secret is not None + self.client_id = client_id.encode("latin1") if isinstance(client_id, str) else client_id + self.client_secret = ( + client_secret.encode("latin1") if isinstance(client_secret, str) else client_secret + ) + + if cert is None and key is not None: + raise RuntimeError("Key can only be used together with a cert.") + + def can_complete_http_basic(self) -> bool: + return self.username is not None + + def can_complete_oauth2_client_credentials(self, scopes: list[str]) -> bool: + return self.client_id is not None + + def can_complete_mutualTLS(self) -> bool: + return self.cert is not None + + async def http_basic_credentials(self) -> tuple[bytes, bytes]: + assert self.username is not None + assert self.password is not None + return self.username, self.password + + async def oauth2_client_credentials(self) -> tuple[bytes, bytes]: + assert self.client_id is not None + assert self.client_secret is not None + return self.client_id, self.client_secret + + def tls_credentials(self) -> tuple[str, str | None]: + assert self.cert is not None + return (self.cert, self.key) diff --git a/pulp-glue/pulp_glue/common/context.py b/pulp-glue/pulp_glue/common/context.py index fff028eff..18b73904b 100644 --- a/pulp-glue/pulp_glue/common/context.py +++ b/pulp-glue/pulp_glue/common/context.py @@ -9,6 +9,7 @@ from packaging.specifiers import SpecifierSet +from pulp_glue.common.authentication import GlueAuthProvider from pulp_glue.common.exceptions import ( NotImplementedFake, OpenAPIError, @@ -19,7 +20,7 @@ UnsafeCallError, ) from pulp_glue.common.i18n import get_translation -from pulp_glue.common.openapi import BasicAuthProvider, OpenAPI +from pulp_glue.common.openapi import OpenAPI if sys.version_info >= (3, 11): import tomllib @@ -202,6 +203,20 @@ def patch_upstream_pulp_replicate_request_body(api: OpenAPI) -> None: operation.pop("requestBody", None) +@api_quirk(PluginRequirement("core", specifier="<3.85")) +def patch_security_scheme_mutual_tls(api: OpenAPI) -> None: + # Trick to allow tls cert auth on older Pulp. + if (components := api.api_spec.get("components")) is not None: + if (security_schemes := components.get("securitySchemes")) is not None: + # Only if it is going to be idempotent... + if "gluePatchTLS" not in security_schemes: + security_schemes["gluePatchTLS"] = {"type": "mutualTLS"} + for method, path in api.operations.values(): + operation = api.api_spec["paths"][path][method] + if "security" in operation: + operation["security"].append({"gluePatchTLS": []}) + + class PulpContext: """ Abstract class for the global PulpContext object. @@ -335,8 +350,13 @@ def from_config(cls, config: dict[str, t.Any]) -> "t.Self": api_kwargs: dict[str, t.Any] = { "base_url": config["base_url"], } - if "username" in config: - api_kwargs["auth_provider"] = BasicAuthProvider(config["username"], config["password"]) + api_kwargs["auth_provider"] = GlueAuthProvider( + **{ + k: v + for k, v in config.items() + if k in {"username", "password", "client_id", "client_secret", "cert", "key"} + } + ) if "headers" in config: api_kwargs["headers"] = dict( (header.split(":", maxsplit=1) for header in config["headers"]) @@ -385,7 +405,9 @@ def api(self) -> OpenAPI: # Deprecated for 'auth'. if not password: password = self.prompt("password", hide_input=True) - self._api_kwargs["auth_provider"] = BasicAuthProvider(username, password) + self._api_kwargs["auth_provider"] = GlueAuthProvider( + username=username, password=password + ) warnings.warn( "Using 'username' and 'password' with 'PulpContext' is deprecated. " "Use an auth provider with the 'auth_provider' argument instead.", @@ -399,10 +421,10 @@ def api(self) -> OpenAPI: ) except OpenAPIError as e: raise PulpException(str(e)) + self._patch_api_spec() # Rerun scheduled version checks for plugin_requirement in self._needed_plugins: self.needs_plugin(plugin_requirement) - self._patch_api_spec() return self._api @property diff --git a/pulp-glue/pulp_glue/common/openapi.py b/pulp-glue/pulp_glue/common/openapi.py index ce7e32e99..217035a78 100644 --- a/pulp-glue/pulp_glue/common/openapi.py +++ b/pulp-glue/pulp_glue/common/openapi.py @@ -1,18 +1,24 @@ +import asyncio import json import logging import os +import ssl import typing as t import warnings -from collections import defaultdict +from base64 import b64encode from dataclasses import dataclass +from datetime import datetime, timedelta +from functools import cached_property from io import BufferedReader from urllib.parse import urlencode, urljoin -import requests -import urllib3 +import aiofiles +import aiofiles.os +import aiohttp from multidict import CIMultiDict, CIMultiDictProxy, MutableMultiMapping from pulp_glue.common import __version__ +from pulp_glue.common.authentication import AuthProviderBase from pulp_glue.common.exceptions import ( OpenAPIError, PulpAuthenticationFailed, @@ -38,7 +44,7 @@ class _Request: operation_id: str method: str url: str - headers: MutableMultiMapping[str] | CIMultiDictProxy[str] | t.MutableMapping[str, str] + headers: MutableMultiMapping[str] | CIMultiDict[str] | t.MutableMapping[str, str] params: dict[str, str] | None = None data: dict[str, t.Any] | str | None = None files: dict[str, tuple[str, UploadType, str]] | None = None @@ -52,99 +58,6 @@ class _Response: body: bytes -class AuthProviderBase: - """ - Base class for auth providers. - - This abstract base class will analyze the authentication proposals of the openapi specs. - Different authentication schemes should be implemented by subclasses. - Returned auth objects need to be compatible with `requests.auth.AuthBase`. - """ - - def basic_auth(self, scopes: list[str]) -> requests.auth.AuthBase | None: - """Implement this to provide means of http basic auth.""" - return None - - def http_auth( - self, security_scheme: dict[str, t.Any], scopes: list[str] - ) -> requests.auth.AuthBase | None: - """Select a suitable http auth scheme or return None.""" - # https://www.iana.org/assignments/http-authschemes/http-authschemes.xhtml - if security_scheme["scheme"] == "basic": - result = self.basic_auth(scopes) - if result: - return result - return None - - def oauth2_client_credentials_auth( - self, flow: t.Any, scopes: list[str] - ) -> requests.auth.AuthBase | None: - """Implement this to provide other authentication methods.""" - return None - - def oauth2_auth( - self, security_scheme: dict[str, t.Any], scopes: list[str] - ) -> requests.auth.AuthBase | None: - """Select a suitable oauth2 flow or return None.""" - # Check flows by preference. - if "clientCredentials" in security_scheme["flows"]: - flow = security_scheme["flows"]["clientCredentials"] - # Select this flow only if it claims to provide all the necessary scopes. - # This will allow subsequent auth proposals to be considered. - if set(scopes) - set(flow["scopes"]): - return None - - result = self.oauth2_client_credentials_auth(flow, scopes) - if result: - return result - return None - - def __call__( - self, - security: list[dict[str, list[str]]], - security_schemes: dict[str, dict[str, t.Any]], - ) -> requests.auth.AuthBase | None: - # Reorder the proposals by their type to prioritize properly. - # Select only single mechanism proposals on the way. - proposed_schemes: dict[str, dict[str, list[str]]] = defaultdict(dict) - for proposal in security: - if len(proposal) == 0: - # Empty proposal: No authentication needed. Shortcut return. - return None - if len(proposal) == 1: - name, scopes = list(proposal.items())[0] - proposed_schemes[security_schemes[name]["type"]][name] = scopes - # Ignore all proposals with more than one required auth mechanism. - - # Check for auth schemes by preference. - if "oauth2" in proposed_schemes: - for name, scopes in proposed_schemes["oauth2"].items(): - result = self.oauth2_auth(security_schemes[name], scopes) - if result: - return result - - # if we get here, either no-oauth2, OR we couldn't find creds - if "http" in proposed_schemes: - for name, scopes in proposed_schemes["http"].items(): - result = self.http_auth(security_schemes[name], scopes) - if result: - return result - - raise OpenAPIError(_("No suitable auth scheme found.")) - - -class BasicAuthProvider(AuthProviderBase): - """ - Implementation for AuthProviderBase providing basic auth with fixed `username`, `password`. - """ - - def __init__(self, username: str, password: str): - self.auth = requests.auth.HTTPBasicAuth(username, password) - - def basic_auth(self, scopes: list[str]) -> requests.auth.AuthBase | None: - return self.auth - - class OpenAPI: """ The abstraction Layer to interact with a server providing an openapi v3 specification. @@ -154,7 +67,7 @@ class OpenAPI: served api. doc_path: Path of the json api doc schema relative to the `base_url`. headers: Dictionary of additional request headers. - auth_provider: Object that returns requests auth objects according to the api spec. + auth_provider: Object that can be questioned for credentials according to the api spec. cert: Client certificate used for auth. key: Matching key for `cert` if not already included. verify_ssl: Whether to check server TLS certificates agains a CA. @@ -210,9 +123,8 @@ def __init__( self._dry_run: bool = dry_run self._headers = CIMultiDict(headers or {}) self._verify_ssl = verify_ssl + self._auth_provider = auth_provider - self._cert = cert - self._key = key self._headers.update( { @@ -223,36 +135,12 @@ def __init__( if cid: self._headers["Correlation-Id"] = cid - self._setup_session() + self._oauth2_lock = asyncio.Lock() + self._oauth2_token: str | None = None + self._oauth2_expires: datetime = datetime.now() self.load_api(refresh_cache=refresh_cache) - def _setup_session(self) -> None: - # This is specific requests library. - - if self._verify_ssl is False: - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - - self._session: requests.Session = requests.session() - # Don't redirect, because carrying auth accross redirects is unsafe. - self._session.max_redirects = 0 - self._session.headers.update(self._headers) - if self._auth_provider: - if self._cert or self._key: - raise OpenAPIError(_("Cannot use both 'auth' and 'cert'.")) - else: - if self._cert and self._key: - self._session.cert = (self._cert, self._key) - elif self._cert: - self._session.cert = self._cert - elif self._key: - raise OpenAPIError(_("Cert is required if key is set.")) - session_settings = self._session.merge_environment_settings( - self._base_url, {}, None, self._verify_ssl, None - ) - self._session.verify = session_settings["verify"] - self._session.proxies = session_settings["proxies"] - @property def base_url(self) -> str: return self._base_url @@ -261,8 +149,25 @@ def base_url(self) -> str: def cid(self) -> str | None: return self._headers.get("Correlation-Id") + @cached_property + def ssl_context(self) -> t.Union[ssl.SSLContext, bool]: + _ssl_context: t.Union[ssl.SSLContext, bool] + if self._verify_ssl is False: + _ssl_context = False + else: + if isinstance(self._verify_ssl, str): + _ssl_context = ssl.create_default_context(cafile=self._verify_ssl) + else: + _ssl_context = ssl.create_default_context() + if self._auth_provider is not None and self._auth_provider.can_complete_mutualTLS(): + _ssl_context.load_cert_chain(*self._auth_provider.tls_credentials()) + return _ssl_context + def load_api(self, refresh_cache: bool = False) -> None: - # TODO: Find a way to invalidate caches on upstream change + asyncio.run(self._load_api(refresh_cache=refresh_cache)) + + async def _load_api(self, refresh_cache: bool = False) -> None: + # TODO: Find a way to invalidate caches on upstream change. xdg_cache_home: str = os.environ.get("XDG_CACHE_HOME") or "~/.cache" apidoc_cache: str = os.path.join( os.path.expanduser(xdg_cache_home), @@ -272,18 +177,19 @@ def load_api(self, refresh_cache: bool = False) -> None: ) try: if refresh_cache: + # Fake that we did not find the cache. raise OSError() - with open(apidoc_cache, "rb") as f: - data: bytes = f.read() + async with aiofiles.open(apidoc_cache, mode="rb") as f: + data: bytes = await f.read() self._parse_api(data) except Exception: - # Try again with a freshly downloaded version - data = self._download_api() + # Try again with a freshly downloaded version. + data = await self._download_api() self._parse_api(data) - # Write to cache as it seems to be valid - os.makedirs(os.path.dirname(apidoc_cache), exist_ok=True) - with open(apidoc_cache, "bw") as f: - f.write(data) + # Write to cache as it seems to be valid. + await aiofiles.os.makedirs(os.path.dirname(apidoc_cache), exist_ok=True) + async with aiofiles.open(apidoc_cache, mode="bw") as f: + await f.write(data) def _parse_api(self, data: bytes) -> None: self.api_spec: dict[str, t.Any] = json.loads(data) @@ -298,15 +204,18 @@ def _parse_api(self, data: bytes) -> None: if method in {"get", "put", "post", "delete", "options", "head", "patch", "trace"} } - def _download_api(self) -> bytes: - try: - response: requests.Response = self._session.get(urljoin(self._base_url, self._doc_path)) - except requests.RequestException as e: - raise OpenAPIError(str(e)) - response.raise_for_status() - if "Correlation-Id" in response.headers: - self._set_correlation_id(response.headers["Correlation-Id"]) - return response.content + async def _download_api(self) -> bytes: + response = await self._send_request( + _Request( + operation_id="", + method="get", + url=urljoin(self._base_url, self._doc_path), + headers=self._headers, + ) + ) + if response.status_code != 200: + raise OpenAPIError(_("Failed to find api docs.")) + return response.body def _set_correlation_id(self, correlation_id: str) -> None: if "Correlation-Id" in self._headers: @@ -318,8 +227,6 @@ def _set_correlation_id(self, correlation_id: str) -> None: ) else: self._headers["Correlation-Id"] = correlation_id - # Do it for requests too... - self._session.headers["Correlation-Id"] = correlation_id def param_spec( self, operation_id: str, param_type: str, required: bool = False @@ -536,7 +443,7 @@ def _render_request( security=security, ) - def _log_request(self, request: _Request) -> None: + async def _log_request(self, request: _Request) -> None: if request.params: qs = urlencode(request.params) self._debug_callback(1, f"{request.operation_id} : {request.method} {request.url}?{qs}") @@ -554,51 +461,143 @@ def _log_request(self, request: _Request) -> None: for key, (name, _dummy, content_type) in request.files.items(): self._debug_callback(3, f"{key} <- {name} [{content_type}]") - def _send_request( + def _select_proposal( self, request: _Request, - ) -> _Response: - # This function uses requests to translate the _Request into a _Response. - if request.security and self._auth_provider: - if "Authorization" in self._session.headers: - # Bad idea, but you wanted it that way. - auth = None + ) -> dict[str, list[str]] | None: + proposal = None + if ( + request.security + and "Authorization" not in request.headers + and self._auth_provider is not None + ): + security_schemes: dict[str, dict[str, t.Any]] = self.api_spec["components"][ + "securitySchemes" + ] + try: + proposal = next( + ( + p + for p in request.security + if self._auth_provider.can_complete(p, security_schemes) + ) + ) + except StopIteration: + raise OpenAPIError(_("No suitable auth scheme found.")) + return proposal + + async def _authenticate_request( + self, + request: _Request, + proposal: dict[str, list[str]], + ) -> bool: + assert self._auth_provider is not None + + may_retry = False + security_schemes = self.api_spec["components"]["securitySchemes"] + for scheme_name, scopes in proposal.items(): + scheme = security_schemes[scheme_name] + if scheme["type"] == "http": + if scheme["scheme"] == "basic": + username, password = await self._auth_provider.http_basic_credentials() + secret = b64encode(username + b":" + password) + # TODO Should we add, amend or replace the existing auth header? + request.headers["Authorization"] = f"Basic {secret.decode()}" + else: + raise NotImplementedError("Auth scheme: http " + scheme["scheme"]) + elif scheme["type"] == "oauth2": + flow = scheme["flows"].get("clientCredentials") + if flow is None: + raise NotImplementedError("OAuth2: Only client credential flow is available.") + # Allow retry if the token was taken from cache. + may_retry = not await self._fetch_oauth2_token(flow) + request.headers["Authorization"] = f"Bearer {self._oauth2_token}" + elif scheme["type"] == "mutualTLS": + # At this point, we assume the cert has already been loaded into the sslcontext. + pass else: - auth = self._auth_provider( - request.security, self.api_spec["components"]["securitySchemes"] + raise NotImplementedError("Auth type: " + scheme["type"]) + return may_retry + + async def _fetch_oauth2_token(self, flow: dict[str, t.Any]) -> bool: + assert self._auth_provider is not None + + new_token = False + async with self._oauth2_lock: + now = datetime.now() + if self._oauth2_token is None or self._oauth2_expires < now: + # Get or refresh token. + client_id, client_secret = await self._auth_provider.oauth2_client_credentials() + secret = b64encode(client_id + b":" + client_secret) + data: dict[str, t.Any] = {"grant_type": "client_credentials"} + scopes = flow.get("scopes") + if scopes: + data["scopes"] = " ".join(scopes) + request = _Request( + operation_id="", + method="post", + url=flow["tokenUrl"], + headers={"Authorization": f"Basic {secret.decode()}"}, + data=data, ) + response = await self._send_request(request) + if response.status_code < 200 or response.status_code >= 300: + raise OpenAPIError("Failed to fetch OAuth2 token") + result = json.loads(response.body) + self._oauth2_token = result["access_token"] + self._oauth2_expires = now + timedelta(seconds=result["expires_in"]) + new_token = True + return new_token + + async def _send_request( + self, + request: _Request, + ) -> _Response: + # This function uses aiohttp to translate the _Request into a _Response. + data: aiohttp.FormData | dict[str, t.Any] | str | None + if request.files: + assert isinstance(request.data, dict) + # Maybe assert on the content type header. + data = aiohttp.FormData(default_to_multipart=True) + for key, value in request.data.items(): + data.add_field(key, encode_param(value)) + for key, (name, value, content_type) in request.files.items(): + data.add_field(key, value, filename=name, content_type=content_type) else: - # No auth required? Don't provide it. - # No auth_provider available? Hope for the best (should do the trick for cert auth). - auth = None + data = request.data try: - r = self._session.request( - request.method, - request.url, - params=request.params, - headers=request.headers, - data=request.data, - files=request.files, - auth=auth, - ) - response = _Response(status_code=r.status_code, headers=r.headers, body=r.content) - except requests.TooManyRedirects as e: - assert e.response is not None + async with aiohttp.ClientSession() as session: + async with session.request( + request.method, + request.url, + params=request.params, + headers=request.headers, + data=data, + ssl=self.ssl_context, + max_redirects=0, + ) as r: + response_body = await r.read() + response = _Response( + status_code=r.status, headers=r.headers, body=response_body + ) + except aiohttp.TooManyRedirects as e: + # We could handle that in the middleware... + assert e.history[-1] is not None raise OpenAPIError( _( "Received redirect to '{new_url} from {old_url}'." " Please check your configuration." ).format( - new_url=e.response.headers["location"], + new_url=e.history[-1].headers["location"], old_url=request.url, ) ) - except requests.RequestException as e: + except aiohttp.ClientResponseError as e: raise OpenAPIError(str(e)) return response - def _log_response(self, response: _Response) -> None: + async def _log_response(self, response: _Response) -> None: self._debug_callback( 1, _("Response: {status_code}").format(status_code=response.status_code) ) @@ -645,6 +644,22 @@ def call( parameters: dict[str, t.Any] | None = None, body: dict[str, t.Any] | None = None, validate_body: bool = True, + ) -> t.Any: + return asyncio.run( + self.async_call( + operation_id=operation_id, + parameters=parameters, + body=body, + validate_body=validate_body, + ) + ) + + async def async_call( + self, + operation_id: str, + parameters: dict[str, t.Any] | None = None, + body: dict[str, t.Any] | None = None, + validate_body: bool = True, ) -> t.Any: """ Make a call to the server. @@ -676,8 +691,9 @@ def call( headers = self._extract_params("header", path_spec, method_spec, parameters) + rel_url = path for name, value in self._extract_params("path", path_spec, method_spec, parameters).items(): - path = path.replace("{" + name + "}", value) + rel_url = path.replace("{" + name + "}", value) query_params = self._extract_params("query", path_spec, method_spec, parameters) @@ -687,7 +703,7 @@ def call( names=", ".join(parameters.keys()), operation_id=operation_id ) ) - url = urljoin(self._base_url, path) + url = urljoin(self._base_url, rel_url) request = self._render_request( path_spec, @@ -698,12 +714,33 @@ def call( body, validate_body=validate_body, ) - self._log_request(request) + await self._log_request(request) if self._dry_run and request.method.upper() not in SAFE_METHODS: raise UnsafeCallError(_("Call aborted due to safe mode")) - response = self._send_request(request) + may_retry = False + if proposal := self._select_proposal(request): + assert len(proposal) == 1, "More complex security proposals are not implemented." + may_retry = await self._authenticate_request(request, proposal) + + response = await self._send_request(request) + + if proposal is not None: + assert self._auth_provider is not None + if may_retry and response.status_code == 401: + self._oauth2_token = None + await self._authenticate_request(request, proposal) + response = await self._send_request(request) + + if response.status_code >= 200 and response.status_code < 300: + await self._auth_provider.auth_success_hook( + proposal, self.api_spec["components"]["securitySchemes"] + ) + elif response.status_code == 401: + await self._auth_provider.auth_failure_hook( + proposal, self.api_spec["components"]["securitySchemes"] + ) - self._log_response(response) + await self._log_response(response) return self._parse_response(method_spec, response) diff --git a/pulp-glue/pyproject.toml b/pulp-glue/pyproject.toml index 414089f56..906256e6b 100644 --- a/pulp-glue/pyproject.toml +++ b/pulp-glue/pyproject.toml @@ -23,9 +23,10 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ + "aiofiles>=25.1.0,<25.2", + "aiohttp>=3.12.0,<3.14", "multidict>=6.0.5,<6.8", "packaging>=22.0,<=26.0", # CalVer - "requests>=2.24.0,<2.33", "tomli>=2.0.0,<2.1;python_version<'3.11'", ] diff --git a/pulp-glue/tests/conftest.py b/pulp-glue/tests/conftest.py index bc8d56678..06a44e75a 100644 --- a/pulp-glue/tests/conftest.py +++ b/pulp-glue/tests/conftest.py @@ -1,11 +1,11 @@ import json -import os import typing as t import pytest +from pulp_glue.common.authentication import GlueAuthProvider from pulp_glue.common.context import PulpContext -from pulp_glue.common.openapi import BasicAuthProvider, OpenAPI +from pulp_glue.common.openapi import OpenAPI FAKE_OPENAPI_SPEC = json.dumps( { @@ -23,8 +23,6 @@ def pulp_ctx( if not any((mark.name == "live" for mark in request.node.iter_markers())): pytest.fail("This fixture can only be used in live (integration) tests.") - if os.environ.get("PULP_OAUTH2", "").lower() == "true": - pytest.skip("Pulp-glue in isolation does not support OAuth2 atm.") verbose = request.config.getoption("verbose") settings = pulp_cli_settings["cli"].copy() settings["debug_callback"] = lambda i, s: i <= verbose and print(s) @@ -58,18 +56,15 @@ def fake_pulp_ctx( if not any((mark.name == "live" for mark in request.node.iter_markers())): pytest.fail("This fixture can only be used in live (integration) tests.") - if os.environ.get("PULP_OAUTH2", "").lower() == "true": - pytest.skip("Pulp-glue in isolation does not support OAuth2 atm.") verbose = request.config.getoption("verbose") settings = pulp_cli_settings["cli"] - if "username" in settings: - username = settings.get("username") - assert isinstance(username, str) - password = settings.get("password") - assert isinstance(password, str) - auth_provider = BasicAuthProvider(username, password) - else: - auth_provider = None + auth_provider = GlueAuthProvider( + **{ + k: v + for k, v in settings.items() + if k in {"username", "password", "client_id", "client_secret", "cert", "key"} + } + ) return PulpContext( api_kwargs={ "base_url": settings["base_url"], diff --git a/pulp-glue/tests/test_auth_provider.py b/pulp-glue/tests/test_auth_provider.py index f8e2f2441..07eed191d 100644 --- a/pulp-glue/tests/test_auth_provider.py +++ b/pulp-glue/tests/test_auth_provider.py @@ -1,10 +1,9 @@ +import asyncio import typing as t import pytest -from requests.auth import AuthBase -from pulp_glue.common.exceptions import OpenAPIError -from pulp_glue.common.openapi import AuthProviderBase +from pulp_glue.common.authentication import AuthProviderBase, BasicAuthProvider, GlueAuthProvider pytestmark = pytest.mark.glue @@ -51,56 +50,60 @@ }, }, }, + "E": {"type": "mutualTLS"}, } -class MockBasicAuth(AuthBase): - pass +class TestBasicAuthProvider: + @pytest.fixture(scope="class") + def provider(self) -> AuthProviderBase: + return BasicAuthProvider(username="user1", password="password1") + def test_can_complete_basic(self, provider: AuthProviderBase) -> None: + assert provider.can_complete_http_basic() -class MockOAuth2CCAuth(AuthBase): - pass + def test_provides_username_and_password(self, provider: AuthProviderBase) -> None: + assert asyncio.run(provider.http_basic_credentials()) == (b"user1", b"password1") + def test_cannot_complete_mutualTLS(self, provider: AuthProviderBase) -> None: + assert not provider.can_complete_mutualTLS() -def test_auth_provider_select_mechanism(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(AuthProviderBase, "basic_auth", lambda *args: MockBasicAuth()) - monkeypatch.setattr( - AuthProviderBase, - "oauth2_client_credentials_auth", - lambda *args: MockOAuth2CCAuth(), - ) - provider = AuthProviderBase() + def test_can_complete_basic_proposal(self, provider: AuthProviderBase) -> None: + assert provider.can_complete({"B": []}, security_schemes=SECURITY_SCHEMES) - # Error if no auth scheme is available. - with pytest.raises(OpenAPIError): - provider([], SECURITY_SCHEMES) + def test_cannot_complete_bearer_proposal(self, provider: AuthProviderBase) -> None: + assert not provider.can_complete({"A": []}, security_schemes=SECURITY_SCHEMES) - # Error if a nonexisting mechanism is proposed. - with pytest.raises(KeyError): - provider([{"foo": []}], SECURITY_SCHEMES) + def test_cannot_complete_combined_proposal(self, provider: AuthProviderBase) -> None: + assert not provider.can_complete({"A": [], "B": []}, security_schemes=SECURITY_SCHEMES) - # Succeed without mechanism for an empty proposal. - assert provider([{}], SECURITY_SCHEMES) is None - # Try select a not implemented auth. - with pytest.raises(OpenAPIError): - provider([{"A": []}], SECURITY_SCHEMES) +class TestGlueAuthProvider: + def test_empty_provider_cannot_complete(self) -> None: + provider = GlueAuthProvider() + assert provider.can_complete_http_basic() is False + assert provider.can_complete_oauth2_client_credentials([]) is False + assert provider.can_complete_mutualTLS() is False - # Ignore proposals with multiple mechanisms. - with pytest.raises(OpenAPIError): - provider([{"B": [], "C": []}], SECURITY_SCHEMES) + def test_username_needs_password(self) -> None: + with pytest.raises(AssertionError): + GlueAuthProvider(username="user1") - # Select Basic auth alone and from multiple. - assert isinstance(provider([{"B": []}], SECURITY_SCHEMES), MockBasicAuth) - assert isinstance(provider([{"A": []}, {"B": []}], SECURITY_SCHEMES), MockBasicAuth) + def test_can_complete_basic_auth_and_provide_credentials(self) -> None: + provider = GlueAuthProvider(username="user1", password="secret1") + assert provider.can_complete_http_basic() is True + assert asyncio.run(provider.http_basic_credentials()) == (b"user1", b"secret1") - # Select oauth2 client credentials alone and over basic auth if scopes match. - assert isinstance(provider([{"D": []}], SECURITY_SCHEMES), MockOAuth2CCAuth) - assert isinstance(provider([{"B": []}, {"D": []}], SECURITY_SCHEMES), MockOAuth2CCAuth) - assert isinstance( - provider([{"B": []}, {"D": ["read:pets"]}], SECURITY_SCHEMES), MockOAuth2CCAuth - ) - # Fall back to basic if scope does not match. - assert isinstance( - provider([{"B": []}, {"D": ["read:cattle"]}], SECURITY_SCHEMES), MockBasicAuth - ) + def test_client_id_needs_client_secret(self) -> None: + with pytest.raises(AssertionError): + GlueAuthProvider(client_id="client1") + + def test_can_complete_oauth2_client_credentials_and_provide_them(self) -> None: + provider = GlueAuthProvider(client_id="client1", client_secret="secret1") + assert provider.can_complete_oauth2_client_credentials([]) is True + assert asyncio.run(provider.oauth2_client_credentials()) == (b"client1", b"secret1") + + def test_can_complete_mutualTLS_and_provide_cert(self) -> None: + provider = GlueAuthProvider(cert="FAKECERTIFICATE") + assert provider.can_complete_mutualTLS() is True + assert provider.tls_credentials() == ("FAKECERTIFICATE", None) diff --git a/pulp-glue/tests/test_authentication.py b/pulp-glue/tests/test_authentication.py deleted file mode 100644 index 108f86927..000000000 --- a/pulp-glue/tests/test_authentication.py +++ /dev/null @@ -1,30 +0,0 @@ -import typing as t - -import pytest - -from pulp_glue.common.authentication import OAuth2ClientCredentialsAuth - -pytestmark = pytest.mark.glue - - -def test_sending_no_scope_when_empty(monkeypatch: pytest.MonkeyPatch) -> None: - class OAuth2MockResponse: - def raise_for_status(self) -> None: - return None - - def json(self) -> dict[str, t.Any]: - return {"expires_in": 1, "access_token": "aaa"} - - def _requests_post_mocked( - url: str, data: dict[str, t.Any], **kwargs: t.Any - ) -> OAuth2MockResponse: - assert "scope" not in data - return OAuth2MockResponse() - - monkeypatch.setattr("requests.post", _requests_post_mocked) - - OAuth2ClientCredentialsAuth(token_url="", client_id="", client_secret="")._retrieve_token() - - OAuth2ClientCredentialsAuth( - token_url="", client_id="", client_secret="", scopes=[] - )._retrieve_token() diff --git a/pulp-glue/tests/test_openapi.py b/pulp-glue/tests/test_openapi.py index afa871c68..47f29de5f 100644 --- a/pulp-glue/tests/test_openapi.py +++ b/pulp-glue/tests/test_openapi.py @@ -1,19 +1,69 @@ +import asyncio import datetime import json import logging import pytest +from multidict import CIMultiDict +from pulp_glue.common.authentication import AuthProviderBase, BasicAuthProvider, GlueAuthProvider from pulp_glue.common.openapi import OpenAPI, _Request, _Response pytestmark = pytest.mark.glue +SECURITY_SCHEMES = { + "A": {"type": "http", "scheme": "bearer"}, + "B": {"type": "http", "scheme": "basic"}, + "C": { + "type": "oauth2", + "flows": { + "implicit": { + "authorizationUrl": "https://example.com/api/oauth/dialog", + "scopes": { + "write:pets": "modify pets in your account", + "read:pets": "read your pets", + }, + }, + "authorizationCode": { + "authorizationUrl": "https://example.com/api/oauth/dialog", + "tokenUrl": "https://example.com/api/oauth/token", + "scopes": { + "write:pets": "modify pets in your account", + "read:pets": "read your pets", + }, + }, + }, + }, + "D": { + "type": "oauth2", + "flows": { + "implicit": { + "authorizationUrl": "https://example.com/api/oauth/dialog", + "scopes": { + "write:pets": "modify pets in your account", + "read:pets": "read your pets", + }, + }, + "clientCredentials": { + "tokenUrl": "https://example.com/api/oauth/token", + "scopes": { + "write:pets": "modify pets in your account", + "read:pets": "read your pets", + }, + }, + }, + }, + "E": {"type": "mutualTLS"}, +} TEST_SCHEMA = json.dumps( { "openapi": "3.0.3", "paths": { "test/": { - "get": {"operationId": "get_test_id", "responses": {200: {}}}, + "get": { + "operationId": "get_test_id", + "responses": {200: {}}, + }, "post": { "operationId": "post_test_id", "requestBody": { @@ -25,9 +75,11 @@ }, }, "responses": {200: {}}, + "security": [{"B": []}], }, } }, + "security": [{}], "components": { "schemas": { "testBody": { @@ -35,14 +87,28 @@ "properties": {"text": {"type": "string"}}, "required": ["text"], } - } + }, + "securitySchemes": SECURITY_SCHEMES, }, } ).encode() -def mock_send_request(request: _Request) -> _Response: - return _Response(status_code=200, headers={}, body=b"{}") +async def mock_send_request(request: _Request) -> _Response: + if request.url.endswith("oauth/token"): + assert request.method.lower() == "post" + # $ echo -n "client1:secret1" | base64 + assert request.headers["Authorization"] == "Basic Y2xpZW50MTpzZWNyZXQx" + assert isinstance(request.data, dict) + assert request.data["grant_type"] == "client_credentials" + assert set(request.data["scopes"].split(" ")) == {"write:pets", "read:pets"} + return _Response( + status_code=200, + headers={}, + body=json.dumps({"access_token": "DEADBEEF", "expires_in": 600}).encode(), + ) + else: + return _Response(status_code=200, headers={}, body=b"{}") @pytest.fixture @@ -54,13 +120,73 @@ def mock_openapi(monkeypatch: pytest.MonkeyPatch) -> OpenAPI: return openapi +@pytest.fixture +def basic_auth_provider( + monkeypatch: pytest.MonkeyPatch, + mock_openapi: OpenAPI, +) -> AuthProviderBase: + auth_provider = BasicAuthProvider(username="user1", password="password1") + monkeypatch.setattr(mock_openapi, "_auth_provider", auth_provider) + return auth_provider + + +@pytest.fixture +def oauth2_cc_auth_provider( + monkeypatch: pytest.MonkeyPatch, + mock_openapi: OpenAPI, +) -> AuthProviderBase: + auth_provider = GlueAuthProvider(client_id="client1", client_secret="secret1") + monkeypatch.setattr(mock_openapi, "_auth_provider", auth_provider) + return auth_provider + + +@pytest.fixture +def tls_auth_provider( + monkeypatch: pytest.MonkeyPatch, + mock_openapi: OpenAPI, +) -> AuthProviderBase: + auth_provider = GlueAuthProvider(cert="asdf") + monkeypatch.setattr(mock_openapi, "_auth_provider", auth_provider) + return auth_provider + + +class TestRenderRequest: + def test_request_has_no_auth( + self, + mock_openapi: OpenAPI, + basic_auth_provider: AuthProviderBase, + ) -> None: + method, path = mock_openapi.operations["get_test_id"] + path_spec = mock_openapi.api_spec["paths"][path] + request = mock_openapi._render_request(path_spec, method, "test/", {}, {}, None) + assert request.security == [{}] + + def test_request_has_security( + self, + mock_openapi: OpenAPI, + basic_auth_provider: AuthProviderBase, + ) -> None: + method, path = mock_openapi.operations["post_test_id"] + path_spec = mock_openapi.api_spec["paths"][path] + request = mock_openapi._render_request( + path_spec, method, "test/", {}, {}, {"text": "TRACE"} + ) + assert request.security == [{"B": []}] + + class TestParseResponse: - def test_returns_dict_for_no_content(self, mock_openapi: OpenAPI) -> None: + def test_returns_dict_for_no_content( + self, + mock_openapi: OpenAPI, + ) -> None: response = _Response(204, {}, b"") result = mock_openapi._parse_response({}, response) assert result == {} - def test_decodes_json(self, mock_openapi: OpenAPI) -> None: + def test_decodes_json( + self, + mock_openapi: OpenAPI, + ) -> None: response = _Response(200, {"content-type": "application/json"}, b'{"a": 1, "b": "Hallo!"}') result = mock_openapi._parse_response( {"responses": {"200": {"content": {"application/json": {}}}}}, response @@ -161,3 +287,141 @@ def test_post_operation_to_debug( ("pulp_glue.openapi", logging.DEBUG + 3, "Response: 200"), ("pulp_glue.openapi", logging.DEBUG + 1, "b'{}'"), ] + + +class TestSelectProposal: + def test_none_if_no_provider( + self, + mock_openapi: OpenAPI, + ) -> None: + request = _Request( + "", + "GET", + "http://example.org", + CIMultiDict(), + security=[{"A": []}, {"B": []}], + ) + assert mock_openapi._select_proposal(request) is None + + def test_none_if_header_provided( + self, + mock_openapi: OpenAPI, + basic_auth_provider: AuthProviderBase, + ) -> None: + request = _Request( + "", + "GET", + "http://example.org", + CIMultiDict({"Authorization": "Weird Auth"}), + security=[{"A": []}, {"B": []}], + ) + assert mock_openapi._select_proposal(request) is None + + def test_B_with_basic_auth( + self, + mock_openapi: OpenAPI, + basic_auth_provider: AuthProviderBase, + ) -> None: + request = _Request( + "", + "GET", + "http://example.org", + CIMultiDict(), + security=[{"A": []}, {"B": []}], + ) + assert mock_openapi._select_proposal(request) == {"B": []} + + def test_oauth_with_client_credentials( + self, + mock_openapi: OpenAPI, + oauth2_cc_auth_provider: AuthProviderBase, + ) -> None: + request = _Request( + "", + "GET", + "http://example.org", + CIMultiDict(), + security=[ + {"A": []}, + {"B": []}, + {"C": []}, + {"D": []}, + {"E": []}, + ], + ) + assert mock_openapi._select_proposal(request) == {"D": []} + + def test_mutual_tls_with_cert( + self, + mock_openapi: OpenAPI, + tls_auth_provider: AuthProviderBase, + ) -> None: + request = _Request( + "", + "GET", + "http://example.org", + CIMultiDict(), + security=[ + {"A": []}, + {"B": []}, + {"C": []}, + {"D": []}, + {"E": []}, + ], + ) + assert mock_openapi._select_proposal(request) == {"E": []} + + +class TestAuthenticate: + def test_basic_auth( + self, + mock_openapi: OpenAPI, + basic_auth_provider: AuthProviderBase, + ) -> None: + request = _Request("", "GET", "http://example.org", CIMultiDict()) + assert asyncio.run(mock_openapi._authenticate_request(request, {"B": []})) is False + # $ echo -n "user1:password1" | base64 + assert request.headers.get("Authorization") == "Basic dXNlcjE6cGFzc3dvcmQx" + + def test_tls( + self, + mock_openapi: OpenAPI, + tls_auth_provider: AuthProviderBase, + ) -> None: + request = _Request("", "GET", "http://example.org", CIMultiDict()) + assert asyncio.run(mock_openapi._authenticate_request(request, {"E": []})) is False + # No header, and the certificate is handled by the ssl_context. + assert "Authorization" not in request.headers + + def test_oauth2_client_credentials( + self, + mock_openapi: OpenAPI, + oauth2_cc_auth_provider: AuthProviderBase, + ) -> None: + request = _Request("", "GET", "http://example.org", CIMultiDict()) + assert asyncio.run(mock_openapi._authenticate_request(request, {"D": ["scope1"]})) is False + assert request.headers.get("Authorization") == "Bearer DEADBEEF" + + def test_oauth2_client_credentials_reuses_token( + self, + mock_openapi: OpenAPI, + oauth2_cc_auth_provider: AuthProviderBase, + ) -> None: + mock_openapi._oauth2_token = "BABACAFE" + mock_openapi._oauth2_expires = datetime.datetime.now() + datetime.timedelta(seconds=500) + + request = _Request("", "GET", "http://example.org", CIMultiDict()) + assert asyncio.run(mock_openapi._authenticate_request(request, {"D": ["scope1"]})) is True + assert request.headers.get("Authorization") == "Bearer BABACAFE" + + def test_oauth2_client_credentials_refreshes_outdated_token( + self, + mock_openapi: OpenAPI, + oauth2_cc_auth_provider: AuthProviderBase, + ) -> None: + mock_openapi._oauth2_token = "BABACAFE" + mock_openapi._oauth2_expires = datetime.datetime.now() - datetime.timedelta(seconds=500) + + request = _Request("", "GET", "http://example.org", CIMultiDict()) + assert asyncio.run(mock_openapi._authenticate_request(request, {"D": ["scope1"]})) is False + assert request.headers.get("Authorization") == "Bearer DEADBEEF" diff --git a/pulp_cli/__init__.py b/pulp_cli/__init__.py index 5c7340f0b..af011ad6d 100644 --- a/pulp_cli/__init__.py +++ b/pulp_cli/__init__.py @@ -212,8 +212,6 @@ def main( api_kwargs = dict( base_url=base_url, headers=dict((header.split(":", maxsplit=1) for header in headers)), - cert=cert, - key=key, verify_ssl=verify_ssl, refresh_cache=refresh_api, dry_run=dry_run, @@ -229,6 +227,8 @@ def main( timeout=timeout, username=username, password=password, + cert=cert, + key=key, oauth2_client_id=client_id, oauth2_client_secret=client_secret, ) diff --git a/pulp_cli/config.py b/pulp_cli/config.py index b378ddbdf..767e732d4 100644 --- a/pulp_cli/config.py +++ b/pulp_cli/config.py @@ -88,10 +88,10 @@ def headers_callback( click.option("--password", default=None, help=_("Password on pulp server")), click.option("--client-id", default=None, help=_("OAuth2 client ID")), click.option("--client-secret", default=None, help=_("OAuth2 client secret")), - click.option("--cert", default="", help=_("Path to client certificate")), + click.option("--cert", default=None, help=_("Path to client certificate")), click.option( "--key", - default="", + default=None, help=_("Path to client private key. Not required if client cert contains this."), ), click.option("--verify-ssl/--no-verify-ssl", default=True, help=_("Verify SSL connection")), @@ -186,7 +186,7 @@ def validate_config(config: dict[str, t.Any], strict: bool = False) -> None: missing_settings = ( set(SETTINGS) - set(config.keys()) - - {"plugins", "username", "password", "client_id", "client_secret"} + - {"plugins", "username", "password", "client_id", "client_secret", "cert", "key"} ) if missing_settings: errors.append(_("Missing settings: '{}'.").format("','".join(missing_settings))) diff --git a/pulp_cli/generic.py b/pulp_cli/generic.py index 8e1369fcf..30820046a 100644 --- a/pulp_cli/generic.py +++ b/pulp_cli/generic.py @@ -1,3 +1,4 @@ +import asyncio import datetime import json import re @@ -7,11 +8,10 @@ from functools import lru_cache, wraps import click -import requests import schema as s import yaml -from pulp_glue.common.authentication import OAuth2ClientCredentialsAuth +from pulp_glue.common.authentication import AuthProviderBase from pulp_glue.common.context import ( DATETIME_FORMATS, DEFAULT_LIMIT, @@ -31,7 +31,6 @@ ) from pulp_glue.common.exceptions import PulpException, PulpNoWait from pulp_glue.common.i18n import get_translation -from pulp_glue.common.openapi import AuthProviderBase try: from pygments import highlight @@ -154,6 +153,8 @@ def __init__( domain: str = "default", username: str | None = None, password: str | None = None, + cert: str | None = None, + key: str | None = None, oauth2_client_id: str | None = None, oauth2_client_secret: str | None = None, ) -> None: @@ -161,8 +162,9 @@ def __init__( self.password = password self.oauth2_client_id = oauth2_client_id self.oauth2_client_secret = oauth2_client_secret - if not api_kwargs.get("cert"): - api_kwargs["auth_provider"] = PulpCLIAuthProvider(pulp_ctx=self) + self.cert = cert + self.key = key + api_kwargs["auth_provider"] = PulpCLIAuthProvider(pulp_ctx=self) verify_ssl: bool | None = api_kwargs.pop("verify_ssl", None) super().__init__( @@ -194,115 +196,130 @@ def output_result(self, result: t.Any) -> None: click.echo(formatter(result)) -if SECRET_STORAGE: +class PulpCLIAuthProvider(AuthProviderBase): + """ + The auth provider using cli promts to ask for missing passwords. + """ - class SecretStorageBasicAuth(requests.auth.AuthBase): - def __init__(self, pulp_ctx: PulpCLIContext): - self.pulp_ctx = pulp_ctx - assert self.pulp_ctx.username is not None - self.password: str | None = None + def __init__(self, pulp_ctx: PulpCLIContext): + super().__init__() + self.pulp_ctx = pulp_ctx + self._http_basic: tuple[bytes, bytes] | None = None + self._password_in_secretstorage: bool | None = None + self._oauth2_client_credentials: tuple[bytes, bytes] | None = None + + def can_complete_http_basic(self) -> bool: + return self.pulp_ctx.username is not None + + def can_complete_oauth2_client_credentials(self, scopes: list[str]) -> bool: + return self.pulp_ctx.oauth2_client_id is not None - self.attr: dict[str, str] = { + def can_complete_mutualTLS(self) -> bool: + return self.pulp_ctx.cert is not None + + def _fetch_password(self) -> bytes: + if self.pulp_ctx.password is not None: + password = self.pulp_ctx.password.encode("latin1") + elif SECRET_STORAGE: + assert self.pulp_ctx.username is not None + secret_attr: dict[str, str] = { "service": "pulp-cli", "base_url": self.pulp_ctx.api.base_url, "api_path": self.pulp_ctx.api_path, "username": self.pulp_ctx.username, } + with closing(secretstorage.dbus_init()) as connection: + collection = secretstorage.get_default_collection(connection) + item = next(collection.search_items(secret_attr), None) + if item: + password = item.get_secret() + self._password_in_secretstorage = True + else: + password = click.prompt("Password", hide_input=True).encode("latin1") + self._password_in_secretstorage = False + else: + password = click.prompt("Password", hide_input=True).encode("latin1") + return password - def response_hook(self, response: requests.Response, **kwargs: t.Any) -> requests.Response: - # Example adapted from: - # https://docs.python-requests.org/en/latest/_modules/requests/auth/#HTTPDigestAuth - if 200 <= response.status_code < 300 and not self.password_in_manager: - if click.confirm(_("Add password to password manager?")): - assert isinstance(self.password, str) - - with closing(secretstorage.dbus_init()) as connection: - collection = secretstorage.get_default_collection(connection) - collection.create_item( - "Pulp CLI", self.attr, self.password.encode(), replace=True - ) - elif response.status_code == 401 and self.password_in_manager: - if click.confirm(_("Remove failed password from password manager?")): - with closing(secretstorage.dbus_init()) as connection: - collection = secretstorage.get_default_collection(connection) - item = next(collection.search_items(self.attr), None) - if item is not None: - item.delete() - self.password = None - return response - - def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: + async def http_basic_credentials(self) -> tuple[bytes, bytes]: + if self._http_basic is None: assert self.pulp_ctx.username is not None - if self.password is None: - with closing(secretstorage.dbus_init()) as connection: - collection = secretstorage.get_default_collection(connection) - item = next(collection.search_items(self.attr), None) - if item: - self.password = item.get_secret().decode() - self.password_in_manager = True - else: - self.password = str(click.prompt("Password", hide_input=True)) - self.password_in_manager = False - request.register_hook("response", self.response_hook) # type: ignore [no-untyped-call] - return requests.auth.HTTPBasicAuth( # type: ignore [no-any-return] - self.pulp_ctx.username, self.password - )(request) + password = await asyncio.get_running_loop().run_in_executor(None, self._fetch_password) + self._http_basic = self.pulp_ctx.username.encode("latin1"), password + return self._http_basic + + def _save_password_to_storage(self) -> None: + if click.confirm(_("Add password to password manager?")): + with closing(secretstorage.dbus_init()) as connection: + assert self.pulp_ctx.username is not None + assert self._http_basic is not None + + secret_attr: dict[str, str] = { + "service": "pulp-cli", + "base_url": self.pulp_ctx.api.base_url, + "api_path": self.pulp_ctx.api_path, + "username": self.pulp_ctx.username, + } + password = self._http_basic[1] + collection = secretstorage.get_default_collection(connection) + collection.create_item("Pulp CLI", secret_attr, password, replace=True) + + async def auth_success_hook( + self, proposal: dict[str, list[str]], security_schemes: dict[str, dict[str, t.Any]] + ) -> None: + if SECRET_STORAGE and self._password_in_secretstorage is False: + await asyncio.get_running_loop().run_in_executor(None, self._save_password_to_storage) + self._password_in_secretstorage = None + + def _remove_password_from_storage(self) -> None: + if click.confirm(_("Remove failed password from password manager?")): + with closing(secretstorage.dbus_init()) as connection: + assert self.pulp_ctx.username is not None + + secret_attr: dict[str, str] = { + "service": "pulp-cli", + "base_url": self.pulp_ctx.api.base_url, + "api_path": self.pulp_ctx.api_path, + "username": self.pulp_ctx.username, + } + collection = secretstorage.get_default_collection(connection) + item = next(collection.search_items(secret_attr), None) + if item is not None: + item.delete() + self.password = None + + async def auth_failure_hook( + self, proposal: dict[str, list[str]], security_schemes: dict[str, dict[str, t.Any]] + ) -> None: + if SECRET_STORAGE and self._password_in_secretstorage is True: + await asyncio.get_running_loop().run_in_executor( + None, self._remove_password_from_storage + ) + self._password_in_secretstorage = None + self._http_basic = None + self._oauth2_client_credentials = None + + def tls_credentials(self) -> tuple[str, str | None]: + assert self.pulp_ctx.cert is not None -class PulpCLIAuthProvider(AuthProviderBase): - def __init__(self, pulp_ctx: PulpCLIContext): - self.pulp_ctx = pulp_ctx - self._memoized: dict[str, requests.auth.AuthBase | None] = {} - - def basic_auth(self, scopes: list[str]) -> requests.auth.AuthBase | None: - if "BASIC_AUTH" not in self._memoized: - if self.pulp_ctx.username is None: - # No username -> No basic auth. - self._memoized["BASIC_AUTH"] = None - elif self.pulp_ctx.password is None: - # TODO give the user a chance to opt out. - if SECRET_STORAGE: - # We could just try to fetch the password here, - # but we want to get a grip on the response_hook. - self._memoized["BASIC_AUTH"] = SecretStorageBasicAuth(self.pulp_ctx) - else: - self._memoized["BASIC_AUTH"] = requests.auth.HTTPBasicAuth( - self.pulp_ctx.username, click.prompt("Password", hide_input=True) - ) - else: - self._memoized["BASIC_AUTH"] = requests.auth.HTTPBasicAuth( - self.pulp_ctx.username, self.pulp_ctx.password - ) - return self._memoized["BASIC_AUTH"] - - def oauth2_client_credentials_auth( - self, flow: t.Any, scopes: list[str] - ) -> requests.auth.AuthBase | None: - token_url = flow["tokenUrl"] - key = "OAUTH2_CLIENT_CREDENTIALS;" + token_url + ";" + ":".join(scopes) - if key not in self._memoized: - if self.pulp_ctx.oauth2_client_id is None: - # No client_id -> No oauth2 client credentials. - self._memoized[key] = None - elif self.pulp_ctx.oauth2_client_secret is None: - self._memoized[key] = OAuth2ClientCredentialsAuth( - client_id=self.pulp_ctx.oauth2_client_id, - client_secret=click.prompt("Client Secret"), - token_url=flow["tokenUrl"], - # Try to request all possible scopes. - scopes=flow["scopes"], - verify_ssl=self.pulp_ctx.verify_ssl, - ) - else: - self._memoized[key] = OAuth2ClientCredentialsAuth( - client_id=self.pulp_ctx.oauth2_client_id, - client_secret=self.pulp_ctx.oauth2_client_secret, - token_url=flow["tokenUrl"], - # Try to request all possible scopes. - scopes=flow["scopes"], - verify_ssl=self.pulp_ctx.verify_ssl, - ) - return self._memoized[key] + return self.pulp_ctx.cert, self.pulp_ctx.key + + def _fetch_client_secret(self) -> str: + return self.pulp_ctx.oauth2_client_secret or click.prompt("Client Secret", hide_input=True) + + async def oauth2_client_credentials(self) -> tuple[bytes, bytes]: + if self._oauth2_client_credentials is None: + assert self.pulp_ctx.oauth2_client_id is not None + + client_secret = await asyncio.get_running_loop().run_in_executor( + None, self._fetch_client_secret + ) + self._oauth2_client_credentials = ( + self.pulp_ctx.oauth2_client_id.encode("latin1"), + client_secret.encode("latin1"), + ) + return self._oauth2_client_credentials ############################################################################## diff --git a/pulpcore/cli/core/task.py b/pulpcore/cli/core/task.py index 6ece44767..98151e748 100644 --- a/pulpcore/cli/core/task.py +++ b/pulpcore/cli/core/task.py @@ -1,13 +1,17 @@ +import asyncio import re from contextlib import suppress from datetime import datetime from pathlib import Path +import aiofiles +import aiohttp import click from pulp_glue.common.context import ( DATETIME_FORMATS, PluginRequirement, + PulpContext, PulpEntityContext, ) from pulp_glue.common.exceptions import PulpException @@ -175,6 +179,20 @@ def cancel( task_ctx.cancel(task_ctx.pulp_href) +async def _download_artifacts( + pulp_ctx: PulpContext, urls: dict[str, str], profile_artifact_dir: Path +) -> None: + async with aiohttp.ClientSession() as session: + for name, url in urls.items(): + profile_artifact_path = profile_artifact_dir / name + click.echo(_("Downloading {path}").format(path=profile_artifact_path)) + async with session.get(url, ssl=pulp_ctx.api.ssl_context) as response: + assert response.status == 200 + async with aiofiles.open(profile_artifact_path, "wb") as fp: + async for chunk in response.content.iter_chunked(1024): + await fp.write(chunk) + + @task.command() @href_option @uuid_option @@ -197,13 +215,7 @@ def profile_artifact_urls( uuid = uuid_match.group("uuid") profile_artifact_dir = Path(".") / f"task_profile-{task_name}-{uuid}" profile_artifact_dir.mkdir(exist_ok=True) - with pulp_ctx.api._session as session: - for name, url in urls.items(): - profile_artifact_path = profile_artifact_dir / name - click.echo(_("Downloading {path}").format(path=profile_artifact_path)) - response = session.get(url) - response.raise_for_status() - profile_artifact_path.write_bytes(response.content) + asyncio.run(_download_artifacts(pulp_ctx, urls, profile_artifact_dir)) @task.command()