From c69eefc7d53adfc37ff0792b24b5530deac379df Mon Sep 17 00:00:00 2001 From: Maxim Mishchenko Date: Tue, 30 Sep 2025 09:08:12 +0200 Subject: [PATCH 01/16] SNOW-2117147 Added certificates revocation checking with revocation lists (CRLs) (#2518) --- DESCRIPTION.md | 1 + setup.cfg | 4 +- src/snowflake/connector/connection.py | 93 + src/snowflake/connector/crl.py | 576 +++++++ src/snowflake/connector/crl_cache.py | 643 +++++++ src/snowflake/connector/network.py | 8 + src/snowflake/connector/session_manager.py | 2 +- src/snowflake/connector/ssl_wrap_socket.py | 35 +- test/extras/run.py | 50 +- test/integ/conftest.py | 4 + test/integ/test_crl.py | 175 ++ test/unit/test_crl.py | 1497 +++++++++++++++++ test/unit/test_crl_cache.py | 620 +++++++ test/unit/test_ssl_partial_chain.py | 2 +- test/unit/test_ssl_partial_chain_handshake.py | 2 +- 15 files changed, 3688 insertions(+), 24 deletions(-) create mode 100644 src/snowflake/connector/crl.py create mode 100644 src/snowflake/connector/crl_cache.py create mode 100644 test/integ/test_crl.py create mode 100644 test/unit/test_crl.py create mode 100644 test/unit/test_crl_cache.py diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 58fc1bee3d..d5483fd816 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -8,6 +8,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne # Release Notes - v3.18.0(TBD) + - Added support for checking certificates revocation using revocation lists (CRLs) - Added the `workload_identity_impersonation_path` parameter to support service account impersonation for Workload Identity Federation on GCP and AWS workloads only - Fixed `get_results_from_sfqid` when using `DictCursor` and executing multiple statements at once - Added the `oauth_credentials_in_body` parameter supporting an option to send the oauth client credentials in the request body diff --git a/setup.cfg b/setup.cfg index 25c3a8dc91..e08154af63 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,7 +48,7 @@ install_requires = asn1crypto>0.24.0,<2.0.0 cffi>=1.9,<2.0.0 cryptography>=3.1.0 - pyOpenSSL>=22.0.0,<25.0.0 + pyOpenSSL>=22.0.0,<26.0.0 pyjwt<3.0.0 pytz requests<3.0.0 @@ -95,7 +95,7 @@ development = pytest-timeout pytest-xdist pytzdata - pytest-asyncio + responses pandas = pandas>=2.1.2,<3.0.0 pyarrow diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index c4efe25f8c..93bf8302bd 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -413,6 +413,43 @@ def _get_private_bytes_from_file( False, bool, ), + # CRL (Certificate Revocation List) configuration parameters + # The default setup is specified in CRLConfig class + "cert_revocation_check_mode": ( + None, + (type(None), str), + ), # CRL revocation check mode: DISABLED, ENABLED, ADVISORY + "allow_certificates_without_crl_url": ( + None, + (type(None), bool), + ), # Allow certificates without CRL distribution points + "crl_connection_timeout_ms": ( + None, + (type(None), int), + ), # Connection timeout for CRL downloads in milliseconds + "crl_read_timeout_ms": ( + None, + (type(None), int), + ), # Read timeout for CRL downloads in milliseconds + "crl_cache_validity_hours": ( + None, + (type(None), int), + ), # CRL cache validity time in hours + "enable_crl_cache": (None, (type(None), bool)), # Enable CRL caching + "enable_crl_file_cache": (None, (type(None), bool)), # Enable file-based CRL cache + "crl_cache_dir": (None, (type(None), str)), # Directory for CRL file cache + "crl_cache_removal_delay_days": ( + None, + (type(None), int), + ), # Days to keep expired CRL files before removal + "crl_cache_cleanup_interval_hours": ( + None, + (type(None), int), + ), # CRL cache cleanup interval in hours + "crl_cache_start_cleanup": ( + None, + (type(None), bool), + ), # Run CRL cache cleanup in the background } APPLICATION_RE = re.compile(r"[\w\d_]+") @@ -641,6 +678,62 @@ def _ocsp_mode(self) -> OCSPMode: else: return OCSPMode.FAIL_CLOSED + # CRL (Certificate Revocation List) configuration properties + @property + def cert_revocation_check_mode(self) -> str | None: + """Certificate revocation check mode: DISABLED, ENABLED, or ADVISORY.""" + return self._cert_revocation_check_mode + + @property + def allow_certificates_without_crl_url(self) -> bool | None: + """Whether to allow certificates without CRL distribution points.""" + return self._allow_certificates_without_crl_url + + @property + def crl_connection_timeout_ms(self) -> int | None: + """Connection timeout for CRL downloads in milliseconds.""" + return self._crl_connection_timeout_ms + + @property + def crl_read_timeout_ms(self) -> int | None: + """Read timeout for CRL downloads in milliseconds.""" + return self._crl_read_timeout_ms + + @property + def crl_cache_validity_hours(self) -> int | None: + """CRL cache validity time in hours.""" + return self._crl_cache_validity_hours + + @property + def enable_crl_cache(self) -> bool | None: + """Whether CRL caching is enabled.""" + return self._enable_crl_cache + + @property + def enable_crl_file_cache(self) -> bool | None: + """Whether file-based CRL cache is enabled.""" + return self._enable_crl_file_cache + + @property + def crl_cache_dir(self) -> str | None: + """Directory for CRL file cache.""" + return self._crl_cache_dir + + @property + def crl_cache_removal_delay_days(self) -> int | None: + """Days to keep expired CRL files before removal.""" + return self._crl_cache_removal_delay_days + + @property + def crl_cache_cleanup_interval_hours(self) -> int | None: + """CRL cache cleanup interval in hours.""" + return self._crl_cache_cleanup_interval_hours + + @property + def crl_cache_start_cleanup(self) -> bool | None: + """Whether to start CRL cache cleanup immediately.""" + return self._crl_cache_start_cleanup + @property def session_id(self) -> int: return self._session_id diff --git a/src/snowflake/connector/crl.py b/src/snowflake/connector/crl.py new file mode 100644 index 0000000000..69d5261d29 --- /dev/null +++ b/src/snowflake/connector/crl.py @@ -0,0 +1,576 @@ +#!/usr/bin/env python +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from enum import Enum, unique +from logging import getLogger +from pathlib import Path +from typing import Any + +from cryptography import x509 +from cryptography.hazmat._oid import ExtensionOID +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import ec, padding, rsa +from OpenSSL.SSL import Connection as SSLConnection + +from .crl_cache import CRLCacheEntry, CRLCacheManager +from .session_manager import SessionManager + +logger = getLogger(__name__) + + +@unique +class CertRevocationCheckMode(Enum): + """Certificate revocation check modes based on revocation lists (CRL) + + CRL mode descriptions: + DISABLED: No revocation check is done. + ENABLED: Revocation check is done in the strictest way. The endpoint must expose at least one fully valid + certificate chain. Any check error invalidate the chain. + ADVISORY: Revocation check is done in a more relaxed way. Only a revocated certificate can invalidate + the chain. An error is treated positively (as a successful check). + """ + + DISABLED = "DISABLED" + ENABLED = "ENABLED" + ADVISORY = "ADVISORY" + + +class CRLValidationResult(Enum): + """Certificate revocation validation result statuses""" + + REVOKED = "REVOKED" + UNREVOKED = "UNREVOKED" + ERROR = "ERROR" + + +@dataclass +class CRLConfig: + """Configuration class for CRL validation settings.""" + + cert_revocation_check_mode: CertRevocationCheckMode = ( + CertRevocationCheckMode.DISABLED + ) + allow_certificates_without_crl_url: bool = False + connection_timeout_ms: int = 3000 + read_timeout_ms: int = 3000 + cache_validity_time: timedelta = timedelta(hours=24) + enable_crl_cache: bool = True + enable_crl_file_cache: bool = True + crl_cache_dir: Path | str | None = None + crl_cache_removal_delay_days: int = 7 + crl_cache_cleanup_interval_hours: int = 1 + crl_cache_start_cleanup: bool = False + + @classmethod + def from_connection(cls, sf_connection) -> CRLConfig: + """ + Create a CRLConfig instance from a SnowflakeConnection instance. + + This method extracts CRL configuration parameters from the connection's + read-only properties and creates a CRLConfig instance. + + Args: + sf_connection: SnowflakeConnection instance containing CRL configuration + + Returns: + CRLConfig: Configured CRLConfig instance + + Raises: + ValueError: If session_manager is not available in the connection + """ + # Extract CRL-specific configuration parameters from connection properties + if sf_connection.cert_revocation_check_mode is None: + cert_revocation_check_mode = cls.cert_revocation_check_mode + elif isinstance(sf_connection.cert_revocation_check_mode, str): + try: + cert_revocation_check_mode = CertRevocationCheckMode( + sf_connection.cert_revocation_check_mode + ) + except ValueError: + logger.warning( + f"Invalid cert_revocation_check_mode: {sf_connection.cert_revocation_check_mode}, " + f"defaulting to {cls.cert_revocation_check_mode}" + ) + cert_revocation_check_mode = cls.cert_revocation_check_mode + elif isinstance( + sf_connection.cert_revocation_check_mode, CertRevocationCheckMode + ): + cert_revocation_check_mode = sf_connection.cert_revocation_check_mode + else: + logger.warning( + f"Unsupported value for cert_revocation_check_mode: {sf_connection.cert_revocation_check_mode}, " + f"defaulting to {cls.cert_revocation_check_mode}" + ) + cert_revocation_check_mode = cls.cert_revocation_check_mode + + if cert_revocation_check_mode == CertRevocationCheckMode.DISABLED: + # The rest of the parameters don't matter if CRL checking is disabled + return cls(cert_revocation_check_mode=cert_revocation_check_mode) + + # Apply default value logic for all other parameters when connection attribute is None + cache_validity_time = ( + cls.cache_validity_time + if sf_connection.crl_cache_validity_hours is None + else timedelta(hours=int(sf_connection.crl_cache_validity_hours)) + ) + crl_cache_dir = ( + cls.crl_cache_dir + if sf_connection.crl_cache_dir is None + else Path(sf_connection.crl_cache_dir) + ) + allow_certificates_without_crl_url = ( + cls.allow_certificates_without_crl_url + if sf_connection.allow_certificates_without_crl_url is None + else bool(sf_connection.allow_certificates_without_crl_url) + ) + connection_timeout_ms = ( + cls.connection_timeout_ms + if sf_connection.crl_connection_timeout_ms is None + else int(sf_connection.crl_connection_timeout_ms) + ) + read_timeout_ms = ( + cls.read_timeout_ms + if sf_connection.crl_read_timeout_ms is None + else int(sf_connection.crl_read_timeout_ms) + ) + enable_crl_cache = ( + cls.enable_crl_cache + if sf_connection.enable_crl_cache is None + else bool(sf_connection.enable_crl_cache) + ) + enable_crl_file_cache = ( + cls.enable_crl_file_cache + if sf_connection.enable_crl_file_cache is None + else bool(sf_connection.enable_crl_file_cache) + ) + crl_cache_removal_delay_days = ( + cls.crl_cache_removal_delay_days + if sf_connection.crl_cache_removal_delay_days is None + else int(sf_connection.crl_cache_removal_delay_days) + ) + crl_cache_cleanup_interval_hours = ( + cls.crl_cache_cleanup_interval_hours + if sf_connection.crl_cache_cleanup_interval_hours is None + else int(sf_connection.crl_cache_cleanup_interval_hours) + ) + crl_cache_start_cleanup = ( + cls.crl_cache_start_cleanup + if sf_connection.crl_cache_start_cleanup is None + else bool(sf_connection.crl_cache_start_cleanup) + ) + + return cls( + cert_revocation_check_mode=cert_revocation_check_mode, + allow_certificates_without_crl_url=allow_certificates_without_crl_url, + connection_timeout_ms=connection_timeout_ms, + read_timeout_ms=read_timeout_ms, + cache_validity_time=cache_validity_time, + enable_crl_cache=enable_crl_cache, + enable_crl_file_cache=enable_crl_file_cache, + crl_cache_dir=crl_cache_dir, + crl_cache_removal_delay_days=crl_cache_removal_delay_days, + crl_cache_cleanup_interval_hours=crl_cache_cleanup_interval_hours, + crl_cache_start_cleanup=crl_cache_start_cleanup, + ) + + +class CRLValidator: + def __init__( + self, + session_manager: SessionManager | Any, + cert_revocation_check_mode: CertRevocationCheckMode = CRLConfig.cert_revocation_check_mode, + allow_certificates_without_crl_url: bool = CRLConfig.allow_certificates_without_crl_url, + connection_timeout_ms: int = CRLConfig.connection_timeout_ms, + read_timeout_ms: int = CRLConfig.read_timeout_ms, + cache_validity_time: timedelta = CRLConfig.cache_validity_time, + cache_manager: CRLCacheManager | None = None, + ): + self._session_manager = session_manager + self._cert_revocation_check_mode = cert_revocation_check_mode + self._allow_certificates_without_crl_url = allow_certificates_without_crl_url + self._connection_timeout_ms = connection_timeout_ms + self._read_timeout_ms = read_timeout_ms + self._cache_validity_time = cache_validity_time + self._cache_manager = cache_manager or CRLCacheManager.noop() + + @classmethod + def from_config( + cls, config: CRLConfig, session_manager: SessionManager + ) -> CRLValidator: + """ + Create a CRLValidator instance from a CRLConfig. + + This method creates a CRLValidator and its underlying objects (except session_manager) + from configuration parameters found in the CRLConfig. + + Args: + config: CRLConfig instance containing CRL-related parameters + session_manager: SessionManager instance + + Returns: + CRLValidator: Configured CRLValidator instance + """ + # Create cache manager if caching is enabled + cache_manager = None + if config.enable_crl_cache: + from snowflake.connector.crl_cache import CRLCacheFactory + + # Create memory cache using factory + memory_cache = CRLCacheFactory.get_memory_cache(config.cache_validity_time) + + # Create file cache if enabled + if config.enable_crl_file_cache: + removal_delay = timedelta(days=config.crl_cache_removal_delay_days) + file_cache = CRLCacheFactory.get_file_cache( + cache_dir=config.crl_cache_dir, removal_delay=removal_delay + ) + else: + from snowflake.connector.crl_cache import NoopCRLCache + + file_cache = NoopCRLCache() + + # Create cache manager + cache_manager = CRLCacheManager( + memory_cache=memory_cache, + file_cache=file_cache, + ) + + # Start cleanup through factory if requested + if config.crl_cache_start_cleanup: + cleanup_interval = timedelta( + hours=config.crl_cache_cleanup_interval_hours + ) + CRLCacheFactory.start_periodic_cleanup(cleanup_interval) + else: + cache_manager = CRLCacheManager.noop() + + return cls( + session_manager=session_manager, + cert_revocation_check_mode=config.cert_revocation_check_mode, + allow_certificates_without_crl_url=config.allow_certificates_without_crl_url, + connection_timeout_ms=config.connection_timeout_ms, + read_timeout_ms=config.read_timeout_ms, + cache_validity_time=config.cache_validity_time, + cache_manager=cache_manager, + ) + + def validate_certificate_chains( + self, certificate_chains: list[list[x509.Certificate]] + ) -> bool: + """ + Validate certificate chains against CRLs with actual HTTP requests + + Args: + certificate_chains: List of certificate chains to validate + + Returns: + True if validation passes, False otherwise + + Raises: + ValueError: If certificate_chains is None or empty + """ + if self._cert_revocation_check_mode == CertRevocationCheckMode.DISABLED: + return True + + if certificate_chains is None or len(certificate_chains) == 0: + logger.warning("Certificate chains are empty") + if self._cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY: + return True + return False + + results = [] + for chain in certificate_chains: + result = self._validate_single_chain(chain) + # If any of the chains is valid, the whole check is considered positive + if result == CRLValidationResult.UNREVOKED: + return True + results.append(result) + + # In non-advisory mode we require at least one chain get a clear UNREVOKED status + if self._cert_revocation_check_mode != CertRevocationCheckMode.ADVISORY: + return False + + # We're in advisory mode, so any error is treated positively + return any(result == CRLValidationResult.ERROR for result in results) + + def _validate_single_chain( + self, chain: list[x509.Certificate] + ) -> CRLValidationResult: + """Validate a single certificate chain""" + # An empty chain is considered an error + if len(chain) == 0: + return CRLValidationResult.ERROR + # the last certificate of the chain is considered the root and isn't validated + results = [] + for i in range(len(chain) - 1): + result = self._validate_certificate(chain[i], chain[i + 1]) + if result == CRLValidationResult.REVOKED: + return CRLValidationResult.REVOKED + results.append(result) + + if CRLValidationResult.ERROR in results: + return CRLValidationResult.ERROR + + return CRLValidationResult.UNREVOKED + + def _validate_certificate( + self, cert: x509.Certificate, ca_cert: x509.Certificate + ) -> CRLValidationResult: + """Validate a single certificate against CRL""" + # Check if certificate is short-lived (skip CRL check) + if self._is_short_lived_certificate(cert): + return CRLValidationResult.UNREVOKED + + # Extract CRL distribution points + crl_urls = self._extract_crl_distribution_points(cert) + + if not crl_urls: + # No CRL URLs found + if self._allow_certificates_without_crl_url: + return CRLValidationResult.UNREVOKED + return CRLValidationResult.ERROR + + results: list[CRLValidationResult] = [] + # Check against each CRL URL + for crl_url in crl_urls: + result = self._check_certificate_against_crl_url(cert, ca_cert, crl_url) + if result == CRLValidationResult.REVOKED: + return result + results.append(result) + + if all(result == CRLValidationResult.ERROR for result in results): + return CRLValidationResult.ERROR + + return CRLValidationResult.UNREVOKED + + @staticmethod + def _is_short_lived_certificate(cert: x509.Certificate) -> bool: + """Check if certificate is short-lived (validity <= 5 days)""" + try: + # Use timezone.utc versions to avoid deprecation warnings + validity_period = cert.not_valid_after_utc - cert.not_valid_before_utc + except AttributeError: + # Fallback for older versions + validity_period = cert.not_valid_after - cert.not_valid_before + return validity_period.days <= 5 + + @staticmethod + def _extract_crl_distribution_points(cert: x509.Certificate) -> list[str]: + """Extract CRL distribution point URLs from certificate""" + try: + crl_dist_points = cert.extensions.get_extension_for_oid( + ExtensionOID.CRL_DISTRIBUTION_POINTS + ).value + + urls = [] + for point in crl_dist_points: + if point.full_name: + for name in point.full_name: + if isinstance(name, x509.UniformResourceIdentifier): + urls.append(name.value) + return urls + except x509.ExtensionNotFound: + return [] + + def _get_crl_from_cache(self, crl_url: str) -> CRLCacheEntry | None: + return self._cache_manager.get(crl_url) + + def _put_crl_to_cache( + self, crl_url: str, crl: x509.CertificateRevocationList, ts: datetime + ) -> None: + self._cache_manager.put(crl_url, crl, ts) + + def _fetch_crl_from_url(self, crl_url: str) -> bytes | None: + try: + logger.debug("Trying to download CRL from: %s", crl_url) + response = self._session_manager.get( + crl_url, timeout=(self._connection_timeout_ms, self._read_timeout_ms) + ) + response.raise_for_status() + return response.content + except Exception: + # CRL fetch or parsing failed + logger.exception("Failed to download CRL from %s", crl_url) + return None + + def _download_crl( + self, crl_url: str + ) -> tuple[x509.CertificateRevocationList | None, datetime | None]: + crl_bytes, now = self._fetch_crl_from_url(crl_url), datetime.now(timezone.utc) + try: + logger.debug("Trying to parse CRL from: %s", crl_url) + crl = x509.load_der_x509_crl(crl_bytes, backend=default_backend()) + # Check if CRL is expired + try: + next_update = crl.next_update_utc + except AttributeError: + next_update = crl.next_update + + if next_update and now > next_update: + logger.warning( + "The CRL from %s was expired on %s", crl_url, next_update + ) + return None, None + + return crl, now + except Exception: + logger.exception("Failed to parse CRL from %s", crl_url) + return None, None + + def _check_certificate_against_crl_url( + self, cert: x509.Certificate, ca_cert: x509.Certificate, crl_url: str + ) -> CRLValidationResult: + """Check if certificate is revoked according to CRL by the provided URL""" + now = datetime.now(timezone.utc) + logger.debug("Trying to get cached CRL for %s", crl_url) + cached_crl = self._get_crl_from_cache(crl_url) + if ( + cached_crl is None + or cached_crl.is_crl_expired_by(now) + or cached_crl.is_evicted_by(now, self._cache_validity_time) + ): + crl, ts = self._download_crl(crl_url) + if crl and ts: + self._put_crl_to_cache(crl_url, crl, ts) + else: + crl = cached_crl.crl + + # If by some reason we didn't get a valid CRL we consider it a check error + if crl is None: + return CRLValidationResult.ERROR + + # Verify CRL signature with CA public key + # Check if the CA certificate is the expected CRL issuer + if crl.issuer != ca_cert.subject: + logger.warning( + "CRL issuer (%s) does not match CA certificate subject (%s) for URL: %s", + crl.issuer, + ca_cert.subject, + crl_url, + ) + # In most cases this indicates a configuration issue, but we'll still try verification + + if not self._verify_crl_signature(crl, ca_cert): + logger.warning("CRL signature verification failed for URL: %s", crl_url) + # Always return ERROR when signature verification fails + # We cannot trust a CRL whose signature cannot be verified + return CRLValidationResult.ERROR + + # Check if certificate is revoked + return self._check_certificate_against_crl(cert, crl) + + def _verify_crl_signature( + self, crl: x509.CertificateRevocationList, ca_cert: x509.Certificate + ) -> bool: + """Verify CRL signature with CA's public key""" + try: + # Get the signature algorithm from the CRL + signature_algorithm = crl.signature_algorithm_oid + hash_algorithm = crl.signature_hash_algorithm + + logger.debug( + "Verifying CRL signature with algorithm: %s, hash: %s", + signature_algorithm, + hash_algorithm, + ) + + # Determine the appropriate padding based on the signature algorithm + public_key = ca_cert.public_key() + + # Handle different key types with appropriate signature verification + if isinstance(public_key, rsa.RSAPublicKey): + # For RSA signatures, we need to use PKCS1v15 padding + public_key.verify( + crl.signature, + crl.tbs_certlist_bytes, + padding.PKCS1v15(), + hash_algorithm, + ) + elif isinstance(public_key, ec.EllipticCurvePublicKey): + # For EC signatures, use ECDSA algorithm + public_key.verify( + crl.signature, + crl.tbs_certlist_bytes, + ec.ECDSA(hash_algorithm), + ) + else: + # For other key types (DSA, etc.), try without padding + public_key.verify( + crl.signature, + crl.tbs_certlist_bytes, + hash_algorithm, + ) + + logger.debug("CRL signature verification successful") + return True + except Exception as e: + logger.warning("CRL signature verification failed: %s", e) + return False + + def _check_certificate_against_crl( + self, cert: x509.Certificate, crl: x509.CertificateRevocationList + ) -> CRLValidationResult: + """Check if certificate is revoked according to CRL""" + revoked_cert = crl.get_revoked_certificate_by_serial_number(cert.serial_number) + return ( + CRLValidationResult.REVOKED + if revoked_cert + else CRLValidationResult.UNREVOKED + ) + + def validate_connection(self, connection: SSLConnection) -> bool: + """ + Validate an OpenSSL connection against CRLs. + + This method extracts certificate chains from the connection and validates them + against Certificate Revocation Lists (CRLs). + + Args: + connection: OpenSSL connection object + + Returns: + True if validation passes, False otherwise + """ + certificate_chains = self._extract_certificate_chains_from_connection( + connection + ) + return self.validate_certificate_chains(certificate_chains) + + def _extract_certificate_chains_from_connection( + self, connection + ) -> list[list[x509.Certificate]]: + """Extract certificate chains from OpenSSL connection for CRL validation. + + Args: + connection: OpenSSL connection object + + Returns: + List of certificate chains, where each chain is a list of x509.Certificate objects + """ + from OpenSSL.crypto import FILETYPE_ASN1, dump_certificate + + try: + cert_chain = connection.get_peer_cert_chain() + if not cert_chain: + logger.debug("No certificate chain found in connection") + return [] + + # Convert OpenSSL certificates to cryptography x509 certificates + x509_chain = [] + for cert_openssl in cert_chain: + cert_der = dump_certificate(FILETYPE_ASN1, cert_openssl) + cert_x509 = x509.load_der_x509_certificate(cert_der, default_backend()) + x509_chain.append(cert_x509) + + logger.debug( + "Extracted %d certificates for CRL validation", len(x509_chain) + ) + return [x509_chain] # Return as a single chain + + except Exception as e: + logger.warning( + "Failed to extract certificate chain for CRL validation: %s", e + ) + return [] diff --git a/src/snowflake/connector/crl_cache.py b/src/snowflake/connector/crl_cache.py new file mode 100644 index 0000000000..73f11cea9d --- /dev/null +++ b/src/snowflake/connector/crl_cache.py @@ -0,0 +1,643 @@ +#!/usr/bin/env python +from __future__ import annotations + +import atexit +import hashlib +import logging +import os +import platform +import threading +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from pathlib import Path + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from filelock import BaseFileLock, FileLock + +logger = logging.getLogger(__name__) + + +@dataclass +class CRLCacheEntry: + """Cache entry containing a CRL and its download timestamp.""" + + crl: x509.CertificateRevocationList + download_time: datetime + + def _next_update(self) -> datetime | None: + """A compatibility wrapper around crl.next_update.""" + return getattr(self.crl, "next_update_utc", None) or getattr( + self.crl, "next_update", None + ) + + def is_crl_expired_by(self, ts: datetime) -> bool: + """ + Check if the CRL has expired. + + Args: + ts: Time to check against + + Returns: + True if the CRL has expired, False otherwise + """ + next_update = self._next_update() + return next_update is not None and next_update < ts + + def is_evicted_by(self, ts: datetime, cache_validity_time: timedelta) -> bool: + """ + Check if the cache entry should be evicted based on cache validity time. + + Args: + ts: Current time to check against + cache_validity_time: How long cache entries remain valid + + Returns: + True if the entry should be evicted, False otherwise + """ + expiry_time = self.download_time + cache_validity_time + return expiry_time < ts + + +class CRLCache(ABC): + """ + Abstract base class for CRL caches. + """ + + @abstractmethod + def get(self, crl_url: str) -> CRLCacheEntry | None: + """ + Get a CRL cache entry by URL. + + Args: + crl_url: The CRL URL + + Returns: + The cache entry if found, None otherwise + """ + raise NotImplementedError() + + @abstractmethod + def put(self, crl_url: str, entry: CRLCacheEntry) -> None: + """ + Store a CRL cache entry. + + Args: + crl_url: The CRL URL + entry: The cache entry to store + """ + raise NotImplementedError() + + @abstractmethod + def cleanup(self) -> None: + """Remove expired and evicted entries from the cache.""" + raise NotImplementedError() + + +class NoopCRLCache(CRLCache): + """ + No-operation CRL cache that doesn't store anything. + """ + + # Singleton instance + INSTANCE = None + + def __new__(cls): + if cls.INSTANCE is None: + cls.INSTANCE = super().__new__(cls) + return cls.INSTANCE + + def get(self, crl_url: str) -> CRLCacheEntry | None: + """Always returns None.""" + return None + + def put(self, crl_url: str, entry: CRLCacheEntry) -> None: + """Does nothing.""" + pass + + def cleanup(self) -> None: + """Does nothing.""" + pass + + +class CRLInMemoryCache(CRLCache): + """ + In-memory CRL cache using a thread-safe dictionary. + """ + + def __init__(self, cache_validity_time: timedelta): + """ + Initialize the in-memory cache. + + Args: + cache_validity_time: How long cache entries remain valid + """ + self._cache: dict[str, CRLCacheEntry] = {} + self._cache_validity_time = cache_validity_time + self._lock = threading.RLock() + + def get(self, crl_url: str) -> CRLCacheEntry | None: + """ + Get a CRL cache entry from memory. + + Args: + crl_url: The CRL URL + + Returns: + The cache entry if found, None otherwise + """ + with self._lock: + entry = self._cache.get(crl_url) + if entry is not None: + logger.debug(f"Found CRL in memory cache for {crl_url}") + return entry + + def put(self, crl_url: str, entry: CRLCacheEntry) -> None: + """ + Store a CRL cache entry in memory. + + Args: + crl_url: The CRL URL + entry: The cache entry to store + """ + with self._lock: + self._cache[crl_url] = entry + + def cleanup(self) -> None: + """Remove expired and evicted entries from memory cache.""" + now = datetime.now(timezone.utc) + logger.debug(f"Cleaning up in-memory CRL cache at {now}") + + with self._lock: + urls_to_remove = [] + + for url, entry in self._cache.items(): + expired = entry.is_crl_expired_by(now) + evicted = entry.is_evicted_by(now, self._cache_validity_time) + + if expired or evicted: + logger.debug( + f"Removing in-memory CRL cache entry for {url}: " + f"expired={expired}, evicted={evicted}" + ) + urls_to_remove.append(url) + + for url in urls_to_remove: + del self._cache[url] + + removed_count = len(urls_to_remove) + if removed_count > 0: + logger.debug( + f"Removed {removed_count} expired/evicted entries from in-memory CRL cache" + ) + + +class CRLFileCache(CRLCache): + """ + File-based CRL cache that persists CRLs to disk. + """ + + def __init__( + self, cache_dir: Path | None = None, removal_delay: timedelta | None = None + ): + """ + Initialize the file cache. + + Args: + cache_dir: Directory to store cached CRLs + removal_delay: How long to wait before removing expired files + + Raises: + OSError: If cache directory cannot be created + """ + self._cache_file_lock_timeout = 5.0 + self._cache_dir = cache_dir or _get_default_crl_cache_path() + self._removal_delay = removal_delay or timedelta(days=7) + + self._ensure_cache_directory_exists() + + def _ensure_cache_directory_exists(self) -> None: + """Create the cache directory if it doesn't exist.""" + try: + self._cache_dir.mkdir(parents=True, exist_ok=True) + logger.debug(f"Cache directory created/verified: {self._cache_dir}") + except OSError as e: + raise OSError(f"Failed to create cache directory {self._cache_dir}: {e}") + + def _get_crl_file_path(self, crl_url: str) -> Path: + """ + Generate a file path for the given CRL URL. + + Args: + crl_url: The CRL URL + + Returns: + Path to the cache file + """ + # Create a safe filename from the URL using a hash + url_hash = hashlib.sha256(crl_url.encode()).hexdigest() + return self._cache_dir / f"crl_{url_hash}.der" + + def _get_crl_file_lock(self, crl_cache_file: Path) -> BaseFileLock: + """Return a lock instance for the given CRL cache file""" + return FileLock( + crl_cache_file.with_suffix(".lock"), + thread_local=True, + timeout=self._cache_file_lock_timeout, + ) + + def get(self, crl_url: str) -> CRLCacheEntry | None: + """ + Get a CRL cache entry from disk. + + Args: + crl_url: The CRL URL + + Returns: + The cache entry if found, None otherwise + """ + crl_file_path = self._get_crl_file_path(crl_url) + with self._get_crl_file_lock(crl_file_path): + try: + if crl_file_path.exists(): + logger.debug(f"Found CRL on disk for {crl_file_path}") + + # Get file modification time as download time + stat_info = crl_file_path.stat() + download_time = datetime.fromtimestamp( + stat_info.st_mtime, tz=timezone.utc + ) + + # Read and parse the CRL + with open(crl_file_path, "rb") as f: + crl_data = f.read() + + crl = x509.load_der_x509_crl(crl_data, backend=default_backend()) + return CRLCacheEntry(crl, download_time) + + except Exception as e: + logger.warning(f"Failed to read CRL from disk cache for {crl_url}: {e}") + + return None + + def put(self, crl_url: str, entry: CRLCacheEntry) -> None: + """ + Store a CRL cache entry to disk. + + Args: + crl_url: The CRL URL + entry: The cache entry to store + """ + crl_file_path = self._get_crl_file_path(crl_url) + with self._get_crl_file_lock(crl_file_path): + try: + # Serialize the CRL to DER format + crl_data = entry.crl.public_bytes(serialization.Encoding.DER) + + # Write to file + with open(crl_file_path, "wb") as f: + f.write(crl_data) + + # Set file modification time to download time + download_timestamp = entry.download_time.timestamp() + os.utime(crl_file_path, (download_timestamp, download_timestamp)) + + logger.debug(f"Stored CRL to disk cache: {crl_file_path}") + + except Exception as e: + logger.warning(f"Failed to write CRL to disk cache for {crl_url}: {e}") + + def _is_cached_crl_file_for_removal( + self, crl_cache_file: Path, ts: datetime + ) -> bool: + """Check if the given CRL cache file is by its lifetime.""" + try: + # Get file modification time + stat_info = crl_cache_file.stat() + download_time = datetime.fromtimestamp(stat_info.st_mtime, tz=timezone.utc) + + # Check if file should be removed based on removal delay + removal_time = download_time + self._removal_delay + return ts > removal_time + except Exception as e: + logger.warning(f"Error processing cache file {crl_cache_file}: {e}") + return False + + def cleanup(self) -> None: + """Remove expired files from disk cache.""" + now = datetime.now(timezone.utc) + logger.debug(f"Cleaning up file-based CRL cache at {now}") + + removed_count = 0 + try: + for crl_file in self._cache_dir.glob("crl_*.der"): + # double-checked locking + if self._is_cached_crl_file_for_removal(crl_file, now): + with self._get_crl_file_lock(crl_file): + if self._is_cached_crl_file_for_removal(crl_file, now): + crl_file.unlink(missing_ok=True) + removed_count += 1 + logger.debug(f"Removed expired file: {crl_file}") + except Exception as e: + logger.error(f"Error during file cache cleanup: {e}") + + +class CRLCacheManager: + """ + Cache manager that coordinates between in-memory and file-based CRL caches. + """ + + def __init__( + self, + memory_cache: CRLCache, + file_cache: CRLCache, + ): + """ + Initialize the cache manager. + + Args: + memory_cache: In-memory cache implementation + file_cache: File-based cache implementation + """ + self._memory_cache = memory_cache + self._file_cache = file_cache + + @classmethod + def noop(cls) -> CRLCacheManager: + """Create noop cache manager.""" + return cls(NoopCRLCache(), NoopCRLCache()) + + def get(self, crl_url: str) -> CRLCacheEntry | None: + """ + Get a CRL cache entry, checking memory cache first, then file cache. + + Args: + crl_url: The CRL URL + + Returns: + The cache entry if found, None otherwise + """ + # Check memory cache first + entry = self._memory_cache.get(crl_url) + if entry is not None: + return entry + + # Check file cache + entry = self._file_cache.get(crl_url) + if entry is not None: + # Promote to memory cache + self._memory_cache.put(crl_url, entry) + return entry + + logger.debug(f"CRL not found in cache for {crl_url}") + return None + + def put( + self, crl_url: str, crl: x509.CertificateRevocationList, download_time: datetime + ) -> None: + """ + Store a CRL in both memory and file caches. + + Args: + crl_url: The CRL URL + crl: The CRL to store + download_time: When the CRL was downloaded + """ + entry = CRLCacheEntry(crl, download_time) + self._memory_cache.put(crl_url, entry) + self._file_cache.put(crl_url, entry) + + +class CRLCacheFactory: + """ + Factory class for creating singleton instances of CRL caches. + + This factory ensures that only one instance of each cache type exists, + providing warnings when attempting to create instances with different parameters. + Also manages background cleanup of existing cache instances. + """ + + # Singleton instances + _memory_cache_instance = None + _file_cache_instance = None + _instance_lock = threading.RLock() + + # Cleanup management + _cleanup_executor: ThreadPoolExecutor | None = None + _cleanup_shutdown: threading.Event = threading.Event() + _cleanup_interval: timedelta | None = None + _atexit_registered: bool = False + + @classmethod + def get_memory_cache(cls, cache_validity_time: timedelta) -> CRLInMemoryCache: + """ + Get or create a singleton CRLInMemoryCache instance. + + Args: + cache_validity_time: How long cache entries remain valid + + Returns: + The singleton CRLInMemoryCache instance + """ + with cls._instance_lock: + if cls._memory_cache_instance is None: + cls._memory_cache_instance = CRLInMemoryCache(cache_validity_time) + elif cls._memory_cache_instance._cache_validity_time != cache_validity_time: + logger.warning( + f"CRLs in-memory cache has already been initialized with cache validity time of {cls._memory_cache_instance._cache_validity_time}, " + f"ignoring new cache validity time of {cache_validity_time}" + ) + return cls._memory_cache_instance + + @classmethod + def get_file_cache( + cls, cache_dir: Path | None = None, removal_delay: timedelta | None = None + ) -> CRLFileCache: + """ + Get or create a singleton CRLFileCache instance. + + Args: + cache_dir: Directory to store cached CRLs + removal_delay: How long to wait before removing expired files + + Returns: + The singleton CRLFileCache instance + """ + with cls._instance_lock: + if cls._file_cache_instance is None: + cls._file_cache_instance = CRLFileCache(cache_dir, removal_delay) + else: + # Check if parameters differ from existing instance + existing_cache_dir = cls._file_cache_instance._cache_dir + existing_removal_delay = cls._file_cache_instance._removal_delay + requested_cache_dir = cache_dir or _get_default_crl_cache_path() + requested_removal_delay = removal_delay or timedelta(days=7) + + if existing_cache_dir != requested_cache_dir: + logger.warning( + f"CRLs file cache has already been initialized with cache directory '{existing_cache_dir}', " + f"ignoring new cache directory '{requested_cache_dir}'" + ) + if existing_removal_delay != requested_removal_delay: + logger.warning( + f"CRLs file cache has already been initialized with removal delay of {existing_removal_delay}, " + f"ignoring new removal delay of {requested_removal_delay}" + ) + return cls._file_cache_instance + + @classmethod + def start_periodic_cleanup(cls, cleanup_interval: timedelta) -> None: + """ + Start the periodic cleanup task for existing cache instances. + + Args: + cleanup_interval: How often to run cleanup tasks + """ + with cls._instance_lock: + if cls.is_periodic_cleanup_running(): + logger.debug( + "Periodic cleanup already running, so it will first be stopped before restarting." + ) + cls.stop_periodic_cleanup() + + cls._cleanup_interval = cleanup_interval + cls._cleanup_executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="crl-cache-cleanup" + ) + + # Register atexit handler for graceful shutdown (only once) + if not cls._atexit_registered: + atexit.register(cls._atexit_cleanup_handler) + cls._atexit_registered = True + + # Submit the cleanup task + cls._cleanup_executor.submit(cls._cleanup_loop) + + logger.debug( + f"Scheduled CRL cache cleanup task to run every {cleanup_interval.total_seconds()} seconds." + ) + + @classmethod + def stop_periodic_cleanup(cls) -> None: + """Stop the periodic cleanup task.""" + executor_to_shutdown = None + + with cls._instance_lock: + if cls._cleanup_executor is None or cls._cleanup_shutdown.is_set(): + return + + cls._cleanup_shutdown.set() + executor_to_shutdown = cls._cleanup_executor + + # Shutdown outside of lock to avoid deadlock + if executor_to_shutdown is not None: + executor_to_shutdown.shutdown(wait=True) + + with cls._instance_lock: + cls._cleanup_shutdown.clear() + cls._cleanup_executor = None + cls._cleanup_interval = None + + @classmethod + def is_periodic_cleanup_running(cls) -> bool: + """Check if periodic cleanup task is running.""" + with cls._instance_lock: + return cls._cleanup_executor is not None + + @classmethod + def _cleanup_loop(cls) -> None: + """Main cleanup loop that runs periodically.""" + while not cls._cleanup_shutdown.is_set(): + if cls._cleanup_interval is None: + break + + logger.debug( + f"Running periodic CRL cache cleanup with interval {cls._cleanup_interval.total_seconds()} seconds" + ) + + # Clean memory cache only if it exists + if cls._memory_cache_instance is not None: + try: + cls._memory_cache_instance.cleanup() + except Exception as e: + logger.error( + f"An error occurred during scheduled CRL memory cache cleanup: {e}" + ) + + # Clean file cache only if it exists + if cls._file_cache_instance is not None: + try: + cls._file_cache_instance.cleanup() + except Exception as e: + logger.error( + f"An error occurred during scheduled CRL disk cache cleanup: {e}" + ) + + shutdown = cls._cleanup_shutdown.wait( + timeout=cls._cleanup_interval.total_seconds() + ) + if shutdown: + logger.debug( + "CRL cache cleanup stopped gracefully by a shutdown event." + ) + break + + @classmethod + def _atexit_cleanup_handler(cls) -> None: + """ + Atexit handler to ensure graceful shutdown of periodic cleanup on program exit. + """ + try: + cls.stop_periodic_cleanup() + logger.debug("CRL cache cleanup stopped gracefully on program exit.") + except Exception as e: + # Don't raise exceptions in atexit handlers + logger.error(f"Error stopping CRL cache cleanup on program exit: {e}") + + @classmethod + def reset(cls) -> None: + """ + Reset the factory, clearing all singleton instances and stopping cleanup. + This is primarily useful for testing purposes. + """ + with cls._instance_lock: + cls.stop_periodic_cleanup() + cls._memory_cache_instance = None + cls._file_cache_instance = None + cls._atexit_registered = False + + +def _get_windows_home_path() -> Path: + try: + return Path.home() + except RuntimeError: + pass + if "USERPROFILE" in os.environ: + return Path(os.environ["USERPROFILE"]) + if "HOMEDRIVE" in os.environ and "HOMEPATH" in os.environ: + return Path(os.environ["HOMEDRIVE"]) / os.environ["HOMEPATH"] + if "LOCALAPPDATA" in os.environ: + return Path(os.environ["LOCALAPPDATA"]).parent.parent + if "APPDATA" in os.environ: + return Path(os.environ["APPDATA"]).parent.parent + return Path("~") + + +def _get_default_crl_cache_path() -> Path: + """Return the default path to persist cached CRLs.""" + if platform.system() == "Windows": + return ( + _get_windows_home_path() + / "AppData" + / "Local" + / "Snowflake" + / "Caches" + / "crls" + ) + elif platform.system() == "Darwin": + return Path.home() / "Library" / "Caches" / "Snowflake" / "crls" + else: + return Path.home() / ".cache" / "Snowflake" / "crls" diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py index ae34375a42..36224174ba 100644 --- a/src/snowflake/connector/network.py +++ b/src/snowflake/connector/network.py @@ -42,6 +42,7 @@ HTTP_HEADER_SERVICE_NAME, HTTP_HEADER_USER_AGENT, ) +from .crl import CRLConfig from .description import ( CLIENT_NAME, CLIENT_VERSION, @@ -338,6 +339,13 @@ def __init__( self._connection._ocsp_response_cache_filename if self._connection else None ) + # CRL mode (should be DISABLED by default) + ssl_wrap_socket.FEATURE_CRL_CONFIG = ( + CRLConfig.from_connection(self._connection) + if self._connection + else ssl_wrap_socket.DEFAULT_CRL_CONFIG + ) + # This is to address the issue where requests hangs _ = "dummy".encode("idna").decode("utf-8") diff --git a/src/snowflake/connector/session_manager.py b/src/snowflake/connector/session_manager.py index a10a89d502..63986e2235 100644 --- a/src/snowflake/connector/session_manager.py +++ b/src/snowflake/connector/session_manager.py @@ -266,7 +266,7 @@ def get( url: str, *, headers: Mapping[str, str] | None = None, - timeout: int | None = 3, + timeout: int | tuple[int, int] | None = 3, use_pooling: bool | None = None, **kwargs, ): diff --git a/src/snowflake/connector/ssl_wrap_socket.py b/src/snowflake/connector/ssl_wrap_socket.py index f1a14e5c89..3c2f92ba80 100644 --- a/src/snowflake/connector/ssl_wrap_socket.py +++ b/src/snowflake/connector/ssl_wrap_socket.py @@ -22,6 +22,7 @@ import OpenSSL.SSL from .constants import OCSPMode +from .crl import CertRevocationCheckMode, CRLConfig, CRLValidator from .errorcode import ER_OCSP_RESPONSE_CERT_STATUS_REVOKED from .errors import OperationalError from .session_manager import SessionManager @@ -31,6 +32,8 @@ DEFAULT_OCSP_MODE: OCSPMode = OCSPMode.FAIL_OPEN FEATURE_OCSP_MODE: OCSPMode = DEFAULT_OCSP_MODE +DEFAULT_CRL_CONFIG: CRLConfig = CRLConfig() +FEATURE_CRL_CONFIG: CRLConfig = DEFAULT_CRL_CONFIG """ OCSP Response cache file name @@ -141,11 +144,13 @@ def reset_current_session_manager(token) -> None: def inject_into_urllib3() -> None: """Monkey-patch urllib3 with PyOpenSSL-backed SSL-support and OCSP.""" log.debug("Injecting ssl_wrap_socket_with_ocsp") - connection_.ssl_wrap_socket = ssl_wrap_socket_with_ocsp + connection_.ssl_wrap_socket = ssl_wrap_socket_with_cert_revocation_checks @wraps(ssl_.ssl_wrap_socket) -def ssl_wrap_socket_with_ocsp(*args: Any, **kwargs: Any) -> WrappedSocket: +def ssl_wrap_socket_with_cert_revocation_checks( + *args: Any, **kwargs: Any +) -> WrappedSocket: # Bind passed args/kwargs to the underlying signature to support both positional and keyword calls bound = _sig(ssl_.ssl_wrap_socket).bind_partial(*args, **kwargs) params = bound.arguments @@ -167,6 +172,32 @@ def ssl_wrap_socket_with_ocsp(*args: Any, **kwargs: Any) -> WrappedSocket: ret = ssl_.ssl_wrap_socket(**params) + log.debug( + "CRL Check Mode: %s", + FEATURE_CRL_CONFIG.cert_revocation_check_mode.name, + ) + if ( + FEATURE_CRL_CONFIG.cert_revocation_check_mode + != CertRevocationCheckMode.DISABLED + ): + crl_validator = CRLValidator.from_config( + FEATURE_CRL_CONFIG, get_current_session_manager() + ) + if not crl_validator.validate_connection(ret.connection): + raise OperationalError( + msg=( + "The certificate is revoked or " + "could not be validated via CRL: hostname={}".format( + server_hostname + ) + ), + errno=ER_OCSP_RESPONSE_CERT_STATUS_REVOKED, + ) + log.debug( + "The certificate revocation check was successful. No additional checks will be performed." + ) + return ret + log.debug( "OCSP Mode: %s, OCSP response cache file name: %s", FEATURE_OCSP_MODE.name, diff --git a/test/extras/run.py b/test/extras/run.py index e29bfecc75..d9f53a33a7 100644 --- a/test/extras/run.py +++ b/test/extras/run.py @@ -30,20 +30,36 @@ # This is to test SNOW-79940, making sure tmp files are removed # Windows does not have ocsp_response_validation_cache.lock assert ( - cache_files - == { - "ocsp_response_validation_cache.json.lock", - "ocsp_response_validation_cache.json", - "ocsp_response_cache.json", - } - and not platform.system() == "Windows" - ) or ( - cache_files - == { - "ocsp_response_validation_cache.json", - "ocsp_response_cache.json", - } - and platform.system() == "Windows" - ), str( - cache_files - ) + ( + cache_files.issubset( + { + "ocsp_response_validation_cache.json.lock", + "ocsp_response_validation_cache.json", + "ocsp_response_cache.json", + "crls", + } + ) + and platform.system() == "Linux" + ) + or ( + cache_files.issubset( + { + "ocsp_response_validation_cache.json", + "ocsp_response_cache.json", + "crls", + } + ) + and platform.system() == "Windows" + ) + or ( + cache_files.issubset( + { + "ocsp_response_validation_cache.json.lock", + "ocsp_response_validation_cache.json", + "ocsp_response_cache.json", + "crls", + } + ) + and platform.system() == "Darwin" + ) + ), str(cache_files) diff --git a/test/integ/conftest.py b/test/integ/conftest.py index 4f41f3638e..077ea3d6ca 100644 --- a/test/integ/conftest.py +++ b/test/integ/conftest.py @@ -313,6 +313,10 @@ def init_test_schema(db_parameters) -> Generator[None]: "private_key_file": db_parameters["private_key_file"], } ) + if "private_key_file_pwd" in db_parameters: + connection_params["private_key_file_pwd"] = db_parameters[ + "private_key_file_pwd" + ] # Role may be needed when running on preprod, but is not present on Jenkins jobs optional_role = db_parameters.get("role") diff --git a/test/integ/test_crl.py b/test/integ/test_crl.py new file mode 100644 index 0000000000..17c678463c --- /dev/null +++ b/test/integ/test_crl.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python +""" +CRL (Certificate Revocation List) Validation Integration Tests + +These tests verify that CRL validation works correctly with real Snowflake connections +in different modes: DISABLED, ADVISORY, and ENABLED. +""" +from __future__ import annotations + +import tempfile + +import pytest + + +@pytest.mark.skipolddriver +def test_crl_validation_enabled_mode(conn_cnx): + """Test that connection works with CRL validation in ENABLED mode.""" + # ENABLED mode should work for normal Snowflake connections since they typically + # have valid certificates with proper CRL distribution points + with conn_cnx( + cert_revocation_check_mode="ENABLED", + allow_certificates_without_crl_url=True, # Allow certs without CRL URLs + crl_connection_timeout_ms=5000, # 5 second timeout + crl_read_timeout_ms=5000, # 5 second timeout + disable_ocsp_checks=True, + ) as cnx: + assert cnx, "Connection should succeed with CRL validation in ENABLED mode" + + # Verify we can execute a simple query + cur = cnx.cursor() + cur.execute("SELECT 1") + result = cur.fetchone() + assert result[0] == 1, "Query should execute successfully" + cur.close() + + # Verify CRL settings were applied + assert cnx.cert_revocation_check_mode == "ENABLED" + assert cnx.allow_certificates_without_crl_url is True + + +@pytest.mark.skipolddriver +def test_crl_validation_advisory_mode(conn_cnx): + """Test that connection works with CRL validation in ADVISORY mode.""" + # ADVISORY mode should be more lenient and allow connections even if CRL checks fail + with conn_cnx( + cert_revocation_check_mode="ADVISORY", + allow_certificates_without_crl_url=False, # Don't allow certs without CRL URLs + crl_connection_timeout_ms=3000, # 3 second timeout + crl_read_timeout_ms=3000, # 3 second timeout + enable_crl_cache=True, # Enable caching + crl_cache_validity_hours=1, # Cache for 1 hour + ) as cnx: + assert cnx, "Connection should succeed with CRL validation in ADVISORY mode" + + # Verify we can execute a simple query + cur = cnx.cursor() + cur.execute("SELECT CURRENT_VERSION()") + result = cur.fetchone() + assert result[0], "Query should return a version string" + cur.close() + + # Verify CRL settings were applied + assert cnx.cert_revocation_check_mode == "ADVISORY" + assert cnx.allow_certificates_without_crl_url is False + assert cnx.enable_crl_cache is True + + +@pytest.mark.skipolddriver +def test_crl_validation_disabled_mode(conn_cnx): + """Test that connection works with CRL validation in DISABLED mode (default).""" + # DISABLED mode should work without any CRL checks + with conn_cnx( + cert_revocation_check_mode="DISABLED", + ) as cnx: + assert cnx, "Connection should succeed with CRL validation in DISABLED mode" + + # Verify we can execute a simple query + cur = cnx.cursor() + cur.execute("SELECT 'CRL_DISABLED' as test_value") + result = cur.fetchone() + assert result[0] == "CRL_DISABLED", "Query should execute successfully" + cur.close() + + # Verify CRL settings were applied + assert cnx.cert_revocation_check_mode == "DISABLED" + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "crl_mode,allow_without_crl,should_succeed", + [ + ("DISABLED", True, True), # DISABLED mode always succeeds + ("DISABLED", False, True), # DISABLED mode always succeeds + ("ADVISORY", True, True), # ADVISORY mode is lenient + ("ADVISORY", False, True), # ADVISORY mode is lenient + ("ENABLED", True, True), # ENABLED with allow_without_crl should succeed + ("ENABLED", False, True), # ENABLED might succeed if certs have valid CRL URLs + ], +) +def test_crl_validation_modes_parametrized( + conn_cnx, crl_mode, allow_without_crl, should_succeed +): + """Parametrized test for different CRL validation modes and settings.""" + try: + with conn_cnx( + cert_revocation_check_mode=crl_mode, + allow_certificates_without_crl_url=allow_without_crl, + crl_connection_timeout_ms=5000, + crl_read_timeout_ms=5000, + ) as cnx: + if should_succeed: + assert ( + cnx + ), f"Connection should succeed with mode={crl_mode}, allow_without_crl={allow_without_crl}" + + # Test basic functionality + cur = cnx.cursor() + cur.execute("SELECT 1") + result = cur.fetchone() + assert result[0] == 1, "Basic query should work" + cur.close() + + # Verify settings + assert cnx.cert_revocation_check_mode == crl_mode + assert cnx.allow_certificates_without_crl_url == allow_without_crl + else: + pytest.fail( + f"Connection should have failed with mode={crl_mode}, allow_without_crl={allow_without_crl}" + ) + + except Exception as e: + if should_succeed: + pytest.fail( + f"Connection unexpectedly failed with mode={crl_mode}, allow_without_crl={allow_without_crl}: {e}" + ) + else: + # Expected failure - verify it's a connection-related error + assert ( + "revoked" in str(e).lower() or "crl" in str(e).lower() + ), f"Expected CRL-related error, got: {e}" + + +@pytest.mark.skipolddriver +def test_crl_cache_configuration(conn_cnx): + """Test CRL cache configuration options.""" + with tempfile.TemporaryDirectory() as temp_dir: + with conn_cnx( + cert_revocation_check_mode="ADVISORY", # Use advisory to avoid strict failures + enable_crl_cache=True, + enable_crl_file_cache=True, + crl_cache_dir=temp_dir, + crl_cache_validity_hours=2, + crl_cache_removal_delay_days=1, + crl_cache_cleanup_interval_hours=1, + crl_cache_start_cleanup=False, # Don't start background cleanup in tests + ) as cnx: + assert cnx, "Connection should succeed with CRL cache configuration" + + # Verify cache settings were applied + assert cnx.enable_crl_cache is True + assert cnx.enable_crl_file_cache is True + assert cnx.crl_cache_dir == temp_dir + assert cnx.crl_cache_validity_hours == 2 + assert cnx.crl_cache_removal_delay_days == 1 + assert cnx.crl_cache_cleanup_interval_hours == 1 + assert cnx.crl_cache_start_cleanup is False + + # Test basic functionality + cur = cnx.cursor() + cur.execute("SELECT 'cache_test' as result") + result = cur.fetchone() + assert ( + result[0] == "cache_test" + ), "Query should work with cache configuration" + cur.close() diff --git a/test/unit/test_crl.py b/test/unit/test_crl.py new file mode 100644 index 0000000000..303b700843 --- /dev/null +++ b/test/unit/test_crl.py @@ -0,0 +1,1497 @@ +#!/usr/bin/env python +from __future__ import annotations + +import logging +import random +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Any +from unittest.mock import Mock +from unittest.mock import patch as mock_patch + +import pytest +import responses +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID + +from snowflake.connector.crl import ( + CertRevocationCheckMode, + CRLConfig, + CRLValidationResult, + CRLValidator, +) +from snowflake.connector.crl_cache import CRLCacheEntry, CRLCacheManager +from snowflake.connector.session_manager import SessionManager + + +@pytest.fixture +def session_manager() -> SessionManager | Any: + """For testing purposes we mock SessionManager instances with `requests` module + to use `responses` module for mocking HTTP responses. + """ + import requests + + return requests + + +@pytest.fixture(scope="module") +def crl_urls(): + @dataclass + class CRLUrls: + _base_url = "http://localhost:43210" + primary_ca = _base_url + "/primary-ca.crl" + backup_ca = _base_url + "/backup-ca.crl" + test_ca = _base_url + "/test-ca.crl" + invalid_ca = _base_url + "/invalid-ca.crl" + valid_ca = _base_url + "/valid-ca.crl" + expired_ca = _base_url + "/expired-ca.crl" + + return CRLUrls() + + +@dataclass +class CertificateChain: + """Container for certificate chain components""" + + root_cert: x509.Certificate + intermediate_cert: x509.Certificate + leaf_cert: x509.Certificate + + +@pytest.fixture(scope="module") +def cert_gen(): + class CertificateGeneratorUtil: + """Utility class for generating test certificates - simplified Python version""" + + def __init__(self): + self.random = random.Random() + self.ca_private_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + self.ca_certificate = self._create_ca_certificate() + self.revoked_serial_numbers = set() + + def _create_ca_certificate(self) -> x509.Certificate: + """Create a CA certificate for signing other certificates""" + ca_name = x509.Name( + [ + x509.NameAttribute( + NameOID.COMMON_NAME, f"Test CA {self.random.randint(1, 10000)}" + ) + ] + ) + + ca_cert = ( + x509.CertificateBuilder() + .subject_name(ca_name) + .issuer_name(ca_name) # Self-signed + .public_key(self.ca_private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after( + datetime.now(timezone.utc) + timedelta(days=3650) # 10 years + ) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .add_extension( + x509.KeyUsage( + key_cert_sign=True, + crl_sign=True, + digital_signature=False, + key_encipherment=False, + key_agreement=False, + content_commitment=False, + data_encipherment=False, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + .sign(self.ca_private_key, hashes.SHA256(), backend=default_backend()) + ) + + return ca_cert + + def generate_valid_crl(self) -> bytes: + """Generate a valid CRL""" + builder = x509.CertificateRevocationListBuilder() + builder = builder.issuer_name(self.ca_certificate.subject) + builder = builder.last_update(datetime.now(timezone.utc)) + builder = builder.next_update( + datetime.now(timezone.utc) + timedelta(days=1) + ) + + # Add any revoked certificates + for serial_number in self.revoked_serial_numbers: + revoked_cert = ( + x509.RevokedCertificateBuilder() + .serial_number(serial_number) + .revocation_date(datetime.now(timezone.utc)) + .build() + ) + builder = builder.add_revoked_certificate(revoked_cert) + + crl = builder.sign( + self.ca_private_key, hashes.SHA256(), backend=default_backend() + ) + return crl.public_bytes(serialization.Encoding.DER) + + def generate_expired_crl(self) -> bytes: + """Generate an expired CRL""" + builder = x509.CertificateRevocationListBuilder() + builder = builder.issuer_name(self.ca_certificate.subject) + # Set dates in the past to make it expired + past_date = datetime.now(timezone.utc) - timedelta(days=2) + builder = builder.last_update(past_date - timedelta(days=1)) + builder = builder.next_update(past_date) # Already expired + + crl = builder.sign( + self.ca_private_key, hashes.SHA256(), backend=default_backend() + ) + return crl.public_bytes(serialization.Encoding.DER) + + def create_simple_chain(self) -> CertificateChain: + """Create a simple certificate chain for testing""" + # Generate key pairs + root_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + intermediate_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048 + ) + leaf_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + # Create root certificate (self-signed) + root_name = x509.Name( + [ + x509.NameAttribute( + NameOID.COMMON_NAME, + f"Test Root CA {self.random.randint(1, 10000)}", + ) + ] + ) + + root_cert = ( + x509.CertificateBuilder() + .subject_name(root_name) + .issuer_name(root_name) + .public_key(root_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(root_key, hashes.SHA256()) + ) + + # Create intermediate certificate + intermediate_name = x509.Name( + [ + x509.NameAttribute( + NameOID.COMMON_NAME, + f"Test Intermediate CA {self.random.randint(1, 10000)}", + ) + ] + ) + + intermediate_cert = ( + x509.CertificateBuilder() + .subject_name(intermediate_name) + .issuer_name(root_name) + .public_key(intermediate_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=0), + critical=True, + ) + .sign(root_key, hashes.SHA256()) + ) + + # Create leaf certificate + leaf_name = x509.Name( + [ + x509.NameAttribute( + NameOID.COMMON_NAME, + f"Test Leaf {self.random.randint(1, 10000)}", + ) + ] + ) + + leaf_cert = ( + x509.CertificateBuilder() + .subject_name(leaf_name) + .issuer_name(intermediate_name) + .public_key(leaf_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=False, path_length=None), + critical=True, + ) + .sign(intermediate_key, hashes.SHA256()) + ) + + return CertificateChain(root_cert, intermediate_cert, leaf_cert) + + def create_short_lived_certificate( + self, validity_days: int, issuance_date: datetime + ) -> x509.Certificate: + """Create a short-lived certificate for testing""" + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + name = x509.Name( + [ + x509.NameAttribute( + NameOID.COMMON_NAME, + f"Test Short-Lived {self.random.randint(1, 10000)}", + ) + ] + ) + + not_after = issuance_date + timedelta(days=validity_days) + + cert = ( + x509.CertificateBuilder() + .subject_name(name) + .issuer_name(name) # Self-signed for simplicity + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(issuance_date) + .not_valid_after(not_after) + .add_extension( + x509.BasicConstraints(ca=False, path_length=None), + critical=True, + ) + .sign(key, hashes.SHA256()) + ) + + return cert + + def create_certificate_with_crl_distribution_points( + self, subject_dn: str, crl_urls: list[str] + ) -> x509.Certificate: + """Create a certificate with CRL distribution points""" + + # Generate a new key pair for this certificate + cert_private_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + + subject_name = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, subject_dn)] + ) + + # Create certificate builder + builder = ( + x509.CertificateBuilder() + .subject_name(subject_name) + .issuer_name(self.ca_certificate.subject) + .public_key(cert_private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=False, path_length=None), + critical=False, + ) + ) + + # Add CRL distribution points if URLs provided + if crl_urls: + distribution_points = [] + for url in crl_urls: + distribution_point = x509.DistributionPoint( + full_name=[x509.UniformResourceIdentifier(url)], + relative_name=None, + crl_issuer=None, + reasons=None, + ) + distribution_points.append(distribution_point) + + crl_distribution_points = x509.CRLDistributionPoints( + distribution_points + ) + builder = builder.add_extension(crl_distribution_points, critical=False) + + # Sign the certificate with CA private key + certificate = builder.sign( + self.ca_private_key, hashes.SHA256(), backend=default_backend() + ) + + return certificate + + def generate_crl_with_revoked_certificate(self, serial_number: int) -> bytes: + """Generate a CRL with a specific certificate marked as revoked""" + self.revoked_serial_numbers.add(serial_number) + return self.generate_valid_crl() + + return CertificateGeneratorUtil() + + +def test_should_allow_connection_when_crl_validation_disabled( + cert_gen, session_manager +): + """Test that connections are allowed when CRL validation is disabled""" + chain = cert_gen.create_simple_chain() + chains = [[chain.leaf_cert, chain.intermediate_cert, chain.root_cert]] + + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.DISABLED, + ) + + assert validator.validate_certificate_chains(chains) + + +def test_should_allow_connection_when_crl_validation_disabled_and_no_cert_chain( + session_manager, +): + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.DISABLED, + ) + assert validator.validate_certificate_chains([]) + assert validator.validate_certificate_chains(None) + + +def test_should_fail_with_null_or_empty_certificate_chains(cert_gen, session_manager): + """Test that validator fails with null or empty certificate chains""" + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + ) + assert not validator.validate_certificate_chains([]) + assert not validator.validate_certificate_chains(None) + + +def test_should_handle_certificates_without_crl_urls_in_enabled_mode( + cert_gen, session_manager +): + """Test handling of certificates without CRL URLs in enabled mode""" + chain = cert_gen.create_simple_chain() + chains = [[chain.leaf_cert, chain.intermediate_cert, chain.root_cert]] + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + allow_certificates_without_crl_url=False, + ) + assert not validator.validate_certificate_chains(chains) + + +def test_should_allow_certificates_without_crl_urls_when_configured( + cert_gen, session_manager +): + """Test that certificates without CRL URLs are allowed when configured""" + chain = cert_gen.create_simple_chain() + chains = [[chain.leaf_cert, chain.intermediate_cert, chain.root_cert]] + + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + allow_certificates_without_crl_url=True, + ) + assert validator.validate_certificate_chains(chains) + + +def test_should_pass_in_advisory_mode_even_with_errors(cert_gen, session_manager): + """Test that validation passes in advisory mode even with errors""" + chain = cert_gen.create_simple_chain() + chains = [[chain.leaf_cert, chain.intermediate_cert, chain.root_cert]] + + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ADVISORY, + ) + + assert validator.validate_certificate_chains(chains) + + +def test_should_validate_multiple_chains_and_return_first_valid_with_no_crl_urls( + cert_gen, + session_manager, +): + """Test validation of multiple chains and return first valid""" + # Create a certificate that would be considered invalid (before March 2024) + before_march_2024 = datetime(2024, 2, 1, tzinfo=timezone.utc) + invalid_cert = cert_gen.create_short_lived_certificate(5, before_march_2024) + + # Create a valid chain + valid_chain = cert_gen.create_simple_chain() + + # Create list with invalid chain first, then valid chain + chains = [ + [invalid_cert, valid_chain.intermediate_cert, valid_chain.root_cert], + [valid_chain.leaf_cert, valid_chain.intermediate_cert, valid_chain.root_cert], + ] + + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + allow_certificates_without_crl_url=True, + ) + + result = validator.validate_certificate_chains(chains) + assert result, "Should return true when at least one valid chain is found" + + +@responses.activate +def test_should_validate_non_revoked_certificate_successfully( + cert_gen, crl_urls, session_manager +): + """Test validation of non-revoked certificate""" + # Setup mock HTTP client + crl_content = cert_gen.generate_valid_crl() + resp = responses.add( + responses.GET, + crl_urls.test_ca, + body=crl_content, + status=200, + content_type="application/pkcs7-mime", + ) + + # Create certificate with CRL distribution point + cert = cert_gen.create_certificate_with_crl_distribution_points( + "CN=Test Server", [crl_urls.test_ca] + ) + chain = [cert, cert_gen.ca_certificate] + + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + ) + + assert validator.validate_certificate_chains([chain]) + assert resp.call_count + + +@responses.activate +def test_should_fail_for_revoked_certificate(cert_gen, crl_urls, session_manager): + """Test failure for revoked certificate""" + # Create certificate first + cert = cert_gen.create_certificate_with_crl_distribution_points( + "CN=Revoked Server", [crl_urls.test_ca] + ) + + # mock a CRL with the cert as revoked + resp = responses.add( + responses.GET, + crl_urls.test_ca, + body=cert_gen.generate_crl_with_revoked_certificate(cert.serial_number), + status=200, + content_type="application/pkcs7-mime", + ) + + chain = [cert, cert_gen.ca_certificate] + + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + ) + + assert not validator.validate_certificate_chains([chain]) + assert resp.call_count + + +@responses.activate +def test_should_allow_revoked_certificate_when_crl_validation_disabled( + cert_gen, crl_urls, session_manager +): + """Test that revoked certificates are allowed when CRL validation is disabled""" + # Create certificate first + revoked_cert = cert_gen.create_certificate_with_crl_distribution_points( + "CN=Revoked Server (Disabled Mode)", [crl_urls.test_ca] + ) + # A mock response, which is expected not to be hit + resp = responses.add( + responses.GET, + crl_urls.test_ca, + body=cert_gen.generate_crl_with_revoked_certificate(revoked_cert.serial_number), + status=200, + content_type="application/pkcs7-mime", + ) + + chain = [revoked_cert, cert_gen.ca_certificate] + + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.DISABLED, + ) + + assert validator.validate_certificate_chains([chain]) + assert resp.call_count == 0 + + +@responses.activate +def test_should_pass_in_advisory_mode_with_crl_errors( + cert_gen, crl_urls, session_manager +): + """Test that advisory mode passes even with CRL errors""" + # Setup 404 response for CRL + resp = responses.add(responses.GET, crl_urls.test_ca, status=404) + + cert = cert_gen.create_certificate_with_crl_distribution_points( + "CN=Test Server", [crl_urls.test_ca] + ) + chain = [cert, cert_gen.ca_certificate] + + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ADVISORY, + ) + + assert validator.validate_certificate_chains([chain]) + assert resp.call_count + + +@responses.activate +def test_should_fail_in_enabled_mode_with_crl_errors( + cert_gen, crl_urls, session_manager +): + """Test that enabled mode fails with CRL errors""" + # Setup 404 response for CRL + resp = responses.add(responses.GET, crl_urls.test_ca, status=404) + + cert = cert_gen.create_certificate_with_crl_distribution_points( + "CN=Test Server", [crl_urls.test_ca] + ) + chain = [cert, cert_gen.ca_certificate] + + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + ) + + assert not validator.validate_certificate_chains([chain]) + assert resp.call_count + + +@responses.activate +def test_should_validate_multiple_chains_and_success_if_just_one_valid( + cert_gen, crl_urls, session_manager +): + """Test validation of multiple chains and return first valid""" + # Create certificates + invalid_cert = cert_gen.create_certificate_with_crl_distribution_points( + "CN=Invalid Server", [crl_urls.invalid_ca] + ) + invalid_chain = [invalid_cert, cert_gen.ca_certificate] + + valid_cert = cert_gen.create_certificate_with_crl_distribution_points( + "CN=Valid Server", [crl_urls.valid_ca] + ) + valid_chain = [valid_cert, cert_gen.ca_certificate] + + valid_crl_content = cert_gen.generate_valid_crl() + + resp_200 = responses.add( + responses.GET, + crl_urls.valid_ca, + body=valid_crl_content, + status=200, + content_type="application/pkcs7-mime", + ) + + # Setup 404 for invalid certificate CRL + resp_404 = responses.add(responses.GET, crl_urls.invalid_ca, status=404) + + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + ) + + assert validator.validate_certificate_chains([invalid_chain, valid_chain]) + assert resp_200.call_count + assert resp_404.call_count + + +@responses.activate +def test_should_reject_expired_crl(cert_gen, crl_urls, session_manager): + """Test rejection of expired CRL""" + # Setup mock HTTP client with expired CRL + resp = responses.add( + responses.GET, + crl_urls.expired_ca, + body=cert_gen.generate_expired_crl(), + status=200, + content_type="application/pkcs7-mime", + ) + + cert = cert_gen.create_certificate_with_crl_distribution_points( + "CN=Test Server", [crl_urls.expired_ca] + ) + chain = [cert, cert_gen.ca_certificate] + + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + ) + + assert not validator.validate_certificate_chains([chain]) + assert resp.call_count + + +def test_should_skip_short_lived_certificates(cert_gen, session_manager): + """Test that short-lived certificates skip CRL validation""" + # Create short-lived certificate (5 days validity) + short_lived_cert = cert_gen.create_short_lived_certificate( + 5, datetime.now(timezone.utc) + ) + chain = [short_lived_cert, cert_gen.ca_certificate] + + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + ) + + # Should pass without any HTTP calls (no responses setup) + assert validator.validate_certificate_chains([chain]) + + +@responses.activate +def test_should_handle_multiple_crl_distribution_points( + cert_gen, crl_urls, session_manager +): + """Test handling of multiple CRL distribution points""" + crl_content = cert_gen.generate_valid_crl() + # Setup mock HTTP that returns valid CRL for both URLs + resp_primary = responses.add( + responses.GET, + crl_urls.primary_ca, + body=crl_content, + status=200, + content_type="application/pkcs7-mime", + ) + resp_backup = responses.add( + responses.GET, + crl_urls.backup_ca, + body=crl_content, + status=200, + content_type="application/pkcs7-mime", + ) + + # Create certificate with multiple CRL URLs + crl_urls_list = [ + crl_urls.primary_ca, + crl_urls.backup_ca, + ] + cert = cert_gen.create_certificate_with_crl_distribution_points( + "CN=Multi-CRL Server", crl_urls_list + ) + chain = [cert, cert_gen.ca_certificate] + + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + ) + + assert validator.validate_certificate_chains([chain]) + assert resp_primary.call_count + assert resp_backup.call_count + + +def test_crl_validator_creation(session_manager): + """Test that CRLValidator can be created properly""" + + # Test basic instantiation + validator = CRLValidator(session_manager) + assert validator is not None + assert isinstance(validator, CRLValidator) + + # Test that it works with from_config class method + validator = CRLValidator.from_config(CRLConfig(), session_manager) + assert validator is not None + assert isinstance(validator, CRLValidator) + + +def test_crl_validator_atexit_cleanup(session_manager): + """Test that CRLValidator properly starts cleanup with atexit handler""" + from snowflake.connector.crl_cache import CRLCacheFactory + + # Create a config with cleanup enabled + config = CRLConfig( + enable_crl_cache=True, + crl_cache_start_cleanup=True, # This will start background cleanup + crl_cache_cleanup_interval_hours=1, + ) + + try: + # Create validator which should start cleanup + CRLValidator.from_config(config, session_manager) + + # Verify cleanup is running through factory + assert CRLCacheFactory.is_periodic_cleanup_running() + + # Verify atexit handler was registered + assert CRLCacheFactory._atexit_registered + + # Test the atexit handler directly + CRLCacheFactory._atexit_cleanup_handler() + + # After calling atexit handler, cleanup should be stopped + assert not CRLCacheFactory.is_periodic_cleanup_running() + finally: + # Ensure cleanup is stopped for other tests + CRLCacheFactory.reset() + + +def test_crl_validator_validate_connection(session_manager): + """Test the validate_connection method""" + # Create a mock connection + mock_connection = Mock() + + # Test with no certificate chain + mock_connection.get_peer_cert_chain.return_value = [] + validator = CRLValidator(session_manager) + + # Should return True when disabled (default) + assert validator.validate_connection(mock_connection) + + # Test with enabled mode and no certificates + validator = CRLValidator( + session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ENABLED + ) + assert not validator.validate_connection(mock_connection) + + +def test_crl_validator_extract_certificate_chains_from_connection( + cert_gen, session_manager +): + """Test the _extract_certificate_chains_from_connection method""" + validator = CRLValidator(session_manager) + + # Test with no certificate chain + mock_connection = Mock() + mock_connection.get_peer_cert_chain.return_value = [] + + chains = validator._extract_certificate_chains_from_connection(mock_connection) + assert chains == [] + + # Test with mock certificate chain + chain = cert_gen.create_simple_chain() + mock_certs = [] + + # Create mock OpenSSL certificates + for cert in [chain.leaf_cert, chain.intermediate_cert, chain.root_cert]: + mock_openssl_cert = Mock() + # Mock the dump_certificate call to return the DER bytes + cert_der = cert.public_bytes(serialization.Encoding.DER) + mock_certs.append((mock_openssl_cert, cert_der)) + + mock_connection.get_peer_cert_chain.return_value = [cert[0] for cert in mock_certs] + + # Mock dump_certificate to return the appropriate DER data + def mock_dump_certificate(file_type, cert_openssl): + for mock_cert, der_data in mock_certs: + if mock_cert == cert_openssl: + return der_data + raise ValueError("Certificate not found") + + # Patch dump_certificate from OpenSSL.crypto module + from unittest.mock import patch + + with patch("OpenSSL.crypto.dump_certificate", side_effect=mock_dump_certificate): + chains = validator._extract_certificate_chains_from_connection(mock_connection) + + assert len(chains) == 1 + assert len(chains[0]) == 3 # leaf, intermediate, root + + +# New comprehensive tests for CRLConfig.from_connection +def test_crl_config_from_connection_disabled_mode(): + """Test CRLConfig.from_connection with DISABLED mode""" + # from unittest.mock import Mock + + mock_connection = Mock() + mock_connection.cert_revocation_check_mode = "DISABLED" + + config = CRLConfig.from_connection(mock_connection) + + assert config.cert_revocation_check_mode == CertRevocationCheckMode.DISABLED + # Other parameters should use defaults when mode is disabled + + +def test_crl_config_from_connection_enabled_mode(): + """Test CRLConfig.from_connection with ENABLED mode and all parameters""" + from unittest.mock import Mock + + mock_connection = Mock() + mock_connection.cert_revocation_check_mode = "ENABLED" + mock_connection.allow_certificates_without_crl_url = True + mock_connection.crl_connection_timeout_ms = 5000 + mock_connection.crl_read_timeout_ms = 6000 + mock_connection.crl_cache_validity_hours = 12 + mock_connection.enable_crl_cache = False + mock_connection.enable_crl_file_cache = False + mock_connection.crl_cache_dir = "/custom/path" + mock_connection.crl_cache_removal_delay_days = 14 + mock_connection.crl_cache_cleanup_interval_hours = 2 + mock_connection.crl_cache_start_cleanup = True + + config = CRLConfig.from_connection(mock_connection) + + assert config.cert_revocation_check_mode == CertRevocationCheckMode.ENABLED + assert config.allow_certificates_without_crl_url + assert config.connection_timeout_ms == 5000 + assert config.read_timeout_ms == 6000 + assert config.cache_validity_time == timedelta(hours=12) + assert not config.enable_crl_cache + assert not config.enable_crl_file_cache + assert config.crl_cache_dir == Path("/custom/path") + assert config.crl_cache_removal_delay_days == 14 + assert config.crl_cache_cleanup_interval_hours == 2 + assert config.crl_cache_start_cleanup + + +def test_crl_config_from_connection_none_values(): + """Test CRLConfig.from_connection with None values uses defaults""" + mock_connection = Mock() + mock_connection.cert_revocation_check_mode = "ADVISORY" + mock_connection.allow_certificates_without_crl_url = None + mock_connection.crl_connection_timeout_ms = None + mock_connection.crl_read_timeout_ms = None + mock_connection.crl_cache_validity_hours = None + mock_connection.enable_crl_cache = None + mock_connection.enable_crl_file_cache = None + mock_connection.crl_cache_dir = None + mock_connection.crl_cache_removal_delay_days = None + mock_connection.crl_cache_cleanup_interval_hours = None + mock_connection.crl_cache_start_cleanup = None + + config = CRLConfig.from_connection(mock_connection) + + assert config.cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY + # All other parameters should use class defaults + assert ( + config.allow_certificates_without_crl_url + == CRLConfig.allow_certificates_without_crl_url + ) + assert config.connection_timeout_ms == CRLConfig.connection_timeout_ms + assert config.read_timeout_ms == CRLConfig.read_timeout_ms + assert config.cache_validity_time == CRLConfig.cache_validity_time + assert config.enable_crl_cache == CRLConfig.enable_crl_cache + assert config.enable_crl_file_cache == CRLConfig.enable_crl_file_cache + assert config.crl_cache_dir == CRLConfig.crl_cache_dir + assert config.crl_cache_removal_delay_days == CRLConfig.crl_cache_removal_delay_days + assert ( + config.crl_cache_cleanup_interval_hours + == CRLConfig.crl_cache_cleanup_interval_hours + ) + assert config.crl_cache_start_cleanup == CRLConfig.crl_cache_start_cleanup + + +def test_crl_config_from_connection_invalid_mode_string(): + """Test CRLConfig.from_connection with invalid cert_revocation_check_mode string""" + mock_connection = Mock() + mock_connection.cert_revocation_check_mode = "INVALID_MODE" + mock_connection.allow_certificates_without_crl_url = None + mock_connection.crl_connection_timeout_ms = None + mock_connection.crl_read_timeout_ms = None + mock_connection.crl_cache_validity_hours = None + mock_connection.enable_crl_cache = None + mock_connection.enable_crl_file_cache = None + mock_connection.crl_cache_dir = None + mock_connection.crl_cache_removal_delay_days = None + mock_connection.crl_cache_cleanup_interval_hours = None + mock_connection.crl_cache_start_cleanup = None + + # Should default to class default and log warning + config = CRLConfig.from_connection(mock_connection) + assert config.cert_revocation_check_mode == CRLConfig.cert_revocation_check_mode + + +def test_crl_config_from_connection_enum_mode(): + """Test CRLConfig.from_connection with CertRevocationCheckMode enum""" + mock_connection = Mock() + mock_connection.cert_revocation_check_mode = CertRevocationCheckMode.ADVISORY + mock_connection.allow_certificates_without_crl_url = None + mock_connection.crl_connection_timeout_ms = None + mock_connection.crl_read_timeout_ms = None + mock_connection.crl_cache_validity_hours = 1 + mock_connection.enable_crl_cache = None + mock_connection.enable_crl_file_cache = None + mock_connection.crl_cache_dir = None + mock_connection.crl_cache_removal_delay_days = None + mock_connection.crl_cache_cleanup_interval_hours = None + mock_connection.crl_cache_start_cleanup = None + + config = CRLConfig.from_connection(mock_connection) + assert config.cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY + + +def test_crl_config_from_connection_unsupported_mode_type(): + """Test CRLConfig.from_connection with unsupported cert_revocation_check_mode type""" + mock_connection = Mock() + mock_connection.cert_revocation_check_mode = 123 # Invalid type + mock_connection.allow_certificates_without_crl_url = None + mock_connection.crl_connection_timeout_ms = None + mock_connection.crl_read_timeout_ms = None + mock_connection.crl_cache_validity_hours = None + mock_connection.enable_crl_cache = None + mock_connection.enable_crl_file_cache = None + mock_connection.crl_cache_dir = None + mock_connection.crl_cache_removal_delay_days = None + mock_connection.crl_cache_cleanup_interval_hours = None + mock_connection.crl_cache_start_cleanup = None + + # Should default to class default and log warning + config = CRLConfig.from_connection(mock_connection) + assert config.cert_revocation_check_mode == CRLConfig.cert_revocation_check_mode + + +def test_crl_config_from_connection_none_mode(): + """Test CRLConfig.from_connection with None cert_revocation_check_mode""" + mock_connection = Mock() + mock_connection.cert_revocation_check_mode = None + mock_connection.allow_certificates_without_crl_url = None + mock_connection.crl_connection_timeout_ms = None + mock_connection.crl_read_timeout_ms = None + mock_connection.crl_cache_validity_hours = None + mock_connection.enable_crl_cache = None + mock_connection.enable_crl_file_cache = None + mock_connection.crl_cache_dir = None + mock_connection.crl_cache_removal_delay_days = None + mock_connection.crl_cache_cleanup_interval_hours = None + mock_connection.crl_cache_start_cleanup = None + + config = CRLConfig.from_connection(mock_connection) + assert config.cert_revocation_check_mode == CRLConfig.cert_revocation_check_mode + + +# Tests for CRL download and certificate checking functionality +@responses.activate +def test_crl_validator_download_crl_success(cert_gen, session_manager): + """Test successful CRL download""" + # Setup mock CRL response with valid CRL data + crl_url = "http://example.com/test.crl" + crl_data = cert_gen.generate_valid_crl() # Use valid CRL data + + responses.add( + responses.GET, + crl_url, + body=crl_data, + status=200, + content_type="application/pkcs7-mime", + ) + + validator = CRLValidator(session_manager) + + # Test the download method - it returns a tuple (crl, timestamp) + crl, timestamp = validator._download_crl(crl_url) + assert crl is not None # Should return parsed CRL object + assert timestamp is not None # Should return download timestamp + assert len(responses.calls) == 1 + + +@responses.activate +def test_crl_validator_download_crl_http_error(session_manager): + """Test CRL download with HTTP error""" + crl_url = "http://example.com/missing.crl" + + responses.add(responses.GET, crl_url, status=404) + + validator = CRLValidator(session_manager) + + # Should return (None, None) on HTTP error + crl, timestamp = validator._download_crl(crl_url) + assert crl is None + assert timestamp is None + + +@responses.activate +def test_crl_validator_download_crl_network_timeout(session_manager): + """Test CRL download with network timeout""" + from requests.exceptions import Timeout + + crl_url = "http://example.com/slow.crl" + + validator = CRLValidator( + session_manager, connection_timeout_ms=1000, read_timeout_ms=1000 + ) + + # Mock requests to raise timeout + with mock_patch.object( + session_manager, + "get", + side_effect=Timeout("Connection timeout"), + ): + crl, timestamp = validator._download_crl(crl_url) + assert crl is None + assert timestamp is None + + +@responses.activate +def test_crl_validator_download_crl_network_error(session_manager): + """Test CRL download with network connection error""" + from requests.exceptions import ConnectionError + + crl_url = "http://example.com/unreachable.crl" + + validator = CRLValidator(session_manager) + + # Mock requests to raise connection error + with mock_patch.object( + session_manager, "get", side_effect=ConnectionError("Connection failed") + ): + crl, timestamp = validator._download_crl(crl_url) + assert crl is None + assert timestamp is None + + +def test_crl_validator_extract_crl_distribution_points_success( + cert_gen, session_manager +): + """Test successful extraction of CRL distribution points""" + # Create certificate with CRL distribution points + crl_urls = ["http://example.com/ca.crl", "http://backup.com/ca.crl"] + cert = cert_gen.create_certificate_with_crl_distribution_points("CN=Test", crl_urls) + + validator = CRLValidator(session_manager) + + extracted_urls = validator._extract_crl_distribution_points(cert) + + assert len(extracted_urls) == 2 + assert "http://example.com/ca.crl" in extracted_urls + assert "http://backup.com/ca.crl" in extracted_urls + + +def test_crl_validator_extract_crl_distribution_points_no_extension( + cert_gen, session_manager +): + """Test extraction when certificate has no CRL distribution points""" + # Create simple certificate without CRL distribution points + chain = cert_gen.create_simple_chain() + cert = chain.leaf_cert + + validator = CRLValidator(session_manager) + + # Should return empty list when no CRL extension found + extracted_urls = validator._extract_crl_distribution_points(cert) + assert extracted_urls == [] + + +def test_crl_validator_check_certificate_against_crl_not_revoked( + cert_gen, session_manager +): + """Test certificate checking against CRL - not revoked""" + from cryptography.x509 import CertificateRevocationList + + # Create test certificate + chain = cert_gen.create_simple_chain() + cert = chain.leaf_cert + + # Mock CRL that doesn't contain the certificate + mock_crl = Mock(spec=CertificateRevocationList) + mock_crl.get_revoked_certificate_by_serial_number.return_value = None + + validator = CRLValidator(session_manager) + + # Should return UNREVOKED + result = validator._check_certificate_against_crl(cert, mock_crl) + assert result == CRLValidationResult.UNREVOKED + + +def test_crl_validator_check_certificate_against_crl_revoked(cert_gen, session_manager): + """Test certificate checking against CRL - revoked""" + from cryptography.x509 import CertificateRevocationList, RevokedCertificate + + # Create test certificate + chain = cert_gen.create_simple_chain() + cert = chain.leaf_cert + + # Mock CRL that contains the certificate as revoked + mock_revoked_cert = Mock(spec=RevokedCertificate) + mock_crl = Mock(spec=CertificateRevocationList) + mock_crl.get_revoked_certificate_by_serial_number.return_value = mock_revoked_cert + + validator = CRLValidator(session_manager) + + # Should return REVOKED + result = validator._check_certificate_against_crl(cert, mock_crl) + assert result == CRLValidationResult.REVOKED + + +def test_crl_validator_check_certificate_against_crl_expired( + cert_gen, session_manager, crl_urls +): + """Test certificate checking against expired CRL""" + + # Create test certificate + chain = cert_gen.create_simple_chain() + cert = chain.leaf_cert + parent = chain.intermediate_cert + + # Mock expired CRL + mock_crl = Mock(spec=x509.CertificateRevocationList) + mock_crl.next_update_utc = datetime.now(timezone.utc) - timedelta(days=1) # Expired + mock_crl.get_revoked_certificate_by_serial_number.return_value = None + + # Cache will return an expired CRL + mock_cache_mgr = Mock(spec=CRLCacheManager) + mock_cache_mgr.get.return_value = CRLCacheEntry(mock_crl, datetime.now()) + + validator = CRLValidator(session_manager, cache_manager=mock_cache_mgr) + with mock_patch.object( + validator, "_download_crl", return_value=(mock_crl, datetime.now()) + ) as mock_download, mock_patch.object( + validator, "_verify_crl_signature", return_value=True + ) as mock_verify: + result = validator._check_certificate_against_crl_url( + cert, parent, crl_urls.expired_ca + ) + assert result == CRLValidationResult.UNREVOKED + mock_cache_mgr.get.assert_called_once() + mock_download.assert_called_once() + mock_verify.assert_called_once_with(mock_crl, parent) + + +def test_crl_validator_validate_certificate_with_cache_hit( + cert_gen, session_manager, crl_urls +): + """Test certificate validation with cache hit""" + + # Create certificate with CRL distribution points + cert = cert_gen.create_certificate_with_crl_distribution_points( + "CN=Test", [crl_urls.test_ca] + ) + ca_cert = cert_gen.ca_certificate + + # Mock cache manager with cache hit + mock_crl = Mock(spec=x509.CertificateRevocationList) + mock_crl.next_update_utc = datetime.now(timezone.utc) + timedelta(days=7) + mock_cache_manager = Mock() + cached_entry = CRLCacheEntry(mock_crl, datetime.now(timezone.utc)) + mock_cache_manager.get.return_value = cached_entry + + validator = CRLValidator(session_manager) + validator._cache_manager = mock_cache_manager + + # Mock CRL parsing and validation + with mock_patch.object( + validator, + "_check_certificate_against_crl", + return_value=CRLValidationResult.UNREVOKED, + ) as mock_check, mock_patch.object( + validator, "_verify_crl_signature", return_value=True + ) as mock_verify: + result = validator._validate_certificate(cert, ca_cert) + + # Should use cached CRL + assert result == CRLValidationResult.UNREVOKED + mock_cache_manager.get.assert_called_once() + mock_check.assert_called_once_with(cert, cached_entry.crl) + mock_verify.assert_called_once_with(cached_entry.crl, ca_cert) + + +def test_crl_validator_validate_certificate_with_cache_miss( + cert_gen, session_manager, crl_urls +): + """Test certificate validation with cache miss and download""" + # Create certificate with CRL distribution points + cert = cert_gen.create_certificate_with_crl_distribution_points( + "CN=Test", [crl_urls.valid_ca] + ) + ca_cert = cert_gen.ca_certificate + + # Mock cache manager with cache miss + mock_cache_manager = Mock() + mock_cache_manager.get.return_value = None + + validator = CRLValidator(session_manager, cache_manager=mock_cache_manager) + + # Mock successful download and validation + with mock_patch.object( + validator, "_fetch_crl_from_url", return_value=b"downloaded_crl" + ) as mock_fetch, mock_patch( + "snowflake.connector.crl.x509.load_der_x509_crl" + ) as mock_load_crl, mock_patch.object( + validator, + "_check_certificate_against_crl", + return_value=CRLValidationResult.UNREVOKED, + ) as mock_check, mock_patch.object( + validator, "_verify_crl_signature", return_value=True + ) as mock_verify: + + mock_crl = Mock() + mock_crl.next_update_utc = datetime.now(timezone.utc) + timedelta(days=7) + mock_load_crl.return_value = mock_crl + + result = validator._validate_certificate(cert, ca_cert) + + # Should download CRL and cache it + assert result == CRLValidationResult.UNREVOKED + mock_cache_manager.get.assert_called_once() + mock_fetch.assert_called_once_with(crl_urls.valid_ca) + mock_cache_manager.put.assert_called_once() + mock_check.assert_called_once_with(cert, mock_crl) + mock_verify.assert_called_once_with(mock_crl, ca_cert) + + +def test_crl_signature_verification_success(cert_gen, session_manager): + """Test successful CRL signature verification""" + # Create a valid CRL signed by the test CA + crl_bytes = cert_gen.generate_valid_crl() + crl = x509.load_der_x509_crl(crl_bytes, backend=default_backend()) + + validator = CRLValidator(session_manager) + + # Should successfully verify the signature + result = validator._verify_crl_signature(crl, cert_gen.ca_certificate) + assert result is True + + +def test_crl_signature_verification_failure_wrong_ca(cert_gen, session_manager): + """Test CRL signature verification failure with wrong CA certificate""" + # Create a CRL signed by the test CA + crl_bytes = cert_gen.generate_valid_crl() + crl = x509.load_der_x509_crl(crl_bytes, backend=default_backend()) + + # Create a different CA certificate + different_ca_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + different_ca_name = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, "Different CA")] + ) + different_ca_cert = ( + x509.CertificateBuilder() + .subject_name(different_ca_name) + .issuer_name(different_ca_name) + .public_key(different_ca_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(different_ca_key, hashes.SHA256(), backend=default_backend()) + ) + + validator = CRLValidator(session_manager) + + # Should fail to verify the signature with wrong CA + result = validator._verify_crl_signature(crl, different_ca_cert) + assert result is False + + +def test_crl_signature_verification_with_ec_key(session_manager): + """Test CRL signature verification with EC (Elliptic Curve) keys""" + from cryptography.hazmat.primitives.asymmetric import ec + + # Generate EC key pair for CA + ec_private_key = ec.generate_private_key(ec.SECP256R1(), backend=default_backend()) + + # Create EC CA certificate + ca_name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "EC Test CA")]) + ec_ca_cert = ( + x509.CertificateBuilder() + .subject_name(ca_name) + .issuer_name(ca_name) + .public_key(ec_private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .add_extension( + x509.KeyUsage( + key_cert_sign=True, + crl_sign=True, + digital_signature=False, + key_encipherment=False, + key_agreement=False, + content_commitment=False, + data_encipherment=False, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + .sign(ec_private_key, hashes.SHA256(), backend=default_backend()) + ) + + # Create CRL signed with EC key + builder = x509.CertificateRevocationListBuilder() + builder = builder.issuer_name(ec_ca_cert.subject) + builder = builder.last_update(datetime.now(timezone.utc)) + builder = builder.next_update(datetime.now(timezone.utc) + timedelta(days=1)) + + ec_crl = builder.sign(ec_private_key, hashes.SHA256(), backend=default_backend()) + + validator = CRLValidator(session_manager) + + # Should successfully verify EC signature + result = validator._verify_crl_signature(ec_crl, ec_ca_cert) + assert result is True + + +def test_crl_signature_verification_with_corrupted_signature(cert_gen, session_manager): + """Test CRL signature verification with corrupted signature""" + # Create a valid CRL + crl_bytes = cert_gen.generate_valid_crl() + crl = x509.load_der_x509_crl(crl_bytes, backend=default_backend()) + + # Mock the CRL to have a corrupted signature + corrupted_crl = Mock(spec=x509.CertificateRevocationList) + corrupted_crl.signature_algorithm_oid = crl.signature_algorithm_oid + corrupted_crl.signature_hash_algorithm = crl.signature_hash_algorithm + corrupted_crl.signature = b"corrupted_signature_bytes" + corrupted_crl.tbs_certlist_bytes = crl.tbs_certlist_bytes + + validator = CRLValidator(session_manager) + + # Should fail to verify corrupted signature + result = validator._verify_crl_signature(corrupted_crl, cert_gen.ca_certificate) + assert result is False + + +def test_crl_signature_verification_exception_handling(cert_gen, session_manager): + """Test CRL signature verification exception handling""" + # Create a valid CRL + crl_bytes = cert_gen.generate_valid_crl() + crl = x509.load_der_x509_crl(crl_bytes, backend=default_backend()) + + # Mock CA certificate that will cause an exception + mock_ca_cert = Mock(spec=x509.Certificate) + mock_ca_cert.public_key.side_effect = Exception("Test exception") + + validator = CRLValidator(session_manager) + + # Should handle exception gracefully and return False + result = validator._verify_crl_signature(crl, mock_ca_cert) + assert result is False + + +def test_crl_signature_verification_integration_with_validation_flow( + cert_gen, crl_urls, session_manager +): + """Test that signature verification is properly integrated into the validation flow""" + # Create certificate with CRL distribution point + cert = cert_gen.create_certificate_with_crl_distribution_points( + "CN=Test Server", [crl_urls.test_ca] + ) + + # Create a CRL signed by a different CA (should fail signature verification) + different_ca_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + different_ca_name = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, "Different CA")] + ) + + # Create CRL with different CA + builder = x509.CertificateRevocationListBuilder() + builder = builder.issuer_name(different_ca_name) + builder = builder.last_update(datetime.now(timezone.utc)) + builder = builder.next_update(datetime.now(timezone.utc) + timedelta(days=1)) + + invalid_crl = builder.sign( + different_ca_key, hashes.SHA256(), backend=default_backend() + ) + invalid_crl_bytes = invalid_crl.public_bytes(serialization.Encoding.DER) + + # Test in ENABLED mode - should fail due to signature verification failure + validator_enabled = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + ) + + with mock_patch.object( + validator_enabled, "_fetch_crl_from_url", return_value=invalid_crl_bytes + ): + result = validator_enabled._validate_certificate(cert, cert_gen.ca_certificate) + assert result == CRLValidationResult.ERROR + + # Test in ADVISORY mode - should also fail due to signature verification failure + # CRL signature verification failure always returns ERROR regardless of mode + validator_advisory = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ADVISORY, + ) + + with mock_patch.object( + validator_advisory, "_fetch_crl_from_url", return_value=invalid_crl_bytes + ): + result = validator_advisory._validate_certificate(cert, cert_gen.ca_certificate) + # Even in ADVISORY mode, signature verification failure should return ERROR + # We cannot trust a CRL whose signature cannot be verified + assert result == CRLValidationResult.ERROR + + +def test_crl_signature_verification_with_issuer_mismatch_warning( + cert_gen, session_manager, caplog +): + """Test that we log a warning when CRL issuer doesn't match CA certificate subject""" + # Create a valid CRL signed by the test CA + crl_bytes = cert_gen.generate_valid_crl() + crl = x509.load_der_x509_crl(crl_bytes, backend=default_backend()) + + # Create a different CA certificate with different subject + different_ca_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + different_subject = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, "Different Subject CA")] + ) + different_ca_cert = ( + x509.CertificateBuilder() + .subject_name(different_subject) + .issuer_name(different_subject) + .public_key(different_ca_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(different_ca_key, hashes.SHA256(), backend=default_backend()) + ) + + validator = CRLValidator(session_manager) + + # Mock the _verify_crl_signature to return True to focus on the issuer check + with mock_patch.object( + validator, "_verify_crl_signature", return_value=True + ), mock_patch.object( + validator, + "_check_certificate_against_crl", + return_value=CRLValidationResult.UNREVOKED, + ), mock_patch.object( + validator, "_download_crl", return_value=(crl, datetime.now(timezone.utc)) + ), caplog.at_level( + logging.WARNING + ): + + # This should log a warning about issuer mismatch but still proceed + result = validator._check_certificate_against_crl_url( + cert_gen.ca_certificate, # dummy cert + different_ca_cert, # CA with different subject than CRL issuer + "http://test.crl", + ) + + # Should still return UNREVOKED since signature verification was mocked to succeed + assert result == CRLValidationResult.UNREVOKED + + # Verify that the warning was logged + assert len(caplog.records) > 0 + warning_found = any( + "CRL issuer" in record.message + and "does not match CA certificate subject" in record.message + for record in caplog.records + if record.levelno == logging.WARNING + ) + assert ( + warning_found + ), f"Expected warning about CRL issuer mismatch not found in logs: {[r.message for r in caplog.records]}" diff --git a/test/unit/test_crl_cache.py b/test/unit/test_crl_cache.py new file mode 100644 index 0000000000..41adf2eece --- /dev/null +++ b/test/unit/test_crl_cache.py @@ -0,0 +1,620 @@ +#!/usr/bin/env python +from __future__ import annotations + +import tempfile +import time +from datetime import datetime, timedelta, timezone +from pathlib import Path +from unittest.mock import Mock, mock_open +from unittest.mock import patch as mock_patch + +import pytest +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import rsa + +from snowflake.connector.crl_cache import ( + CRLCache, + CRLCacheEntry, + CRLCacheManager, + CRLFileCache, + CRLInMemoryCache, + NoopCRLCache, +) + + +@pytest.fixture(scope="module") +def download_time(): + return datetime.now(timezone.utc) - timedelta(minutes=30) + + +@pytest.fixture(scope="module") +def noop_cache(): + return NoopCRLCache() + + +@pytest.fixture(scope="module") +def memory_cache(): + return CRLInMemoryCache(cache_validity_time=timedelta(hours=1)) + + +@pytest.fixture(scope="module") +def disk_cache(): + with tempfile.TemporaryDirectory() as temp_dir: + yield CRLFileCache(Path(temp_dir), timedelta(hours=1)) + + +@pytest.fixture(scope="function") +def mem_cache_mock(): + return Mock(spec=CRLCache) + + +@pytest.fixture(scope="function") +def disk_cache_mock(): + return Mock(spec=CRLCache) + + +@pytest.fixture(scope="function") +def cache_mgr(mem_cache_mock, disk_cache_mock): + mgr = CRLCacheManager(mem_cache_mock, disk_cache_mock) + yield mgr + + +@pytest.fixture(scope="function") +def cache_factory(): + """Fixture that provides CRLCacheFactory and ensures cleanup after each test.""" + from snowflake.connector.crl_cache import CRLCacheFactory + + yield CRLCacheFactory + # Always reset the factory after each test to prevent test pollution + CRLCacheFactory.reset() + + +@pytest.fixture(scope="module") +def crl_url(crl): + return "http://test.com/crl" + + +@pytest.fixture(scope="module") +def crl() -> x509.CertificateRevocationList: + """Create a test CRL""" + # Generate a key pair for signing + private_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + + # Create a simple issuer name + issuer = x509.Name([x509.NameAttribute(x509.oid.NameOID.COMMON_NAME, "Test CA")]) + + # Build the CRL + builder = x509.CertificateRevocationListBuilder() + builder = builder.issuer_name(issuer) + builder = builder.last_update(datetime.now(timezone.utc)) + builder = builder.next_update(datetime.now(timezone.utc) + timedelta(days=1)) + + # Sign the CRL + crl = builder.sign(private_key, hashes.SHA256(), backend=default_backend()) + return crl + + +@pytest.fixture(scope="module") +def cache_entry(crl, download_time) -> CRLCacheEntry: + return CRLCacheEntry(crl, download_time) + + +def test_cache_entry_creation(crl, download_time): + """Test creating a cache entry""" + entry = CRLCacheEntry(crl, download_time) + + assert entry.crl == crl + assert entry.download_time == download_time + + +def test_is_crl_expired_false_when_not_expired(cache_entry): + """Test CRL expiration check when not expired""" + now = datetime.now(timezone.utc) + + assert not cache_entry.is_crl_expired_by(now) + + +def test_is_evicted_false_when_not_evicted(crl, download_time): + """Test cache eviction check when not evicted""" + entry = CRLCacheEntry(crl, download_time) + current_time = datetime.now(timezone.utc) + cache_validity = timedelta(hours=24) + + assert not entry.is_evicted_by(current_time, cache_validity) + + +def test_is_evicted_true_when_evicted(crl, download_time): + """Test cache eviction check when evicted""" + old_download_time = datetime.now(timezone.utc) - timedelta(days=2) + entry = CRLCacheEntry(crl, old_download_time) + current_time = datetime.now(timezone.utc) + cache_validity = timedelta(hours=1) # Short validity period + + assert entry.is_evicted_by(current_time, cache_validity) + + +def test_noop_get_returns_none(crl, crl_url, download_time, noop_cache): + """Test that get always returns None""" + result = noop_cache.get(crl_url) + assert result is None + + +def test_noop_put_does_nothing(crl, crl_url, download_time, noop_cache): + """Test that put does nothing""" + # Should not raise any exceptions + noop_cache.put(crl_url, CRLCacheEntry(crl, download_time)) + + +def test_noop_cleanup_does_nothing(noop_cache): + """Test that cleanup does nothing""" + # Should not raise any exceptions + noop_cache.cleanup() + + +def test_noop_singleton_behavior(): + """Test that NoopCRLCache behaves as singleton""" + cache1 = NoopCRLCache() + cache2 = NoopCRLCache() + assert cache1 is cache2 + + +def test_memory_put_and_get(crl, crl_url, download_time, memory_cache): + """Test storing and retrieving from memory cache""" + download_time = datetime.now(timezone.utc) + entry = CRLCacheEntry(crl, download_time) + + memory_cache.put(crl_url, entry) + result = memory_cache.get(crl_url) + + assert result is not None + assert result.download_time == download_time + assert result.crl == crl + + +def test_memory_get_nonexistent_returns_none(memory_cache): + """Test that getting non-existent entry returns None""" + result = memory_cache.get("http://nonexistent.com/crl") + assert result is None + + +def test_memory_cleanup_removes_evicted_entries( + crl, crl_url, download_time, memory_cache +): + """Test that cleanup removes evicted entries""" + # Add an old entry that should be evicted + old_time = datetime.now(timezone.utc) - timedelta(hours=2) + old_entry = CRLCacheEntry(crl, old_time) + memory_cache.put(crl_url, old_entry) + assert memory_cache.get(crl_url) is not None + + memory_cache.cleanup() + assert memory_cache.get(crl_url) is None + + +def test_disk_put_and_get(crl, crl_url, download_time, disk_cache): + """Test storing and retrieving from file cache""" + # download_time = datetime.now(timezone.utc) + entry = CRLCacheEntry(crl, download_time) + + disk_cache.put(crl_url, entry) + result = disk_cache.get(crl_url) + + assert result is not None + # Note: CRL comparison might not work directly, so we check the type + assert isinstance(result.crl, x509.CertificateRevocationList) + # Download time might be slightly different due to file system precision + assert abs(result.download_time.timestamp() - download_time.timestamp()) < 1.0 + + +def test_disk_get_nonexistent_returns_none(crl, download_time, disk_cache): + """Test that getting non-existent entry returns None""" + result = disk_cache.get("http://nonexistent.com/crl") + assert result is None + + +def test_should_return_cache_entry_when_memory_cache_hit( + crl, crl_url, download_time, cache_entry, mem_cache_mock, disk_cache_mock, cache_mgr +): + """Test returning cache entry when memory cache has it""" + mem_cache_mock.get.return_value = cache_entry + result = cache_mgr.get(crl_url) + + assert result is not None + assert result.crl == crl + assert result.download_time == download_time + mem_cache_mock.get.assert_called_once_with(crl_url) + disk_cache_mock.get.assert_not_called() + + +def test_should_promote_file_cache_hit_to_memory_cache( + crl, crl_url, download_time, cache_entry, mem_cache_mock, disk_cache_mock, cache_mgr +): + """Test promoting file cache hit to memory cache""" + mem_cache_mock.get.return_value = None + disk_cache_mock.get.return_value = cache_entry + result = cache_mgr.get(crl_url) + + assert result is not None + assert result.crl == crl + assert result.download_time == download_time + mem_cache_mock.get.assert_called_once_with(crl_url) + disk_cache_mock.get.assert_called_once_with(crl_url) + mem_cache_mock.put.assert_called_once_with(crl_url, cache_entry) + + +def test_should_return_none_when_both_caches_miss( + crl, crl_url, download_time, mem_cache_mock, disk_cache_mock, cache_mgr +): + """Test returning None when both caches miss""" + mem_cache_mock.get.return_value = None + disk_cache_mock.get.return_value = None + + result = cache_mgr.get(crl_url) + + assert result is None + mem_cache_mock.get.assert_called_once_with(crl_url) + disk_cache_mock.get.assert_called_once_with(crl_url) + mem_cache_mock.put.assert_not_called() + + +def test_should_put_to_both_memory_and_file_cache( + crl, crl_url, download_time, cache_entry, cache_mgr, mem_cache_mock, disk_cache_mock +): + """Test putting to both memory and file cache""" + cache_mgr.put(crl_url, crl, download_time) + + # Verify both caches were called + mem_cache_mock.put.assert_called_once() + disk_cache_mock.put.assert_called_once() + + # Check the arguments (entry should have correct CRL and time) + mem_put_call_args = mem_cache_mock.put.call_args[0] + disk_put_call_args = disk_cache_mock.put.call_args[0] + + assert mem_put_call_args == (crl_url, cache_entry) + assert disk_put_call_args == (crl_url, cache_entry) + + +def test_should_not_promote_to_memory_cache_when_file_cache_returns_none( + crl, crl_url, download_time, mem_cache_mock, disk_cache_mock, cache_mgr +): + """Test not promoting to memory cache when file cache returns None""" + mem_cache_mock.get.return_value = None + disk_cache_mock.get.return_value = None + + result = cache_mgr.get(crl_url) + + assert result is None + mem_cache_mock.get.assert_called_once_with(crl_url) + disk_cache_mock.get.assert_called_once_with(crl_url) + mem_cache_mock.put.assert_not_called() + + +def test_should_create_different_cache_entries_for_same_crl_with_different_download_times( + crl, crl_url, mem_cache_mock, disk_cache_mock, cache_mgr +): + """Test creating different cache entries for same CRL with different download times""" + first_put_time = datetime.now(timezone.utc) - timedelta(hours=1) + second_put_time = datetime.now(timezone.utc) + + cache_mgr.put(crl_url, crl, first_put_time) + cache_mgr.put(crl_url, crl, second_put_time) + + # Verify both puts were called + assert mem_cache_mock.put.call_count == 2 + assert disk_cache_mock.put.call_count == 2 + + # Check that the download times are different + first_memory_call = mem_cache_mock.put.call_args_list[0] + assert first_memory_call.args == (crl_url, CRLCacheEntry(crl, first_put_time)) + second_memory_call = mem_cache_mock.put.call_args_list[1] + assert second_memory_call.args == (crl_url, CRLCacheEntry(crl, second_put_time)) + + +def test_cleanup_loop_starts_and_stops_properly(cache_factory): + """Test that the cleanup loop starts and stops properly""" + # Initially by default the cleanup is not running + assert not cache_factory.is_periodic_cleanup_running() + + # Start the cleanup loop + cache_factory.start_periodic_cleanup(timedelta(milliseconds=50)) + + # Verify cleanup executor is created + assert cache_factory.is_periodic_cleanup_running() + + # Stop the cleanup loop + cache_factory.stop_periodic_cleanup() + + # Verify cleanup is properly stopped + assert not cache_factory.is_periodic_cleanup_running() + + +def test_cleanup_loop_calls_cleanup_on_both_caches_periodically( + cache_factory, mem_cache_mock, disk_cache_mock +): + """Test that the cleanup loop calls cleanup on both memory and file caches periodically""" + # Set up singleton instances to be cleaned + cache_factory._memory_cache_instance = mem_cache_mock + cache_factory._file_cache_instance = disk_cache_mock + + # Start the cleanup loop + cache_factory.start_periodic_cleanup(timedelta(milliseconds=50)) + + # Wait for at least 2 cleanup cycles to occur + time.sleep(0.15) + + # Stop the cleanup loop + cache_factory.stop_periodic_cleanup() + + # Verify that cleanup was called on both caches at least once + assert mem_cache_mock.cleanup.call_count >= 1 + assert disk_cache_mock.cleanup.call_count >= 1 + + # Verify both caches were called the same number of times + assert mem_cache_mock.cleanup.call_count == disk_cache_mock.cleanup.call_count + + +def test_cleanup_loop_handles_exceptions_gracefully( + cache_factory, mem_cache_mock, disk_cache_mock +): + """Test that the cleanup loop handles exceptions gracefully and continues running""" + + # Make memory cache cleanup raise an exception on first call, then work normally + mem_cache_mock.cleanup.side_effect = [ + Exception("Mem cache cleanup failure"), + None, + None, + None, + ] + disk_cache_mock.cleanup.side_effect = [ + None, + Exception("Disk cache cleanup failure"), + None, + None, + ] + + # Set up singleton instances to be cleaned + cache_factory._memory_cache_instance = mem_cache_mock + cache_factory._file_cache_instance = disk_cache_mock + + # Start the cleanup loop + cache_factory.start_periodic_cleanup(timedelta(milliseconds=50)) + + # Wait for multiple cleanup cycles to occur + time.sleep(0.15) + + # Stop the cleanup loop + cache_factory.stop_periodic_cleanup() + + # Verify that cleanup was attempted multiple times despite the exception + assert mem_cache_mock.cleanup.call_count > 1 + assert disk_cache_mock.cleanup.call_count > 1 + + +def test_cleanup_loop_stops_gracefully_with_shutdown_event( + cache_factory, mem_cache_mock, disk_cache_mock +): + """Test that the cleanup loop stops gracefully when shutdown event is set""" + + # Set up singleton instances to be cleaned + cache_factory._memory_cache_instance = mem_cache_mock + cache_factory._file_cache_instance = disk_cache_mock + + # Start cleanup with longer interval to test shutdown + cache_factory.start_periodic_cleanup(timedelta(hours=1)) + + # Give it a moment to make first cleanup cycle + time.sleep(0.1) + + # Stop the cleanup loop - this should interrupt the wait + cache_factory.stop_periodic_cleanup() + + # Verify cleanup was called at least once (initial call) + assert mem_cache_mock.cleanup.call_count == 1 + assert disk_cache_mock.cleanup.call_count == 1 + + +def test_cleanup_loop_double_stop_is_safe(cache_factory): + """Test that calling stop_periodic_cleanup multiple times is safe""" + # Start the cleanup loop + cache_factory.start_periodic_cleanup(timedelta(milliseconds=50)) + assert cache_factory.is_periodic_cleanup_running() + + # Stop it once + cache_factory.stop_periodic_cleanup() + assert not cache_factory.is_periodic_cleanup_running() + + # Stop it again - should not raise any exceptions + cache_factory.stop_periodic_cleanup() + assert not cache_factory.is_periodic_cleanup_running() + + +def test_cleanup_loop_double_start_is_safe_and_restarts( + cache_factory, mem_cache_mock, disk_cache_mock +): + """Test that calling start_periodic_cleanup multiple times creates new executors""" + # Set up singleton instances to be cleaned + cache_factory._memory_cache_instance = mem_cache_mock + cache_factory._file_cache_instance = disk_cache_mock + + for i in range(1, 3): + cache_factory.start_periodic_cleanup(timedelta(hours=1)) + time.sleep(0.1) + # The cleanup should be in the running state and by this moment successfully made exactly one additional cleanup cycle + assert cache_factory.is_periodic_cleanup_running() + assert mem_cache_mock.cleanup.call_count == i + assert disk_cache_mock.cleanup.call_count == i + + +# New comprehensive error handling tests +def test_file_cache_directory_creation_error(): + """Test CRLFileCache handles directory creation errors gracefully""" + # Create a path that would cause permission error + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) / "restricted" + + # Mock os.makedirs to raise PermissionError + with mock_patch( + "os.makedirs", side_effect=PermissionError("Permission denied") + ): + cache = CRLFileCache(cache_dir=cache_dir) + + # Should still work, but directory operations may fail gracefully + entry = CRLCacheEntry(b"test_crl", datetime.now(timezone.utc)) + # This should not crash even if directory creation fails + cache.put("test_key", entry) + + +def test_file_cache_file_write_error(): + """Test CRLFileCache handles file write errors gracefully""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + cache = CRLFileCache(cache_dir=cache_dir) + + entry = CRLCacheEntry(b"test_crl", datetime.now(timezone.utc)) + + # Mock open to raise IOError on write + mock_file = mock_open() + mock_file.return_value.write.side_effect = IOError("Disk full") + + with mock_patch("builtins.open", mock_file): + # Should not crash, but may log error + cache.put("test_key", entry) + + +def test_file_cache_file_read_error(): + """Test CRLFileCache handles file read errors gracefully""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + cache = CRLFileCache(cache_dir=cache_dir) + + # First put a valid entry + entry = CRLCacheEntry(b"test_crl", datetime.now(timezone.utc)) + cache.put("test_key", entry) + + # Mock open to raise IOError on read + with mock_patch("builtins.open", side_effect=IOError("File corrupted")): + # Should return None instead of crashing + result = cache.get("test_key") + assert result is None + + +def test_file_cache_cleanup_file_removal_error(): + """Test CRLFileCache cleanup handles file removal errors gracefully""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + cache = CRLFileCache(cache_dir=cache_dir, removal_delay=timedelta(seconds=0)) + + # Put an entry that should be removed immediately + entry = CRLCacheEntry( + b"test_crl", datetime.now(timezone.utc) - timedelta(days=1) + ) + cache.put("test_key", entry) + + # Mock os.remove to raise PermissionError + with mock_patch("os.remove", side_effect=PermissionError("File in use")): + # Should not crash during cleanup + cache.cleanup() + + +def test_factory_warning_messages_for_memory_cache(cache_factory): + """Test CRLCacheFactory logs appropriate warning for memory cache parameter mismatch""" + # First call with one validity time + cache1 = cache_factory.get_memory_cache(timedelta(hours=1)) + + # Second call with different validity time should log warning + with mock_patch("snowflake.connector.crl_cache.logger.warning") as mock_warning: + cache2 = cache_factory.get_memory_cache(timedelta(hours=2)) + + # Should return same instance + assert cache1 is cache2 + + # Should have logged warning with human-readable message + mock_warning.assert_called_once() + warning_msg = mock_warning.call_args[0][0] + assert "CRLs in-memory cache has already been initialized" in warning_msg + assert "1:00:00" in warning_msg # Original time + assert "2:00:00" in warning_msg # New time + + +def test_factory_warning_messages_for_file_cache(cache_factory): + """Test CRLCacheFactory logs appropriate warning for file cache parameter mismatch""" + with tempfile.TemporaryDirectory() as temp_dir1, tempfile.TemporaryDirectory() as temp_dir2: + cache_dir1 = Path(temp_dir1) + cache_dir2 = Path(temp_dir2) + + # First call with one directory and delay + cache1 = cache_factory.get_file_cache(cache_dir1, timedelta(days=7)) + + # Second call with different parameters should log warnings + with mock_patch("snowflake.connector.crl_cache.logger.warning") as mock_warning: + cache2 = cache_factory.get_file_cache(cache_dir2, timedelta(days=14)) + + # Should return same instance + assert cache1 is cache2 + + # Should have logged two warnings (for directory and delay) + assert mock_warning.call_count == 2 + + # Check warning messages + warning_calls = [call[0][0] for call in mock_warning.call_args_list] + dir_warning = next(msg for msg in warning_calls if "cache directory" in msg) + delay_warning = next(msg for msg in warning_calls if "removal delay" in msg) + + assert "CRLs file cache has already been initialized" in dir_warning + assert "CRLs file cache has already been initialized" in delay_warning + assert str(cache_dir1) in dir_warning + assert str(cache_dir2) in dir_warning + assert "7 days" in delay_warning + assert "14 days" in delay_warning + + +def test_platform_specific_cache_path(): + """Test _get_default_crl_cache_path returns platform-appropriate path""" + from snowflake.connector.crl_cache import _get_default_crl_cache_path + + # Test on different platforms + with mock_patch("platform.system") as mock_system, mock_patch( + "pathlib.Path.home" + ) as mock_home_path: + mock_home_path.return_value = Path("~") + + # Test Windows + mock_system.return_value = "Windows" + path = _get_default_crl_cache_path() + assert "AppData" in str(path) + assert "snowflake" in str(path).lower() + + # Test macOS + mock_system.return_value = "Darwin" + path = _get_default_crl_cache_path() + assert "Library" in str(path) + assert "snowflake" in str(path).lower() + + # Test Linux + mock_system.return_value = "Linux" + path = _get_default_crl_cache_path() + assert ".cache" in str(path) + assert "snowflake" in str(path).lower() + + +def test_atexit_handler_error_handling(cache_factory): + """Test atexit cleanup handler handles errors gracefully""" + # Start cleanup to register atexit handler + cache_factory.start_periodic_cleanup(timedelta(seconds=0.1)) + + # Mock stop_periodic_cleanup to raise exception + with mock_patch.object( + cache_factory, + "stop_periodic_cleanup", + side_effect=Exception("Test error"), + ): + # Calling atexit handler should not raise exception + try: + cache_factory._atexit_cleanup_handler() + except Exception as e: + pytest.fail(f"Atexit handler should not raise exceptions: {e}") diff --git a/test/unit/test_ssl_partial_chain.py b/test/unit/test_ssl_partial_chain.py index 4f74e63f95..1c691137d6 100644 --- a/test/unit/test_ssl_partial_chain.py +++ b/test/unit/test_ssl_partial_chain.py @@ -39,7 +39,7 @@ def fake_ssl_wrap_socket( # pylint: disable=unused-argument,too-many-arguments, monkeypatch.setattr(ssw.ssl_, "ssl_wrap_socket", fake_ssl_wrap_socket) # Call our wrapper without providing ssl_context; it should inject one - ssw.ssl_wrap_socket_with_ocsp( + ssw.ssl_wrap_socket_with_cert_revocation_checks( sock=None, keyfile=None, certfile=None, diff --git a/test/unit/test_ssl_partial_chain_handshake.py b/test/unit/test_ssl_partial_chain_handshake.py index 1e51ab4ff5..64abc58cfd 100644 --- a/test/unit/test_ssl_partial_chain_handshake.py +++ b/test/unit/test_ssl_partial_chain_handshake.py @@ -188,7 +188,7 @@ def test_partial_chain_handshake_succeeds_with_intermediate_as_anchor(): s.connect((host, port)) # The wrapper expects kwargs similar to urllib3; use provided context - ws = ssw.ssl_wrap_socket_with_ocsp( + ws = ssw.ssl_wrap_socket_with_cert_revocation_checks( sock=s, server_hostname="localhost", ssl_context=ctx, From 9a2c6cf6745054a3b1a16330e18ecd715b91a076 Mon Sep 17 00:00:00 2001 From: Patryk Czajka Date: Tue, 7 Oct 2025 11:19:50 +0200 Subject: [PATCH 02/16] Snow 2355881 Add CERT_REVOCATION_CHECK_MODE to CLIENT_ENVIRONMENT (#2562) --- DESCRIPTION.md | 20 +++++++- src/snowflake/connector/auth/_auth.py | 3 ++ src/snowflake/connector/auth/okta.py | 1 + src/snowflake/connector/auth/webbrowser.py | 1 + src/snowflake/connector/connection.py | 54 ++++++++++++++++------ src/snowflake/connector/crl.py | 6 +-- test/helpers.py | 2 +- test/integ/test_crl.py | 4 ++ test/unit/test_auth.py | 1 + test/unit/test_auth_keypair.py | 1 + test/unit/test_auth_okta.py | 1 + test/unit/test_auth_webbrowser.py | 1 + test/unit/test_crl.py | 10 ++++ test/unit/test_telemetry.py | 10 ++++ 14 files changed, 95 insertions(+), 20 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index d5483fd816..0608b90893 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -7,12 +7,30 @@ https://docs.snowflake.com/ Source code is also available at: https://github.com/snowflakedb/snowflake-connector-python # Release Notes -- v3.18.0(TBD) +- v4.1.0(TBD) + - Added `CERT_REVOCATION_CHECK_MODE` to `CLIENT_ENVIRONMENT` + +- v4.0.0(October 01,2025) - Added support for checking certificates revocation using revocation lists (CRLs) - Added the `workload_identity_impersonation_path` parameter to support service account impersonation for Workload Identity Federation on GCP and AWS workloads only - Fixed `get_results_from_sfqid` when using `DictCursor` and executing multiple statements at once - Added the `oauth_credentials_in_body` parameter supporting an option to send the oauth client credentials in the request body + - Fix retry behavior for `ECONNRESET` error + - Added an option to exclude `botocore` and `boto3` dependencies by setting `SNOWFLAKE_NO_BOTO` environment variable during installation + - Revert changing exception type in case of token expired scenario for `Oauth` authenticator back to `DatabaseError` + - Enhanced configuration file security checks with stricter permission validation. + - Configuration files writable by group or others now raise a `ConfigSourceError` with detailed permission information, preventing potential credential tampering. + - Added support for pandas conversion for Day-time and Year-Month Interval types + - Fixed the return type of `SnowflakeConnection.cursor(cursor_class)` to match the type of `cursor_class` + - Constrained the types of `fetchone`, `fetchmany`, `fetchall` + - As part of this fix, `DictCursor` is no longer a subclass of `SnowflakeCursor`; use `SnowflakeCursorBase` as a superclass of both. + - Fix "No AWS region was found" error if AWS region was set in `AWS_DEFAULT_REGION` variable instead of `AWS_REGION` for `WORKLOAD_IDENTITY` authenticator + - Add `ocsp_root_certs_dict_lock_timeout` connection parameter to set the timeout (in seconds) for acquiring the lock on the OCSP root certs dictionary. Default value for this parameter is -1 which indicates no timeout. + +- v3.17.4(September 22,2025) - Added support for intermediate certificates as roots when they are stored in the trust store + - Bumped up vendored `urllib3` to `2.5.0` and `requests` to `v2.32.5` + - Dropped support for OpenSSL versions older than 1.1.1 - v3.17.3(September 02,2025) - Enhanced configuration file permission warning messages. diff --git a/src/snowflake/connector/auth/_auth.py b/src/snowflake/connector/auth/_auth.py index 5dca31a361..a1e6ba302f 100644 --- a/src/snowflake/connector/auth/_auth.py +++ b/src/snowflake/connector/auth/_auth.py @@ -101,6 +101,7 @@ def base_auth_data( internal_application_name, internal_application_version, ocsp_mode, + cert_revocation_check_mode, login_timeout: int | None = None, network_timeout: int | None = None, socket_timeout: int | None = None, @@ -132,6 +133,7 @@ def base_auth_data( "PYTHON_RUNTIME": IMPLEMENTATION, "PYTHON_COMPILER": COMPILER, "OCSP_MODE": ocsp_mode.name, + "CERT_REVOCATION_CHECK_MODE": cert_revocation_check_mode, "TRACING": logger.getEffectiveLevel(), "LOGIN_TIMEOUT": login_timeout, "NETWORK_TIMEOUT": network_timeout, @@ -192,6 +194,7 @@ def authenticate( self._rest._connection._internal_application_name, self._rest._connection._internal_application_version, self._rest._connection._ocsp_mode(), + self._rest._connection.cert_revocation_check_mode, self._rest._connection.login_timeout, self._rest._connection._network_timeout, self._rest._connection._socket_timeout, diff --git a/src/snowflake/connector/auth/okta.py b/src/snowflake/connector/auth/okta.py index e6117216f1..cab810f1f2 100644 --- a/src/snowflake/connector/auth/okta.py +++ b/src/snowflake/connector/auth/okta.py @@ -166,6 +166,7 @@ def _step1( conn._internal_application_name, conn._internal_application_version, conn._ocsp_mode(), + conn.cert_revocation_check_mode, conn.login_timeout, conn.network_timeout, conn.socket_timeout, diff --git a/src/snowflake/connector/auth/webbrowser.py b/src/snowflake/connector/auth/webbrowser.py index 6f416cbdb3..3aa1b3d993 100644 --- a/src/snowflake/connector/auth/webbrowser.py +++ b/src/snowflake/connector/auth/webbrowser.py @@ -468,6 +468,7 @@ def _get_sso_url( conn._internal_application_name, conn._internal_application_version, conn._ocsp_mode(), + conn.cert_revocation_check_mode, conn.login_timeout, conn.network_timeout, conn.socket_timeout, diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index 93bf8302bd..7b887b6a55 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -86,6 +86,7 @@ QueryStatus, ) from .converter import SnowflakeConverter +from .crl import CRLConfig from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor, SnowflakeCursorBase from .description import ( CLIENT_NAME, @@ -433,7 +434,7 @@ def _get_private_bytes_from_file( ), # Read timeout for CRL downloads in milliseconds "crl_cache_validity_hours": ( None, - (type(None), int), + (type(None), float), ), # CRL cache validity time in hours "enable_crl_cache": (None, (type(None), bool)), # Enable CRL caching "enable_crl_file_cache": (None, (type(None), bool)), # Enable file-based CRL cache @@ -586,6 +587,7 @@ def __init__( # Placeholder attributes; will be initialized in connect() self._http_config: HttpConfig | None = None + self._crl_config: CRLConfig | None = None self._session_manager: SessionManager | None = None self._rest: SnowflakeRestful | None = None @@ -682,57 +684,81 @@ def _ocsp_mode(self) -> OCSPMode: @property def cert_revocation_check_mode(self) -> str | None: """Certificate revocation check mode: DISABLED, ENABLED, or ADVISORY.""" - return self._cert_revocation_check_mode + if not self._crl_config: + return self._cert_revocation_check_mode + return self._crl_config.cert_revocation_check_mode.value @property def allow_certificates_without_crl_url(self) -> bool | None: """Whether to allow certificates without CRL distribution points.""" - return self._allow_certificates_without_crl_url + if not self._crl_config: + return self._allow_certificates_without_crl_url + return self._crl_config.allow_certificates_without_crl_url @property def crl_connection_timeout_ms(self) -> int | None: """Connection timeout for CRL downloads in milliseconds.""" - return self._crl_connection_timeout_ms + if not self._crl_config: + return self._crl_connection_timeout_ms + return self._crl_config.connection_timeout_ms @property def crl_read_timeout_ms(self) -> int | None: """Read timeout for CRL downloads in milliseconds.""" - return self._crl_read_timeout_ms + if not self._crl_config: + return self._crl_read_timeout_ms + return self._crl_config.read_timeout_ms @property - def crl_cache_validity_hours(self) -> int | None: + def crl_cache_validity_hours(self) -> float | None: """CRL cache validity time in hours.""" - return self._crl_cache_validity_hours + if not self._crl_config: + return self._crl_cache_validity_hours + return self._crl_config.cache_validity_time.total_seconds() / 3600 @property def enable_crl_cache(self) -> bool | None: """Whether CRL caching is enabled.""" - return self._enable_crl_cache + if not self._crl_config: + return self._enable_crl_cache + return self._crl_config.enable_crl_cache @property def enable_crl_file_cache(self) -> bool | None: """Whether file-based CRL cache is enabled.""" - return self._enable_crl_file_cache + if not self._crl_config: + return self._enable_crl_file_cache + return self._crl_config.enable_crl_file_cache @property def crl_cache_dir(self) -> str | None: """Directory for CRL file cache.""" - return self._crl_cache_dir + if not self._crl_config: + return self._crl_cache_dir + if not self._crl_config.crl_cache_dir: + return None + return str(self._crl_config.crl_cache_dir) @property def crl_cache_removal_delay_days(self) -> int | None: """Days to keep expired CRL files before removal.""" - return self._crl_cache_removal_delay_days + if not self._crl_config: + return self._crl_cache_removal_delay_days + return self._crl_config.crl_cache_removal_delay_days @property def crl_cache_cleanup_interval_hours(self) -> int | None: """CRL cache cleanup interval in hours.""" - return self._crl_cache_cleanup_interval_hours + if not self._crl_config: + return self._crl_cache_cleanup_interval_hours + return self._crl_config.crl_cache_cleanup_interval_hours @property def crl_cache_start_cleanup(self) -> bool | None: """Whether to start CRL cache cleanup immediately.""" - return self._crl_cache_start_cleanup + if not self._crl_config: + return self._crl_cache_start_cleanup + return self._crl_config.crl_cache_start_cleanup @property def session_id(self) -> int: @@ -1035,6 +1061,8 @@ def connect(self, **kwargs) -> None: if len(kwargs) > 0: self.__config(**kwargs) + self._crl_config: CRLConfig = CRLConfig.from_connection(self) + self._http_config = HttpConfig( adapter_factory=ProxySupportAdapterFactory(), use_pooling=(not self.disable_request_pooling), diff --git a/src/snowflake/connector/crl.py b/src/snowflake/connector/crl.py index 69d5261d29..0b1203fc95 100644 --- a/src/snowflake/connector/crl.py +++ b/src/snowflake/connector/crl.py @@ -105,15 +105,11 @@ def from_connection(cls, sf_connection) -> CRLConfig: ) cert_revocation_check_mode = cls.cert_revocation_check_mode - if cert_revocation_check_mode == CertRevocationCheckMode.DISABLED: - # The rest of the parameters don't matter if CRL checking is disabled - return cls(cert_revocation_check_mode=cert_revocation_check_mode) - # Apply default value logic for all other parameters when connection attribute is None cache_validity_time = ( cls.cache_validity_time if sf_connection.crl_cache_validity_hours is None - else timedelta(hours=int(sf_connection.crl_cache_validity_hours)) + else timedelta(hours=float(sf_connection.crl_cache_validity_hours)) ) crl_cache_dir = ( cls.crl_cache_dir diff --git a/test/helpers.py b/test/helpers.py index 6c335c930e..441e51f011 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -328,6 +328,7 @@ def create_mock_auth_body(): "internal_application_name", "internal_application_version", ocsp_mode, + "CRL_MODE", login_timeout=60 * 60, network_timeout=60 * 60, socket_timeout=60 * 60, @@ -341,7 +342,6 @@ def apply_auth_class_update_body(auth_class, req_body_before): auth_class.update_body(req_body_after) return req_body_after - async def apply_auth_class_update_body_async(auth_class, req_body_before): req_body_after = copy.deepcopy(req_body_before) await auth_class.update_body(req_body_after) diff --git a/test/integ/test_crl.py b/test/integ/test_crl.py index 17c678463c..4b9ff21304 100644 --- a/test/integ/test_crl.py +++ b/test/integ/test_crl.py @@ -63,6 +63,10 @@ def test_crl_validation_advisory_mode(conn_cnx): assert cnx.cert_revocation_check_mode == "ADVISORY" assert cnx.allow_certificates_without_crl_url is False assert cnx.enable_crl_cache is True + assert cnx.crl_connection_timeout_ms == 3000 + assert cnx.crl_read_timeout_ms == 3000 + assert cnx.crl_cache_validity_hours == 1 + assert cnx.crl_cache_dir is None @pytest.mark.skipolddriver diff --git a/test/unit/test_auth.py b/test/unit/test_auth.py index cfae32f8c3..92af49542f 100644 --- a/test/unit/test_auth.py +++ b/test/unit/test_auth.py @@ -29,6 +29,7 @@ def _init_rest(application, post_requset): connection = mock_connection() connection.errorhandler = Mock(return_value=None) connection._ocsp_mode = Mock(return_value=OCSPMode.FAIL_OPEN) + connection.cert_revocation_check_mode = "TEST_CRL_MODE" type(connection).application = PropertyMock(return_value=application) type(connection)._internal_application_name = PropertyMock(return_value=CLIENT_NAME) type(connection)._internal_application_version = PropertyMock( diff --git a/test/unit/test_auth_keypair.py b/test/unit/test_auth_keypair.py index 80c27e9602..d5a20e8e74 100644 --- a/test/unit/test_auth_keypair.py +++ b/test/unit/test_auth_keypair.py @@ -155,6 +155,7 @@ def _init_rest(application, post_requset): connection = mock_connection() connection.errorhandler = Mock(return_value=None) connection._ocsp_mode = Mock(return_value=OCSPMode.FAIL_OPEN) + connection.cert_revocation_check_mode = "TEST_CRL_MODE" type(connection).application = PropertyMock(return_value=application) type(connection)._internal_application_name = PropertyMock(return_value=CLIENT_NAME) type(connection)._internal_application_version = PropertyMock( diff --git a/test/unit/test_auth_okta.py b/test/unit/test_auth_okta.py index 206f630969..9404fc1d85 100644 --- a/test/unit/test_auth_okta.py +++ b/test/unit/test_auth_okta.py @@ -345,6 +345,7 @@ def post_request(url, headers, body, **kwargs): connection = mock_connection(disable_saml_url_check=disable_saml_url_check) connection.errorhandler = Mock(return_value=None) connection._ocsp_mode = Mock(return_value=OCSPMode.FAIL_OPEN) + connection.cert_revocation_check_mode = "TEST_CRL_MODE" type(connection).application = PropertyMock(return_value=CLIENT_NAME) type(connection)._internal_application_name = PropertyMock(return_value=CLIENT_NAME) type(connection)._internal_application_version = PropertyMock( diff --git a/test/unit/test_auth_webbrowser.py b/test/unit/test_auth_webbrowser.py index 9fab65b0a2..ac2512ddd2 100644 --- a/test/unit/test_auth_webbrowser.py +++ b/test/unit/test_auth_webbrowser.py @@ -364,6 +364,7 @@ def post_request(url, headers, body, **kwargs): connection = mock_connection() connection.errorhandler = Mock(return_value=None) connection._ocsp_mode = Mock(return_value=OCSPMode.FAIL_OPEN) + connection.cert_revocation_check_mode = "TEST_CRL_MODE" connection._disable_console_login = disable_console_login type(connection).application = PropertyMock(return_value=CLIENT_NAME) type(connection)._internal_application_name = PropertyMock(return_value=CLIENT_NAME) diff --git a/test/unit/test_crl.py b/test/unit/test_crl.py index 303b700843..67ae3e6781 100644 --- a/test/unit/test_crl.py +++ b/test/unit/test_crl.py @@ -812,6 +812,16 @@ def test_crl_config_from_connection_disabled_mode(): mock_connection = Mock() mock_connection.cert_revocation_check_mode = "DISABLED" + mock_connection.allow_certificates_without_crl_url = None + mock_connection.crl_connection_timeout_ms = None + mock_connection.crl_read_timeout_ms = None + mock_connection.crl_cache_validity_hours = None + mock_connection.enable_crl_cache = None + mock_connection.enable_crl_file_cache = None + mock_connection.crl_cache_dir = None + mock_connection.crl_cache_removal_delay_days = None + mock_connection.crl_cache_cleanup_interval_hours = None + mock_connection.crl_cache_start_cleanup = None config = CRLConfig.from_connection(mock_connection) diff --git a/test/unit/test_telemetry.py b/test/unit/test_telemetry.py index 336a9d9c6e..7f62a642d7 100644 --- a/test/unit/test_telemetry.py +++ b/test/unit/test_telemetry.py @@ -409,6 +409,16 @@ def get_mocked_telemetry_connection(telemetry_enabled: bool = True) -> Mock: mock_connection.is_closed = False mock_connection.socket_timeout = None mock_connection.messages = [] + mock_connection.crl_cache_validity_hours = None + mock_connection.crl_cache_dir = None + mock_connection.crl_connection_timeout_ms = None + mock_connection.crl_read_timeout_ms = None + mock_connection.crl_cache_removal_delay_days = None + mock_connection.crl_cache_cleanup_interval_hours = None + mock_connection.crl_cache_start_cleanup = None + mock_connection.enable_crl_cache = None + mock_connection.enable_crl_file_cache = None + mock_connection.allow_certificates_without_crl_url = None from src.snowflake.connector.errors import Error From ebed46a0f967aae2a834e9847f13e2b5f0efede1 Mon Sep 17 00:00:00 2001 From: Patryk Czajka Date: Wed, 8 Oct 2025 15:12:21 +0200 Subject: [PATCH 03/16] Snow 2388762 crl post review fixes (#2567) Co-authored-by: James Kasten --- src/snowflake/connector/crl.py | 199 +++++-- src/snowflake/connector/ssl_wrap_socket.py | 24 +- test/integ/test_crl.py | 10 + test/unit/test_crl.py | 655 +++++++++++++++++++-- 4 files changed, 811 insertions(+), 77 deletions(-) diff --git a/src/snowflake/connector/crl.py b/src/snowflake/connector/crl.py index 0b1203fc95..0d774ebe18 100644 --- a/src/snowflake/connector/crl.py +++ b/src/snowflake/connector/crl.py @@ -1,6 +1,7 @@ #!/usr/bin/env python from __future__ import annotations +from collections import defaultdict from dataclasses import dataclass from datetime import datetime, timedelta, timezone from enum import Enum, unique @@ -11,6 +12,7 @@ from cryptography import x509 from cryptography.hazmat._oid import ExtensionOID from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ec, padding, rsa from OpenSSL.SSL import Connection as SSLConnection @@ -53,15 +55,15 @@ class CRLConfig: CertRevocationCheckMode.DISABLED ) allow_certificates_without_crl_url: bool = False - connection_timeout_ms: int = 3000 - read_timeout_ms: int = 3000 + connection_timeout_ms: int = 5000 + read_timeout_ms: int = 5000 # 5s cache_validity_time: timedelta = timedelta(hours=24) enable_crl_cache: bool = True enable_crl_file_cache: bool = True crl_cache_dir: Path | str | None = None crl_cache_removal_delay_days: int = 7 crl_cache_cleanup_interval_hours: int = 1 - crl_cache_start_cleanup: bool = False + crl_cache_start_cleanup: bool = True @classmethod def from_connection(cls, sf_connection) -> CRLConfig: @@ -176,6 +178,7 @@ class CRLValidator: def __init__( self, session_manager: SessionManager | Any, + trusted_certificates: list[x509.Certificate], cert_revocation_check_mode: CertRevocationCheckMode = CRLConfig.cert_revocation_check_mode, allow_certificates_without_crl_url: bool = CRLConfig.allow_certificates_without_crl_url, connection_timeout_ms: int = CRLConfig.connection_timeout_ms, @@ -191,9 +194,22 @@ def __init__( self._cache_validity_time = cache_validity_time self._cache_manager = cache_manager or CRLCacheManager.noop() + # list of trusted CA and their certificates + self._trusted_ca: dict[x509.Name, list[x509.Certificate]] = defaultdict(list) + for cert in trusted_certificates: + self._trusted_ca[cert.subject].append(cert) + + # declaration of validate_certificate_is_not_revoked function cache + self._cache_for__validate_certificate_is_not_revoked: dict[ + x509.Certificate, CRLValidationResult + ] = {} + @classmethod def from_config( - cls, config: CRLConfig, session_manager: SessionManager + cls, + config: CRLConfig, + session_manager: SessionManager, + trusted_certificates: list[x509.Certificate], ) -> CRLValidator: """ Create a CRLValidator instance from a CRLConfig. @@ -204,6 +220,7 @@ def from_config( Args: config: CRLConfig instance containing CRL-related parameters session_manager: SessionManager instance + trusted_certificates: List of trusted CA certificates Returns: CRLValidator: Configured CRLValidator instance @@ -244,6 +261,7 @@ def from_config( return cls( session_manager=session_manager, + trusted_certificates=trusted_certificates, cert_revocation_check_mode=config.cert_revocation_check_mode, allow_certificates_without_crl_url=config.allow_certificates_without_crl_url, connection_timeout_ms=config.connection_timeout_ms, @@ -272,9 +290,7 @@ def validate_certificate_chains( if certificate_chains is None or len(certificate_chains) == 0: logger.warning("Certificate chains are empty") - if self._cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY: - return True - return False + return self._cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY results = [] for chain in certificate_chains: @@ -294,24 +310,133 @@ def validate_certificate_chains( def _validate_single_chain( self, chain: list[x509.Certificate] ) -> CRLValidationResult: - """Validate a single certificate chain""" + """ + Returns: + UNREVOKED: If there is a path to any trusted certificate where all certificates are unrevoked. + REVOKED: If all paths to trusted certificates are revoked. + ERROR: If there is a path to any trusted certificate on which none certificate is revoked, + but some certificates can't be verified. + """ # An empty chain is considered an error if len(chain) == 0: return CRLValidationResult.ERROR - # the last certificate of the chain is considered the root and isn't validated - results = [] - for i in range(len(chain) - 1): - result = self._validate_certificate(chain[i], chain[i + 1]) - if result == CRLValidationResult.REVOKED: - return CRLValidationResult.REVOKED - results.append(result) - if CRLValidationResult.ERROR in results: + subject_certificates: dict[x509.Name, list[x509.Certificate]] = defaultdict( + list + ) + for cert in chain: + subject_certificates[cert.subject].append(cert) + currently_visited_subjects: set[x509.Name] = set() + + def traverse_chain(cert: x509.Certificate) -> CRLValidationResult | None: + # UNREVOKED - unrevoked path to a trusted certificate found + # REVOKED - all paths are revoked + # ERROR - some certificates on potentially unrevoked paths can't be verified, or no path to a trusted CA is detected + # None - ignore this path (cycle detected) + if self._is_certificate_trusted_by_os(cert): + logger.debug("Found trusted certificate: %s", cert.subject) + return CRLValidationResult.UNREVOKED + + if trusted_ca_issuer := self._get_trusted_ca_issuer(cert): + logger.debug("Certificate signed by trusted CA: %s", cert.subject) + return self._validate_certificate_is_not_revoked_with_cache( + cert, trusted_ca_issuer + ) + + if cert.issuer in currently_visited_subjects: + # cycle detected - invalid path + return None + + valid_results: list[tuple[CRLValidationResult, x509.Certificate]] = [] + for ca_cert in subject_certificates[cert.issuer]: + if not self._verify_certificate_signature(cert, ca_cert): + logger.debug( + "Certificate signature verification failed for %s, looking for other paths", + cert, + ) + continue + + currently_visited_subjects.add(cert.issuer) + ca_result = traverse_chain(ca_cert) + currently_visited_subjects.remove(cert.issuer) + if ca_result is None: + # ignore invalid path result + continue + if ca_result == CRLValidationResult.UNREVOKED: + # good path found + return self._validate_certificate_is_not_revoked_with_cache( + cert, ca_cert + ) + valid_results.append((ca_result, ca_cert)) + + if len(valid_results) == 0: + # "root" certificate not cought by "is_trusted_by_os" check + logger.debug("No path towards trusted anchor: %s", cert.subject) + return CRLValidationResult.ERROR + + # check if there exists an ERROR path + for ca_result, ca_cert in valid_results: + if ca_result == CRLValidationResult.ERROR: + cert_result = self._validate_certificate_is_not_revoked_with_cache( + cert, ca_cert + ) + if cert_result == CRLValidationResult.REVOKED: + return CRLValidationResult.REVOKED + return CRLValidationResult.ERROR + + # no ERROR result found, all paths are REVOKED + return CRLValidationResult.REVOKED + + currently_visited_subjects.add(chain[0].subject) + error_result = False + revoked_result = False + for cert in subject_certificates[chain[0].subject]: + result = traverse_chain(cert) + if result == CRLValidationResult.UNREVOKED: + return result + error_result |= result == CRLValidationResult.ERROR + revoked_result |= result == CRLValidationResult.REVOKED + + if error_result or not revoked_result: return CRLValidationResult.ERROR + return CRLValidationResult.REVOKED - return CRLValidationResult.UNREVOKED + def _is_certificate_trusted_by_os(self, cert: x509.Certificate) -> bool: + if cert.subject not in self._trusted_ca: + return False + + cert_der = cert.public_bytes(serialization.Encoding.DER) + return any( + cert_der == trusted_cert.public_bytes(serialization.Encoding.DER) + for trusted_cert in self._trusted_ca[cert.subject] + ) - def _validate_certificate( + def _get_trusted_ca_issuer(self, cert: x509.Certificate) -> x509.Certificate | None: + for trusted_cert in self._trusted_ca[cert.issuer]: + if self._verify_certificate_signature(cert, trusted_cert): + return trusted_cert + return None + + def _verify_certificate_signature( + self, cert: x509.Certificate, ca_cert: x509.Certificate + ) -> bool: + try: + cert.verify_directly_issued_by(ca_cert) + return True + except Exception: + return False + + def _validate_certificate_is_not_revoked_with_cache( + self, cert: x509.Certificate, ca_cert: x509.Certificate + ) -> CRLValidationResult: + # validate certificate can be called multiple times with the same certificate + if cert not in self._cache_for__validate_certificate_is_not_revoked: + self._cache_for__validate_certificate_is_not_revoked[cert] = ( + self._validate_certificate_is_not_revoked(cert, ca_cert) + ) + return self._cache_for__validate_certificate_is_not_revoked[cert] + + def _validate_certificate_is_not_revoked( self, cert: x509.Certificate, ca_cert: x509.Certificate ) -> CRLValidationResult: """Validate a single certificate against CRL""" @@ -343,14 +468,29 @@ def _validate_certificate( @staticmethod def _is_short_lived_certificate(cert: x509.Certificate) -> bool: - """Check if certificate is short-lived (validity <= 5 days)""" + """Check if certificate is short-lived according to CA/Browser Forum definition: + - For certificates issued on or after 15 March 2024 and prior to 15 March 2026: + validity period <= 10 days (864,000 seconds) + - For certificates issued on or after 15 March 2026: + validity period <= 7 days (604,800 seconds) + """ try: # Use timezone.utc versions to avoid deprecation warnings + issue_date = cert.not_valid_before_utc validity_period = cert.not_valid_after_utc - cert.not_valid_before_utc except AttributeError: # Fallback for older versions + issue_date = cert.not_valid_before validity_period = cert.not_valid_after - cert.not_valid_before - return validity_period.days <= 5 + + # Convert issue_date to UTC if it's not timezone-aware + if issue_date.tzinfo is None: + issue_date = issue_date.replace(tzinfo=timezone.utc) + + march_15_2026 = datetime(2026, 3, 15, tzinfo=timezone.utc) + if issue_date >= march_15_2026: + return validity_period.total_seconds() <= 604800 # 7 days in seconds + return validity_period.total_seconds() <= 864000 # 10 days in seconds @staticmethod def _extract_crl_distribution_points(cert: x509.Certificate) -> list[str]: @@ -446,7 +586,7 @@ def _check_certificate_against_crl_url( ca_cert.subject, crl_url, ) - # In most cases this indicates a configuration issue, but we'll still try verification + return CRLValidationResult.ERROR if not self._verify_crl_signature(crl, ca_cert): logger.warning("CRL signature verification failed for URL: %s", crl_url) @@ -545,25 +685,16 @@ def _extract_certificate_chains_from_connection( Returns: List of certificate chains, where each chain is a list of x509.Certificate objects """ - from OpenSSL.crypto import FILETYPE_ASN1, dump_certificate - try: - cert_chain = connection.get_peer_cert_chain() + # Convert OpenSSL certificates to cryptography x509 certificates + cert_chain = connection.get_peer_cert_chain(as_cryptography=True) if not cert_chain: logger.debug("No certificate chain found in connection") return [] - - # Convert OpenSSL certificates to cryptography x509 certificates - x509_chain = [] - for cert_openssl in cert_chain: - cert_der = dump_certificate(FILETYPE_ASN1, cert_openssl) - cert_x509 = x509.load_der_x509_certificate(cert_der, default_backend()) - x509_chain.append(cert_x509) - logger.debug( - "Extracted %d certificates for CRL validation", len(x509_chain) + "Extracted %d certificates for CRL validation", len(cert_chain) ) - return [x509_chain] # Return as a single chain + return [cert_chain] # Return as a single chain except Exception as e: logger.warning( diff --git a/src/snowflake/connector/ssl_wrap_socket.py b/src/snowflake/connector/ssl_wrap_socket.py index 3c2f92ba80..4712cf395d 100644 --- a/src/snowflake/connector/ssl_wrap_socket.py +++ b/src/snowflake/connector/ssl_wrap_socket.py @@ -16,7 +16,7 @@ from functools import wraps from inspect import signature as _sig from socket import socket -from typing import Any +from typing import TYPE_CHECKING, Any import certifi import OpenSSL.SSL @@ -30,6 +30,9 @@ from .vendored.urllib3.contrib.pyopenssl import PyOpenSSLContext, WrappedSocket from .vendored.urllib3.util import ssl_ as ssl_ +if TYPE_CHECKING: + from cryptography import x509 + DEFAULT_OCSP_MODE: OCSPMode = OCSPMode.FAIL_OPEN FEATURE_OCSP_MODE: OCSPMode = DEFAULT_OCSP_MODE DEFAULT_CRL_CONFIG: CRLConfig = CRLConfig() @@ -147,6 +150,17 @@ def inject_into_urllib3() -> None: connection_.ssl_wrap_socket = ssl_wrap_socket_with_cert_revocation_checks +def _load_trusted_certificates(cafile: str | None) -> list[x509.Certificate]: + # Use default SSL context to load the CA file and get the certificates + ctx = ssl.create_default_context() + ctx.load_verify_locations(cafile=cafile) + certs = ctx.get_ca_certs(binary_form=True) + from cryptography.hazmat.backends import default_backend + from cryptography.x509 import load_der_x509_certificate + + return [load_der_x509_certificate(cert, default_backend()) for cert in certs] + + @wraps(ssl_.ssl_wrap_socket) def ssl_wrap_socket_with_cert_revocation_checks( *args: Any, **kwargs: Any @@ -163,12 +177,12 @@ def ssl_wrap_socket_with_cert_revocation_checks( # Ensure PyOpenSSL context with partial-chain is used if none or wrong type provided provided_ctx = params.get("ssl_context") + cafile_for_ctx = _resolve_cafile(params) if not isinstance(provided_ctx, PyOpenSSLContext): - cafile_for_ctx = _resolve_cafile(params) params["ssl_context"] = _build_context_with_partial_chain(cafile_for_ctx) else: # If a PyOpenSSLContext is provided, ensure it trusts the provided CA and partial-chain is enabled - _ensure_partial_chain_on_context(provided_ctx, _resolve_cafile(params)) + _ensure_partial_chain_on_context(provided_ctx, cafile_for_ctx) ret = ssl_.ssl_wrap_socket(**params) @@ -181,7 +195,9 @@ def ssl_wrap_socket_with_cert_revocation_checks( != CertRevocationCheckMode.DISABLED ): crl_validator = CRLValidator.from_config( - FEATURE_CRL_CONFIG, get_current_session_manager() + FEATURE_CRL_CONFIG, + get_current_session_manager(), + trusted_certificates=_load_trusted_certificates(cafile_for_ctx), ) if not crl_validator.validate_connection(ret.connection): raise OperationalError( diff --git a/test/integ/test_crl.py b/test/integ/test_crl.py index 4b9ff21304..8f97907646 100644 --- a/test/integ/test_crl.py +++ b/test/integ/test_crl.py @@ -8,8 +8,18 @@ from __future__ import annotations import tempfile +import warnings import pytest +from cryptography.utils import CryptographyDeprecationWarning + + +@pytest.fixture(autouse=True) +def _ignore_deprecation_warnings(): + """Fixture to handle deprecation warnings in all tests in this module.""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", CryptographyDeprecationWarning) + yield @pytest.mark.skipolddriver diff --git a/test/unit/test_crl.py b/test/unit/test_crl.py index 67ae3e6781..f85184840b 100644 --- a/test/unit/test_crl.py +++ b/test/unit/test_crl.py @@ -62,6 +62,27 @@ class CertificateChain: leaf_cert: x509.Certificate +@dataclass +class CrossSignedCertificateChain: + # CA + # / \ + # rootA rootB + # / \ + # A --(AsignB)--> B + # <-(BsignA)-- + # \ / + # leafA leafB + # \/ + # subject + + rootA: x509.Certificate + rootB: x509.Certificate + AsignB: x509.Certificate + BsignA: x509.Certificate + leafA: x509.Certificate + leafB: x509.Certificate + + @pytest.fixture(scope="module") def cert_gen(): class CertificateGeneratorUtil: @@ -242,6 +263,123 @@ def create_simple_chain(self) -> CertificateChain: return CertificateChain(root_cert, intermediate_cert, leaf_cert) + def create_cross_signed_chain(self) -> CertificateChain: + A_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + B_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + leaf_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + A_name = x509.Name( + [ + x509.NameAttribute( + NameOID.COMMON_NAME, + f"Test CA A {self.random.randint(1, 10000)}", + ) + ] + ) + B_name = x509.Name( + [ + x509.NameAttribute( + NameOID.COMMON_NAME, + f"Test CA B {self.random.randint(1, 10000)}", + ) + ] + ) + leaf_name = x509.Name( + [ + x509.NameAttribute( + NameOID.COMMON_NAME, + f"Test Leaf {self.random.randint(1, 10000)}", + ) + ] + ) + rootA_cert = ( + x509.CertificateBuilder() + .subject_name(A_name) + .issuer_name(self.ca_certificate.subject) + .public_key(A_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(self.ca_private_key, hashes.SHA256()) + ) + rootB_cert = ( + x509.CertificateBuilder() + .subject_name(B_name) + .issuer_name(self.ca_certificate.subject) + .public_key(B_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(self.ca_private_key, hashes.SHA256()) + ) + BsignA_cert = ( + x509.CertificateBuilder() + .subject_name(A_name) + .issuer_name(B_name) + .public_key(A_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=False, path_length=None), + critical=True, + ) + .sign(B_key, hashes.SHA256()) + ) + AsignB_cert = ( + x509.CertificateBuilder() + .subject_name(B_name) + .issuer_name(A_name) + .public_key(B_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=False, path_length=None), + critical=True, + ) + .sign(A_key, hashes.SHA256()) + ) + leafA_cert = ( + x509.CertificateBuilder() + .subject_name(leaf_name) + .issuer_name(A_name) + .public_key(leaf_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=False, path_length=None), + critical=True, + ) + .sign(A_key, hashes.SHA256()) + ) + leafB_cert = ( + x509.CertificateBuilder() + .subject_name(leaf_name) + .issuer_name(B_name) + .public_key(leaf_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=False, path_length=None), + critical=True, + ) + .sign(B_key, hashes.SHA256()) + ) + return CrossSignedCertificateChain( + rootA_cert, rootB_cert, AsignB_cert, BsignA_cert, leafA_cert, leafB_cert + ) + def create_short_lived_certificate( self, validity_days: int, issuance_date: datetime ) -> x509.Certificate: @@ -262,7 +400,7 @@ def create_short_lived_certificate( cert = ( x509.CertificateBuilder() .subject_name(name) - .issuer_name(name) # Self-signed for simplicity + .issuer_name(self.ca_certificate.subject) .public_key(key.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(issuance_date) @@ -271,7 +409,7 @@ def create_short_lived_certificate( x509.BasicConstraints(ca=False, path_length=None), critical=True, ) - .sign(key, hashes.SHA256()) + .sign(self.ca_private_key, hashes.SHA256(), backend=default_backend()) ) return cert @@ -347,6 +485,7 @@ def test_should_allow_connection_when_crl_validation_disabled( validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.DISABLED, + trusted_certificates=[chain.root_cert], ) assert validator.validate_certificate_chains(chains) @@ -358,6 +497,7 @@ def test_should_allow_connection_when_crl_validation_disabled_and_no_cert_chain( validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.DISABLED, + trusted_certificates=[], ) assert validator.validate_certificate_chains([]) assert validator.validate_certificate_chains(None) @@ -368,6 +508,7 @@ def test_should_fail_with_null_or_empty_certificate_chains(cert_gen, session_man validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + trusted_certificates=[], ) assert not validator.validate_certificate_chains([]) assert not validator.validate_certificate_chains(None) @@ -383,6 +524,7 @@ def test_should_handle_certificates_without_crl_urls_in_enabled_mode( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, allow_certificates_without_crl_url=False, + trusted_certificates=[chain.root_cert], ) assert not validator.validate_certificate_chains(chains) @@ -398,6 +540,7 @@ def test_should_allow_certificates_without_crl_urls_when_configured( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, allow_certificates_without_crl_url=True, + trusted_certificates=[chain.root_cert], ) assert validator.validate_certificate_chains(chains) @@ -410,14 +553,14 @@ def test_should_pass_in_advisory_mode_even_with_errors(cert_gen, session_manager validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ADVISORY, + trusted_certificates=[chain.root_cert], ) assert validator.validate_certificate_chains(chains) def test_should_validate_multiple_chains_and_return_first_valid_with_no_crl_urls( - cert_gen, - session_manager, + cert_gen, session_manager ): """Test validation of multiple chains and return first valid""" # Create a certificate that would be considered invalid (before March 2024) @@ -437,12 +580,169 @@ def test_should_validate_multiple_chains_and_return_first_valid_with_no_crl_urls session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, allow_certificates_without_crl_url=True, + trusted_certificates=[valid_chain.root_cert], ) result = validator.validate_certificate_chains(chains) assert result, "Should return true when at least one valid chain is found" +def test_cross_signed_certificate_chain(cert_gen, session_manager): + """Test validation of cross-signed certificate chain""" + chain = cert_gen.create_cross_signed_chain() + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + allow_certificates_without_crl_url=True, + trusted_certificates=[cert_gen.ca_certificate], + ) + + # provide full chain in arbitrary order + chains = [ + [ + chain.leafA, + chain.AsignB, + chain.leafB, + chain.BsignA, + chain.rootB, + chain.rootA, + ] + ] + assert validator.validate_certificate_chains(chains) + + # only A is signed by CA + chains = [ + [ + chain.leafA, + chain.AsignB, + chain.leafB, + chain.BsignA, + # chain.rootB, + chain.rootA, + ] + ] + assert validator.validate_certificate_chains(chains) + + # nor A nor B is signed by CA + chains = [ + [ + chain.leafA, + chain.AsignB, + chain.leafB, + chain.BsignA, + # chain.rootB, + # chain.rootA, + ] + ] + assert not validator.validate_certificate_chains(chains) + + # mingled A and B paths passed in one chain - A has no connection to CA, B has + chains = [ + [ + chain.leafA, + chain.AsignB, + chain.leafB, + # chain.BsignA, + chain.rootB, + # chain.rootA, + ] + ] + assert validator.validate_certificate_chains(chains) + + +def test_starfield_incident(cert_gen, session_manager): + # leaf is signed by A, who is signed by CA (revoked) and B, who is also a trusted CA + chain = cert_gen.create_cross_signed_chain() + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + trusted_certificates=[cert_gen.ca_certificate, chain.rootB], + ) + + def mock_validate(cert, _): + if cert == chain.rootA: + return CRLValidationResult.REVOKED + return CRLValidationResult.UNREVOKED + + validator._validate_certificate_is_not_revoked = mock_validate + + assert ( + validator._validate_single_chain([chain.leafA, chain.BsignA, chain.rootA]) + == CRLValidationResult.UNREVOKED + ) + + +def test_validate_single_chain(cert_gen, session_manager): + chain = cert_gen.create_cross_signed_chain() + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + trusted_certificates=[cert_gen.ca_certificate], + ) + + input_chain = [chain.leafA, chain.leafB, chain.rootA, chain.rootB] + + # case 1: at least one valid path + def mock_validate_with_special_cert(revoked_cert, error_result): + validator._validate_certificate_is_not_revoked_with_cache = lambda cert, _: ( + error_result if cert == revoked_cert else CRLValidationResult.UNREVOKED + ) + + for error_result in [CRLValidationResult.ERROR, CRLValidationResult.REVOKED]: + for revoked_cert in [chain.rootA, chain.rootB, chain.leafA, chain.leafB]: + mock_validate_with_special_cert(revoked_cert, error_result) + assert ( + validator._validate_single_chain(input_chain) + == CRLValidationResult.UNREVOKED + ) + + # case 2: all paths revoked + def mock_validate(cert, _): + if cert in [chain.rootA, chain.rootB]: + return CRLValidationResult.REVOKED + return CRLValidationResult.UNREVOKED + + validator._validate_certificate_is_not_revoked_with_cache = mock_validate + assert validator._validate_single_chain(input_chain) == CRLValidationResult.REVOKED + + # case 3: revoked + error should result in revoked\ + def mock_validate(cert, _): + if cert in [chain.rootA, chain.leafB]: + return CRLValidationResult.REVOKED + return CRLValidationResult.ERROR + + validator._validate_certificate_is_not_revoked_with_cache = mock_validate + assert validator._validate_single_chain(input_chain) == CRLValidationResult.REVOKED + + # case 4: no path to trusted certificate + def mock_validate(cert, _): + return CRLValidationResult.UNREVOKED + + validator._validate_certificate_is_not_revoked_with_cache = mock_validate + assert ( + validator._validate_single_chain( + [chain.leafA, chain.leafB, chain.AsignB, chain.BsignA] + ) + == CRLValidationResult.ERROR + ) + + # case 5: only unrevoked path has an error + def mock_validate(cert, _): + if cert in [chain.rootA, chain.leafB]: + return CRLValidationResult.REVOKED + if cert == chain.BsignA: + return CRLValidationResult.ERROR + return CRLValidationResult.UNREVOKED + + validator._validate_certificate_is_not_revoked_with_cache = mock_validate + assert ( + validator._validate_single_chain( + [chain.leafA, chain.rootA, chain.leafB, chain.rootB, chain.BsignA] + ) + == CRLValidationResult.ERROR + ) + + @responses.activate def test_should_validate_non_revoked_certificate_successfully( cert_gen, crl_urls, session_manager @@ -467,6 +767,38 @@ def test_should_validate_non_revoked_certificate_successfully( validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + trusted_certificates=[cert_gen.ca_certificate], + ) + + assert validator.validate_certificate_chains([chain]) + assert resp.call_count + + +@responses.activate +def test_should_validate_non_revoked_certificate_successfully_if_root_not_provided_on_chain( + cert_gen, crl_urls, session_manager +): + """Test validation of non-revoked certificate""" + # Setup mock HTTP client + crl_content = cert_gen.generate_valid_crl() + resp = responses.add( + responses.GET, + crl_urls.test_ca, + body=crl_content, + status=200, + content_type="application/pkcs7-mime", + ) + + # Create certificate with CRL distribution point + cert = cert_gen.create_certificate_with_crl_distribution_points( + "CN=Test Server", [crl_urls.test_ca] + ) + chain = [cert] + + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + trusted_certificates=[cert_gen.ca_certificate], ) assert validator.validate_certificate_chains([chain]) @@ -495,6 +827,7 @@ def test_should_fail_for_revoked_certificate(cert_gen, crl_urls, session_manager validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + trusted_certificates=[cert_gen.ca_certificate], ) assert not validator.validate_certificate_chains([chain]) @@ -524,6 +857,7 @@ def test_should_allow_revoked_certificate_when_crl_validation_disabled( validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.DISABLED, + trusted_certificates=[cert_gen.ca_certificate], ) assert validator.validate_certificate_chains([chain]) @@ -546,6 +880,7 @@ def test_should_pass_in_advisory_mode_with_crl_errors( validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ADVISORY, + trusted_certificates=[cert_gen.ca_certificate], ) assert validator.validate_certificate_chains([chain]) @@ -568,6 +903,7 @@ def test_should_fail_in_enabled_mode_with_crl_errors( validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + trusted_certificates=[cert_gen.ca_certificate], ) assert not validator.validate_certificate_chains([chain]) @@ -606,6 +942,7 @@ def test_should_validate_multiple_chains_and_success_if_just_one_valid( validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + trusted_certificates=[cert_gen.ca_certificate], ) assert validator.validate_certificate_chains([invalid_chain, valid_chain]) @@ -633,6 +970,7 @@ def test_should_reject_expired_crl(cert_gen, crl_urls, session_manager): validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + trusted_certificates=[cert_gen.ca_certificate], ) assert not validator.validate_certificate_chains([chain]) @@ -650,6 +988,7 @@ def test_should_skip_short_lived_certificates(cert_gen, session_manager): validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + trusted_certificates=[cert_gen.ca_certificate], ) # Should pass without any HTTP calls (no responses setup) @@ -691,6 +1030,7 @@ def test_should_handle_multiple_crl_distribution_points( validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + trusted_certificates=[cert_gen.ca_certificate], ) assert validator.validate_certificate_chains([chain]) @@ -702,12 +1042,14 @@ def test_crl_validator_creation(session_manager): """Test that CRLValidator can be created properly""" # Test basic instantiation - validator = CRLValidator(session_manager) + validator = CRLValidator(session_manager, trusted_certificates=[]) assert validator is not None assert isinstance(validator, CRLValidator) # Test that it works with from_config class method - validator = CRLValidator.from_config(CRLConfig(), session_manager) + validator = CRLValidator.from_config( + CRLConfig(), session_manager, trusted_certificates=[] + ) assert validator is not None assert isinstance(validator, CRLValidator) @@ -725,7 +1067,7 @@ def test_crl_validator_atexit_cleanup(session_manager): try: # Create validator which should start cleanup - CRLValidator.from_config(config, session_manager) + CRLValidator.from_config(config, session_manager, trusted_certificates=[]) # Verify cleanup is running through factory assert CRLCacheFactory.is_periodic_cleanup_running() @@ -750,14 +1092,16 @@ def test_crl_validator_validate_connection(session_manager): # Test with no certificate chain mock_connection.get_peer_cert_chain.return_value = [] - validator = CRLValidator(session_manager) + validator = CRLValidator(session_manager, trusted_certificates=[]) # Should return True when disabled (default) assert validator.validate_connection(mock_connection) # Test with enabled mode and no certificates validator = CRLValidator( - session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ENABLED + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + trusted_certificates=[], ) assert not validator.validate_connection(mock_connection) @@ -766,7 +1110,9 @@ def test_crl_validator_extract_certificate_chains_from_connection( cert_gen, session_manager ): """Test the _extract_certificate_chains_from_connection method""" - validator = CRLValidator(session_manager) + chain = cert_gen.create_simple_chain() + + validator = CRLValidator(session_manager, trusted_certificates=[chain.root_cert]) # Test with no certificate chain mock_connection = Mock() @@ -776,7 +1122,6 @@ def test_crl_validator_extract_certificate_chains_from_connection( assert chains == [] # Test with mock certificate chain - chain = cert_gen.create_simple_chain() mock_certs = [] # Create mock OpenSSL certificates @@ -992,7 +1337,9 @@ def test_crl_validator_download_crl_success(cert_gen, session_manager): content_type="application/pkcs7-mime", ) - validator = CRLValidator(session_manager) + validator = CRLValidator( + session_manager, trusted_certificates=[cert_gen.ca_certificate] + ) # Test the download method - it returns a tuple (crl, timestamp) crl, timestamp = validator._download_crl(crl_url) @@ -1008,7 +1355,7 @@ def test_crl_validator_download_crl_http_error(session_manager): responses.add(responses.GET, crl_url, status=404) - validator = CRLValidator(session_manager) + validator = CRLValidator(session_manager, trusted_certificates=[]) # Should return (None, None) on HTTP error crl, timestamp = validator._download_crl(crl_url) @@ -1024,7 +1371,10 @@ def test_crl_validator_download_crl_network_timeout(session_manager): crl_url = "http://example.com/slow.crl" validator = CRLValidator( - session_manager, connection_timeout_ms=1000, read_timeout_ms=1000 + session_manager, + connection_timeout_ms=1000, + read_timeout_ms=1000, + trusted_certificates=[], ) # Mock requests to raise timeout @@ -1045,7 +1395,7 @@ def test_crl_validator_download_crl_network_error(session_manager): crl_url = "http://example.com/unreachable.crl" - validator = CRLValidator(session_manager) + validator = CRLValidator(session_manager, trusted_certificates=[]) # Mock requests to raise connection error with mock_patch.object( @@ -1064,7 +1414,7 @@ def test_crl_validator_extract_crl_distribution_points_success( crl_urls = ["http://example.com/ca.crl", "http://backup.com/ca.crl"] cert = cert_gen.create_certificate_with_crl_distribution_points("CN=Test", crl_urls) - validator = CRLValidator(session_manager) + validator = CRLValidator(session_manager, trusted_certificates=[]) extracted_urls = validator._extract_crl_distribution_points(cert) @@ -1081,7 +1431,7 @@ def test_crl_validator_extract_crl_distribution_points_no_extension( chain = cert_gen.create_simple_chain() cert = chain.leaf_cert - validator = CRLValidator(session_manager) + validator = CRLValidator(session_manager, trusted_certificates=[]) # Should return empty list when no CRL extension found extracted_urls = validator._extract_crl_distribution_points(cert) @@ -1102,7 +1452,7 @@ def test_crl_validator_check_certificate_against_crl_not_revoked( mock_crl = Mock(spec=CertificateRevocationList) mock_crl.get_revoked_certificate_by_serial_number.return_value = None - validator = CRLValidator(session_manager) + validator = CRLValidator(session_manager, trusted_certificates=[]) # Should return UNREVOKED result = validator._check_certificate_against_crl(cert, mock_crl) @@ -1122,7 +1472,7 @@ def test_crl_validator_check_certificate_against_crl_revoked(cert_gen, session_m mock_crl = Mock(spec=CertificateRevocationList) mock_crl.get_revoked_certificate_by_serial_number.return_value = mock_revoked_cert - validator = CRLValidator(session_manager) + validator = CRLValidator(session_manager, trusted_certificates=[]) # Should return REVOKED result = validator._check_certificate_against_crl(cert, mock_crl) @@ -1143,12 +1493,15 @@ def test_crl_validator_check_certificate_against_crl_expired( mock_crl = Mock(spec=x509.CertificateRevocationList) mock_crl.next_update_utc = datetime.now(timezone.utc) - timedelta(days=1) # Expired mock_crl.get_revoked_certificate_by_serial_number.return_value = None + mock_crl.issuer = parent.subject # Cache will return an expired CRL mock_cache_mgr = Mock(spec=CRLCacheManager) mock_cache_mgr.get.return_value = CRLCacheEntry(mock_crl, datetime.now()) - validator = CRLValidator(session_manager, cache_manager=mock_cache_mgr) + validator = CRLValidator( + session_manager, cache_manager=mock_cache_mgr, trusted_certificates=[] + ) with mock_patch.object( validator, "_download_crl", return_value=(mock_crl, datetime.now()) ) as mock_download, mock_patch.object( @@ -1177,11 +1530,12 @@ def test_crl_validator_validate_certificate_with_cache_hit( # Mock cache manager with cache hit mock_crl = Mock(spec=x509.CertificateRevocationList) mock_crl.next_update_utc = datetime.now(timezone.utc) + timedelta(days=7) + mock_crl.issuer = ca_cert.subject mock_cache_manager = Mock() cached_entry = CRLCacheEntry(mock_crl, datetime.now(timezone.utc)) mock_cache_manager.get.return_value = cached_entry - validator = CRLValidator(session_manager) + validator = CRLValidator(session_manager, trusted_certificates=[]) validator._cache_manager = mock_cache_manager # Mock CRL parsing and validation @@ -1192,7 +1546,7 @@ def test_crl_validator_validate_certificate_with_cache_hit( ) as mock_check, mock_patch.object( validator, "_verify_crl_signature", return_value=True ) as mock_verify: - result = validator._validate_certificate(cert, ca_cert) + result = validator._validate_certificate_is_not_revoked(cert, ca_cert) # Should use cached CRL assert result == CRLValidationResult.UNREVOKED @@ -1215,7 +1569,11 @@ def test_crl_validator_validate_certificate_with_cache_miss( mock_cache_manager = Mock() mock_cache_manager.get.return_value = None - validator = CRLValidator(session_manager, cache_manager=mock_cache_manager) + validator = CRLValidator( + session_manager, + cache_manager=mock_cache_manager, + trusted_certificates=[], + ) # Mock successful download and validation with mock_patch.object( @@ -1229,12 +1587,11 @@ def test_crl_validator_validate_certificate_with_cache_miss( ) as mock_check, mock_patch.object( validator, "_verify_crl_signature", return_value=True ) as mock_verify: - mock_crl = Mock() mock_crl.next_update_utc = datetime.now(timezone.utc) + timedelta(days=7) + mock_crl.issuer = ca_cert.subject # Set the CRL issuer to match CA subject mock_load_crl.return_value = mock_crl - - result = validator._validate_certificate(cert, ca_cert) + result = validator._validate_certificate_is_not_revoked(cert, ca_cert) # Should download CRL and cache it assert result == CRLValidationResult.UNREVOKED @@ -1251,7 +1608,7 @@ def test_crl_signature_verification_success(cert_gen, session_manager): crl_bytes = cert_gen.generate_valid_crl() crl = x509.load_der_x509_crl(crl_bytes, backend=default_backend()) - validator = CRLValidator(session_manager) + validator = CRLValidator(session_manager, trusted_certificates=[]) # Should successfully verify the signature result = validator._verify_crl_signature(crl, cert_gen.ca_certificate) @@ -1286,7 +1643,7 @@ def test_crl_signature_verification_failure_wrong_ca(cert_gen, session_manager): .sign(different_ca_key, hashes.SHA256(), backend=default_backend()) ) - validator = CRLValidator(session_manager) + validator = CRLValidator(session_manager, trusted_certificates=[]) # Should fail to verify the signature with wrong CA result = validator._verify_crl_signature(crl, different_ca_cert) @@ -1339,7 +1696,7 @@ def test_crl_signature_verification_with_ec_key(session_manager): ec_crl = builder.sign(ec_private_key, hashes.SHA256(), backend=default_backend()) - validator = CRLValidator(session_manager) + validator = CRLValidator(session_manager, trusted_certificates=[]) # Should successfully verify EC signature result = validator._verify_crl_signature(ec_crl, ec_ca_cert) @@ -1359,7 +1716,7 @@ def test_crl_signature_verification_with_corrupted_signature(cert_gen, session_m corrupted_crl.signature = b"corrupted_signature_bytes" corrupted_crl.tbs_certlist_bytes = crl.tbs_certlist_bytes - validator = CRLValidator(session_manager) + validator = CRLValidator(session_manager, trusted_certificates=[]) # Should fail to verify corrupted signature result = validator._verify_crl_signature(corrupted_crl, cert_gen.ca_certificate) @@ -1373,10 +1730,12 @@ def test_crl_signature_verification_exception_handling(cert_gen, session_manager crl = x509.load_der_x509_crl(crl_bytes, backend=default_backend()) # Mock CA certificate that will cause an exception + mock_public_key = Mock() + mock_public_key.verify.side_effect = Exception("Test exception") mock_ca_cert = Mock(spec=x509.Certificate) - mock_ca_cert.public_key.side_effect = Exception("Test exception") + mock_ca_cert.public_key.return_value = mock_public_key - validator = CRLValidator(session_manager) + validator = CRLValidator(session_manager, trusted_certificates=[]) # Should handle exception gracefully and return False result = validator._verify_crl_signature(crl, mock_ca_cert) @@ -1415,12 +1774,15 @@ def test_crl_signature_verification_integration_with_validation_flow( validator_enabled = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + trusted_certificates=[], ) with mock_patch.object( validator_enabled, "_fetch_crl_from_url", return_value=invalid_crl_bytes ): - result = validator_enabled._validate_certificate(cert, cert_gen.ca_certificate) + result = validator_enabled._validate_certificate_is_not_revoked( + cert, cert_gen.ca_certificate + ) assert result == CRLValidationResult.ERROR # Test in ADVISORY mode - should also fail due to signature verification failure @@ -1428,12 +1790,15 @@ def test_crl_signature_verification_integration_with_validation_flow( validator_advisory = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ADVISORY, + trusted_certificates=[], ) with mock_patch.object( validator_advisory, "_fetch_crl_from_url", return_value=invalid_crl_bytes ): - result = validator_advisory._validate_certificate(cert, cert_gen.ca_certificate) + result = validator_advisory._validate_certificate_is_not_revoked( + cert, cert_gen.ca_certificate + ) # Even in ADVISORY mode, signature verification failure should return ERROR # We cannot trust a CRL whose signature cannot be verified assert result == CRLValidationResult.ERROR @@ -1444,8 +1809,8 @@ def test_crl_signature_verification_with_issuer_mismatch_warning( ): """Test that we log a warning when CRL issuer doesn't match CA certificate subject""" # Create a valid CRL signed by the test CA - crl_bytes = cert_gen.generate_valid_crl() - crl = x509.load_der_x509_crl(crl_bytes, backend=default_backend()) + crl = Mock(spec=x509.CertificateRevocationList) + crl.issuer = cert_gen.ca_certificate.subject # Create a different CA certificate with different subject different_ca_key = rsa.generate_private_key( @@ -1469,7 +1834,7 @@ def test_crl_signature_verification_with_issuer_mismatch_warning( .sign(different_ca_key, hashes.SHA256(), backend=default_backend()) ) - validator = CRLValidator(session_manager) + validator = CRLValidator(session_manager, trusted_certificates=[]) # Mock the _verify_crl_signature to return True to focus on the issuer check with mock_patch.object( @@ -1491,8 +1856,8 @@ def test_crl_signature_verification_with_issuer_mismatch_warning( "http://test.crl", ) - # Should still return UNREVOKED since signature verification was mocked to succeed - assert result == CRLValidationResult.UNREVOKED + # Should return ERROR due to issuer mismatch + assert result == CRLValidationResult.ERROR # Verify that the warning was logged assert len(caplog.records) > 0 @@ -1505,3 +1870,215 @@ def test_crl_signature_verification_with_issuer_mismatch_warning( assert ( warning_found ), f"Expected warning about CRL issuer mismatch not found in logs: {[r.message for r in caplog.records]}" + + +@pytest.mark.parametrize( + "issue_date,validity_days,expected", + [ + ( + # Issued on March 15, 2024, should use 10-day rule + datetime(2024, 3, 15, tzinfo=timezone.utc), + 10, + True, + ), + ( + # Issued on March 15, 2024, should use 10-day rule + datetime(2024, 3, 15, tzinfo=timezone.utc), + 11, + False, + ), + ( + # Issued on March 15, 2024, should use 10-day rule + datetime(2024, 3, 15), + 10, + True, + ), + ( + # Issued on March 15, 2024, should use 10-day rule + datetime(2024, 3, 15), + 11, + False, + ), + ( + # Issued on March 15, 2026, should use 7-day rule + datetime(2026, 3, 15, tzinfo=timezone.utc), + 7, + True, + ), + ( + # Issued on March 15, 2026, should use 7-day rule + datetime(2026, 3, 15, tzinfo=timezone.utc), + 8, + False, + ), + ( + # Issued on March 15, 2026, should use 7-day rule + datetime(2026, 3, 15), + 7, + True, + ), + ( + # Issued on March 15, 2026, should use 7-day rule + datetime(2026, 3, 15), + 8, + False, + ), + ], +) +def test_is_short_lived_certificate(cert_gen, issue_date, validity_days, expected): + cert = cert_gen.create_short_lived_certificate(validity_days, issue_date) + assert CRLValidator._is_short_lived_certificate(cert) == expected + + +def test_validate_certificate_signatures(cert_gen, session_manager): + """Test that certificate validation fails with ERROR when signed by wrong key""" + # Create a certificate signed by the test CA + valid_cert = cert_gen.create_certificate_with_crl_distribution_points( + "CN=Test Server", [] + ) + + # Create a different CA key pair + different_ca_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + different_cert = ( + x509.CertificateBuilder() + .subject_name(valid_cert.subject) + .issuer_name(cert_gen.ca_certificate.subject) + .public_key(cert_gen.ca_private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(different_ca_key, hashes.SHA256(), backend=default_backend()) + ) + short_lived_different_cert = ( + x509.CertificateBuilder() + .subject_name(valid_cert.subject) + .issuer_name(cert_gen.ca_certificate.subject) + .public_key(different_ca_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=3)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(different_ca_key, hashes.SHA256(), backend=default_backend()) + ) + + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + allow_certificates_without_crl_url=True, + trusted_certificates=[cert_gen.ca_certificate], + ) + + # wrong signature - no path found = ERROR + assert ( + validator._validate_single_chain([different_cert]) == CRLValidationResult.ERROR + ) + # wrong signature - short-lived - no path found = ERROR + assert ( + validator._validate_single_chain([short_lived_different_cert]) + == CRLValidationResult.ERROR + ) + # wrong signature does not stop from searching of new path + assert ( + validator._validate_single_chain( + [different_cert, short_lived_different_cert, valid_cert] + ) + == CRLValidationResult.UNREVOKED + ) + + +def test_validate_certificate_signatures_in_chain(cert_gen, session_manager): + """Test that certificate validation fails with ERROR when signed by wrong key""" + # Create a certificate chain signed by the test CA: leaf -> A -> B -> CA + # mingle with A -> B + chain = cert_gen.create_cross_signed_chain() + + valid_cert = chain.BsignA + + # Create a different CA key pair + different_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + different_cert = ( + x509.CertificateBuilder() + .subject_name(valid_cert.subject) + .issuer_name(cert_gen.ca_certificate.subject) + .public_key(cert_gen.ca_private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(different_key, hashes.SHA256(), backend=default_backend()) + ) + short_lived_different_cert = ( + x509.CertificateBuilder() + .subject_name(valid_cert.subject) + .issuer_name(cert_gen.ca_certificate.subject) + .public_key(different_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=3)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(different_key, hashes.SHA256(), backend=default_backend()) + ) + + validator = CRLValidator( + session_manager, + allow_certificates_without_crl_url=True, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + trusted_certificates=[cert_gen.ca_certificate], + ) + + # wrong signature - no path found = ERROR + assert ( + validator._validate_single_chain([chain.leafA, different_cert, chain.rootB]) + == CRLValidationResult.ERROR + ) + # wrong signature - short-lived - no path found = ERROR + assert ( + validator._validate_single_chain( + [chain.leafA, short_lived_different_cert, chain.rootB] + ) + == CRLValidationResult.ERROR + ) + # wrong signature does not stop from searching of new path + assert ( + validator._validate_single_chain( + [ + chain.leafA, + different_cert, + short_lived_different_cert, + valid_cert, + chain.rootB, + ] + ) + == CRLValidationResult.UNREVOKED + ) + + +def test_trusted_certificates_helpers(cert_gen): + chain = cert_gen.create_simple_chain() + + validator = CRLValidator( + session_manager=Mock(), trusted_certificates=[chain.root_cert] + ) + + assert validator._is_certificate_trusted_by_os(chain.root_cert) is True + assert validator._is_certificate_trusted_by_os(chain.intermediate_cert) is False + + assert validator._get_trusted_ca_issuer(chain.intermediate_cert) is chain.root_cert + assert validator._get_trusted_ca_issuer(chain.leaf_cert) is None From 8aefec8bb3d77c474c8868d487f07da1c411914b Mon Sep 17 00:00:00 2001 From: Patryk Czajka Date: Thu, 23 Oct 2025 10:28:35 +0200 Subject: [PATCH 04/16] Snow 2401045 fix vulnerabilities in crl (#2584) Co-authored-by: Tomasz Urbaszek --- src/snowflake/connector/crl.py | 225 +++++++---- test/unit/test_crl.py | 669 +++++++++++++++++++++++---------- 2 files changed, 630 insertions(+), 264 deletions(-) diff --git a/src/snowflake/connector/crl.py b/src/snowflake/connector/crl.py index 0d774ebe18..e6af12c412 100644 --- a/src/snowflake/connector/crl.py +++ b/src/snowflake/connector/crl.py @@ -270,61 +270,67 @@ def from_config( cache_manager=cache_manager, ) - def validate_certificate_chains( - self, certificate_chains: list[list[x509.Certificate]] + def validate_certificate_chain( + self, peer_cert: x509.Certificate, chain: list[x509.Certificate] | None ) -> bool: """ - Validate certificate chains against CRLs with actual HTTP requests + Validate a certificate chain against CRLs with actual HTTP requests Args: - certificate_chains: List of certificate chains to validate + peer_cert: The peer certificate to validate (e.g., server certificate) + chain: Certificate chain to use for validation (can be None or empty) Returns: True if validation passes, False otherwise - - Raises: - ValueError: If certificate_chains is None or empty """ if self._cert_revocation_check_mode == CertRevocationCheckMode.DISABLED: return True - if certificate_chains is None or len(certificate_chains) == 0: - logger.warning("Certificate chains are empty") - return self._cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY - - results = [] - for chain in certificate_chains: - result = self._validate_single_chain(chain) - # If any of the chains is valid, the whole check is considered positive - if result == CRLValidationResult.UNREVOKED: - return True - results.append(result) + chain = chain if chain is not None else [] + result = self._validate_chain(peer_cert, chain) - # In non-advisory mode we require at least one chain get a clear UNREVOKED status - if self._cert_revocation_check_mode != CertRevocationCheckMode.ADVISORY: + if result == CRLValidationResult.UNREVOKED: + return True + if result == CRLValidationResult.REVOKED: return False + # In advisory mode, errors are treated positively + return self._cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY - # We're in advisory mode, so any error is treated positively - return any(result == CRLValidationResult.ERROR for result in results) - - def _validate_single_chain( - self, chain: list[x509.Certificate] + def _validate_chain( + self, start_cert: x509.Certificate, chain: list[x509.Certificate] ) -> CRLValidationResult: """ + Validate a certificate chain starting from start_cert. + + Args: + start_cert: The certificate to start validation from + chain: List of certificates to use for building the trust path + Returns: UNREVOKED: If there is a path to any trusted certificate where all certificates are unrevoked. REVOKED: If all paths to trusted certificates are revoked. ERROR: If there is a path to any trusted certificate on which none certificate is revoked, but some certificates can't be verified. """ - # An empty chain is considered an error - if len(chain) == 0: + # Check if start certificate is expired + if not self._is_within_validity_dates(start_cert): + logger.warning( + "Start certificate is expired or not yet valid: %s", start_cert.subject + ) return CRLValidationResult.ERROR subject_certificates: dict[x509.Name, list[x509.Certificate]] = defaultdict( list ) for cert in chain: + if not self._is_ca_certificate(cert): + logger.warning("Ignoring non-CA certificate: %s", cert) + continue + if not self._is_within_validity_dates(cert): + logger.warning( + "Ignoring certificate not within validity dates: %s", cert + ) + continue subject_certificates[cert.subject].append(cert) currently_visited_subjects: set[x509.Name] = set() @@ -387,19 +393,7 @@ def traverse_chain(cert: x509.Certificate) -> CRLValidationResult | None: # no ERROR result found, all paths are REVOKED return CRLValidationResult.REVOKED - currently_visited_subjects.add(chain[0].subject) - error_result = False - revoked_result = False - for cert in subject_certificates[chain[0].subject]: - result = traverse_chain(cert) - if result == CRLValidationResult.UNREVOKED: - return result - error_result |= result == CRLValidationResult.ERROR - revoked_result |= result == CRLValidationResult.REVOKED - - if error_result or not revoked_result: - return CRLValidationResult.ERROR - return CRLValidationResult.REVOKED + return traverse_chain(start_cert) def _is_certificate_trusted_by_os(self, cert: x509.Certificate) -> bool: if cert.subject not in self._trusted_ca: @@ -426,6 +420,50 @@ def _verify_certificate_signature( except Exception: return False + @staticmethod + def _is_ca_certificate(ca_cert: x509.Certificate) -> bool: + # Check if a certificate has basicConstraints extension with CA flag set to True. + try: + basic_constraints = ca_cert.extensions.get_extension_for_oid( + ExtensionOID.BASIC_CONSTRAINTS + ).value + return basic_constraints.ca + except x509.ExtensionNotFound: + # If the extension is not present, the certificate is not a CA + return False + + @staticmethod + def _get_certificate_validity_dates( + cert: x509.Certificate, + ) -> tuple[datetime, datetime]: + # Extract UTC-aware validity dates from a certificate. + + try: + # Use timezone-aware versions to avoid deprecation warnings + not_valid_before = cert.not_valid_before_utc + not_valid_after = cert.not_valid_after_utc + except AttributeError: + # Fallback for older versions without _utc methods + not_valid_before = cert.not_valid_before + not_valid_after = cert.not_valid_after + + # Convert to UTC if not timezone-aware + if not_valid_before.tzinfo is None: + not_valid_before = not_valid_before.replace(tzinfo=timezone.utc) + if not_valid_after.tzinfo is None: + not_valid_after = not_valid_after.replace(tzinfo=timezone.utc) + + return not_valid_before, not_valid_after + + @staticmethod + def _is_within_validity_dates(cert: x509.Certificate) -> bool: + # Check if a certificate is currently valid (not expired and not before validity period). + not_valid_before, not_valid_after = ( + CRLValidator._get_certificate_validity_dates(cert) + ) + now = datetime.now(timezone.utc) + return not_valid_before <= now <= not_valid_after + def _validate_certificate_is_not_revoked_with_cache( self, cert: x509.Certificate, ca_cert: x509.Certificate ) -> CRLValidationResult: @@ -474,23 +512,13 @@ def _is_short_lived_certificate(cert: x509.Certificate) -> bool: - For certificates issued on or after 15 March 2026: validity period <= 7 days (604,800 seconds) """ - try: - # Use timezone.utc versions to avoid deprecation warnings - issue_date = cert.not_valid_before_utc - validity_period = cert.not_valid_after_utc - cert.not_valid_before_utc - except AttributeError: - # Fallback for older versions - issue_date = cert.not_valid_before - validity_period = cert.not_valid_after - cert.not_valid_before - - # Convert issue_date to UTC if it's not timezone-aware - if issue_date.tzinfo is None: - issue_date = issue_date.replace(tzinfo=timezone.utc) + issue_date, expiry_date = CRLValidator._get_certificate_validity_dates(cert) + validity_period = expiry_date - issue_date + timedelta(days=1) march_15_2026 = datetime(2026, 3, 15, tzinfo=timezone.utc) if issue_date >= march_15_2026: - return validity_period.total_seconds() <= 604800 # 7 days in seconds - return validity_period.total_seconds() <= 864000 # 10 days in seconds + return validity_period.days <= 7 + return validity_period.days <= 10 @staticmethod def _extract_crl_distribution_points(cert: x509.Certificate) -> list[str]: @@ -594,6 +622,11 @@ def _check_certificate_against_crl_url( # We cannot trust a CRL whose signature cannot be verified return CRLValidationResult.ERROR + # Verify that the CRL URL matches the IDP extension + if not self._verify_against_idp_extension(crl, crl_url): + logger.warning("CRL URL does not match IDP extension for URL: %s", crl_url) + return CRLValidationResult.ERROR + # Check if certificate is revoked return self._check_certificate_against_crl(cert, crl) @@ -645,6 +678,52 @@ def _verify_crl_signature( logger.warning("CRL signature verification failed: %s", e) return False + def _verify_against_idp_extension( + self, crl: x509.CertificateRevocationList, crl_url: str + ) -> bool: + # Verify that the CRL distribution point URL matches the IDP extension. + logger.debug( + "Trying to verify CRL URL against IDP extension for URL: %s", crl_url + ) + + try: + idp_extension = crl.extensions.get_extension_for_oid( + ExtensionOID.ISSUING_DISTRIBUTION_POINT + ) + idp = idp_extension.value + + # If the IDP has a distribution point, verify it matches the CRL URL + if not idp.full_name: + # according to baseline requirements this should not happen + # https://github.com/cabforum/servercert/blob/main/docs/BR.md + logger.debug( + "IDP extension has no full_name - treating as invalid", + crl_url, + ) + return False + + for name in idp.full_name: + if isinstance(name, x509.UniformResourceIdentifier): + if name.value == crl_url: + logger.debug("CRL URL matches IDP extension: %s", crl_url) + return True + # If we found distribution points but none matched + logger.warning( + "CRL URL %s does not match any IDP distribution point", crl_url + ) + return False + + except x509.ExtensionNotFound: + # If the IDP extension is not present, consider it valid + logger.debug( + "No IDP extension found in CRL, treating as valid for URL: %s", crl_url + ) + return True + except Exception as e: + # If we can't parse the IDP extension, log and treat as error + logger.warning("Failed to verify IDP extension: %s", e) + return False + def _check_certificate_against_crl( self, cert: x509.Certificate, crl: x509.CertificateRevocationList ) -> CRLValidationResult: @@ -660,8 +739,8 @@ def validate_connection(self, connection: SSLConnection) -> bool: """ Validate an OpenSSL connection against CRLs. - This method extracts certificate chains from the connection and validates them - against Certificate Revocation Lists (CRLs). + This method extracts the peer certificate and certificate chain from the + connection and validates them against Certificate Revocation Lists (CRLs). Args: connection: OpenSSL connection object @@ -669,35 +748,47 @@ def validate_connection(self, connection: SSLConnection) -> bool: Returns: True if validation passes, False otherwise """ - certificate_chains = self._extract_certificate_chains_from_connection( - connection - ) - return self.validate_certificate_chains(certificate_chains) + try: + # Get the peer certificate (the start certificate) + peer_cert = connection.get_peer_certificate(as_cryptography=True) + if peer_cert is None: + logger.warning("No peer certificate found in connection") + return ( + self._cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY + ) + + # Extract the certificate chain + cert_chain = self._extract_certificate_chain_from_connection(connection) + + return self.validate_certificate_chain(peer_cert, cert_chain) + except Exception as e: + logger.warning("Failed to validate connection: %s", e) + return self._cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY - def _extract_certificate_chains_from_connection( + def _extract_certificate_chain_from_connection( self, connection - ) -> list[list[x509.Certificate]]: - """Extract certificate chains from OpenSSL connection for CRL validation. + ) -> list[x509.Certificate] | None: + """Extract certificate chain from OpenSSL connection for CRL validation. Args: connection: OpenSSL connection object Returns: - List of certificate chains, where each chain is a list of x509.Certificate objects + Certificate chain as a list of x509.Certificate objects, or None on error """ try: # Convert OpenSSL certificates to cryptography x509 certificates cert_chain = connection.get_peer_cert_chain(as_cryptography=True) if not cert_chain: logger.debug("No certificate chain found in connection") - return [] + return None logger.debug( "Extracted %d certificates for CRL validation", len(cert_chain) ) - return [cert_chain] # Return as a single chain + return cert_chain except Exception as e: logger.warning( "Failed to extract certificate chain for CRL validation: %s", e ) - return [] + return None diff --git a/test/unit/test_crl.py b/test/unit/test_crl.py index f85184840b..2b798c6991 100644 --- a/test/unit/test_crl.py +++ b/test/unit/test_crl.py @@ -72,8 +72,10 @@ class CrossSignedCertificateChain: # <-(BsignA)-- # \ / # leafA leafB - # \/ - # subject + # \ / + # leaf_ca + # | + # subject rootA: x509.Certificate rootB: x509.Certificate @@ -81,6 +83,7 @@ class CrossSignedCertificateChain: BsignA: x509.Certificate leafA: x509.Certificate leafB: x509.Certificate + final_cert: x509.Certificate @pytest.fixture(scope="module") @@ -263,10 +266,11 @@ def create_simple_chain(self) -> CertificateChain: return CertificateChain(root_cert, intermediate_cert, leaf_cert) - def create_cross_signed_chain(self) -> CertificateChain: + def create_cross_signed_chain(self) -> CrossSignedCertificateChain: A_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) B_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) leaf_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + subject_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) A_name = x509.Name( [ @@ -288,7 +292,15 @@ def create_cross_signed_chain(self) -> CertificateChain: [ x509.NameAttribute( NameOID.COMMON_NAME, - f"Test Leaf {self.random.randint(1, 10000)}", + f"Test CA Leaf {self.random.randint(1, 10000)}", + ) + ] + ) + subject_name = x509.Name( + [ + x509.NameAttribute( + NameOID.COMMON_NAME, + f"Test Subject {self.random.randint(1, 10000)}", ) ] ) @@ -329,7 +341,7 @@ def create_cross_signed_chain(self) -> CertificateChain: .not_valid_before(datetime.now(timezone.utc)) .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) .add_extension( - x509.BasicConstraints(ca=False, path_length=None), + x509.BasicConstraints(ca=True, path_length=None), critical=True, ) .sign(B_key, hashes.SHA256()) @@ -343,7 +355,7 @@ def create_cross_signed_chain(self) -> CertificateChain: .not_valid_before(datetime.now(timezone.utc)) .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) .add_extension( - x509.BasicConstraints(ca=False, path_length=None), + x509.BasicConstraints(ca=True, path_length=None), critical=True, ) .sign(A_key, hashes.SHA256()) @@ -357,7 +369,7 @@ def create_cross_signed_chain(self) -> CertificateChain: .not_valid_before(datetime.now(timezone.utc)) .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) .add_extension( - x509.BasicConstraints(ca=False, path_length=None), + x509.BasicConstraints(ca=True, path_length=None), critical=True, ) .sign(A_key, hashes.SHA256()) @@ -371,13 +383,33 @@ def create_cross_signed_chain(self) -> CertificateChain: .not_valid_before(datetime.now(timezone.utc)) .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) .add_extension( - x509.BasicConstraints(ca=False, path_length=None), + x509.BasicConstraints(ca=True, path_length=None), critical=True, ) .sign(B_key, hashes.SHA256()) ) + final_cert = ( + x509.CertificateBuilder() + .subject_name(subject_name) + .issuer_name(leaf_name) + .public_key(subject_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=False, path_length=None), + critical=True, + ) + .sign(leaf_key, hashes.SHA256()) + ) return CrossSignedCertificateChain( - rootA_cert, rootB_cert, AsignB_cert, BsignA_cert, leafA_cert, leafB_cert + rootA_cert, + rootB_cert, + AsignB_cert, + BsignA_cert, + leafA_cert, + leafB_cert, + final_cert, ) def create_short_lived_certificate( @@ -480,7 +512,6 @@ def test_should_allow_connection_when_crl_validation_disabled( ): """Test that connections are allowed when CRL validation is disabled""" chain = cert_gen.create_simple_chain() - chains = [[chain.leaf_cert, chain.intermediate_cert, chain.root_cert]] validator = CRLValidator( session_manager, @@ -488,19 +519,22 @@ def test_should_allow_connection_when_crl_validation_disabled( trusted_certificates=[chain.root_cert], ) - assert validator.validate_certificate_chains(chains) + assert validator.validate_certificate_chain( + chain.leaf_cert, [chain.intermediate_cert, chain.root_cert] + ) def test_should_allow_connection_when_crl_validation_disabled_and_no_cert_chain( - session_manager, + cert_gen, session_manager ): + cert = cert_gen.create_short_lived_certificate(10, datetime.now(timezone.utc)) validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.DISABLED, trusted_certificates=[], ) - assert validator.validate_certificate_chains([]) - assert validator.validate_certificate_chains(None) + assert validator.validate_certificate_chain(cert, []) + assert validator.validate_certificate_chain(cert, None) def test_should_fail_with_null_or_empty_certificate_chains(cert_gen, session_manager): @@ -510,8 +544,10 @@ def test_should_fail_with_null_or_empty_certificate_chains(cert_gen, session_man cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, trusted_certificates=[], ) - assert not validator.validate_certificate_chains([]) - assert not validator.validate_certificate_chains(None) + # Create a dummy certificate for testing + dummy_cert = cert_gen.create_short_lived_certificate(10, datetime.now(timezone.utc)) + assert not validator.validate_certificate_chain(dummy_cert, []) + assert not validator.validate_certificate_chain(dummy_cert, None) def test_should_handle_certificates_without_crl_urls_in_enabled_mode( @@ -519,14 +555,15 @@ def test_should_handle_certificates_without_crl_urls_in_enabled_mode( ): """Test handling of certificates without CRL URLs in enabled mode""" chain = cert_gen.create_simple_chain() - chains = [[chain.leaf_cert, chain.intermediate_cert, chain.root_cert]] validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, allow_certificates_without_crl_url=False, trusted_certificates=[chain.root_cert], ) - assert not validator.validate_certificate_chains(chains) + assert not validator.validate_certificate_chain( + chain.leaf_cert, [chain.intermediate_cert, chain.root_cert] + ) def test_should_allow_certificates_without_crl_urls_when_configured( @@ -534,7 +571,6 @@ def test_should_allow_certificates_without_crl_urls_when_configured( ): """Test that certificates without CRL URLs are allowed when configured""" chain = cert_gen.create_simple_chain() - chains = [[chain.leaf_cert, chain.intermediate_cert, chain.root_cert]] validator = CRLValidator( session_manager, @@ -542,13 +578,14 @@ def test_should_allow_certificates_without_crl_urls_when_configured( allow_certificates_without_crl_url=True, trusted_certificates=[chain.root_cert], ) - assert validator.validate_certificate_chains(chains) + assert validator.validate_certificate_chain( + chain.leaf_cert, [chain.intermediate_cert, chain.root_cert] + ) def test_should_pass_in_advisory_mode_even_with_errors(cert_gen, session_manager): """Test that validation passes in advisory mode even with errors""" chain = cert_gen.create_simple_chain() - chains = [[chain.leaf_cert, chain.intermediate_cert, chain.root_cert]] validator = CRLValidator( session_manager, @@ -556,36 +593,10 @@ def test_should_pass_in_advisory_mode_even_with_errors(cert_gen, session_manager trusted_certificates=[chain.root_cert], ) - assert validator.validate_certificate_chains(chains) - - -def test_should_validate_multiple_chains_and_return_first_valid_with_no_crl_urls( - cert_gen, session_manager -): - """Test validation of multiple chains and return first valid""" - # Create a certificate that would be considered invalid (before March 2024) - before_march_2024 = datetime(2024, 2, 1, tzinfo=timezone.utc) - invalid_cert = cert_gen.create_short_lived_certificate(5, before_march_2024) - - # Create a valid chain - valid_chain = cert_gen.create_simple_chain() - - # Create list with invalid chain first, then valid chain - chains = [ - [invalid_cert, valid_chain.intermediate_cert, valid_chain.root_cert], - [valid_chain.leaf_cert, valid_chain.intermediate_cert, valid_chain.root_cert], - ] - - validator = CRLValidator( - session_manager, - cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, - allow_certificates_without_crl_url=True, - trusted_certificates=[valid_chain.root_cert], + assert validator.validate_certificate_chain( + chain.leaf_cert, [chain.intermediate_cert, chain.root_cert] ) - result = validator.validate_certificate_chains(chains) - assert result, "Should return true when at least one valid chain is found" - def test_cross_signed_certificate_chain(cert_gen, session_manager): """Test validation of cross-signed certificate chain""" @@ -598,20 +609,21 @@ def test_cross_signed_certificate_chain(cert_gen, session_manager): ) # provide full chain in arbitrary order - chains = [ + assert validator.validate_certificate_chain( + chain.final_cert, [ - chain.leafA, chain.AsignB, + chain.leafA, chain.leafB, chain.BsignA, chain.rootB, chain.rootA, - ] - ] - assert validator.validate_certificate_chains(chains) + ], + ) # only A is signed by CA - chains = [ + assert validator.validate_certificate_chain( + chain.final_cert, [ chain.leafA, chain.AsignB, @@ -619,12 +631,12 @@ def test_cross_signed_certificate_chain(cert_gen, session_manager): chain.BsignA, # chain.rootB, chain.rootA, - ] - ] - assert validator.validate_certificate_chains(chains) + ], + ) # nor A nor B is signed by CA - chains = [ + assert not validator.validate_certificate_chain( + chain.final_cert, [ chain.leafA, chain.AsignB, @@ -632,12 +644,12 @@ def test_cross_signed_certificate_chain(cert_gen, session_manager): chain.BsignA, # chain.rootB, # chain.rootA, - ] - ] - assert not validator.validate_certificate_chains(chains) + ], + ) # mingled A and B paths passed in one chain - A has no connection to CA, B has - chains = [ + assert validator.validate_certificate_chain( + chain.final_cert, [ chain.leafA, chain.AsignB, @@ -645,9 +657,8 @@ def test_cross_signed_certificate_chain(cert_gen, session_manager): # chain.BsignA, chain.rootB, # chain.rootA, - ] - ] - assert validator.validate_certificate_chains(chains) + ], + ) def test_starfield_incident(cert_gen, session_manager): @@ -667,7 +678,7 @@ def mock_validate(cert, _): validator._validate_certificate_is_not_revoked = mock_validate assert ( - validator._validate_single_chain([chain.leafA, chain.BsignA, chain.rootA]) + validator._validate_chain(chain.leafA, [chain.BsignA, chain.rootA]) == CRLValidationResult.UNREVOKED ) @@ -692,7 +703,7 @@ def mock_validate_with_special_cert(revoked_cert, error_result): for revoked_cert in [chain.rootA, chain.rootB, chain.leafA, chain.leafB]: mock_validate_with_special_cert(revoked_cert, error_result) assert ( - validator._validate_single_chain(input_chain) + validator._validate_chain(chain.final_cert, input_chain) == CRLValidationResult.UNREVOKED ) @@ -703,16 +714,22 @@ def mock_validate(cert, _): return CRLValidationResult.UNREVOKED validator._validate_certificate_is_not_revoked_with_cache = mock_validate - assert validator._validate_single_chain(input_chain) == CRLValidationResult.REVOKED + assert ( + validator._validate_chain(chain.final_cert, input_chain) + == CRLValidationResult.REVOKED + ) - # case 3: revoked + error should result in revoked\ + # case 3: revoked + error should result in revoked def mock_validate(cert, _): if cert in [chain.rootA, chain.leafB]: return CRLValidationResult.REVOKED return CRLValidationResult.ERROR validator._validate_certificate_is_not_revoked_with_cache = mock_validate - assert validator._validate_single_chain(input_chain) == CRLValidationResult.REVOKED + assert ( + validator._validate_chain(chain.final_cert, input_chain) + == CRLValidationResult.REVOKED + ) # case 4: no path to trusted certificate def mock_validate(cert, _): @@ -720,8 +737,8 @@ def mock_validate(cert, _): validator._validate_certificate_is_not_revoked_with_cache = mock_validate assert ( - validator._validate_single_chain( - [chain.leafA, chain.leafB, chain.AsignB, chain.BsignA] + validator._validate_chain( + chain.final_cert, [chain.leafA, chain.leafB, chain.AsignB, chain.BsignA] ) == CRLValidationResult.ERROR ) @@ -736,8 +753,15 @@ def mock_validate(cert, _): validator._validate_certificate_is_not_revoked_with_cache = mock_validate assert ( - validator._validate_single_chain( - [chain.leafA, chain.rootA, chain.leafB, chain.rootB, chain.BsignA] + validator._validate_chain( + chain.final_cert, + [ + chain.leafA, + chain.rootA, + chain.leafB, + chain.rootB, + chain.BsignA, + ], ) == CRLValidationResult.ERROR ) @@ -762,7 +786,6 @@ def test_should_validate_non_revoked_certificate_successfully( cert = cert_gen.create_certificate_with_crl_distribution_points( "CN=Test Server", [crl_urls.test_ca] ) - chain = [cert, cert_gen.ca_certificate] validator = CRLValidator( session_manager, @@ -770,7 +793,7 @@ def test_should_validate_non_revoked_certificate_successfully( trusted_certificates=[cert_gen.ca_certificate], ) - assert validator.validate_certificate_chains([chain]) + assert validator.validate_certificate_chain(cert, [cert_gen.ca_certificate]) assert resp.call_count @@ -793,7 +816,6 @@ def test_should_validate_non_revoked_certificate_successfully_if_root_not_provid cert = cert_gen.create_certificate_with_crl_distribution_points( "CN=Test Server", [crl_urls.test_ca] ) - chain = [cert] validator = CRLValidator( session_manager, @@ -801,7 +823,7 @@ def test_should_validate_non_revoked_certificate_successfully_if_root_not_provid trusted_certificates=[cert_gen.ca_certificate], ) - assert validator.validate_certificate_chains([chain]) + assert validator.validate_certificate_chain(cert, []) assert resp.call_count @@ -822,15 +844,13 @@ def test_should_fail_for_revoked_certificate(cert_gen, crl_urls, session_manager content_type="application/pkcs7-mime", ) - chain = [cert, cert_gen.ca_certificate] - validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, trusted_certificates=[cert_gen.ca_certificate], ) - assert not validator.validate_certificate_chains([chain]) + assert not validator.validate_certificate_chain(cert, [cert_gen.ca_certificate]) assert resp.call_count @@ -852,15 +872,13 @@ def test_should_allow_revoked_certificate_when_crl_validation_disabled( content_type="application/pkcs7-mime", ) - chain = [revoked_cert, cert_gen.ca_certificate] - validator = CRLValidator( session_manager, cert_revocation_check_mode=CertRevocationCheckMode.DISABLED, trusted_certificates=[cert_gen.ca_certificate], ) - assert validator.validate_certificate_chains([chain]) + assert validator.validate_certificate_chain(revoked_cert, [cert_gen.ca_certificate]) assert resp.call_count == 0 @@ -875,7 +893,6 @@ def test_should_pass_in_advisory_mode_with_crl_errors( cert = cert_gen.create_certificate_with_crl_distribution_points( "CN=Test Server", [crl_urls.test_ca] ) - chain = [cert, cert_gen.ca_certificate] validator = CRLValidator( session_manager, @@ -883,7 +900,7 @@ def test_should_pass_in_advisory_mode_with_crl_errors( trusted_certificates=[cert_gen.ca_certificate], ) - assert validator.validate_certificate_chains([chain]) + assert validator.validate_certificate_chain(cert, [cert_gen.ca_certificate]) assert resp.call_count @@ -898,7 +915,6 @@ def test_should_fail_in_enabled_mode_with_crl_errors( cert = cert_gen.create_certificate_with_crl_distribution_points( "CN=Test Server", [crl_urls.test_ca] ) - chain = [cert, cert_gen.ca_certificate] validator = CRLValidator( session_manager, @@ -906,50 +922,10 @@ def test_should_fail_in_enabled_mode_with_crl_errors( trusted_certificates=[cert_gen.ca_certificate], ) - assert not validator.validate_certificate_chains([chain]) + assert not validator.validate_certificate_chain(cert, [cert_gen.ca_certificate]) assert resp.call_count -@responses.activate -def test_should_validate_multiple_chains_and_success_if_just_one_valid( - cert_gen, crl_urls, session_manager -): - """Test validation of multiple chains and return first valid""" - # Create certificates - invalid_cert = cert_gen.create_certificate_with_crl_distribution_points( - "CN=Invalid Server", [crl_urls.invalid_ca] - ) - invalid_chain = [invalid_cert, cert_gen.ca_certificate] - - valid_cert = cert_gen.create_certificate_with_crl_distribution_points( - "CN=Valid Server", [crl_urls.valid_ca] - ) - valid_chain = [valid_cert, cert_gen.ca_certificate] - - valid_crl_content = cert_gen.generate_valid_crl() - - resp_200 = responses.add( - responses.GET, - crl_urls.valid_ca, - body=valid_crl_content, - status=200, - content_type="application/pkcs7-mime", - ) - - # Setup 404 for invalid certificate CRL - resp_404 = responses.add(responses.GET, crl_urls.invalid_ca, status=404) - - validator = CRLValidator( - session_manager, - cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, - trusted_certificates=[cert_gen.ca_certificate], - ) - - assert validator.validate_certificate_chains([invalid_chain, valid_chain]) - assert resp_200.call_count - assert resp_404.call_count - - @responses.activate def test_should_reject_expired_crl(cert_gen, crl_urls, session_manager): """Test rejection of expired CRL""" @@ -965,7 +941,6 @@ def test_should_reject_expired_crl(cert_gen, crl_urls, session_manager): cert = cert_gen.create_certificate_with_crl_distribution_points( "CN=Test Server", [crl_urls.expired_ca] ) - chain = [cert, cert_gen.ca_certificate] validator = CRLValidator( session_manager, @@ -973,7 +948,7 @@ def test_should_reject_expired_crl(cert_gen, crl_urls, session_manager): trusted_certificates=[cert_gen.ca_certificate], ) - assert not validator.validate_certificate_chains([chain]) + assert not validator.validate_certificate_chain(cert, [cert_gen.ca_certificate]) assert resp.call_count @@ -983,7 +958,6 @@ def test_should_skip_short_lived_certificates(cert_gen, session_manager): short_lived_cert = cert_gen.create_short_lived_certificate( 5, datetime.now(timezone.utc) ) - chain = [short_lived_cert, cert_gen.ca_certificate] validator = CRLValidator( session_manager, @@ -992,7 +966,9 @@ def test_should_skip_short_lived_certificates(cert_gen, session_manager): ) # Should pass without any HTTP calls (no responses setup) - assert validator.validate_certificate_chains([chain]) + assert validator.validate_certificate_chain( + short_lived_cert, [cert_gen.ca_certificate] + ) @responses.activate @@ -1025,7 +1001,6 @@ def test_should_handle_multiple_crl_distribution_points( cert = cert_gen.create_certificate_with_crl_distribution_points( "CN=Multi-CRL Server", crl_urls_list ) - chain = [cert, cert_gen.ca_certificate] validator = CRLValidator( session_manager, @@ -1033,7 +1008,7 @@ def test_should_handle_multiple_crl_distribution_points( trusted_certificates=[cert_gen.ca_certificate], ) - assert validator.validate_certificate_chains([chain]) + assert validator.validate_certificate_chain(cert, [cert_gen.ca_certificate]) assert resp_primary.call_count assert resp_backup.call_count @@ -1106,20 +1081,20 @@ def test_crl_validator_validate_connection(session_manager): assert not validator.validate_connection(mock_connection) -def test_crl_validator_extract_certificate_chains_from_connection( +def test_crl_validator_extract_certificate_chain_from_connection( cert_gen, session_manager ): - """Test the _extract_certificate_chains_from_connection method""" + """Test the _extract_certificate_chain_from_connection method""" chain = cert_gen.create_simple_chain() validator = CRLValidator(session_manager, trusted_certificates=[chain.root_cert]) # Test with no certificate chain mock_connection = Mock() - mock_connection.get_peer_cert_chain.return_value = [] + mock_connection.get_peer_cert_chain.return_value = None - chains = validator._extract_certificate_chains_from_connection(mock_connection) - assert chains == [] + chains = validator._extract_certificate_chain_from_connection(mock_connection) + assert chains is None # Test with mock certificate chain mock_certs = [] @@ -1144,10 +1119,9 @@ def mock_dump_certificate(file_type, cert_openssl): from unittest.mock import patch with patch("OpenSSL.crypto.dump_certificate", side_effect=mock_dump_certificate): - chains = validator._extract_certificate_chains_from_connection(mock_connection) + chain = validator._extract_certificate_chain_from_connection(mock_connection) - assert len(chains) == 1 - assert len(chains[0]) == 3 # leaf, intermediate, root + assert len(chain) == 3 # leaf, intermediate, root # New comprehensive tests for CRLConfig.from_connection @@ -1494,6 +1468,10 @@ def test_crl_validator_check_certificate_against_crl_expired( mock_crl.next_update_utc = datetime.now(timezone.utc) - timedelta(days=1) # Expired mock_crl.get_revoked_certificate_by_serial_number.return_value = None mock_crl.issuer = parent.subject + # Mock extensions to raise ExtensionNotFound for IDP extension + mock_crl.extensions.get_extension_for_oid.side_effect = x509.ExtensionNotFound( + "Extension not found", x509.oid.ExtensionOID.ISSUING_DISTRIBUTION_POINT + ) # Cache will return an expired CRL mock_cache_mgr = Mock(spec=CRLCacheManager) @@ -1531,6 +1509,10 @@ def test_crl_validator_validate_certificate_with_cache_hit( mock_crl = Mock(spec=x509.CertificateRevocationList) mock_crl.next_update_utc = datetime.now(timezone.utc) + timedelta(days=7) mock_crl.issuer = ca_cert.subject + # Mock extensions to raise ExtensionNotFound for IDP extension + mock_crl.extensions.get_extension_for_oid.side_effect = x509.ExtensionNotFound( + "Extension not found", x509.oid.ExtensionOID.ISSUING_DISTRIBUTION_POINT + ) mock_cache_manager = Mock() cached_entry = CRLCacheEntry(mock_crl, datetime.now(timezone.utc)) mock_cache_manager.get.return_value = cached_entry @@ -1590,6 +1572,10 @@ def test_crl_validator_validate_certificate_with_cache_miss( mock_crl = Mock() mock_crl.next_update_utc = datetime.now(timezone.utc) + timedelta(days=7) mock_crl.issuer = ca_cert.subject # Set the CRL issuer to match CA subject + # Mock extensions to raise ExtensionNotFound for IDP extension + mock_crl.extensions.get_extension_for_oid.side_effect = x509.ExtensionNotFound( + "Extension not found", x509.oid.ExtensionOID.ISSUING_DISTRIBUTION_POINT + ) mock_load_crl.return_value = mock_crl result = validator._validate_certificate_is_not_revoked(cert, ca_cert) @@ -1878,49 +1864,49 @@ def test_crl_signature_verification_with_issuer_mismatch_warning( ( # Issued on March 15, 2024, should use 10-day rule datetime(2024, 3, 15, tzinfo=timezone.utc), - 10, + 9, True, ), ( # Issued on March 15, 2024, should use 10-day rule datetime(2024, 3, 15, tzinfo=timezone.utc), - 11, + 10, False, ), ( # Issued on March 15, 2024, should use 10-day rule datetime(2024, 3, 15), - 10, + 9, True, ), ( # Issued on March 15, 2024, should use 10-day rule datetime(2024, 3, 15), - 11, + 10, False, ), ( # Issued on March 15, 2026, should use 7-day rule datetime(2026, 3, 15, tzinfo=timezone.utc), - 7, + 6, True, ), ( # Issued on March 15, 2026, should use 7-day rule datetime(2026, 3, 15, tzinfo=timezone.utc), - 8, + 7, False, ), ( # Issued on March 15, 2026, should use 7-day rule datetime(2026, 3, 15), - 7, + 6, True, ), ( # Issued on March 15, 2026, should use 7-day rule datetime(2026, 3, 15), - 8, + 7, False, ), ], @@ -1931,21 +1917,59 @@ def test_is_short_lived_certificate(cert_gen, issue_date, validity_days, expecte def test_validate_certificate_signatures(cert_gen, session_manager): - """Test that certificate validation fails with ERROR when signed by wrong key""" - # Create a certificate signed by the test CA - valid_cert = cert_gen.create_certificate_with_crl_distribution_points( - "CN=Test Server", [] + """Test that certificate validation fails with ERROR when certificate is expired""" + name = x509.Name( + [ + x509.NameAttribute( + NameOID.COMMON_NAME, + "Test Expired Certificate", + ) + ] ) + different_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + malsigned_cert = ( + x509.CertificateBuilder() + .subject_name(name) + .issuer_name(cert_gen.ca_certificate.subject) + .public_key(cert_gen.ca_certificate.public_key()) # does not matter + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc) - timedelta(days=2)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(different_key, hashes.SHA256()) + ) + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + allow_certificates_without_crl_url=True, + trusted_certificates=[cert_gen.ca_certificate], + ) + # expired cert - no path found = ERROR + assert validator._validate_chain(malsigned_cert, []) == CRLValidationResult.ERROR + + +def test_validate_certificate_signatures_in_chain(cert_gen, session_manager): + """Test that certificate validation fails with ERROR when signed by wrong key""" + # Create a certificate chain signed by the test CA: leaf -> A -> B -> CA + # mingle with A -> B + chain = cert_gen.create_cross_signed_chain() + + valid_cert = chain.BsignA # Create a different CA key pair - different_ca_key = rsa.generate_private_key( + different_key = rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() ) different_cert = ( x509.CertificateBuilder() .subject_name(valid_cert.subject) - .issuer_name(cert_gen.ca_certificate.subject) - .public_key(cert_gen.ca_private_key.public_key()) + .issuer_name(valid_cert.subject) + .public_key(valid_cert.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(datetime.now(timezone.utc)) .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) @@ -1953,13 +1977,13 @@ def test_validate_certificate_signatures(cert_gen, session_manager): x509.BasicConstraints(ca=True, path_length=None), critical=True, ) - .sign(different_ca_key, hashes.SHA256(), backend=default_backend()) + .sign(different_key, hashes.SHA256(), backend=default_backend()) ) short_lived_different_cert = ( x509.CertificateBuilder() .subject_name(valid_cert.subject) - .issuer_name(cert_gen.ca_certificate.subject) - .public_key(different_ca_key.public_key()) + .issuer_name(valid_cert.subject) + .public_key(valid_cert.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(datetime.now(timezone.utc)) .not_valid_after(datetime.now(timezone.utc) + timedelta(days=3)) @@ -1967,73 +1991,117 @@ def test_validate_certificate_signatures(cert_gen, session_manager): x509.BasicConstraints(ca=True, path_length=None), critical=True, ) - .sign(different_ca_key, hashes.SHA256(), backend=default_backend()) + .sign(different_key, hashes.SHA256(), backend=default_backend()) ) validator = CRLValidator( session_manager, - cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, allow_certificates_without_crl_url=True, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, trusted_certificates=[cert_gen.ca_certificate], ) # wrong signature - no path found = ERROR assert ( - validator._validate_single_chain([different_cert]) == CRLValidationResult.ERROR + validator._validate_chain( + chain.final_cert, [chain.leafA, different_cert, chain.rootB] + ) + == CRLValidationResult.ERROR ) # wrong signature - short-lived - no path found = ERROR assert ( - validator._validate_single_chain([short_lived_different_cert]) + validator._validate_chain( + chain.final_cert, [chain.leafA, short_lived_different_cert, chain.rootB] + ) == CRLValidationResult.ERROR ) # wrong signature does not stop from searching of new path assert ( - validator._validate_single_chain( - [different_cert, short_lived_different_cert, valid_cert] + validator._validate_chain( + chain.final_cert, + [ + chain.leafA, + different_cert, + short_lived_different_cert, + valid_cert, + chain.rootB, + ], ) == CRLValidationResult.UNREVOKED ) -def test_validate_certificate_signatures_in_chain(cert_gen, session_manager): - """Test that certificate validation fails with ERROR when signed by wrong key""" - # Create a certificate chain signed by the test CA: leaf -> A -> B -> CA - # mingle with A -> B +def test_validate_expired_certificates(cert_gen, session_manager): + """Test that certificate validation fails with ERROR when certificate is expired""" + name = x509.Name( + [ + x509.NameAttribute( + NameOID.COMMON_NAME, + "Test Expired Certificate", + ) + ] + ) + expired_cert = ( + x509.CertificateBuilder() + .subject_name(name) + .issuer_name(cert_gen.ca_certificate.subject) + .public_key(cert_gen.ca_certificate.public_key()) # does not matter + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc) - timedelta(days=2)) + .not_valid_after(datetime.now(timezone.utc) - timedelta(days=1)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(cert_gen.ca_private_key, hashes.SHA256()) + ) + validator = CRLValidator( + session_manager, + cert_revocation_check_mode=CertRevocationCheckMode.ENABLED, + allow_certificates_without_crl_url=True, + trusted_certificates=[cert_gen.ca_certificate], + ) + # expired cert - no path found = ERROR + assert validator._validate_chain(expired_cert, []) == CRLValidationResult.ERROR + + +def test_validate_expired_certificates_in_chain(cert_gen, session_manager): + """Test that certificate validation fails with ERROR when certificate in chain is expired""" + # Create a certificate chain signed by the test CA: final_cert -> leafA -> rootA -> CA chain = cert_gen.create_cross_signed_chain() - valid_cert = chain.BsignA + valid_cert = chain.rootA - # Create a different CA key pair - different_key = rsa.generate_private_key( - public_exponent=65537, key_size=2048, backend=default_backend() - ) - different_cert = ( + # Create an expired certificate with the same subject as valid_cert + expired_cert = ( x509.CertificateBuilder() .subject_name(valid_cert.subject) .issuer_name(cert_gen.ca_certificate.subject) - .public_key(cert_gen.ca_private_key.public_key()) + .public_key(valid_cert.public_key()) .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.now(timezone.utc)) - .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .not_valid_before(datetime.now(timezone.utc) - timedelta(days=365)) + .not_valid_after(datetime.now(timezone.utc) - timedelta(days=10)) .add_extension( x509.BasicConstraints(ca=True, path_length=None), critical=True, ) - .sign(different_key, hashes.SHA256(), backend=default_backend()) + .sign(cert_gen.ca_private_key, hashes.SHA256(), backend=default_backend()) ) - short_lived_different_cert = ( + + # Create a short-lived expired certificate + short_lived_expired_cert = ( x509.CertificateBuilder() .subject_name(valid_cert.subject) .issuer_name(cert_gen.ca_certificate.subject) - .public_key(different_key.public_key()) + .public_key(valid_cert.public_key()) .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.now(timezone.utc)) - .not_valid_after(datetime.now(timezone.utc) + timedelta(days=3)) + .not_valid_before(datetime.now(timezone.utc) - timedelta(days=4)) + .not_valid_after(datetime.now(timezone.utc) - timedelta(days=1)) .add_extension( x509.BasicConstraints(ca=True, path_length=None), critical=True, ) - .sign(different_key, hashes.SHA256(), backend=default_backend()) + .sign(cert_gen.ca_private_key, hashes.SHA256(), backend=default_backend()) ) validator = CRLValidator( @@ -2043,28 +2111,28 @@ def test_validate_certificate_signatures_in_chain(cert_gen, session_manager): trusted_certificates=[cert_gen.ca_certificate], ) - # wrong signature - no path found = ERROR + # expired cert - no path found = ERROR assert ( - validator._validate_single_chain([chain.leafA, different_cert, chain.rootB]) + validator._validate_chain(chain.final_cert, [chain.leafA, expired_cert]) == CRLValidationResult.ERROR ) - # wrong signature - short-lived - no path found = ERROR + # expired short-lived cert - no path found = ERROR assert ( - validator._validate_single_chain( - [chain.leafA, short_lived_different_cert, chain.rootB] + validator._validate_chain( + chain.final_cert, [chain.leafA, short_lived_expired_cert] ) == CRLValidationResult.ERROR ) - # wrong signature does not stop from searching of new path + # expired cert does not stop from searching for a valid path assert ( - validator._validate_single_chain( + validator._validate_chain( + chain.final_cert, [ chain.leafA, - different_cert, - short_lived_different_cert, + expired_cert, + short_lived_expired_cert, valid_cert, - chain.rootB, - ] + ], ) == CRLValidationResult.UNREVOKED ) @@ -2082,3 +2150,210 @@ def test_trusted_certificates_helpers(cert_gen): assert validator._get_trusted_ca_issuer(chain.intermediate_cert) is chain.root_cert assert validator._get_trusted_ca_issuer(chain.leaf_cert) is None + + +@pytest.mark.parametrize( + "timedelta_before,timedelta_after,expected_result", + [ + # Valid certificate (currently within validity period) + (timedelta(days=-1), timedelta(days=365), True), + # Expired certificate (after not_valid_after) + (timedelta(days=-365), timedelta(days=-1), False), + # Not yet valid certificate (before not_valid_before) + (timedelta(days=1), timedelta(days=365), False), + # Edge case - just became valid + (timedelta(seconds=-1), timedelta(days=365), True), + # Edge case - about to expire + (timedelta(days=-365), timedelta(seconds=1), True), + ], +) +def test_is_within_validity_dates(timedelta_before, timedelta_after, expected_result): + """Test the _is_within_validity_dates function for certificate validity checks""" + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + cert_name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "Test Certificate")]) + now = datetime.now(timezone.utc) + cert = ( + x509.CertificateBuilder() + .subject_name(cert_name) + .issuer_name(cert_name) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now + timedelta_before) + .not_valid_after(now + timedelta_after) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(key, hashes.SHA256(), backend=default_backend()) + ) + + assert CRLValidator._is_within_validity_dates(cert) is expected_result + + +def test_verify_against_idp_extension_no_extension(cert_gen): + """Test IDP verification when CRL has no IDP extension - should pass""" + # Generate a CRL without IDP extension + crl_bytes = cert_gen.generate_valid_crl() + crl = x509.load_der_x509_crl(crl_bytes, backend=default_backend()) + + validator = CRLValidator( + session_manager=Mock(), + trusted_certificates=[cert_gen.ca_certificate], + ) + + # Should return True when no IDP extension is present + assert ( + validator._verify_against_idp_extension(crl, "http://example.com/crl") is True + ) + + +@pytest.mark.parametrize( + "full_name_urls,crl_url,expected_result", + [ + ( + # matching single URL + ["http://example.com/test.crl"], + "http://example.com/test.crl", + True, + ), + ( + # non-matching single URL + ["http://example.com/correct.crl"], + "http://example.com/wrong.crl", + False, + ), + ( + # matching one of multiple URLs + [ + "http://example.com/crl1.crl", + "http://example.com/crl2.crl", + "http://example.com/crl3.crl", + ], + "http://example.com/crl2.crl", + True, + ), + ( + # non-matching with multiple URLs + [ + "http://example.com/crl1.crl", + "http://example.com/crl2.crl", + "http://example.com/crl3.crl", + ], + "http://example.com/wrong.crl", + False, + ), + ( + # no full_name (violates baseline requirements) + None, + "http://example.com/crl", + False, + ), + ], +) +def test_verify_against_idp_extension_with_full_name( + cert_gen, full_name_urls, crl_url, expected_result +): + """Test IDP verification with various full_name configurations""" + + full_name = ( + [x509.UniformResourceIdentifier(url) for url in full_name_urls] + if full_name_urls + else None + ) + # Build CRL with IDP extension + crl = ( + x509.CertificateRevocationListBuilder() + .issuer_name(cert_gen.ca_certificate.subject) + .last_update(datetime.now(timezone.utc)) + .next_update(datetime.now(timezone.utc) + timedelta(days=1)) + .add_extension( + x509.IssuingDistributionPoint( + full_name=full_name, + relative_name=None, + only_contains_user_certs=True, + only_contains_ca_certs=False, + only_some_reasons=None, + indirect_crl=False, + only_contains_attribute_certs=False, + ), + critical=True, + ) + .sign(cert_gen.ca_private_key, hashes.SHA256(), backend=default_backend()) + ) + + validator = CRLValidator( + session_manager=Mock(), + trusted_certificates=[cert_gen.ca_certificate], + ) + + # Verify the result matches expected + assert validator._verify_against_idp_extension(crl, crl_url) is expected_result + + +@responses.activate +def test_check_certificate_against_crl_url_with_idp_mismatch( + cert_gen, session_manager, crl_urls +): + """CRL validation should fail when IDP URL doesn't match""" + chain = cert_gen.create_simple_chain() + + # Create a test CA for signing the CRL + test_ca_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + test_ca_cert = ( + x509.CertificateBuilder() + .subject_name(chain.root_cert.subject) + .issuer_name(chain.root_cert.subject) + .public_key(test_ca_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(test_ca_key, hashes.SHA256()) + ) + + # Create a CRL with IDP extension pointing to a different URL + builder = x509.CertificateRevocationListBuilder() + builder = builder.issuer_name(test_ca_cert.subject) + builder = builder.last_update(datetime.now(timezone.utc)) + builder = builder.next_update(datetime.now(timezone.utc) + timedelta(days=1)) + + # Add IDP extension with different URL than the one we'll use to fetch + idp = x509.IssuingDistributionPoint( + full_name=[x509.UniformResourceIdentifier("http://different.com/crl.crl")], + relative_name=None, + only_contains_user_certs=False, + only_contains_ca_certs=False, + only_some_reasons=None, + indirect_crl=False, + only_contains_attribute_certs=False, + ) + builder = builder.add_extension(idp, critical=False) + + crl = builder.sign(test_ca_key, hashes.SHA256(), backend=default_backend()) + crl_bytes = crl.public_bytes(serialization.Encoding.DER) + + # Mock the HTTP response + responses.add( + responses.GET, + crl_urls.test_ca, + body=crl_bytes, + status=200, + content_type="application/pkix-crl", + ) + + validator = CRLValidator( + session_manager=session_manager, + trusted_certificates=[test_ca_cert], + ) + + # Check certificate against CRL URL - should fail due to IDP mismatch + result = validator._check_certificate_against_crl_url( + chain.leaf_cert, test_ca_cert, crl_urls.test_ca + ) + + assert result == CRLValidationResult.ERROR From f9196e1f819374ccb8da52f36a301ba95fdac6bd Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Mon, 3 Nov 2025 10:23:10 +0100 Subject: [PATCH 05/16] NO-SNOW: CRL changes in aio --- src/snowflake/connector/aio/_connection.py | 2 + .../connector/aio/_session_manager.py | 42 ++++++++++++++++++- src/snowflake/connector/crl.py | 25 +++++++++-- src/snowflake/connector/ssl_wrap_socket.py | 8 ++-- 4 files changed, 69 insertions(+), 8 deletions(-) diff --git a/src/snowflake/connector/aio/_connection.py b/src/snowflake/connector/aio/_connection.py index 9e0f6a9267..e1d3feebdc 100644 --- a/src/snowflake/connector/aio/_connection.py +++ b/src/snowflake/connector/aio/_connection.py @@ -47,6 +47,7 @@ PARAMETER_TIMEZONE, QueryStatus, ) +from ..crl import CRLConfig from ..description import PLATFORM, PYTHON_VERSION, SNOWFLAKE_CONNECTOR_VERSION from ..errorcode import ( ER_CONNECTION_IS_CLOSED, @@ -622,6 +623,7 @@ def _init_connection_parameters( # Placeholder attributes; will be initialized in connect() self._http_config: AioHttpConfig | None = None + self._crl_config: CRLConfig | None = None self._session_manager: SessionManager | None = None self._rest = None for name, (value, _) in DEFAULT_CONFIGURATION.items(): diff --git a/src/snowflake/connector/aio/_session_manager.py b/src/snowflake/connector/aio/_session_manager.py index a2fa06bc0a..4a544460cd 100644 --- a/src/snowflake/connector/aio/_session_manager.py +++ b/src/snowflake/connector/aio/_session_manager.py @@ -3,6 +3,7 @@ import sys from typing import TYPE_CHECKING +import certifi from aiohttp import ClientRequest, ClientTimeout from aiohttp.client import _RequestOptions from aiohttp.client_proto import ResponseHandler @@ -10,8 +11,15 @@ from aiohttp.typedefs import StrOrURL from .. import OperationalError +from ..crl import CertRevocationCheckMode, CRLValidator from ..errorcode import ER_OCSP_RESPONSE_CERT_STATUS_REVOKED -from ..ssl_wrap_socket import FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME +from ..ssl_wrap_socket import ( + FEATURE_CRL_CONFIG, + FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME, + get_current_session_manager, + load_trusted_certificates, + resolve_cafile, +) from ._ocsp_asn1crypto import SnowflakeOCSPAsn1Crypto if TYPE_CHECKING: @@ -71,6 +79,20 @@ async def connect( ) -> Connection: connection = await super().connect(req, traces, timeout) protocol = connection.protocol + + logger.debug( + "CRL Check Mode: %s", + FEATURE_CRL_CONFIG.cert_revocation_check_mode.name, + ) + if ( + FEATURE_CRL_CONFIG.cert_revocation_check_mode + != CertRevocationCheckMode.ENABLED + ): + self.validate_crl(protocol, req) + logger.debug( + "The certificate revocation check was successful. No additional checks will be performed." + ) + if ( req.is_ssl() and protocol is not None @@ -90,6 +112,24 @@ async def connect( protocol._snowflake_ocsp_validated = True return connection + def validate_crl(self, protocol: ResponseHandler, req: ClientRequest): + # Resolve CA file path from environment variables or use certifi default + cafile_for_ctx = resolve_cafile({"ca_certs": certifi.where()}) + crl_validator = CRLValidator.from_config( + FEATURE_CRL_CONFIG, + get_current_session_manager(), + trusted_certificates=load_trusted_certificates(cafile_for_ctx), + ) + sll_object = protocol.transport.get_extra_info("ssl_object") + if not crl_validator.validate_connection(sll_object): + raise OperationalError( + msg=( + "The certificate is revoked or " + "could not be validated via CRL: hostname={}".format(req.url.host) + ), + errno=ER_OCSP_RESPONSE_CERT_STATUS_REVOKED, + ) + async def validate_ocsp( self, hostname: str, diff --git a/src/snowflake/connector/crl.py b/src/snowflake/connector/crl.py index e6af12c412..a38f132ee3 100644 --- a/src/snowflake/connector/crl.py +++ b/src/snowflake/connector/crl.py @@ -1,12 +1,14 @@ #!/usr/bin/env python from __future__ import annotations +import ssl from collections import defaultdict from dataclasses import dataclass from datetime import datetime, timedelta, timezone from enum import Enum, unique from logging import getLogger from pathlib import Path +from ssl import SSLObject from typing import Any from cryptography import x509 @@ -750,7 +752,7 @@ def validate_connection(self, connection: SSLConnection) -> bool: """ try: # Get the peer certificate (the start certificate) - peer_cert = connection.get_peer_certificate(as_cryptography=True) + peer_cert = self._get_peer_certificate(connection) if peer_cert is None: logger.warning("No peer certificate found in connection") return ( @@ -765,6 +767,14 @@ def validate_connection(self, connection: SSLConnection) -> bool: logger.warning("Failed to validate connection: %s", e) return self._cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY + @staticmethod + def _get_peer_certificate(connection): + if isinstance(connection, SSLObject): + return x509.load_der_x509_certificate( + connection.getpeercert(binary_form=True), default_backend() + ) + return connection.get_peer_certificate(as_cryptography=True) + def _extract_certificate_chain_from_connection( self, connection ) -> list[x509.Certificate] | None: @@ -777,8 +787,17 @@ def _extract_certificate_chain_from_connection( Certificate chain as a list of x509.Certificate objects, or None on error """ try: - # Convert OpenSSL certificates to cryptography x509 certificates - cert_chain = connection.get_peer_cert_chain(as_cryptography=True) + if isinstance(connection, SSLObject): + cert_chain = [] + for cert in connection._sslobj.get_unverified_chain(): + cert_bytes = ssl.PEM_cert_to_DER_cert(cert.public_bytes()) + cert_chain.append( + x509.load_der_x509_certificate(cert_bytes, default_backend()) + ) + else: + # Convert OpenSSL certificates to cryptography x509 certificates + cert_chain = connection.get_peer_cert_chain(as_cryptography=True) + if not cert_chain: logger.debug("No certificate chain found in connection") return None diff --git a/src/snowflake/connector/ssl_wrap_socket.py b/src/snowflake/connector/ssl_wrap_socket.py index 4712cf395d..5b2d6ed325 100644 --- a/src/snowflake/connector/ssl_wrap_socket.py +++ b/src/snowflake/connector/ssl_wrap_socket.py @@ -47,7 +47,7 @@ # Helper utilities (private) -def _resolve_cafile(kwargs: dict[str, Any]) -> str | None: +def resolve_cafile(kwargs: dict[str, Any]) -> str | None: """Resolve CA bundle path from kwargs or standard environment variables. Precedence: @@ -150,7 +150,7 @@ def inject_into_urllib3() -> None: connection_.ssl_wrap_socket = ssl_wrap_socket_with_cert_revocation_checks -def _load_trusted_certificates(cafile: str | None) -> list[x509.Certificate]: +def load_trusted_certificates(cafile: str | None) -> list[x509.Certificate]: # Use default SSL context to load the CA file and get the certificates ctx = ssl.create_default_context() ctx.load_verify_locations(cafile=cafile) @@ -177,7 +177,7 @@ def ssl_wrap_socket_with_cert_revocation_checks( # Ensure PyOpenSSL context with partial-chain is used if none or wrong type provided provided_ctx = params.get("ssl_context") - cafile_for_ctx = _resolve_cafile(params) + cafile_for_ctx = resolve_cafile(params) if not isinstance(provided_ctx, PyOpenSSLContext): params["ssl_context"] = _build_context_with_partial_chain(cafile_for_ctx) else: @@ -197,7 +197,7 @@ def ssl_wrap_socket_with_cert_revocation_checks( crl_validator = CRLValidator.from_config( FEATURE_CRL_CONFIG, get_current_session_manager(), - trusted_certificates=_load_trusted_certificates(cafile_for_ctx), + trusted_certificates=load_trusted_certificates(cafile_for_ctx), ) if not crl_validator.validate_connection(ret.connection): raise OperationalError( From c8bdeb4748709195b6ff64c077a3cfca308bf136 Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Wed, 5 Nov 2025 14:41:07 +0100 Subject: [PATCH 06/16] NO-SNOW: Fix tests --- setup.cfg | 1 + src/snowflake/connector/aio/_connection.py | 1 + src/snowflake/connector/aio/_session_manager.py | 2 +- src/snowflake/connector/aio/auth/_auth.py | 1 + src/snowflake/connector/aio/auth/_okta.py | 1 + src/snowflake/connector/aio/auth/_webbrowser.py | 1 + test/unit/aio/test_auth_async.py | 1 + test/unit/aio/test_auth_keypair_async.py | 1 + test/unit/aio/test_auth_okta_async.py | 1 + test/unit/aio/test_auth_webbrowser_async.py | 1 + 10 files changed, 10 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index e08154af63..5cdc6d876a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -96,6 +96,7 @@ development = pytest-xdist pytzdata responses + pytest-asyncio pandas = pandas>=2.1.2,<3.0.0 pyarrow diff --git a/src/snowflake/connector/aio/_connection.py b/src/snowflake/connector/aio/_connection.py index e1d3feebdc..8df7665bb4 100644 --- a/src/snowflake/connector/aio/_connection.py +++ b/src/snowflake/connector/aio/_connection.py @@ -1037,6 +1037,7 @@ async def connect(self, **kwargs) -> None: else: self.__config(**self._conn_parameters) + self._crl_config: CRLConfig = CRLConfig.from_connection(self) self._http_config: AioHttpConfig = AioHttpConfig( connector_factory=SnowflakeSSLConnectorFactory(), use_pooling=not self.disable_request_pooling, diff --git a/src/snowflake/connector/aio/_session_manager.py b/src/snowflake/connector/aio/_session_manager.py index 4a544460cd..c437a55bf0 100644 --- a/src/snowflake/connector/aio/_session_manager.py +++ b/src/snowflake/connector/aio/_session_manager.py @@ -86,7 +86,7 @@ async def connect( ) if ( FEATURE_CRL_CONFIG.cert_revocation_check_mode - != CertRevocationCheckMode.ENABLED + != CertRevocationCheckMode.DISABLED ): self.validate_crl(protocol, req) logger.debug( diff --git a/src/snowflake/connector/aio/auth/_auth.py b/src/snowflake/connector/aio/auth/_auth.py index b8c6564837..a0f3d228ba 100644 --- a/src/snowflake/connector/aio/auth/_auth.py +++ b/src/snowflake/connector/aio/auth/_auth.py @@ -99,6 +99,7 @@ async def authenticate( self._rest._connection._internal_application_name, self._rest._connection._internal_application_version, self._rest._connection._ocsp_mode(), + self._rest._connection.cert_revocation_check_mode, self._rest._connection._login_timeout, self._rest._connection._network_timeout, self._rest._connection._socket_timeout, diff --git a/src/snowflake/connector/aio/auth/_okta.py b/src/snowflake/connector/aio/auth/_okta.py index 50a9c8a6b8..475a6a3890 100644 --- a/src/snowflake/connector/aio/auth/_okta.py +++ b/src/snowflake/connector/aio/auth/_okta.py @@ -121,6 +121,7 @@ async def _step1( conn._internal_application_name, conn._internal_application_version, conn._ocsp_mode(), + conn.cert_revocation_check_mode, conn.login_timeout, conn.network_timeout, conn.socket_timeout, diff --git a/src/snowflake/connector/aio/auth/_webbrowser.py b/src/snowflake/connector/aio/auth/_webbrowser.py index aca409d0c5..c6231ffa63 100644 --- a/src/snowflake/connector/aio/auth/_webbrowser.py +++ b/src/snowflake/connector/aio/auth/_webbrowser.py @@ -377,6 +377,7 @@ async def _get_sso_url( conn._internal_application_name, conn._internal_application_version, conn._ocsp_mode(), + conn.cert_revocation_check_mode, conn.login_timeout, conn.network_timeout, conn.socket_timeout, diff --git a/test/unit/aio/test_auth_async.py b/test/unit/aio/test_auth_async.py index e92f3be556..5b6491ffa1 100644 --- a/test/unit/aio/test_auth_async.py +++ b/test/unit/aio/test_auth_async.py @@ -25,6 +25,7 @@ def _init_rest(application, post_requset): connection = mock_connection() connection.errorhandler = Mock(return_value=None) connection._ocsp_mode = Mock(return_value=OCSPMode.FAIL_OPEN) + connection.cert_revocation_check_mode = "TEST_CRL_MODE" type(connection).application = PropertyMock(return_value=application) type(connection)._internal_application_name = PropertyMock(return_value=CLIENT_NAME) type(connection)._internal_application_version = PropertyMock( diff --git a/test/unit/aio/test_auth_keypair_async.py b/test/unit/aio/test_auth_keypair_async.py index e802a3d1cc..755acb1b74 100644 --- a/test/unit/aio/test_auth_keypair_async.py +++ b/test/unit/aio/test_auth_keypair_async.py @@ -165,6 +165,7 @@ def _init_rest(application, post_requset): connection = mock_connection() connection.errorhandler = Mock(return_value=None) connection._ocsp_mode = Mock(return_value=OCSPMode.FAIL_OPEN) + connection.cert_revocation_check_mode = "TEST_CRL_MODE" type(connection).application = PropertyMock(return_value=application) type(connection)._internal_application_name = PropertyMock(return_value=CLIENT_NAME) type(connection)._internal_application_version = PropertyMock( diff --git a/test/unit/aio/test_auth_okta_async.py b/test/unit/aio/test_auth_okta_async.py index 1a2a8d0298..ad4712e35a 100644 --- a/test/unit/aio/test_auth_okta_async.py +++ b/test/unit/aio/test_auth_okta_async.py @@ -361,6 +361,7 @@ async def post_request(url, headers, body, **kwargs): connection = mock_connection(disable_saml_url_check=disable_saml_url_check) connection.errorhandler = Mock(return_value=None) connection._ocsp_mode = Mock(return_value=OCSPMode.FAIL_OPEN) + connection.cert_revocation_check_mode = "TEST_CRL_MODE" type(connection).application = PropertyMock(return_value=CLIENT_NAME) type(connection)._internal_application_name = PropertyMock(return_value=CLIENT_NAME) type(connection)._internal_application_version = PropertyMock( diff --git a/test/unit/aio/test_auth_webbrowser_async.py b/test/unit/aio/test_auth_webbrowser_async.py index 5e1b699f80..ed5aa56ece 100644 --- a/test/unit/aio/test_auth_webbrowser_async.py +++ b/test/unit/aio/test_auth_webbrowser_async.py @@ -414,6 +414,7 @@ async def post_request(url, headers, body, **kwargs): connection = mock_connection(socket_timeout=socket_timeout) connection.errorhandler = Mock(return_value=None) connection._ocsp_mode = Mock(return_value=OCSPMode.FAIL_OPEN) + connection.cert_revocation_check_mode = "TEST_CRL_MODE" connection._disable_console_login = disable_console_login type(connection).application = PropertyMock(return_value=CLIENT_NAME) type(connection)._internal_application_name = PropertyMock(return_value=CLIENT_NAME) From 3ddf8017f854d48b47c3f9e7f9902d80fec87742 Mon Sep 17 00:00:00 2001 From: Patryk Czajka Date: Thu, 9 Oct 2025 08:05:13 +0200 Subject: [PATCH 07/16] Fixup Jenkins build (#2572) --- src/snowflake/connector/crl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/connector/crl.py b/src/snowflake/connector/crl.py index a38f132ee3..2429fc8e3b 100644 --- a/src/snowflake/connector/crl.py +++ b/src/snowflake/connector/crl.py @@ -65,7 +65,7 @@ class CRLConfig: crl_cache_dir: Path | str | None = None crl_cache_removal_delay_days: int = 7 crl_cache_cleanup_interval_hours: int = 1 - crl_cache_start_cleanup: bool = True + crl_cache_start_cleanup: bool = False @classmethod def from_connection(cls, sf_connection) -> CRLConfig: From 6833f9b869aa8ae59ffd19aeff6808a5abb4ecb4 Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Thu, 6 Nov 2025 12:13:48 +0100 Subject: [PATCH 08/16] NO-SNOW: Add more tests for async crl --- .../connector/aio/_session_manager.py | 4 +- test/integ/aio_it/test_crl_async.py | 179 ++++++++++++++++++ 2 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 test/integ/aio_it/test_crl_async.py diff --git a/src/snowflake/connector/aio/_session_manager.py b/src/snowflake/connector/aio/_session_manager.py index c437a55bf0..b805019af8 100644 --- a/src/snowflake/connector/aio/_session_manager.py +++ b/src/snowflake/connector/aio/_session_manager.py @@ -16,7 +16,6 @@ from ..ssl_wrap_socket import ( FEATURE_CRL_CONFIG, FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME, - get_current_session_manager, load_trusted_certificates, resolve_cafile, ) @@ -117,10 +116,11 @@ def validate_crl(self, protocol: ResponseHandler, req: ClientRequest): cafile_for_ctx = resolve_cafile({"ca_certs": certifi.where()}) crl_validator = CRLValidator.from_config( FEATURE_CRL_CONFIG, - get_current_session_manager(), + self._session_manager, trusted_certificates=load_trusted_certificates(cafile_for_ctx), ) sll_object = protocol.transport.get_extra_info("ssl_object") + # TODO(asyncio): SNOW-2681061 Add sync support for validate_connection if not crl_validator.validate_connection(sll_object): raise OperationalError( msg=( diff --git a/test/integ/aio_it/test_crl_async.py b/test/integ/aio_it/test_crl_async.py new file mode 100644 index 0000000000..d8b07d4192 --- /dev/null +++ b/test/integ/aio_it/test_crl_async.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python +""" +CRL (Certificate Revocation List) Validation Integration Tests - Async version + +These tests verify that CRL validation works correctly with real Snowflake connections +in different modes: DISABLED, ADVISORY, and ENABLED. +""" +from __future__ import annotations + +import tempfile + +import pytest + + +@pytest.mark.skipolddriver +async def test_crl_validation_enabled_mode(conn_cnx): + """Test that connection works with CRL validation in ENABLED mode.""" + # ENABLED mode should work for normal Snowflake connections since they typically + # have valid certificates with proper CRL distribution points + async with conn_cnx( + cert_revocation_check_mode="ENABLED", + allow_certificates_without_crl_url=True, # Allow certs without CRL URLs + crl_connection_timeout_ms=5000, # 5 second timeout + crl_read_timeout_ms=5000, # 5 second timeout + disable_ocsp_checks=True, + ) as cnx: + assert cnx, "Connection should succeed with CRL validation in ENABLED mode" + + # Verify we can execute a simple query + cur = cnx.cursor() + await cur.execute("SELECT 1") + result = await cur.fetchone() + assert result[0] == 1, "Query should execute successfully" + await cur.close() + + # Verify CRL settings were applied + assert cnx.cert_revocation_check_mode == "ENABLED" + assert cnx.allow_certificates_without_crl_url is True + + +@pytest.mark.skipolddriver +async def test_crl_validation_advisory_mode(conn_cnx): + """Test that connection works with CRL validation in ADVISORY mode.""" + # ADVISORY mode should be more lenient and allow connections even if CRL checks fail + async with conn_cnx( + cert_revocation_check_mode="ADVISORY", + allow_certificates_without_crl_url=False, # Don't allow certs without CRL URLs + crl_connection_timeout_ms=3000, # 3 second timeout + crl_read_timeout_ms=3000, # 3 second timeout + enable_crl_cache=True, # Enable caching + crl_cache_validity_hours=1, # Cache for 1 hour + ) as cnx: + assert cnx, "Connection should succeed with CRL validation in ADVISORY mode" + + # Verify we can execute a simple query + cur = cnx.cursor() + await cur.execute("SELECT CURRENT_VERSION()") + result = await cur.fetchone() + assert result[0], "Query should return a version string" + await cur.close() + + # Verify CRL settings were applied + assert cnx.cert_revocation_check_mode == "ADVISORY" + assert cnx.allow_certificates_without_crl_url is False + assert cnx.enable_crl_cache is True + assert cnx.crl_connection_timeout_ms == 3000 + assert cnx.crl_read_timeout_ms == 3000 + assert cnx.crl_cache_validity_hours == 1 + assert cnx.crl_cache_dir is None + + +@pytest.mark.skipolddriver +async def test_crl_validation_disabled_mode(conn_cnx): + """Test that connection works with CRL validation in DISABLED mode (default).""" + # DISABLED mode should work without any CRL checks + async with conn_cnx( + cert_revocation_check_mode="DISABLED", + ) as cnx: + assert cnx, "Connection should succeed with CRL validation in DISABLED mode" + + # Verify we can execute a simple query + cur = cnx.cursor() + await cur.execute("SELECT 'CRL_DISABLED' as test_value") + result = await cur.fetchone() + assert result[0] == "CRL_DISABLED", "Query should execute successfully" + await cur.close() + + # Verify CRL settings were applied + assert cnx.cert_revocation_check_mode == "DISABLED" + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "crl_mode,allow_without_crl,should_succeed", + [ + ("DISABLED", True, True), # DISABLED mode always succeeds + ("DISABLED", False, True), # DISABLED mode always succeeds + ("ADVISORY", True, True), # ADVISORY mode is lenient + ("ADVISORY", False, True), # ADVISORY mode is lenient + ("ENABLED", True, True), # ENABLED with allow_without_crl should succeed + ("ENABLED", False, True), # ENABLED might succeed if certs have valid CRL URLs + ], +) +async def test_crl_validation_modes_parametrized( + conn_cnx, crl_mode, allow_without_crl, should_succeed +): + """Parametrized test for different CRL validation modes and settings.""" + try: + async with conn_cnx( + cert_revocation_check_mode=crl_mode, + allow_certificates_without_crl_url=allow_without_crl, + crl_connection_timeout_ms=5000, + crl_read_timeout_ms=5000, + ) as cnx: + if should_succeed: + assert ( + cnx + ), f"Connection should succeed with mode={crl_mode}, allow_without_crl={allow_without_crl}" + + # Test basic functionality + cur = cnx.cursor() + await cur.execute("SELECT 1") + result = await cur.fetchone() + assert result[0] == 1, "Basic query should work" + await cur.close() + + # Verify settings + assert cnx.cert_revocation_check_mode == crl_mode + assert cnx.allow_certificates_without_crl_url == allow_without_crl + else: + pytest.fail( + f"Connection should have failed with mode={crl_mode}, allow_without_crl={allow_without_crl}" + ) + + except Exception as e: + if should_succeed: + pytest.fail( + f"Connection unexpectedly failed with mode={crl_mode}, allow_without_crl={allow_without_crl}: {e}" + ) + else: + # Expected failure - verify it's a connection-related error + assert ( + "revoked" in str(e).lower() or "crl" in str(e).lower() + ), f"Expected CRL-related error, got: {e}" + + +@pytest.mark.skipolddriver +async def test_crl_cache_configuration(conn_cnx): + """Test CRL cache configuration options.""" + with tempfile.TemporaryDirectory() as temp_dir: + async with conn_cnx( + cert_revocation_check_mode="ADVISORY", # Use advisory to avoid strict failures + enable_crl_cache=True, + enable_crl_file_cache=True, + crl_cache_dir=temp_dir, + crl_cache_validity_hours=2, + crl_cache_removal_delay_days=1, + crl_cache_cleanup_interval_hours=1, + crl_cache_start_cleanup=False, # Don't start background cleanup in tests + ) as cnx: + assert cnx, "Connection should succeed with CRL cache configuration" + + # Verify cache settings were applied + assert cnx.enable_crl_cache is True + assert cnx.enable_crl_file_cache is True + assert cnx.crl_cache_dir == temp_dir + assert cnx.crl_cache_validity_hours == 2 + assert cnx.crl_cache_removal_delay_days == 1 + assert cnx.crl_cache_cleanup_interval_hours == 1 + assert cnx.crl_cache_start_cleanup is False + + # Test basic functionality + cur = cnx.cursor() + await cur.execute("SELECT 'cache_test' as result") + result = await cur.fetchone() + assert ( + result[0] == "cache_test" + ), "Query should work with cache configuration" + await cur.close() From e0de4ab853c208d419790a46f085bfcde0714af8 Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Wed, 12 Nov 2025 15:59:03 +0100 Subject: [PATCH 09/16] fixup! NO-SNOW: Add more tests for async crl --- .../connector/aio/_session_manager.py | 15 +++++++++------ src/snowflake/connector/ssl_wrap_socket.py | 4 ++++ test/integ/aio_it/test_crl_async.py | 18 ++++++++++++++++++ 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/src/snowflake/connector/aio/_session_manager.py b/src/snowflake/connector/aio/_session_manager.py index b805019af8..e4826320ca 100644 --- a/src/snowflake/connector/aio/_session_manager.py +++ b/src/snowflake/connector/aio/_session_manager.py @@ -14,8 +14,8 @@ from ..crl import CertRevocationCheckMode, CRLValidator from ..errorcode import ER_OCSP_RESPONSE_CERT_STATUS_REVOKED from ..ssl_wrap_socket import ( - FEATURE_CRL_CONFIG, FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME, + get_feature_crl_config, load_trusted_certificates, resolve_cafile, ) @@ -79,15 +79,16 @@ async def connect( connection = await super().connect(req, traces, timeout) protocol = connection.protocol + feature_crl_config = get_feature_crl_config() logger.debug( "CRL Check Mode: %s", - FEATURE_CRL_CONFIG.cert_revocation_check_mode.name, + feature_crl_config.cert_revocation_check_mode.name, ) if ( - FEATURE_CRL_CONFIG.cert_revocation_check_mode + feature_crl_config.cert_revocation_check_mode != CertRevocationCheckMode.DISABLED ): - self.validate_crl(protocol, req) + self.validate_crl(feature_crl_config, protocol, req) logger.debug( "The certificate revocation check was successful. No additional checks will be performed." ) @@ -111,11 +112,13 @@ async def connect( protocol._snowflake_ocsp_validated = True return connection - def validate_crl(self, protocol: ResponseHandler, req: ClientRequest): + def validate_crl( + self, feature_crl_config, protocol: ResponseHandler, req: ClientRequest + ): # Resolve CA file path from environment variables or use certifi default cafile_for_ctx = resolve_cafile({"ca_certs": certifi.where()}) crl_validator = CRLValidator.from_config( - FEATURE_CRL_CONFIG, + feature_crl_config, self._session_manager, trusted_certificates=load_trusted_certificates(cafile_for_ctx), ) diff --git a/src/snowflake/connector/ssl_wrap_socket.py b/src/snowflake/connector/ssl_wrap_socket.py index 5b2d6ed325..62ba689a5f 100644 --- a/src/snowflake/connector/ssl_wrap_socket.py +++ b/src/snowflake/connector/ssl_wrap_socket.py @@ -271,3 +271,7 @@ def _openssl_connect( time.sleep(sleeping_time) if err: raise err + + +def get_feature_crl_config() -> CRLConfig: + return FEATURE_CRL_CONFIG diff --git a/test/integ/aio_it/test_crl_async.py b/test/integ/aio_it/test_crl_async.py index d8b07d4192..67019ad170 100644 --- a/test/integ/aio_it/test_crl_async.py +++ b/test/integ/aio_it/test_crl_async.py @@ -11,6 +11,8 @@ import pytest +from snowflake.connector.ssl_wrap_socket import get_feature_crl_config + @pytest.mark.skipolddriver async def test_crl_validation_enabled_mode(conn_cnx): @@ -25,6 +27,10 @@ async def test_crl_validation_enabled_mode(conn_cnx): disable_ocsp_checks=True, ) as cnx: assert cnx, "Connection should succeed with CRL validation in ENABLED mode" + assert ( + get_feature_crl_config().cert_revocation_check_mode.value + == cnx.cert_revocation_check_mode + ) # Verify we can execute a simple query cur = cnx.cursor() @@ -51,6 +57,10 @@ async def test_crl_validation_advisory_mode(conn_cnx): crl_cache_validity_hours=1, # Cache for 1 hour ) as cnx: assert cnx, "Connection should succeed with CRL validation in ADVISORY mode" + assert ( + get_feature_crl_config().cert_revocation_check_mode.value + == cnx.cert_revocation_check_mode + ) # Verify we can execute a simple query cur = cnx.cursor() @@ -77,6 +87,10 @@ async def test_crl_validation_disabled_mode(conn_cnx): cert_revocation_check_mode="DISABLED", ) as cnx: assert cnx, "Connection should succeed with CRL validation in DISABLED mode" + assert ( + get_feature_crl_config().cert_revocation_check_mode.value + == cnx.cert_revocation_check_mode + ) # Verify we can execute a simple query cur = cnx.cursor() @@ -112,6 +126,10 @@ async def test_crl_validation_modes_parametrized( crl_connection_timeout_ms=5000, crl_read_timeout_ms=5000, ) as cnx: + assert ( + get_feature_crl_config().cert_revocation_check_mode.value + == cnx.cert_revocation_check_mode + ) if should_succeed: assert ( cnx From a5e2f45b15a3472898c0ce5dd085d57289d62488 Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Wed, 12 Nov 2025 17:09:30 +0100 Subject: [PATCH 10/16] fixup! fixup! NO-SNOW: Add more tests for async crl --- test/helpers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/helpers.py b/test/helpers.py index 441e51f011..fb5f9dd7d7 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -342,6 +342,7 @@ def apply_auth_class_update_body(auth_class, req_body_before): auth_class.update_body(req_body_after) return req_body_after + async def apply_auth_class_update_body_async(auth_class, req_body_before): req_body_after = copy.deepcopy(req_body_before) await auth_class.update_body(req_body_after) From f7b50561b04f96e1ebdb544c6cb309cb35281265 Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Thu, 13 Nov 2025 10:52:07 +0100 Subject: [PATCH 11/16] fixup! fixup! fixup! NO-SNOW: Add more tests for async crl --- src/snowflake/connector/ssl_wrap_socket.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/snowflake/connector/ssl_wrap_socket.py b/src/snowflake/connector/ssl_wrap_socket.py index 62ba689a5f..5b5b367f9f 100644 --- a/src/snowflake/connector/ssl_wrap_socket.py +++ b/src/snowflake/connector/ssl_wrap_socket.py @@ -20,6 +20,7 @@ import certifi import OpenSSL.SSL +from cryptography.utils import CryptographyDeprecationWarning from .constants import OCSPMode from .crl import CertRevocationCheckMode, CRLConfig, CRLValidator @@ -158,7 +159,16 @@ def load_trusted_certificates(cafile: str | None) -> list[x509.Certificate]: from cryptography.hazmat.backends import default_backend from cryptography.x509 import load_der_x509_certificate - return [load_der_x509_certificate(cert, default_backend()) for cert in certs] + x509_certs = [] + for cert in certs: + try: + x509_certs.append(load_der_x509_certificate(cert, default_backend())) + except CryptographyDeprecationWarning: + # Reason: Parsed a serial number which wasn't positive (i.e., it was negative or zero), which is + # disallowed by RFC 5280. Loading this certificate will cause an exception in a future + # release of cryptography. + continue + return x509_certs @wraps(ssl_.ssl_wrap_socket) From f4622e5f32ce40d75aec72538c490dd114b7ad65 Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Thu, 13 Nov 2025 11:31:43 +0100 Subject: [PATCH 12/16] Fix direct import checker --- ci/pre-commit/check_optional_imports.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/pre-commit/check_optional_imports.py b/ci/pre-commit/check_optional_imports.py index ab5d1a4dcd..8a71a59d4f 100644 --- a/ci/pre-commit/check_optional_imports.py +++ b/ci/pre-commit/check_optional_imports.py @@ -60,7 +60,8 @@ def visit_ImportFrom(self, node: ast.ImportFrom): if node.module: # Check if importing from a checked module directly for module in CHECKED_MODULES: - if node.module.startswith(module): + module_name = node.module.split(".")[0] + if module_name == module: self.violations.append( ImportViolation( self.filename, From 8db0952587bcd10d1999ec2c2351a7470b10573c Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Fri, 14 Nov 2025 14:26:40 +0100 Subject: [PATCH 13/16] Fix sending async requests in crl flow --- src/snowflake/connector/aio/_crl.py | 12 ++++++++++++ src/snowflake/connector/aio/_session_manager.py | 7 ++++--- src/snowflake/connector/crl.py | 6 +++++- 3 files changed, 21 insertions(+), 4 deletions(-) create mode 100644 src/snowflake/connector/aio/_crl.py diff --git a/src/snowflake/connector/aio/_crl.py b/src/snowflake/connector/aio/_crl.py new file mode 100644 index 0000000000..53df16deb0 --- /dev/null +++ b/src/snowflake/connector/aio/_crl.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from logging import getLogger + +from snowflake.connector.crl import CRLValidator as CRLValidatorSync + +logger = getLogger(__name__) + + +class CRLValidator(CRLValidatorSync): + async def _session_manager_get(self, *args, **kwargs): + return await self._session_manager.get(*args, **kwargs) diff --git a/src/snowflake/connector/aio/_session_manager.py b/src/snowflake/connector/aio/_session_manager.py index e4826320ca..0bcabbb8de 100644 --- a/src/snowflake/connector/aio/_session_manager.py +++ b/src/snowflake/connector/aio/_session_manager.py @@ -11,7 +11,7 @@ from aiohttp.typedefs import StrOrURL from .. import OperationalError -from ..crl import CertRevocationCheckMode, CRLValidator +from ..crl import CertRevocationCheckMode from ..errorcode import ER_OCSP_RESPONSE_CERT_STATUS_REVOKED from ..ssl_wrap_socket import ( FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME, @@ -19,6 +19,7 @@ load_trusted_certificates, resolve_cafile, ) +from ._crl import CRLValidator from ._ocsp_asn1crypto import SnowflakeOCSPAsn1Crypto if TYPE_CHECKING: @@ -122,9 +123,9 @@ def validate_crl( self._session_manager, trusted_certificates=load_trusted_certificates(cafile_for_ctx), ) - sll_object = protocol.transport.get_extra_info("ssl_object") + ssl_object = protocol.transport.get_extra_info("ssl_object") # TODO(asyncio): SNOW-2681061 Add sync support for validate_connection - if not crl_validator.validate_connection(sll_object): + if not crl_validator.validate_connection(ssl_object): raise OperationalError( msg=( "The certificate is revoked or " diff --git a/src/snowflake/connector/crl.py b/src/snowflake/connector/crl.py index 2429fc8e3b..f604a822f1 100644 --- a/src/snowflake/connector/crl.py +++ b/src/snowflake/connector/crl.py @@ -548,10 +548,14 @@ def _put_crl_to_cache( ) -> None: self._cache_manager.put(crl_url, crl, ts) + def _session_manager_get(self, *args, **kwargs): + """Dedicated method that is being overridden in aio._crl.CRLValidator""" + return self._session_manager.get(*args, **kwargs) + def _fetch_crl_from_url(self, crl_url: str) -> bytes | None: try: logger.debug("Trying to download CRL from: %s", crl_url) - response = self._session_manager.get( + response = self._session_manager_get( crl_url, timeout=(self._connection_timeout_ms, self._read_timeout_ms) ) response.raise_for_status() From 747e6f56674e4e3f04fdaa558149c1107033d773 Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Fri, 14 Nov 2025 14:53:01 +0100 Subject: [PATCH 14/16] fixup! Fix sending async requests in crl flow --- src/snowflake/connector/aio/_crl.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/snowflake/connector/aio/_crl.py b/src/snowflake/connector/aio/_crl.py index 53df16deb0..cf6dae6717 100644 --- a/src/snowflake/connector/aio/_crl.py +++ b/src/snowflake/connector/aio/_crl.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from logging import getLogger from snowflake.connector.crl import CRLValidator as CRLValidatorSync @@ -8,5 +9,7 @@ class CRLValidator(CRLValidatorSync): - async def _session_manager_get(self, *args, **kwargs): - return await self._session_manager.get(*args, **kwargs) + def _session_manager_get(self, *args, **kwargs): + return asyncio.get_event_loop().run_until_complete( + self._session_manager.get(*args, **kwargs) + ) From a9fed63b0135c0b14ec0c16160e7a4a9496f2759 Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Fri, 14 Nov 2025 15:45:42 +0100 Subject: [PATCH 15/16] fixup! fixup! Fix sending async requests in crl flow --- src/snowflake/connector/aio/_session_manager.py | 9 +++++++-- test/integ/aio_it/test_crl_async.py | 1 + test/integ/test_crl.py | 1 + 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/snowflake/connector/aio/_session_manager.py b/src/snowflake/connector/aio/_session_manager.py index 0bcabbb8de..51a11aaf95 100644 --- a/src/snowflake/connector/aio/_session_manager.py +++ b/src/snowflake/connector/aio/_session_manager.py @@ -263,12 +263,17 @@ async def get( url: str, *, headers: Mapping[str, str] | None = None, - timeout: int | None = 3, + timeout: int | tuple[int, int] | None = 3, use_pooling: bool | None = None, **kwargs, ) -> aiohttp.ClientResponse: - async with self.use_session(url, use_pooling) as session: + if isinstance(timeout, tuple): + connect, total = timeout + timeout_obj = aiohttp.ClientTimeout(total=total, connect=connect) + else: timeout_obj = aiohttp.ClientTimeout(total=timeout) if timeout else None + + async with self.use_session(url, use_pooling) as session: return await session.get( url, headers=headers, timeout=timeout_obj, **kwargs ) diff --git a/test/integ/aio_it/test_crl_async.py b/test/integ/aio_it/test_crl_async.py index 67019ad170..763d7bda1e 100644 --- a/test/integ/aio_it/test_crl_async.py +++ b/test/integ/aio_it/test_crl_async.py @@ -24,6 +24,7 @@ async def test_crl_validation_enabled_mode(conn_cnx): allow_certificates_without_crl_url=True, # Allow certs without CRL URLs crl_connection_timeout_ms=5000, # 5 second timeout crl_read_timeout_ms=5000, # 5 second timeout + enable_crl_file_cache=False, # To avoid local side effects disable_ocsp_checks=True, ) as cnx: assert cnx, "Connection should succeed with CRL validation in ENABLED mode" diff --git a/test/integ/test_crl.py b/test/integ/test_crl.py index 8f97907646..b7d9e73177 100644 --- a/test/integ/test_crl.py +++ b/test/integ/test_crl.py @@ -32,6 +32,7 @@ def test_crl_validation_enabled_mode(conn_cnx): allow_certificates_without_crl_url=True, # Allow certs without CRL URLs crl_connection_timeout_ms=5000, # 5 second timeout crl_read_timeout_ms=5000, # 5 second timeout + enable_crl_file_cache=False, # To avoid local side effects disable_ocsp_checks=True, ) as cnx: assert cnx, "Connection should succeed with CRL validation in ENABLED mode" From e0eaea17d880859420ade589a5435c23117f70e4 Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Fri, 14 Nov 2025 16:41:17 +0100 Subject: [PATCH 16/16] fixup! fixup! fixup! Fix sending async requests in crl flow --- src/snowflake/connector/aio/_crl.py | 269 +++++++++++++++++- .../connector/aio/_session_manager.py | 7 +- src/snowflake/connector/crl.py | 6 +- 3 files changed, 269 insertions(+), 13 deletions(-) diff --git a/src/snowflake/connector/aio/_crl.py b/src/snowflake/connector/aio/_crl.py index cf6dae6717..e9c12143dd 100644 --- a/src/snowflake/connector/aio/_crl.py +++ b/src/snowflake/connector/aio/_crl.py @@ -1,15 +1,276 @@ from __future__ import annotations -import asyncio +from datetime import datetime, timezone from logging import getLogger +from cryptography import x509 + +from snowflake.connector.crl import CRLValidationResult from snowflake.connector.crl import CRLValidator as CRLValidatorSync logger = getLogger(__name__) class CRLValidator(CRLValidatorSync): - def _session_manager_get(self, *args, **kwargs): - return asyncio.get_event_loop().run_until_complete( - self._session_manager.get(*args, **kwargs) + async def _fetch_crl_from_url(self, crl_url: str) -> bytes | None: + """Async version of CRL fetching""" + try: + logger.debug("Trying to download CRL from: %s", crl_url) + response = await self._session_manager.get( + crl_url, timeout=(self._connection_timeout_ms, self._read_timeout_ms) + ) + response.raise_for_status() + return response.content + except Exception: + # CRL fetch or parsing failed + logger.exception("Failed to download CRL from %s", crl_url) + return None + + async def _download_crl( + self, crl_url: str + ) -> tuple[x509.CertificateRevocationList | None, datetime | None]: + """Async version of CRL download""" + from cryptography.hazmat.backends import default_backend + + crl_bytes = await self._fetch_crl_from_url(crl_url) + now = datetime.now(timezone.utc) + try: + logger.debug("Trying to parse CRL from: %s", crl_url) + crl = x509.load_der_x509_crl(crl_bytes, backend=default_backend()) + # Check if CRL is expired + try: + next_update = crl.next_update_utc + except AttributeError: + next_update = crl.next_update + + if next_update and now > next_update: + logger.warning( + "The CRL from %s was expired on %s", crl_url, next_update + ) + return None, None + + return crl, now + except Exception: + logger.exception("Failed to parse CRL from %s", crl_url) + return None, None + + async def _check_certificate_against_crl_url( + self, cert: x509.Certificate, ca_cert: x509.Certificate, crl_url: str + ) -> CRLValidationResult: + """Async version of checking certificate against CRL URL""" + now = datetime.now(timezone.utc) + logger.debug("Trying to get cached CRL for %s", crl_url) + cached_crl = self._get_crl_from_cache(crl_url) + if ( + cached_crl is None + or cached_crl.is_crl_expired_by(now) + or cached_crl.is_evicted_by(now, self._cache_validity_time) + ): + crl, ts = await self._download_crl(crl_url) + if crl and ts: + self._put_crl_to_cache(crl_url, crl, ts) + else: + crl = cached_crl.crl + + # If by some reason we didn't get a valid CRL we consider it a check error + if crl is None: + return CRLValidationResult.ERROR + + # Verify CRL signature with CA public key + # Check if the CA certificate is the expected CRL issuer + if crl.issuer != ca_cert.subject: + logger.warning( + "CRL issuer (%s) does not match CA certificate subject (%s) for URL: %s", + crl.issuer, + ca_cert.subject, + crl_url, + ) + return CRLValidationResult.ERROR + + if not self._verify_crl_signature(crl, ca_cert): + logger.warning("CRL signature verification failed for URL: %s", crl_url) + # Always return ERROR when signature verification fails + # We cannot trust a CRL whose signature cannot be verified + return CRLValidationResult.ERROR + + # Verify that the CRL URL matches the IDP extension + if not self._verify_against_idp_extension(crl, crl_url): + logger.warning("CRL URL does not match IDP extension for URL: %s", crl_url) + return CRLValidationResult.ERROR + + # Check if certificate is revoked + return self._check_certificate_against_crl(cert, crl) + + async def _validate_certificate_is_not_revoked( + self, cert: x509.Certificate, ca_cert: x509.Certificate + ) -> CRLValidationResult: + """Async version of certificate validation""" + # Check if certificate is short-lived (skip CRL check) + if self._is_short_lived_certificate(cert): + return CRLValidationResult.UNREVOKED + + # Extract CRL distribution points + crl_urls = self._extract_crl_distribution_points(cert) + + if not crl_urls: + # No CRL URLs found + if self._allow_certificates_without_crl_url: + return CRLValidationResult.UNREVOKED + return CRLValidationResult.ERROR + + results: list[CRLValidationResult] = [] + # Check against each CRL URL + for crl_url in crl_urls: + result = await self._check_certificate_against_crl_url( + cert, ca_cert, crl_url + ) + if result == CRLValidationResult.REVOKED: + return result + results.append(result) + + if all(result == CRLValidationResult.ERROR for result in results): + return CRLValidationResult.ERROR + + return CRLValidationResult.UNREVOKED + + async def _validate_certificate_is_not_revoked_with_cache( + self, cert: x509.Certificate, ca_cert: x509.Certificate + ) -> CRLValidationResult: + """Async version with caching""" + # validate certificate can be called multiple times with the same certificate + if cert not in self._cache_for__validate_certificate_is_not_revoked: + self._cache_for__validate_certificate_is_not_revoked[cert] = ( + await self._validate_certificate_is_not_revoked(cert, ca_cert) + ) + return self._cache_for__validate_certificate_is_not_revoked[cert] + + async def _validate_chain( + self, start_cert: x509.Certificate, chain: list[x509.Certificate] + ) -> CRLValidationResult: + """Async version of chain validation""" + from collections import defaultdict + + # Check if start certificate is expired + if not self._is_within_validity_dates(start_cert): + logger.warning( + "Start certificate is expired or not yet valid: %s", start_cert.subject + ) + return CRLValidationResult.ERROR + + subject_certificates: dict[x509.Name, list[x509.Certificate]] = defaultdict( + list ) + for cert in chain: + if not self._is_ca_certificate(cert): + logger.warning("Ignoring non-CA certificate: %s", cert) + continue + if not self._is_within_validity_dates(cert): + logger.warning( + "Ignoring certificate not within validity dates: %s", cert + ) + continue + subject_certificates[cert.subject].append(cert) + currently_visited_subjects: set[x509.Name] = set() + + async def traverse_chain(cert: x509.Certificate) -> CRLValidationResult | None: + # UNREVOKED - unrevoked path to a trusted certificate found + # REVOKED - all paths are revoked + # ERROR - some certificates on potentially unrevoked paths can't be verified, or no path to a trusted CA is detected + # None - ignore this path (cycle detected) + if self._is_certificate_trusted_by_os(cert): + logger.debug("Found trusted certificate: %s", cert.subject) + return CRLValidationResult.UNREVOKED + + if trusted_ca_issuer := self._get_trusted_ca_issuer(cert): + logger.debug("Certificate signed by trusted CA: %s", cert.subject) + return await self._validate_certificate_is_not_revoked_with_cache( + cert, trusted_ca_issuer + ) + + if cert.issuer in currently_visited_subjects: + # cycle detected - invalid path + return None + + valid_results: list[tuple[CRLValidationResult, x509.Certificate]] = [] + for ca_cert in subject_certificates[cert.issuer]: + if not self._verify_certificate_signature(cert, ca_cert): + logger.debug( + "Certificate signature verification failed for %s, looking for other paths", + cert, + ) + continue + + currently_visited_subjects.add(cert.issuer) + ca_result = await traverse_chain(ca_cert) + currently_visited_subjects.remove(cert.issuer) + if ca_result is None: + # ignore invalid path result + continue + if ca_result == CRLValidationResult.UNREVOKED: + # good path found + return await self._validate_certificate_is_not_revoked_with_cache( + cert, ca_cert + ) + valid_results.append((ca_result, ca_cert)) + + if len(valid_results) == 0: + # "root" certificate not cought by "is_trusted_by_os" check + logger.debug("No path towards trusted anchor: %s", cert.subject) + return CRLValidationResult.ERROR + + # check if there exists an ERROR path + for ca_result, ca_cert in valid_results: + if ca_result == CRLValidationResult.ERROR: + cert_result = ( + await self._validate_certificate_is_not_revoked_with_cache( + cert, ca_cert + ) + ) + if cert_result == CRLValidationResult.REVOKED: + return CRLValidationResult.REVOKED + return CRLValidationResult.ERROR + + # no ERROR result found, all paths are REVOKED + return CRLValidationResult.REVOKED + + return await traverse_chain(start_cert) + + async def validate_certificate_chain( + self, peer_cert: x509.Certificate, chain: list[x509.Certificate] | None + ) -> bool: + """Async version of certificate chain validation""" + from snowflake.connector.crl import CertRevocationCheckMode + + if self._cert_revocation_check_mode == CertRevocationCheckMode.DISABLED: + return True + + chain = chain if chain is not None else [] + result = await self._validate_chain(peer_cert, chain) + + if result == CRLValidationResult.UNREVOKED: + return True + if result == CRLValidationResult.REVOKED: + return False + # In advisory mode, errors are treated positively + return self._cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY + + async def validate_connection(self, connection) -> bool: + """Async version of connection validation""" + from snowflake.connector.crl import CertRevocationCheckMode + + try: + # Get the peer certificate (the start certificate) + peer_cert = self._get_peer_certificate(connection) + if peer_cert is None: + logger.warning("No peer certificate found in connection") + return ( + self._cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY + ) + + # Extract the certificate chain + cert_chain = self._extract_certificate_chain_from_connection(connection) + + return await self.validate_certificate_chain(peer_cert, cert_chain) + except Exception as e: + logger.warning("Failed to validate connection: %s", e) + return self._cert_revocation_check_mode == CertRevocationCheckMode.ADVISORY diff --git a/src/snowflake/connector/aio/_session_manager.py b/src/snowflake/connector/aio/_session_manager.py index 51a11aaf95..03124d85b5 100644 --- a/src/snowflake/connector/aio/_session_manager.py +++ b/src/snowflake/connector/aio/_session_manager.py @@ -89,7 +89,7 @@ async def connect( feature_crl_config.cert_revocation_check_mode != CertRevocationCheckMode.DISABLED ): - self.validate_crl(feature_crl_config, protocol, req) + await self.validate_crl(feature_crl_config, protocol, req) logger.debug( "The certificate revocation check was successful. No additional checks will be performed." ) @@ -113,7 +113,7 @@ async def connect( protocol._snowflake_ocsp_validated = True return connection - def validate_crl( + async def validate_crl( self, feature_crl_config, protocol: ResponseHandler, req: ClientRequest ): # Resolve CA file path from environment variables or use certifi default @@ -124,8 +124,7 @@ def validate_crl( trusted_certificates=load_trusted_certificates(cafile_for_ctx), ) ssl_object = protocol.transport.get_extra_info("ssl_object") - # TODO(asyncio): SNOW-2681061 Add sync support for validate_connection - if not crl_validator.validate_connection(ssl_object): + if not await crl_validator.validate_connection(ssl_object): raise OperationalError( msg=( "The certificate is revoked or " diff --git a/src/snowflake/connector/crl.py b/src/snowflake/connector/crl.py index f604a822f1..2429fc8e3b 100644 --- a/src/snowflake/connector/crl.py +++ b/src/snowflake/connector/crl.py @@ -548,14 +548,10 @@ def _put_crl_to_cache( ) -> None: self._cache_manager.put(crl_url, crl, ts) - def _session_manager_get(self, *args, **kwargs): - """Dedicated method that is being overridden in aio._crl.CRLValidator""" - return self._session_manager.get(*args, **kwargs) - def _fetch_crl_from_url(self, crl_url: str) -> bytes | None: try: logger.debug("Trying to download CRL from: %s", crl_url) - response = self._session_manager_get( + response = self._session_manager.get( crl_url, timeout=(self._connection_timeout_ms, self._read_timeout_ms) ) response.raise_for_status()